Assignment 4: Activation Patching
==========

As described in the class 12 slides, traditional methods, such as classifier-based probing, that are designed to identify layers or attention heads responsible for different phenomenona  don't tell us anything about the causal role identifying of those layers or heads.  To identify the  *computational mechanisms* that lead to the outputs, we make use of a suite of methods that are collecitvely called  *mechanistic interpretability* methods.

Complete this assignment by following all of the steps in this notebook.




In [19]:
model, tokenizer = load_gpt2('gpt2-medium')
model = model.float()

NameError: name 'load_gpt2' is not defined

## Step 1:  Understanding the concepts of *residual streams* and *hooks*


### Residual streams

Before we can begin our activation patching we'll need to get some conceptual preliminaries out of the way. First, we'll need to undestand the *residual stream*, i.e., the "stream" that  passes information between the transformer blocks and undergoes linear transformations.

It's important to understand that there is no block or set of layers that individually correspond to a residual stream in the transformer; rather, it is a conceptual interpretation of the transformer architecture that captures the fact that,  due to residual connections within the transformer blocks, one can look at the flow of information (i.e., the vector representations) as being *read from* by the transformer blocks (i.e., reading results of previous computations), and then *writing* the results of the attention calculations back to the representations (via linear operations applied to previous representations).


### Hooks

The code below works with so-called *hooks*, i.e., functions that are "hooked onto" the forward pass through the model and are executed together with the normal computations, when the model is called. They are based on [this](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook) native PyTorch functionality.

## Step 2: Understand the code implementataion of activation patching using *Transformer_Lens*

Activation patching is an interpretability method for *causal intervention* on the computational mechanisms of the model: certain results of the computations are injecte and the effect that these results have on the outcome are observed.  Certain representations computed in one pass through the model are taken (i.e., representations from the *clean run*) and injected into another (the *corrupted run*). These injected representations are called the patch.

Patching techniques are quite tricky to work with since patching something in an intermediate layer affects not only the local layer, but also all the downstream computations. To be able to retrieve meaningful results, careful comparison and, e.g., freezing of subsequent representations is required.

Below is an example of using activation patching for the task of identifying which model internals are responsible for recalling landmark information. How does the model complete the prompt “The Colosseum is in” with the answer “Rome”? We'll implement this using the library [`transformer_lens`](https://github.com/TransformerLensOrg/TransformerLens).

Our clean prompt will be "The Colosseum is in" and our corrupted prompt will be "The Champs Elysees is in"

Recall from the slides that our metric for determining the which model internals is based on a simple metric--- the difference in the logits of the correct token, given different inputs. This metric leverages an old insight from LLM-based modelling:  a model can perform a task well if it assigns higher log probabilities (i.e., the logits are higher) for correct predictions that for incorrect ones.

We then pick a specific model activation, run the model on the corrupted prompt, but then intervene on that activation and patch in its value when run on the clean prompt. We then apply the metric, and see how much this patch has recovered the clean performance.


Essentially, we ask: given a corrupted input, if we inject certain representations from the "correct" run, can we "fix" the performance? If we can, we know that the injected representations causally contribute to producing the correct output.

The code below is taken from [this](https://github.com/TransformerLensOrg/TransformerLens/blob/main/demos/Main_Demo.ipynb) demo.


In [2]:
!pip install transformer_lens plotly

Collecting transformer_lens
  Downloading transformer_lens-2.15.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting transformers-stream-generator<0.0.6,>=0.0.5 (from transformer_lens)
  Downloading transformers-stream-generator-0.0.5.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting wadler-lindig>=0.1.3 (from jaxtyping>=0.2.11->transformer_lens)
  Downloading wadler_lindig-0.1.5-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.2->transfo

In [33]:
from transformer_lens import HookedTransformer
import plotly.express as px
import transformer_lens.utils as utils
import tqdm
from functools import partial
import torch

In [34]:
# load the model within the wrapper of the library which allows to easily access and patch activations

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [35]:
# first, we check if the model can do the task at all
# i.e., we compare the difference in logits for the correct and incorrect answer
# given different inputs without any interventions



#Note that you will get an index error if your corrupted prompt is smaller than your clean prompt
clean_prompt = "The Colosseum is in"
corrupted_prompt = "The Champs Elysees is in"


clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" Rome", incorrect_answer=" Paris"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

Clean logit difference: 1.068
Corrupted logit difference: -2.766


In [36]:
# define a helper

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)


In [37]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary
def residual_stream_patching_hook(
    resid_pre,
    hook,
    position
):
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    # NOTE: this is the key step in the patching process
    # where we replace the activations in the residual stream with the same activations from the clean run
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre

# We make a tensor to store the results for each patching run.
# We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)



100%|██████████| 12/12 [00:04<00:00,  2.90it/s]


In [38]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")


## Step 3: Implement 5 examples of activation patching using the above code

Now brainstorm 5 clean/corrupted pairs,  run the above code on each of these pairs, and interpret the plotly output. Your answer for each pair should have the following elements:


1.   A description of the concept being investigated. In the above example, the concept was *landmarks* but your concept could be *nobel prize winners*, *number agreement*, etc.

2.   Test your pairs using the *logit_to_logit_diff* function to ensure that the pair differs significantly and report the logit results.

3. Run the above code and produce a plotly plot showing layer-wise activations logit difference. Interpret your results by identifying which layers were responsible for your concept.


In [45]:
def logit_to_logit_diff(model, prompt, correct_answer, incorrect_answer):
    # Tokenize input
    tokens = model.to_tokens(prompt)
    logits = model(tokens)
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    last_token_logits = logits[0, -1]
    return (last_token_logits[correct_index] - last_token_logits[incorrect_index]).item()


## Gender Bias in Occupation

In [65]:
clean_prompt = "The nurse said that she would assist the doctor."
corrupted_prompt = "The nurse said that he would assist the doctor."


clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" she", incorrect_answer=" he"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)



Clean logit difference: 2.233
Corrupted logit difference: -0.081


100%|██████████| 12/12 [00:05<00:00,  2.21it/s]


In [66]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")


## Nobel Prize winners - Fame Association

In [72]:
clean_prompt="Marie Curie won the Nobel Prize in Physics."

corrupted_prompt="Marie Curie won the Nobel Prize in Biology."


clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" Physics", incorrect_answer=" Biology"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
#clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
#corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)





Clean logit difference: 1.354
Corrupted logit difference: 1.967


100%|██████████| 12/12 [00:05<00:00,  2.11it/s]


In [73]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")


## Number Agreement

In [74]:
clean_prompt= "The dogs are barking loudly."
corrupted_prompt= "The dogs is barking loudly."

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" are", incorrect_answer=" is"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)



Clean logit difference: -2.147
Corrupted logit difference: -2.508


100%|██████████| 12/12 [00:03<00:00,  3.10it/s]


In [75]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")


## Country Capital Knowledge

In [76]:
clean_prompt="The capital of France is Paris."

corrupted_prompt= "The capital of France is Berlin."

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" Paris", incorrect_answer=" Berlin"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)



Clean logit difference: 4.771
Corrupted logit difference: -2.386


100%|██████████| 12/12 [00:03<00:00,  3.24it/s]


In [77]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")


## Temporal Consistency

In [81]:
clean_prompt="Yesterday, I walked to the store."

corrupted_prompt= "Yesterday, I will walk to the store."



clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" walked", incorrect_answer=" walk"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)



Clean logit difference: 2.723
Corrupted logit difference: -0.867


100%|██████████| 12/12 [00:05<00:00,  2.13it/s]


In [82]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")
