# Using a Trained SAE

This demo shows how to use a trained SAE for mechanistic interpretability research.

## Setup

### Imports

In [1]:
# Check if we're in Colab
try:
    import google.colab  # noqa: F401 # type: ignore

    in_colab = True
except ImportError:
    in_colab = False

#  Install if in Colab
if in_colab:
    %pip install sparse_autoencoder transformer_lens transformers wandb

# Otherwise enable hot reloading in dev mode
if not in_colab:
    %load_ext autoreload
    %autoreload 2

In [10]:
from sparse_autoencoder import (
    ActivationResamplerHyperparameters,
    AutoencoderHyperparameters,
    Hyperparameters,
    LossHyperparameters,
    Method,
    OptimizerHyperparameters,
    Parameter,
    PipelineHyperparameters,
    SourceDataHyperparameters,
    SourceModelHyperparameters,
    SweepConfig,
    sweep,
    SparseAutoencoder,
    TensorActivationStore,
    Axis,
    store_activations_hook,
)
from functools import partial
import os
from transformer_lens import HookedTransformer
from transformer_lens.utils import test_prompt
from einops import einsum

from jaxtyping import Float, Int
from torch import Tensor
import pandas as pd
from einops import rearrange
from transformers import PreTrainedTokenizerBase

os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Loading the SAE

In [3]:
gpt2 = HookedTransformer.from_pretrained("gpt2")
tokenizer: PreTrainedTokenizerBase = gpt2.tokenizer # type: ignore

Loaded pretrained model gpt2 into HookedTransformer


In [4]:
sae = SparseAutoencoder.load_from_wandb(wandb_artifact_name="alan-cooney/sparse-autoencoder/silvery-sweep-7_final:v0")

sae_cache_names = [f"blocks.{layer}.hook_mlp_out" for layer in range(12)]

sae

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Downloading large artifact silvery-sweep-7_final:v0, 216.18MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.6


SparseAutoencoder(
  (pre_encoder_bias): TiedBias(position=pre_encoder)
  (encoder): LinearEncoder(
    input_features=768, learnt_features=3072, n_components=12
    (activation_function): ReLU()
  )
  (decoder): UnitNormDecoder(learnt_features=3072, decoded_features=768, n_components=12)
  (post_decoder_bias): TiedBias(position=post_decoder)
)

Next we'll create a function to get SAE learned activations from a forward pass of the source and
sae models.

In [11]:
def forward_pass_sae_activations(
    sparse_autoencoder: SparseAutoencoder,
    source_model: HookedTransformer,
    input_ids: Int[Tensor, Axis.names(Axis.BATCH, Axis.POSITION)],
    cache_names: list[str],
) -> Float[Tensor, Axis.names(Axis.POSITION, Axis.COMPONENT, Axis.LEARNT_FEATURE)]:
    """Forward pass that gets the hidden SAE activations."""
    # Create the store for source model activations
    number_activations: int = input_ids.numel()
    store = TensorActivationStore(
        number_activations,
        n_neurons=sparse_autoencoder.config.n_input_features,
        n_components=sparse_autoencoder.config.n_components or 1,
    )

    # Add the hook to the model (will automatically store the activations every time the model
    # runs)
    source_model.remove_all_hook_fns()
    for component_idx, cache_name in enumerate(cache_names):
        hook = partial(store_activations_hook, store=store, component_idx=component_idx)
        source_model.add_hook(cache_name, hook)

    # Source model forward pass (to fill the store)
    source_model(input_ids)

    # SAE forward pass (to get SAE hidden activations)
    sae_res = sparse_autoencoder.forward(store[:])

    # Remove the hook
    source_model.remove_all_hook_fns()

    return sae_res.learned_activations

## Example Analysis: Identifying MLP enrichment of "Michael Jordan" with facts about him

### Introduction - testing the model's capabilities

[] found that MLP layers enrich the subject "Michael Jordan" with lots of facts that the model knows
about michael Jordan. It was theorised that these facts are stored in MLP layers, but due to
polysemanticity it was not possible to identify specific neurons that were involved in this process
(i.e. we can't find a "basketball" or "man" MLP neuron that fires on the Jordan token). In this
example we'll use a trained SAE to try and locate specific neurons that are involved in this task.

Here's a couple of examples:

In [6]:
test_prompt("Michael Jordan plays the sport of", answer="basketball", model=gpt2)

Tokenized prompt: ['<|endoftext|>', 'Michael', ' Jordan', ' plays', ' the', ' sport', ' of']
Tokenized answer: [' basketball']


Top 0th token. Logit: 16.17 Prob: 50.01% Token: | basketball|
Top 1th token. Logit: 13.77 Prob:  4.53% Token: | football|
Top 2th token. Logit: 13.72 Prob:  4.30% Token: | golf|
Top 3th token. Logit: 13.08 Prob:  2.27% Token: | tennis|
Top 4th token. Logit: 12.90 Prob:  1.90% Token: | soccer|
Top 5th token. Logit: 12.89 Prob:  1.87% Token: | hockey|
Top 6th token. Logit: 12.55 Prob:  1.33% Token: | boxing|
Top 7th token. Logit: 12.38 Prob:  1.13% Token: | sports|
Top 8th token. Logit: 12.24 Prob:  0.98% Token: | the|
Top 9th token. Logit: 12.16 Prob:  0.90% Token: | baseball|


In [7]:
test_prompt("Michael Jordan is a", answer="man", model=gpt2)

Tokenized prompt: ['<|endoftext|>', 'Michael', ' Jordan', ' is', ' a']
Tokenized answer: [' man']


Top 0th token. Logit: 13.18 Prob:  3.11% Token: | man|
Top 1th token. Logit: 13.14 Prob:  3.00% Token: | former|
Top 2th token. Logit: 12.95 Prob:  2.47% Token: | great|
Top 3th token. Logit: 12.64 Prob:  1.82% Token: | big|
Top 4th token. Logit: 12.27 Prob:  1.25% Token: | very|
Top 5th token. Logit: 12.14 Prob:  1.10% Token: | basketball|
Top 6th token. Logit: 12.07 Prob:  1.03% Token: | good|
Top 7th token. Logit: 11.94 Prob:  0.90% Token: | legend|
Top 8th token. Logit: 11.85 Prob:  0.82% Token: | player|
Top 9th token. Logit: 11.84 Prob:  0.82% Token: | star|


### Find Monosemantic Neurons

The simple way to do this is some forward passes of the model (with and without the training data),
storing the activations of the MLP output layers. We can then run these outputs through the SAE to
get SAE hidden layer outputs. Finally we can find which neurons activate for a given proportion of inputs.

In [12]:
prompt = "Michael Jordan"
input_ids: Int[Tensor, Axis.names(Axis.BATCH, Axis.POSITION)] = tokenizer.encode(prompt, return_tensors="pt")
input_ids.shape

torch.Size([1, 2])

### Direct logit attribution (DLA)

In [13]:
last_token_sae_activation_cache = forward_pass_sae_activations(sae, gpt2, input_ids, sae_cache_names)[-1]

The cache here is of shape [component x sae_hidden]. To convert each neuron into residual directions
we want a [component x sae_hidden x d_model] tensor:

In [14]:
neuron_residual_directions: Float[
    Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
] = einsum(
    last_token_sae_activation_cache,
    sae.decoder.weight,
    (
        f"{Axis.COMPONENT} {Axis.LEARNT_FEATURE} ,"
        f"{Axis.COMPONENT} {Axis.INPUT_OUTPUT_FEATURE} {Axis.LEARNT_FEATURE} "
        f"-> {Axis.COMPONENT} {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE}"
    ),
)

neuron_residual_directions.shape

torch.Size([12, 3072, 768])

In [18]:
_res, cache = gpt2.run_with_cache(input_ids)

ln_final_scale = cache["ln_final.hook_scale"][:, -1, :].to(device="cpu")
neuron_residual_directions_scaled = neuron_residual_directions / ln_final_scale
neuron_residual_directions_scaled.shape

torch.Size([12, 3072, 768])

We'll also want the residual of the token "basketball" (by unembedding it), and thb

In [21]:
basketball_residual_direction = gpt2.tokens_to_residual_directions("basketball").to("cpu")
basketball_residual_direction.shape

torch.Size([768])

In [24]:
dla = einsum(neuron_residual_directions_scaled, basketball_residual_direction,  (
        f"{Axis.COMPONENT} {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} ,"
        f"{Axis.INPUT_OUTPUT_FEATURE} "
        f"-> {Axis.COMPONENT} {Axis.LEARNT_FEATURE}"
    ))
dla.shape

torch.Size([12, 3072])

In [23]:
layer_dla = dla.sum(dim=1)
layer_dla_df = 

### Ablations

### Patching