# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-27_dioce
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-27_dioce
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit ce82675a8e89b6d5e6229a89620c843c794f3b04
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [3]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7c9a1bccf220>

## Load Model

In [5]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


# Generate dataset with multiple prompts

In [85]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        # self.corr_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        # ]
        # self.incorr_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        # ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [86]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [87]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            # 'corr': str(i+4),
            # 'incorr': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 101)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [88]:
import random

def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y):
        r1 = random.randint(1, 100)
        r2 = random.randint(1, 100)
        while True:
            r3 = random.randint(1, 100)
            r4 = random.randint(1, 100)
            if r4 - 1 != r3:
                break
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(i+4),
            'incorr': str(i+3),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 101)
# prompts_list_2 = generate_prompts_list_corr(1, 2)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

In [89]:
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

# tests

In [17]:
dataset.toks.size()

torch.Size([100, 4])

In [7]:
import torch as t
from torch import Tensor
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set

In [90]:
results = t.zeros((2, model.cfg.n_layers, model.cfg.n_heads)) #, device=device

# Define components from our model (for typechecking, and cleaner code)
embed = model.embed
mlp0 = model.blocks[0].mlp
ln0 = model.blocks[0].ln2
unembed = model.unembed
ln_final = model.ln_final

# Get embeddings for the names in our list
# name_tokens: Int[Tensor, "batch 1"] = model.to_tokens(names, prepend_bos=False)
# name_embeddings: Int[Tensor, "batch 1 d_model"] = embed(name_tokens)

embeddings = embed(dataset.toks)
embeddings.size()

torch.Size([100, 4, 768])

In [91]:
# Get residual stream after applying MLP
resid_after_mlp1 = embeddings + mlp0(ln0(embeddings))
resid_after_mlp1.size()

torch.Size([100, 4, 768])

In [92]:
mlp9 = model.blocks[9].mlp
mlp9

MLP(
  (hook_pre): HookPoint()
  (hook_post): HookPoint()
)

In [93]:
ln9 = model.blocks[9].ln2

In [29]:
resid_after_mlp9 = resid_after_mlp1 + mlp9(ln9(resid_after_mlp1))
resid_after_mlp9.size()

torch.Size([100, 4, 768])

In [31]:
logits = unembed(ln_final(resid_after_mlp9)).squeeze()
logits.size()

torch.Size([100, 4, 50257])

In [34]:
k=5
topk_logits: Int[Tensor, "batch k"] = t.topk(logits, dim=-1, k=k).indices
topk_logits.size()

torch.Size([100, 4, 5])

In [51]:
words = [key for key in dataset.prompts[0].keys() if key != 'text']
words

['S1', 'S2', 'S3', 'S4']

In [55]:
for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 ['teenth', 'th', 'teen', 'WD', 'ND']
5 ['th', 'Thirty', '42', '41', '43']
6 ['teenth', 'th', 'teen', '03', '34']
7 ['th', '07', '49', '87', '46']
8 ['192', '98', '07', 'th', '90']
9 ['09', '07', '08', '999', '06']
10 [' minutes', ' percent', ' times', 'bp', ' years']
11 ['87', '10', '03', '11', '07']
12 ['03', '34', '02', '04', '92']
13 ['rd', '37', 'DD', '66', 'th']
14 ['teenth', 'th', '34', '37', '14']
15 ['th', ' minutes', '20', '15', ' years']
16 ['384', 'th', '16', 'burn', ' stitches']
17 ['th', 'rd', '76', '87', '37']
18 ['teenth', '37', '34', 'th', '94']
19 ['th', 'aldi', '61', ' months', 'teenth']
20 [' minutes', ' years', 'th', ' Years', 'nd']
21 ['st', 'nd', '50', 'rd', ' NCT']
22 ['nd', 'ND', ' NCT', ' sts', 'nces']
23 ['rd', 'RD', 'DD', 'nd', '00']
24 [' hrs', 'th', 'ND', ' hours', '34']
25 ['th', 'rd', 'ishing', '%-', 'agher']
26 ['th', 'nd', '26', '37', '31']
27 ['th', 'rd', '00', '37', '26']
28 ['th', '00', 'nd', 'nm', '80']
29 ['th', '29', 'ighth', '00', '79']
30 [' m

In [57]:
after_mlp1 = mlp0(ln0(embeddings))
after_mlp9 = mlp9(ln9(after_mlp1))
logits = unembed(ln_final(after_mlp9)).squeeze()

for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 ['.', ' to', '-', "'", ',']
5 [',', '-', ' to', '.', ' or']
6 [' to', '-', ',', 'z', '/']
7 [' to', ' (', ' or', ',', '.']
8 ['.', ',', ' or', '-', ' to']
9 [' to', ' T', 'to', ' L', ' M']
10 [' or', ',', '.', ' and', ' to']
11 [' to', ' (', '-', ',', '.']
12 ['-', ',', '.', ' (', '–']
13 ['-', ' (', ' which', ',', ' Image']
14 ['-', ' (', ',', ' M', 'to']
15 ['-', ',', ' and', '/', ' or']
16 ['-', ',', ' of', '/', ' stories']
17 ['-', 'to', '+', ' August', '%']
18 ['-', '+', ' M', ' T', ' B']
19 ['+', ' Jr', 'th', '-', ' of']
20 ['+', '%', '/', '.', 'yd']
21 ['st', ' st', 'y', '.', 'DE']
22 ['DE', 'D', 'AD', 'B', 'st']
23 ['DE', 'st', ' Way', ' GA', 'EMA']
24 ['-', '/', ',', '.', ' Just']
25 [' and', 'th', ' form', ' most', ',']
26 [' and', '-', ' recogn', '/', ' still']
27 [' to', ' recogn', 'al', ',', 'to']
28 [' &', ',', '-', 'is', ' which']
29 ['les', 'al', "'s", ' &', ' Ten']
30 ['-', ' or', ' and', ' to', '/']
31 ['-', '.', 'B', 'Image', '+']
32 ['-', ' un', '/', 'MA', '+']
33

In [58]:
after_mlp1 = mlp0(embeddings)
after_mlp9 = mlp9(after_mlp1)
logits = unembed(ln_final(after_mlp9)).squeeze()

for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 ['th', ' days', 'g', ' of', 'ml']
5 ['cc', 'ici', 'ts', 'ml', 'earch']
6 [' coh', ' ms', ' feet', ' Hearts', ' AA']
7 [' coh', ' county', ' sqor', ' Investigator', 'da']
8 [' envelope', ' district', ' Borough', ' card', ' administrative']
9 [' Miner', 'id', ' magistrate', ' number', ' No']
10 [' coh', ' confir', 'ici', ' mm', ' nm']
11 ['th', ' AA', ' coh', ' day', 'd']
12 ['theless', 'th', 'ths', 'ui', ' gauge']
13 ['antz', 'th', ' coh', 'agara', 'ufact']
14 ['theless', 'th', 'bda', 'tm', ' coh']
15 ['thur', 'esville', ' elim', 'bek', 'bda']
16 ['abc', 'cm', 'db', 'bek', 'antz']
17 ['th', 'bek', 'db', 'd', 'px']
18 ['th', '000', 'mn', 'db', 'ths']
19 ['th', 'dn', 'db', 'nm', 'mn']
20 ['ths', 'thur', '%', 'th', 'mm']
21 ['ala', ' level', ' complex', ' very', '018']
22 ['th', 'b', 'ted', '45', 'D']
23 ['dp', '45', '00', 'bp', 'DS']
24 [' we', 'chy', 'ck', "'s", 'th']
25 ['cc', 'db', 'mm', 'hm', '000000']
26 ['abc', 'dn', '000', 'antz', 'a']
27 ['tnc', 'dn', ' obser', 'abc', '11']
28 [

In [59]:
layer = 9
head = 1

# Get W_OV matrix
W_OV = model.W_V[layer, head] @ model.W_O[layer, head]

# Get residual stream after applying W_OV or -W_OV respectively
# (note, because of bias b_U, it matters that we do sign flip here, not later)
resid_after_OV_pos = resid_after_mlp1 @ W_OV
resid_after_OV_neg = resid_after_mlp1 @ -W_OV

# Get logits from value of residual stream
logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_pos)).squeeze()
logits_neg: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_neg)).squeeze()

In [60]:
for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits_pos[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 [' 5', ' 4', '5', ' 6', ' five']
5 [' 6', '6', ' 5', ' 7', ' six']
6 [' 7', ' 6', '7', ' 8', ' seven']
7 [' 8', ' 7', '8', ' 9', ' 808']
8 [' 9', '9', ' 8', ' 10', ' nine']
9 [' 10', '10', ' 9', ' 11', ' 12']
10 [' 11', ' 12', ' eleven', ' 10', '11']
11 [' 12', ' 13', '12', ' 14', ' 11']
12 [' 13', ' 14', '13', ' 12', ' thirteen']
13 [' 14', '14', ' 13', ' 15', ' 16']
14 [' 15', ' 16', ' 14', ' 18', ' 17']
15 [' 16', ' 17', ' 18', ' 15', '16']
16 [' 17', ' 18', '17', '18', ' 19']
17 [' 18', '18', ' 19', ' 17', ' 1889']
18 [' 19', ' 18', ' 1870', '19', ' 20']
19 [' 21', ' 20', ' 22', ' 19', ' 23']
20 [' 21', ' 22', ' 20', '21', ' 25']
21 [' 22', ' 23', ' 21', '22', ' 53']
22 [' 23', ' 22', ' 24', '23', ' 29']
23 [' 24', ' 23', ' 25', '24', '23']
24 [' 25', ' 26', ' 24', '25', ' 27']
25 [' 26', ' 25', ' 27', ' 30', ' 35']
26 [' 27', ' 28', ' 29', ' 37', ' 31']
27 [' 28', ' 29', ' 34', ' 27', ' 33']
28 [' 29', ' 30', ' 34', ' 39', ' 33']
29 [' 31', ' 30', ' 29', ' 34', ' 39']
30 [' 31',

In [61]:
for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits_neg[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 [' both', 'amo', ' latter', ' second', ' Childhood']
5 [' both', 'unda', 'both', 'eral', ' antiv']
6 [' truly', ' Truly', ' really', ' Really', ' genuinely']
7 [' partly', ' secondly', ' truly', ' wholly', ' genuinely']
8 [' Clever', ' given', 'asse', ' Kerr', ' Shields']
9 [' latter', 'sworth', ' either', ' now', ' Either']
10 [' Either', ' needed', ' Both', 'helm', ' somehow']
11 [' second', '360', 'ilated', 'gyn', ' possible']
12 ['ills', ' though', ' ado', 'now', 'let']
13 ['now', ' therefore', 'omin', 'gging', ' secondly']
14 [' disreg', ' effortlessly', 'ered', ' unres', 'ering']
15 [' instead', ' needed', ' redistributed', 'anwhile', 'dayName']
16 [' demanded', ' instead', 'acs', ' urgently', ' required']
17 [' instead', ' Sind', ' required', ' demanded', 'instead']
18 [' somehow', ' though', ' somew', ' finally', ' agrees']
19 [' says', ' finally', ' insepar', ' say', 'bish']
20 [' says', ' warranted', './', ' unavoid', ' ultimately']
21 [' changed', ' continued', ' heartbeat

In [62]:
layer = 4
head = 4

# Get W_OV matrix
W_OV = model.W_V[layer, head] @ model.W_O[layer, head]

# Get residual stream after applying W_OV or -W_OV respectively
# (note, because of bias b_U, it matters that we do sign flip here, not later)
resid_after_OV_pos = resid_after_mlp1 @ W_OV
resid_after_OV_neg = resid_after_mlp1 @ -W_OV

# Get logits from value of residual stream
logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_pos)).squeeze()
# logits_neg: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_neg)).squeeze()
for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits_pos[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 ['ogens', ' indo', 'transfer', 'wind', 'rich']
5 ['mone', 'ogens', ' Leban', 'embed', ' Gutenberg']
6 ['raine', ' needs', ' indo', 'embed', ' inher']
7 [' reinvest', 'raine', 'uph', 'oint', ' Constantin']
8 [' indo', ' spons', ' inher', ' reinvest', ' embell']
9 ['plates', 'ulet', ' embell', ' Helm', 'embed']
10 [' Leban', ' fingerprints', ' intrig', 'メ', ' Gutenberg']
11 [' indo', ' Leban', ' chances', ' prest', 'onz']
12 ['ール', ' indo', ' Leban', ' privatization', ' fundraising']
13 [' Leban', ' reuse', 'ール', ' prest', ' reusable']
14 ['itiz', 'ogens', ' vetting', 'FY', 'henko']
15 [' Leban', ' confir', 'ogens', 'itiz', 'vernment']
16 [' indo', 'ogens', ' spons', 'unity', ' urgently']
17 [' indo', ' spons', 'ogens', ' vetting', ' sten']
18 [' spons', 'ogens', ' indo', 'ebin', 'itiz']
19 [' Allaah', 'oyal', 'plet', 'wind', 'uph']
20 ['ebin', ' Allaah', 'mone', 'kay', ' guiActiveUn']
21 [' prints', 'prints', 'ettel', ' automate', 'ebin']
22 ['ール', 'apons', ' indo', 'soDeliveryDate', 

In [66]:
layer = 9
head = 1

# Get W_OV matrix
W_OV = model.W_V[layer, head] @ model.W_O[layer, head]

# Get residual stream after applying W_OV or -W_OV respectively
# (note, because of bias b_U, it matters that we do sign flip here, not later)
resid_after_OV_pos = resid_after_mlp1 @ W_OV
# resid_after_OV_neg = resid_after_mlp1 @ -W_OV

resid_after_mlp9 = resid_after_OV_pos + mlp9(ln9(resid_after_OV_pos))

logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_mlp9)).squeeze()

for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits_pos[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 [' 5', ' 4', '5', ' 6', ' five']
5 [' 6', '6', ' 5', ' 7', ' six']
6 [' 7', '7', ' 6', ' 8', ' seven']
7 [' 8', ' 7', '8', ' 9', ' eight']
8 [' 9', '9', ' 8', ' 10', ' nine']
9 [' 10', '10', ' 9', ' 11', ' 12']
10 [' 11', ' 12', ' eleven', ' 10', '11']
11 [' 12', ' 13', '12', ' 14', '13']
12 [' 13', ' 14', '13', ' 12', ' thirteen']
13 [' 14', '14', ' 13', ' 15', ' 16']
14 [' 15', ' 16', '15', ' 18', ' 14']
15 [' 16', ' 17', ' 18', '16', ' 15']
16 [' 17', ' 18', '17', '18', ' 19']
17 [' 18', '18', ' 19', ' 17', ' 1889']
18 [' 19', ' 18', '19', ' 1870', ' 20']
19 [' 21', ' 20', ' 22', ' 19', '20']
20 [' 21', ' 22', ' 20', '21', ' 25']
21 [' 22', ' 23', ' 21', '22', ' 52']
22 [' 23', ' 22', '23', ' 24', ' 29']
23 [' 24', ' 23', ' 25', '24', ' 26']
24 [' 25', ' 26', '25', ' 24', ' 27']
25 [' 26', ' 25', ' 27', ' 30', ' 29']
26 [' 27', ' 28', ' 37', ' 29', ' 31']
27 [' 28', ' 29', ' 34', ' 27', ' 33']
28 [' 29', ' 30', ' 39', ' 34', ' 33']
29 [' 30', ' 31', ' 29', ' 34', ' 35']
30 [' 31',

In [68]:
layer = 4
head = 4
W_OV = model.W_V[layer, head] @ model.W_O[layer, head]
resid_after_OV_pos = resid_after_mlp1 @ W_OV

layer = 9
head = 1
W_OV = model.W_V[layer, head] @ model.W_O[layer, head]
resid_after_OV_pos = resid_after_OV_pos @ W_OV

resid_after_mlp9 = resid_after_OV_pos + mlp9(ln9(resid_after_OV_pos))

logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_mlp9)).squeeze()

for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits_pos[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

4 [' multiple', ' thousands', ' hundreds', ' millions', ' alike']
5 [' hundreds', ' thousands', ' multiple', ' countless', ' millions']
6 [' again', ' Called', ' believed', ' multiple', 'piring']
7 [' again', ' multiple', ' Again', 'again', ' thousands']
8 [' again', 'again', ' Again', ' three', 'Absolutely']
9 [' again', ' begun', ' hundreds', ' thousands', 'rawn']
10 [' hundreds', ' thousands', ' countless', ' millions', ' Thousands']
11 [' thousands', ' millions', ' hundreds', ' mistaken', ' two']
12 [' hundreds', ' millions', ' thousands', 'lished', ' enough']
13 [' enough', ' hundreds', ' multiple', 'lished', ' thousands']
14 [' needed', ' depended', ' required', ' multiple', ' believed']
15 [' thousands', ' millions', ' hundreds', ' needed', 'Reviewer']
16 [' believed', ' needed', ' meant', ' hundreds', ' enough']
17 [' multiple', ' required', ' thousands', ' needed', ' hundreds']
18 [' hundreds', ' thousands', ' beyond', ' Hundreds', ' Thousands']
19 [' thousands', ' hundreds', 

In [95]:
resid_after_OV_pos = resid_after_mlp1
for layer in range(12):
    for head in range(12):
        W_OV = model.W_V[layer, head] @ model.W_O[layer, head]
        resid_after_OV_pos = resid_after_OV_pos @ W_OV

resid_after_mlp9 = resid_after_OV_pos + mlp9(ln9(resid_after_OV_pos))

logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_mlp9)).squeeze()

for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits_pos[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

1 ['bugs']
2 ['bugs']
3 ['bugs']
4 ['bugs']
5 ['bugs']
6 ['bugs']
7 ['bugs']
8 ['bugs']
9 ['bugs']
10 ['bugs']
11 ['bugs']
12 ['bugs']
13 ['bugs']
14 ['bugs']
15 ['bugs']
16 ['bugs']
17 ['bugs']
18 ['bugs']
19 ['bugs']
20 ['bugs']
21 ['bugs']
22 ['bugs']
23 ['bugs']
24 ['bugs']
25 ['bugs']
26 ['bugs']
27 ['bugs']
28 ['bugs']
29 ['bugs']
30 ['bugs']
31 ['bugs']
32 ['bugs']
33 ['bugs']
34 ['bugs']
35 ['bugs']
36 ['bugs']
37 ['bugs']
38 ['bugs']
39 ['bugs']
40 ['bugs']
41 ['bugs']
42 ['bugs']
43 ['bugs']
44 ['bugs']
45 ['bugs']
46 ['bugs']
47 ['bugs']
48 ['bugs']
49 ['bugs']
50 ['bugs']
51 ['bugs']
52 ['bugs']
53 ['bugs']
54 ['bugs']
55 ['bugs']
56 ['bugs']
57 ['bugs']
58 ['bugs']
59 ['bugs']
60 ['bugs']
61 ['bugs']
62 ['bugs']
63 ['bugs']
64 ['bugs']
65 ['bugs']
66 ['bugs']
67 ['bugs']
68 ['bugs']
69 ['bugs']
70 ['bugs']
71 ['bugs']
72 ['bugs']
73 ['bugs']
74 ['bugs']
75 ['bugs']
76 ['bugs']
77 ['bugs']
78 ['bugs']
79 ['bugs']
80 ['bugs']
81 ['bugs']
82 ['bugs']
83 ['bugs']
84 ['bugs']
8

What if MLP 9 filters out all results except for 9.1?

In [None]:
layer = 4
head = 4

# Get W_OV matrix
W_OV = model.W_V[layer, head] @ model.W_O[layer, head]

# Get residual stream after applying W_OV or -W_OV respectively
# (note, because of bias b_U, it matters that we do sign flip here, not later)
resid_after_OV_pos = resid_after_mlp1 @ W_OV
# resid_after_OV_neg = resid_after_mlp1 @ -W_OV

resid_after_mlp9 = resid_after_OV_pos + mlp9(ln9(resid_after_OV_pos))

logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_mlp9)).squeeze()

for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            logits_pos[seq_idx, dataset.word_idx[word][seq_idx]], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

In [None]:
def get_copying_scores(
    model: HookedTransformer,
    k: int = 5,
    names: list = NAMES
) -> Float[Tensor, "2 layer-1 head"]:
    '''
    Gets copying scores (both positive and negative) as described in page 6 of the IOI paper, for every (layer, head) pair in the model.

    Returns these in a 3D tensor (the first dimension is for positive vs negative).

    Omits the 0th layer, because this is before MLP0 (which we're claiming acts as an extended embedding).
    '''
    # SOLUTION
    results = t.zeros((2, model.cfg.n_layers, model.cfg.n_heads), device=device)

    # Define components from our model (for typechecking, and cleaner code)
    embed: Embed = model.embed
    mlp0: MLP = model.blocks[0].mlp
    ln0: LayerNorm = model.blocks[0].ln2
    unembed: Unembed = model.unembed
    ln_final: LayerNorm = model.ln_final

    # Get embeddings for the names in our list
    name_tokens: Int[Tensor, "batch 1"] = model.to_tokens(names, prepend_bos=False)
    name_embeddings: Int[Tensor, "batch 1 d_model"] = embed(name_tokens)

    # Get residual stream after applying MLP
    resid_after_mlp1 = name_embeddings + mlp0(ln0(name_embeddings))

    # Loop over all (layer, head) pairs
    for layer in range(1, model.cfg.n_layers):
        for head in range(model.cfg.n_heads):

            # Get W_OV matrix
            W_OV = model.W_V[layer, head] @ model.W_O[layer, head]

            # Get residual stream after applying W_OV or -W_OV respectively
            # (note, because of bias b_U, it matters that we do sign flip here, not later)
            resid_after_OV_pos = resid_after_mlp1 @ W_OV
            resid_after_OV_neg = resid_after_mlp1 @ -W_OV

            # Get logits from value of residual stream
            logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_pos)).squeeze()
            logits_neg: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_neg)).squeeze()

            # Check how many are in top k
            topk_logits: Int[Tensor, "batch k"] = t.topk(logits_pos, dim=-1, k=k).indices
            in_topk = (topk_logits == name_tokens).any(-1)
            # Check how many are in bottom k
            bottomk_logits: Int[Tensor, "batch k"] = t.topk(logits_neg, dim=-1, k=k).indices
            in_bottomk = (bottomk_logits == name_tokens).any(-1)

            # Fill in results
            results[:, layer-1, head] = t.tensor([in_topk.float().mean(), in_bottomk.float().mean()])

    return results

# Generate dataset 1 prompts

In [69]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'text': f"{i}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 101)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [70]:
# results = t.zeros((2, model.cfg.n_layers, model.cfg.n_heads)) #, device=device

# Define components from our model (for typechecking, and cleaner code)
embed = model.embed
mlp0 = model.blocks[0].mlp
ln0 = model.blocks[0].ln2
unembed = model.unembed
ln_final = model.ln_final

# Get embeddings for the names in our list
# name_tokens: Int[Tensor, "batch 1"] = model.to_tokens(names, prepend_bos=False)
# name_embeddings: Int[Tensor, "batch 1 d_model"] = embed(name_tokens)

embeddings = embed(dataset.toks)
embeddings.size()

torch.Size([100, 1, 768])

In [72]:
# Get residual stream after applying MLP
resid_after_mlp1 = embeddings + mlp0(ln0(embeddings))

mlp9 = model.blocks[9].mlp
ln9 = model.blocks[9].ln2
resid_after_mlp9 = resid_after_mlp1 + mlp9(ln9(resid_after_mlp1))
logits = unembed(ln_final(resid_after_mlp9)).squeeze()
logits.size()

torch.Size([100, 50257])

In [73]:
k=5
topk_logits: Int[Tensor, "batch k"] = t.topk(logits, dim=-1, k=k).indices
topk_logits.size()

torch.Size([100, 5])

In [74]:
words = [key for key in dataset.prompts[0].keys() if key != 'text']
words

['S1']

In [81]:
logits[0].size()

torch.Size([50257])

In [82]:
k=1
for seq_idx, prompt in enumerate(dataset.prompts):
    # for word in words:
    word = words[-1]
    pred_tokens = [
        model.tokenizer.decode(token)
        for token in torch.topk(
            # logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            logits[seq_idx], k
        ).indices
    ]
    print(prompt[word], pred_tokens)

1 ['128']
2 ['nd']
3 ['rd']
4 ['teenth']
5 ['th']
6 ['teenth']
7 ['46']
8 ['192']
9 ['999']
10 ['82']
11 ['87']
12 ['02']
13 ['66']
14 ['159']
15 ['50']
16 ['384']
17 ['76']
18 ['650']
19 ['37']
20 ['GW']
21 ['st']
22 ['nd']
23 ['rd']
24 ['89']
25 ['th']
26 ['th']
27 ['th']
28 ['th']
29 ['89']
30 ['30']
31 ['803']
32 ['nd']
33 ['rd']
34 ['68']
35 ['00']
36 ['36']
37 ['37']
38 ['608']
39 ['61']
40 ['40']
41 ['41']
42 ['nd']
43 ['rd']
44 ['00']
45 ['678']
46 ['46']
47 ['00']
48 ['576']
49 ['49']
50 ['50']
51 ['50']
52 ['50']
53 ['rd']
54 ['32']
55 ['55']
56 ['32']
57 ['LM']
58 ['427']
59 ['61']
60 [' Minutes']
61 ['61']
62 ['803']
63 ['rd']
64 ['64']
65 [' ILCS']
66 ['67']
67 ['89']
68 ['68']
69 ['69']
70 [' ILCS']
71 ['002']
72 ['80']
73 ['70']
74 ['69']
75 [' ILCS']
76 ['80']
77 ['77']
78 ['78']
79 ['79']
80 ['80']
81 ['803']
82 ['502']
83 ['rd']
84 ['80']
85 ['67']
86 ['32']
87 ['66']
88 ['889']
89 ['89']
90 ['90']
91 ['504']
92 ['502']
93 ['rd']
94 ['CE']
95 ['95']
96 ['152']
97 ['15