# Localizing the underlying circuit via activation patching

You need to upgrade Jupiter netbook before
pip install --upgrade jupyter

In [1]:
import random
from functools import partial
from IPython.display import clear_output

import numpy as np

from plotly_utils import imshow, line, scatter

import torch

from transformer_lens import HookedTransformer
from transformer_lens import utils, patching

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def compute_logit_diff_2(logits, answer_tokens, average=True):
    """
    Compute the logit difference between the correct answer and the largest logit
    of all the possible incorrect capital letters. This is done for every iteration
    (i.e. each of the three letters of the acronym) and then averaged if desired.
    If `average=False`, then a `Tensor[batch_size, 3]` is returned, containing the
    logit difference at every iteration for every prompt in the batch

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Переносим answer_tokens на устройство logits
    answer_tokens = answer_tokens.to(logits.device)

    # Логиты правильных ответов (batch_size, 3)
    correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()

    # Получаем максимальный логит для возможных неправильных ответов
    capital_letters_tokens = torch.tensor(
        [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
         50, 51, 52, 53, 54, 55, 56, 57],
        dtype=torch.long,
        device=logits.device  # Убедимся, что тензор на том же устройстве
    )
    batch_size = logits.shape[0]
    capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
    incorrect_capital_letters = capital_letters_tokens_expanded[
        capital_letters_tokens_expanded != answer_tokens[..., None]
    ].reshape(batch_size, 3, -1)
    incorrect_logits, _ = logits[:, -3:].gather(-1, incorrect_capital_letters).max(-1)

    # Возвращаем среднее значение
    return (correct_logits - incorrect_logits).mean() if average else (correct_logits - incorrect_logits)

def topk_of_Nd_tensor(tensor, k):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

In [3]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
with open("data/acronyms.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))

# take a subset of the dataset (we do this because VRAM limitations)
# you can take as much as your GPU can
n_samples = 250
# giga-cursed way of sampling from the dataset
prompts, acronyms = list(map(list, zip(*random.choices(list(zip(prompts, acronyms)), k=n_samples))))

In [5]:
clean_tokens = model.to_tokens(prompts)
answer_tokens = model.to_tokens(acronyms, prepend_bos=False)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logits = model(clean_tokens)

In [6]:
clean_logit_diff = compute_logit_diff_2(clean_logits, answer_tokens, average=True)
clean_logit_diff.item()

1.2833060026168823

In [7]:
corrupted_tokens = clean_tokens.clone()
corrupted_tokens = corrupted_tokens[torch.randperm(corrupted_tokens.shape[0])]

corrupted_tokens_acronym = clean_tokens.clone()
corrupted_tokens_acronym = corrupted_tokens_acronym[torch.randperm(corrupted_tokens_acronym.shape[0])]

corrupted_tokens[:, -2:] = corrupted_tokens_acronym[:, -2:]

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [8]:
compute_logit_diff_2(corrupted_logits, answer_tokens).item()

-2.5778794288635254

In [9]:
act_patch_resid_pre_iter = []
indices = [2, 3, 4]
for i, j in enumerate(indices):
    # Corrupt just the current word
    corrupted_tokens_i = clean_tokens.clone()
    corrupted_tokens_i[:, j:j+2] = corrupted_tokens[:, j:j+2] 
    _, corrupted_cache_i = model.run_with_cache(corrupted_tokens_i)

    compute_logit_diff_aux = partial(compute_logit_diff_2, answer_tokens=answer_tokens, average=False) # returns (batch_size, 3)
    compute_logit_diff_iter = lambda logits: compute_logit_diff_aux(logits)[:, i].mean()
    act_patch_resid_pre = patching.get_act_patch_resid_pre(model, clean_tokens, corrupted_cache_i, compute_logit_diff_iter)
    act_patch_resid_pre_iter.append(act_patch_resid_pre)
act_patch_resid_pre_iter = torch.stack(act_patch_resid_pre_iter, dim=0)

clear_output()

In [10]:
facet_labels = ["Iteration 0", "Iteration 1", "Iteration 2"]
labels = ["BOS", "The", "C1", "C2", "C3", " (", "A1", "A2"]

baseline_logit_diff = compute_logit_diff_2(clean_logits, answer_tokens, average=False).mean(0)

imshow(
    act_patch_resid_pre_iter - baseline_logit_diff[..., None, None],
    x=labels,
    facet_col=0, # This argument tells plotly which dimension to split into separate plots
    facet_col_wrap=3,
    facet_labels=facet_labels, # Subtitles of separate plots
    title="Residual Stream Patching", 
    labels={"x": "Sequence Position", "y": "Layer"},
    height=400,
    width=1000,
)

In [None]:
act_patch_attn_head_out_all_pos_iter = []
indices = [2, 3, 4]
for i, j in enumerate(indices):
    # Corrupt just the current word
    corrupted_tokens_i = clean_tokens.clone()
    corrupted_tokens_i[:, j:j+2] = corrupted_tokens[:, j:j+2] 
    _, corrupted_cache_i = model.run_with_cache(corrupted_tokens_i)

    compute_logit_diff_aux = partial(compute_logit_diff_2, answer_tokens=answer_tokens, average=False) # returns (batch_size, 3)
    compute_logit_diff_iter = lambda logits: compute_logit_diff_aux(logits)[:, i].mean()
    act_patch_attn_head_out_all_pos = patching.get_act_patch_attn_head_out_all_pos(model, clean_tokens, corrupted_cache_i, compute_logit_diff_iter)
    act_patch_attn_head_out_all_pos_iter.append(act_patch_attn_head_out_all_pos)
act_patch_attn_head_out_all_pos_iter = torch.stack(act_patch_attn_head_out_all_pos_iter, dim=0)

clear_output()

  0%|          | 0/144 [00:00<?, ?it/s]

In [None]:
facet_labels = ["Iteration 0", "Iteration 1", "Iteration 2"]
labels = ["BOS", "The", "C1", "C2", "C3", " (", "A1", "A2"]

baseline_logit_diff = compute_logit_diff_2(clean_logits, answer_tokens, average=False).mean(0)

imshow(
    act_patch_attn_head_out_all_pos_iter - baseline_logit_diff[..., None, None],
    facet_col=0,
    facet_labels=facet_labels,
    labels={"y": "Layer", "x": "Head"}, 
    title="Patching Attention Heads", width=800, height=400
)

In [None]:
k = 5
path_patch_head_to_heads_iteration = act_patch_attn_head_out_all_pos_iter - baseline_logit_diff[..., None, None]
top_heads_1 = topk_of_Nd_tensor(-1*path_patch_head_to_heads_iteration[0], k)
top_heads_2 = topk_of_Nd_tensor(-1*path_patch_head_to_heads_iteration[1], k)
top_heads_3 = topk_of_Nd_tensor(-1*path_patch_head_to_heads_iteration[2], k)

print(f"Top {k} heads on iteration 1: {top_heads_1}")
print(f"Top {k} heads on iteration 2: {top_heads_2}")
print(f"Top {k} heads on iteration 3: {top_heads_3}")