In [1]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [2]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
df[df['model']=='gemma-2-2b'] # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. 

Unnamed: 0,release,repo_id,model,saes_map
gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,google/gemma-scope-27b-pt-res,gemma-2-2b,{'layer_10/width_131k/average_l0_106': 'layer_...
gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/average_l0_104': 'layer_0/...
gemma-scope-2b-pt-att-canonical,gemma-scope-2b-pt-att-canonical,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...
gemma-scope-2b-pt-mlp,gemma-scope-2b-pt-mlp,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/average_l0_119': 'layer_0/...
gemma-scope-2b-pt-mlp-canonical,gemma-scope-2b-pt-mlp-canonical,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...
gemma-scope-2b-pt-res,gemma-scope-2b-pt-res,google/gemma-scope-2b-pt-res,gemma-2-2b,{'layer_0/width_16k/average_l0_105': 'layer_0/...
gemma-scope-2b-pt-res-canonical,gemma-scope-2b-pt-res-canonical,google/gemma-scope-2b-pt-res,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...
gemma-scope-9b-pt-att,gemma-scope-9b-pt-att,google/gemma-scope-9b-pt-att,gemma-2-2b,{'layer_0/width_131k/average_l0_55': 'layer_0/...
gemma-scope-9b-pt-mlp,gemma-scope-9b-pt-mlp,google/gemma-scope-9b-pt-mlp,gemma-2-2b,{'layer_0/width_131k/average_l0_11': 'layer_0/...


In [3]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer
os.environ["HF_TOKEN"] = "<hf token here>"
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device = device)





Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [4]:
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # <- Release name 
    sae_id = "layer_8/width_16k/average_l0_71", # <- SAE id (not always a hook point!)
    device = device
)

In [5]:
model.tokenizer("equal")

{'input_ids': [2, 13819], 'attention_mask': [1, 1]}

In [123]:
from transformer_lens.utils import test_prompt

prompt = "What is the output of 54 plus 32 ? It is "
answer = '8'
# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, model)

Tokenized prompt: ['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '4', ' plus', ' ', '3', '2', ' ?', ' It', ' is', ' ']
Tokenized answer: [' ', '8']


Top 0th token. Logit: 27.03 Prob: 54.65% Token: |8|
Top 1th token. Logit: 25.23 Prob:  9.05% Token: |1|
Top 2th token. Logit: 25.20 Prob:  8.74% Token: |5|
Top 3th token. Logit: 24.80 Prob:  5.89% Token: |2|
Top 4th token. Logit: 24.41 Prob:  3.97% Token: |3|
Top 5th token. Logit: 24.19 Prob:  3.20% Token: |<strong>|
Top 6th token. Logit: 24.15 Prob:  3.06% Token: |4|
Top 7th token. Logit: 23.99 Prob:  2.62% Token: |6|
Top 8th token. Logit: 23.86 Prob:  2.31% Token: |________________|
Top 9th token. Logit: 23.85 Prob:  2.27% Token: |7|


Top 0th token. Logit: 24.69 Prob: 26.87% Token: |8|
Top 1th token. Logit: 23.93 Prob: 12.58% Token: |1|
Top 2th token. Logit: 23.79 Prob: 10.90% Token: |5|
Top 3th token. Logit: 23.54 Prob:  8.47% Token: |2|
Top 4th token. Logit: 23.26 Prob:  6.44% Token: |<strong>|
Top 5th token. Logit: 23.20 Prob:  6.05% Token: |3|
Top 6th token. Logit: 23.13 Prob:  5.62% Token: |6|
Top 7th token. Logit: 23.10 Prob:  5.47% Token: |________________|
Top 8th token. Logit: 23.06 Prob:  5.28% Token: |4|
Top 9th token. Logit: 22.80 Prob:  4.06% Token: |7|


In [127]:
prompts = ["What’s the output of 54 plus 32? It is ", 
"What result is 54 plus 32? It is ",
"What does 54 plus 32 equal? It is ",
"What’s the sum of 54 and 32? It is ",
"What do you get from 54 plus 32? It is ",
"What does 54 plus 32 give? It is ",
"What is 54 plus 32 equal to? It is "]

answer = '8'

for pr in prompts:
    # test_prompt(pr, answer, model)
    _, cache = model.run_with_cache_with_saes(pr, saes=[sae])
    vals, inds = torch.topk(cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    print(inds)
    

tensor([16100,  8857, 13307, 15191,  2024,  5326, 10744,  5927, 14739,  6179,
         4802,  8025,  9213, 15600,  4150], device='cuda:0')
tensor([16100,  2024,  8857, 13307, 10744,  5326, 15191,  8025, 14739, 15600,
        11484,  5927,  6179, 14157,  2121], device='cuda:0')
tensor([16100,  2024, 10744,  8857,  5326, 15191, 13307, 14739,  8025,  6179,
         5927,  3936, 15600,  4150, 10524], device='cuda:0')
tensor([16100,  2024,  8857, 13307, 10350, 10744, 14739,  5927,  8025,  6179,
        15600, 15191, 15379,  4150, 10524], device='cuda:0')
tensor([16100,  8857,  6179, 15191,  2024,  4797, 10744,  5927,  5326, 13307,
         8025, 15600, 11010, 10350, 13549], device='cuda:0')
tensor([16100,  2024,  8857, 10744, 15191,  6179,  5326,  8025, 13307,  5927,
         4797, 13831, 10350, 15600,  3173], device='cuda:0')
tensor([16100,  2024, 10744,  8857, 13307,  5326, 14739,  8025,  5927,  2121,
        15191, 15600,  4150, 14157,  3173], device='cuda:0')


In [128]:
prompts = [
    "What is the output of 23 plus 45? It is ",
    "What’s the output of 37 plus 28? It is ",
    "What result is 61 plus 19? It is ",
    "What does 46 plus 27 equal? It is ",
    "What’s the sum of 34 and 52? It is ",
    "What do you get from 29 plus 44? It is ",
    "What does 18 plus 65 give? It is ",
    "What is 53 plus 30 equal to? It is ",
]
for pr in prompts:
    # test_prompt(pr, answer, model)
    _, cache = model.run_with_cache_with_saes(pr, saes=[sae])
    vals, inds = torch.topk(cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    print(inds)

tensor([16100,  8857,  2024, 15191, 13307,  5326, 10744, 14739,  5927,  9213,
         6179, 15600,  8025,  4150,  2121], device='cuda:0')
tensor([16100,  8857, 13307,  2024, 15191, 10744,  5326,  6179,  9213,  5927,
        14739,  4802, 15600,  8025,  4150], device='cuda:0')
tensor([16100,  2024,  8857, 13307, 15191, 10744,  5927,  8025, 15600,  5326,
        14739, 11484,  6179,   302,  2121], device='cuda:0')
tensor([16100,  2024,  8857, 15191, 10744, 13307,  6179,  5326, 14739,  8025,
         5927,  3936, 15600, 10350,  1296], device='cuda:0')
tensor([16100,  8857,  2024, 10350, 13307, 10744, 14739,  5927,  8025, 15600,
        15191, 10524,  6179,  2857, 15379], device='cuda:0')
tensor([16100,  4797,  8857,  2024, 10744, 15191,  6179,   302,  8025,  5927,
        13307, 13549, 15600,  3173, 13831], device='cuda:0')
tensor([16100,  2024,  8857, 10744,  4797, 13831,  6179,  5326, 13329,  8025,
        15191,  3173,  1296,  5927, 13307], device='cuda:0')
tensor([16100,  2024,  8857

In [130]:
import torch
from collections import defaultdict

# Dictionary to accumulate the sums of values for each index
vals_dict = defaultdict(float)
# Dictionary to count occurrences of each index
count_dict = defaultdict(int)

# Run through each prompt and accumulate the values for each index
for pr in prompts:
    _, cache = model.run_with_cache_with_saes(pr, saes=[sae])
    vals, inds = torch.topk(cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    
    for val, ind in zip(vals, inds):
        vals_dict[ind.item()] += val.item()
        count_dict[ind.item()] += 1

# Calculate the average value for each index
avg_vals_dict = {ind: vals_dict[ind] / count_dict[ind] for ind in vals_dict}

# Sort the indices by their average values in descending order
sorted_avg_vals = sorted(avg_vals_dict.items(), key=lambda x: x[1], reverse=True)

# Get the top 10 indices with the highest average values
top_10_inds = sorted_avg_vals[:10]

# Print the top 10 indices and their corresponding average values
print("Top 10 feature indices and their average values:")
for ind, avg_val in top_10_inds:
    print(f"Index: {ind}, Average Value: {avg_val}")
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))


Top 10 feature indices and their average values:
Index: 16100, Average Value: 26.569886445999146


Index: 8857, Average Value: 14.54923939704895


Index: 4797, Average Value: 12.907676219940186


Index: 2024, Average Value: 12.378076672554016


Index: 10744, Average Value: 9.643695950508118


Index: 13307, Average Value: 9.091849684715271


Index: 15191, Average Value: 8.494134068489075


Index: 10350, Average Value: 8.450870513916016


Index: 5326, Average Value: 8.189166704813639


Index: 14739, Average Value: 7.221282720565796


In [131]:
import torch
from collections import defaultdict

# Dictionary to accumulate the sums of values for each index
vals_dict = defaultdict(float)
# Dictionary to count occurrences of each index
count_dict = defaultdict(int)

# Run through each prompt and accumulate the values for each index across all token positions
for pr in prompts:
    _, cache = model.run_with_cache_with_saes(pr, saes=[sae])
    
    # Get the entire activation tensor, not just the final token position
    activations = cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, :]
    
    # Iterate over all token positions
    for token_pos in range(activations.size(0)):
        vals, inds = torch.topk(activations[token_pos, :], 15)
        
        for val, ind in zip(vals, inds):
            vals_dict[ind.item()] += val.item()
            count_dict[ind.item()] += 1

# Calculate the average value for each index
avg_vals_dict = {ind: vals_dict[ind] / count_dict[ind] for ind in vals_dict}

# Sort the indices by their average values in descending order
sorted_avg_vals = sorted(avg_vals_dict.items(), key=lambda x: x[1], reverse=True)

# Get the top 10 indices with the highest average values
top_10_inds = sorted_avg_vals[:10]

# Print the top 10 indices and their corresponding average values
print("Top 10 feature indices and their average values:")
for ind, avg_val in top_10_inds:
    print(f"Index: {ind}, Average Value: {avg_val}")
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))


Top 10 feature indices and their average values:
Index: 4399, Average Value: 240.1869659423828


Index: 3397, Average Value: 217.4520034790039


Index: 6069, Average Value: 186.01401977539064


Index: 9213, Average Value: 175.6007941489996


Index: 13188, Average Value: 134.60874938964844


Index: 12695, Average Value: 127.2891616821289


Index: 6524, Average Value: 114.64870357513428


Index: 14537, Average Value: 113.84121990203857


Index: 556, Average Value: 109.75512886047363


Index: 13729, Average Value: 109.55464839935303


In [134]:
# constrastive pair 
prompt = "What is the output of 54 plus 32 ? It is "
answer = '8'
# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, model)


Tokenized prompt: ['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '4', ' plus', ' ', '3', '2', ' ?', ' It', ' is', ' ']
Tokenized answer: [' ', '8']


Top 0th token. Logit: 27.03 Prob: 54.65% Token: |8|
Top 1th token. Logit: 25.23 Prob:  9.05% Token: |1|
Top 2th token. Logit: 25.20 Prob:  8.74% Token: |5|
Top 3th token. Logit: 24.80 Prob:  5.89% Token: |2|
Top 4th token. Logit: 24.41 Prob:  3.97% Token: |3|
Top 5th token. Logit: 24.19 Prob:  3.20% Token: |<strong>|
Top 6th token. Logit: 24.15 Prob:  3.06% Token: |4|
Top 7th token. Logit: 23.99 Prob:  2.62% Token: |6|
Top 8th token. Logit: 23.86 Prob:  2.31% Token: |________________|
Top 9th token. Logit: 23.85 Prob:  2.27% Token: |7|


Top 0th token. Logit: 24.69 Prob: 26.87% Token: |8|
Top 1th token. Logit: 23.93 Prob: 12.58% Token: |1|
Top 2th token. Logit: 23.79 Prob: 10.90% Token: |5|
Top 3th token. Logit: 23.54 Prob:  8.47% Token: |2|
Top 4th token. Logit: 23.26 Prob:  6.44% Token: |<strong>|
Top 5th token. Logit: 23.20 Prob:  6.05% Token: |3|
Top 6th token. Logit: 23.13 Prob:  5.62% Token: |6|
Top 7th token. Logit: 23.10 Prob:  5.47% Token: |________________|
Top 8th token. Logit: 23.06 Prob:  5.28% Token: |4|
Top 9th token. Logit: 22.80 Prob:  4.06% Token: |7|


In [80]:
# hooked SAE Transformer will enable us to get the feature activations from the SAE
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])

print([(k, v.shape) for k,v in cache.items() if "sae" in k])

# note there were 11 tokens in our prompt, the residual stream dimension is 768, and the number of SAE features is 768

[('blocks.8.hook_resid_post.hook_sae_input', torch.Size([1, 14, 2304])), ('blocks.8.hook_resid_post.hook_sae_acts_pre', torch.Size([1, 14, 16384])), ('blocks.8.hook_resid_post.hook_sae_acts_post', torch.Size([1, 14, 16384])), ('blocks.8.hook_resid_post.hook_sae_recons', torch.Size([1, 14, 2304])), ('blocks.8.hook_resid_post.hook_sae_output', torch.Size([1, 14, 2304]))]


In [81]:
cache['blocks.8.hook_resid_post.hook_sae_acts_post'].shape

torch.Size([1, 14, 16384])

In [82]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

# Create the line plot
fig = px.line(
    cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
)

# Show the plot (optional, for interactive display)
fig.show()

# Save the figure to a file
fig.write_image("feature_activations.png")  # Save as a PNG image

# let's print the top 5 features and how much they fired
vals, inds = torch.topk(cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired {val:.2f}")
    # html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=ind)
    # display(IFrame(html, width=1200, height=300))

Feature 16100 fired 26.37
Feature 2024 fired 12.97
Feature 13307 fired 10.09
Feature 14739 fired 9.11
Feature 8857 fired 8.37
Feature 10744 fired 8.32
Feature 15191 fired 7.78
Feature 5326 fired 7.65
Feature 8025 fired 6.88
Feature 2121 fired 5.81
Feature 15600 fired 5.77
Feature 302 fired 5.42
Feature 4150 fired 5.32
Feature 6179 fired 5.24
Feature 324 fired 5.22


In [None]:
Pos
Feature 16100 fired 27.27
Feature 2024 fired 12.52
Feature 10744 fired 10.61
Feature 8857 fired 9.67
Feature 13307 fired 8.00
Feature 5326 fired 7.58
Feature 8025 fired 7.53
Feature 14739 fired 6.78
Feature 15191 fired 6.25
Feature 15600 fired 6.15
Feature 11484 fired 5.83
Feature 6179 fired 5.61
Feature 2121 fired 5.11
Feature 5927 fired 5.10
Feature 3173 fired 4.64

neg 
Feature 16100 fired 26.37
Feature 2024 fired 12.97
Feature 13307 fired 10.09
Feature 14739 fired 9.11
Feature 8857 fired 8.37
Feature 10744 fired 8.32
Feature 15191 fired 7.78
Feature 5326 fired 7.65
Feature 8025 fired 6.88
Feature 2121 fired 5.81
Feature 15600 fired 5.77
Feature 302 fired 5.42
Feature 4150 fired 5.32
Feature 6179 fired 5.24
Feature 324 fired 5.22

In [116]:
prompt = ["What is 54 plus 32 ? \nIt is ", "What is 54 plus 32 ? \nWhat is "]
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
print([(k, v.shape) for k,v in cache.items() if "sae" in k])

feature_activation_df = pd.DataFrame(cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
                                     index = [f"feature_{i}" for i in range(sae.cfg.d_sae)],
)
feature_activation_df.columns = ["int"]
feature_activation_df["string"] = cache['blocks.8.hook_resid_post.hook_sae_acts_post'][1, -1, :].cpu().numpy()
feature_activation_df["diff"]= feature_activation_df["int"] - feature_activation_df["string"]

fig = px.line(
    feature_activation_df,
    title="Feature activations for the prompt",
    labels={"index": "Feature", "value": "Activation"},
)

# hide the x-ticks
fig.update_xaxes(showticklabels=False)
fig.show()
fig.write_image("diff_feature_activations.png") 

[('blocks.8.hook_resid_post.hook_sae_input', torch.Size([2, 16, 2304])), ('blocks.8.hook_resid_post.hook_sae_acts_pre', torch.Size([2, 16, 16384])), ('blocks.8.hook_resid_post.hook_sae_acts_post', torch.Size([2, 16, 16384])), ('blocks.8.hook_resid_post.hook_sae_recons', torch.Size([2, 16, 2304])), ('blocks.8.hook_resid_post.hook_sae_output', torch.Size([2, 16, 2304]))]


In [122]:
# let's look at the biggest features in terms of absolute difference

diff = cache['blocks.8.hook_resid_post.hook_sae_acts_post'][1, -1, :].cpu() - cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu()
vals, inds = torch.topk(torch.abs(diff), 15)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} had a difference of {val:.2f}")
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))

Feature 6179 had a difference of 19.84


Feature 9213 had a difference of 9.89


Feature 14888 had a difference of 9.73


Feature 8857 had a difference of 8.95


Feature 8025 had a difference of 7.52


Feature 95 had a difference of 7.47


Feature 5326 had a difference of 7.09


Feature 7759 had a difference of 6.80


Feature 4150 had a difference of 6.78


Feature 10906 had a difference of 6.50


Feature 4087 had a difference of 6.44


Feature 5927 had a difference of 5.19


Feature 5293 had a difference of 4.84


Feature 13329 had a difference of 4.71


Feature 11670 had a difference of 4.57


In [53]:
!nvidia-smi

Sun Aug 18 00:13:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:1B:00.0 Off |                    0 |
| N/A   34C    P0              69W / 500W |  32010MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000000:4E:00.0 Off |  

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [151]:
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple, Callable

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
import torch.nn.functional as F

# Use the functional API with inplace=False
# feature_acts = self.hook_sae_acts_post(F.relu(hidden_pre, inplace=False))

class SaeReconstructionCache(NamedTuple):
    sae_in: torch.Tensor
    feature_acts: torch.Tensor
    sae_out: torch.Tensor
    sae_error: torch.Tensor


def track_grad(tensor: torch.Tensor) -> None:
    """wrapper around requires_grad and retain_grad"""
    tensor.requires_grad_(True)
    tensor.retain_grad()


@dataclass
class ApplySaesAndRunOutput:
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]

    def zero_grad(self) -> None:
        """Helper to zero grad all tensors in this object."""
        self.model_output.grad = None
        for act in self.model_activations.values():
            act.grad = None
        for cache in self.sae_activations.values():
            cache.sae_in.grad = None
            cache.feature_acts.grad = None
            cache.sae_out.grad = None
            cache.sae_error.grad = None


def apply_saes_and_run(
    model: HookedTransformer,
    saes: dict[str, SAE],
    input: Any,
    include_error_term: bool = True,
    track_model_hooks: list[str] | None = None,
    return_type: Literal["logits", "loss"] = "logits",
    track_grads: bool = False,
) -> ApplySaesAndRunOutput:
    """
    Apply the SAEs to the model at the specific hook points, and run the model.
    By default, this will include a SAE error term which guarantees that the SAE
    will not affect model output. This function is designed to work correctly with
    backprop as well, so it can be used for gradient-based feature attribution.

    Args:
        model: the model to run
        saes: the SAEs to apply
        input: the input to the model
        include_error_term: whether to include the SAE error term to ensure the SAE doesn't affect model output. Default True
        track_model_hooks: a list of hook points to record the activations and gradients. Default None
        return_type: this is passed to the model.run_with_hooks function. Default "logits"
        track_grads: whether to track gradients. Default False
    """

    fwd_hooks = []
    bwd_hooks = []

    sae_activations: dict[str, SaeReconstructionCache] = {}
    model_activations: dict[str, torch.Tensor] = {}

    # this hook just track the SAE input, output, features, and error. If `track_grads=True`, it also ensures
    # that requires_grad is set to True and retain_grad is called for intermediate values.
    def reconstruction_hook(sae_in: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        sae = saes[hook_point]
#         x = sae_in.to(sae.dtype)
#         x = sae.reshape_fn_in(x)
#         x = sae.hook_sae_input(x)
#         x = sae.run_time_activation_norm_fn_in(x)

#         sae_in = x - (sae.b_dec * sae.cfg.apply_b_dec_to_input)

#         # # "... d_in, d_in d_sae -> ... d_sae",
#         hidden_pre = sae.hook_sae_acts_pre(sae_in @ sae.W_enc + sae.b_enc)
#         feature_acts = sae.hook_sae_acts_post(F.relu(hidden_pre, inplace=False))
        feature_acts = sae.encode(sae_in)
        
        
        sae_out = sae.decode(feature_acts)
        sae_error = (sae_in - sae_out).detach().clone()
        if track_grads:
            track_grad(sae_error)
            track_grad(sae_out)
            track_grad(feature_acts)
            track_grad(sae_in)
        sae_activations[hook_point] = SaeReconstructionCache(
            sae_in=sae_in,
            feature_acts=feature_acts,
            sae_out=sae_out,
            sae_error=sae_error,
        )

        if include_error_term:
            return sae_out + sae_error
        return sae_out

    def sae_bwd_hook(output_grads: torch.Tensor, hook: HookPoint):  # noqa: ARG001
        # this just passes the output grads to the input, so the SAE gets the same grads despite the error term hackery
        return (output_grads,)

    # this hook just records model activations, and ensures that intermediate activations have gradient tracking turned on if needed
    def tracking_hook(hook_input: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        model_activations[hook_point] = hook_input
        if track_grads:
            track_grad(hook_input)
        return hook_input

    for hook_point in saes.keys():
        fwd_hooks.append(
            (hook_point, partial(reconstruction_hook, hook_point=hook_point))
        )
        bwd_hooks.append((hook_point, sae_bwd_hook))
    for hook_point in track_model_hooks or []:
        fwd_hooks.append((hook_point, partial(tracking_hook, hook_point=hook_point)))

    # now, just run the model while applying the hooks
    with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks):
        model_output = model(input, return_type=return_type)

    return ApplySaesAndRunOutput(
        model_output=model_output,
        model_activations=model_activations,
        sae_activations=sae_activations,
    )


from dataclasses import dataclass
from transformer_lens.hook_points import HookPoint
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
torch.autograd.set_detect_anomaly(True)
EPS = 1e-8

torch.set_grad_enabled(True)
@dataclass
class AttributionGrads:
    metric: torch.Tensor
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]


@dataclass
class Attribution:
    model_attributions: dict[str, torch.Tensor]
    model_activations: dict[str, torch.Tensor]
    model_grads: dict[str, torch.Tensor]
    sae_feature_attributions: dict[str, torch.Tensor]
    sae_feature_activations: dict[str, torch.Tensor]
    sae_feature_grads: dict[str, torch.Tensor]
    sae_errors_attribution_proportion: dict[str, float]


def calculate_attribution_grads(
    model: HookedSAETransformer,
    prompt: str,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> AttributionGrads:
    """
    Wrapper around apply_saes_and_run that calculates gradients wrt to the metric_fn.
    Tracks grads for both SAE feature and model neurons, and returns them in a structured format.
    """
    output = apply_saes_and_run(
        model,
        saes=include_saes or {},
        input=prompt,
        return_type="logits" if return_logits else "loss",
        track_model_hooks=track_hook_points,
        include_error_term=include_error_term,
        track_grads=True,
    )
    metric = metric_fn(output.model_output)
    output.zero_grad()
    metric.backward()
    return AttributionGrads(
        metric=metric,
        model_output=output.model_output,
        model_activations=output.model_activations,
        sae_activations=output.sae_activations,
    )


def calculate_feature_attribution(
    model: HookedSAETransformer,
    input: Any,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> Attribution:
    """
    Calculate feature attribution for SAE features and model neurons following
    the procedure in https://transformer-circuits.pub/2024/march-update/index.html#feature-heads.
    This include the SAE error term by default, so inserting the SAE into the calculation is
    guaranteed to not affect the model output. This can be disabled by setting `include_error_term=False`.

    Args:
        model: The model to calculate feature attribution for.
        input: The input to the model.
        metric_fn: A function that takes the model output and returns a scalar metric.
        track_hook_points: A list of model hook points to track activations for, if desired
        include_saes: A dictionary of SAEs to include in the calculation. The key is the hook point to apply the SAE to.
        return_logits: Whether to return the model logits or loss. This is passed to TLens, so should match whatever the metric_fn expects (probably logits)
        include_error_term: Whether to include the SAE error term in the calculation. This is recommended, as it ensures that the SAE will not affecting the model output.
    """
    # first, calculate gradients wrt to the metric_fn.
    # these will be multiplied with the activation values to get the attributions
    outputs_with_grads = calculate_attribution_grads(
        model,
        input,
        metric_fn,
        track_hook_points,
        include_saes=include_saes,
        return_logits=return_logits,
        include_error_term=include_error_term,
    )
    model_attributions = {}
    model_activations = {}
    model_grads = {}
    sae_feature_attributions = {}
    sae_feature_activations = {}
    sae_feature_grads = {}
    sae_error_proportions = {}
    # this code is long, but all it's doing is multiplying the grads by the activations
    # and recording grads, acts, and attributions in dictionaries to return to the user
    with torch.no_grad():
        for name, act in outputs_with_grads.model_activations.items():
            assert act.grad is not None
            raw_activation = act.detach().clone()
            model_attributions[name] = (act.grad * raw_activation).detach().clone()
            model_activations[name] = raw_activation
            model_grads[name] = act.grad.detach().clone()
        for name, act in outputs_with_grads.sae_activations.items():
            assert act.feature_acts.grad is not None
            assert act.sae_out.grad is not None
            raw_activation = act.feature_acts.detach().clone()
            sae_feature_attributions[name] = (
                (act.feature_acts.grad * raw_activation).detach().clone()
            )
            sae_feature_activations[name] = raw_activation
            sae_feature_grads[name] = act.feature_acts.grad.detach().clone()
            if include_error_term:
                assert act.sae_error.grad is not None
                error_grad_norm = act.sae_error.grad.norm().item()
            else:
                error_grad_norm = 0
            sae_out_norm = act.sae_out.grad.norm().item()
            sae_error_proportions[name] = error_grad_norm / (
                sae_out_norm + error_grad_norm + EPS
            )
        return Attribution(
            model_attributions=model_attributions,
            model_activations=model_activations,
            model_grads=model_grads,
            sae_feature_attributions=sae_feature_attributions,
            sae_feature_activations=sae_feature_activations,
            sae_feature_grads=sae_feature_grads,
            sae_errors_attribution_proportion=sae_error_proportions,
        )
        
        
# prompt = " Tiger Woods plays the sport of"
# pos_token = model.tokenizer.encode(" golf")[0]
prompt = "What is the output of 54 plus 32 ? It is "
pos_token = [model.tokenizer.encode("8")[1]]
neg_token = [model.tokenizer.encode("1")[1]]
# def metric_fn(logits: torch.tensor, pos_token: torch.tensor =pos_token, neg_token: torch.Tensor=neg_token) -> torch.Tensor:
#     return logits[0,-1,pos_token] - logits[0,-1,neg_token]

def metric_fn(logits: torch.Tensor, pos_token: torch.Tensor = pos_token, neg_token: torch.Tensor = neg_token) -> torch.Tensor:
    return (logits[0, -1, pos_token] - logits[0, -1, neg_token]).sum()


feature_attribution_df = calculate_feature_attribution(
    input = prompt,
    model = model,
    metric_fn = metric_fn,
    include_saes={sae.cfg.hook_name: sae},
    include_error_term=True,
    return_logits=True,
)


In [152]:
model.tokenizer.encode("8")[1]

235321

In [153]:
from transformer_lens.utils import test_prompt
test_prompt(prompt, model.to_string(pos_token), model)

Tokenized prompt: ['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '4', ' plus', ' ', '3', '2', ' ?', ' It', ' is', ' ']
Tokenized answer: [' ', '8']


Top 0th token. Logit: 27.03 Prob: 54.65% Token: |8|
Top 1th token. Logit: 25.23 Prob:  9.05% Token: |1|
Top 2th token. Logit: 25.20 Prob:  8.74% Token: |5|
Top 3th token. Logit: 24.80 Prob:  5.89% Token: |2|
Top 4th token. Logit: 24.41 Prob:  3.97% Token: |3|
Top 5th token. Logit: 24.19 Prob:  3.20% Token: |<strong>|
Top 6th token. Logit: 24.15 Prob:  3.06% Token: |4|
Top 7th token. Logit: 23.99 Prob:  2.62% Token: |6|
Top 8th token. Logit: 23.86 Prob:  2.31% Token: |________________|
Top 9th token. Logit: 23.85 Prob:  2.27% Token: |7|


Top 0th token. Logit: 24.69 Prob: 26.87% Token: |8|
Top 1th token. Logit: 23.93 Prob: 12.58% Token: |1|
Top 2th token. Logit: 23.79 Prob: 10.90% Token: |5|
Top 3th token. Logit: 23.54 Prob:  8.47% Token: |2|
Top 4th token. Logit: 23.26 Prob:  6.44% Token: |<strong>|
Top 5th token. Logit: 23.20 Prob:  6.05% Token: |3|
Top 6th token. Logit: 23.13 Prob:  5.62% Token: |6|
Top 7th token. Logit: 23.10 Prob:  5.47% Token: |________________|
Top 8th token. Logit: 23.06 Prob:  5.28% Token: |4|
Top 9th token. Logit: 22.80 Prob:  4.06% Token: |7|


In [154]:
# Generate the plot
tokens = model.to_str_tokens(prompt)
unique_tokens = [f"{i}/{t}" for i, t in enumerate(tokens)]

fig = px.bar(x=unique_tokens,
             y=feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0].sum(-1).detach().cpu().numpy())

# Show the plot
fig.show()

# Save the plot as a PNG file
fig.write_image("feature_attribution_plot.png")


In [155]:
def convert_sparse_feature_to_long_df(sparse_tensor: torch.Tensor) -> pd.DataFrame:
    """
    Convert a sparse tensor to a long format pandas DataFrame.
    """
    df = pd.DataFrame(sparse_tensor.detach().cpu().numpy())
    df_long = df.melt(ignore_index=False, var_name='column', value_name='value')
    df_long.columns = ["feature", "attribution"]
    df_long_nonzero = df_long[df_long['attribution'] != 0]
    df_long_nonzero = df_long_nonzero.reset_index().rename(columns={'index': 'position'})
    return df_long_nonzero

df_long_nonzero = convert_sparse_feature_to_long_df(feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0])
df_long_nonzero.sort_values("attribution", ascending=False).head(15)

Unnamed: 0,position,feature,attribution
7104,16,16100,0.276239
3435,0,7848,0.128525
1153,15,2646,0.118398
5811,0,13188,0.117178
552,0,1289,0.101739
3845,12,8764,0.096654
4336,8,9854,0.095792
2055,13,4781,0.093329
3882,0,8853,0.083462
2140,0,4965,0.080137


In [156]:
for i, v in df_long_nonzero.query("position==12").groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=int(i))
    display(IFrame(html, width=1200, height=300))

Feature 8764 had a total attribution of 0.10


Feature 9854 had a total attribution of 0.06


Feature 13914 had a total attribution of 0.02


Feature 12438 had a total attribution of 0.02


Feature 16055 had a total attribution of 0.01


In [146]:
for i, v in df_long_nonzero.groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=int(i))
    display(IFrame(html, width=1200, height=300))

Feature 16100 had a total attribution of 0.28


Feature 9854 had a total attribution of 0.20


Feature 2646 had a total attribution of 0.15


Feature 7848 had a total attribution of 0.13


Feature 13188 had a total attribution of 0.12
