In [1]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import torch
from collections import defaultdict

# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer
os.environ["HF_TOKEN"] = "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"
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device = device)


from transformer_lens.utils import test_prompt

prompt = "What is the output of 53 plus 34 ? It is "
answer = '8'
# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, model)

Device: cuda




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer
Tokenized prompt: ['<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '3', ' plus', ' ', '3', '4', ' ?', ' It', ' is', ' ']
Tokenized answer: [' ', '8']


Top 0th token. Logit: 27.23 Prob: 54.42% Token: |8|
Top 1th token. Logit: 25.61 Prob: 10.80% Token: |5|
Top 2th token. Logit: 25.38 Prob:  8.58% Token: |1|
Top 3th token. Logit: 24.99 Prob:  5.78% Token: |2|
Top 4th token. Logit: 24.60 Prob:  3.92% Token: |3|
Top 5th token. Logit: 24.43 Prob:  3.30% Token: |<strong>|
Top 6th token. Logit: 24.13 Prob:  2.47% Token: |4|
Top 7th token. Logit: 24.03 Prob:  2.23% Token: |________________|
Top 8th token. Logit: 23.97 Prob:  2.09% Token: |6|
Top 9th token. Logit: 23.95 Prob:  2.06% Token: |7|


Top 0th token. Logit: 25.26 Prob: 26.73% Token: |8|
Top 1th token. Logit: 24.54 Prob: 13.10% Token: |1|
Top 2th token. Logit: 24.53 Prob: 12.87% Token: |5|
Top 3th token. Logit: 24.14 Prob:  8.76% Token: |2|
Top 4th token. Logit: 23.82 Prob:  6.35% Token: |<strong>|
Top 5th token. Logit: 23.81 Prob:  6.29% Token: |3|
Top 6th token. Logit: 23.60 Prob:  5.12% Token: |6|
Top 7th token. Logit: 23.47 Prob:  4.46% Token: |4|
Top 8th token. Logit: 23.44 Prob:  4.36% Token: |________________|
Top 9th token. Logit: 23.36 Prob:  4.01% Token: |7|


In [2]:
import re
from collections import defaultdict
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
df[df['model']=='gemma-2-2b']
sae_keys = list(df.loc['gemma-scope-2b-pt-res']['saes_map'].keys())
# Dictionary to store the closest string for each layer
closest_strings = {}

# Regular expression to extract the layer number and l0 value
pattern = re.compile(r'layer_(\d+)/width_16k/average_l0_(\d+)')

# Organize strings by layer
layer_dict = defaultdict(list)

for s in sae_keys:
    match = pattern.search(s)
    if match:
        layer = int(match.group(1))
        l0_value = int(match.group(2))
        layer_dict[layer].append((s, l0_value))

# Find the string with l0 value closest to 100 for each layer
for layer, items in layer_dict.items():
    closest_string = min(items, key=lambda x: abs(x[1] - 100))
    closest_strings[layer] = closest_string[0]

# Output the closest string for each layer
for layer in sorted(closest_strings):
    print(f"Layer {layer}: {closest_strings[layer]}")


Layer 0: layer_0/width_16k/average_l0_105
Layer 1: layer_1/width_16k/average_l0_102
Layer 2: layer_2/width_16k/average_l0_141
Layer 3: layer_3/width_16k/average_l0_59
Layer 4: layer_4/width_16k/average_l0_124
Layer 5: layer_5/width_16k/average_l0_68
Layer 6: layer_6/width_16k/average_l0_70
Layer 7: layer_7/width_16k/average_l0_69
Layer 8: layer_8/width_16k/average_l0_71
Layer 9: layer_9/width_16k/average_l0_73
Layer 10: layer_10/width_16k/average_l0_77
Layer 11: layer_11/width_16k/average_l0_80
Layer 12: layer_12/width_16k/average_l0_82
Layer 13: layer_13/width_16k/average_l0_84
Layer 14: layer_14/width_16k/average_l0_84
Layer 15: layer_15/width_16k/average_l0_78
Layer 16: layer_16/width_16k/average_l0_78
Layer 17: layer_17/width_16k/average_l0_77
Layer 18: layer_18/width_16k/average_l0_74
Layer 19: layer_19/width_16k/average_l0_73
Layer 20: layer_20/width_16k/average_l0_71
Layer 21: layer_21/width_16k/average_l0_70
Layer 22: layer_22/width_16k/average_l0_72
Layer 23: layer_23/width_16

In [4]:
import random

def generate_number_variations(prompt, num_variations=5):
    variations = []
    for _ in range(num_variations):
        num1 = random.randint(10, 70)
        num2 = random.randint(10, 70)
        new_prompt = prompt.replace("53", str(num1)).replace("34", str(num2))
        tokens = model.to_str_tokens(new_prompt)
        # print(tokens)
        # num1_pos = tokens.index(str(num1)[0])
        # num2_pos = tokens.index(str(num2)[0])
        # operation_pos = tokens.index(' plus')
        # question_pos = tokens.index(' ?')
        variations.append(new_prompt)
    return variations

prompt = "What is the output of 53 plus 34 ? It is "
number_variations = generate_number_variations(prompt, 20)
number_variations
# 7, 8, 9, 11, 12, 13, -1 

['What is the output of 38 plus 33 ? It is ',
 'What is the output of 58 plus 19 ? It is ',
 'What is the output of 67 plus 44 ? It is ',
 'What is the output of 49 plus 25 ? It is ',
 'What is the output of 32 plus 61 ? It is ',
 'What is the output of 16 plus 68 ? It is ',
 'What is the output of 61 plus 54 ? It is ',
 'What is the output of 48 plus 39 ? It is ',
 'What is the output of 64 plus 40 ? It is ',
 'What is the output of 10 plus 14 ? It is ',
 'What is the output of 70 plus 41 ? It is ',
 'What is the output of 48 plus 61 ? It is ',
 'What is the output of 67 plus 56 ? It is ',
 'What is the output of 28 plus 21 ? It is ',
 'What is the output of 10 plus 56 ? It is ',
 'What is the output of 43 plus 46 ? It is ',
 'What is the output of 38 plus 62 ? It is ',
 'What is the output of 56 plus 42 ? It is ',
 'What is the output of 31 plus 36 ? It is ',
 'What is the output of 58 plus 50 ? It is ']

In [3]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # <- Release name 
    sae_id = "layer_8/width_16k/average_l0_71", # <- SAE id (not always a hook point!)
    device = device
)
pr = 'What is the output of 53 plus 32 ? A: '

# Data

In [4]:
import random

def generate_example_pair():
    # Generate two random numbers between 1 and 3 digits
    num1 = random.randint(10, 99)
    num2 = random.randint(10, 99)
    
    # Create clean and corrupted examples
    clean_example = f'What is the output of {num1} plus {num2} ? '
    corrupted_example = f'What is the output of {num1} and {num2} ? '
    
    return clean_example, corrupted_example

def generate_dataset(N):
    dataset = []
    for _ in range(N):
        clean, corrupted = generate_example_pair()
        dataset.append((clean, corrupted))
    return dataset

# Example usage
N = 100  # Number of pairs to generate
dataset = generate_dataset(N)

# Print the dataset
for i, (clean, corrupted) in enumerate(dataset):
    print(f"Pair {i+1}:")
    print(f"  Clean:     {clean}")
    print(f"  Corrupted: {corrupted}")
    print()
    if i>10:
        break

Pair 1:
  Clean:     What is the output of 54 plus 67 ? 
  Corrupted: What is the output of 54 and 67 ? 

Pair 2:
  Clean:     What is the output of 54 plus 84 ? 
  Corrupted: What is the output of 54 and 84 ? 

Pair 3:
  Clean:     What is the output of 75 plus 20 ? 
  Corrupted: What is the output of 75 and 20 ? 

Pair 4:
  Clean:     What is the output of 26 plus 80 ? 
  Corrupted: What is the output of 26 and 80 ? 

Pair 5:
  Clean:     What is the output of 85 plus 29 ? 
  Corrupted: What is the output of 85 and 29 ? 

Pair 6:
  Clean:     What is the output of 81 plus 33 ? 
  Corrupted: What is the output of 81 and 33 ? 

Pair 7:
  Clean:     What is the output of 87 plus 85 ? 
  Corrupted: What is the output of 87 and 85 ? 

Pair 8:
  Clean:     What is the output of 92 plus 41 ? 
  Corrupted: What is the output of 92 and 41 ? 

Pair 9:
  Clean:     What is the output of 11 plus 95 ? 
  Corrupted: What is the output of 11 and 95 ? 

Pair 10:
  Clean:     What is the output of 54

In [5]:
sae.cfg.hook_name

'blocks.8.hook_resid_post'

In [6]:
clean_pr = []
corr_pr = []
for i, (clean, corrupted) in enumerate(dataset):
    clean_pr.append(clean)
    corr_pr.append(corrupted)

# Code to run a feature from normal cache 

In [25]:
def run_model_till_feature(prompt):
    _, cache = model.run_with_cache(
            prompt, 
            stop_at_layer=sae.cfg.hook_layer + 1, 
            names_filter=[sae.cfg.hook_name])
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()
    return feature_acts[:, 15191][-2:].sum()

In [28]:
for i, (clean, corrupted) in enumerate(dataset):
    print("Clean: ", run_model_till_feature(clean))
    print("Corrupted: ", run_model_till_feature(corrupted))

Clean:  tensor(5.4081, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(6.6722, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(9.2776, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(6.1121, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(8.4368, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(5.7264, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(6.8089, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(9.0535, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(9.7028, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(10.7939, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(8.1387, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(6.3161, device='cuda:0')
Corrupted:  tensor(0., device='cuda:0')
Clean:  tensor(6.7129, device='cuda:0')

In [7]:
sae.cfg.hook_name

'blocks.8.hook_resid_post'

In [14]:
_, cache = model.run_with_cache(
            number_variations[0:5], 
            stop_at_layer=sae.cfg.hook_layer + 1, 
            names_filter=[sae.cfg.hook_name])

In [16]:
cache[sae.cfg.hook_name].shape

torch.Size([5, 17, 2304])

In [18]:
sae_in = cache[sae.cfg.hook_name]
feature_acts = sae.encode(sae_in).squeeze()
feature_acts = feature_acts.flatten(0, 1)
feature_acts.shape

torch.Size([85, 16384])

In [20]:
feature_acts[:, 15191].max().item()

12.202563285827637

In [11]:
sae_in = cache[sae.cfg.hook_name]
feature_acts = sae.encode(sae_in).squeeze()

In [None]:
'<bos>', 'What', ' is', ' the', ' output', ' of', ' ', '5', '3', ' plus', ' ', '3', '4', ' ?', ' It', ' is', ' '

In [13]:
feature_acts[:, 15191] #.shape

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  2.4291,  0.0000,  0.0000,  7.2193,  0.0000,  2.6175,
        11.4568], device='cuda:0')

In [6]:
sae.W_dec[15191]#.to(model.cfg.device)

tensor([ 0.0349,  0.0009,  0.0382,  ..., -0.0309, -0.0299,  0.0056],
       device='cuda:0', requires_grad=True)

# Logit diff replaced by feature activation 

In [10]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.HookedTransformer import HookedTransformer

from __future__ import annotations

import itertools
from functools import partial
from typing import Callable, Optional, Sequence, Tuple, Union, overload

import einops
import pandas as pd
import torch
from jaxtyping import Float, Int
from tqdm.auto import tqdm
from typing_extensions import Literal

import types
from transformer_lens.utils import Slice, SliceInput

import functools

In [11]:
def run_with_cache_with_extra_hook(
    self,
    *model_args: Any,
    current_activation_name: str,
    current_hook: Any,
    names_filter: NamesFilter = None,
    device: DeviceType = None,
    remove_batch_dim: bool = False,
    incl_bwd: bool = False,
    reset_hooks_end: bool = True,
    clear_contexts: bool = False,
    pos_slice: Optional[Union[Slice, SliceInput]] = None,
    **model_kwargs: Any,
):
    """
    Runs the model and returns the model output and a Cache object.
    
    Adds an extra forward hook (current_activation_name, current_hook) to the hooks.

    Args:
        *model_args: Positional arguments for the model.
        current_activation_name: The name of the activation to hook.
        current_hook: The hook function to use.
        names_filter (NamesFilter, optional): A filter for which activations to cache.
        device (str or torch.Device, optional): The device to cache activations on.
        remove_batch_dim (bool, optional): If True, removes the batch dimension when caching.
        incl_bwd (bool, optional): If True, caches gradients as well.
        reset_hooks_end (bool, optional): If True, removes all hooks added by this function.
        clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
        pos_slice: The slice to apply to the cache output. Defaults to None.
        **model_kwargs: Keyword arguments for the model.

    Returns:
        tuple: A tuple containing the model output and a Cache object.
    """

    pos_slice = Slice.unwrap(pos_slice)

    # Get the caching hooks
    cache_dict, fwd, bwd = self.get_caching_hooks(
        names_filter,
        incl_bwd,
        device,
        remove_batch_dim=remove_batch_dim,
        pos_slice=pos_slice,
    )

    # Add the extra forward hook
    fwd_hooks = [(current_activation_name, current_hook)] + fwd

    # Run the model with the hooks
    with self.hooks(
        fwd_hooks=fwd_hooks,
        bwd_hooks=bwd,
        reset_hooks_end=reset_hooks_end,
        clear_contexts=clear_contexts,
    ):
        model_out = self(*model_args, **model_kwargs)
        if incl_bwd:
            model_out.backward()

    return model_out, cache_dict


In [12]:

# Attach the new method to the model instance
model.run_with_cache_with_extra_hook = types.MethodType(run_with_cache_with_extra_hook, model)


In [35]:


def generic_activation_patch(
    model: HookedTransformer,
    corrupted_tokens: Int[torch.Tensor, "batch pos"],
    clean_cache: ActivationCache,
    patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
    patch_setter: Callable[
        [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
    ],
    activation_name: str,
    index_axis_names: Optional[Sequence[AxisNames]] = None,
    index_df: Optional[pd.DataFrame] = None,
    return_index_df: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]:
    """
    A generic function to do activation patching, will be specialised to specific use cases.

    Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted.

    Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered.

    The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer.

    This function then iterates over every tuple of indices, does the relevant patch, and stores it

    Args:
        model: The relevant model
        corrupted_tokens: The input tokens for the corrupted run
        clean_cache: The cached activations from the clean run
        patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
        patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation
        activation_name: The name of the activation being patched
        index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df
        index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df
        return_index_df: A Boolean flag for whether to return the dataframe of indices too

    Returns:
        patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row.
        index_df *optional*: The dataframe of indices
    """

    if index_df is None:
        assert index_axis_names is not None

        # Get the max range for all possible axes
        max_axis_range = {
            "layer": model.cfg.n_layers,
            "pos": corrupted_tokens.shape[-1],
            "head_index": model.cfg.n_heads,
        }
        max_axis_range["src_pos"] = max_axis_range["pos"]
        max_axis_range["dest_pos"] = max_axis_range["pos"]
        max_axis_range["head"] = max_axis_range["head_index"]

        # Get the max range for each axis we iterate over
        index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names]

        # Get the dataframe where each row is a tuple of indices
        index_df = transformer_lens.patching.make_df_from_ranges(index_axis_max_range, index_axis_names)

        flattened_output = False
    else:
        # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names
        assert index_axis_names is None
        index_axis_max_range = index_df.max().to_list()

        flattened_output = True

    # Create an empty tensor to show the patched metric for each patch
    if flattened_output:
        patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device)
    else:
        patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device)

    # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation
    def patching_hook(corrupted_activation, hook, index, clean_activation):
        return patch_setter(corrupted_activation, index, clean_activation)

    # Iterate over every list of indices, and make the appropriate patch!
    for c, index_row in enumerate(tqdm((list(index_df.iterrows())))):
        index = index_row[1].to_list()

        # The current activation name is just the activation name plus the layer (assumed to be the first element of the input)
        current_activation_name = utils.get_act_name(activation_name, layer=index[0])

        # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation
        current_hook = partial(
            patching_hook,
            index=index,
            clean_activation=clean_cache[current_activation_name],
        )
        
#         incl_bwd = False
#         cache_dict, fwd, bwd = model.get_caching_hooks(
#             incl_bwd=incl_bwd,
#             device=device,
#             names_filter=None
#         )
        
#         fwd_hooks = [(current_activation_name, current_hook)] + fwd
        # Run the model with the patching hook and get the logits!
        # patched_logits, patched_cache = "", ""
        
        patched_logits, patched_cache = model.run_with_cache_with_extra_hook(
            corrupted_tokens, 
            current_activation_name=current_activation_name, 
            current_hook= current_hook
        )
        # print(patched_cache.keys())
        # print(patched_logits.shape)

        # Calculate the patching metric and store
        if flattened_output:
            patched_metric_output[c] = patching_metric(patched_cache).item()
        else:
            patched_metric_output[tuple(index)] = patching_metric(patched_cache).item()

    if return_index_df:
        return patched_metric_output, index_df
    else:
        return patched_metric_output

def layer_pos_patch_setter(corrupted_activation, index, clean_activation):
    """
    Applies the activation patch where index = [layer, pos]

    Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor.
    """
    assert len(index) == 2
    layer, pos = index
    corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...]
    return corrupted_activation
    
get_act_patch_resid_pre = partial(
    generic_activation_patch,
    patch_setter=layer_pos_patch_setter,
    activation_name="resid_pre",
    index_axis_names=("layer", "pos"),
)

In [36]:
def equal_feature_metric(cache):
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in)
    # print(feature_acts.shape)
    feature_acts = feature_acts.squeeze()
    return feature_acts[:, :, 15191][-2:].sum()

In [15]:
clean_tokens = model.to_tokens(clean_pr)
corrupted_tokens = model.to_tokens(corr_pr)

In [16]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [None]:
resid_pre_act_patch_results = get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, equal_feature_metric)


In [29]:
resid_pre_act_patch_results

tensor([[ 2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,
          2.8504, 23.5642,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504],
        [ 2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,
          2.8504, 23.5370,  2.7963,  2.8199,  2.8341,  2.7722,  2.8504],
        [ 2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,
          2.8504, 24.1331,  2.8042,  2.8268,  5.2053,  2.6619,  2.8504],
        [ 2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,
          2.8504, 23.6086,  2.9299,  5.2871,  3.0597,  2.6450,  2.8504],
        [ 2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,
          2.8504, 20.0806,  5.5723,  5.6269,  3.1798,  2.6530,  2.8504],
        [ 2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,
          2.8504, 14.6763,  6.5523,  6.2608,  6.0508,  5.4315,  2.8504],
        [ 2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8504,  2.8

In [27]:
from neel_plotly import line, imshow, scatter

In [31]:
fig = imshow(
    resid_pre_act_patch_results, 
    yaxis="Layer", 
    xaxis="Position", 
    x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
    title="resid_pre Activation Patching",
    return_fig=True  # This ensures the figure object is returned
)

fig.write_image("resid_pre_activation_patching.png")


In [32]:
def layer_head_vector_patch_setter(
    corrupted_activation,
    index,
    clean_activation,
):
    """
    Applies the activation patch where index = [layer,  head_index]

    Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
    """
    assert len(index) == 2
    layer, head_index = index
    corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]

    return corrupted_activation

get_act_patch_attn_head_out_all_pos = partial(
    generic_activation_patch,
    patch_setter=layer_head_vector_patch_setter,
    activation_name="z",
    index_axis_names=("layer", "head"),
)

In [None]:
attn_head_out_all_pos_act_patch_results = get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, equal_feature_metric)
fig = imshow(attn_head_out_all_pos_act_patch_results,  
       yaxis="Layer", 
       xaxis="Head", 
       title="attn_head_out Activation Patching (All Pos)", 
        return_fig=True)

fig.write_image("attn_head_out Activation Patching All Pos.png")

In [None]:
t

In [37]:
def layer_pos_head_vector_patch_setter(
    corrupted_activation,
    index,
    clean_activation,
):
    """
    Applies the activation patch where index = [layer, pos, head_index]

    Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
    """
    assert len(index) == 3
    layer, pos, head_index = index
    corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index]
    return corrupted_activation

get_act_patch_attn_head_out_by_pos = partial(
    generic_activation_patch,
    patch_setter=layer_pos_head_vector_patch_setter,
    activation_name="z",
    index_axis_names=("layer", "pos", "head"),
)

In [39]:
DO_SLOW_RUNS = True
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_out_act_patch_results = get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, equal_feature_metric)
    attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
    fig = imshow(attn_head_out_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos", 
        return_fig=True)
    fig.write_image("attn_head_out_act_patch_results.png")

  0%|          | 0/3120 [00:00<?, ?it/s]

In [42]:
attn_head_out_act_patch_results.shape

torch.Size([208, 15])

In [48]:
# Assuming attn_head_out_act_patch_results is your tensor
sliced_results = attn_head_out_act_patch_results[:72, -7:]
# Adjust the y-axis labels for the first 72 elements
sliced_y_labels = ALL_HEAD_LABELS[:72]

# Adjust the x-axis labels for the last 7 positions
sliced_x_labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))][-7:]

fig = imshow(
    sliced_results, 
    yaxis="Head Label", 
    xaxis="Pos", 
    x=sliced_x_labels,
    y=sliced_y_labels,
    title="attn_head_out Activation Patching By Pos", 
    width=1000,  # Increase the width of the figure
    height=1200,  # Increase the height of the figure
    return_fig=True
)

# Optionally, you can adjust the tickfont size for better readability
fig.update_layout(
    yaxis=dict(tickfont=dict(size=10)),  # Adjust the size as needed
    xaxis=dict(tickfont=dict(size=10))   # Adjust the size as needed
)

# Save the figure
fig.write_image("attn_head_out_act_patch_results_sliced.png")


In [49]:
import torch

# Assuming sliced_results is your sliced tensor
mean_value = sliced_results.mean().item()
std_dev = sliced_results.std().item()

# Calculate the threshold for one standard deviation away from the mean
lower_threshold = mean_value - std_dev
upper_threshold = mean_value + std_dev

# Identify the indices where the values are one standard deviation away from the mean
indices = (sliced_results < lower_threshold) | (sliced_results > upper_threshold)
y_indices, x_indices = torch.where(indices)

# Extract the corresponding y labels, x labels, and values
tuples_list = [
    (sliced_y_labels[y_idx], sliced_x_labels[x_idx], sliced_results[y_idx, x_idx].item())
    for y_idx, x_idx in zip(y_indices, x_indices)
]

# Display the tuples
tuples_list


[('L2H1', '6 11', 5.30517578125),
 ('L2H5', ' plus 9', 6.3087568283081055),
 ('L3H2', ' plus 9', 5.351902008056641),
 ('L3H4', ' plus 9', 5.247307777404785),
 ('L3H6', '6 11', 5.364980697631836),
 ('L4H0', ' ? 13', 5.629123210906982),
 ('L4H1', '  10', 5.518120765686035),
 ('L4H1', '6 11', 5.378449440002441),
 ('L4H1', '7 12', 5.5230021476745605),
 ('L4H2', '7 12', 5.488124370574951),
 ('L4H6', ' ? 13', 5.429757118225098),
 ('L5H0', '7 12', 5.438570976257324),
 ('L5H1', '  10', 5.3226399421691895),
 ('L5H1', '6 11', 5.492075443267822),
 ('L5H1', '7 12', 5.903805732727051),
 ('L5H1', ' ? 13', 5.319910049438477),
 ('L5H4', ' ? 13', 5.545135498046875),
 ('L6H0', ' plus 9', 5.626616954803467),
 ('L6H0', ' ? 13', 0.0),
 ('L6H4', '7 12', 5.299982070922852),
 ('L6H5', '7 12', 5.293205261230469),
 ('L7H0', ' ? 13', 5.8349199295043945),
 ('L7H1', ' ? 13', 5.990413188934326),
 ('L7H3', ' ? 13', 6.422385215759277),
 ('L7H4', '7 12', 5.271982192993164),
 ('L7H7', ' ? 13', 5.491620063781738),
 ('L8

In [63]:
# Function to convert L2H1 format to (2, 1)
def convert_to_tuple(layer_head_str):
    layer = int(layer_head_str[1])
    head = int(layer_head_str[3])
    return (layer, head)

# Convert the first element of each tuple in the list
converted_tuples = [convert_to_tuple(item[0]) for item in tuples_list] #[(convert_to_tuple(item[0]), item[1], item[2]) for item in tuples_list]

# Display the result
converted_tuples

[(2, 1),
 (2, 5),
 (3, 2),
 (3, 4),
 (3, 6),
 (4, 0),
 (4, 1),
 (4, 1),
 (4, 1),
 (4, 2),
 (4, 6),
 (5, 0),
 (5, 1),
 (5, 1),
 (5, 1),
 (5, 1),
 (5, 4),
 (6, 0),
 (6, 0),
 (6, 4),
 (6, 5),
 (7, 0),
 (7, 1),
 (7, 3),
 (7, 4),
 (7, 7),
 (8, 3),
 (8, 5)]

In [51]:
_, clean_cache = model.run_with_cache(clean_tokens)

In [58]:
temp_att_pattern = clean_cache['blocks.2.attn.hook_pattern'][0, 1, :, :] #.shape  #.keys()

In [65]:
def save_relevant_attention_patterns(clean_cache, layer_head_tuples):
    for layer_ind, head_ind in layer_head_tuples:
        temp_att_pattern = clean_cache[f'blocks.{layer_ind}.attn.hook_pattern'][1, head_ind, :, :]
        attention_pattern = temp_att_pattern.detach().cpu().numpy()
        # Define the x and y labels, assuming they correspond to tokens
        tokens = model.to_str_tokens(clean_tokens[0])
        labels = [f"{tok} {i}" for i, tok in enumerate(tokens)]

        # Generate the heatmap
        fig = px.imshow(
            attention_pattern,
            labels=dict(x="Head Position", y="Head Position", color="Attention"),
            x=labels,
            y=labels,
            title=f"Attention Pattern in Layer {layer_ind}, Head {head_ind}",
            color_continuous_scale="Blues"
        )
        # Display the figure
        fig.write_image(f"equal_feature_operator_heads/L{layer_ind}H{head_ind}_atten_pattern.png")
        
save_relevant_attention_patterns(clean_cache, converted_tuples)

In [59]:
import plotly.express as px
import torch

# Assuming attention_pattern is a PyTorch tensor
attention_pattern = temp_att_pattern.detach().cpu().numpy()  # Convert to numpy for plotting

# Define the x and y labels, assuming they correspond to tokens
tokens = model.to_str_tokens(clean_tokens[0])
labels = [f"{tok} {i}" for i, tok in enumerate(tokens)]

# Generate the heatmap
fig = px.imshow(
    attention_pattern,
    labels=dict(x="Head Position", y="Head Position", color="Attention"),
    x=labels,
    y=labels,
    title="Attention Pattern in Block 2, Head 1",
    color_continuous_scale="Blues"
)

# Display the figure
fig.write_image("L2H1_atten_pattern.png")


In [43]:
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_pytorch_tanh',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 16.0,
 'attn_scores_soft_cap': 50.0,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
