# How does GPT-2 Small Predict Capital Letters After Full Stops? - *Exploratory Analysis*
### Matthew Jennings

## Imports

In [1]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

## Setup PyTorch

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x172d23a0b50>

## Define plotting helper functions

In [3]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Load Model

In [4]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()
print(f"Pytorch device: {device}")

Loaded pretrained model gpt2-small into HookedTransformer
Pytorch device: cuda


  return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)


## Task performance 

### Simple first example.

Does the model correctly predict capital letters after a full stop?

In [5]:
example_prompt = "This is a sentence."
answer = " It"

utils.test_prompt(example_prompt, answer, model, top_k=100)

Tokenized prompt: ['<|endoftext|>', 'This', ' is', ' a', ' sentence', '.']
Tokenized answer: [' It']


Top 0th token. Logit: 15.64 Prob: 17.54% Token: |
|
Top 1th token. Logit: 15.14 Prob: 10.67% Token: | It|
Top 2th token. Logit: 14.47 Prob:  5.44% Token: | This|
Top 3th token. Logit: 14.39 Prob:  5.03% Token: | I|
Top 4th token. Logit: 14.37 Prob:  4.94% Token: | The|
Top 5th token. Logit: 14.09 Prob:  3.75% Token: | If|
Top 6th token. Logit: 14.07 Prob:  3.68% Token: | A|
Top 7th token. Logit: 13.80 Prob:  2.79% Token: | You|
Top 8th token. Logit: 13.26 Prob:  1.63% Token: | Please|
Top 9th token. Logit: 12.99 Prob:  1.25% Token: | We|
Top 10th token. Logit: 12.97 Prob:  1.22% Token: | In|
Top 11th token. Logit: 12.87 Prob:  1.10% Token: | There|
Top 12th token. Logit: 12.83 Prob:  1.06% Token: |

|
Top 13th token. Logit: 12.71 Prob:  0.94% Token: | For|
Top 14th token. Logit: 12.66 Prob:  0.90% Token: | An|
Top 15th token. Logit: 12.45 Prob:  0.73% Token: | See|
Top 16th token. Logit: 12.38 Prob:  0.67% Token: |<|endoftext|>|
Top 17th token. Logit: 12.26 Prob:  0.60% Token: | "|
Top

#### Observations
- Interesting! The top result is a newline. Inspecting the rest of the results provides a nice reminder that tokens of the form `<space><capital-letter>` *are not only valid tokens to follow a full stop.* Other examples include:
  - "\n"
  - "\n\n"
  - `"<|endoftext|>"`
  - Non-alphanumeric chars (preceded by spaces): ` (`, ` "`

- No numeric tokens are present. I expect logits for numerical tokens to increase for if the last character preceding the full stop was numeric. Quickly test below: 

### Numeric example

In [6]:
numeric_example_prompt = (
    "The probability of me getting to the bottom of this circuit is not 0."
)
numeric_answer = " It"

utils.test_prompt(numeric_example_prompt, numeric_answer, model, top_k=100)

Tokenized prompt: ['<|endoftext|>', 'The', ' probability', ' of', ' me', ' getting', ' to', ' the', ' bottom', ' of', ' this', ' circuit', ' is', ' not', ' 0', '.']
Tokenized answer: [' It']


Top 0th token. Logit: 16.66 Prob:  8.34% Token: |
|
Top 1th token. Logit: 16.62 Prob:  7.99% Token: |5|
Top 2th token. Logit: 16.19 Prob:  5.20% Token: |1|
Top 3th token. Logit: 15.69 Prob:  3.16% Token: |01|
Top 4th token. Logit: 15.61 Prob:  2.91% Token: | The|
Top 5th token. Logit: 15.59 Prob:  2.87% Token: |001|
Top 6th token. Logit: 15.58 Prob:  2.84% Token: |0001|
Top 7th token. Logit: 15.56 Prob:  2.77% Token: | If|
Top 8th token. Logit: 15.55 Prob:  2.76% Token: |0|
Top 9th token. Logit: 15.33 Prob:  2.20% Token: | It|
Top 10th token. Logit: 15.13 Prob:  1.81% Token: |9|
Top 11th token. Logit: 15.08 Prob:  1.72% Token: | But|
Top 12th token. Logit: 14.95 Prob:  1.51% Token: |000|
Top 13th token. Logit: 14.95 Prob:  1.51% Token: |6|
Top 14th token. Logit: 14.94 Prob:  1.49% Token: |2|
Top 15th token. Logit: 14.87 Prob:  1.39% Token: |25|
Top 16th token. Logit: 14.81 Prob:  1.32% Token: |05|
Top 17th token. Logit: 14.81 Prob:  1.32% Token: |8|
Top 18th token. Logit: 14.81 Prob:  

#### Observations:
- Top token is `5`! No preceding space. (Upon reruns, sometimes it is second after `"\n"`")
- Of the top 20 predictions, only 5 are non-numeric (one of which is a newline).

- Will need to consider this in circuit exploration
  - Possibly (likely?) a different circuit handles decimal points

## Multiple examples - stick with English prose for now

In [7]:
prompts = [
    "Go.",
    "Hello.",
    "Matthew doesn't know what he is doing.",
    "This is a sentence. This is another sentence.",
    "The enigmatic, silver-haired professor, known for his eccentric lectures on quantum entanglement and the nature of reality, embarked on a perilous journey through the mist-shrouded mountains of Bhutan, seeking an ancient, mystical artifact rumored to hold the key to unlocking the secrets of the universe, while his loyal assistant, a quick-witted and resourceful graduate student with a penchant for solving cryptic puzzles, followed close behind, armed only with a weathered journal, a compass, and an unwavering determination to unravel the mystery that had haunted her mentor for decades.",
]

### Run model and get logits and cache for prompts

In [8]:
model.tokenizer.padding_side = "left"
tokens = model.to_tokens(prompts)

logits, cache = model.run_with_cache(tokens)
print(logits.size())

logits_final = logits[:, -1, :]
print(logits_final.size())

logits_sorted, logits_idx_sorted = logits_final.sort(
    descending=True, stable=True, dim=-1
)
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)
print()
print(f"Top prediction each prompt:\n{model.to_str_tokens(logits_idx_sorted[:, 1])}")

torch.Size([5, 116, 50257])
torch.Size([5, 50257])
Prompt length: 3
Prompt as tokens: ['<|endoftext|>', 'Go', '.']
Prompt length: 3
Prompt as tokens: ['<|endoftext|>', 'Hello', '.']
Prompt length: 10
Prompt as tokens: ['<|endoftext|>', 'Matthew', ' doesn', "'t", ' know', ' what', ' he', ' is', ' doing', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'This', ' is', ' a', ' sentence', '.', ' This', ' is', ' another', ' sentence', '.']
Prompt length: 116
Prompt as tokens: ['<|endoftext|>', 'The', ' enigmatic', ',', ' silver', '-', 'haired', ' professor', ',', ' known', ' for', ' his', ' eccentric', ' lectures', ' on', ' quantum', ' ent', 'ang', 'lement', ' and', ' the', ' nature', ' of', ' reality', ',', ' embarked', ' on', ' a', ' perilous', ' journey', ' through', ' the', ' mist', '-', 'sh', 'roud', 'ed', ' mountains', ' of', ' Bh', 'utan', ',', ' seeking', ' an', ' ancient', ',', ' mystical', ' artifact', ' rumored', ' to', ' hold', ' the', ' key', ' to', ' unlocking', ' th

### Top answers filtered by '\<space\>\<titleword\>' and corresponding incorrect answers

- Could be missing something important here!

- Implications for not taking very top? Hope it doesn't break later code...

In [9]:
import re

space_titleword_pattern = re.compile("\s[A-Z]\w*")

def titleword_answer_generator(logits_idx_sorted):
    """Generate a sequence of correct/incorrect answer tuples by sampling
    the first '<space><titleword>' token in `logits_sorted_idx` and
    finding the corresponding incorrect versions. I.e.,
    '<titleword>' (no space), '<space><lowerword>', and '<lowerword>'
    """
    answer_str_tokens = []
    answer_tokens = []

    for logits_sorted_for_prompt in logits_idx_sorted:

        predictions_sorted = model.to_str_tokens(logits_sorted_for_prompt)
        top_space_titleword_prediction = next(
            pred
            for pred in predictions_sorted[:100]
            if space_titleword_pattern.match(pred)
        )
        top_space_titleword_token = model.to_single_token(
            top_space_titleword_prediction
        )

        lowercase_counterpart_str_token = top_space_titleword_prediction.lower()
        lowercase_counterpart_token = model.to_single_token(
            lowercase_counterpart_str_token
        )

        nospace_counterpart_str_token = top_space_titleword_prediction.lstrip()
        nospace_counterpart_token = model.to_single_token(nospace_counterpart_str_token)

        nospace_lowercase_counterpart_str_token = (
            lowercase_counterpart_str_token.lstrip()
        )
        nospace_lowercase_counterpart_token = model.to_single_token(
            nospace_lowercase_counterpart_str_token
        )

        answer_str_tokens.append(
            (
                top_space_titleword_prediction,
                lowercase_counterpart_str_token,
                nospace_counterpart_str_token,
                nospace_lowercase_counterpart_str_token,
            )
        )

        answer_tokens.append(
            (
                top_space_titleword_token,
                lowercase_counterpart_token,
                nospace_counterpart_token,
                nospace_lowercase_counterpart_token,
            )
        )

    answer_tokens = torch.tensor(answer_tokens).to(device)

    return answer_tokens, answer_str_tokens


answer_tokens, answer_str_tokens = titleword_answer_generator(logits_idx_sorted)
print(answer_str_tokens)
print(answer_tokens)

[(' Go', ' go', 'Go', 'go'), (' I', ' i', 'I', 'i'), (' He', ' he', 'He', 'he'), (' This', ' this', 'This', 'this'), (' But', ' but', 'But', 'but')]
tensor([[1514,  467, 5247, 2188],
        [ 314, 1312,   40,   72],
        [ 679,  339, 1544,  258],
        [ 770,  428, 1212, 5661],
        [ 887,  475, 1537, 4360]], device='cuda:0')


#### Calculate logit diffs for each of the three incorrect answer types
- Please forgive the weird formatting in the `logits_to_ave_logit_diff()` cell! I like to use the `black` formatter, but it's being annoyingly bugger for that particualr cell! I don't understand why and am mindful of the timesink that troubleshooting it likely presents...

In [10]:
def logits_to_ave_logit_diff_2(
    logits, answer_tokens, per_prompt=False, print_=True, incorrect_idx=1
):
    """A modified version of `logits_to_ave_logit_diff_2()` from the Exploratory Analysis
    Demo that permits multiple incorrect answers in `answer_tokens` where the user
    can specify an incorrect answer index for which to calculate the logit diff.
    """
    final_logits = logits[:, -1, :]

    answer_logits = final_logits.gather(
        dim=-1, index=answer_tokens[:, [0, incorrect_idx]]
    )
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    answer_logit_diff_mean = answer_logit_diff.mean()

    if print_:
        print(
            "Per prompt logit difference:",
            answer_logit_diff.detach().cpu().round(decimals=3),
        )
        print(
            "Average logit difference:",
            round(answer_logit_diff_mean.item(), 3),
        )

    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff_mean

In [12]:
idx_to_incorrect_answer_label = dict(
    (
        (1, "lowercase"),
        (2, "missing space"),
        (3, "lowercase and missing space"),
    )
)
print("Prompts:")
for prompt in prompts:
    print(f"'{prompt}'")

print(f"\nTop prediction per prompt:\n{[row[0] for row in answer_str_tokens]}")
for incorrect_idx in range(1, answer_tokens.size(1)):
    print(
        f"\nLogit diffs for incorrect answer type: '{idx_to_incorrect_answer_label[incorrect_idx]}':"
    )
    logit_diffs = logits_to_ave_logit_diff_2(
        logits, answer_tokens, per_prompt=True, incorrect_idx=incorrect_idx
    )

Prompts:
'Go.'
'Hello.'
'Matthew doesn't know what he is doing.'
'This is a sentence. This is another sentence.'
'The enigmatic, silver-haired professor, known for his eccentric lectures on quantum entanglement and the nature of reality, embarked on a perilous journey through the mist-shrouded mountains of Bhutan, seeking an ancient, mystical artifact rumored to hold the key to unlocking the secrets of the universe, while his loyal assistant, a quick-witted and resourceful graduate student with a penchant for solving cryptic puzzles, followed close behind, armed only with a weathered journal, a compass, and an unwavering determination to unravel the mystery that had haunted her mentor for decades.'

Top prediction per prompt:
[' Go', ' I', ' He', ' This', ' But']

Logit diffs for incorrect answer type: 'lowercase':
Per prompt logit difference: tensor([5.0080, 5.9630, 7.7480, 6.4700, 7.8960])
Average logit difference: 6.617

Logit diffs for incorrect answer type: 'missing space':
Per pr

#### Interpretation

- Average logit differences are high, with the min average implying $e^{4.266} \approx 71 \times$ more likely for the correct answer to be chosen then its incorrect counterpart

- Missing space logit diff performs worst. Perhaps some of the training data included full-stop-separated words, like URLs?
  - Although curiously missing space AND lowercase is least likely
    - `"go"` following `"Go."` is an outlier here
  - Something is forcing the titlecasing?

- Still very low data: cautious with above results.

### Try Logit Lens

- Start with one incorrect case: lowercase version of same token.

In [13]:
incorrect_idx = 1

original_average_logit_diff = logits_to_ave_logit_diff_2(
        logits, answer_tokens, per_prompt=False, incorrect_idx=incorrect_idx, print_=False)

answer_residual_directions = model.tokens_to_residual_directions(answer_tokens[:, [0, incorrect_idx]])
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([5, 2, 768])
Logit difference directions shape: torch.Size([5, 768])


#### Verify okay:

In [14]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)

final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Final residual stream shape: torch.Size([5, 116, 768])
Calculated average logit diff: 7.233
Original logit difference: 6.617


## **I don't understand the difference above. Let's simplify:**

### Redefine `answers` and simpler `logit_to_ave_logit_diff()`. Stick to difference from lowercase version.

In [15]:
answers = [
    (" Go", " go"),
    (" I", " i"),
    (" He", " he"),
    (" This", " this"),
    (" But", " but"),
]

answer_tokens = torch.tensor(
    [[model.to_single_token(ans_row[0]), model.to_single_token(ans_row[1])] for ans_row in answers]
).to(device)
print(answer_tokens)

tensor([[1514,  467],
        [ 314, 1312],
        [ 679,  339],
        [ 770,  428],
        [ 887,  475]], device='cuda:0')


In [16]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

original_average_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)

Original average logit diff: 6.617


Well, that matches above...

### Try Logit lens again

- Code copied from [Exploratory Analysis Demo](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)

In [17]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Answer residual directions shape: torch.Size([5, 2, 768])
Logit difference directions shape: torch.Size([5, 768])
Final residual stream shape: torch.Size([5, 116, 768])
Calculated average logit diff: 7.233
Original logit difference: 6.617


#### Still different! (Exactly the same numbers)

- Though it at least appears that the the new logic in `logits_to_ave_logit_diff_2()` and the handling of more incorrect answers in `answer_tokens` is not the cause

## Try prompts of same length

- Exploratory analysis demo warned about prompts of varying length

- Perhaps the left padding flag is not enough.

### New prompts

In [29]:
prompts = [
    "Filling up to eleven tokens.\nGo.",
    "Hello. Hello. Hello. Hello. Hello.",
    "Matthew really does not know what he is doing.",
    "This is a sentence. This is another sentence.",
]
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

tokens = model.to_tokens(prompts)

logits, cache = model.run_with_cache(tokens)
logits_final = logits[:, -1, :]
logits_sorted, logits_idx_sorted = logits_final.sort(
    descending=True, stable=True, dim=-1
)
print()
print(f"Top prediction each prompt:\n{model.to_str_tokens(logits_idx_sorted[:, 1])}")

Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'F', 'illing', ' up', ' to', ' eleven', ' tokens', '.', '\n', 'Go', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Hello', '.', ' Hello', '.', ' Hello', '.', ' Hello', '.', ' Hello', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Matthew', ' really', ' does', ' not', ' know', ' what', ' he', ' is', ' doing', '.']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'This', ' is', ' a', ' sentence', '.', ' This', ' is', ' another', ' sentence', '.']

Top prediction each prompt:
[' Go', '\n', '\n', '\n']


### Generate new answers:

In [30]:
answer_tokens, answer_str_tokens = titleword_answer_generator(logits_idx_sorted)
print(answer_str_tokens)
print(answer_tokens)

[(' Go', ' go', 'Go', 'go'), (' Hello', ' hello', 'Hello', 'hello'), (' He', ' he', 'He', 'he'), (' This', ' this', 'This', 'this')]
tensor([[ 1514,   467,  5247,  2188],
        [18435, 23748, 15496, 31373],
        [  679,   339,  1544,   258],
        [  770,   428,  1212,  5661]], device='cuda:0')


### New logit diffs:

In [31]:
idx_to_incorrect_answer_label = dict(
    (
        (1, "lowercase"),
        (2, "missing space"),
        (3, "lowercase and missing space"),
    )
)
print("Prompts:")
for prompt in prompts:
    print(f"'{prompt}'")

print(f"\nTop prediction per prompt:\n{[row[0] for row in answer_str_tokens]}")
for incorrect_idx in range(1, answer_tokens.size(1)):
    print(
        f"\nLogit diffs for incorrect answer type: '{idx_to_incorrect_answer_label[incorrect_idx]}':"
    )
    logit_diffs = logits_to_ave_logit_diff_2(
        logits, answer_tokens, per_prompt=True, incorrect_idx=incorrect_idx
    )



Prompts:
'Filling up to eleven tokens.
Go.'
'Hello. Hello. Hello. Hello. Hello.'
'Matthew really does not know what he is doing.'
'This is a sentence. This is another sentence.'

Top prediction per prompt:
[' Go', ' Hello', ' He', ' This']

Logit diffs for incorrect answer type: 'lowercase':
Per prompt logit difference: tensor([5.5300, 5.4780, 6.6000, 6.4700])
Average logit difference: 6.02

Logit diffs for incorrect answer type: 'missing space':
Per prompt logit difference: tensor([3.0500, 4.3010, 6.1450, 5.2680])
Average logit difference: 4.691

Logit diffs for incorrect answer type: 'lowercase and missing space':
Per prompt logit difference: tensor([ 3.9800,  8.0950,  9.7580, 10.6940])
Average logit difference: 8.132


### Logit Lens

In [34]:
incorrect_idx = 2

original_average_logit_diff = logits_to_ave_logit_diff_2(logits, answer_tokens, per_prompt=False, incorrect_idx=incorrect_idx, print_=False)

In [36]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, incorrect_idx]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Answer residual directions shape: torch.Size([4, 4, 768])
Logit difference directions shape: torch.Size([4, 768])
Final residual stream shape: torch.Size([4, 11, 768])
Calculated average logit diff: 6.871
Original logit difference: 6.02


### Still different. In the interest of time, let's press on.

- I think a possible cause is that I have not not been picking the top answer for each prompt
as the correct answer.
  - E.g., there may be some difference in scaling across the prompts that makes comparing
an average invalid. 



Quick check:

In [48]:
top_answers = [
    (" Go", " go"),
    (" \n", " i"),
    (" \n", " he"),
    (" \n", " this"),
    (" \n", " but"),
]

top_answer_tokens = torch.tensor(
    [[model.to_single_token(ans_row[0]), model.to_single_token(ans_row[1])] for ans_row in answers]
).to(device)
print(top_answer_tokens)
logits_top_answer, cache_top_answer = model.run_with_cache(top_answer_tokens)
original_average_logit_diff = logits_to_ave_logit_diff_2(logits_top_answer, top_answer_tokens, per_prompt=False, incorrect_idx=1, print_=False)

answer_residual_directions = model.tokens_to_residual_directions(top_answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache_top_answer["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache_top_answer.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

tensor([[1514,  467],
        [ 314, 1312],
        [ 679,  339],
        [ 770,  428],
        [ 887,  475]], device='cuda:0')
Answer residual directions shape: torch.Size([5, 2, 768])
Logit difference directions shape: torch.Size([5, 768])
Final residual stream shape: torch.Size([5, 2, 768])
Calculated average logit diff: -3.219
Original logit difference: -3.191


Well this is much closer! But still nto spot on as in Exploratory Analysis demo.