In [1]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [41]:
import torch
from collections import defaultdict

In [2]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer
os.environ["HF_TOKEN"] = "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"
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device = device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer


# Making a dataset 

In [22]:
from transformer_lens.utils import test_prompt

prompt = "What is the output of 53 plus 34 ? It is "
answer = '8'
# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, model)

Tokenized prompt: ['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '3', ' plus', ' ', '3', '4', ' ?', ' It', ' is', ' ']
Tokenized answer: [' ', '8']


Top 0th token. Logit: 27.23 Prob: 54.42% Token: |8|
Top 1th token. Logit: 25.61 Prob: 10.80% Token: |5|
Top 2th token. Logit: 25.38 Prob:  8.58% Token: |1|
Top 3th token. Logit: 24.99 Prob:  5.78% Token: |2|
Top 4th token. Logit: 24.60 Prob:  3.92% Token: |3|
Top 5th token. Logit: 24.43 Prob:  3.30% Token: |<strong>|
Top 6th token. Logit: 24.13 Prob:  2.47% Token: |4|
Top 7th token. Logit: 24.03 Prob:  2.23% Token: |________________|
Top 8th token. Logit: 23.97 Prob:  2.09% Token: |6|
Top 9th token. Logit: 23.95 Prob:  2.06% Token: |7|


Top 0th token. Logit: 25.26 Prob: 26.73% Token: |8|
Top 1th token. Logit: 24.54 Prob: 13.10% Token: |1|
Top 2th token. Logit: 24.53 Prob: 12.87% Token: |5|
Top 3th token. Logit: 24.14 Prob:  8.76% Token: |2|
Top 4th token. Logit: 23.82 Prob:  6.35% Token: |<strong>|
Top 5th token. Logit: 23.81 Prob:  6.29% Token: |3|
Top 6th token. Logit: 23.60 Prob:  5.12% Token: |6|
Top 7th token. Logit: 23.47 Prob:  4.46% Token: |4|
Top 8th token. Logit: 23.44 Prob:  4.36% Token: |________________|
Top 9th token. Logit: 23.36 Prob:  4.01% Token: |7|


In [40]:
import random

def generate_number_variations(prompt, num_variations=5):
    variations = []
    for _ in range(num_variations):
        num1 = random.randint(10, 70)
        num2 = random.randint(10, 70)
        new_prompt = prompt.replace("53", str(num1)).replace("34", str(num2))
        tokens = model.to_str_tokens(new_prompt)
        # print(tokens)
        # num1_pos = tokens.index(str(num1)[0])
        # num2_pos = tokens.index(str(num2)[0])
        # operation_pos = tokens.index(' plus')
        # question_pos = tokens.index(' ?')
        variations.append(new_prompt)
    return variations

prompt = "What is the output of 53 plus 34 ? It is "
number_variations = generate_number_variations(prompt, 20)
number_variations
# 7, 8, 9, 11, 12, 13, -1 

['What is the output of 56 plus 17 ? It is ',
 'What is the output of 69 plus 21 ? It is ',
 'What is the output of 16 plus 15 ? It is ',
 'What is the output of 20 plus 25 ? It is ',
 'What is the output of 44 plus 57 ? It is ',
 'What is the output of 67 plus 35 ? It is ',
 'What is the output of 62 plus 57 ? It is ',
 'What is the output of 14 plus 29 ? It is ',
 'What is the output of 29 plus 70 ? It is ',
 'What is the output of 41 plus 51 ? It is ',
 'What is the output of 62 plus 43 ? It is ',
 'What is the output of 23 plus 15 ? It is ',
 'What is the output of 59 plus 45 ? It is ',
 'What is the output of 68 plus 47 ? It is ',
 'What is the output of 47 plus 40 ? It is ',
 'What is the output of 58 plus 22 ? It is ',
 'What is the output of 40 plus 37 ? It is ',
 'What is the output of 13 plus 14 ? It is ',
 'What is the output of 16 plus 54 ? It is ',
 'What is the output of 38 plus 39 ? It is ']

In [35]:
def generate_operation_variations(prompt, operations=["plus", "minus", "times"]):
    variations = []
    for op in operations:
        new_prompt = prompt.replace("plus", op)
        tokens = model.to_str_tokens(new_prompt)
        print(tokens)
        # num1_pos = tokens.index("53")
        # num2_pos = tokens.index("34")
        # operation_pos = tokens.index(op)
        # question_pos = tokens.index("?")
        variations.append(new_prompt)
    return variations

operation_variations = generate_operation_variations(prompt)

['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '3', ' plus', ' ', '3', '4', ' ?', ' It', ' is', ' ']
['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '3', ' minus', ' ', '3', '4', ' ?', ' It', ' is', ' ']
['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '3', ' times', ' ', '3', '4', ' ?', ' It', ' is', ' ']


In [38]:
def generate_phrasing_variations(prompt):
    variations = []
    phrasings = [
        "What do you get when you add 53 and 34? It is ",
        "Calculate the result of 53 plus 34. The answer is ",
        "If you add 53 to 34, the outcome is ",
        "What is the sum of 53 and 34? The result is ",
        "Add 53 and 34 together. The result is "
    ]
    
    for phrasing in phrasings:
        new_prompt = phrasing.replace("53", "53").replace("34", "34")
        tokens = model.to_str_tokens(new_prompt)
        # num1_pos = tokens.index("53")
        # num2_pos = tokens.index("34")
        # operation_pos = tokens.index("plus")
        # question_pos = tokens.index("?")
        variations.append(new_prompt) #, num1_pos, operation_pos, num2_pos, question_pos))
    return variations

phrasing_variations = generate_phrasing_variations(prompt)
phrasing_variations

['What do you get when you add 53 and 34? It is ',
 'Calculate the result of 53 plus 34. The answer is ',
 'If you add 53 to 34, the outcome is ',
 'What is the sum of 53 and 34? The result is ',
 'Add 53 and 34 together. The result is ']

# Naive look up, top k activations 

In [None]:
from sae_lens.toolkit.pretrained_saes_directory import (
    get_norm_scaling_factor,
    get_pretrained_saes_directory,
)
sae_directory = get_pretrained_saes_directory()
sae_directory["gemma-scope-2b-pt-res-canonical"].saes_map

In [47]:
# Dictionary to store top 10 features for each layer
top_features_per_layer = {}

for layer in range(model.cfg.n_layers):
    # Load the SAE for the current layer
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release="gemma-scope-2b-pt-res-canonical",  # <- Release name 
        sae_id=f"layer_{layer}/width_16k/canonical",  # <- SAE id (not always a hook point!)
        device=device
    )
    
    # Dictionary to accumulate the sums of values for each index
    vals_dict = defaultdict(float)
    # Dictionary to count occurrences of each index
    count_dict = defaultdict(int)

    # Run through each prompt and accumulate the values for each index
    for pr in number_variations:
        _, cache = model.run_with_cache_with_saes(pr, saes=[sae])
        vals, inds = torch.topk(cache[f'blocks.{layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)

        for val, ind in zip(vals, inds):
            vals_dict[ind.item()] += val.item()
            count_dict[ind.item()] += 1

    # Calculate the average value for each index
    avg_vals_dict = {ind: vals_dict[ind] / count_dict[ind] for ind in vals_dict}

    # Sort the indices by their average values in descending order
    sorted_avg_vals = sorted(avg_vals_dict.items(), key=lambda x: x[1], reverse=True)

    # Get the top 10 indices with the highest average values
    top_10_inds = sorted_avg_vals[:5]

    # Store the top 10 features for this layer
    top_features_per_layer[layer] = top_10_inds

    # Optionally, print the top 10 indices and their corresponding average values
    print(f"Top 10 feature indices for layer {layer} and their average values:")
    for ind, avg_val in top_10_inds:
        print(f"Index: {ind}, Average Value: {avg_val}")

# Now `top_features_per_layer` contains the top 10 features for each layer

Top 10 feature indices for layer 0 and their average values:
Index: 6186, Average Value: 30.469324111938477
Index: 4845, Average Value: 26.52650318145752
Index: 10435, Average Value: 8.262005710601807
Index: 15091, Average Value: 4.431811332702637
Index: 8920, Average Value: 2.9824002742767335
Top 10 feature indices for layer 1 and their average values:
Index: 8994, Average Value: 46.77613296508789
Index: 5662, Average Value: 9.123183822631836
Index: 15908, Average Value: 7.394703197479248
Index: 12634, Average Value: 7.187231922149659
Index: 1401, Average Value: 5.367806434631348
Top 10 feature indices for layer 2 and their average values:
Index: 6868, Average Value: 38.52335453033447
Index: 2501, Average Value: 11.303512573242188
Index: 13544, Average Value: 10.067531108856201
Index: 8399, Average Value: 7.386288595199585
Index: 15142, Average Value: 5.760507869720459
Top 10 feature indices for layer 3 and their average values:
Index: 12349, Average Value: 32.95688438415527
Index: 53

In [49]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

In [54]:
# Iterate through the dictionary and print the indices and corresponding layer
for layer, top_features in top_features_per_layer.items():
    print(f"Layer {layer}:")
    if layer<21:
        continue
    for ind, avg_val in top_features:
        print(f"  Index: {ind}, Average Value: {avg_val}")
        html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id=f"{layer}-gemmascope-res-16k", feature_idx=ind)
        display(IFrame(html, width=1200, height=300))
    # if layer>20:
    #     break

Layer 0:
Layer 1:
Layer 2:
Layer 3:
Layer 4:
Layer 5:
Layer 6:
Layer 7:
Layer 8:
Layer 9:
Layer 10:
Layer 11:
Layer 12:
Layer 13:
Layer 14:
Layer 15:
Layer 16:
Layer 17:
Layer 18:
Layer 19:
Layer 20:
Layer 21:
  Index: 9753, Average Value: 109.16744499206543


  Index: 9561, Average Value: 97.17663269042968


  Index: 6691, Average Value: 75.3885612487793


  Index: 3461, Average Value: 69.70830173492432


  Index: 2253, Average Value: 63.28204618181501


Layer 22:
  Index: 5112, Average Value: 106.36619758605957


  Index: 2684, Average Value: 98.48553276062012


  Index: 4384, Average Value: 97.07634620666504


  Index: 14240, Average Value: 93.51848983764648


  Index: 5528, Average Value: 72.70002746582031


Layer 23:
  Index: 10559, Average Value: 129.1775161743164


  Index: 15985, Average Value: 110.38838958740234


  Index: 4798, Average Value: 76.111572265625


  Index: 9872, Average Value: 73.2943229675293


  Index: 7493, Average Value: 69.2427749633789


Layer 24:
  Index: 3650, Average Value: 231.72640380859374


  Index: 14448, Average Value: 183.26643905639648


  Index: 793, Average Value: 103.86737823486328


  Index: 2335, Average Value: 101.63099670410156


  Index: 8643, Average Value: 90.91254425048828


Layer 25:
  Index: 11781, Average Value: 321.95685119628905


  Index: 348, Average Value: 247.3959991455078


  Index: 15298, Average Value: 210.3530746459961


  Index: 13749, Average Value: 210.26357955932616


  Index: 2567, Average Value: 106.16500091552734


In [None]:
for 

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # <- Release name 
    sae_id = "layer_8/width_16k/average_l0_71", # <- SAE id (not always a hook point!)
    device = device
)