In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import einops
from tqdm import tqdm

import interp_tools.saes.jumprelu_sae as jumprelu_sae
import interp_tools.saes.base_sae as base_sae
import interp_tools.model_utils as model_utils
import interp_tools.data_utils as data_utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "google/gemma-2-2b"
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=dtype)

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.06it/s]


In [3]:
layer = 20

repo_id = "google/gemma-scope-2b-pt-res"
filename = f"layer_{layer}/width_16k/average_l0_71/params.npz"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "google/gemma-2-2b"

sae = jumprelu_sae.load_gemma_scope_jumprelu_sae(repo_id, filename, layer, model_name, device, dtype)

submodule = model_utils.get_submodule(model, layer)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

batch_size = 32
context_length = 128

batched_tokens = data_utils.get_batched_tokens(
    tokenizer=tokenizer,
    model_name=model_name,
    dataset_name = "togethercomputer/RedPajama-Data-V2",
    num_tokens=2_000_000,
    batch_size=batch_size,
    device=device,
    context_length=context_length,
)

Loading tokenized dataset from tokens/togethercomputer_RedPajama-Data-V2_2000000_google_gemma-2-2b.pt


In [5]:
print(batched_tokens[0]['input_ids'].shape)
print(batched_tokens[0]['attention_mask'].shape)

torch.Size([32, 128])
torch.Size([32, 128])


In [6]:
@torch.no_grad()
def get_max_activating_prompts(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    submodule: torch.nn.Module,
    tokenized_inputs_bL: list[dict[str, torch.Tensor]],
    dim_indices: torch.Tensor,
    batch_size: int,
    dictionary: base_sae.BaseSAE,
    context_length: int,
    k: int = 30,
    zero_bos: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    For each feature in dim_indices, find the top-k (prompt, position) with the highest
    dictionary-encoded activation. Return the tokens and the activations for those points.
    """

    device = model.device
    feature_count = dim_indices.shape[0]

    # We'll store results in [F, k] or [F, k, L] shape
    max_activating_indices_FK = torch.zeros(
        (feature_count, k), device=device, dtype=torch.int32
    )
    max_activations_FK = torch.zeros(
        (feature_count, k), device=device, dtype=torch.bfloat16
    )
    max_tokens_FKL = torch.zeros(
        (feature_count, k, context_length), device=device, dtype=torch.int32
    )
    max_activations_FKL = torch.zeros(
        (feature_count, k, context_length), device=device, dtype=torch.bfloat16
    )

    for i, inputs_BL in tqdm(
        enumerate(tokenized_inputs_bL), total=len(tokenized_inputs_bL)
    ):
        batch_offset = i * batch_size
        attention_mask = inputs_BL["attention_mask"]

        # 1) Collect submodule activations
        activations_BLD = model_utils.collect_activations(model, submodule, inputs_BL)

        # 2) Apply dictionary's encoder
        #    shape: [B, L, D], dictionary.encode -> [B, L, F]
        #    Then keep only the dims in dim_indices
        activations_BLF = dictionary.encode(activations_BLD)
        if zero_bos:
            bos_mask_BL = data_utils.get_bos_pad_eos_mask(
                inputs_BL["input_ids"], tokenizer
            )
            activations_BLF *= bos_mask_BL[:, :, None]

        activations_BLF = activations_BLF[:, :, dim_indices]  # shape: [B, L, Fselected]

        activations_BLF = activations_BLF * attention_mask[:, :, None]

        # 3) Move dimension to (F, B, L)
        activations_FBL = einops.rearrange(activations_BLF, "B L F -> F B L")

        # For each sequence, the "peak activation" is the maximum over positions:
        # shape: [F, B]
        activations_FB = einops.reduce(activations_FBL, "F B L -> F B", "max")

        # We'll replicate the tokens to shape [F, B, L]
        tokens_FBL = einops.repeat(
            inputs_BL["input_ids"], "B L -> F B L", F=feature_count
        )

        # Create an index for the batch offset
        indices_B = torch.arange(batch_offset, batch_offset + batch_size, device=device)
        indices_FB = einops.repeat(indices_B, "B -> F B", F=feature_count)

        # Concatenate with previous top-k
        combined_activations_FB = torch.cat([max_activations_FK, activations_FB], dim=1)
        combined_indices_FB = torch.cat([max_activating_indices_FK, indices_FB], dim=1)

        combined_activations_FBL = torch.cat(
            [max_activations_FKL, activations_FBL], dim=1
        )
        combined_tokens_FBL = torch.cat([max_tokens_FKL, tokens_FBL], dim=1)

        # 4) Sort to keep only top-k
        topk_activations_FK, topk_indices_FK = torch.topk(
            combined_activations_FB, k, dim=1
        )

        max_activations_FK = topk_activations_FK
        feature_indices_F1 = torch.arange(feature_count, device=device)[:, None]

        max_activating_indices_FK = combined_indices_FB[
            feature_indices_F1, topk_indices_FK
        ]
        max_activations_FKL = combined_activations_FBL[
            feature_indices_F1, topk_indices_FK
        ]
        max_tokens_FKL = combined_tokens_FBL[feature_indices_F1, topk_indices_FK]

    return max_tokens_FKL, max_activations_FKL


In [7]:
# NOTE: This is also available in interp_utils.py

max_tokens_FKL, max_activations_FKL = get_max_activating_prompts(
        model=model,
        tokenizer=tokenizer,
        submodule=submodule,
        tokenized_inputs_bL=batched_tokens,
        dim_indices=torch.arange(sae.W_dec.shape[0]),
        batch_size=batch_size,
        dictionary=sae,
        context_length=context_length,
        k=30,  # or pass as a parameter if you want
    )

100%|██████████| 489/489 [04:03<00:00,  2.01it/s]


In [8]:
print(max_activations_FKL.shape)
print(max_tokens_FKL.shape)

torch.Size([16384, 30, 128])
torch.Size([16384, 30, 128])


In [None]:
from circuitsvis.activations import text_neuron_activations
import gc
from IPython.display import clear_output, display


def _list_decode(x: torch.Tensor):
    assert len(x.shape) == 1 or len(x.shape) == 2
    # Convert to list of lists, even if x is 1D
    if len(x.shape) == 1:
        x = x.unsqueeze(0)  # Make it 2D for consistent handling

    # Convert tensor to list of list of ints
    token_ids = x.tolist()
    
    # Convert token ids to token strings
    return [tokenizer.batch_decode(seq, skip_special_tokens=False) for seq in token_ids]


def create_html_activations(
    selected_tokens_FKL: list[str],
    selected_activations_FKL: list[torch.Tensor],
    num_display: int = 10,
    k: int = 5,
) -> list:

    all_html_activations = []

    for i in range(num_display):

        selected_activations_KL11 = [
            selected_activations_FKL[i, k, :, None, None] for k in range(k)
        ]
        selected_tokens_KL = selected_tokens_FKL[i]
        selected_token_strs_KL = _list_decode(selected_tokens_KL)

        for k in range(len(selected_token_strs_KL)):
            if (
                "<s>" in selected_token_strs_KL[k][0]
                or "<bos>" in selected_token_strs_KL[k][0]
            ):
                selected_token_strs_KL[k][0] = "BOS>"


        html_activations = text_neuron_activations(
            selected_token_strs_KL, selected_activations_KL11
        )

        all_html_activations.append(html_activations)
    
    return all_html_activations

top_k_ids = torch.tensor([1000, 10004])

clear_output(wait=True)
gc.collect()
html_activations = create_html_activations(max_tokens_FKL[top_k_ids.cpu()], max_activations_FKL[top_k_ids.cpu()], num_display=len(top_k_ids))


torch.Size([2, 30, 128])
torch.Size([2, 30, 128])


In [13]:
# NOTE: If using remote ssh, visualizing the html activations will often be slow / buggy if doing many 1000s of tokens.
# In this case, I save html activations with pickle and then either a) display them locally or b) open them in jupyter browser on the remote machine.

display(html_activations[1])