# MATS Workshop October 2024: "Hands-On Tutorial with SAELens & Neuronpedia"

This notebook is the companion notebook to a session run for Neel Nanda's MATS scholars.

The slides for this session are [here](https://docs.google.com/presentation/d/13c-oTOz5n_VPIo9H_6RJiGgT75E1sSC1d0MYNfR07Q4/edit?usp=sharing). 

Much of this content is directly taken from the [ARENA tutorials](https://colab.research.google.com/drive/1ePkM8oBHIEZ2kcqAiA3waeAmz8RSdHmq#scrollTo=GQgKfUNekdPv). I've skipped some sections for speed and added some which are less code-y and more about using Neuronpedia. 

## Setup

If you want to install a fresh enviroment, the following commands might be useful:

```bash
conda create -n sae-lens-workshop python=3.12 -y
conda activate sae-lens-workshop
pip install sae-lens sae-dashboard ipykernel tabulate
```

If you think you are likely to want to edit SAE Lens, then using poetry to install the package is recommended.

In [None]:
# ! pip install sae-lens ipykernel tabulate plotly-express nbformat
! pip install git+https://github.com/callummcdougall/sae_vis.git@callum/v3
! pip install frozendict

In [None]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens sae-vis plotly-express nbformat
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

In [None]:
# For displaying sae-vis inline

import http.server
import socketserver
import threading
import webbrowser

from google.colab import output

PORT = 8010


def display_vis_inline(filename: Path, height: int = 850):
    """
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    """
    global PORT

    def serve(directory):
        os.chdir(directory)
        handler = http.server.SimpleHTTPRequestHandler
        with socketserver.TCPServer(("", PORT), handler) as httpd:
            print(f"Serving files from {directory} on port {PORT}")
            httpd.serve_forever()

    thread = threading.Thread(target=serve, args=("/content",))
    thread.start()

    filename = str(filename).split("/content")[-1]

    output.serve_kernel_port_as_iframe(
        PORT, path=filename, height=height, cache_in_notebook=True
    )

    PORT += 1

## Stage 1: Basics

### 1.1 Open Source SAEs

Core details:
- There are lots of open-source SAEs or many different models.
  - They vary in their architecture, training methods and other details. 
- A subset of them are supported by SAE Lens / Neuronpedia
  - This comes down to whether or not we have the code for doing the forward pass.
  - And whether or not we have the config for them.
- In SAELens, we group SAEs by Hugginface Repo + path within the repo.
  - "release" -> repo
  - "id" -> a path. 
  - "model" -> We store the model (t-lens string), but we don't necessarily need to do this. 
  - "neuronpedia_id" -> Has a different id system so we store id's here for ease of reference.
- We were previously considering storing information here for testing that they were running correctly but we'll probably back this out (eg: explained variance).

**Which SAEs should you use for your experiments?**
1. Smallest model that can do the task.
2. Try SAEs of different sizes depending on what features you're interested in.
3. Neuronpedia support is nice but not essential.
4. SAE performance should meet at least some minimal bar (we'll talk about this later)

Links:
- Source of ground truth: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
- SAE Lens Docs: https://jbloomaus.github.io/SAELens/sae_table/
- Neuronpedia: https://www.neuronpedia.org/

In [None]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from tabulate import tabulate

print(get_pretrained_saes_directory())

In [None]:
metadata_rows = [
    [data.model, data.release, data.repo_id, len(data.saes_map)] for data in get_pretrained_saes_directory().values()
]

# Print all SAE releases, sorted by base model
print(
    tabulate(
        sorted(metadata_rows, key=lambda x: x[0]),
        headers=["model", "release", "repo_id", "n_saes"],
        tablefmt="simple_outline",
    )
)

In [None]:
def format_value(value):
    return "{{{0!r}: {1!r}, ...}}".format(*next(iter(value.items()))) if isinstance(value, dict) else repr(value)

release = get_pretrained_saes_directory()["gpt2-small-res-jb"]

print(
    tabulate(
        [[k, format_value(v)] for k, v in release.__dict__.items()],
        headers=["Field", "Value"],
        tablefmt="simple_outline",
    )
)

In [None]:
data = [[id, path, release.neuronpedia_id[id]] for id, path in release.saes_map.items()]

print(
    tabulate(
        data,
        headers=["SAE id", "SAE path (HuggingFace)", "Neuronpedia ID"],
        tablefmt="simple_outline",
    )
)

### 1.2 Loading SAEs

- We can use from_pretrained to get HookedSAETransformer and SAEs (similar method on each).
- In order to work with SAEs, it's essential to know you are running them correctly:
  - This means that you're doing the forward pass correctly
  - That you're using the correct inputs. (eg: context_size / dataset it was trained on)
  - That you're using the correct model / running it correctly (eg: `from_pretrained_kwargs`)
- Much of the work that we do with SAE Lens is try to store all of the information required to do this. 
- Note that the information to reproduce the SAE training must be sourced from elsewhere (since people train SAEs all sorts of ways)

In [None]:
import torch 
from sae_lens import HookedSAETransformer, SAE

torch.set_grad_enabled(False)

gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

gpt2_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    device=str(device),
)

The `sae` object is an instance of the `SAE` (Sparse Autoencoder) class. There are many different SAE architectures which may have different weights or activation functions. In order to simplify working with SAEs, SAELens handles most of this complexity for you. You can run the cell below to see each of the SAE config parameters for the one we'll be using.

<details>
<summary>Click to read a description of each of the SAE config parameters.</summary>

1. `architecture`: Specifies the type of SAE architecture being used, in this case, the standard architecture (encoder and decoder with hidden activations, as opposed to a gated SAE).
2. `d_in`: Defines the input dimension of the SAE, which is 768 in this configuration.
3. `d_sae`: Sets the dimension of the SAE's hidden layer, which is 24576 here. This represents the number of possible feature activations.
4. `activation_fn_str`: Specifies the activation function used in the SAE, which is ReLU in this case. TopK is another option that we will not cover here.
5. `apply_b_dec_to_input`: Determines whether to apply the decoder bias to the input, set to True here.
6. `finetuning_scaling_factor`: Indicates whether to use a scaling factor to weight initialization and the forward pass. This is not usually used and was introduced to support a [solution for shrinkage](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes).
7. `context_size`: Defines the size of the context window, which is 128 tokens in this case. In turns out SAEs trained on small activations from small prompts [often don't perform well on longer prompts](https://www.lesswrong.com/posts/baJyjpktzmcmRfosq/stitching-saes-of-different-sizes).
8. `model_name`: Specifies the name of the model being used, which is 'gpt2-small' here. [This is a valid model name in TransformerLens](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html).
9. `hook_name`: Indicates the specific hook in the model where the SAE is applied.
10. `hook_layer`: Specifies the layer number where the hook is applied, which is layer 7 in this case.
11. `hook_head_index`: Defines which attention head to hook into; not relevant here since we are looking at a residual stream SAE.
12. `prepend_bos`: Determines whether to prepend the beginning-of-sequence token, set to True.
13. `dataset_path`: Specifies the path to the dataset used for training or evaluation. (Can be local or a huggingface dataset.)
14. `dataset_trust_remote_code`: Indicates whether to trust remote code (from HuggingFace) when loading the dataset, set to True.
15. `normalize_activations`: Specifies how to normalize activations, set to 'none' in this config.
16. `dtype`: Defines the data type for tensor operations, set to 32-bit floating point.
17. `device`: Specifies the computational device to use.
18. `sae_lens_training_version`: Indicates the version of SAE Lens used for training, set to None here.
19. `activation_fn_kwargs`: Allows for additional keyword arguments for the activation function. This would be used if e.g. the `activation_fn_str` was set to `topk`, so that `k` could be specified.

</details>

In [None]:
print(
    tabulate(
        gpt2_sae.cfg.__dict__.items(),
        headers=["name", "value"],
        tablefmt="simple_outline",
    )
)

Optional: Let's look at some examples for other SAEs (especially different architectures)

In [None]:
gpt2_sae_topk_open_ai, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-resid-post-v5-32k",
    sae_id="blocks.0.hook_resid_post",
    device=str(device),
)

print(
    tabulate(
        gpt2_sae_topk_open_ai.cfg.__dict__.items(),
        headers=["name", "value"],
        tablefmt="simple_outline",
    )
)

In [None]:
gemma_scope_16k_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_0/width_16k/canonical",
    device=str(device),
)

print(
    tabulate(
        gemma_scope_16k_sae.cfg.__dict__.items(),
        headers=["name", "value"],
        tablefmt="simple_outline",
    )
)

### 1.3 Running SAEs with SAE Lens

Core details:
- The suggested way to run SAEs is with HookedSAETransformer.
- You can render dashboards from Neuronpedia in your notebook and use our servers to test activations.

Key Links:
- ARENA Exercises: https://colab.research.google.com/drive/1ePkM8oBHIEZ2kcqAiA3waeAmz8RSdHmq 

In [None]:
from transformer_lens import utils

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
utils.test_prompt(prompt, answer, gpt2)

# Test our prompt, to see what the model says
with gpt2.saes(saes=[gpt2_sae]):
    utils.test_prompt(prompt, answer, gpt2)

# Same thing, done in a different way
gpt2.add_sae(gpt2_sae)
utils.test_prompt(prompt, answer, gpt2)
gpt2.reset_saes()  # Remember to always do this!

# Using `run_with_saes` method in place of standard forward pass
logits = gpt2(prompt, return_type="logits")
logits_sae = gpt2.run_with_saes(prompt, saes=[gpt2_sae], return_type="logits")
answer_token_id = gpt2.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {gpt2.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {gpt2.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

In [None]:
_, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])

for name, param in cache.items():
    if "hook_sae" in name:
        print(f"{name:<43}: {tuple(param.shape)}")

In [None]:
import random 
from IPython.display import IFrame

def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    feature_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{feature_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
    display(IFrame(url, width=width, height=height))


In [None]:
# Get top activations on final token
_, cache = gpt2.run_with_cache_with_saes(
    prompt,
    saes=[gpt2_sae],
    stop_at_layer=gpt2_sae.cfg.hook_layer + 1,
)
sae_acts_post = cache[f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :]

# Plot line chart of feature activations
px.line(
    sae_acts_post.cpu().numpy(),
    title=f"Feature activations at the final token position ({sae_acts_post.nonzero().numel()} alive)",
    labels={"index": "Feature", "value": "Activation"},
    width=1000,
).update_layout(showlegend=False).show()

# Print the top 5 features, and inspect their dashboards
for act, ind in zip(*sae_acts_post.topk(3)):
    print(f"Feature {ind} had activation {act:.2f}")
    display_dashboard(feature_idx=ind)

### 1.4 Feature Dashboards and Neuronpedia (see slides)

(Might leave this challenges for after the session)

Latent Search Challenges:
1. Can you find a pokemon related latent using search via explanations? How about via inference?
2. Can you find a latent fires between open brackets via explanations? How about via inference?
3. Can you find a latent that fires when the next token is likely to start with the letter P? 
4. Extra hard: In GemmaScope Res 16k layer 20, can you find the "cringe" feature? 

Dashboard Quality Challenges:
1. Can you find a dense feature? (firing > 2% of the time). What might it be tracking?
2. Can you find an example of a bad latent explanation. Why is it bad? 

Redteaming Challenges:
1. Using the feature absorption app, can you demonstrate examples of feature absorption? Replicate this result just on Neuronpedia.
2. (hard)Using the [meta-SAEs app](https://metasae.streamlit.app/?page=Meta+Feature+Explorer&meta_feature=504) can you come up with any more features which might be subject to absorption?

## Stage 2: Toolkit

### 2.1 Training SAEs (Almost exactly the ARENA Content)

`SAELens` makes training your SAEs relatively straightforward (if you know what the config does!). The code for training is essentially:

```python
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

runner_cfg = LanguageModelSAERunnerConfig(...)
runner = SAETrainingRunner(runner_cfg)
sae = runner.run()
```

The `LanguageModelSAERunnerConfig` class contains all the parameters necessary to specify how your SAE gets trained. This includes many of the parameters that went into the SAE config (in fact, this class contains a method `get_base_sae_cfg_dict()` which returns a dictionary that can be used to create the associated `SAEConfig` object). However, it also includes many other arguments which are specific to the training process itself. You can see the full set of config parameters in the source code for `LanguageModelSAERunnerConfig`, however for our purposes we can group them into ~7 main categories:

1. **Data generation** - everything to do with how the data we use to train the SAE is generated & batched. Recall that your SAEs are trained on the activations of a base TransformerLens model, so this includes things like the model name (which should point to a model [supported by TransformerLens](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)), hook point in the model, dataset path (which should reference a HuggingFace dataset), etc. We've also included `d_in` in this category since it's uniquely determined by the hook point you're training on.
2. **SAE architecture** - everything to do with the SAE architecture, which isn't implied by the data generation parameters. This includes things like `d_sae` (which you are allowed to use, although in practice we often specify `expansion_factor` and then `d_sae` is defined as `d_in * expansion_factor`), which activation function to use, how our weights are initialized, whether we subtract the decoder bias `b_dec` from the input, etc.
3. **Activations store** - everything to do with how the activations are stored and processed during training. Recall we used the `ActivationsStore` class earlier when we were generating the feature dashboards for our model. During training we also use an instance of this class to store and process batches of activations to feed into our SAE. We need to specify things like how many batches we want to store, how many prompts we want to process at once, etc.
4. **Training hyperparameters (standard)** - these are the standard kinds of parameters you'd expect to see in an other ML training loop: learning rate & learning rate scheduling, betas for Adam optimizer, etc.
5. **Training hyperparameters (SAE-specific)** - these are all the parameters which are specific to your SAE training. This means various coefficients like the $L_1$ penalty (as well as warmups for these coefficients), as well as things like the resampling protocol - more on this later. Certain other architectures (e.g. gated) might also come with additional parameters that need specifying.
6. **Logging / evals** - these control how frequently we log data to Weights & Biases, as well as how often we perform evaluations on our SAE (more on evals in the second half of this exercise set!). Remember that when we talk about evals during training, we're often talking about different kinds of evals than we perform post-training to compare different SAEs to each other (although there's certainly a lot of overlap).
7. **Misc** - this is a catchall for anything else you might want to specify, e.g. how often you save your model checkpoints, random seeds, device & dtype, etc.

#### Multi-GPU Training Runs


We inherit multi-GPU support from transformer_lens, so we pass arguments to the `model_from_pretrained_kwargs` to set `n_devices > 1`. 

```python
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

runner_cfg = LanguageModelSAERunnerConfig(
    ..., 
    model_from_pretrained_kwargs = {"n_devices": torch.cuda.device_count() - 1}
    )
runner = SAETrainingRunner(runner_cfg)
sae = runner.run()
```

#### Pre-generating your activations (enable reuse and possibly better shuffling)

Sometimes you'll be able to reuse activations (eg: when hyperparameter tuning). It's worth

```python
from sae_lens.cache_activations_runner import CacheActivationsRunner
from sae_lens.config import CacheActivationsRunnerConfig

path ="whatever/you/want"
cfg = CacheActivationsRunnerConfig(new_cached_activations_path=path, ...)
CacheActivationsRunner(cfg).run()

runner_cfg = LanguageModelSAERunnerConfig(
    ..., 
    use_cached_activations=True,
    cached_activations_path=path
)
runner = SAETrainingRunner(runner_cfg)
sae = runner.run()
```

#### Logging, checkpointing & saving SAEs

For any real training run, you should be **logging to Weights and Biases (WandB)**. This will allow you to track your training progress and compare different runs. To enable WandB, set `log_to_wandb=True`. The `wandb_project` parameter in the config controls the project name in WandB. You can also control the logging frequency with `wandb_log_frequency` and `eval_every_n_wandb_logs`. A number of helpful metrics are logged to WandB, including the sparsity of the SAE, the mean squared error (MSE) of the SAE, dead features, and explained variance. These metrics can be used to monitor the training progress and adjust the training parameters. We'll discuss these metrics more in later sections.

**Checkpoints** allow you to save a snapshot of the SAE and sparsitity statistics during training. To enable checkpointing, set `n_checkpoints` to a value larger than 0. If WandB logging is enabled, checkpoints will be uploaded as WandB artifacts. To save checkpoints locally, the `checkpoint_path` parameter can be set to a local directory.

Once you have a set of SAEs that you're happy with, your next step is to share them with the world! SAELens has a `upload_saes_to_huggingface()` function which makes this easy to do. You'll need to upload a dictionary where the keys are SAE ids, and the values are either `SAE` objects or paths to an SAE that's been saved using the `sae.save_model()` method (you can use a combination of both in your dictionary). Note that you'll need to be logged in to your huggingface account either by running `huggingface-cli login` in the terminal or by setting the `HF_TOKEN` environment variable to your API token (which should have write access to your repo).

```python
from sae_lens import SAE, upload_saes_to_huggingface

# Create a dictionary of SAEs: this will be a single release; keys are the SAE ids
saes_dict = {
    "blocks.0.hook_resid_pre": layer_0_sae,
    "blocks.1.hook_resid_pre": layer_1_sae,
}

# Upload the SAEs to HuggingFace, in your chosen repository
upload_saes_to_huggingface(
    saes_dict,
    hf_repo_id="your-username/your-sae-repo",
)

# Load them back in, using `from_pretrained()`
uploaded_saes = {
    layer: SAE.from_pretrained(
        release="your-username/your-sae-repo",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device=str(device)
    )[0]
    for layer in [0, 1]
}
```

#### Metrics

#### Reconstruction vs sparsity metrics

As we've discussed, most metrics you log to WandB are attempts to measure either reconstruction loss or sparsity, in various different ways. The goal is to monitor how these two different objectives are being balanced, and hopefully find pareto-improvements for them! For reconstruction loss, you want to pay particular attention to **MSE loss**, **CE loss recovered** and **explained variance**. For sparsity, you want to look at the L0 and L1 statistics, as well as the activations histogram (more on that below, since it's more nuanced than "brr line go up/down"!).

The L1 coefficient is the primary lever you have for managing the tradeoff between accurate reconstruction and sparsity. Too high and you get lots of dead features (although this is mediated by L1 warmup if using a scheduler - see below), too low and your features will be dense and polysemantic rather than sparse and interpretable.

Another really important point when considering the tradeoff between these two metrics - it might be tempting to use a very high L1 coefficient to get nice sparse interpretable features, but if this comes at a cost of high reconstruction loss, then there's a real risk that **you're not actually learning the model's true behaviour.** SAEs are only valuable when they give us a true insight into what the model's representations actually are, and doing interpretability without this risks being just a waste of time. See discussion [here](https://www.lesswrong.com/posts/tEPHGZAb63dfq2v8n/how-useful-is-mechanistic-interpretability) between Neel Nanda, Ryan Greenblat & Buck Schlegris for more on this point (note that I don't agree with all the points made in this post, but it raises some highly valuable ideas that SAE researchers would do well to keep in mind).

#### ...but metrics can sometimes be misleading

Although metrics in one of these two groups will often tell a similar story (e.g. explained variance will usually be high when MSE loss is small), they can occasionally detach, and it's important to understand why this might happen. Some examples:

- L0 and L1 both tell you about sparsity, but feature shrinkage makes them detach (it causes smaller L1, not L0)
- MSE loss and KL div / downstream CE both tell you about reconstruction, but they detach because one is myopic and the other is not

As well as being useful to understand these specific examples, it's valuable to put yourself in a skeptical mindset, and understand why these kinds of problems can arise. New metrics are being developed all the time, and some of them might be improvements over current ones while others might carry entirely new and unforseen pitfalls!

#### Dead features, resampling & ghost gradients

**Dead features** are ones that never fire on any input. These can be a big problem during training, because they don't receive any gradients and so represent permanently lost capacity in your SAE. Two ways of dealing with dead features, which Anthropic have described in various papers and update posts:

1. **Ghost gradients** - this describes the method of adding an additional term to the loss, which essentially gives dead latents a gradient signal that pushes them in the direction of explaining more of the autoencoder's residual. Full technical details [here](https://transformer-circuits.pub/2024/jan-update/index.html#dict-learning-resampling).
2. **Resampling** - at various timesteps we take all our dead features and randomly re-initialize them to values which help explain more of the residual (specifically, we will randomly select inputs which the SAE fails to reconstruct, and then set dead featurs to be the SAE hidden states corresponding to those inputs).

These techniques are both useful, however they've been reassessed in the time after they were initially introduced, and are now not seen to be as critical as they once were (especially [ghost grads](https://transformer-circuits.pub/2024/march-update/index.html#dl-update)). Instead, we use more standard techniques to avoid dead features, specifically a combination of an appropriately small learning rate and an L1 warmup. We generally recommend people use resampling and not ghost gradients, but to take enough care with your LR and L1 warmup to avoid having to rely on resampling.

Assuming we're not using ghost gradients, resampling is controlled by the following 3 parameters:

- `feature_sampling_window`, which is how often we resample neurons
- `dead_feature_window`, which is the size of the window over which we count dead features each time we resample. This should be smaller than `feature_sampling_window`
- `dead_feature_threshold`, which is the threshold below which we consider a feature to be dead & resample it

#### Dense latents & learning rates

Dense latents are the opposite problem to dead latents: if your learning rate is too small or L1 penalty too small, you'll fail to train latents to the point of being sparse. A dense latent is one that fires too frequently (e.g. on >1/100 or even >1/10 tokens). These features seem generally uninterpretable, but can help with youre reconstruction loss immensely - essentially it's a way for the SAEs to smuggle not-particularly-sparse possibly-nonlinear computation into your SAE.

It can be difficult to balance dense and dead latents during training. Generally you want to drop your learning rate as far down as it will go, without causing your features to be dense and your training to be super slow.

Note - what (if any) the right number of dense or dead latents should be in any situation is very much an open question, and depends on your beliefs about the underlying feature distribution in question. One way we can investigate this question is to try and train SAEs in a simpler domain, where the underlying features are easier to guess about (e.g. OthelloGPT, or TinyStories - both of which we'll discuss later in this chapter).

#### Interpreting the feature density histogram

This section is quoted directly from Joseph Bloom's [excellent post](https://www.lesswrong.com/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream#Why_can_training_Sparse_AutoEncoders_be_difficult__) on training SAEs:

> Feature density histograms are a good measure of SAE quality. We plot the log10 feature sparsity (how often it fires) for all features. In order to make this easier to operationalize, I’ve drawn a diagram that captures my sense of the issues these histograms help you diagnose. Feature density histograms can be broken down into:
> - **Too Dense**:  dense features will occur at a frequency > 1 / 100. Some dense-ish features are likely fine (such as a feature representing that a token begins with a space) but too many is likely an issue.
> - **Too Sparse**: Dead features won’t be sampled so will turn up at log10(epsilon), for epsilon added to avoid logging 0 numbers. Too many of these mean you’re over penalizing with L1.
> - **Just-Right**: Without too many dead or dense features, we see a distribution that has most mass between -5 or -4 and -3 log10 feature sparsity. The exact range can vary depending on the model / SAE size but the dense or dead features tend to stick out.
>
> <img src="https://res.cloudinary.com/lesswrong-2-0/image/upload/f_auto,q_auto/v1/mirroredImages/f9EgfLSurAiqRJySD/loz4zhrj3ps0cue7uike" width="450">

### Training Case Study: TinyStories-1L, MLP-out

In our first training case study, we'll train an SAE on the output of the final (only) MLP layer of a TinyStories model. [TinyStories](https://arxiv.org/abs/2305.07759) is a synthetic dataset consisting of short stories, which contains a vocabulary of ~1500 words (mostly just common words that typical 3-4 year old children can understand). Each story is also relatively short, self-contained, and contains a basic sequence of events which can often be causally inferred from the previous context. Example sequences look like:

> Once upon a time, there was a little girl named Lily. Lily liked to pretend she was a popular princess. She lived in a big castle with her best friends, a cat and a dog. One day, while playing in the castle, Lily found a big cobweb. The cobweb was in the way of her fun game. She wanted to get rid of it, but she was scared of the spider that lived there. Lily asked her friends, the cat and the dog, to help her. They all worked together to clean the cobweb. The spider was sad, but it found a new home outside. Lily, the cat, and the dog were happy they could play without the cobweb in the way. And they all lived happily ever after.

This dataset gives us a useful playground for interpretability analysis, because the kinds of features which it is useful for models to learn in order to minimize predictive loss on this dataset are far narrower and simpler than they would be for models trained on more complex natural language datasets.

Let's load in the model we'll be training our SAE on, and get a sense for how models trained on this dataset behave by generating text from it. This is a useful first step when it comes to thinking about what features the model is likely to have learned.

In [None]:
tinystories_model = HookedSAETransformer.from_pretrained("tiny-stories-1L-21M")

In [None]:
completions = [
    (
        i,
        tinystories_model.generate("Once upon a time", temperature=1, max_new_tokens=50),
    )
    for i in range(5)
]

print(tabulate(completions, tablefmt="simple_grid", maxcolwidths=[None, 100]))

Now we're ready to train out SAE. We'll make a runner config, instantiate the runner and the rest is taken care of for us!

During training, you use weights and biases to check key metrics which indicate how well we are able to optimize the variables we care about. You can reorganize your WandB dashboard to put important metrics like L0, CE loss score, explained variance etc in one section at the top. We also recommend you make a [run comparer](https://docs.wandb.ai/guides/app/features/panels/run-comparer/) for your different runs, whenever performing multiple training runs (e.g. hyperparameter sweeps).

If you've disabled gradients, remember to re-enable them using `t.set_grad_enabled(True)` before training.

In [None]:
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner

total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = l1_warm_up_steps = total_training_steps // 10  # 10% of training
lr_decay_steps = total_training_steps // 5  # 20% of training

cfg = LanguageModelSAERunnerConfig(
    #
    # Data generation
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",
    hook_layer=0,
    d_in=tinystories_model.cfg.d_model,
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # tokenized language dataset on HF for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    prepend_bos=True,  # you should use whatever the base model was trained with
    streaming=True,  # we could pre-download the token dataset if it was small.
    train_batch_size_tokens=batch_size,
    context_size=512,  # larger is better but takes longer (for tutorial we'll use a short one)
    #
    # SAE architecture
    architecture="gated",
    expansion_factor=16,
    b_dec_init_method="zeros",
    apply_b_dec_to_input=True,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    #
    # Activations store
    n_batches_in_buffer=64,
    training_tokens=total_training_tokens,
    store_batch_size_prompts=16,
    #
    # Training hyperparameters (standard)
    lr=5e-5,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # controls how the LR warmup / decay works
    lr_warm_up_steps=lr_warm_up_steps,  # avoids large number of initial dead features
    lr_decay_steps=lr_decay_steps,  # helps avoid overfitting
    #
    # Training hyperparameters (SAE-specific)
    l1_coefficient=4,
    l1_warm_up_steps=l1_warm_up_steps,
    use_ghost_grads=False,  # we don't use ghost grads anymore
    feature_sampling_window=2000,  # how often we resample dead features
    dead_feature_window=1000,  # size of window to assess whether a feature is dead
    dead_feature_threshold=1e-4,  # threshold for classifying feature as dead, over window
    #
    # Logging / evals
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="arena-demos-tinystories",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    #
    # Misc.
    device=str(device),
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype="float32",
)

# # Comment this code out to train! Otherwise, it will load in the already trained model.
# torch.set_grad_enabled(True)
# runner = SAETrainingRunner(cfg)
# sae = runner.run()

hf_repo_id = "callummcdougall/arena-demos-tinystories"
sae_id = cfg.hook_name

# upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id)

tinystories_sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

Once you've finished training your SAE, you can try using the following code from the `sae_vis` library to visualize your SAE's latents.

(Note - this code comes from a branch of the `sae_vis` library, which soon will be merged into main, and will also be more closely integrated with the rest of `SAELens`. For example, this method currently works by directly taking a batch of tokens, but in the future it will probably take an `ActivationsStore` object to make things easier.)

First, we get a batch of tokens from the dataset:

In [None]:
from datasets import load_dataset

dataset_path = "apollo-research/roneneldan-TinyStories-tokenizer-gpt2"
total_batch_size = 1024

dataset = load_dataset(dataset_path, streaming=True)

tokens = torch.tensor(
    [x["input_ids"] for i, x in zip(range(total_batch_size), dataset["train"])],
    device=str(device),
)
print(tokens.shape)

Next, we create the visualization and save it (you'll need to download the file and open it in a browser to view). Note that if you get OOM errors then you can reduce the number of features visualized, or decrease either the batch size or context length.

In [None]:
from sae_vis import SaeVisConfig, SaeVisData

sae_vis_data = SaeVisData.create(
    sae=tinystories_sae,
    model=tinystories_model,
    tokens=tokens,
    cfg=SaeVisConfig(hook_point=tinystories_sae.cfg.hook_name, features=range(16)),
    verbose=True,
)
sae_vis_data.save_feature_centric_vis(filename="feature_vis.html")

# display_vis_inline(str(section_dir / "feature_vis.html"))

### Exercise - identify good and bad training curves

```c
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵

You should spend up to 25-40 minutes on this exercise.
```

[Here](https://wandb.ai/callum-mcdougall/arena-demos-tinystories-v3/workspace?nw=nwusercallummcdougall) is a link to a WandB project page with seven training runs. The first one (Run #0) is a "good" training run (at least compared to the others), and can be thought of as a baseline. Each of the other 6 runs (labelled Run #1 - Run #6) has some particular issue with it. Your task will be to identify the issue (i.e. from one or more of the metrics plots), and find the root cause of the issue from looking at the configs (you can compare the configs to each other in the "runs" tab). Note that we recommend trying to identify the issue from the plots first, rather than immediately jumping to the config and looking for a diff between it and the good run. You'll get most value out of the exercises by using the following pattern when assessing each run:

1. Looking at the metrics, and finding some which seem like indications of poor SAE quality
2. Based on these metrics, try and guess what might be going wrong in the config
3. Look at the config, and test whether your guess was correct.

Also, a reminder - you can look at a run's density histogram plot, although you can only see this when you're looking at the page for a single run (as opposed to the project page).

Use the dropdowns below to see the answer for each of the runs. Note, the first few are more important to get right, as the last few are more difficult and don't always have obvious root causes.

<details>
<summary>Run #1</summary>

This run had too small an L1 coefficient, meaning it wasn't learning a sparse solution. The tipoff here should have been the feature sparsity statistics, e.g. L0 being extremely high. Exactly what L0 is ideal varies between different models and hook points (and also depends on things like the SAE width - see the section on feature splitting earlier for more on this), but as an idea, the canonical GemmaScope SAEs were chosen to be those with L0 closest to 100. Much larger than this (e.g. 200+) is almost definitely bad, especially if we're talking about what should fundamentally be a pretty simple dataset (TinyStories, with a 1L model).

</details>

<details>
<summary>Run #2</summary>

This run had far too many dead latents. This was as a result of choosing an unnecessarily large expansion factor: 32, rather than an expansion factor of 16 as is used in the baseline run. Note that having a larger expansion factor & dead features isn't necessarily bad, but it does imply a lot of wasted capacity. The fact that resampling was seemingly unable to reduce the number of dead latents is a sign that our `d_sae` was larger than it needed to be.

</details>

<details>
<summary>Run #3</summary>

This run had a very low learning rate: `1e-5` vs the baseline value of `5e-5`. Having a learning rate this low isn't inherently bad if you train for longer (and in fact a smaller learning rate and longer training duration is generally better if you have the time for it), but given the same number of training tokens a smaller learning rate can result in poorer end-of-training performance. In this case, we can see that most loss curves are still dropping when the training finishes, suggesting that this model was undertrained.

</details>

<details>
<summary>Run #4</summary>

This run had an expansion factor of 1, meaning that the number of learned features couldn't be larger than the dimensionality of the MLP output. This is obviously bad, and will lead to the poor performance seen in the loss curves (both in terms of sparsity and reconstruction loss).

</details>

<details>
<summary>Run #5</summary>

This run had a large number of dead features. Unlike run #2, the cause wasn't an unnecessarily large expansion factor, instead it was a combination of:

- High learning rate
- No warmup steps
- No feature resampling

So unlike run #2, there wasn't also a large number of live features, meaning performance was much poorer.

</details>

<details>
<summary>Run #6 (hint)</summary>

The failure mode here is of a different kind than the other five. Try looking at the plot `metrics/ce_loss_without_sae` - what does this tell you? (You can look at the SAELens source code to figure out what this metric means).

</details>

<details>
<summary>Run #6</summary>

The failure mode for this run is different from the other runs in this section. The SAE was trained perfectly fine, but it was trained on activations generated from the wrong input distribution! The dataset from which we generated our model activations was `apollo-research/monology-pile-uncopyrighted-tokenizer-gpt2` - this dataset was designed for a model like GPT2, and not for the tinystories model we're using.

The tipoff here could have come from a few different plots, but in particular the `"metrics/ce_loss_without_sae"` plot - we can see that the model performs much worse (without the SAE even being involved) than it did for any of the other runs in the project. This metrics plot is a useful sanity check to make sure your model is being fed appropriate data!

</details>

### 2.2 Evaluating SAEs (interactive session, unfinished but we can slow down here and do it together)

Together let's:
- Start with a code that loops over model activations and generates feature activations.
- Add code to count feature activations / distinct prompts in which features activate.
- Plot a histogram of feature density.
- Plot Feature Density vs Consistent Activation Heuristic. 

Let's re-use the code I wrote to reproduce Josh's non-linear features results.

In [None]:

from sae_lens import HookedSAETransformer, SAE, ActivationsStore

torch.set_grad_enabled(False)

# gpt2
# model_name = "gpt2-small"
# release = "gpt2-small-res-jb"
# sae_id = "blocks.7.hook_resid_pre"


# tinystories
model_name = "tiny-stories-1L-21M"
release = "callummcdougall/arena-demos-tinystories"
sae_id = "blocks.0.hook_mlp_out"



model: HookedSAETransformer = HookedSAETransformer.from_pretrained(model_name, device=device)

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=release,
    sae_id=sae_id,
    device=str(device),
)

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

# Example of how you can use this:
tokens = activation_store.get_batch_tokens()
assert tokens.shape == (activation_store.store_batch_size_prompts, activation_store.context_size)

In [None]:
from sklearn.decomposition import PCA

day_of_the_week_features = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185]
# months_of_the_year = [3977, 4140, 5993, 7299, 9104, 9401, 10449, 11196, 12661, 14715, 17068, 17528, 19589, 21033, 22043, 23304]
# years_of_20th_century = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]

days_of_the_week = [
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
    "Sunday",
]
buffer = 5
seq_len = activation_store.context_size
sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"

all_data = {"recons": [], "context": [], "token": [], "token_group": []}
total_batches = 500

for i in tqdm(range(total_batches)):
    _, cache = model.run_with_cache_with_saes(
        tokens := activation_store.get_batch_tokens(),
        saes=[sae],
        stop_at_layer=sae.cfg.hook_layer + 1,
        names_filter=[sae_acts_post_hook_name],
    )
    acts = cache[sae_acts_post_hook_name][..., day_of_the_week_features].flatten(0, 1)

    any_feature_fired = (acts > 0).any(dim=1)
    acts = acts[any_feature_fired]
    reconstructions = acts @ gpt2_sae.W_dec[day_of_the_week_features]

    all_data["recons"].append(reconstructions)

    for batch_seq_flat_idx in torch.nonzero(any_feature_fired).squeeze(-1).tolist():
        batch, seq = divmod(batch_seq_flat_idx, seq_len)  # type: ignore

        token = gpt2.tokenizer.decode(tokens[batch, seq])  # type: ignore
        token_group = token.strip() if token.strip() in days_of_the_week else "Other"

        context = gpt2.tokenizer.decode(  # type: ignore
            tokens[batch, max(seq - buffer, 0) : min(seq + buffer + 1, seq_len)]
        )

        all_data["context"].append(context)
        all_data["token"].append(token)
        all_data["token_group"].append(token_group)



In [None]:
pca = PCA(n_components=3)
pca_embedding = pca.fit_transform(torch.concat(all_data.pop("recons")).detach().cpu().numpy())

all_data |= {"PC2": pca_embedding[:, 1], "PC3": pca_embedding[:, 2]}
pca_df = pd.DataFrame(all_data)

px.scatter(
    pca_df,
    x="PC2",
    y="PC3",
    hover_data=["context"],
    hover_name="token",
    height=700,
    width=1000,
    color="token_group",
    color_discrete_sequence=px.colors.sample_colorscale("Viridis", 7) + ["#aaa"],
    title="PCA Subspace Reconstructions",
    labels={"token_group": "Activating token"},
    category_orders={"token_group": days_of_the_week + ["Other"]},
).show()

## Stage 3: Analysing SAEs

In [None]:
# If we want to use Callum's TinyStories SAE (maybe we'll generate CAH scores and use it to find good steering features)
model = HookedSAETransformer.from_pretrained("tiny-stories-1L-21M")
hf_repo_id = "callummcdougall/arena-demos-tinystories"
sae_id = cfg.hook_name
sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

In [None]:
# Alternatively, if we want to use my GPT2 SAE (maybe we'll generate CAH scores and use it to find good steering features)
# Link to Neuronpedia page with CAH stats: https://www.neuronpedia.org/gpt2-small/7-res-jb
model = HookedSAETransformer.from_pretrained("gpt2-small")
hf_repo_id = "gpt2-small-res-jb"
sae_id = "blocks.7.hook_resid_pre"
sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

### Steering / Clamping

In [None]:
from tqdm import tqdm
from functools import partial 

def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
    '''
    Find the maximum activation for a given feature index. This is useful for 
    calibrating the right amount of the feature to add.
    '''
    max_activation = 0.0

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        tokens = activation_store.get_batch_tokens()
        
        _, cache = model.run_with_cache(
            tokens, 
            stop_at_layer=sae.cfg.hook_layer + 1, 
            names_filter=[sae.cfg.hook_name]
        )
        sae_in = cache[sae.cfg.hook_name]
        feature_acts = sae.encode(sae_in).squeeze()

        feature_acts = feature_acts.flatten(0, 1)
        batch_max_activation = feature_acts[:, feature_idx].max().item()
        max_activation = max(max_activation, batch_max_activation)
        
        pbar.set_description(f"Max activation: {max_activation:.4f}")

    return max_activation

def steering(activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0):
    # Note if the feature fires anyway, we'd be adding to that here.
    return activations + max_act * steering_strength * steering_vector

def generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=1.0, max_new_tokens=95):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)
    
    steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)
    
    steering_hook = partial(
        steering,
        steering_vector=steering_vector,
        steering_strength=steering_strength,
        max_act=max_act
    )
    
    # standard transformerlens syntax for a hook context for generation
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos = False if device == "mps" else True,
            prepend_bos = sae.cfg.prepend_bos,
        )
    
    return model.tokenizer.decode(output[0])

# Choose a feature to steer
steering_feature = 100  # Choose a feature to steer towards

# Find the maximum activation for this feature
max_act = find_max_activation(model, sae, activation_store, steering_feature)
print(f"Maximum activation for feature {steering_feature}: {max_act:.4f}")

# note we could also get the max activation from Neuronpedia (https://www.neuronpedia.org/api-doc#tag/lookup/GET/api/feature/{modelId}/{layer}/{index})

# Generate text without steering for comparison
prompt = "Once upon a time"
normal_text = model.generate(
    prompt,
    max_new_tokens=95, 
    stop_at_eos = False if device == "mps" else True,
    prepend_bos = sae.cfg.prepend_bos,
)

print("\nNormal text (without steering):")
print(normal_text)

# Generate text with steering
steered_text = generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=2.0)
print("Steered text:")
print(steered_text)

We can also do steering via the Neuronpedia interface / API.


In [None]:
import requests
import numpy as np

url = "https://www.neuronpedia.org/api/steer"

payload = {
    # "prompt": "A knight in shining",
    # "prompt": "He had to fight back in self-", 
    "prompt": "In the middle of the universe is the galactic",
    # "prompt": "Oh no. We're running on empty. Its time to fill up the car with",
    # "prompt": "Sure, I'm happy to pay. I don't have any cash on me but let me write you a",
    "modelId": "gpt2-small",
    "features": [
        {
            "modelId": "gpt2-small",
            "layer": "7-res-jb",
            "index": 6770,
            "strength": 8
        }
    ],
    "temperature": 0.2,
    "n_tokens": 2,
    "freq_penalty": 1,
    "seed": np.random.randint(100),
    "strength_multiplier": 4
}
headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)

response.json()

In [None]:
import requests
import numpy as np
url = "https://www.neuronpedia.org/api/steer"

payload = {
    "prompt": "I wrote a letter to my girlfiend. It said \"",
    "modelId": "gpt2-small",
    "features": [
        {
            "modelId": "gpt2-small",
            "layer": "7-res-jb",
            "index": 20115,
            "strength": 4
        }
    ],
    "temperature": 0.7,
    "n_tokens": 120,
    "freq_penalty": 1,
    "seed": np.random.randint(100),
    "strength_multiplier": 4
}
headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)

response.json()

### Feature Ablation

Feature ablation is also worth looking at. In a way, it's a special case of steering where the value of the feature is always zeroed out.

Here we do the following:
1. Use test prompt rather than generate to get more nuance. 
2. attach a hook to the SAE feature activations.
3. 0 out a feature at all positions (we know that the default feature fires at the final position.)
4. Check whether this ablation is more / less effective if we include the error term (info our SAE isn't capturing).

Note that the existence of [The Hydra Effect](https://arxiv.org/abs/2307.15771) can make reasoning about ablation experiments difficult.

In [None]:
from transformer_lens.utils import test_prompt
from functools import partial

def test_prompt_with_ablation(model, sae, prompt, answer, ablation_features):
    
    def ablate_feature_hook(feature_activations, hook, feature_ids, position = None):
    
        if position is None:
            feature_activations[:,:,feature_ids] = 0
        else:
            feature_activations[:,position,feature_ids] = 0
        
        return feature_activations
        
    ablation_hook = partial(ablate_feature_hook, feature_ids = ablation_features)
    
    model.add_sae(sae)
    hook_point = sae.cfg.hook_name + '.hook_sae_acts_post'
    model.add_hook(hook_point, ablation_hook, "fwd")
    
    test_prompt(prompt, answer, model)
    
    model.reset_hooks()
    model.reset_saes()

# Example usage in a notebook:

# Assume model and sae are already defined

# Choose a feature to ablate

model.reset_hooks(including_permanent=True)
prompt = "In the beginning, God created the heavens and the"
answer = "earth"
test_prompt(prompt, answer, model)


# Generate text with feature ablation
print("Test Prompt with feature ablation and no error term")
ablation_feature = 16873  # Replace with any feature index you're interested in. We use the religion feature
sae.use_error_term = False
test_prompt_with_ablation(model, sae, prompt, answer, ablation_feature)

print("Test Prompt with feature ablation and error term")
ablation_feature = 16873  # Replace with any feature index you're interested in. We use the religion feature
sae.use_error_term = True
test_prompt_with_ablation(model, sae, prompt, answer, ablation_feature)

### Feature Attribution

In [None]:
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple, Callable

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint


class SaeReconstructionCache(NamedTuple):
    sae_in: torch.Tensor
    feature_acts: torch.Tensor
    sae_out: torch.Tensor
    sae_error: torch.Tensor


def track_grad(tensor: torch.Tensor) -> None:
    """wrapper around requires_grad and retain_grad"""
    tensor.requires_grad_(True)
    tensor.retain_grad()


@dataclass
class ApplySaesAndRunOutput:
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]

    def zero_grad(self) -> None:
        """Helper to zero grad all tensors in this object."""
        self.model_output.grad = None
        for act in self.model_activations.values():
            act.grad = None
        for cache in self.sae_activations.values():
            cache.sae_in.grad = None
            cache.feature_acts.grad = None
            cache.sae_out.grad = None
            cache.sae_error.grad = None


def apply_saes_and_run(
    model: HookedTransformer,
    saes: dict[str, SAE],
    input: Any,
    include_error_term: bool = True,
    track_model_hooks: list[str] | None = None,
    return_type: Literal["logits", "loss"] = "logits",
    track_grads: bool = False,
) -> ApplySaesAndRunOutput:
    """
    Apply the SAEs to the model at the specific hook points, and run the model.
    By default, this will include a SAE error term which guarantees that the SAE
    will not affect model output. This function is designed to work correctly with
    backprop as well, so it can be used for gradient-based feature attribution.

    Args:
        model: the model to run
        saes: the SAEs to apply
        input: the input to the model
        include_error_term: whether to include the SAE error term to ensure the SAE doesn't affect model output. Default True
        track_model_hooks: a list of hook points to record the activations and gradients. Default None
        return_type: this is passed to the model.run_with_hooks function. Default "logits"
        track_grads: whether to track gradients. Default False
    """

    fwd_hooks = []
    bwd_hooks = []

    sae_activations: dict[str, SaeReconstructionCache] = {}
    model_activations: dict[str, torch.Tensor] = {}

    # this hook just track the SAE input, output, features, and error. If `track_grads=True`, it also ensures
    # that requires_grad is set to True and retain_grad is called for intermediate values.
    def reconstruction_hook(sae_in: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        sae = saes[hook_point]
        feature_acts = sae.encode(sae_in)
        sae_out = sae.decode(feature_acts)
        sae_error = (sae_in - sae_out).detach().clone()
        if track_grads:
            track_grad(sae_error)
            track_grad(sae_out)
            track_grad(feature_acts)
            track_grad(sae_in)
        sae_activations[hook_point] = SaeReconstructionCache(
            sae_in=sae_in,
            feature_acts=feature_acts,
            sae_out=sae_out,
            sae_error=sae_error,
        )

        if include_error_term:
            return sae_out + sae_error
        return sae_out

    def sae_bwd_hook(output_grads: torch.Tensor, hook: HookPoint):  # noqa: ARG001
        # this just passes the output grads to the input, so the SAE gets the same grads despite the error term hackery
        return (output_grads,)

    # this hook just records model activations, and ensures that intermediate activations have gradient tracking turned on if needed
    def tracking_hook(hook_input: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        model_activations[hook_point] = hook_input
        if track_grads:
            track_grad(hook_input)
        return hook_input

    for hook_point in saes.keys():
        fwd_hooks.append(
            (hook_point, partial(reconstruction_hook, hook_point=hook_point))
        )
        bwd_hooks.append((hook_point, sae_bwd_hook))
    for hook_point in track_model_hooks or []:
        fwd_hooks.append((hook_point, partial(tracking_hook, hook_point=hook_point)))

    # now, just run the model while applying the hooks
    with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks):
        model_output = model(input, return_type=return_type)

    return ApplySaesAndRunOutput(
        model_output=model_output,
        model_activations=model_activations,
        sae_activations=sae_activations,
    )

In [None]:
from dataclasses import dataclass
from transformer_lens.hook_points import HookPoint
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint

EPS = 1e-8

torch.set_grad_enabled(True)
@dataclass
class AttributionGrads:
    metric: torch.Tensor
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]


@dataclass
class Attribution:
    model_attributions: dict[str, torch.Tensor]
    model_activations: dict[str, torch.Tensor]
    model_grads: dict[str, torch.Tensor]
    sae_feature_attributions: dict[str, torch.Tensor]
    sae_feature_activations: dict[str, torch.Tensor]
    sae_feature_grads: dict[str, torch.Tensor]
    sae_errors_attribution_proportion: dict[str, float]


def calculate_attribution_grads(
    model: HookedSAETransformer,
    prompt: str,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> AttributionGrads:
    """
    Wrapper around apply_saes_and_run that calculates gradients wrt to the metric_fn.
    Tracks grads for both SAE feature and model neurons, and returns them in a structured format.
    """
    output = apply_saes_and_run(
        model,
        saes=include_saes or {},
        input=prompt,
        return_type="logits" if return_logits else "loss",
        track_model_hooks=track_hook_points,
        include_error_term=include_error_term,
        track_grads=True,
    )
    metric = metric_fn(output.model_output)
    output.zero_grad()
    metric.backward()
    return AttributionGrads(
        metric=metric,
        model_output=output.model_output,
        model_activations=output.model_activations,
        sae_activations=output.sae_activations,
    )


def calculate_feature_attribution(
    model: HookedSAETransformer,
    input: Any,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> Attribution:
    """
    Calculate feature attribution for SAE features and model neurons following
    the procedure in https://transformer-circuits.pub/2024/march-update/index.html#feature-heads.
    This include the SAE error term by default, so inserting the SAE into the calculation is
    guaranteed to not affect the model output. This can be disabled by setting `include_error_term=False`.

    Args:
        model: The model to calculate feature attribution for.
        input: The input to the model.
        metric_fn: A function that takes the model output and returns a scalar metric.
        track_hook_points: A list of model hook points to track activations for, if desired
        include_saes: A dictionary of SAEs to include in the calculation. The key is the hook point to apply the SAE to.
        return_logits: Whether to return the model logits or loss. This is passed to TLens, so should match whatever the metric_fn expects (probably logits)
        include_error_term: Whether to include the SAE error term in the calculation. This is recommended, as it ensures that the SAE will not affecting the model output.
    """
    # first, calculate gradients wrt to the metric_fn.
    # these will be multiplied with the activation values to get the attributions
    outputs_with_grads = calculate_attribution_grads(
        model,
        input,
        metric_fn,
        track_hook_points,
        include_saes=include_saes,
        return_logits=return_logits,
        include_error_term=include_error_term,
    )
    model_attributions = {}
    model_activations = {}
    model_grads = {}
    sae_feature_attributions = {}
    sae_feature_activations = {}
    sae_feature_grads = {}
    sae_error_proportions = {}
    # this code is long, but all it's doing is multiplying the grads by the activations
    # and recording grads, acts, and attributions in dictionaries to return to the user
    with torch.no_grad():
        for name, act in outputs_with_grads.model_activations.items():
            assert act.grad is not None
            raw_activation = act.detach().clone()
            model_attributions[name] = (act.grad * raw_activation).detach().clone()
            model_activations[name] = raw_activation
            model_grads[name] = act.grad.detach().clone()
        for name, act in outputs_with_grads.sae_activations.items():
            assert act.feature_acts.grad is not None
            assert act.sae_out.grad is not None
            raw_activation = act.feature_acts.detach().clone()
            sae_feature_attributions[name] = (
                (act.feature_acts.grad * raw_activation).detach().clone()
            )
            sae_feature_activations[name] = raw_activation
            sae_feature_grads[name] = act.feature_acts.grad.detach().clone()
            if include_error_term:
                assert act.sae_error.grad is not None
                error_grad_norm = act.sae_error.grad.norm().item()
            else:
                error_grad_norm = 0
            sae_out_norm = act.sae_out.grad.norm().item()
            sae_error_proportions[name] = error_grad_norm / (
                sae_out_norm + error_grad_norm + EPS
            )
        return Attribution(
            model_attributions=model_attributions,
            model_activations=model_activations,
            model_grads=model_grads,
            sae_feature_attributions=sae_feature_attributions,
            sae_feature_activations=sae_feature_activations,
            sae_feature_grads=sae_feature_grads,
            sae_errors_attribution_proportion=sae_error_proportions,
        )
        
        
# prompt = " Tiger Woods plays the sport of"
# pos_token = model.tokenizer.encode(" golf")[0]
prompt = "In the beginning, God created the heavens and the"
pos_token = model.tokenizer.encode(" earth")
neg_token = model.tokenizer.encode(" sky")
def metric_fn(logits: torch.tensor, pos_token: torch.tensor =pos_token, neg_token: torch.Tensor=neg_token) -> torch.Tensor:
    return logits[0,-1,pos_token] - logits[0,-1,neg_token]

feature_attribution_df = calculate_feature_attribution(
    input = prompt,
    model = model,
    metric_fn = metric_fn,
    include_saes={sae.cfg.hook_name: sae},
    include_error_term=True,
    return_logits=True,
)


In [None]:
tokens = model.to_str_tokens(prompt)
unique_tokens = [f"{i}/{t}" for i, t in enumerate(tokens)]

px.bar(x = unique_tokens,
       y = feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0].sum(-1).detach().cpu().numpy())

In [None]:
def convert_sparse_feature_to_long_df(sparse_tensor: torch.Tensor) -> pd.DataFrame:
    """
    Convert a sparse tensor to a long format pandas DataFrame.
    """
    df = pd.DataFrame(sparse_tensor.detach().cpu().numpy())
    df_long = df.melt(ignore_index=False, var_name='column', value_name='value')
    df_long.columns = ["feature", "attribution"]
    df_long_nonzero = df_long[df_long['attribution'] != 0]
    df_long_nonzero = df_long_nonzero.reset_index().rename(columns={'index': 'position'})
    return df_long_nonzero

df_long_nonzero = convert_sparse_feature_to_long_df(feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0])
df_long_nonzero.sort_values("attribution", ascending=False)

In [None]:
from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

# html = get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=feature_idx)
# IFrame(html, width=1200, height=600)

In [None]:
features = df_long_nonzero.query("position==8").groupby("feature").attribution.sum().sort_values(ascending=False).head(5).index.tolist()

In [None]:

# one by one
# for i, v in df_long_nonzero.query("position==8").groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
#     print(f"Feature {i} had a total attribution of {v:.2f}")
#     html = get_dashboard_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=int(i))
#     display(IFrame(html, width=1200, height=300))

# as a list 
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

get_neuronpedia_quick_list(sae, features)

In [None]:
for i, v in df_long_nonzero.groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_dashboard_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=int(i))
    display(IFrame(html, width=1200, height=300))

In [None]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

get_neuronpedia_quick_list(sae, feature_idx)