# SAE-VIZ demo

This Colab was created to demo my open-source sparse autoencoder visualizer, as can be seen [here](https://www.perfectlynormal.co.uk/blog-sae). The [GitHub readme](https://github.com/callummcdougall/sae_vis) contains a more comprehensive explanation of how it works; this Colab focuses on actually demoing the functions.

In this notebook, we demo two different visualization:

1. **Feature-centric vis**, where you look at a single feature and see e.g. which sequences in a large dataset this feature fires strongest on.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/sae-demo-1.png" width="1000">

2. **Prompt-centric vis**, where you input a custom prompt and see which features score highest on that prompt, according to a variety of possible metrics.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/sae-demo-2.png" width="750">



# Imports & Installs

In [1]:
try:
    import google.colab # type: ignore

    !git clone https://github.com/jbloomAus/mats_sae_training.git
except:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

from IPython.display import display, HTML
import torch
from datasets import load_dataset
import pickle
import webbrowser
import os
import sys
from transformer_lens import utils, HookedTransformer
from huggingface_hub import hf_hub_download
from tqdm.notebook import tqdm

from sparsify.scripts.generate_dashboards.sae_vis.model_fns import AutoEncoder, AutoEncoderConfig
from sparsify.scripts.generate_dashboards.sae_vis.data_fetching_fns import get_feature_data, get_prompt_data
from sparsify.scripts.generate_dashboards.sae_vis.data_storing_fns import FeatureVisParams, MultiFeatureData, MultiPromptData
from sparsify.scripts.generate_dashboards.sae_vis.utils_fns import create_vocab_dict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_grad_enabled(False);

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'sae_vis'

# Setup

## Autoencoders

<!-- We're being a bit lazy here, and slicing our autoencoder so that we only take the first 2048 features (i.e. `dict_mult = 1`) rather than all 16384 features. This is literally just to avoid OOMs; you can increase the `DICT_MULT` parameter up to 8 if you'd like. -->

We set up our autoencoder here. You can use your own autoencoder, as long as it has the same parameters `W_enc`, `W_dec`, `b_enc` and `b_dec` (used in the same way) and has a `cfg` attribute which itself is a dataclass with attributes `d_mlp` and `dict_mult`. The forward pass method doesn't matter; we only ever use the weights directly in this codebase.

In [None]:
encoder = AutoEncoder.load_from_hf(version="run1")
encoder_B = AutoEncoder.load_from_hf(version="run2")

for k, v in encoder.named_parameters():
    print(f"{k}: {tuple(v.shape)}")

## Models

This library will eventually support non-transformerlens models, but it's not there currently. If you're interested in this, please reach out!

<!-- This library supports non-transformerlens models, provided you apply a wrapper around your model with a few specific methods (e.g. a modified `forward` function which returns a tuple of `(logits, activations, resid)`). However, it's much easier to just use a TransformerLens model in most cases! -->

<!-- The code below loads in our GELU-1l transformer model. You can create your transformer model any way you like; all that matters is that:

* Your model has a `forward` method which takes `tokens` and returns a tuple of `(logits, residual, post_activations)`.
* This forward method has a parameter `return_logits`, which is by default `True`, and when `False` it only returns `(residual, post_activations)`.

Provided this is the case, all other code here (including calculating the effect of ablating certain features) doesn't rely on any specific implementation details of the model.

If you're trying to use a particular model, we recommend **creating a wrapper class around your model which has an altered `forward` method** to match the required behaviour. In the case of this notebook, to make it clear that a `HookedTransformer` model is not necessary, we're using a `DemoTransformer` model (code in this repository), which is a very minimal version of the `HookedTransformer` model lacking the features like hooks, caches, etc. -->

In [None]:
model = HookedTransformer.from_pretrained("gelu-1l")

## Data

Obviously you can replace this code with your own data loading code. You should eventually have a 2D tensor of token ids.

In [None]:
SEQ_LEN = 128

data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN)
tokenized_data = tokenized_data.shuffle(42)
all_tokens: torch.Tensor = tokenized_data["tokens"]

print(all_tokens.shape)

# Creating visualisations #1 (feature-centric)


Here's an example, which generates data for the first 256 features, and generates the vis for one of them.

In [None]:
# Create a dataclass for the feature vis parameters (and use the help method to see what all the arguments do)

feature_vis_params = FeatureVisParams(
    hook_point = utils.get_act_name("post", 0),
    features = range(1024),
)
feature_vis_params.help()

# Get the feature data (this should take ~30 seconds on A100, because we're only doing 1024 features and 1024 sequences)

feature_data = get_feature_data(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens[:1024],
    fvp = feature_vis_params,
)

# Get the HTML (in Colab 'webbrowser' won't work, so you'll need to download and open this visualization in your browser)

test_idx = 8
filepath = "feature_vis_demo.html"

html_str = feature_data[test_idx].get_html()
display(HTML(html_str))
with open(filepath, "w") as f:
    f.write(html_str)
result = webbrowser.open(filepath)

If you don't care about the activation quantiles, you can also make this function run faster by getting rid of those groups - just set `n_groups=0` in the `FeatureVisParams` dataclass. You can also pass `include_left_tables=False` if you want an even more minimal plot (although this doesn't really save much time, since the left tables are fast to compute: most of the time is taken up by the forward passes & sequence data calculations). This code also demonstrates using `border=False`, which removes the shadow border around the plot.

In [None]:
feature_vis_params = FeatureVisParams(
    hook_point = utils.get_act_name("post", 0),
    features = range(256),
    n_groups = 0,
    first_group_size = 10,
    include_left_tables = False,
    border = False,
)

feature_data = get_feature_data(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens[:1024],
    fvp = feature_vis_params,
)

html_str = feature_data[test_idx].get_html()
display(HTML(html_str))
with open(filepath, "w") as f:
    f.write(html_str)
result = webbrowser.open(filepath)

# Feature-centric visualisations for multi-layer models

We've currently only worked with 1-layer models. Let's try and see what happens when we use a multi-layer model. Thankfully, Joseph Bloom has trained some excellent SAEs on GPT2-small, so we can use one of them.

First, we load the model, and the autoencoder. The autoencoder code is currently a bit hacky because it's not closely integrated with Joseph's library (I'm using my own autoencoder class & config object rather than Joseph's), but this will improve soon.

In [None]:
gpt2 = HookedTransformer.from_pretrained("gpt2-small")

layer = 2
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

# Make sure Joseph's lib is in the path, or else the load will fail
if os.getcwd() + "/mats_sae_training" not in sys.path:
    sys.path.append(os.getcwd() + "/mats_sae_training")

obj = torch.load(path, mmap="cpu")
state_dict = obj["state_dict"]
assert set(state_dict.keys()) == {'W_enc', 'b_enc', 'W_dec', 'b_dec'}

cfg = AutoEncoderConfig(
    d_in = obj["cfg"].d_in,
    dict_mult = obj["cfg"].expansion_factor,
)
gpt2_sae = AutoEncoder(cfg)
gpt2_sae.load_state_dict(state_dict);

And now let's get our vis. Feel the force!

In [None]:
feature = 7650

feature_vis_params_gpt = FeatureVisParams(
    hook_point = obj["cfg"].hook_point,
    minibatch_size_tokens = 512,
    n_groups = 0,
    first_group_size = 15,
    features = feature,
    verbose = True,
    include_left_tables = False,
)

feature_data_gpt = get_feature_data(
    encoder = gpt2_sae,
    model = gpt2,
    tokens = all_tokens[:8192],
    fvp = feature_vis_params_gpt,
)

html_str = feature_data_gpt[feature].get_html()
display(HTML(html_str))
with open(filepath, "w") as f:
    f.write(html_str)

result = webbrowser.open(filepath)

# Creating visualisations #2 (prompt-centric)

First we create our vocab dict, via a helper function which allows us to get nice HTML representations of our tokens (rather than things which mess up our HTML, e.g. actual line breaks). You should do this on your model's tokenizer, since this `vocab_dict` will be used in subsequent functions. I've only worked with the GPT2 tokenizer, so if this code fails in some way for a different tokenizer, please let me know!

In [None]:
vocab_dict = create_vocab_dict(model.tokenizer)

Next, we pick a prompt and generate the data for it. The `get_prompt_data` function requires `feature_data` as input, because it needs things like the max-activating sequences for this feature. Note, we're using the `feature_data` object with `n_groups=0` and `include_left_tables=False` - this is because we don't actually need these for the prompt-centric visualization. If you're only trying to generate the prompt-centric view, it's a good idea to have these parameters set to these values, because it will speed up the process.

We don't have an extra dataclass like `FeatureVisParams` to wrap our arguments in, because there are very few. Some of them (e.g. `first_group_size`) are inherited from the `FeatureVisParams` object which was used to generate the `feature_data` which is supplied. The only important argument we need to use is `num_top_features`, which is the max number of top-scoring features which are displayed for any given prompt & metric. There's also the argument `verbose` (default False) which controls whether progress bars are printed.

In [None]:
prompt = "'first_name': ('django.db.models.fields"

str_toks = model.tokenizer.tokenize(prompt)
print(str_toks)

prompt_data = get_prompt_data(
    encoder = encoder,
    model = model,
    prompt = prompt,
    feature_data = feature_data,
    fvp = feature_vis_params,
    num_top_features = 10,
)

Lastly, from this data we create our visualization. We've chosen to examine the `"loss_effect"` on the `django` token, i.e. showing the features whose contributions most reduce the loss on this token.

In [None]:
str_score = "loss_effect"
seq_pos = str_toks.index("django")

html_str = prompt_data.get_html(seq_pos, str_score, vocab_dict)

display(HTML(html_str))

filepath = "prompt_vis_demo.html"
with open(filepath, "w") as f:
    f.write(html_str)

result = webbrowser.open(filepath)

Alternatively, you can use the `"act_size"` or `"act_quantile"` metrics (we recommend the latter) on the `Ġ('` token, i.e. the token immediately before `django`. Remember, we have to include this `Ġ` character at the front of the token (which represents the space character), although this will depend on what tokenizer your model is using.

In [None]:
str_score = "act_quantile"
seq_pos = str_toks.index("Ġ('")

html_str = prompt_data.get_html(seq_pos, str_score, vocab_dict)

display(HTML(html_str))

filepath = "prompt_vis_demo.html"
with open(filepath, "w") as f:
    f.write(html_str)

result = webbrowser.open(filepath)

# Saving data

Obviously the HTML strings can be saved, either as strings or as regular HTML files. If you want something more compact, you can pickle the dataclasses:

In [None]:
# Save
with open("feature_data.pkl", "wb") as f:
    pickle.dump(feature_data, f)

# Load
with open("feature_data.pkl", "rb") as f:
    feature_data: MultiFeatureData = pickle.load(f)

# Delete
os.remove("feature_data.pkl")

# Visualize the loaded data, to check it works
html_str = feature_data[test_idx].get_html()
display(HTML(html_str))

And for the prompt-centric visualisation:

In [None]:
# Save
with open("prompt_data.pkl", "wb") as f:
    pickle.dump(prompt_data, f)

# Load
with open("prompt_data.pkl", "rb") as f:
    prompt_data: MultiPromptData = pickle.load(f)

# Delete
os.remove("prompt_data.pkl")

# Visualize the loaded data, to check it works
html_str = prompt_data.get_html(seq_pos, str_score, vocab_dict)
display(HTML(html_str))