In [1]:
import os
import json
import argparse
import torch
import numpy as np
import pandas as pd
from functools import partial
from tqdm import tqdm
from collections import defaultdict
import re
from sae_lens import SAE, HookedSAETransformer
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
import helpers.utils as utils

In IPython
Set autoreload


In [2]:
def logit_diff_fn(logits, clean_labels, corr_labels, token_wise=False):
    """
    Compute logit difference for a batch.
    """
    # logits shape: [batch, seq_len, vocab_size]
    clean_logits = logits[torch.arange(logits.shape[0]), -1, clean_labels]
    corr_logits = logits[torch.arange(logits.shape[0]), -1, corr_labels]
    if token_wise:
        return (clean_logits - corr_logits)
    else:
        return (clean_logits - corr_logits).mean()


In [3]:
with open("config.json", 'r') as file:
   config = json.load(file)
token = config.get('huggingface_token', None)
os.environ["HF_TOKEN"] = token

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

hf_cache = "/work/pi_jensen_umass_edu/jnainani_umass_edu/mechinterp/huggingface_cache/hub"
os.environ["HF_HOME"] = hf_cache

# Load the model
model_name = "google/gemma-2-2b"
# model_name = "google/gemma-2-9b"
model = HookedSAETransformer.from_pretrained(model_name, device=device, cache_dir=hf_cache) 

for param in model.parameters():
   param.requires_grad_(False)



Device: cuda


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [4]:
layer_l0_target = 100
sae_gap = 5

df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
# df.loc['gemma-scope-2b-pt-res']['saes_map'] #['neuronpedia_id']

In [13]:
model.cfg.n_layers

26

In [19]:
sae_gap = 3
# # This is your custom approach from the notebook:
neuronpedia_dict = df.loc['gemma-scope-2b-pt-res']['saes_map']
pattern = re.compile(r'layer_(\d+)/width_16k/average_l0_(\d+)')
layer_dict = defaultdict(list)
for s in neuronpedia_dict.keys():
    match = pattern.search(s)
    if match and neuronpedia_dict[s] is not None:
        layer = int(match.group(1))
        l0_value = int(match.group(2))
        layer_dict[layer].append((s, l0_value))

# Find the string with l0 value closest to the user-specified target
closest_strings = {}
for layer, items in layer_dict.items():
    closest_string = min(items, key=lambda x: abs(x[1] - layer_l0_target))
    closest_strings[layer] = closest_string[0]
print(closest_strings)
# Actually load the SAEs
layers = [i for i in range(0, model.cfg.n_layers, sae_gap)]
print(layers)

{0: 'layer_0/width_16k/average_l0_105', 1: 'layer_1/width_16k/average_l0_102', 2: 'layer_2/width_16k/average_l0_141', 3: 'layer_3/width_16k/average_l0_59', 4: 'layer_4/width_16k/average_l0_124', 5: 'layer_5/width_16k/average_l0_68', 6: 'layer_6/width_16k/average_l0_70', 7: 'layer_7/width_16k/average_l0_69', 8: 'layer_8/width_16k/average_l0_71', 9: 'layer_9/width_16k/average_l0_73', 10: 'layer_10/width_16k/average_l0_77', 11: 'layer_11/width_16k/average_l0_80', 12: 'layer_12/width_16k/average_l0_82', 13: 'layer_13/width_16k/average_l0_84', 14: 'layer_14/width_16k/average_l0_84', 15: 'layer_15/width_16k/average_l0_78', 16: 'layer_16/width_16k/average_l0_78', 17: 'layer_17/width_16k/average_l0_77', 18: 'layer_18/width_16k/average_l0_74', 19: 'layer_19/width_16k/average_l0_73', 20: 'layer_20/width_16k/average_l0_71', 21: 'layer_21/width_16k/average_l0_70', 22: 'layer_22/width_16k/average_l0_72', 23: 'layer_23/width_16k/average_l0_75', 24: 'layer_24/width_16k/average_l0_73', 25: 'layer_25/w

In [20]:

saes = []
for layer in tqdm(layers):
    sae_id = closest_strings.get(layer, None)
    if sae_id is not None:
        sae = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res",
            sae_id=sae_id,
            device=device
        )[0]
        saes.append(sae)
    else:
        print(f"Warning: No matching SAE ID found for layer {layer} with L0 target {layer_l0_target}")

  0%|          | 0/9 [00:00<?, ?it/s]

 11%|█         | 1/9 [00:00<00:04,  1.61it/s]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 22%|██▏       | 2/9 [00:09<00:37,  5.29s/it]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 33%|███▎      | 3/9 [00:17<00:40,  6.79s/it]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 44%|████▍     | 4/9 [00:26<00:37,  7.48s/it]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 67%|██████▋   | 6/9 [00:35<00:16,  5.39s/it]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 78%|███████▊  | 7/9 [00:43<00:12,  6.42s/it]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

 89%|████████▉ | 8/9 [00:52<00:07,  7.14s/it]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 9/9 [01:01<00:00,  6.80s/it]


In [7]:
task = "sva/rc_train"
example_length = 7
N = 3000
batch_size = 16

data_path = f"data/{task}.json"
if not os.path.exists(data_path):
    raise FileNotFoundError(f"Data file {data_path} not found.")

with open(data_path, 'r') as f:
    data = [json.loads(line) for line in f]

clean_data = []
corr_data = []
clean_labels = []
corr_labels = []

for entry in data:
    clean_len = len(model.tokenizer(entry['clean_prefix']).input_ids)
    corr_len = len(model.tokenizer(entry['patch_prefix']).input_ids)
    if clean_len == corr_len == example_length:
        clean_data.append(entry['clean_prefix'])
        corr_data.append(entry['patch_prefix'])
        clean_labels.append(entry['clean_answer'])
        corr_labels.append(entry['patch_answer'])

# Limit to top N
clean_data = clean_data[:N]
corr_data = corr_data[:N]
clean_labels = clean_labels[:N]
corr_labels = corr_labels[:N]

# Tokenize
clean_tokens = model.to_tokens(clean_data)
corr_tokens = model.to_tokens(corr_data)
clean_label_tokens = model.to_tokens(clean_labels, prepend_bos=False).squeeze(-1)
corr_label_tokens = model.to_tokens(corr_labels, prepend_bos=False).squeeze(-1)

# Reshape into batches
n_batches_total = (len(clean_tokens) // batch_size)
clean_tokens = clean_tokens[:n_batches_total*batch_size]
corr_tokens = corr_tokens[:n_batches_total*batch_size]
clean_label_tokens = clean_label_tokens[:n_batches_total*batch_size]
corr_label_tokens = corr_label_tokens[:n_batches_total*batch_size]

clean_tokens = clean_tokens.reshape(-1, batch_size, clean_tokens.shape[-1])
corr_tokens = corr_tokens.reshape(-1, batch_size, corr_tokens.shape[-1])
clean_label_tokens = clean_label_tokens.reshape(-1, batch_size)
corr_label_tokens = corr_label_tokens.reshape(-1, batch_size)

print("Number of total batches after reshape:", clean_tokens.shape[0])

Number of total batches after reshape: 187


In [21]:
print("Computing average full model logit diff (clean vs corr)...")
avg_model_diff = 0.0
num_batches_eval = 10
# utils.cleanup_cuda()
with torch.no_grad():
    for i in range(min(num_batches_eval, clean_tokens.shape[0])):
        logits = model(clean_tokens[i])  # shape [batch_size, seq_len, vocab_size]
        ld = logit_diff_fn(logits, clean_label_tokens[i], corr_label_tokens[i])
        avg_model_diff += ld
avg_model_diff = (avg_model_diff / min(num_batches_eval, clean_tokens.shape[0])).item()
print("Average Full Model LD:", avg_model_diff)

Computing average full model logit diff (clean vs corr)...
Average Full Model LD: 3.532318115234375


In [22]:
avg_logit_diff = 0.0
with torch.no_grad():
    for i in range(min(num_batches_eval, clean_tokens.shape[0])):
        logits, saes = utils.run_sae_hook_fn(model, saes, clean_tokens[i], use_mean_error=False)
        ld = logit_diff_fn(logits, clean_label_tokens[i], corr_label_tokens[i])
        avg_logit_diff += ld
avg_logit_diff = (avg_logit_diff / min(num_batches_eval, clean_tokens.shape[0])).item()
print("Average Model + SAEs LD:", avg_logit_diff)

Average Model + SAEs LD: 1.569484829902649


In [10]:
avg_logit_diff = 0.0
with torch.no_grad():
    for i in range(min(num_batches_eval, clean_tokens.shape[0])):
        logits, saes = utils.run_sae_hook_fn(model, saes, clean_tokens[i], use_mean_error=False, calc_error=True, use_error=True)
        ld = logit_diff_fn(logits, clean_label_tokens[i], corr_label_tokens[i])
        avg_logit_diff += ld
avg_logit_diff = (avg_logit_diff / min(num_batches_eval, clean_tokens.shape[0])).item()
print("Average Model + SAEs LD:", avg_logit_diff)

Average Model + SAEs LD: 3.532318115234375


In [23]:
for sae in saes:
    sae.mask = utils.SparseMask(sae.cfg.d_sae, 1.0, seq_len=example_length).to(device)
saes = utils.get_sae_means(model, saes, corr_tokens, total_batches=40, batch_size=16)
saes = utils.get_sae_error_means(model, saes, corr_tokens, total_batches=40, batch_size=16)

Mean Accum Progress:   0%|          | 0/640 [00:00<?, ?it/s]

Mean Accum Progress: 100%|██████████| 640/640 [02:39<00:00,  4.02it/s]
Mean Accum Progress: 100%|██████████| 640/640 [02:38<00:00,  4.04it/s]


In [24]:
avg_logit_err_diff = 0.0
with torch.no_grad():
    for i in range(min(num_batches_eval, clean_tokens.shape[0])):
        logits, saes = utils.run_sae_hook_fn(
            model, saes, clean_tokens[i], use_mean_error=True
        )
        ld = logit_diff_fn(logits, clean_label_tokens[i], corr_label_tokens[i])
        avg_logit_err_diff += ld
avg_logit_err_diff = (avg_logit_err_diff / min(num_batches_eval, clean_tokens.shape[0])).item()
print("Average Model + SAEs (Mean Error) LD:", avg_logit_err_diff)

Average Model + SAEs (Mean Error) LD: 2.1713619232177734
