In [1]:
!nvidia-smi

Wed Aug 21 06:23:58 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:0A:00.0 Off |                    0 |
| N/A   33C    P0              64W / 400W |      0MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000000:8A:00.0 Off |  

In [3]:
import torch.nn.functional as F

In [34]:
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

In [2]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [4]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="8-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

In [5]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer
os.environ["HF_TOKEN"] = "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"
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device = device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [6]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
df[df['model']=='gemma-2-2b'] # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. 

Unnamed: 0,release,repo_id,model,saes_map
gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,google/gemma-scope-27b-pt-res,gemma-2-2b,{'layer_10/width_131k/average_l0_106': 'layer_...
gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/average_l0_104': 'layer_0/...
gemma-scope-2b-pt-att-canonical,gemma-scope-2b-pt-att-canonical,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...
gemma-scope-2b-pt-mlp,gemma-scope-2b-pt-mlp,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/average_l0_119': 'layer_0/...
gemma-scope-2b-pt-mlp-canonical,gemma-scope-2b-pt-mlp-canonical,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...
gemma-scope-2b-pt-res,gemma-scope-2b-pt-res,google/gemma-scope-2b-pt-res,gemma-2-2b,{'layer_0/width_16k/average_l0_105': 'layer_0/...
gemma-scope-2b-pt-res-canonical,gemma-scope-2b-pt-res-canonical,google/gemma-scope-2b-pt-res,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...
gemma-scope-9b-pt-att,gemma-scope-9b-pt-att,google/gemma-scope-9b-pt-att,gemma-2-2b,{'layer_0/width_131k/average_l0_55': 'layer_0/...
gemma-scope-9b-pt-mlp,gemma-scope-9b-pt-mlp,google/gemma-scope-9b-pt-mlp,gemma-2-2b,{'layer_0/width_131k/average_l0_11': 'layer_0/...


In [7]:
import re
from collections import defaultdict

sae_keys = list(df.loc['gemma-scope-2b-pt-res']['saes_map'].keys())
# Dictionary to store the closest string for each layer
closest_strings = {}

# Regular expression to extract the layer number and l0 value
pattern = re.compile(r'layer_(\d+)/width_16k/average_l0_(\d+)')

# Organize strings by layer
layer_dict = defaultdict(list)

for s in sae_keys:
    match = pattern.search(s)
    if match:
        layer = int(match.group(1))
        l0_value = int(match.group(2))
        layer_dict[layer].append((s, l0_value))

# Find the string with l0 value closest to 100 for each layer
for layer, items in layer_dict.items():
    closest_string = min(items, key=lambda x: abs(x[1] - 100))
    closest_strings[layer] = closest_string[0]

# Output the closest string for each layer
for layer in sorted(closest_strings):
    print(f"Layer {layer}: {closest_strings[layer]}")


Layer 0: layer_0/width_16k/average_l0_105
Layer 1: layer_1/width_16k/average_l0_102
Layer 2: layer_2/width_16k/average_l0_141
Layer 3: layer_3/width_16k/average_l0_59
Layer 4: layer_4/width_16k/average_l0_124
Layer 5: layer_5/width_16k/average_l0_68
Layer 6: layer_6/width_16k/average_l0_70
Layer 7: layer_7/width_16k/average_l0_69
Layer 8: layer_8/width_16k/average_l0_71
Layer 9: layer_9/width_16k/average_l0_73
Layer 10: layer_10/width_16k/average_l0_77
Layer 11: layer_11/width_16k/average_l0_80
Layer 12: layer_12/width_16k/average_l0_82
Layer 13: layer_13/width_16k/average_l0_84
Layer 14: layer_14/width_16k/average_l0_84
Layer 15: layer_15/width_16k/average_l0_78
Layer 16: layer_16/width_16k/average_l0_78
Layer 17: layer_17/width_16k/average_l0_77
Layer 18: layer_18/width_16k/average_l0_74
Layer 19: layer_19/width_16k/average_l0_73
Layer 20: layer_20/width_16k/average_l0_71
Layer 21: layer_21/width_16k/average_l0_70
Layer 22: layer_22/width_16k/average_l0_72
Layer 23: layer_23/width_16

In [8]:
# layer = 0

prompt = 'What is the output of 53 plus 32 ? A: '

# operator
ind = 9

operator_features = {}

for layer in range(model.cfg.n_layers):
    operator_features[layer] = []
    # Load the SAE for the current layer
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release="gemma-scope-2b-pt-res", 
        sae_id=closest_strings[layer],  
        device=device
    )
    _, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
    x = cache[f'blocks.{layer}.hook_resid_post.hook_sae_acts_post']
    topk_values, topk_indices = torch.topk(x[0, ind, :], 10)
    gathered_values = x[0, :, topk_indices]
    
    softmaxed_x = F.softmax(gathered_values, dim=0)
    top3_indices_after_softmax = torch.topk(softmaxed_x, 1, dim=0).indices
    for i in range(topk_indices.size(0)):
        if ind in top3_indices_after_softmax[:, i]:
            operator_features[layer].append(topk_indices[i])
    

In [None]:
operator_features

In [44]:

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # <- Release name 
    sae_id = "layer_8/width_16k/average_l0_71", # <- SAE id (not always a hook point!)
    device = device
)
pr = 'What is the output of 53 plus 32 ? A: '

# for ind in range(1, 17):
ind = 13
_, cache = model.run_with_cache_with_saes(pr, saes=[sae])
x = cache['blocks.8.hook_resid_post.hook_sae_acts_post']
topk_values, topk_indices = torch.topk(x[0, ind, :], 50)
gathered_values = x[0, :, topk_indices]
print(gathered_values.shape)
# softmaxed_values = F.softmax(gathered_values, dim=0)
# print(softmaxed_values)
# vals, inds = torch.topk(cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, 13, :], 50)
# inds

torch.Size([17, 50])


In [45]:
softmaxed_x = F.softmax(gathered_values, dim=0)
top3_indices_after_softmax = torch.topk(softmaxed_x, 2, dim=0).indices
max_indices = torch.argmax(softmaxed_x, dim=0)

In [46]:
top3_indices_after_softmax

tensor([[13, 13, 13,  0,  9, 13, 16, 13,  9,  3, 13, 13, 13,  0,  4,  3, 13,  0,
         16,  0,  0, 13,  4,  0, 13, 14,  0, 13,  0,  0, 13,  4,  0, 13,  4,  0,
          5,  0,  4, 13, 13, 16,  5,  0, 13, 13, 14, 13,  0, 16],
        [ 0,  6,  0,  1, 10,  0,  0,  0, 10,  4,  4,  2,  9,  4,  3, 13, 14, 13,
         13, 13,  9,  7, 13,  2,  6, 15, 13,  0, 13, 13,  0,  6,  4,  0, 13, 13,
         13, 14,  5,  0,  0,  8, 13, 13, 12,  0,  9, 10, 13, 13]],
       device='cuda:0')

In [47]:
for i in range(topk_indices.size(0)):
    if 13 in top3_indices_after_softmax[:, i]:
        print(topk_indices[i])

tensor(4909, device='cuda:0')
tensor(7759, device='cuda:0')
tensor(2003, device='cuda:0')
tensor(4781, device='cuda:0')
tensor(4646, device='cuda:0')
tensor(8109, device='cuda:0')
tensor(2165, device='cuda:0')
tensor(10524, device='cuda:0')
tensor(14888, device='cuda:0')
tensor(15075, device='cuda:0')
tensor(778, device='cuda:0')
tensor(15191, device='cuda:0')
tensor(2707, device='cuda:0')
tensor(10585, device='cuda:0')
tensor(2121, device='cuda:0')
tensor(9261, device='cuda:0')
tensor(4978, device='cuda:0')
tensor(8262, device='cuda:0')
tensor(1083, device='cuda:0')
tensor(571, device='cuda:0')
tensor(9188, device='cuda:0')
tensor(9561, device='cuda:0')
tensor(7876, device='cuda:0')
tensor(14993, device='cuda:0')
tensor(2288, device='cuda:0')
tensor(11746, device='cuda:0')
tensor(5864, device='cuda:0')
tensor(11973, device='cuda:0')
tensor(498, device='cuda:0')
tensor(4788, device='cuda:0')
tensor(5821, device='cuda:0')
tensor(10971, device='cuda:0')
tensor(10825, device='cuda:0')
ten

In [39]:
max_indices

tensor([13, 13, 13,  0,  9, 13, 16, 13,  9,  3, 13, 13, 13,  0,  4,  3, 13,  0,
        16,  0], device='cuda:0')

In [31]:
cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 6.1795, 0.0000, 0.0000, 7.0834],
       device='cuda:0')

In [33]:
for i, tk in enumerate(model.to_str_tokens(pr)):
    print(i, tk)

0 <bos>
1 What
2  is
3  the
4  output
5  of
6  
7 5
8 3
9  plus
10  
11 3
12 2
13  ?
14  A
15 :
16  


In [None]:
operator = 9
operand = [7, 8, 11, 12]

In [63]:
import random

def generate_example_pair():
    # Generate two random numbers between 1 and 3 digits
    num1 = random.randint(10, 99)
    num2 = random.randint(10, 99)
    
    # Create clean and corrupted examples
    clean_example = f'What is the output of {num1} plus {num2} ? '
    corrupted_example = f'What is the output of {num1} and {num2} ? '
    
    return clean_example, corrupted_example

def generate_dataset(N):
    dataset = []
    for _ in range(N):
        clean, corrupted = generate_example_pair()
        dataset.append((clean, corrupted))
    return dataset

# Example usage
N = 10  # Number of pairs to generate
dataset = generate_dataset(N)

# Print the dataset
for i, (clean, corrupted) in enumerate(dataset):
    print(f"Pair {i+1}:")
    print(f"  Clean:     {clean}")
    print(f"  Corrupted: {corrupted}")
    print()

Pair 1:
  Clean:     What is the output of 96 plus 43 ? 
  Corrupted: What is the output of 96 and 43 ? 

Pair 2:
  Clean:     What is the output of 56 plus 77 ? 
  Corrupted: What is the output of 56 and 77 ? 

Pair 3:
  Clean:     What is the output of 93 plus 56 ? 
  Corrupted: What is the output of 93 and 56 ? 

Pair 4:
  Clean:     What is the output of 29 plus 37 ? 
  Corrupted: What is the output of 29 and 37 ? 

Pair 5:
  Clean:     What is the output of 29 plus 75 ? 
  Corrupted: What is the output of 29 and 75 ? 

Pair 6:
  Clean:     What is the output of 48 plus 18 ? 
  Corrupted: What is the output of 48 and 18 ? 

Pair 7:
  Clean:     What is the output of 91 plus 12 ? 
  Corrupted: What is the output of 91 and 12 ? 

Pair 8:
  Clean:     What is the output of 33 plus 76 ? 
  Corrupted: What is the output of 33 and 76 ? 

Pair 9:
  Clean:     What is the output of 53 plus 83 ? 
  Corrupted: What is the output of 53 and 83 ? 

Pair 10:
  Clean:     What is the output of 39

In [12]:
closest_strings[1]

'layer_1/width_16k/average_l0_102'

In [13]:
# LOAD till 8 
saes = []
for layer in range(9):
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release="gemma-scope-2b-pt-res", 
        sae_id=closest_strings[layer],  
        device=device
    )
    sae.use_error_term = True
    saes.append(sae)

In [93]:
cfg_dict

{'architecture': 'jumprelu',
 'd_in': 2304,
 'd_sae': 16384,
 'dtype': 'float32',
 'model_name': 'gemma-2-2b',
 'hook_name': 'blocks.8.hook_resid_post',
 'hook_layer': 8,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 1024,
 'dataset_trust_remote_code': True,
 'apply_b_dec_to_input': False,
 'normalize_activations': None,
 'device': 'cuda'}

In [31]:
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])

print([(k, v.shape) for k,v in cache.items() if "sae" in k])

print(cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191])

[('blocks.8.hook_resid_post.hook_sae_input', torch.Size([1, 17, 2304])), ('blocks.8.hook_resid_post.hook_sae_acts_pre', torch.Size([1, 17, 16384])), ('blocks.8.hook_resid_post.hook_sae_acts_post', torch.Size([1, 17, 16384])), ('blocks.8.hook_resid_post.hook_sae_recons', torch.Size([1, 17, 2304])), ('blocks.8.hook_resid_post.hook_sae_error', torch.Size([1, 17, 2304])), ('blocks.8.hook_resid_post.hook_sae_output', torch.Size([1, 17, 2304]))]
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 6.1795, 0.0000, 0.0000, 7.0834],
       device='cuda:0')


In [64]:
for i, (clean, corrupted) in enumerate(dataset):
    _, clean_cache = model.run_with_cache_with_saes(clean, saes=saes)
    print(f"Pair {i+1}:")
    print(f"  Clean:     {clean}")
    print(clean_cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191][13])
    _, corr_cache = model.run_with_cache_with_saes(corrupted, saes=saes)
    print(f"  Corrupted: {corrupted}")
    print(corr_cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191][13])

Pair 1:
  Clean:     What is the output of 96 plus 43 ? 
tensor(7.0159, device='cuda:0')
  Corrupted: What is the output of 96 and 43 ? 
tensor(0., device='cuda:0')
Pair 2:
  Clean:     What is the output of 56 plus 77 ? 
tensor(5.5384, device='cuda:0')
  Corrupted: What is the output of 56 and 77 ? 
tensor(0., device='cuda:0')
Pair 3:
  Clean:     What is the output of 93 plus 56 ? 
tensor(6.7896, device='cuda:0')
  Corrupted: What is the output of 93 and 56 ? 
tensor(0., device='cuda:0')
Pair 4:
  Clean:     What is the output of 29 plus 37 ? 
tensor(5.6018, device='cuda:0')
  Corrupted: What is the output of 29 and 37 ? 
tensor(0., device='cuda:0')
Pair 5:
  Clean:     What is the output of 29 plus 75 ? 
tensor(6.8103, device='cuda:0')
  Corrupted: What is the output of 29 and 75 ? 
tensor(0., device='cuda:0')
Pair 6:
  Clean:     What is the output of 48 plus 18 ? 
tensor(8.9317, device='cuda:0')
  Corrupted: What is the output of 48 and 18 ? 
tensor(2.8513, device='cuda:0')
Pair 7

In [78]:
ind = 9
operator_features = {}
_, clean_cache = model.run_with_cache_with_saes(clean, saes=saes)
for layer in range(9):
    operator_features[layer] = []
    x = clean_cache[f'blocks.{layer}.hook_resid_post.hook_sae_acts_post']
    topk_values, topk_indices = torch.topk(x[0, ind, :], 50)
    gathered_values = x[0, :, topk_indices]
    softmaxed_x = F.softmax(gathered_values, dim=0)
    top3_indices_after_softmax = torch.topk(softmaxed_x, 3, dim=0).indices
    for i in range(topk_indices.size(0)):
        if ind in top3_indices_after_softmax[:, i]:
            operator_features[layer].append(topk_indices[i])

In [79]:
total_features_under_comp = 0
for key, val in operator_features.items():
    total_features_under_comp+= len(val)
total_features_under_comp

385

In [80]:
for key, val in operator_features.items():
    print(key, val)

0 [tensor(14495, device='cuda:0'), tensor(9416, device='cuda:0'), tensor(60, device='cuda:0'), tensor(5039, device='cuda:0'), tensor(12108, device='cuda:0'), tensor(6179, device='cuda:0'), tensor(13928, device='cuda:0'), tensor(15396, device='cuda:0'), tensor(9361, device='cuda:0'), tensor(283, device='cuda:0'), tensor(8441, device='cuda:0'), tensor(5043, device='cuda:0'), tensor(3977, device='cuda:0'), tensor(10297, device='cuda:0'), tensor(1717, device='cuda:0'), tensor(15325, device='cuda:0'), tensor(2337, device='cuda:0'), tensor(4470, device='cuda:0'), tensor(7782, device='cuda:0'), tensor(9091, device='cuda:0'), tensor(9069, device='cuda:0'), tensor(750, device='cuda:0'), tensor(9908, device='cuda:0'), tensor(4275, device='cuda:0'), tensor(15871, device='cuda:0'), tensor(15842, device='cuda:0'), tensor(6772, device='cuda:0'), tensor(15706, device='cuda:0'), tensor(10646, device='cuda:0'), tensor(5874, device='cuda:0'), tensor(13170, device='cuda:0'), tensor(13651, device='cuda:0'

In [73]:
# Metric definition



In [41]:
def patch_activations(corr_activations, clean_cache, layer_name, feature_index):
    # Replace the corrupted activations at the specified feature index with the clean one
    patched_activations = corr_activations.clone()  # clone to avoid modifying the original tensor
    clean_activations = clean_cache[layer_name]
    patched_activations[..., feature_index] = clean_activations[..., feature_index]
    return patched_activations

def forward_hook(module, input, output, clean_cache, layer_name, feature_index):
    return patch_activations(output, clean_cache, layer_name, feature_index)


In [74]:
op_clean, clean_cache = model.run_with_cache_with_saes(clean, saes=saes)
op_corr, corr_cache = model.run_with_cache_with_saes(corrupted, saes=saes)

In [75]:
feature_difference(clean_cache, corr_cache)

tensor(7.5431, device='cuda:0')

In [67]:
print(op_clean)
print(feature_difference(clean_cache))

print(op_corr)
print(feature_difference(corr_cache))

tensor([[[-24.3121,  -8.7513,  -6.9737,  ..., -18.3960, -17.4268, -24.3171],
         [-17.6520,  -1.1834, -13.3353,  ..., -11.1017,  -4.4314, -17.6057],
         [-15.1261,  -1.6613,  -1.7596,  ...,  -9.2118, -12.0350, -14.9411],
         ...,
         [-13.9915,  18.2728,   2.8354,  ...,   0.2403,   3.8752, -13.9369],
         [-10.2113,  19.6030,   4.2905,  ...,  -1.8026,   5.7783, -10.1475],
         [ -6.0919,  13.5185,   9.8800,  ...,   4.7390,  11.3725,  -6.1564]]],
       device='cuda:0')
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 7.5431, 2.4472], device='cuda:0')
tensor([[[-24.3121,  -8.7513,  -6.9737,  ..., -18.3960, -17.4268, -24.3171],
         [-17.6520,  -1.1834, -13.3353,  ..., -11.1017,  -4.4314, -17.6057],
         [-15.1261,  -1.6613,  -1.7596,  ...,  -9.2118, -12.0350, -14.9411],
         ...,
         [-11.1365,  16.4926,   6.8349,  ...,   2.2972,   5.6023, -11.1145],
         [ -7.3857,  2

In [69]:
probs = op_corr[0].softmax(dim=-1)
token_probs = probs[-1]
sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)
        # Janky way to get the index of the token in the sorted list - I couldn't find a better way?
for i in range(10):
    # sorted_token_values
    print(model.to_string(sorted_token_values[i]))

3
1



2
4
5
<strong>


6
8


In [99]:
# LOAD till 8 
saes = []
for layer in range(9):
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release="gemma-scope-2b-pt-res", 
        sae_id=closest_strings[layer],  
        device=device
    )
    sae.use_error_term = True
    saes.append(sae)

In [126]:
def get_deep_attr(obj: Any, path: str):
    """Helper function to get a nested attribute from a object.
    In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)

    Args:
        obj: Any object. In practice, this is a HookedTransformer (or subclass)
        path: str. The path to the attribute you want to access. (eg "blocks.0.attn.hook_z")

    returns:
        Any. The attribute at the end of the path
    """
    parts = path.split(".")
    # Navigate to the last component in the path
    for part in parts:
        if part.isdigit():  # This is a list index
            obj = obj[int(part)]
        else:  # This is an attribute
            obj = getattr(obj, part)
    return obj

def replace_with_zeros(layer, feature_ind, clean_cache):
    """
    Forward hook function to replace activations in the corrupted run
    with zeros instead of the clean activations.
    
    Args:
        layer (int): The layer number where the intervention should take place.
        feature_ind (int): The feature index to replace.
        clean_cache (ActivationCache): The cache from the clean run (not used in this case).
    """
    def hook_fn(module, input, output):
        # Create a tensor of zeros with the same shape as the specific output slice
        zero_tensor = torch.zeros_like(output[0, :, feature_ind])
        
        # Replace the corresponding activations in the output with zeros
        output[0, :, feature_ind] = zero_tensor
        
        return output
    
    return hook_fn

# Define the hook function
def replace_with_clean_activations(layer, feature_ind, clean_cache):
    """
    Forward hook function to replace activations in the corrupted run
    with those from the clean run.
    
    Args:
        layer (int): The layer number where the intervention should take place.
        feature_ind (int): The feature index to replace.
        clean_cache (ActivationCache): The cache from the clean run.
    """
    def hook_fn(module, input, output):
        # Extract the clean activations
        clean_acts = clean_cache[f'blocks.{layer}.hook_resid_post.hook_sae_acts_post'][0, :, feature_ind]
        # print(output[:, :, feature_ind])
        # print(clean_acts)
        # Replace the corresponding activations in the output of the corrupted run
        output[0, :, feature_ind] = clean_acts
        # print(output[0, :, feature_ind])
        return output
    
    return hook_fn

def subtract_and_mean(tensor1, tensor2):
    """
    Subtract two tensors element-wise and return the mean of the resulting tensor.
    
    Args:
        tensor1 (torch.Tensor): The first tensor.
        tensor2 (torch.Tensor): The second tensor.
        
    Returns:
        torch.Tensor: The mean of the element-wise subtraction of the two tensors.
    """
    # Ensure that both tensors have the same shape
    assert tensor1.shape == tensor2.shape, "Tensors must have the same shape"
    
    # Subtract the tensors element-wise
    difference = tensor1 - tensor2
    
    # Calculate the mean of the difference
    mean_difference = torch.mean(difference)
    
    return mean_difference



def feature_acts(cache):
    return cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191][13]

def overall_feature(cache):
    return cache['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191]

def feature_difference(cache1, cache2):
    return feature_acts(cache1) - feature_acts(cache2)

def overall_feature_difference(cache1, cache2):
    return subtract_and_mean(cache1['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191], cache2['blocks.8.hook_resid_post.hook_sae_acts_post'][0, :, 15191])
# # Specify the layer number and feature index
# layer = 2  # Example: replace activations at layer 2
# feature_ind = 5  # Example: replace activations at feature index 5




op_clean, clean_cache = model.run_with_cache_with_saes(clean, saes=saes)
op_corr, corr_cache = model.run_with_cache_with_saes(corrupted, saes=saes)

print("og logit", op_clean.shape)
print("Original feature diff: ", feature_difference(clean_cache, corr_cache))

for key, val in operator_features.items():
    layer = key
    for ind in range(10000):
        feature_ind = ind
        hook = replace_with_clean_activations(layer, feature_ind, clean_cache)
        # with model.saes(saes=saes, reset_saes_end=True):
            # Now that SAEs are attached, access the hook point
        hook_point = get_deep_attr(saes[layer], 'hook_sae_acts_post')
        # print(hook_point)
        # Register the hook
        hook_handle = hook_point.register_forward_hook(hook)
        # print(hook_handle)
        # Run the corrupted prompt with the hook applied
        inter_op, corr_cache_intervened = model.run_with_cache_with_saes(corrupted, saes=saes)
        # print("intervened logit", inter_op[0, -1, :10])
        # Remove the hook after the forward pass
        hook_handle.remove()
        
        # print(subtract_and_mean(op_corr, inter_op))
        if overall_feature_difference(corr_cache_intervened, corr_cache) != 0: 
            print("UPPPPPPPP")
        # print(overall_feature_difference(corr_cache_intervened, corr_cache))
            
            
        # hook_point = get_deep_attr(model, f'blocks.{layer}.hook_resid_post.hook_sae_acts_post')
        # print(hook_point)
        # hook_handle = hook_point.register_forward_hook(hook)
        # print(hook_handle)
        # _, corr_cache_intervened = model.run_with_cache_with_saes(corrupted, saes=saes)
        # hook_handle.remove()
        # print("PLESAAE CHANGE ONCE JUST ONCE :", feature_acts(corr_cache_intervened))
        # print("PLESAAE CHANGE ONCE JUST ONCE :", feature_difference(clean_cache, corr_cache_intervened))


og logit torch.Size([1, 15, 256000])
Original feature diff:  tensor(7.5431, device='cuda:0')


KeyboardInterrupt: 

In [None]:
for key, val in operator_features.items():
    layer = key
    for feature_ind in val:
        print(feature_ind.item())
        break

In [127]:
# Corr -> clean 

_, clean_cache = model.run_with_cache_with_saes(clean, saes=saes)
_, corr_cache = model.run_with_cache_with_saes(corrupted, saes=saes)

print("Original feature difference: ", feature_difference(clean_cache, corr_cache))

# Dictionary to store the differences
differences = []

for key, val in operator_features.items():
    layer = key
    print("layer: ", layer)
    for feature_ind in range(1000):
        hook = replace_with_clean_activations(layer, feature_ind, corr_cache)
        
        # Now that SAEs are attached, access the hook point
        hook_point = get_deep_attr(saes[layer], 'hook_sae_acts_post')
        
        # Register the hook
        hook_handle = hook_point.register_forward_hook(hook)
        
        # Run the corrupted prompt with the hook applied
        _, clean_cache_intervened = model.run_with_cache_with_saes(clean, saes=saes)
        
        # Remove the hook after the forward pass
        hook_handle.remove()
        
        # Calculate the overall feature difference
        diff = overall_feature_difference(clean_cache_intervened, clean_cache)
        
        # Store the difference along with the corresponding layer and feature index
        differences.append((layer, feature_ind, diff))

# Sort the differences to get the top 10
top_differences = sorted(differences, key=lambda x: x[2], reverse=True)[:10]

# Print the top 10 features with their differences
print("Top 10 features with the highest overall_feature_difference:")
for layer, feature_ind, diff in top_differences:
    print(f"Layer: {layer}, Feature Index: {feature_ind}, Overall Feature Difference: {diff}")


Original feature difference:  tensor(7.5431, device='cuda:0')
layer:  0
layer:  1
layer:  2
layer:  3
layer:  4
layer:  5
layer:  6
layer:  7
layer:  8
Top 10 features with the highest overall_feature_difference:
Layer: 1, Feature Index: 172, Overall Feature Difference: 5.880991693629767e-07
Layer: 1, Feature Index: 202, Overall Feature Difference: 5.722046125811175e-07
Layer: 5, Feature Index: 697, Overall Feature Difference: 5.722046125811175e-07
Layer: 5, Feature Index: 112, Overall Feature Difference: 5.404154990173993e-07
Layer: 0, Feature Index: 133, Overall Feature Difference: 5.245208853921213e-07
Layer: 2, Feature Index: 616, Overall Feature Difference: 5.245208853921213e-07
Layer: 1, Feature Index: 534, Overall Feature Difference: 5.086263286102621e-07
Layer: 4, Feature Index: 646, Overall Feature Difference: 5.086263286102621e-07
Layer: 2, Feature Index: 498, Overall Feature Difference: 4.92731771828403e-07
Layer: 4, Feature Index: 920, Overall Feature Difference: 4.92731771

In [125]:
# Corr -> clean 

op_clean, clean_cache = model.run_with_cache_with_saes(clean, saes=saes)
op_corr, corr_cache = model.run_with_cache_with_saes(corrupted, saes=saes)

print("og logit", op_clean.shape)
print("Original feature diff: ", feature_difference(clean_cache, corr_cache))

for key, val in operator_features.items():
    layer = key
    for feature_ind in val:
        feature_ind = feature_ind.item()
        hook = replace_with_zeros(layer, feature_ind, corr_cache)
        # with model.saes(saes=saes, reset_saes_end=True):
            # Now that SAEs are attached, access the hook point
        hook_point = get_deep_attr(saes[layer], 'hook_sae_acts_post')
        # print(hook_point)
        # Register the hook
        hook_handle = hook_point.register_forward_hook(hook)
        # print(hook_handle)
        # Run the corrupted prompt with the hook applied
        inter_op, clean_cache_intervened = model.run_with_cache_with_saes(clean, saes=saes)
        # print("intervened logit", inter_op[0, -1, :10])
        # Remove the hook after the forward pass
        hook_handle.remove()
        
        # print(subtract_and_mean(op_corr, inter_op))
        # print("PLESAAE CHANGE ONCE JUST ONCE :", feature_difference(clean_cache_intervened, corr_cache))
        if overall_feature_difference(clean_cache_intervened, clean_cache) != 0: 
            print("UPPPPPPPP")
            print(overall_feature_difference(clean_cache_intervened, clean_cache))
            # print("Og :", overall_feature(clean_cache))
            # print("Intervened :", overall_feature(clean_cache_intervened))

og logit torch.Size([1, 15, 256000])
Original feature diff:  tensor(7.5431, device='cuda:0')
UPPPPPPPP
tensor(5.8810e-07, device='cuda:0')
UPPPPPPPP
tensor(2.3842e-07, device='cuda:0')
UPPPPPPPP
tensor(4.6094e-07, device='cuda:0')
UPPPPPPPP
tensor(6.8347e-07, device='cuda:0')
UPPPPPPPP
tensor(3.0200e-07, device='cuda:0')
UPPPPPPPP
tensor(-1.5895e-08, device='cuda:0')
UPPPPPPPP
tensor(1.5895e-07, device='cuda:0')
UPPPPPPPP
tensor(7.9473e-08, device='cuda:0')
UPPPPPPPP
tensor(2.5431e-07, device='cuda:0')
UPPPPPPPP
tensor(1.5895e-07, device='cuda:0')
UPPPPPPPP
tensor(-6.3578e-08, device='cuda:0')
UPPPPPPPP
tensor(1.1126e-07, device='cuda:0')
UPPPPPPPP
tensor(1.5895e-08, device='cuda:0')
UPPPPPPPP
tensor(1.7484e-07, device='cuda:0')
UPPPPPPPP
tensor(1.5895e-07, device='cuda:0')
UPPPPPPPP
tensor(6.1989e-07, device='cuda:0')
UPPPPPPPP
tensor(-4.7684e-08, device='cuda:0')
UPPPPPPPP
tensor(2.5431e-07, device='cuda:0')
UPPPPPPPP
tensor(3.8147e-07, device='cuda:0')
UPPPPPPPP
tensor(1.2716e-07, d

In [None]:
act_names_to_reset = []
prev_saes = []
for sae in saes:
    print(sae.cfg.hook_name)
    act_names_to_reset.append(sae.cfg.hook_name)
    prev_sae = model.acts_to_saes.get(sae.cfg.hook_name, None)
    print(model.acts_to_saes)
    prev_saes.append(prev_sae)
    model.add_sae(sae, use_error_term=True)

In [None]:
results = []

for i, (clean, corrupted) in enumerate(dataset):
    _, clean_cache = model.run_with_cache_with_saes(clean, saes=saes)
    _, corr_cache = model.run_with_cache_with_saes(corrupted, saes=saes)
    
    hook_handle = model.get_layer(layer_name).register_forward_hook(
        lambda module, input, output: forward_hook(module, input, output, clean_cache, layer_name, feature_index)
    )
    
    patched_output, _ = model(corrupted)
    hook_handle.remove()
    
    metric_diff = feature_difference(clean_cache, corr_cache)
    results.append(metric_diff)


In [None]:
def patch_activations(corr_activations, clean_cache, layer_name, feature_index):
    # Replace the corrupted activations at the specified feature index with the clean one
    patched_activations = corr_activations.clone()  # clone to avoid modifying the original tensor
    clean_activations = clean_cache[layer_name]
    patched_activations[..., feature_index] = clean_activations[..., feature_index]
    return patched_activations

def forward_hook(module, input, output, clean_cache, layer_name, feature_index):
    return patch_activations(output, clean_cache, layer_name, feature_index)


# Example layer name and feature index
layer_name = 'blocks.8.hook_resid_post.hook_sae_acts_post'
feature_index = 15191

# Register the hook
hook_handle = model.get_layer(layer_name).register_forward_hook(
    lambda module, input, output: forward_hook(module, input, output, clean_cache, layer_name, feature_index)
)

# Run the model on the corrupted input with the hook applied
patched_output, _ = model(corrupted)

# Remove the hook after running
hook_handle.remove()
