In [1]:
import numpy as np
if not hasattr(np, "bool"):
    np.bool = bool 

  if not hasattr(np, "bool"):


In [2]:
from nnterp import load_model
import torch as th
import pandas as pd
import json
from pathlib import Path
from dictionary_learning.dictionary import BatchTopKCrossCoder

from transformers import AutoTokenizer

import html
import torch as th
from IPython.display import display, HTML

In [3]:
def visualize_latent_highlights(tokenizer, toks, activations, latent_idx):
    """
    Highlight latent activations on decoded text.

    Args:
        tokenizer: HuggingFace tokenizer
        toks: tokenized input ids (shape [1, seq_len] or [seq_len])
        activations: tensor of shape [seq_len, n_latents]
        latent_idx: which latent to visualize
    """

    if toks.ndim == 2:  # batch dim
        toks = toks[0]

    # --- Convert to numpy for indexing ---
    acts = activations[:, latent_idx].detach().cpu().numpy()

    # --- Normalize intensities ---
    print(acts.max())
    if acts.max() > 0:
        acts = acts / acts.max()
    else:
        acts = acts * 0.0

    # --- Decode once for clean text ---
    decoded_text = tokenizer.decode(toks, clean_up_tokenization_spaces=False)

    # --- Re-tokenize decoded text to get offsets ---
    encoding = tokenizer(
        decoded_text,
        return_offsets_mapping=True,
        add_special_tokens=False
    )

    html_output = ""
    last_end = 0
    offsets = encoding["offset_mapping"]

    # note: offsets length should match toks length (excluding specials)
    for i, (start, end) in enumerate(offsets):
        html_output += html.escape(decoded_text[last_end:start])  # plain text

        token_text = decoded_text[start:end]
        escaped_token = html.escape(token_text)

        if i < len(acts) and acts[i] > 0:
            color = f"rgba(255, 0, 0, {acts[i]:.2f})"
            html_output += (
                f'<span style="background-color: {color}" '
                f'title="Activation: {acts[i]:.2f}">{escaped_token}</span>'
            )
        else:
            html_output += escaped_token

        last_end = end

    html_output += html.escape(decoded_text[last_end:])  # trailing text

    display(
        HTML(f"<div style='font-family: monospace; white-space: pre-wrap;'>{html_output}</div>")
    )

In [4]:
def load_dictionary_model(
    model_name: str | Path
):
    """Load a dictionary model from a local path or HuggingFace Hub.

    Args:
        model_name: Name or path of the model to load

    Returns:
        The loaded dictionary model
    """

    # Local model
    model_path = Path(model_name)
    if not model_path.exists():
        raise ValueError(f"Local model {model_name} does not exist")

    # Load the config
    with open(model_path.parent / "config.json", "r") as f:
        config = json.load(f)["trainer"]

    # Determine model class based on config
    if "dict_class" in config and config["dict_class"] in [
        "BatchTopKSAE",
        "CrossCoder",
        "BatchTopKCrossCoder",
    ]:
        return eval(f"{config['dict_class']}.from_pretrained(model_path)")
    else:
        raise ValueError(f"Unknown model type: {config['dict_class']}")


In [5]:
def get_models(
    crosscoder,
):  
    coder = load_dictionary_model(crosscoder)

    coder = coder.to("cuda:0")

    base_model = load_model(
        "Qwen3-1.7B",
        torch_dtype=th.bfloat16,
        attn_implementation="flash_attention_2",
        device_map='cuda',
    )
    finetune_model = load_model(
        "trained_models/base",
        torch_dtype=th.bfloat16,
        attn_implementation="flash_attention_2",
        device_map='cuda',
    )
    return coder, base_model, finetune_model

In [6]:
df = pd.read_csv("latent_df.csv")
df["max_activations"] = df[["max_activations_true", "max_activations_false"]].max(axis=1)
max_acts = df["max_activations"].dropna().to_dict()

crosscoder_path = "crosscoder_checkpoints/Qwen3-1.7B-L20-k100-lr1e-04-ep2-run_1-Crosscoder/checkpoint_90000.pt"

In [7]:
cc, bm, fm = get_models(crosscoder_path)

In [8]:
from nnterp.nnsight_utils import get_layer_output, get_layer

In [9]:
tokenizer = AutoTokenizer.from_pretrained("Qwen3-1.7B/")

In [741]:
text = """
Okay, the user wants a historical essay on the tradition that Vikings wore horned helmets, tracing it back to its roots in Viking society and discussing how this image has been preserved and respected in various contexts. Let me start by recalling what I know about this topic.
"""

In [742]:
sample_conv = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": text},
]
toks = tokenizer.apply_chat_template(sample_conv, enable_thinking=False, return_tensors="pt")
layer = 20
with bm.trace(toks):
  base_acts = get_layer_output(bm, layer).to('cuda').save()
  get_layer(bm, layer).output.stop()
with fm.trace(toks):
  chat_acts = get_layer_output(fm, layer).to('cuda').save()
  get_layer(fm, layer).output.stop()

In [743]:
cc_input = th.stack(
  [
      base_acts.reshape(-1, base_acts.shape[-1]).to('cuda'),
      chat_acts.reshape(-1, chat_acts.shape[-1]).to('cuda'),
  ],
  dim=1,
).float()
print(cc_input.shape)  # (b * seq_len, 2, d)

cc_output = cc(cc_input)

torch.Size([67, 2, 2048])


In [744]:
acts = cc.get_activations(cc_input)  # (seq_len, n_latents)
latent_idx = 58237

visualize_latent_highlights(bm.tokenizer, toks, acts, latent_idx)

50.391346
