## Imports & Installs

In [1]:
from sae_lens import SAE

sae = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
    device = "cuda"
)

  from .autonotebook import tqdm as notebook_tqdm
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [2]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

In [3]:
# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading

In [4]:
COLAB = False

## Set Up

In [5]:
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [6]:
def display_vis_inline(filename: str, height: int = 850):
    """
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    """
    if not (COLAB):
        webbrowser.open(filename)

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(
            PORT, path=f"/{filename}", height=height, cache_in_notebook=True
        )

        PORT += 1

# Loading a pretrained Sparse Autoencoder

Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface.

In [7]:
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id="blocks.8.hook_resid_pre",  # won't always be a hook point
    device=device,
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2-small into HookedTransformer


In [8]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

# (1) Tokenization: Converts raw text into token IDs
# (2) Concatenation: Combines token IDs into a single tensor (one massive long stream of tokens)
# (3) Chunking: Splits the tensor into smaller chunks

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=sae.cfg.metadata.context_size,
    add_bos_token=sae.cfg.metadata.prepend_bos,
)

## Basic Analysis

Let's check some basic stats on this SAE in order to see how some basic functionality in the codebase works.

We'll calculate:

- L0 (the number of features that fire per activation)
    - Refers to the number of non-zero features that are active for a given input.
    - Sparsity Metric: It is the primary way to measure "sparsity." 
    - "Features that Fire": When the code calculates L0, it is counting how many features have an activation value greater than zero.
    - An SAE has thousands of possible features (e.g., 32,768), but for any single token (like the word " the"), only a tiny handful should be active to represent it.
    - Interpretation:
        - A high L0 value indicates that many features are active, which can be a sign of overfitting or noise.
        - A low L0 value indicates that only a few features are active, which is a sign of sparsity.
- The cross entropy loss when the output of the SAE is used in place of the activations.
    - The Baseline (Original Loss): Use the model normally. It processes the text and tries to predict the next token. If it is a good model, it assigns high probability to the correct next token, resulting in low Cross Entropy Loss.
    - The Intervention (Reconstruction Loss): 
        - Take the internal activations (what the model is "thinking" at a specific layer).
        - Pass them through the SAE (Compression -> Sparse Features -> Decompression).
        - Swap the original activations with the SAE's imperfect reconstructed versions.
        - Force the model to continue computing using this reconstructed data.
    - The Comparison
        - If the Reconstruction Loss is effectively the same as the Original Loss, it means the SAE successfully preserved all the "important" information the model needed to predict the next token.
        - If the Loss spikes (gets much higher), it means the SAE "broke" the model's chain of thought, implying it failed to capture critical features.

### L0 Test and Reconstruction Test

In [9]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:32]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    # returns a sparse list of features that activated
    feature_acts = sae.encode(cache[sae.cfg.metadata.hook_name])
    # turns the sparse list of features into original dense activations
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    # Slices the data to ignore the first token (position 0) -- First token is usually the "BOS" (Beginning of Sequence)
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    # prints the average sparsity
    print("average l0", l0.mean().item())
    # Draws the a graph shoing distribution of sparsity
    l0_values = l0.flatten().cpu().numpy()
    fig = px.histogram(
        x=l0_values,
        labels={
            "x": "Number of SAE features active per token (L0 count)",
            "y": "Token positions in sampled batch",
        },
        title="Distribution of SAE feature sparsity (L0)",
    )
    fig.show()

average l0 68.78986358642578


Note that while the mean L0 is 64, it varies with the specific activation.
In the histogram above, the x-axis shows the L0 count (number of non-zero SAE features per token position after dropping the BOS token) and the y-axis shows how many token positions fall into each bin.

To estimate reconstruction performance, we calculate the CE loss of the model with and without the SAE being used in place of the activations. This will vary depending on the tokens.

In [10]:
from transformer_lens import utils
from functools import partial


# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item()) # Calculates the original loss
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.metadata.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
) # Calculates the loss after reconstruction
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.metadata.hook_name, zero_abl_hook)],
    ).item(),
) # Calculates the loss after zeroing out the activations

Orig 3.562199592590332
reconstr 3.764155387878418
Zero 11.146590232849121


## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

In [15]:
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary" # Indirect Object Indentification (IOI) task
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True) # What it does: Runs the standard model on this prompt and prints a report.

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt) # Converts the string prompt into integer tokens (e.g., [50256, 1204, ...]) so passing it to the model later is faster/easier.
sae_out = sae(cache[sae.cfg.metadata.hook_name]) # Passes the specific activations from the hook point (e.g., "Block 8") through the SAE.

# A function that ignores the incoming activations and blindly returns sae_out (our SAE reconstruction). Used to "hot-swap" the brain state.
def reconstr_hook(activations, hook, sae_out):
    return sae_out # replaced activations with SAE reconstruction

# A function that replaces everything with zeros (lobotomy).
def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

# The exact address of the layer we are modifying (e.g., 'blocks.8.hook_resid_pre').
hook_name = sae.cfg.metadata.hook_name

print("Orig", model(tokens, return_type="loss").item())

# Calculating the loss when we swap in the SAE reconstruction.
# We want to see if the SAE intervention causes the loss to spike for this specific task.
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)

print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(hook_name, zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            hook_name,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.19 Prob: 69.93% Token: | Mary|
Top 1th token. Logit: 15.82 Prob:  6.49% Token: | them|
Top 2th token. Logit: 15.48 Prob:  4.66% Token: | the|
Top 3th token. Logit: 14.93 Prob:  2.66% Token: | his|
Top 4th token. Logit: 14.86 Prob:  2.49% Token: | John|
Top 5th token. Logit: 14.12 Prob:  1.19% Token: | her|
Top 6th token. Logit: 13.99 Prob:  1.04% Token: | their|
Top 7th token. Logit: 13.70 Prob:  0.78% Token: | a|
Top 8th token. Logit: 13.53 Prob:  0.66% Token: | him|
Top 9th token. Logit: 13.39 Prob:  0.57% Token: | Mrs|


RuntimeError: Inference tensors cannot be saved for backward. Please do not use Tensors created in inference mode in computation tracked by autograd. To work around this, you can make a clone to get a normal tensor and use it in autograd, or use `torch.no_grad()` instead of `torch.inference_mode()`.

# Generating Feature Interfaces

Feature dashboards are an important part of SAE Evaluation. They work by:

- 1. Collecting feature activations over a larger number of examples.
- 2. Aggregating feature specific statistics (such as max activating examples).
- 3. Representing that information in a standardized way

For our feature visualizations, we will use a separate library called SAEDashboard.


In [12]:
# Make sure to install sae-dashboard if not running in colab
# pip install sae-dashboard
# Note: this cell may not work until sae-dashboard is updated to work with the latest version of sae-lens

test_feature_idx_gpt = list(range(10)) + [14057] # first ten features + one random feature

from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner


feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_name,
    features=test_feature_idx_gpt,
    minibatch_size_features=64,
    minibatch_size_tokens=256,
    verbose=True,
    device=device,
)

# Collect activation stats for selected SAE features on a token sample for dashboard plots.
visualization_data_gpt = SaeVisRunner(
    feature_vis_config_gpt
).run(
    encoder=sae,  # type: ignore
    model=model,
    tokens=token_dataset[:10000]["tokens"],  # type: ignore
)
# SaeVisData.create(
#     encoder=sae,
#     model=model, # type: ignore
#     tokens=token_dataset[:10000]["tokens"],  # type: ignore
#     cfg=feature_vis_config_gpt,
# )


Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)


Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)


Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an

Forward passes to cache data for vis: 100%|██████████| 40/40 [00:12<00:00,  3.19it/s]
Extracting vis data from cached data: 100%|██████████| 11/11 [00:12<00:00,  1.14s/it]


In [13]:
from sae_dashboard.data_writing_fns import save_feature_centric_vis

filename = f"demo_feature_dashboards.html"
save_feature_centric_vis(sae_vis_data=visualization_data_gpt, filename=filename)

Saving feature-centric vis:   0%|          | 0/11 [00:00<?, ?it/s]

Saving feature-centric vis: 100%|██████████| 11/11 [00:00<00:00, 35.44it/s]


In [14]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# this function should open
neuronpedia_quick_list = get_neuronpedia_quick_list(sae, test_feature_idx_gpt)

if COLAB:
    # If you're on colab, click the link below
    print(neuronpedia_quick_list)