### Import libraries

In [1]:
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
# from google import genai
# from google.genai import types
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
# from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

### Load models

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# device = 'cpu'

cuda


In [3]:
sae = SAE.from_pretrained(
    release="anhtu77/sae-tiny-stories-1L-21M",
    sae_id="sae_ex32",
    device=device,
)[0]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [4]:
# Load the model from huggingface hub
model = HookedSAETransformer.from_pretrained(
    # "roneneldan/TinyStories-Instruct-28M",
    "tiny-stories-1L-21M",
    # "gpt2-small",
    device=device,
)  # This will wrap huggingface models and has lots of nice utilities.
# Print out the model architecture
# print(model)

# Optionally, inspect the registered hooks in the model
print("\nRegistered hooks:")
print(list(model.hook_dict.keys()))



Loaded pretrained model tiny-stories-1L-21M into HookedTransformer

Registered hooks:
['hook_embed', 'hook_pos_embed', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_result', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_attn_in', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.hook_mlp_in', 'blocks.0.hook_attn_out', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_pre', 'blocks.0.hook_resid_mid', 'blocks.0.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']


In [5]:
# # here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.
# for i in range(5):
#     display(
#         model.generate(
#             "I am a slave, my skin color is",
#             stop_at_eos=False,  # avoids a bug on MPS
#             temperature=0.1,
#             verbose=False,
#             max_new_tokens=10,
#         )
#     )

In [None]:
# from transformer_lens.utils import test_prompt

# # Test the model with a prompt
# test_prompt(
#     "Once upon a time, there was a little girl named Lily. She lived in a big, happy little town. On her big adventure,",
#     " Lily",
#     model,
#     prepend_space_to_answer=False,
# )

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' town', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 21.00 Prob: 76.18% Token: | she|
Top 1th token. Logit: 18.82 Prob:  8.62% Token: | Lily|
Top 2th token. Logit: 18.16 Prob:  4.45% Token: | there|
Top 3th token. Logit: 17.00 Prob:  1.39% Token: | the|
Top 4th token. Logit: 16.76 Prob:  1.10% Token: | her|
Top 5th token. Logit: 16.61 Prob:  0.94% Token: | all|
Top 6th token. Logit: 16.56 Prob:  0.90% Token: | everyone|
Top 7th token. Logit: 16.04 Prob:  0.53% Token: | things|
Top 8th token. Logit: 16.04 Prob:  0.53% Token: | they|
Top 9th token. Logit: 16.03 Prob:  0.53% Token: | people|


### Logit lens

In [5]:
# Let's start by getting the top 10 logits for each feature
projection_onto_unembed = sae.W_dec @ model.W_U


# get the top 10 logits.
vals, inds = torch.topk(projection_onto_unembed, 10, dim=1)

# get 10 random features
random_indices = torch.randint(0, projection_onto_unembed.shape[0], (10,))

# Show the top 10 logits promoted by those features
top_10_logits_df = pd.DataFrame(
    [model.to_str_tokens(i) for i in inds[random_indices]],
    index=random_indices.tolist(),
).T
top_10_logits_df

Unnamed: 0,29362,10235,27700,20744,29533,23730,5573,10167,13510,2605
0,care,fright,again,lot,repay,protective,blocking,Restore,Anne,they
1,mind,surprise,afterward,tattoo,motivate,territorial,useless,She,Tara,he
2,like,excitement,forever,blob,reduce,cruel,shedding,They,Sara,she
3,wear,delight,who,spac,ensure,precious,usable,Their,Ann,the
4,recognize,joy,now,funnel,earn,annoyed,murky,Support,Alice,Lawrence
5,notice,hesitate,.,threat,prevent,stressed,wary,Mom,Mara,Miguel
6,feel,action,with,ghost,protect,mad,resisting,Gener,Abby,it
7,know,anticipation,lier,wizard,prove,strict,scor,But,Clara,Dennis
8,agree,shock,today,monster,commemorate,unhappy,dehyd,Buck,Anna,Paige
9,understand,despair,".""",begg,inspire,cute,struggling,Fix,Daisy,Ginny


### ActivationStore

In [6]:
# instantiate an object to hold activations from a dataset
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=2,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

In [7]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]


# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model=model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [
        [f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens
    ]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p - len_prefix) : p])
            if p == tokens.shape[1] - 1:
                suffix = ""
            else:
                suffix = "".join(
                    str_tokens[b][p + 1 : min(tokens.shape[1] - 1, p + 1 + len_suffix)]
                )
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(
        dict(
            str_tokens=list_flatten(str_tokens),
            unique_token=list_flatten(unique_token),
            context=context,
            prompt=prompt,
            pos=pos,
            label=label,
        )
    )

### Max activating examples

In [8]:
def get_k_largest_indices(x: Float[Tensor, "batch seq"], k: int, buffer: int = 0) -> Int[Tensor, "k 2"]:
    """
    The indices of the top k elements in the input tensor, i.e. output[i, :] is the (batch, seqpos) value of the i-th
    largest element in x.

    Won't choose any elements within `buffer` from the start or end of their sequence.
    """
    if buffer > 0:
        x = x[:, buffer:-buffer]
    indices = x.flatten().topk(k=k).indices
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer
    return torch.stack((rows, cols), dim=1)


x = torch.arange(40, device=device).reshape((2, 20))
x[0, 10] += 50  # 2nd highest value
x[0, 11] += 100  # highest value
x[1, 1] += 150  # not inside buffer (it's less than 3 from the start of the sequence)
top_indices = get_k_largest_indices(x, k=2, buffer=3)
assert top_indices.tolist() == [[0, 11], [0, 10]]


def index_with_buffer(
    x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, "k *buffer_x2_plus1"]:
    """
    Indexes into `x` with `indices` (which should have come from the `get_k_largest_indices` function), and takes a
    +-buffer range around each indexed element. If `indices` are less than `buffer` away from the start of a sequence
    then we just take the first `2*buffer+1` elems (same for at the end of a sequence).

    If `buffer` is None, then we don't add any buffer and just return the elements at the given indices.
    """
    rows, cols = indices.unbind(dim=-1)
    if buffer is not None:
        rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
        cols[cols < buffer] = buffer
        cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
        cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + torch.arange(
            -buffer, buffer + 1, device=cols.device
        )
    return x[rows, cols]


x_top_values_with_context = index_with_buffer(x, top_indices, buffer=3)
assert x_top_values_with_context[0].tolist() == [8, 9, 10 + 50, 11 + 100, 12, 13, 14]  # highest value in the middle
assert x_top_values_with_context[1].tolist() == [7, 8, 9, 10 + 50, 11 + 100, 12, 13]  # 2nd highest value in the middle


def display_top_seqs(data: list[tuple[float, list[str], int]]):
    """
    Given a list of (activation: float, str_toks: list[str], seq_pos: int), displays a table of these sequences, with
    the relevant token highlighted.

    We also turn newlines into "\\n", and remove unknown tokens � (usually weird quotation marks) for readability.
    """
    table = Table("Act", "Sequence", title="Max Activating Examples", show_lines=True)
    for act, str_toks, seq_pos in data:
        formatted_seq = (
            "".join([f"[b u green]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)])
            .replace("�", "")
            .replace("\n", "↵")
        )
        table.add_row(f"{act:.3f}", repr(formatted_seq))
    rprint(table)


example_data = [
    (0.5, [" one", " two", " three"], 0),
    (1.5, [" one", " two", " three"], 1),
    (2.5, [" one", " two", " three"], 2),
]
display_top_seqs(example_data)

In [9]:
def fetch_max_activating_examples(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 10,
    buffer: int = 10,
) -> list[tuple[float, list[str], int]]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"

    # Create list to store the top k activations for each batch. Once we're done,
    # we'll filter this to only contain the top k over all batches
    data = []

    for _ in tqdm(range(total_batches), desc="Computing activations for max activating examples"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]

        # Get largest indices, get the corresponding max acts, and get the surrounding indices
        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        tokens_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
        str_toks = [model.to_str_tokens(toks) for toks in tokens_with_buffer]
        top_acts = index_with_buffer(acts, k_largest_indices).tolist()
        data.extend(list(zip(top_acts, str_toks, [buffer] * len(str_toks))))

    return sorted(data, key=lambda x: x[0], reverse=True)[:k]

In [12]:
# from tqdm import tqdm
# # finding max activating examples is a bit harder. To do this we need to calculate feature activations for a large number of tokens
# feature_list = torch.randint(0, sae.cfg.d_sae, (100,))
# examples_found = 0
# all_fired_tokens = []
# all_feature_acts = []
# all_reconstructions = []
# all_token_dfs = []

# total_batches = 100
# batch_size_prompts = activation_store.store_batch_size_prompts
# batch_size_tokens = activation_store.context_size * batch_size_prompts
# pbar = tqdm(range(total_batches))
# for i in pbar:
#     tokens = activation_store.get_batch_tokens()
#     tokens_df = make_token_df(tokens)
#     tokens_df["batch"] = i

#     flat_tokens = tokens.flatten()

#     _, cache = model.run_with_cache(
#         tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]
#     )
#     sae_in = cache[sae.cfg.hook_name]
#     feature_acts = sae.encode(sae_in).squeeze()

#     feature_acts = feature_acts.flatten(0, 1)
#     fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
#     fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
#     reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

#     token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
#     all_token_dfs.append(token_df)
#     all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
#     all_fired_tokens.append(fired_tokens)
#     all_reconstructions.append(reconstruction)

#     examples_found += len(fired_tokens)
#     # print(f"Examples found: {examples_found}")
#     # update description
#     pbar.set_description(f"Examples found: {examples_found}")

# # flatten the list of lists
# all_token_dfs = pd.concat(all_token_dfs)
# all_fired_tokens = list_flatten(all_fired_tokens)
# all_reconstructions = torch.cat(all_reconstructions)
# all_feature_acts = torch.cat(all_feature_acts)



In [13]:
# # Get the indices of the maximum activations for feature 0
# feature_0_activations = all_feature_acts[:, 0]
# max_indices = torch.topk(feature_0_activations, k=5).indices  # Top 5 examples

# # Extract the corresponding rows from all_token_dfs
# max_activation_examples = all_token_dfs.iloc[max_indices.cpu().numpy()]

# # Print the examples
# print(max_activation_examples[['str_tokens', 'context']])

### Autointerp

In [10]:
def create_prompt(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 15,
    buffer: int = 10,
) -> dict[Literal["system", "user", "assistant"], str]:
    """
    Returns the system, user & assistant prompts for autointerp.
    """

    data = fetch_max_activating_examples(model, sae, act_store, latent_idx, total_batches, k, buffer)
    str_formatted_examples = "\n".join(
        f"{i+1}. {''.join(f'<<{tok}>>' if j == buffer else tok for j, tok in enumerate(seq[1]))}"
        for i, seq in enumerate(data)
    )
    return {
        "system": "We're studying neurons in a neural network. Each neuron activates on some particular word or concept in a short document. The activating words in each document are indicated with << ... >>. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is activating on. Try to be specific in your explanations, although don't be so specific that you exclude some of the examples from matching your explanation. Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.",
        "user": f"""The activating documents are given below:\n\n{str_formatted_examples}""",
        "assistant": "this neuron fires on",
    }

In [11]:
def get_autointerp_explanation(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 15,
    buffer: int = 10,
    n_completions: int = 1,
) -> list[str]:
    """
    Queries OpenAI's API using prompts returned from `create_prompt`, and returns
    a list of the completions.
    """
    client = OpenAI(
        api_key=os.getenv("GENAI_API_KEY"),
        base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
    )

    prompts = create_prompt(model, sae, act_store, latent_idx, total_batches, k, buffer)

    result = client.chat.completions.create(
        model="gemini-2.0-flash",
        messages=[
            {"role": "system", "content": prompts["system"]},
            {"role": "user", "content": prompts["user"]},
            {"role": "assistant", "content": prompts["assistant"]},
        ],
        n=n_completions,
        max_tokens=50,
        stream=False,
    )
    return [choice.message.content for choice in result.choices]

completions = get_autointerp_explanation(model, sae, activation_store, latent_idx=13510, n_completions=5)
for i, completion in enumerate(completions):
    print(f"Completion {i+1}: {completion!r}")

Computing activations for max activating examples:   0%|          | 0/100 [00:00<?, ?it/s]

Completion 1: ' end punctuation in quotes or sentences.\n'
Completion 2: ' periods at the end of sentences \n'
Completion 3: ' periods followed by double quotes or question marks followed by double quotes\n'
Completion 4: ' ending punctuation such as periods or question marks\n'
Completion 5: ' periods that end a sentence'
