# How does GPT-2 Small Predict Capital Letters After Full Stops? - *Exploratory Analysis*
### Matthew Jennings

## Imports

In [61]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

## Setup PyTorch

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [62]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x248983edc10>

## Define plotting helper functions

In [63]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Load Model

In [64]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()
print(f"Pytorch device: {device}")

Loaded pretrained model gpt2-small into HookedTransformer
Pytorch device: cuda


## Task performance 

### Simple first example.

Does the model correctly predict capital letters after a full stop?

In [65]:
example_prompt = "This is a sentence."
answer = " It"

utils.test_prompt(example_prompt, answer, model, top_k=100)

Tokenized prompt: ['<|endoftext|>', 'This', ' is', ' a', ' sentence', '.']
Tokenized answer: [' It']


Top 0th token. Logit: 15.64 Prob: 17.54% Token: |
|
Top 1th token. Logit: 15.14 Prob: 10.67% Token: | It|
Top 2th token. Logit: 14.47 Prob:  5.44% Token: | This|
Top 3th token. Logit: 14.39 Prob:  5.03% Token: | I|
Top 4th token. Logit: 14.37 Prob:  4.94% Token: | The|
Top 5th token. Logit: 14.09 Prob:  3.75% Token: | If|
Top 6th token. Logit: 14.07 Prob:  3.68% Token: | A|
Top 7th token. Logit: 13.80 Prob:  2.79% Token: | You|
Top 8th token. Logit: 13.26 Prob:  1.63% Token: | Please|
Top 9th token. Logit: 12.99 Prob:  1.25% Token: | We|
Top 10th token. Logit: 12.97 Prob:  1.22% Token: | In|
Top 11th token. Logit: 12.87 Prob:  1.10% Token: | There|
Top 12th token. Logit: 12.83 Prob:  1.06% Token: |

|
Top 13th token. Logit: 12.71 Prob:  0.94% Token: | For|
Top 14th token. Logit: 12.66 Prob:  0.90% Token: | An|
Top 15th token. Logit: 12.45 Prob:  0.73% Token: | See|
Top 16th token. Logit: 12.38 Prob:  0.67% Token: |<|endoftext|>|
Top 17th token. Logit: 12.26 Prob:  0.60% Token: | "|
Top

#### Observations
- Interesting! The top result is a newline. Inspecting the rest of the results provides a nice reminder that tokens of the form `<space><capital-letter>` *are not only valid tokens to follow a full stop.* Other examples include:
  - "\n"
  - "\n\n"
  - `"<|endoftext|>"`
  - Non-alphanumeric chars (preceded by spaces): ` (`, ` "`

- No numeric tokens are present. I expect logits for numerical tokens to increase for if the last character preceding the full stop was numeric. Quickly test below: 

### Numeric example

In [66]:
numeric_example_prompt = (
    "The probability of me getting to the bottom of this circuit is not 0."
)
numeric_answer = " It"

utils.test_prompt(numeric_example_prompt, numeric_answer, model, top_k=100)

Tokenized prompt: ['<|endoftext|>', 'The', ' probability', ' of', ' me', ' getting', ' to', ' the', ' bottom', ' of', ' this', ' circuit', ' is', ' not', ' 0', '.']
Tokenized answer: [' It']


Top 0th token. Logit: 16.66 Prob:  8.34% Token: |
|
Top 1th token. Logit: 16.62 Prob:  7.99% Token: |5|
Top 2th token. Logit: 16.19 Prob:  5.20% Token: |1|
Top 3th token. Logit: 15.69 Prob:  3.16% Token: |01|
Top 4th token. Logit: 15.61 Prob:  2.91% Token: | The|
Top 5th token. Logit: 15.59 Prob:  2.87% Token: |001|
Top 6th token. Logit: 15.58 Prob:  2.84% Token: |0001|
Top 7th token. Logit: 15.56 Prob:  2.77% Token: | If|
Top 8th token. Logit: 15.55 Prob:  2.76% Token: |0|
Top 9th token. Logit: 15.33 Prob:  2.20% Token: | It|
Top 10th token. Logit: 15.13 Prob:  1.81% Token: |9|
Top 11th token. Logit: 15.08 Prob:  1.72% Token: | But|
Top 12th token. Logit: 14.95 Prob:  1.51% Token: |000|
Top 13th token. Logit: 14.95 Prob:  1.51% Token: |6|
Top 14th token. Logit: 14.94 Prob:  1.49% Token: |2|
Top 15th token. Logit: 14.87 Prob:  1.39% Token: |25|
Top 16th token. Logit: 14.81 Prob:  1.32% Token: |05|
Top 17th token. Logit: 14.81 Prob:  1.32% Token: |8|
Top 18th token. Logit: 14.81 Prob:  

#### Observations:
- Top token is `5`! No preceding space. (Upon reruns, sometimes it is second after `"\n"`")
- Of the top 20 predictions, only 5 are non-numeric (one of which is a newline).

- Will need to consider this in circuit exploration
  - Possibly (likely?) a different circuit handles decimal points

In [67]:
uncapitalised_example_prompt = (
    "this is an uncapitalised sentence."
)
uncapitalised_answer = " It"

utils.test_prompt(uncapitalised_example_prompt, uncapitalised_answer, model, top_k=10)

Tokenized prompt: ['<|endoftext|>', 'this', ' is', ' an', ' unc', 'ap', 'ital', 'ised', ' sentence', '.']
Tokenized answer: [' It']


Top 0th token. Logit: 12.34 Prob: 14.65% Token: |
|
Top 1th token. Logit: 11.46 Prob:  6.12% Token: | I|
Top 2th token. Logit: 10.95 Prob:  3.66% Token: |

|
Top 3th token. Logit: 10.69 Prob:  2.81% Token: | It|
Top 4th token. Logit: 10.68 Prob:  2.81% Token: | it|
Top 5th token. Logit: 10.64 Prob:  2.69% Token: | the|
Top 6th token. Logit: 10.49 Prob:  2.32% Token: | i|
Top 7th token. Logit: 10.35 Prob:  2.01% Token: | The|
Top 8th token. Logit: 10.18 Prob:  1.69% Token: | This|
Top 9th token. Logit: 10.07 Prob:  1.52% Token: | if|


## Multiple examples - stick with English prose for now

In [68]:
prompts = [
    "Go.",
    "Hello.",
    "Matthew doesn't know what he is doing.",
    "This is a sentence. This is another sentence.",
    "The enigmatic, silver-haired professor, known for his eccentric lectures on quantum entanglement and the nature of reality, embarked on a perilous journey through the mist-shrouded mountains of Bhutan, seeking an ancient, mystical artifact rumored to hold the key to unlocking the secrets of the universe, while his loyal assistant, a quick-witted and resourceful graduate student with a penchant for solving cryptic puzzles, followed close behind, armed only with a weathered journal, a compass, and an unwavering determination to unravel the mystery that had haunted her mentor for decades.",
]

### Run model and get logits and cache for prompts

In [69]:
model.tokenizer.padding_side = "left"
tokens = model.to_tokens(prompts)

logits, cache = model.run_with_cache(tokens)
print(logits.size())

logits_final = logits[:, -1, :]
print(logits_final.size())

logits_sorted, logits_idx_sorted = logits_final.sort(
    descending=True, stable=True, dim=-1
)
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)
print()
print(f"Top prediction each prompt:\n{model.to_str_tokens(logits_idx_sorted[:, 1])}")

torch.Size([5, 116, 50257])
torch.Size([5, 50257])
Prompt length: 3
Prompt as tokens: ['<|endoftext|>', 'Go', '.']
Prompt length: 3
Prompt as tokens: ['<|endoftext|>', 'Hello', '.']
Prompt length: 10
Prompt as tokens: ['<|endoftext|>', 'Matthew', ' doesn', "'t", ' know', ' what', ' he', ' is', ' doing', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'This', ' is', ' a', ' sentence', '.', ' This', ' is', ' another', ' sentence', '.']
Prompt length: 116
Prompt as tokens: ['<|endoftext|>', 'The', ' enigmatic', ',', ' silver', '-', 'haired', ' professor', ',', ' known', ' for', ' his', ' eccentric', ' lectures', ' on', ' quantum', ' ent', 'ang', 'lement', ' and', ' the', ' nature', ' of', ' reality', ',', ' embarked', ' on', ' a', ' perilous', ' journey', ' through', ' the', ' mist', '-', 'sh', 'roud', 'ed', ' mountains', ' of', ' Bh', 'utan', ',', ' seeking', ' an', ' ancient', ',', ' mystical', ' artifact', ' rumored', ' to', ' hold', ' the', ' key', ' to', ' unlocking', ' th

### Top answers filtered by '\<space\>\<titleword\>' and corresponding incorrect answers

- Could be missing something important here!

- Implications for not taking very top? Hope it doesn't break later code...

In [70]:
import re

space_titleword_pattern = re.compile("\s[A-Z]\w*")

def titleword_answer_generator(logits_idx_sorted):
    """Generate a sequence of correct/incorrect answer tuples by sampling
    the first '<space><titleword>' token in `logits_sorted_idx` and
    finding the corresponding incorrect versions. I.e.,
    '<titleword>' (no space), '<space><lowerword>', and '<lowerword>'
    """
    answer_str_tokens = []
    answer_tokens = []

    for logits_sorted_for_prompt in logits_idx_sorted:

        predictions_sorted = model.to_str_tokens(logits_sorted_for_prompt)
        top_space_titleword_prediction = next(
            pred
            for pred in predictions_sorted[:100]
            if space_titleword_pattern.match(pred)
        )
        top_space_titleword_token = model.to_single_token(
            top_space_titleword_prediction
        )

        lowercase_counterpart_str_token = top_space_titleword_prediction.lower()
        lowercase_counterpart_token = model.to_single_token(
            lowercase_counterpart_str_token
        )

        nospace_counterpart_str_token = top_space_titleword_prediction.lstrip()
        nospace_counterpart_token = model.to_single_token(nospace_counterpart_str_token)

        nospace_lowercase_counterpart_str_token = (
            lowercase_counterpart_str_token.lstrip()
        )
        nospace_lowercase_counterpart_token = model.to_single_token(
            nospace_lowercase_counterpart_str_token
        )

        answer_str_tokens.append(
            (
                top_space_titleword_prediction,
                lowercase_counterpart_str_token,
                nospace_counterpart_str_token,
                nospace_lowercase_counterpart_str_token,
            )
        )

        answer_tokens.append(
            (
                top_space_titleword_token,
                lowercase_counterpart_token,
                nospace_counterpart_token,
                nospace_lowercase_counterpart_token,
            )
        )

    answer_tokens = torch.tensor(answer_tokens).to(device)

    return answer_tokens, answer_str_tokens


answer_tokens, answer_str_tokens = titleword_answer_generator(logits_idx_sorted)
print(answer_str_tokens)
print(answer_tokens)

[(' Go', ' go', 'Go', 'go'), (' I', ' i', 'I', 'i'), (' He', ' he', 'He', 'he'), (' This', ' this', 'This', 'this'), (' But', ' but', 'But', 'but')]
tensor([[1514,  467, 5247, 2188],
        [ 314, 1312,   40,   72],
        [ 679,  339, 1544,  258],
        [ 770,  428, 1212, 5661],
        [ 887,  475, 1537, 4360]], device='cuda:0')


#### Calculate logit diffs for each of the three incorrect answer types
- Please forgive the weird formatting in the `logits_to_ave_logit_diff()` cell! I like to use the `black` formatter, but it's being annoyingly bugger for that particualr cell! I don't understand why and am mindful of the timesink that troubleshooting it likely presents...

In [71]:
def logits_to_ave_logit_diff_2(
    logits, answer_tokens, per_prompt=False, print_=True, incorrect_idx=1
):
    """A modified version of `logits_to_ave_logit_diff_2()` from the Exploratory Analysis
    Demo that permits multiple incorrect answers in `answer_tokens` where the user
    can specify an incorrect answer index for which to calculate the logit diff.
    """
    final_logits = logits[:, -1, :]

    answer_logits = final_logits.gather(
        dim=-1, index=answer_tokens[:, [0, incorrect_idx]]
    )
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    answer_logit_diff_mean = answer_logit_diff.mean()

    if print_:
        print(
            "Per prompt logit difference:",
            answer_logit_diff.detach().cpu().round(decimals=3),
        )
        print(
            "Average logit difference:",
            round(answer_logit_diff_mean.item(), 3),
        )

    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff_mean

In [72]:
idx_to_incorrect_answer_label = dict(
    (
        (1, "lowercase"),
        (2, "missing space"),
        (3, "lowercase and missing space"),
    )
)
print("Prompts:")
for prompt in prompts:
    print(f"'{prompt}'")

print(f"\nTop prediction per prompt:\n{[row[0] for row in answer_str_tokens]}")
for incorrect_idx in range(1, answer_tokens.size(1)):
    print(
        f"\nLogit diffs for incorrect answer type: '{idx_to_incorrect_answer_label[incorrect_idx]}':"
    )
    logit_diffs = logits_to_ave_logit_diff_2(
        logits, answer_tokens, per_prompt=True, incorrect_idx=incorrect_idx
    )

Prompts:
'Go.'
'Hello.'
'Matthew doesn't know what he is doing.'
'This is a sentence. This is another sentence.'
'The enigmatic, silver-haired professor, known for his eccentric lectures on quantum entanglement and the nature of reality, embarked on a perilous journey through the mist-shrouded mountains of Bhutan, seeking an ancient, mystical artifact rumored to hold the key to unlocking the secrets of the universe, while his loyal assistant, a quick-witted and resourceful graduate student with a penchant for solving cryptic puzzles, followed close behind, armed only with a weathered journal, a compass, and an unwavering determination to unravel the mystery that had haunted her mentor for decades.'

Top prediction per prompt:
[' Go', ' I', ' He', ' This', ' But']

Logit diffs for incorrect answer type: 'lowercase':
Per prompt logit difference: tensor([5.0080, 5.9630, 7.7480, 6.4700, 7.8960])
Average logit difference: 6.617

Logit diffs for incorrect answer type: 'missing space':
Per pr

#### Interpretation

- Average logit differences are high, with the min average implying $e^{4.266} \approx 71 \times$ more likely for the correct answer to be chosen then its incorrect counterpart

- Missing space logit diff performs worst. Perhaps some of the training data included full-stop-separated words, like URLs?
  - Although curiously missing space AND lowercase is least likely
    - `"go"` following `"Go."` is an outlier here
  - Something is forcing the titlecasing?

- Still very low data: cautious with above results.

### Try Logit Lens

- Start with one incorrect case: lowercase version of same token.

In [73]:
incorrect_idx = 1

original_average_logit_diff = logits_to_ave_logit_diff_2(
        logits, answer_tokens, per_prompt=False, incorrect_idx=incorrect_idx, print_=False)

answer_residual_directions = model.tokens_to_residual_directions(answer_tokens[:, [0, incorrect_idx]])
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([5, 2, 768])
Logit difference directions shape: torch.Size([5, 768])


#### Verify okay:

In [74]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)

final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Final residual stream shape: torch.Size([5, 116, 768])
Calculated average logit diff: 7.233
Original logit difference: 6.617


## **I don't understand the difference above. Let's simplify:**

### Redefine `answers` and simpler `logit_to_ave_logit_diff()`. Stick to difference from lowercase version.

In [75]:
answers = [
    (" Go", " go"),
    (" I", " i"),
    (" He", " he"),
    (" This", " this"),
    (" But", " but"),
]

answer_tokens = torch.tensor(
    [[model.to_single_token(ans_row[0]), model.to_single_token(ans_row[1])] for ans_row in answers]
).to(device)
print(answer_tokens)

tensor([[1514,  467],
        [ 314, 1312],
        [ 679,  339],
        [ 770,  428],
        [ 887,  475]], device='cuda:0')


In [76]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

original_average_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)

Well, that matches above...

### Try Logit lens again

- Code copied from [Exploratory Analysis Demo](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)

In [77]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Answer residual directions shape: torch.Size([5, 2, 768])
Logit difference directions shape: torch.Size([5, 768])
Final residual stream shape: torch.Size([5, 116, 768])
Calculated average logit diff: 7.233
Original logit difference: 6.617


#### Still different! (Exactly the same numbers)

- Though it at least appears that the the new logic in `logits_to_ave_logit_diff_2()` and the handling of more incorrect answers in `answer_tokens` is not the cause

## Try prompts of same length

- Exploratory analysis demo warned about prompts of varying length

- Perhaps the left padding flag is not enough.

### New prompts

In [78]:
prompts = [
    "Filling up to eleven tokens.\nGo.", # Single word sentence
    "Hello. Hello. Hello. Hello. Hello.", # Repeat single word sentence
    "Yeah Matt doesn't know what he is doing.", # More or less standard sentence
    "That Will does not know where he is going.", # More or less standard sentence, different end verb
    "Someone should really help them out, I think.", # More or less standard sentence, not "ing" end verb, has comma
    "It is wonderful to be in Adelaide in March.", # More or less standard sentence, end noun, has comma
    "Koalas are cute, but grumpy.", # More or less standard sentence, end in adjective
    "This is a sentence. This is another sentence.", # Two sentences
]
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

tokens = model.to_tokens(prompts)

logits, cache = model.run_with_cache(tokens)
logits_final = logits[:, -1, :]
logits_sorted, logits_idx_sorted = logits_final.sort(
    descending=True, stable=True, dim=-1
)
print()
print(f"Top prediction each prompt:\n{model.to_str_tokens(logits_idx_sorted[:, 1])}")

Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'F', 'illing', ' up', ' to', ' eleven', ' tokens', '.', '\n', 'Go', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Hello', '.', ' Hello', '.', ' Hello', '.', ' Hello', '.', ' Hello', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Yeah', ' Matt', ' doesn', "'t", ' know', ' what', ' he', ' is', ' doing', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'That', ' Will', ' does', ' not', ' know', ' where', ' he', ' is', ' going', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Someone', ' should', ' really', ' help', ' them', ' out', ',', ' I', ' think', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'It', ' is', ' wonderful', ' to', ' be', ' in', ' Adelaide', ' in', ' March', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Ko', 'al', 'as', ' are', ' cute', ',', ' but', ' gr', 'umpy', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'This', ' is', ' a', ' 

### Generate new answers:

In [79]:
answer_tokens, answer_str_tokens = titleword_answer_generator(logits_idx_sorted)
print(answer_str_tokens)
print(answer_tokens)

[(' Go', ' go', 'Go', 'go'), (' Hello', ' hello', 'Hello', 'hello'), (' He', ' he', 'He', 'he'), (' He', ' he', 'He', 'he'), (' I', ' i', 'I', 'i'), (' It', ' it', 'It', 'it'), (' They', ' they', 'They', 'they'), (' This', ' this', 'This', 'this')]
tensor([[ 1514,   467,  5247,  2188],
        [18435, 23748, 15496, 31373],
        [  679,   339,  1544,   258],
        [  679,   339,  1544,   258],
        [  314,  1312,    40,    72],
        [  632,   340,  1026,   270],
        [ 1119,   484,  2990,  9930],
        [  770,   428,  1212,  5661]], device='cuda:0')


### New logit diffs:

In [80]:
idx_to_incorrect_answer_label = dict(
    (
        (1, "lowercase"),
        (2, "missing space"),
        (3, "lowercase and missing space"),
    )
)
print("Prompts:")
for prompt in prompts:
    print(f"'{prompt}'")

print(f"\nTop space-titleword prediction per prompt:\n{[row[0] for row in answer_str_tokens]}")
for incorrect_idx in range(1, answer_tokens.size(1)):
    print(
        f"\nLogit diffs for incorrect answer type: '{idx_to_incorrect_answer_label[incorrect_idx]}':"
    )
    logit_diffs = logits_to_ave_logit_diff_2(
        logits, answer_tokens, per_prompt=True, incorrect_idx=incorrect_idx
    )

Prompts:
'Filling up to eleven tokens.
Go.'
'Hello. Hello. Hello. Hello. Hello.'
'Yeah Matt doesn't know what he is doing.'
'That Will does not know where he is going.'
'Someone should really help them out, I think.'
'It is wonderful to be in Adelaide in March.'
'Koalas are cute, but grumpy.'
'This is a sentence. This is another sentence.'

Top space-titleword prediction per prompt:
[' Go', ' Hello', ' He', ' He', ' I', ' It', ' They', ' This']

Logit diffs for incorrect answer type: 'lowercase':
Per prompt logit difference: tensor([5.5300, 5.4780, 5.8090, 6.1450, 6.2310, 7.7480, 8.2100, 6.4700])
Average logit difference: 6.453

Logit diffs for incorrect answer type: 'missing space':
Per prompt logit difference: tensor([3.0500, 4.3010, 5.8150, 5.5950, 3.9660, 6.5970, 6.3920, 5.2680])
Average logit difference: 5.123

Logit diffs for incorrect answer type: 'lowercase and missing space':
Per prompt logit difference: tensor([ 3.9800,  8.0950,  9.1950,  9.6280, 10.1990, 11.4920, 11.1240, 10

#### Observations

- All values high now. Avoid sentences that are expected to break standard English prose rule?

### Logit Lens

In [81]:
incorrect_idx = 1

original_average_logit_diff = logits_to_ave_logit_diff_2(logits, answer_tokens, per_prompt=False, incorrect_idx=incorrect_idx, print_=False)

In [82]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, incorrect_idx]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Answer residual directions shape: torch.Size([8, 4, 768])
Logit difference directions shape: torch.Size([8, 768])
Final residual stream shape: torch.Size([8, 11, 768])
Calculated average logit diff: 7.14
Original logit difference: 6.453


### Still different.
- I think a possible cause is that I have not not been picking the top answer for each prompt
as the correct answer.

  - E.g., there may be some difference in scaling across the prompts that makes comparing
an average invalid. 



Quick check:

In [83]:
top_answers = [
    (" Go", " go"),
    ("\n", " i"),
    ("\n", " i"),
    ("\n", " i"),
    (" I", " i"),
    (" It", " it"),
    ("\n", " i"),
    ("\n", " i"),
]

top_answer_tokens = torch.tensor(
    [[model.to_single_token(ans_row[0]), model.to_single_token(ans_row[1])] for ans_row in top_answers]
).to(device)
print(top_answer_tokens)
logits_top_answer, cache_top_answer = model.run_with_cache(top_answer_tokens)
original_average_logit_diff = logits_to_ave_logit_diff_2(logits_top_answer, top_answer_tokens, per_prompt=False, incorrect_idx=1, print_=False)

answer_residual_directions = model.tokens_to_residual_directions(top_answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

top_answer_logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", top_answer_logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache_top_answer["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache_top_answer.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    top_answer_logit_diff_directions,
) / len(top_answers)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

tensor([[1514,  467],
        [ 198, 1312],
        [ 198, 1312],
        [ 198, 1312],
        [ 314, 1312],
        [ 632,  340],
        [ 198, 1312],
        [ 198, 1312]], device='cuda:0')
Answer residual directions shape: torch.Size([8, 2, 768])
Logit difference directions shape: torch.Size([8, 768])
Final residual stream shape: torch.Size([8, 2, 768])
Calculated average logit diff: -1.614
Original logit difference: 0.233


Hmm. Still no good.

## In the interest of time, press on

- In normal circumstances, I would seek advice from a colleague/mentor on cause/importance of this difference in scaling.

- I have a small hunch that the difference may not overly adversely impact meaningfulness of results.

- Press on as if no difference for now

In [84]:
original_average_logit_diff = logits_to_ave_logit_diff_2(logits, answer_tokens, per_prompt=False, incorrect_idx=1, print_=False)

## Logit Lens

In [85]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

In [86]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulated Residual Stream",
)

### Layer Attribution

In [87]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

#### Observation

- It appears that MLP layers are primarily what matters.

- Especially the later ones (with increasing effect toward the later layers)

- With my limited understanding of transformers, this makes some sense. It seems to me that the most important information for predicting a space-titleword token next is the nature of the *current* position/token: it is a "sentence terminating character". Since MLP layers are used to process information at a position, I think this info is likely "determined/used" in MLP layera.

### Head Attribution

- Just out of interest, of course heads are parts of attention layers, which are apparently relatively unimportant - at least directly - according to the layer attribution graph above.

In [88]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

Tried to stack head results when they weren't cached. Computing head results now


#### Observations
- `0.1`, `10.7` and `11.0` look most interesting, in at least relative to the others.

- Note the colormap  scale. Not much effect - at least as compared to IOI! 

In [89]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [94]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

#### Observations:

- In the strongest head, `11.0`, token `1` attends strongly (relatively speaking) to token `0` (`<|endoftext|>`). This might contribute to capitalisation of token `1`?

- In `0.1`, the final token (`"."`) attends fairly strongly to token `7`, which is another full stop in at least one of the prompts.

- In `10.7`, attention to the second to last token negatively contributes to the logits of the final token.
  - One (major?) flaw in this analysis is investigating only the first titleword token. For the selection of prompts I chose, perhaps the token immediately preceding the final full stop contributes to a stronger prediction of some other token (the newline?)

## Residual Stream Patching

- **Note**: It's difficult to for me to know how much wisdom I can draw from this section. The "corrupted" prompts below do significantly increase the logits for the \<space\>\<**lowercase**\> version of the next token, but they do not meaningfully *reduce* the logits of the "clean" token. It is not a neat reversal like the IOI example

- The problem is: I'm struggling to think of a way to manufacture such a reversal, especially using the *relevant circuitry*. I'm sure I am missing something, but this seems like a difficult problem to provide an input that tests the "capitalise next word after full stop" circuit for which I can provide a corrupted input that produces an uncapitalised first word.
  - E.g., perhaps I could exclusively use **repeat sequences** of a single word (e.g., clean version:`["Go. Go. Go. Go"]`, corrupted version:`["go. go. go. go"]`), but it seems likely that some *different circuit* (e.g., duplication) produces capitalised output in the clean case.

In [95]:
prompts_corrupted = [
    "filling up to eleven tokens.\ngo.", # Single word sentence
    "hello. hello. hello. hello. hello.", # Repeat single word sentence
    "yeah Matt doesn't know what he is doing.", # More or less standard sentence
    "that Will does not know where he is going.", # More or less standard sentence, different end verb
    "someone should really help them out, I think.", # More or less standard sentence, not "ing" end verb, has comma
    "it is wonderful to be in Adelaide in March.", # More or less standard sentence, end noun, has comma
    "koalas are cute, but grumpy.", # More or less standard sentence, end in adjective
    "this is a sentence. this is another sentence.", # Two sentences
]

tokens_corrupted = model.to_tokens(prompts_corrupted)

logits_corrupted, cache_corrupted = model.run_with_cache(tokens_corrupted)
logits_final_corrupted = logits_corrupted[:, -1, :]
logits_sorted_corrupted, logits_idx_sorted_corrupted = logits_final_corrupted.sort(
    descending=True, stable=True, dim=-1
)
average_logit_diff_corrupted = logits_to_ave_logit_diff_2(logits_corrupted, answer_tokens, per_prompt=False, incorrect_idx=1, print_=True)
print("\nCorrupted Average Logit Diff", round(average_logit_diff_corrupted.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))
utils.test_prompt(prompts_corrupted[1], answer_str_tokens[1][0], model, top_k=1)
utils.test_prompt(prompts_corrupted[1], answer_str_tokens[1][1], model, top_k=1)

Per prompt logit difference: tensor([ 1.3570, -2.6480, -1.0820,  3.3580,  2.7450,  2.0550,  2.8190, -2.4030])
Average logit difference: 0.775

Corrupted Average Logit Diff 0.78
Clean Average Logit Diff 6.45
Tokenized prompt: ['<|endoftext|>', 'hello', '.', ' hello', '.', ' hello', '.', ' hello', '.', ' hello', '.']
Tokenized answer: [' Hello']


Top 0th token. Logit: 14.09 Prob: 35.49% Token: | hello|


Tokenized prompt: ['<|endoftext|>', 'hello', '.', ' hello', '.', ' hello', '.', ' hello', '.', ' hello', '.']
Tokenized answer: [' hello']


Top 0th token. Logit: 14.09 Prob: 35.49% Token: | hello|


#### Observation:

- 5 out of 8 did not even reverse! But any positive logit differences are smaller.

- That said, at least for these prompts, ***capitalisation of the next "word-like" token appears to depend fairly strongly on whether or not the *first* token (not token `0`/`<|endoftext|>`) was capitalised.***

In [92]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - average_logit_diff_corrupted) / (
        original_average_logit_diff - average_logit_diff_corrupted
    )


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            tokens_corrupted,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [93]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)

#### Observation

- **With all of the above caveats**
  - It *does* appear that some of the contributory computation is happening on the first token, which is ordinarily capitalised.

  - At around layers 6-8, the information is moved to the final token?
    - Are the initial layers determining a "capitalisation-state" datum for the first token in the sequence, and using this to help determine the token following the full stop?