<a name="top"></a>
#Quick Research Summary
##Preface
*  This is a hackathon-esq notebook demonstraing how I go about analyzing model behaviors and interpret figures. I often set 10 hour limits when doing such searches as there may be other behaviors more worth while exploring!
*   Hopefully this notebook demonstrated my ability to analyze behaviors in the wild.
*   Note: Under every chart linked in this summary there is a more thorough interpretation underneath the chart than what I provide in the summary.


##Initial exploration
  *   In playing around with GPT2-Small I noticed that it would repeat some grammatical patterns. The one I decided to focus on was using contractions that previously appeared in sentences.
      *   "We're helping our friend while they" predicts "'re" instead "are"
      *   "We are helping our friend while they" predicts "are" instead "'re"
  *   I also noticed that the model will (mostly) continue to use contractions even if the final contraction is distinct from the previously used one.
    *   "We're helping our friend while he" predicts 's instead "is" despite the previous contraction in the sentence being 're

##Research Goal
*   For one type of contraction earlier in the sentence paired with one terminating pronoun, how does the model know to predict the correct contraction?
*   For several different contraction pronoun pairs, can I identify the model's mechanisms used to predict the final contraction?
*   If I can do the above, can I compare the various mechanisms used to predict contractions? Are they the same? Are they different? If so how are they different? Do they use various parts of the same circuit?
*   These questions are super exciting to me yet are probably too grand in scope for this application. I set out to make as much progress in answering these questions as I can while learning the tools of mechanistic interpretability.

##Hypothesis
*   From purely thinking about what the model could be implementing, I think it is possible that the mechanism to do this task involves Bigram/trigram munging. Bigrams are when the model predicts the next token based on what is most likely to appear after exclusively the final token. Skip trigrams fare when the model sees some pattern "A...B" and then predicts C. I believe it is possible that the model looks at the pronoun at the end of the sentence, boosts verbs and verb contractions, and from the trigram with the earlier contraction token, boosts the correct contraction of verbs for that pronoun.
*   I anticipate that the mechanisms used to do this are slightly different for various contraction pronoun pairs. I am very uncertain about this but my reasoning is that the model (in some contraction pairings) will be able to copy the previous contraction instead of relying only on the bigram with the pronoun at the end. I wonder if the model will favor copying information from the contraction in some cases or will just stick to the same general mechanism.

## Experiment 1: Looking at the ('re, they) pairing
*   By decomposing the residual stream layer by layer [we can see](#E1:LA)  that the most the most important parts of the model for this task are layer 7 attention and layer 8 and 10 MLP.
*   Looking at the [attention heads](#E1:AH) we see that L7 H11 seems critically important to completing the task. Inspecting it further we see that [this head](#E1:HP) attends the contraction token to the final token
*   I then conducted a ROME like analysis with corrupted prompts to further look at this behavior
*   By [patching attention layer activations](#E1:ALP) we see that we can recover a good amount of performance on corrupted prompts by patching L7 and L8 which aligns with our previous analysis but interestingly we can't fully recover performance which means these
*   By [patching MLP layers](#E1:MLPP) we see that we can recover a surprising amount from MLP layers 1 3 8 and 10.
*   As explained [below](#neurExp), I wanted to inspect all the neurons in these layers on neuroscope but realized this was simply too many neurons to look at. Instead I used [clementneo's](#credit) technique for activation patching for individual neurons. I was hoping this way I could narrow down my search for neurons (recognizing that I was overlooking important neurons) that would be important on their own as a starting point. Due to time I only did this for [layer 8](#E1:NP8) and [layer 10](#E1:NP10). I found [Layer: 8. Neuron Index: 2744](#E1:NI2744), [Layer: 10. Neuron Index: 1063](#E1:NI1063), and [Layer: 10 . Neuron Index: 2193](#E1:NI2193) to be important. Their corresponding neuroscope pages also made sense for why this would be the case!
*   Lastly I decomposed the heads and found that the [value patching](#E1:VPP) most significantly recovered performance in L7 H11 and not the [attention pattern patching](#E1:APP) which contradicted what I thought was going on when I visualized this head. Would want to look more closely into why this is the case! We [can see](#E1:HP) that there are other heads (L10 H1) in the model that attend the contraction token to the final one yet maybe they don't copy the same information as head 7 which would explain why the value patching is more important than the pattern patching. Will look more in depth as to the plausibility of this after the application is due but can't say this is definitely the case without further analysis.

## Experiment 2: Looking at other pairings

*   I won't go as in depth into the analysis itself because it pretty closely aligns with the analysis I did above just for multiple experimental trials where the prompts within a trial all have the same contraction pronoun pairing.
*   The various prompts and experimental groups for each pairing can be found [here](#E2:DPA)
*   [Decomposing the residual stream](#E2:LA) layer by layer suggests that for all pronoun pairings the same 3 layers are the most important yet which one of those layers is the most important varies slightly
*   We can see by looking at the [heads](#E2:HA) that L7 H11 is by far the most important again and the other top heads are largely the same yet for the 's we, 's they pairings L11 H8 was the third most important head which is not the case for the other heads.
*   I then conducted an analysis for each experimental group with corrupted prompts and found that the [same layers could improve performance in each experimental group](#E2:RSP)
*   Finally I [Attention Layers](#E2:ALP) and [patched MLP Layers](#E2:MLPP) from the clean runs of the model
*   For Attention Layers layers I noticed that largely the same layers are important for each task but there is variation! For 're he and 're she pairings, L8 is slightly more significant that L7 yet for some pairings like 's she and 's he L8 is less significant than L7. Granted these differences are < 0.2 which may not be significant but it is there!
*   For the MLP Layers there is slightly more variation in the later layers between the tasks. The differences between the experimental groups (largerly the use of L11 MLP) can be largely bundled into two groups categorized by what the predicted contraction should be. Within each group that chart looks largely similar. Again these differences are within a small range < 0.2 so may not be significant but are definitely interesting!


#Wrap-up
*   As explained above, I believe that some of the observations suggest that maybe there is some trigram bigram munging happening to perform this task but I would be far from saying that I can reject any sort of null hypothesis.
*   This emphasizes to me just how complicated and involved some of the processing this model is doing actually is. A good example is telling myself to keep looking into MLPs eventhough it seemed like L7 H11 provided a simple all encompassing solution. I am now certain the truthful circuit is far more involved than a simple attention pattern.



















#Imports + Setup

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-h16p2edl
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-h16p2edl
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 3f0e1c3a32ef5a69b11284ffd0ddfefe11197bc5
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting numpy>=1.23
  Downloading numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
Collecting transformers>=4.25.1
  Downloading transfo

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import tqdm.auto as tqdm
from jaxtyping import Float, Int


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "vscode"

In [None]:
!pip install torchtyping
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import numpy as np
import einops
from typing import List, Union, Optional
import pysvelte
from IPython.display import HTML
from functools import partial

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchtyping
  Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)
Installing collected packages: torchtyping
Successfully installed torchtyping-0.1.4


In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa5837a15d0>

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
#Importing the model to be used
gpt_small = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


#Some general observations from prompt experimentation on GPT-Small
*   The model has the tendancy repeate gramatical patterns
*    One such example of this is using apostrophe earlier in the sentance makes the model more likely to predict an apostrophe token later in the sentence. Similary not including one encourages the model to not to use one later in the sentence
*   Examples of this behavior are shown below
  *   "We're helping our friend while he" predicts "'s" instead "is"
  *   "We are helping our friend while he" predicts "is" instead "'s"
*   A second observation is that the model can still perform the task when there are different pronoun tokens
  *   "We're helping our friend while he" ('re vs he) --> predicts 's
  *   "We're helping our friend while she" ('re vs she) --> predicts 's
  *   "We're helping our friend while they" ('re vs they)  --> predicts 're
  *   "We're helping our friend while we" ('re vs we) --> predicts 're
  *   "She's helping our friend while he" ('s vs he)  --> predicts 's
  *   "She's helping our friend while she" ('s vs she) --> predicts 's
  *   "She's helping our friend while they" ('s vs they) --> predicts 're
  *   "She's helping our friend while we" ('s vs we) --> predicts 're


















#Initial hypothesis
*   Since 're and 's are both individual tokens, I hypothesize that GPT_Small can achieve this task using a skip trigram/bigram mechanism (very new to this concept so may be using incorretly here). There is a part of the model that sees the 're or 's token and also detects the current token is some pronoun so then the next token predicted is 're or 's. Otheriwise it predicts the next token is the standard full (word is was were they're etc).

#Peliminary questions
*   I want to investigate how the model relates to the shortened 's or 're
 to different pronouns. Does it have seprate mechanisms for each pronoun that all do the same thing (highlighting the 's and're tokens with the current one)? Does it use a more abstract ciruit to tackle all these cases?

*   If I am reasoning about this correctly, I think investigating this question can shed light on some pretty awesome behaviors!
  1.   Do models of this size learn to develop the same mechanism for many different tokens or do they learn to make more generalizable ciruits for simple tasks?
  2.   If the former, where do these different mechanism lie within the model (same layer, same head pos, etc)?
  3.   If the latter, how could it classify he, she, they as the same thing gramatical object (pronouns) to be used by the more general circuit that predicts if it should continue the pattern?


#Experiment 1: Prompt generating

In [None]:
import os
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
example_prompt = "She's at home studying while they"

example_answer = "'s"
utils.test_prompt(example_prompt, example_answer, gpt_small, prepend_bos=True, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', 'She', "'s", ' at', ' home', ' studying', ' while', ' they']
Tokenized answer: ["'s"]


Top 0th token. Logit: 17.01 Prob: 23.57% Token: |'re|
Top 1th token. Logit: 15.43 Prob:  4.88% Token: | watch|
Top 2th token. Logit: 15.41 Prob:  4.76% Token: | wait|
Top 3th token. Logit: 15.06 Prob:  3.35% Token: | work|
Top 4th token. Logit: 15.00 Prob:  3.16% Token: | play|
Top 5th token. Logit: 14.72 Prob:  2.38% Token: | go|
Top 6th token. Logit: 14.27 Prob:  1.52% Token: | get|
Top 7th token. Logit: 14.25 Prob:  1.49% Token: | take|
Top 8th token. Logit: 13.97 Prob:  1.13% Token: | are|
Top 9th token. Logit: 13.95 Prob:  1.10% Token: | try|


In [None]:
#Generating orginal prompts for the 're and they combo
orig_prompts = ["We're helping our friend when they",
                "We are helping our friend when they",
                "They're eating dinner together while they",
                "They are eating dinner together while they",
                "We're driving uptown and they",
                "We are driving uptown and they",
                "They're sleeping too soon and they",
                "They are sleeping too soon and they"]

#Creating a list of correct answers for these prompts
correct_answers = [("'re"," are") if x%2==0 else (" are","'re") for x in range(8)]

#Tokenizing the answers
answer_tokens = []
for ans in correct_answers:
  curr_tokens = []
  curr_tokens.append(gpt_small.to_single_token(ans[0]))
  curr_tokens.append(gpt_small.to_single_token(ans[1]))
  answer_tokens.append(curr_tokens)
answer_tokens = torch.tensor(answer_tokens).cuda()

In [None]:
#Running gpt-small on these prompts and recording cache
tokens = gpt_small.to_tokens(orig_prompts, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.cuda()
# Run the model and cache all activations
original_logits, original_cache = gpt_small.run_with_cache(tokens)

In [None]:
#Verifying token count for all the prompts
for prompt in tokens:
    print("Prompt length:", len(prompt))
    print("Prompt as tokens:", prompt.tolist())
    print("Prompt as words:", gpt_small.to_str_tokens(prompt))

Prompt length: 8
Prompt as tokens: [50256, 1135, 821, 5742, 674, 1545, 618, 484]
Prompt as words: ['<|endoftext|>', 'We', "'re", ' helping', ' our', ' friend', ' when', ' they']
Prompt length: 8
Prompt as tokens: [50256, 1135, 389, 5742, 674, 1545, 618, 484]
Prompt as words: ['<|endoftext|>', 'We', ' are', ' helping', ' our', ' friend', ' when', ' they']
Prompt length: 8
Prompt as tokens: [50256, 2990, 821, 6600, 8073, 1978, 981, 484]
Prompt as words: ['<|endoftext|>', 'They', "'re", ' eating', ' dinner', ' together', ' while', ' they']
Prompt length: 8
Prompt as tokens: [50256, 2990, 389, 6600, 8073, 1978, 981, 484]
Prompt as words: ['<|endoftext|>', 'They', ' are', ' eating', ' dinner', ' together', ' while', ' they']
Prompt length: 8
Prompt as tokens: [50256, 1135, 821, 5059, 18529, 593, 290, 484]
Prompt as words: ['<|endoftext|>', 'We', "'re", ' driving', ' upt', 'own', ' and', ' they']
Prompt length: 8
Prompt as tokens: [50256, 1135, 389, 5059, 18529, 593, 290, 484]
Prompt as word

In [None]:
answer_tokens

tensor([[821, 389],
        [389, 821],
        [821, 389],
        [389, 821],
        [821, 389],
        [389, 821],
        [821, 389],
        [389, 821]], device='cuda:0')

In [None]:
#Here we want to verify the model can perform the task by correctly measuring if it continues the pattern

def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    #Since token correspond to lookup, we can identify token likelihood by getting index of token value
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)

    #Get difference between correct token and trap answer
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([1.7542, 1.6924, 2.9144, 0.9969, 2.5970, 1.0220, 3.6359, 2.2313],
       device='cuda:0')
Average logit difference: 2.1054983139038086


Here $𝑒^{2.1}$ is around equal to 8 so the model is 8x more likely to predict the correct token so it is not as good at this task as IOI but still interested in investigating further

In [None]:
#Projecting onto a direction in the residual stream and taking differences
answer_residual_directions = gpt_small.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([8, 2, 768])
Logit difference directions shape: torch.Size([8, 768])


In [None]:
#Here we are verifying that direct logit atribution works and checking if we get the same results

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = original_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]

# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = original_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(orig_prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

Final residual stream shape: torch.Size([8, 8, 768])
Calculated average logit diff: 2.1054985523223877
Original logit difference: 2.1054983139038086


It works!

#Experiment 1: Logit Lens Analysis

In [None]:
#Function to calculate residual stream differences at specific point
#Simulates what happens if we just delete all following layers
#known as the logit lens technique
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache, prompts: List, logit_diff_directions: torch.Tensor) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

In [None]:
#Checking logits before MLP after Attention Head
accumulated_residual, labels = original_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, original_cache, orig_prompts, logit_diff_directions )
line(logit_lens_logit_diffs, x=np.arange(gpt_small.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

*   Seems like the model is totally unable to do the task until layer 7



#Experiment 1: Layer Attribution



<a name="E1:LA"></a>

In [None]:
per_layer_residual, labels = original_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, original_cache, orig_prompts, logit_diff_directions)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

*   Looks like post layer six the layers become way more important but in particular layer 7 attn, layer 8 MLP, layer 10 MLP


#Experiment 1: Analyzing Head Attribution
<a name="E1:AH"></a>

In [None]:
#Get per head residuals
per_head_residual, labels = original_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)

#"Cut off" model at these different heads and see what happens
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, original_cache, orig_prompts, logit_diff_directions)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=gpt_small.cfg.n_layers, head_index=gpt_small.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


*   Here we can see that there are notable heads in three of the layers (L7, L8, L10), interestingly the same layers that were most notable in the layer analysis above. We see that L7 H11 has by far the most significant contribution in this chart enough to be  is why it was the only spike in that graph to be attributed to an attention head. Seems like this attention head is crucial to the task!  


#Experiment 1: Attention Pattern Analysis

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: Optional[ActivationCache]=None,
    local_tokens: Optional[torch.Tensor]=None,
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = original_cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = tokens[0]

    labels = []
    patterns = []
    batch_index = 0

    for head in heads:
        layer = head // gpt_small.cfg.n_heads
        head_index = head % gpt_small.cfg.n_heads

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")

    str_tokens = gpt_small.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)

    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

<a name="E1:HP"></a>

In [None]:
top_k = 3
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")
top_negative_logit_attr_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads")

pysvelte components appear to be unbuilt or stale
Running npm install...
Building pysvelte components with webpack...


In [None]:
top_positive_logit_attr_head_values = torch.topk(per_head_logit_diffs.flatten(), k=top_k).values
top_positive_logit_attr_head_values

tensor([0.4068, 0.1754, 0.1146], device='cuda:0')

###Attention pattern observations
*   There seem to be two of the top three positive difference heads performing similar things in different layers (L7 H11, L10 H1)
*   It looks as thought these heads are both copying information at position of the 're token to the last token position
*   Attention Head 11 in layer 7 is by far the most crucial head though as its attribution value is 0.5018 while the next most important value is 0.1667!
*   Personal note: Must remember that more info is stored in these token positions!






#Experiment 1: Generating corrupted prompts for Rome like analysis

In [None]:
corrupted_prompts = []
for i in range(0, len(orig_prompts), 2):
    corrupted_prompts.append(orig_prompts[i+1])
    corrupted_prompts.append(orig_prompts[i])
corrupted_tokens = gpt_small.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = gpt_small.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)

Corrupted Average Logit Diff tensor(-2.1055, device='cuda:0')
Clean Average Logit Diff tensor(2.1055, device='cuda:0')


In [None]:
tokens

tensor([[50256,  1135,   821,  5742,   674,  1545,   618,   484],
        [50256,  1135,   389,  5742,   674,  1545,   618,   484],
        [50256,  2990,   821,  6600,  8073,  1978,   981,   484],
        [50256,  2990,   389,  6600,  8073,  1978,   981,   484],
        [50256,  1135,   821,  5059, 18529,   593,   290,   484],
        [50256,  1135,   389,  5059, 18529,   593,   290,   484],
        [50256,  2990,   821, 11029,  1165,  2582,   290,   484],
        [50256,  2990,   389, 11029,  1165,  2582,   290,   484]],
       device='cuda:0')

In [None]:
corrupted_tokens

tensor([[50256,  1135,   389,  5742,   674,  1545,   618,   484],
        [50256,  1135,   821,  5742,   674,  1545,   618,   484],
        [50256,  2990,   389,  6600,  8073,  1978,   981,   484],
        [50256,  2990,   821,  6600,  8073,  1978,   981,   484],
        [50256,  1135,   389,  5059, 18529,   593,   290,   484],
        [50256,  1135,   821,  5059, 18529,   593,   290,   484],
        [50256,  2990,   389, 11029,  1165,  2582,   290,   484],
        [50256,  2990,   821, 11029,  1165,  2582,   290,   484]],
       device='cuda:0')

#Experiment 1: Patching from Residual Stream

In [None]:
#Function to patch residual components
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff, original_average_logit_diff, corrupted_average_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

patched_residual_stream_diff = torch.zeros(gpt_small.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)

#For each layer
for layer in range(gpt_small.cfg.n_layers):
    #For each position within a layer
    for position in range(tokens.shape[1]):
        #Fix params of patch_residual_component
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=original_cache)

        #Run the model with the patch resid stream
        patched_logits = gpt_small.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("resid_pre", layer),
                hook_fn)],
            return_type="logits"
        )
        #Get logit difference
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        #Normalize them and set to the final residual stream differences
        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff, original_average_logit_diff, corrupted_average_logit_diff)

In [None]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt_small.to_str_tokens(tokens[0]))]
imshow(patched_residual_stream_diff, x=prompt_position_labels, title="Logit Difference From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})

*    Around layer 7-8 there are is a switch from the largest difference happening in the 're token to the largest difference being at the final token position. This suggests that the information from that token was copied around layer 7 which aligns with our head analysis

#Experiment 1: Patching from Attention and MLP Layers


In [None]:
patched_attn_diff = torch.zeros(gpt_small.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
patched_mlp_diff = torch.zeros(gpt_small.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
for layer in range(gpt_small.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=original_cache)
        patched_attn_logits = gpt_small.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("attn_out", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(patched_attn_logits, answer_tokens)
        patched_mlp_logits = gpt_small.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("mlp_out", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(patched_mlp_logits, answer_tokens)

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff,original_average_logit_diff, corrupted_average_logit_diff)
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff, original_average_logit_diff, corrupted_average_logit_diff)

<a name="E1:ALP"></a>

In [None]:
imshow(patched_attn_diff, x=prompt_position_labels, title="Logit Difference From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})

<a name="E1:MLPP"></a>

In [None]:
imshow(patched_mlp_diff, x=prompt_position_labels, title="Logit Difference From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})

Here we can see in contrast to the IOI task, there are three MLP layers that seem to matter a lot for prediction (Color may be deciving because MLP0 is super important but patched values are around the same logit difference). Important MLP layers/positions are ('re L1: 0.229, 're L3: 0.1448, they L8: 0.20344, they L10: 0.1409)

#Experiment 1: Neuron Patching Set Up

In [None]:
def patch_neuron_activation(
    corrupted_residual_component: TT["batch", "pos", "d_mlp"],
    hook,
    neuron,
    clean_cache):
    corrupted_residual_component[:, :, neuron] = clean_cache[hook.name][:, :, neuron]
    return corrupted_residual_component

patched_neurons_normalized_improvement_8 = torch.zeros(gpt_small.cfg.d_mlp, device=device, dtype=torch.float32)
layer = 8
max_neurons = 3072
for neuron in range(gpt_small.cfg.d_mlp)[:max_neurons]:
    hook_fn = partial(patch_neuron_activation, neuron=neuron, clean_cache=original_cache)
    patched_neuron_logits = gpt_small.run_with_hooks(
        corrupted_tokens,
        fwd_hooks = [("blocks.8.mlp.hook_post", hook_fn)],
        return_type="logits"
    )
    patched_neuron_correct_incorrect_logit_diff = logits_to_ave_logit_diff(patched_neuron_logits, answer_tokens)

    patched_neurons_normalized_improvement_8[neuron] = normalize_patched_logit_diff(patched_neuron_correct_incorrect_logit_diff, original_average_logit_diff, corrupted_average_logit_diff)

In [None]:
patched_neurons_normalized_improvement_10 = torch.zeros(gpt_small.cfg.d_mlp, device=device, dtype=torch.float32)
layer = 10
max_neurons = 4000
for neuron in range(gpt_small.cfg.d_mlp)[:max_neurons]:
    hook_fn = partial(patch_neuron_activation, neuron=neuron, clean_cache=original_cache)
    patched_neuron_logits = gpt_small.run_with_hooks(
        corrupted_tokens,
        fwd_hooks = [("blocks.10.mlp.hook_post", hook_fn)],
        return_type="logits"
    )
    patched_neuron_correct_incorrect_logit_diff_10 = logits_to_ave_logit_diff(patched_neuron_logits, answer_tokens)

    patched_neurons_normalized_improvement_10[neuron] = normalize_patched_logit_diff(patched_neuron_correct_incorrect_logit_diff_10, original_average_logit_diff, corrupted_average_logit_diff)

#Experiment 1: Single Neuron Patching in Layer 8 and Layer 10

<a name="E1:NP8"></a>

In [None]:
line(patched_neurons_normalized_improvement_8[:max_neurons], x=list(range(len(patched_neurons_normalized_improvement_8))), title="Logit Difference From Patched Neurons in MLP Layer 8", labels={"x":"neuron", "y":"Patch Improvement"})

<a name="E1:NP10"></a>

In [None]:
line(patched_neurons_normalized_improvement_10[:max_neurons], x=list(range(len(patched_neurons_normalized_improvement_10))), title="Logit Difference From Patched Neurons in MLP Layer 10", labels={"x":"neuron", "y":"Patch Improvement"})

*   For Layer 8 it seems that when removed neuron 2744 is important
*   For layer 10 it seems that when removed neuron 1063 is important and removing neuron 2193 actually imporves performance



<a name="credit"></a>
#Credit where credit is due
Thanks to clementneo for the code! This is from the mechanistic interpretibility hackathon run by alignment jams. I stopped by for a little to see my uni group's progress (umich) and after checked out this really cool exploration "We Discovered An Neuron". I used their code directly as a means to narrow down neurons to look at recognizing that this is probably not the best approach to determining important neurons.

#Experiment 1: Verifying Single Neuron Patching in Layer 8 and Layer 10 results with Neuroscope

In [None]:
from transformer_lens.utils import to_numpy
def get_neuron_acts(text, layer, neuron_index):
    # Hacky way to get out state from a single hook - we have a single element list and edit that list within the hook.
    cache = {}
    def caching_hook(act, hook):
        cache["activation"] = act[0, :, neuron_index]

    gpt_small.run_with_hooks(
        text, fwd_hooks=[(f"blocks.{layer}.mlp.hook_post", caching_hook)]
    )
    return to_numpy(cache["activation"])

In [None]:
# This is some CSS (tells us what style )to give each token a thin gray border, to make it easy to see token separation
style_string = """<style>
    span.token {
        border: 1px solid rgb(123, 123, 123)
        }
    </style>"""

def calculate_color(val, max_val, min_val):
    # Hacky code that takes in a value val in range [min_val, max_val], normalizes it to [0, 1] and returns a color which interpolates between slightly off-white and red (0 = white, 1 = red)
    # We return a string of the form "rgb(240, 240, 240)" which is a color CSS knows
    normalized_val = (val - min_val) / max_val
    return f"rgb(240, {240*(1-normalized_val)}, {240*(1-normalized_val)})"


def basic_neuron_vis(text, layer, neuron_index, max_val=None, min_val=None):
    """
    text: The text to visualize
    layer: The layer index
    neuron_index: The neuron index
    max_val: The top end of our activation range, defaults to the maximum activation
    min_val: The top end of our activation range, defaults to the minimum activation

    Returns a string of HTML that displays the text with each token colored according to its activation

    Note: It's useful to be able to input a fixed max_val and min_val, because otherwise the colors will change as you edit the text, which is annoying.
    """
    if layer is None:
        return "Please select a Layer"
    if neuron_index is None:
        return "Please select a Neuron"
    acts = get_neuron_acts(text, layer, neuron_index)
    act_max = acts.max()
    act_min = acts.min()
    # Defaults to the max and min of the activations
    if max_val is None:
        max_val = act_max
    if min_val is None:
        min_val = act_min
    # We want to make a list of HTML strings to concatenate into our final HTML string
    # We first add the style to make each token element have a nice border
    htmls = [style_string]
    # We then add some text to tell us what layer and neuron we're looking at - we're just dealing with strings and can use f-strings as normal
    # h4 means "small heading"
    htmls.append(f"<h4>Layer: <b>{layer}</b>. Neuron Index: <b>{neuron_index}</b></h4>")
    # We then add a line telling us the limits of our range
    htmls.append(
        f"<h4>Max Range: <b>{max_val:.4f}</b>. Min Range: <b>{min_val:.4f}</b></h4>"
    )
    # If we added a custom range, print a line telling us the range of our activations too.
    if act_max != max_val or act_min != min_val:
        htmls.append(
            f"<h4>Custom Range Set. Max Act: <b>{act_max:.4f}</b>. Min Act: <b>{act_min:.4f}</b></h4>"
        )
    # Convert the text to a list of tokens
    str_tokens = gpt_small.to_str_tokens(text)
    for tok, act in zip(str_tokens, acts):
        # A span is an HTML element that lets us style a part of a string (and remains on the same line by default)
        # We set the background color of the span to be the color we calculated from the activation
        # We set the contents of the span to be the token
        htmls.append(
            f"<span class='token' style='background-color:{calculate_color(act, max_val, min_val)}' >{tok}</span>"
        )

    return "".join(htmls)

In [None]:
# The function outputs a string of HTML
default_layer = 8
default_neuron_index = 2744
default_max_val = 4.0
default_min_val = 0.0
default_text = "We're helping our friend when they"
default_html_string = basic_neuron_vis(
    default_text,
    default_layer,
    default_neuron_index,
    max_val=default_max_val,
    min_val=default_min_val,
    )

# IPython lets us display HTML
print("Displayed HTML")
display(HTML(default_html_string))

# We can also print the string directly
print("HTML String - it's just raw HTML code!")
print(default_html_string)

Displayed HTML


HTML String - it's just raw HTML code!
<style> 
    span.token {
        border: 1px solid rgb(123, 123, 123)
        } 
    </style><h4>Layer: <b>8</b>. Neuron Index: <b>2744</b></h4><h4>Max Range: <b>4.0000</b>. Min Range: <b>0.0000</b></h4><h4>Custom Range Set. Max Act: <b>1.9648</b>. Min Act: <b>-0.0916</b></h4><span class='token' style='background-color:rgb(240, 242.71859161555767, 242.71859161555767)' ><|endoftext|></span><span class='token' style='background-color:rgb(240, 212.48052656650543, 212.48052656650543)' >We</span><span class='token' style='background-color:rgb(240, 244.68825057148933, 244.68825057148933)' >'re</span><span class='token' style='background-color:rgb(240, 243.7611535191536, 243.7611535191536)' > helping</span><span class='token' style='background-color:rgb(240, 245.49718603491783, 245.49718603491783)' > our</span><span class='token' style='background-color:rgb(240, 243.67072261869907, 243.67072261869907)' > friend</span><span class='token' style='backgroun

<a name="E1:NI2744"></a>
NI: 2744 Corresponding [neuroscope page](https://neuroscope.io/gpt2-small/8/2744.html)


In [None]:
# The function outputs a string of HTML
default_layer = 10
default_neuron_index = 1063
default_max_val = 4.0
default_min_val = 0.0
default_text = "We're helping our friend when they"
default_html_string = basic_neuron_vis(
    default_text,
    default_layer,
    default_neuron_index,
    max_val=default_max_val,
    min_val=default_min_val,
    )

# IPython lets us display HTML
print("Displayed HTML")
display(HTML(default_html_string))

# We can also print the string directly
print("HTML String - it's just raw HTML code!")
print(default_html_string)

Displayed HTML


HTML String - it's just raw HTML code!
<style> 
    span.token {
        border: 1px solid rgb(123, 123, 123)
        } 
    </style><h4>Layer: <b>10</b>. Neuron Index: <b>1063</b></h4><h4>Max Range: <b>4.0000</b>. Min Range: <b>0.0000</b></h4><h4>Custom Range Set. Max Act: <b>3.4040</b>. Min Act: <b>-0.1640</b></h4><span class='token' style='background-color:rgb(240, 221.35707557201385, 221.35707557201385)' ><|endoftext|></span><span class='token' style='background-color:rgb(240, 109.4798469543457, 109.4798469543457)' >We</span><span class='token' style='background-color:rgb(240, 248.97878050804138, 248.97878050804138)' >'re</span><span class='token' style='background-color:rgb(240, 249.64562863111496, 249.64562863111496)' > helping</span><span class='token' style='background-color:rgb(240, 243.96182656288147, 243.96182656288147)' > our</span><span class='token' style='background-color:rgb(240, 247.89610415697098, 247.89610415697098)' > friend</span><span class='token' style='backgrou

<a name="E1:NI1063"></a>
NI 1063: Corresponding [neuroscope page](https://neuroscope.io/gpt2-small/10/1063.html)


In [None]:
# The function outputs a string of HTML
default_layer = 10
default_neuron_index = 2193
default_max_val = 4.0
default_min_val = 0.0
default_text = "We're helping our friend when they"
default_html_string = basic_neuron_vis(
    default_text,
    default_layer,
    default_neuron_index,
    max_val=default_max_val,
    min_val=default_min_val,
)

# IPython lets us display HTML
print("Displayed HTML")
display(HTML(default_html_string))

# We can also print the string directly
print("HTML String - it's just raw HTML code!")
print(default_html_string)

Displayed HTML


HTML String - it's just raw HTML code!
<style> 
    span.token {
        border: 1px solid rgb(123, 123, 123)
        } 
    </style><h4>Layer: <b>10</b>. Neuron Index: <b>2193</b></h4><h4>Max Range: <b>4.0000</b>. Min Range: <b>0.0000</b></h4><h4>Custom Range Set. Max Act: <b>1.2683</b>. Min Act: <b>-0.1700</b></h4><span class='token' style='background-color:rgb(240, 246.77837312221527, 246.77837312221527)' ><|endoftext|></span><span class='token' style='background-color:rgb(240, 214.46180284023285, 214.46180284023285)' >We</span><span class='token' style='background-color:rgb(240, 240.1641828753054, 240.1641828753054)' >'re</span><span class='token' style='background-color:rgb(240, 241.43916461616755, 241.43916461616755)' > helping</span><span class='token' style='background-color:rgb(240, 240.1729013537988, 240.1729013537988)' > our</span><span class='token' style='background-color:rgb(240, 250.20109981298447, 250.20109981298447)' > friend</span><span class='token' style='background

<a name="E1:NI2193"></a>
NI 2193: Corresponding [neuroscope page](https://neuroscope.io/gpt2-small/10/2193.html)


<a name="neurExp"></a>
##Observations/Explanation
1.   3072 neurons are a lot. After using the neuroscope tool for a while I decided to employ a less rigorous technique of activition patching for single neurons. The hope is that doing this would narrow down the search space for neurons I should investigate on the neuroscope tool. I was suprised to find that there seemed to be a single neuron in layer 8 (2744) layer 10 there were two (1063, 2193) that was able to recover or destroy performance.
2.   Using the interactive neuroscope tool we can see these neurons activate pronouns before ' and the last token. Looking at the neuroscope pages themseleves it seems like this isn't the only thing these neurons do. Sometimes they miss some pronouns in the text as well. All of this is to say that the behavior is pretty weird and definitely not the result of this set of neurons.
3.   In terms of patching neurons, I am unsure that all the behavior can be explained by these single removals. From the patching technique paired with inspecting neuroscope it seems to me that these neurons are probably important and behave in a way that I would expect them to but am unconvinced that the entire model behavior can be explained by their activations. I think it is far more likely that there is a more complex mechanism in this MLP explaining this behavior. What if there are reduandant neurons that are essential to the mechanism of operation? This would make removing one of these neurons seem insignificant yet we would be missing a big picture of how the model is actually representing things.
3.    I am definitely missing some interesting neurons by doing this and with more time would explore some more neurons.
4. I would be incredibly excited to  brainstorm new more rigourous techniques to explore MLP's. Super curious to know what's actually going on!



#Experiment 1: Further Head analysis

In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache):

    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][:, :, head_index, :]
    return corrupted_head_vector

patched_head_z_diff = torch.zeros(gpt_small.cfg.n_layers, gpt_small.cfg.n_heads, device="cuda", dtype=torch.float32)

for layer in range(gpt_small.cfg.n_layers):
    for head_index in range(gpt_small.cfg.n_heads):
        #For every attention head patch clean head attention into dirty one
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=original_cache)
        patched_logits = gpt_small.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"),
                hook_fn)],
            return_type="logits"
        )

        #Measure task performance by comparing difference in likelihood between tokens
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)
        #Normalize task performance
        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff,  original_average_logit_diff, corrupted_average_logit_diff)

In [None]:
imshow(patched_head_z_diff, title="Logit Difference From Patched Head Output", labels={"x":"Head", "y":"Layer"})

*   Again we can see that Layer 8 Head 2 and Layer 7 Head 11 have the biggest differences between logits


#Experiment 1: Head value patching

In [None]:
#Investigating the OV Circuit
patched_head_v_diff = torch.zeros(gpt_small.cfg.n_layers, gpt_small.cfg.n_heads, device="cuda", dtype=torch.float32)

for layer in range(gpt_small.cfg.n_layers):
    for head_index in range(gpt_small.cfg.n_heads):
        #For each head in every layer
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=original_cache)
         #Patch clean value vector into dirty one and see how its logits change
        patched_logits = gpt_small.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("v", layer, "attn"),
                hook_fn)],
            return_type="logits"
        )
        #Measure performance difference when clean is patched
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)
        #Normalize that difference
        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff,  original_average_logit_diff, corrupted_average_logit_diff)

In [None]:
imshow(patched_head_v_diff, title="Logit Difference From Patched Head Value", labels={"x":"Head", "y":"Layer"})

<a name="E1:VPP"></a>

<a name="E1:VP"></a>

In [None]:
head_labels = [f"L{l}H{h}" for l in range(gpt_small.cfg.n_layers) for h in range(gpt_small.cfg.n_heads)]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(gpt_small.cfg.n_layers), "layer -> (layer head)", head=gpt_small.cfg.n_heads),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs value patching")

*   We can see that most of the logit differences is likely due to the value as patching values results in near identical values to output patching



#Experiment 1: Head attention pattern patching

##Investing the QK circuit

In [None]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook,
    head_index,
    clean_cache):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][:, head_index, :, :]
    return corrupted_head_pattern

patched_head_attn_diff = torch.zeros(gpt_small.cfg.n_layers, gpt_small.cfg.n_heads, device="cuda", dtype=torch.float32)
for layer in range(gpt_small.cfg.n_layers):
    for head_index in range(gpt_small.cfg.n_heads):
        #For every head in every layer
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=original_cache)
        #Replace that head with corresponding attention pattern in clean one
        patched_logits = gpt_small.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("attn", layer, "attn"),
                hook_fn)],
            return_type="logits"
        )
        #Measure performance
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)
        #Normalize performance
        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff, original_average_logit_diff, corrupted_average_logit_diff)


<a name="E1:APP"></a>

In [None]:
imshow(patched_head_attn_diff, title="Logit Difference From Patched Head Pattern", labels={"x":"Head", "y":"Layer"})
head_labels = [f"L{l}H{h}" for l in range(gpt_small.cfg.n_layers) for h in range(gpt_small.cfg.n_heads)]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

*   Given the really small scale of the x-axis, it is likely that the attention isn't as important to the task as the value patching

<a name="E2:DPA"></a>
#Experiment 2: Defining the prompts and answers


In [None]:
re_correct_answers = [("'re"," are") if x%2==0 else (" are","'re") for x in range(8)]
re_correct_answers

[("'re", ' are'),
 (' are', "'re"),
 ("'re", ' are'),
 (' are', "'re"),
 ("'re", ' are'),
 (' are', "'re"),
 ("'re", ' are'),
 (' are', "'re")]

In [None]:
#Prompts where 're is relating to the varying pronouns **Note the above experimentation was 're they
re_she_prompts = ["We're helping our friend when she",
                "We are helping our friend when she",
                "They're eating dinner together while she",
                "They are eating dinner together while she",
                "We're driving uptown and she",
                "We are driving uptown and she",
                "They're sleeping too soon and she",
                "They are sleeping too soon and she"]

re_he_prompts = ["We're helping our friend when he",
                "We are helping our friend when he",
                "They're eating dinner together while he",
                "They are eating dinner together while he",
                "We're driving uptown and he",
                "We are driving uptown and he",
                "They're sleeping too soon and he",
                "They are sleeping too soon and he"]

re_we_prompts = ["We're helping our friend when we",
                "We are helping our friend when we",
                "They're eating dinner together while we",
                "They are eating dinner together while we",
                "We're driving uptown and we",
                "We are driving uptown and we",
                "They're sleeping too soon and we",
                "They are sleeping too soon and we"]

#Prompts where 's is relating to the varying pronouns
s_she_prompts = ["He's helping our friend when she",
                "He is helping our friend when she",
                "She's eating dinner together while she",
                "She is eating dinner together while she",
                "He's driving uptown and she",
                "He is driving uptown and she",
                "She's sleeping too soon and she",
                "She is sleeping too soon and she"]

s_he_prompts = ["He's helping our friend when he",
                "He is helping our friend when he",
                "She's eating dinner together while he",
                "She is eating dinner together while he",
                "He's driving uptown and he",
                "He is driving uptown and he",
                "She's sleeping too soon and he",
                "She is sleeping too soon and he"]

s_they_prompts = ["He's helping our friend when they",
                "He is helping our friend when they",
                "She's eating dinner together while they",
                "She is eating dinner together while they",
                "He's driving uptown and they",
                "He is driving uptown and they",
                "She's sleeping too soon and they",
                "She is sleeping too soon and they"]

s_we_prompts = ["He's helping our friend when we",
                "He is helping our friend when we",
                "She's eating dinner together while we",
                "She is eating dinner together while we",
                "He's driving uptown and we",
                "He is driving uptown and we",
                "She's sleeping too soon and we",
                "She is sleeping too soon and we"]

re_correct_answers = [("'re"," are") if x%2==0 else (" are","'re") for x in range(8)]
s_correct_answers = [("'s"," is") if x%2==0 else (" is","'s") for x in range(8)]

#Tokenizing the re answers
re_answer_tokens = []
for x in re_correct_answers:
  curr_tokens = []
  curr_tokens.append(gpt_small.to_single_token(x[0]))
  curr_tokens.append(gpt_small.to_single_token(x[1]))
  re_answer_tokens.append(curr_tokens)
re_answer_tokens = torch.tensor(re_answer_tokens).cuda()

#Tokenizing the s answers
s_answer_tokens = []
for x in s_correct_answers:
  curr_tokens = []
  curr_tokens.append(gpt_small.to_single_token(x[0]))
  curr_tokens.append(gpt_small.to_single_token(x[1]))
  s_answer_tokens.append(curr_tokens)
s_answer_tokens = torch.tensor(s_answer_tokens).cuda()

#Total experimental tokens
clean_prompts = [re_she_prompts, re_he_prompts, re_we_prompts, s_she_prompts, s_he_prompts, s_we_prompts, s_they_prompts]
clean_answer_tokens = [s_answer_tokens, s_answer_tokens, re_answer_tokens, s_answer_tokens, s_answer_tokens, re_answer_tokens, re_answer_tokens]
group_names = ["Current pairing: 're she", "Current Pairing: 're he", "Current Pairing: 're we", "Current Pairing: 's she", "Current Pairing: 's he",  "Current Pairing: 's we", "Current Pairing: 's they"]

clean_tokens = []
clean_cache = []
clean_logits = []

for group in clean_prompts:
  tokens = gpt_small.to_tokens(group, prepend_bos=True)
  clean_tokens.append(tokens.cuda())
  original_logits, cache = gpt_small.run_with_cache(tokens)
  clean_cache.append(cache)
  clean_logits.append(original_logits)

#Experiment 2: Evaluating performance on the task

In [None]:
group_clean_av_logit_diff = []

for group_ind in range(len(clean_logits)):
  print(group_names[group_ind])
  print("Per prompt logit difference:", logits_to_ave_logit_diff(clean_logits[group_ind], clean_answer_tokens[group_ind], per_prompt=True))
  group_clean_av_logit_diff.append(logits_to_ave_logit_diff(clean_logits[group_ind], clean_answer_tokens[group_ind]))
  print("Average logit difference:", logits_to_ave_logit_diff(clean_logits[group_ind], clean_answer_tokens[group_ind]).item())

Current pairing: 're she
Per prompt logit difference: tensor([ 1.9601,  0.5968,  2.2588,  1.0614,  2.6934, -0.1984,  3.0978,  1.2086],
       device='cuda:0')
Average logit difference: 1.5848093032836914
Current Pairing: 're he
Per prompt logit difference: tensor([ 2.0880,  0.7160,  2.1094,  1.4240,  2.9604, -0.1025,  3.2693,  1.3576],
       device='cuda:0')
Average logit difference: 1.7277706861495972
Current Pairing: 're we
Per prompt logit difference: tensor([1.4513, 1.3573, 2.8709, 0.3862, 3.2182, 0.9171, 2.8223, 1.0932],
       device='cuda:0')
Average logit difference: 1.7645436525344849
Current Pairing: 's she
Per prompt logit difference: tensor([2.6008, 1.3898, 2.9350, 1.9743, 3.5712, 1.7758, 3.8507, 1.6473],
       device='cuda:0')
Average logit difference: 2.4681220054626465
Current Pairing: 's he
Per prompt logit difference: tensor([2.6615, 1.0811, 2.7025, 2.5058, 4.0128, 0.7606, 3.6733, 2.0657],
       device='cuda:0')
Average logit difference: 2.4329159259796143
Current P

*   Can generally perform the task aside from two sentences. "We're driving uptown and she", "We're driving uptown and he"



#Experiment 2: Getting directions in activation space

In [None]:
group_logit_diff_directions = []
for clean_tkns in clean_answer_tokens:
  answer_residual_directions = gpt_small.tokens_to_residual_directions(clean_tkns)
  logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
  group_logit_diff_directions.append(logit_diff_directions)

#Experiment 2: Logit Lens

In [None]:
#Checking logits before MLP after Attention Head
for group_ind in range(len(clean_cache)):
  print(group_names[group_ind])
  accumulated_residual, labels = clean_cache[group_ind].accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
  logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, clean_cache[group_ind], clean_prompts[group_ind],group_logit_diff_directions[group_ind] )
  line(logit_lens_logit_diffs, x=np.arange(gpt_small.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

Current pairing: 're she


Current Pairing: 're he


Current Pairing: 're we


Current Pairing: 's she


Current Pairing: 's he


Current Pairing: 's we


Current Pairing: 's they


<a name="E2:LA"></a>
#Experiment 2: Layer Attribution

In [None]:
for group_ind in range(len(clean_cache)):
  print(group_names[group_ind])
  per_layer_residual, labels = clean_cache[group_ind].decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
  per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, clean_cache[group_ind], clean_prompts[group_ind], group_logit_diff_directions[group_ind])
  line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

Current pairing: 're she


Current Pairing: 're he


Current Pairing: 're we


Current Pairing: 's she


Current Pairing: 's he


Current Pairing: 's we


Current Pairing: 's they


*   Interesting to see that for he and she layer 10 MLP seems to matter a lot but not for the other ones for some reason
*   Also seems like the previously observed three layers matter the most for all different pairings (L7 attn, L8 MLP, L10, MLP)





<a name="E2:HA"></a>
#Experiment 2: Head Attribution

In [None]:
group_per_head_logit_diffs = []

for group_ind in range(len(clean_cache)):
  print(group_names[group_ind])
  per_head_residual, labels = clean_cache[group_ind].stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
  per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, clean_cache[group_ind], clean_prompts[group_ind], group_logit_diff_directions[group_ind])
  per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=gpt_small.cfg.n_layers, head_index=gpt_small.cfg.n_heads)
  imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")
  group_per_head_logit_diffs.append(per_head_logit_diffs)

Current pairing: 're she
Tried to stack head results when they weren't cached. Computing head results now


Current Pairing: 're he
Tried to stack head results when they weren't cached. Computing head results now


Current Pairing: 're we
Tried to stack head results when they weren't cached. Computing head results now


Current Pairing: 's she
Tried to stack head results when they weren't cached. Computing head results now


Current Pairing: 's he
Tried to stack head results when they weren't cached. Computing head results now


Current Pairing: 's we
Tried to stack head results when they weren't cached. Computing head results now


Current Pairing: 's they
Tried to stack head results when they weren't cached. Computing head results now


*   Head 11 Layer 7 seems to be the most important for all of them but aside from that and mostly (L8 H2, L10 H1) there is a good amount of variation!


#Experiment 2: Attention Analysis

In [None]:
top_k = 3
for group_ind in range(len(group_per_head_logit_diffs)):
  print(group_names[group_ind])
  top_positive_logit_attr_heads = torch.topk(group_per_head_logit_diffs[group_ind].flatten(), k=top_k).indices
  visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads", local_tokens=clean_tokens[group_ind][0])
  top_negative_logit_attr_heads = torch.topk(-group_per_head_logit_diffs[group_ind].flatten(), k=top_k).indices
  visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads", local_tokens=clean_tokens[group_ind][0])

Current pairing: 're she


Current Pairing: 're he


Current Pairing: 're we


Current Pairing: 's she


Current Pairing: 's he


Current Pairing: 's we


Current Pairing: 's they


###Observations
*   're she, 're he, 're we, 're they, 's she, 's he (L7 H11, L8 H2, L10 H1)
*   's we, 's they (L7 H11, L8 H2, L11 H8)
*   Why do the attention layers for 11 matter? Maybe this is relatively insignificant compared to L7 H11 as the majority of the work is done in the MLP?




In [None]:
top_k = 3
for group_ind in range(len(group_per_head_logit_diffs)):
  print(group_names[group_ind])
  top_positive_logit_attr_heads = torch.topk(group_per_head_logit_diffs[group_ind].flatten(), k=top_k).values
  print(top_positive_logit_attr_heads)

Current pairing: 're she
tensor([0.1745, 0.1183, 0.0316], device='cuda:0')
Current Pairing: 're he
tensor([0.2091, 0.1181, 0.0367], device='cuda:0')
Current Pairing: 're we
tensor([0.3249, 0.1450, 0.0760], device='cuda:0')
Current Pairing: 's she
tensor([0.3514, 0.0919, 0.0861], device='cuda:0')
Current Pairing: 's he
tensor([0.3927, 0.0699, 0.0627], device='cuda:0')
Current Pairing: 's we
tensor([0.3725, 0.1641, 0.0793], device='cuda:0')
Current Pairing: 's they
tensor([0.5018, 0.1667, 0.0775], device='cuda:0')


*   Seems that the third index (L11 H8) is not totally insignificant but is less than half the value of the most important head so I will proceed with the investigation.


#Experiment 2: Corrupted Prompt Generation

In [None]:
group_corrupted_prompts = []
group_corrupted_tokens = []
group_corrupted_av_logit_diff = []

for group_ind in range(len(clean_prompts)):
  print(group_names[group_ind])
  corrupted_prompts = []
  for i in range(0, len(clean_prompts[group_ind]), 2):
      corrupted_prompts.append(clean_prompts[group_ind][i+1])
      corrupted_prompts.append(clean_prompts[group_ind][i])
  corrupted_tokens = gpt_small.to_tokens(corrupted_prompts, prepend_bos=True)
  corrupted_logits, corrupted_cache = gpt_small.run_with_cache(corrupted_tokens, return_type="logits")
  corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, clean_answer_tokens[group_ind])
  group_corrupted_prompts.append(corrupted_prompts)
  group_corrupted_tokens.append(corrupted_tokens)
  group_corrupted_av_logit_diff.append(corrupted_average_logit_diff)

  print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
  print("Clean Average Logit Diff", group_clean_av_logit_diff[group_ind])

Current pairing: 're she
Corrupted Average Logit Diff tensor(-1.5848, device='cuda:0')
Clean Average Logit Diff tensor(1.5848, device='cuda:0')
Current Pairing: 're he
Corrupted Average Logit Diff tensor(-1.7278, device='cuda:0')
Clean Average Logit Diff tensor(1.7278, device='cuda:0')
Current Pairing: 're we
Corrupted Average Logit Diff tensor(-1.7645, device='cuda:0')
Clean Average Logit Diff tensor(1.7645, device='cuda:0')
Current Pairing: 's she
Corrupted Average Logit Diff tensor(-2.4681, device='cuda:0')
Clean Average Logit Diff tensor(2.4681, device='cuda:0')
Current Pairing: 's he
Corrupted Average Logit Diff tensor(-2.4329, device='cuda:0')
Clean Average Logit Diff tensor(2.4329, device='cuda:0')
Current Pairing: 's we
Corrupted Average Logit Diff tensor(-2.0183, device='cuda:0')
Clean Average Logit Diff tensor(2.0183, device='cuda:0')
Current Pairing: 's they
Corrupted Average Logit Diff tensor(-2.2866, device='cuda:0')
Clean Average Logit Diff tensor(2.2866, device='cuda:0')

#Experiment 2: Residual Stream Patching

In [None]:
group_patched_residual_stream_diff = []

for group_ind in range(len(clean_prompts)):
  patched_residual_stream_diff = torch.zeros(gpt_small.cfg.n_layers, clean_tokens[group_ind].shape[1], device="cuda", dtype=torch.float32)

  for layer in range(gpt_small.cfg.n_layers):
    for position in range(clean_tokens[group_ind].shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache[group_ind])

        patched_logits = gpt_small.run_with_hooks(
            group_corrupted_tokens[group_ind],
            fwd_hooks = [(utils.get_act_name("resid_pre", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, clean_answer_tokens[group_ind])

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff, group_clean_av_logit_diff[group_ind], group_corrupted_av_logit_diff[group_ind])

  group_patched_residual_stream_diff.append(patched_residual_stream_diff)

<a name="E2:RSP"></a>

In [None]:
for group_ind in range(len(clean_prompts)):
  print(group_names[group_ind])
  prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt_small.to_str_tokens(clean_tokens[group_ind][0]))]
  imshow(group_patched_residual_stream_diff[group_ind], x=prompt_position_labels, title="Logit Difference From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})

Current pairing: 're she


Current Pairing: 're he


Current Pairing: 're we


Current Pairing: 's she


Current Pairing: 's he


Current Pairing: 's we


Current Pairing: 's they


*   For all the tasks it looks like the layer attribution is largely the same



#Experiment 2: MLP and Attention layer patching

In [None]:
group_patched_attn_diff = []
group_patched_mlp_diff = []

for group_ind in range(len(clean_prompts)):
  patched_attn_diff = torch.zeros(gpt_small.cfg.n_layers, clean_tokens[group_ind].shape[1], device="cuda", dtype=torch.float32)
  patched_mlp_diff = torch.zeros(gpt_small.cfg.n_layers, clean_tokens[group_ind].shape[1], device="cuda", dtype=torch.float32)
  for layer in range(gpt_small.cfg.n_layers):
      for position in range(clean_tokens[group_ind].shape[1]):
          hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache[group_ind])
          patched_attn_logits = gpt_small.run_with_hooks(
              group_corrupted_tokens[group_ind],
              fwd_hooks = [(utils.get_act_name("attn_out", layer),
                  hook_fn)],
              return_type="logits"
          )

          patched_attn_logit_diff = logits_to_ave_logit_diff(patched_attn_logits, clean_answer_tokens[group_ind])
          patched_mlp_logits = gpt_small.run_with_hooks(
               group_corrupted_tokens[group_ind],
              fwd_hooks = [(utils.get_act_name("mlp_out", layer),
                  hook_fn)],
              return_type="logits"
          )
          patched_mlp_logit_diff = logits_to_ave_logit_diff(patched_mlp_logits, clean_answer_tokens[group_ind])

          patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff, group_clean_av_logit_diff[group_ind], group_corrupted_av_logit_diff[group_ind])
          patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff, group_clean_av_logit_diff[group_ind], group_corrupted_av_logit_diff[group_ind])
  group_patched_attn_diff.append(patched_attn_diff)
  group_patched_mlp_diff.append(patched_mlp_diff)

<a name="E2:ALP"></a>

In [None]:
for group_ind in range(len(clean_prompts)):
  print(group_names[group_ind])
  prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt_small.to_str_tokens(clean_tokens[group_ind][0]))]
  imshow(group_patched_attn_diff[group_ind], x=prompt_position_labels, title="Logit Difference From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})

Current pairing: 're she


Current Pairing: 're he


Current Pairing: 're we


Current Pairing: 's she


Current Pairing: 's he


Current Pairing: 's we


Current Pairing: 's they


##Some interesting behaviors
*   For 're he and 're she pairings, L8 is slightly more significant that L7 yet for some pairings like 's she and 's he L8 is much less significant than L7.
*   Possible explanation: Could this be becuase in this case it is copying the token as the contraction is the same so it would rely more heavily on one attention head?
*   Actually am leaning away from believing this compeltely as in the 're we case L7 is significantly more important but the difference between the two is less significant than the 's he and she cases. This could still be what is happening to some degree but I would like to do further analysis before believeing this completely.
*   Also L10 varies in value, it seems to be lower for 's they and 's we.











<a name="E2:MLPP"></a>

In [None]:
for group_ind in range(len(clean_prompts)):
  print(group_names[group_ind])
  prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt_small.to_str_tokens(clean_tokens[group_ind][0]))]
  imshow(group_patched_mlp_diff[group_ind], x=prompt_position_labels, title="Logit Difference From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})

Current pairing: 're she


Current Pairing: 're he


Current Pairing: 're we


Current Pairing: 's she


Current Pairing: 's he


Current Pairing: 's we


Current Pairing: 's they


#Some quick observations
*   For the MLP differences it looks as though the combinations terminating in he she and the combinations terminating in they we are grouped together. This makes sense as both groups have their own contractions which would explain the differences between two groups and similarity within.
*   This leads me to consider the possiblity that these MLP's are mainly boosting certain contractions over others.
*   Again would want to make more examples and do some more analysis before I can say this is the case.



[jump to top](#top)