[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Matthew-Jennings/mech-interp-explore/blob/main/capital_letter_follows_full_stop_CONTINUED.ipynb)

# How does GPT-2 Small Predict Capital Letters After Full Stops? - *Exploratory Analysis*
### Matthew Jennings

## Imports

In [1]:
from functools import partial
from pprint import pprint
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

from helpers import cumul_probs_by_capitalisation_type, correct_incorrect_answers_for_top_spacetitleword

## Setup PyTorch

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x13fcd46b190>

## Define plotting helper functions

In [3]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Load Model

In [4]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()
print(f"Pytorch device: {device}")

Loaded pretrained model gpt2-small into HookedTransformer
Pytorch device: cuda


  return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)


## Initial Investigations

### Simple first example.

Does the model correctly predict capital letters after a full stop?

In [5]:
example_prompt = "This is a sentence."
answer = " It"

utils.test_prompt(example_prompt, answer, model, top_k=100)

Tokenized prompt: ['<|endoftext|>', 'This', ' is', ' a', ' sentence', '.']
Tokenized answer: [' It']


Top 0th token. Logit: 15.64 Prob: 17.54% Token: |
|
Top 1th token. Logit: 15.14 Prob: 10.67% Token: | It|
Top 2th token. Logit: 14.47 Prob:  5.44% Token: | This|
Top 3th token. Logit: 14.39 Prob:  5.03% Token: | I|
Top 4th token. Logit: 14.37 Prob:  4.94% Token: | The|
Top 5th token. Logit: 14.09 Prob:  3.75% Token: | If|
Top 6th token. Logit: 14.07 Prob:  3.68% Token: | A|
Top 7th token. Logit: 13.80 Prob:  2.79% Token: | You|
Top 8th token. Logit: 13.26 Prob:  1.63% Token: | Please|
Top 9th token. Logit: 12.99 Prob:  1.25% Token: | We|
Top 10th token. Logit: 12.97 Prob:  1.22% Token: | In|
Top 11th token. Logit: 12.87 Prob:  1.10% Token: | There|
Top 12th token. Logit: 12.83 Prob:  1.06% Token: |

|
Top 13th token. Logit: 12.71 Prob:  0.94% Token: | For|
Top 14th token. Logit: 12.66 Prob:  0.90% Token: | An|
Top 15th token. Logit: 12.45 Prob:  0.73% Token: | See|
Top 16th token. Logit: 12.38 Prob:  0.67% Token: |<|endoftext|>|
Top 17th token. Logit: 12.26 Prob:  0.60% Token: | "|
Top

In [6]:
print("Cumulative probabilities for next token by 'pattern' for typical sentence:")
logits, _ = model.run_with_cache(example_prompt, remove_batch_dim=True)
cumul_probs_by_capitalisation_type(logits[:, -1], model, print_=True)

print("\nCumulative probabilities for next token by 'pattern' for all lowercase sentence:")
logits, _ = model.run_with_cache(example_prompt.lower(), remove_batch_dim=True)
cumul_probs_by_capitalisation_type(logits[:, -1], model, print_=True)

print("\nCumulative probabilities for next token by 'pattern' for all uppercase sentence:")
logits, _ = model.run_with_cache(example_prompt.upper(), remove_batch_dim=True)
cumul_probs_by_capitalisation_type(logits[:, -1], model, print_=True);

Cumulative probabilities for next token by 'pattern' for typical sentence:
Space, First Char Uppercase: 77.78%
Other: 21.19%
No Space, First Char Uppercase: 0.35%
Space, First Char Lowercase: 0.32%
Space, First Char Numeral: 0.23%
No Space, First Char Lowercase: 0.10%
No Space, First Char Numeral: 0.03%

Cumulative probabilities for next token by 'pattern' for all lowercase sentence:
Space, First Char Uppercase: 42.14%
Space, First Char Lowercase: 32.01%
Other: 19.68%
No Space, First Char Lowercase: 4.22%
No Space, First Char Uppercase: 1.20%
Space, First Char Numeral: 0.59%
No Space, First Char Numeral: 0.15%

Cumulative probabilities for next token by 'pattern' for all uppercase sentence:
Space, First Char Uppercase: 75.63%
Other: 21.66%
No Space, First Char Uppercase: 2.00%
Space, First Char Numeral: 0.40%
Space, First Char Lowercase: 0.21%
No Space, First Char Numeral: 0.06%
No Space, First Char Lowercase: 0.03%


#### Observations
- Interesting! The top result is a newline. Inspecting the rest of the results provides a nice reminder that tokens of the form `<space><capital-letter>` *are not only valid tokens to follow a full stop.* Other examples include:
  - `"\n"`
  - `"\n\n"`
  - `"<|endoftext|>"`
  - Non-alphanumeric chars (preceded by spaces): ` (`, ` "`

- If we treat any answer of the forms `<space><titleword>` and `Other` as valid, then, a basic sampling of the next token will produce a valid token $77.8 + 21.2 = 99\%$ of the time.

- The same prompt adjusted to all *lowercase* letters - particularly the first letter of the first word - significantly raises the probability that a token of the form `<space><lowercase_word>` would be basic sampled from 0.32% to 32% - a 100x increase!

- The same prompt adjusted to all *uppercase* letters results in a similar probability distribution by token pattern as the original prompt, albeit tokens of the form `<no_space><titleword>` increase in probaility from 0.35% to 2%

### Numeric example

In [7]:
numeric_example_prompt = (
    "The probability of me getting to the bottom of this circuit is not 0."
)

utils.test_prompt(numeric_example_prompt, answer, model, top_k=100)
print("\nCumulative probabilities for next token by 'pattern' for typical sentence ending in the numeral 0:")
logits, _ = model.run_with_cache(numeric_example_prompt, remove_batch_dim=True)
cumul_probs_by_capitalisation_type(logits[:, -1], model, print_=True);

Tokenized prompt: ['<|endoftext|>', 'The', ' probability', ' of', ' me', ' getting', ' to', ' the', ' bottom', ' of', ' this', ' circuit', ' is', ' not', ' 0', '.']
Tokenized answer: [' It']


Top 0th token. Logit: 16.66 Prob:  8.34% Token: |
|
Top 1th token. Logit: 16.62 Prob:  7.99% Token: |5|
Top 2th token. Logit: 16.19 Prob:  5.20% Token: |1|
Top 3th token. Logit: 15.69 Prob:  3.16% Token: |01|
Top 4th token. Logit: 15.61 Prob:  2.91% Token: | The|
Top 5th token. Logit: 15.59 Prob:  2.87% Token: |001|
Top 6th token. Logit: 15.58 Prob:  2.84% Token: |0001|
Top 7th token. Logit: 15.56 Prob:  2.77% Token: | If|
Top 8th token. Logit: 15.55 Prob:  2.76% Token: |0|
Top 9th token. Logit: 15.33 Prob:  2.20% Token: | It|
Top 10th token. Logit: 15.13 Prob:  1.81% Token: |9|
Top 11th token. Logit: 15.08 Prob:  1.72% Token: | But|
Top 12th token. Logit: 14.95 Prob:  1.51% Token: |000|
Top 13th token. Logit: 14.95 Prob:  1.51% Token: |6|
Top 14th token. Logit: 14.94 Prob:  1.49% Token: |2|
Top 15th token. Logit: 14.87 Prob:  1.39% Token: |25|
Top 16th token. Logit: 14.81 Prob:  1.32% Token: |05|
Top 17th token. Logit: 14.81 Prob:  1.32% Token: |8|
Top 18th token. Logit: 14.81 Prob:  


Cumulative probabilities for next token by 'pattern' for typical sentence ending in the numeral 0:
No Space, First Char Numeral: 61.18%
Space, First Char Uppercase: 27.35%
Other: 10.12%
Space, First Char Numeral: 0.72%
No Space, First Char Uppercase: 0.30%
Space, First Char Lowercase: 0.27%
No Space, First Char Lowercase: 0.06%


#### Observations:
- As with the original sentence, the top predicted token for this new example sentence with a numeric final character is a newline character!

- The next highest prediction `5`, with no preceding space. Interesting that a token of the form `<space><titleword>` is *not* more likely.

- Of the top 20 predictions, only 5 are non-numeric (one of which is a newline).

- Valid answers of the form `<no_space><numeral>`, `<space><titleword>` and "Other" will be basic sampled with just under 99% probability. An numeric, immediately next character is most likely, at 61%

## More Rigorous

### Investigate multiple, very simple prompts.
- Single sentence.
- Same number of tokens.
- No proper nouns
- No other punctuation
- No non-alphabetical characters other than spaces.

In [12]:
prompts_simple = [
    "This is a very slightly longer sentence.",
    "The cat loves to scratch my car.",
    "Kangaroos should be ridden.",
    "That car driver drives rather recklessly.",
    "It is wonderful to be alive today.",
    "This is a difficult problem to solve.", 
    "Koalas are grumpy cute.",
]
prompt_token_counts = [len(model.to_str_tokens(prompt)) for prompt in prompts_simple]
print(prompt_token_counts)
assert all(count == prompt_token_counts[0] for count in prompt_token_counts)

[9, 9, 9, 9, 9, 9, 9]


In [13]:
logits, _ = model.run_with_cache(prompts_simple)

logits_final_sorted, logits_final_sorted_idx = logits[:, -1, :].sort(dim=-1, descending=True)

answer_tokens, answer_str_tokens = correct_incorrect_answers_for_top_spacetitleword(logits_final_sorted_idx, model)

print(answer_str_tokens)

[(' I', ' i', 'I', 'i'), (' It', ' it', 'It', 'it'), (' They', ' they', 'They', 'they'), (' He', ' he', 'He', 'he'), (' I', ' i', 'I', 'i'), (' The', ' the', 'The', 'the'), (' They', ' they', 'They', 'they')]


### Generate new answers:

In [None]:
from pprint import pprint
answer_tokens, answer_str_tokens = titleword_answer_generator(logits_idx_sorted)
pprint(answer_str_tokens)
print(answer_tokens)

[' Go', '\n', '\n', ' He', ' I', ' It', '\n', '\n']

### New logit diffs:

In [None]:
idx_to_incorrect_answer_label = dict(
    (
        (1, "lowercase"),
        (2, "missing space"),
        (3, "lowercase and missing space"),
    )
)
print("Prompts:")
for prompt in prompts:
    print(f"'{prompt}'")

print(f"\nTop space-titleword prediction per prompt:\n{[row[0] for row in answer_str_tokens]}")
for incorrect_idx in range(1, answer_tokens.size(1)):
    print(
        f"\nLogit diffs for incorrect answer type: '{idx_to_incorrect_answer_label[incorrect_idx]}':"
    )
    logit_diffs = logits_to_ave_logit_diff_2(
        logits, answer_tokens, per_prompt=True, incorrect_idx=incorrect_idx
    )

"""Logit diffs for incorrect answer type: 'lowercase':
Per prompt logit difference: tensor([5.5300, 5.4780, 5.8090, 6.1450, 6.2310, 7.7480, 8.2100, 6.4700])
Average logit difference: 6.453
"""

#### Observations

- All values high now. Avoid sentences that are expected not to conform to standard English prose?

### Logit Lens

In [None]:
incorrect_idx = 1

original_average_logit_diff = logits_to_ave_logit_diff_2(logits, answer_tokens, per_prompt=False, incorrect_idx=incorrect_idx, print_=False)

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, incorrect_idx]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

### Still different.
- I think a possible cause is that I have not not been picking the top answer for each prompt
as the correct answer.

  - E.g., there may be some difference in scaling across the prompts that makes comparing
an average invalid. 



Quick check:

In [None]:
top_answers = [
    (" Go", " go"),
    ("\n", " i"),
    ("\n", " i"),
    ("\n", " i"),
    (" I", " i"),
    (" It", " it"),
    ("\n", " i"),
    ("\n", " i"),
]

top_answer_tokens = torch.tensor(
    [[model.to_single_token(ans_row[0]), model.to_single_token(ans_row[1])] for ans_row in top_answers]
).to(device)
print(top_answer_tokens)
logits_top_answer, cache_top_answer = model.run_with_cache(top_answer_tokens)
original_average_logit_diff = logits_to_ave_logit_diff_2(logits_top_answer, top_answer_tokens, per_prompt=False, incorrect_idx=1, print_=False)

answer_residual_directions = model.tokens_to_residual_directions(top_answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

top_answer_logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", top_answer_logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache_top_answer["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache_top_answer.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    top_answer_logit_diff_directions,
) / len(top_answers)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Hmm. Still no good.

## In the interest of time, press on

- In normal circumstances, I would seek advice from a colleague/mentor on cause/importance of this difference in scaling.

- I have a small hunch that the difference may not overly adversely impact meaningfulness of results.

- Press on as if no difference for now

In [None]:
original_average_logit_diff = logits_to_ave_logit_diff_2(logits, answer_tokens, per_prompt=False, incorrect_idx=1, print_=False)

## Logit Lens

In [None]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

In [None]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulated Residual Stream",
)

### Layer Attribution

In [None]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

#### Observation

- It appears that MLP layers are primarily what matters.

- Especially the later ones (with increasing effect toward the later layers)

- With my limited understanding of transformers, this makes some sense. It seems to me that the most important information for predicting a space-titleword token next is the nature of the *current* position/token: it is a "sentence terminating character". Since MLP layers are used to process information at a position, I think this info is likely "determined/used" in MLP layera.

### Head Attribution

- Just out of interest, of course heads are parts of attention layers, which are apparently relatively unimportant - at least directly - according to the layer attribution graph above.

In [None]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

#### Observations
- `0.1`, `10.7` and `11.0` look most interesting, in at least relative to the others.

- Note the colormap  scale. Not much effect - at least as compared to IOI! 

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [None]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

#### Observations:

- In the strongest head, `11.0`, token `1` attends strongly (relatively speaking) to token `0` (`<|endoftext|>`). This might contribute to capitalisation of token `1`?

- In `0.1`, the final token (`"."`) attends fairly strongly to token `7`, which is another full stop in at least one of the prompts.

- In `10.7`, attention to the second to last token negatively contributes to the logits of the final token.
  - One (major?) flaw in this analysis is investigating only the first titleword token. For the selection of prompts I chose, perhaps the token immediately preceding the final full stop contributes to a stronger prediction of some other token (the newline?)

## Residual Stream Patching

- **Note**: It's difficult to for me to know how much wisdom I can draw from this section. The "corrupted" prompts below do significantly increase the logits for the \<space\>\<**lowercase**\> version of the next token, but they do not meaningfully *reduce* the logits of the "clean" token. It is not a neat reversal like the IOI example

- The problem is: I'm struggling to think of a way to manufacture such a reversal, especially using the *relevant circuitry*. I'm sure I am missing something, but this seems like a difficult problem to provide an input that tests the "capitalise next word after full stop" circuit for which I can provide a corrupted input that produces an uncapitalised first word.
  - E.g., perhaps I could exclusively use **repeat sequences** of a single word (e.g., clean version:`["Go. Go. Go. Go"]`, corrupted version:`["go. go. go. go"]`), but it seems likely that some *different circuit* (e.g., duplication) produces capitalised output in the clean case.

In [None]:
prompts_corrupted = [
    "filling up to eleven tokens.\ngo.", 
    "hello. hello. hello. hello. hello.", 
    "yeah Matt doesn't know what he is doing.", 
    "that Will does not know where he is going.",
    "someone should really help them out, I think.", 
    "it is wonderful to be in Adelaide in March.",
    "koalas are cute, but grumpy.", 
    "this is a sentence. this is another sentence.",
]

tokens_corrupted = model.to_tokens(prompts_corrupted)

logits_corrupted, cache_corrupted = model.run_with_cache(tokens_corrupted)
logits_final_corrupted = logits_corrupted[:, -1, :]
logits_sorted_corrupted, logits_idx_sorted_corrupted = logits_final_corrupted.sort(
    descending=True, stable=True, dim=-1
)
average_logit_diff_corrupted = logits_to_ave_logit_diff_2(logits_corrupted, answer_tokens, per_prompt=False, incorrect_idx=1, print_=True)
print("\nCorrupted Average Logit Diff", round(average_logit_diff_corrupted.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))
utils.test_prompt(prompts_corrupted[1], answer_str_tokens[1][0], model, top_k=1)
utils.test_prompt(prompts_corrupted[1], answer_str_tokens[1][1], model, top_k=1)

    """
Per prompt logit difference: tensor([ 1.3570, -2.6480, -1.0820,  3.3580,  2.7450,  2.0550,  2.8190, -2.4030])
Average logit difference: 0.775
    """

#### Observation:

- 5 out of 8 did not even reverse! But any positive logit differences are smaller.

- That said, at least for these prompts, ***capitalisation of the next "word-like" token appears to depend fairly strongly on whether or not the *first* token (not token `0`/`<|endoftext|>`) was capitalised.***

In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - average_logit_diff_corrupted) / (
        original_average_logit_diff - average_logit_diff_corrupted
    )


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            tokens_corrupted,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)

#### Observation

- **With all of the above caveats**
  - It *does* appear that some of the contributory computation is happening on the first token, which is ordinarily capitalised.

  - At around layers 5-8, the information is moved to the final token?
    - Are the initial layers determining a "capitalisation-state" datum for the first token in the sequence, and using this to help determine the token following the full stop?

  - Not much other evidence for this though. Attention heads in these layers seem to weakly affect logits directly. Let's have a squiz anyway:

## Heads in layers 5 - 8

In [None]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

html = ""
for layer in range(5, 9):
    heads = range(layer*model.cfg.n_heads, (layer+1)*model.cfg.n_heads)

    html += visualize_attention_patterns(
        heads,
        cache,
        tokens[0],
        f"Layer {layer} Logit Attribution Heads",
    )

HTML(html)

### Interpretation

- Not seeing any evidence of attention paid to F by fullstop in layers 6-8...

- Missing something.