In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [2]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
df # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. 

Unnamed: 0,release,repo_id,model,saes_map
gemma-2b-it-res-jb,gemma-2b-it-res-jb,jbloom/Gemma-2b-IT-Residual-Stream-SAEs,gemma-2b-it,{'blocks.12.hook_resid_post': 'gemma_2b_it_blo...
gemma-2b-res-jb,gemma-2b-res-jb,jbloom/Gemma-2b-Residual-Stream-SAEs,gemma-2b,{'blocks.0.hook_resid_post': 'gemma_2b_blocks....
gpt2-small-hook-z-kk,gpt2-small-hook-z-kk,ckkissane/attn-saes-gpt2-small-all-layers,gpt2-small,{'blocks.0.hook_z': 'gpt2-small_L0_Hcat_z_lr1....
gpt2-small-mlp-tm,gpt2-small-mlp-tm,tommmcgrath/gpt2-small-mlp-out-saes,gpt2-small,{'blocks.0.hook_mlp_out': 'sae_group_gpt2_bloc...
gpt2-small-res-jb,gpt2-small-res-jb,jbloom/GPT2-Small-SAEs-Reformatted,gpt2-small,{'blocks.0.hook_resid_pre': 'blocks.0.hook_res...
gpt2-small-res-jb-feature-splitting,gpt2-small-res-jb-feature-splitting,jbloom/GPT2-Small-Feature-Splitting-Experiment...,gpt2-small,{'blocks.8.hook_resid_pre_768': 'blocks.8.hook...
gpt2-small-resid-post-v5-128k,gpt2-small-resid-post-v5-128k,jbloom/GPT2-Small-OAI-v5-128k-resid-post-SAEs,gpt2-small,{'blocks.0.hook_resid_post': 'v5_128k_layer_0'...
gpt2-small-resid-post-v5-32k,gpt2-small-resid-post-v5-32k,jbloom/GPT2-Small-OAI-v5-32k-resid-post-SAEs,gpt2-small,{'blocks.0.hook_resid_post': 'v5_32k_layer_0.p...
mistral-7b-res-wg,mistral-7b-res-wg,JoshEngels/Mistral-7B-Residual-Stream-SAEs,mistral-7b,{'blocks.8.hook_resid_pre': 'mistral_7b_layer_...


In [3]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained("gpt2-small", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # <- Release name 
    sae_id = "blocks.7.hook_resid_pre", # <- SAE id (not always a hook point!)
    device = device
)



config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


blocks.7.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

In [5]:
print(sae.cfg.__dict__)

{'architecture': 'standard', 'd_in': 768, 'd_sae': 24576, 'activation_fn_str': 'relu', 'apply_b_dec_to_input': True, 'finetuning_scaling_factor': False, 'context_size': 128, 'model_name': 'gpt2-small', 'hook_name': 'blocks.7.hook_resid_pre', 'hook_layer': 7, 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'Skylion007/openwebtext', 'dataset_trust_remote_code': True, 'normalize_activations': 'none', 'dtype': 'torch.float32', 'device': 'cuda', 'sae_lens_training_version': None, 'activation_fn_kwargs': {}}


In [6]:
from datasets import load_dataset  
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model.tokenizer, # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

Downloading readme:   0%|          | 0.00/373 [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/921 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (229134 > 1024). Running this sequence through the model will result in indexing errors


In [7]:
from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=feature_idx)
IFrame(html, width=1200, height=600)

In [8]:
from transformer_lens.utils import test_prompt

prompt = "In the beginning, God created the heavens and the"
answer = "earth"

# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'In', ' the', ' beginning', ',', ' God', ' created', ' the', ' heavens', ' and', ' the']
Tokenized answer: [' earth']


Top 0th token. Logit: 27.64 Prob: 99.32% Token: | earth|
Top 1th token. Logit: 22.46 Prob:  0.56% Token: | Earth|
Top 2th token. Logit: 19.20 Prob:  0.02% Token: | planets|
Top 3th token. Logit: 18.80 Prob:  0.01% Token: | moon|
Top 4th token. Logit: 18.07 Prob:  0.01% Token: | heavens|
Top 5th token. Logit: 17.67 Prob:  0.00% Token: | oceans|
Top 6th token. Logit: 17.43 Prob:  0.00% Token: | ten|
Top 7th token. Logit: 17.41 Prob:  0.00% Token: | stars|
Top 8th token. Logit: 17.38 Prob:  0.00% Token: | seas|
Top 9th token. Logit: 17.35 Prob:  0.00% Token: | four|


In [9]:
# SAEs don't reconstruct activation perfectly, so if you attach an SAE and want the model to stay performant, you need to use the error term.
# This is because the SAE will be used to modify the forward pass, and if it doesn't reconstruct the activations well, the outputs may be effected.
# Good SAEs have small error terms but it's something to be mindful of.

sae.use_error_term # If use error term is set to false, we will modify the forward pass by using the sae.

False

In [10]:
# hooked SAE Transformer will enable us to get the feature activations from the SAE
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])

print([(k, v.shape) for k,v in cache.items() if "sae" in k])

# note there were 11 tokens in our prompt, the residual stream dimension is 768, and the number of SAE features is 768

[('blocks.7.hook_resid_pre.hook_sae_input', torch.Size([1, 11, 768])), ('blocks.7.hook_resid_pre.hook_sae_acts_pre', torch.Size([1, 11, 24576])), ('blocks.7.hook_resid_pre.hook_sae_acts_post', torch.Size([1, 11, 24576])), ('blocks.7.hook_resid_pre.hook_sae_recons', torch.Size([1, 11, 768])), ('blocks.7.hook_resid_pre.hook_sae_output', torch.Size([1, 11, 768]))]


In [11]:
# let's look at which features fired at layer 8 at the final token position

# hover over lines to see the Feature ID.
px.line(
    cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
).show()

# let's print the top 5 features and how much they fired
vals, inds = torch.topk(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :], 5)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired {val:.2f}")
    html = get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))

AttributeError: 'Figure' object has no attribute 'save'