Install requirements

In [8]:
%pip install -r ../requirements.txt

Note: you may need to restart the kernel to use updated packages.


Import and manage libraries and packages


In [9]:
import torch
from transformer_lens import HookedTransformer
import pandas as pd
import numpy as np
import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import einops
from ivy import to_numpy
import plotly_express as px

# Enable gradient calculations (by default)

torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x326458470>

Load GPT-2 small model from transformer_lens

In [10]:
model = HookedTransformer.from_pretrained("gpt2-small")


Loaded pretrained model gpt2-small into HookedTransformer


Initialise model components as variables

In [11]:
n_layers = model.cfg.n_layers  # Number of transformer layers
d_model = model.cfg.d_model    # Dimension of the model
n_heads = model.cfg.n_heads    # Number of attention heads
d_head = model.cfg.d_head      # Dimension of each attention head
d_mlp = model.cfg.d_mlp        # Dimension of the MLP (Feed Forward network) within the transformer
d_vocab = model.cfg.d_vocab    # Size of the vocabulary

Load and display common words

In [12]:
common_words = open("../common_words.txt", "r").read().split("\n")
print(common_words[:10])

['a', 'aa', 'aaa', 'aaron', 'ab', 'abandoned', 'abc', 'aberdeen', 'abilities', 'ability']


Calculate the number of tokens for each common word

In [13]:
num_tokens = [len(model.to_tokens(" " + word, prepend_bos=False).squeeze(0)) for word in common_words]
print(list(zip(num_tokens, common_words))[:10])

[(1, 'a'), (2, 'aa'), (2, 'aaa'), (2, 'aaron'), (1, 'ab'), (1, 'abandoned'), (2, 'abc'), (2, 'aberdeen'), (1, 'abilities'), (1, 'ability')]


Create a DataFrame of words and their token counts

In [14]:
word_df = pd.DataFrame({"word": common_words, "num_tokens": num_tokens})
word_df = word_df.query('num_tokens < 4')  # Filter words with less than 4 tokens
word_df.value_counts("num_tokens")

num_tokens
1    8025
2    1614
3     335
Name: count, dtype: int64

Define the prefix for context and set parameters

In [15]:
prefix = "The United States Declaration of Independence received its first formal public reading, in Philadelphia.\nWhen"
PREFIX_LENGTH = len(model.to_tokens(prefix, prepend_bos=True).squeeze(0))
NUM_WORDS = 7
MAX_WORD_LENGTH = 3

Split the data into training and testing sets

In [16]:
train_filter = np.random.rand(len(word_df)) < 0.8
train_word_df = word_df.iloc[train_filter]
test_word_df = word_df.iloc[~train_filter]
print(train_word_df.value_counts("num_tokens"))
print(test_word_df.value_counts("num_tokens"))

num_tokens
1    6362
2    1297
3     260
Name: count, dtype: int64
num_tokens
1    1663
2     317
3      75
Name: count, dtype: int64


Group words by their token length

In [17]:
train_word_by_length_array = [np.array([" " + j for j in train_word_df[train_word_df.num_tokens == i].word.values]) for i in range(1, MAX_WORD_LENGTH + 1)]
test_word_by_length_array = [np.array([" " + j for j in test_word_df[test_word_df.num_tokens == i].word.values]) for i in range(1, MAX_WORD_LENGTH + 1)]

Define a function to generate batches of tokenized inputs

In [18]:
def gen_batch(batch_size, word_by_length_array):
    word_lengths = torch.randint(1, MAX_WORD_LENGTH+1, (batch_size, NUM_WORDS))
    words = []
    for i in range(batch_size):
        row = []
        for word_len in word_lengths[i].tolist():
            word = word_by_length_array[word_len-1][np.random.randint(len(word_by_length_array[word_len-1]))]
            row.append(word)
        words.append("".join(row))
    full_tokens = torch.ones((batch_size, PREFIX_LENGTH + MAX_WORD_LENGTH*NUM_WORDS), dtype=torch.int64)
    tokens = model.to_tokens([prefix + word for word in words], prepend_bos=True)
    full_tokens[:, :tokens.shape[-1]] = tokens
    
    first_token_indices = torch.concatenate([
        torch.zeros(batch_size, dtype=int)[:, None], word_lengths.cumsum(dim=-1)[..., :-1]
    ], dim=-1) + PREFIX_LENGTH
    
    last_token_indices = word_lengths.cumsum(dim=-1) - 1 + PREFIX_LENGTH
    return full_tokens, words, word_lengths, first_token_indices, last_token_indices

Generate a batch of tokens and their related information

In [19]:
tokens, words, word_lengths, first_token_indices, last_token_indices = gen_batch(10, train_word_by_length_array)
tokens, words, word_lengths, first_token_indices, last_token_indices

(tensor([[50256,   464,  1578,  1829, 24720,   286, 20153,  2722,   663,   717,
           8766,  1171,  3555,    11,   287,  8857,    13,   198,  2215,   269,
          15356,  2146,  7426,  1808,  5296,   324, 13131,  8333,  2156,   318,
          38006, 15065, 50256, 50256, 50256, 50256, 50256,     1,     1,     1],
         [50256,   464,  1578,  1829, 24720,   286, 20153,  2722,   663,   717,
           8766,  1171,  3555,    11,   287,  8857,    13,   198,  2215,  8574,
          37941,  1453,  1291,  4703,  1573,   285,    85,   267,   805,  6145,
           8984, 50256, 50256, 50256, 50256, 50256, 50256,     1,     1,     1],
         [50256,   464,  1578,  1829, 24720,   286, 20153,  2722,   663,   717,
           8766,  1171,  3555,    11,   287,  8857,    13,   198,  2215,   716,
           5168,   785, 35961,  6799, 25761, 14354,   300,   590,  2853,   332,
            256,  1872,  8149, 50256, 50256, 50256, 50256,     1,     1,     1],
         [50256,   464,  1578,  1829,

Set training parameters

In [20]:
batch_size = 256
epochs = 1000

Collect residuals for tokens across multiple epochs

In [21]:
torch.set_grad_enabled(False)
epochs = 100
all_first_token_residuals = []
all_last_token_residuals = []
                                                        
for i in tqdm.tqdm(range(epochs)):
    tokens, words, word_lengths, first_token_indices, last_token_indices = gen_batch(batch_size, train_word_by_length_array)
    with torch.no_grad():
        # _, cache = model.run_with_cache(tokens.cuda(), names_filter=lambda x: x.endswith("resid_post"))
        _, cache = model.run_with_cache(tokens, names_filter=lambda x: x.endswith("resid_post")) # Can't run run_with_cache with CUDA on my Mac, just passing 'tokens' instead 
        residuals = cache.stack_activation("resid_post")
        first_token_residuals = residuals[:, torch.arange(len(first_token_indices)).to(residuals.device)[:, None], first_token_indices, :]
        last_token_residuals = residuals[:, torch.arange(len(last_token_indices)).to(residuals.device)[:, None], last_token_indices, :]
        print("Shapes", first_token_residuals.shape, last_token_residuals.shape)
        all_first_token_residuals.append(to_numpy(first_token_residuals))
        all_last_token_residuals.append(to_numpy(last_token_residuals))

  0%|          | 0/100 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  1%|          | 1/100 [00:02<04:47,  2.91s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  2%|▏         | 2/100 [00:06<05:25,  3.32s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  3%|▎         | 3/100 [00:10<05:47,  3.58s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  4%|▍         | 4/100 [00:13<05:40,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  5%|▌         | 5/100 [00:17<05:38,  3.56s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  6%|▌         | 6/100 [00:21<05:35,  3.57s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  7%|▋         | 7/100 [00:24<05:32,  3.58s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  8%|▊         | 8/100 [00:28<05:26,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


  9%|▉         | 9/100 [00:31<05:15,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 10%|█         | 10/100 [00:34<05:12,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 11%|█         | 11/100 [00:38<04:58,  3.36s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 12%|█▏        | 12/100 [00:41<04:57,  3.38s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 13%|█▎        | 13/100 [00:45<05:17,  3.65s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 14%|█▍        | 14/100 [00:49<05:23,  3.76s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 15%|█▌        | 15/100 [00:53<05:14,  3.70s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 16%|█▌        | 16/100 [00:56<05:01,  3.59s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 17%|█▋        | 17/100 [01:00<04:54,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 18%|█▊        | 18/100 [01:03<04:49,  3.53s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 19%|█▉        | 19/100 [01:07<04:46,  3.54s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 20%|██        | 20/100 [01:10<04:43,  3.54s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 21%|██        | 21/100 [01:14<04:40,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 22%|██▏       | 22/100 [01:17<04:36,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 23%|██▎       | 23/100 [01:21<04:43,  3.69s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 24%|██▍       | 24/100 [01:25<04:32,  3.59s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 25%|██▌       | 25/100 [01:28<04:27,  3.57s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 26%|██▌       | 26/100 [01:32<04:21,  3.53s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 27%|██▋       | 27/100 [01:35<04:15,  3.50s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 28%|██▊       | 28/100 [01:39<04:14,  3.54s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 29%|██▉       | 29/100 [01:42<04:03,  3.43s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 30%|███       | 30/100 [01:45<04:02,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 31%|███       | 31/100 [01:49<03:58,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 32%|███▏      | 32/100 [01:52<03:58,  3.51s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 33%|███▎      | 33/100 [01:56<03:49,  3.42s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 34%|███▍      | 34/100 [01:59<03:48,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 35%|███▌      | 35/100 [02:03<03:46,  3.49s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 36%|███▌      | 36/100 [02:06<03:45,  3.52s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 37%|███▋      | 37/100 [02:10<03:41,  3.52s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 38%|███▊      | 38/100 [02:13<03:31,  3.41s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 39%|███▉      | 39/100 [02:17<03:31,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 40%|████      | 40/100 [02:20<03:28,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 41%|████      | 41/100 [02:24<03:26,  3.50s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 42%|████▏     | 42/100 [02:27<03:21,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 43%|████▎     | 43/100 [02:31<03:17,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 44%|████▍     | 44/100 [02:35<03:23,  3.64s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 45%|████▌     | 45/100 [02:38<03:20,  3.64s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 46%|████▌     | 46/100 [02:42<03:13,  3.59s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 47%|████▋     | 47/100 [02:46<03:21,  3.80s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 48%|████▊     | 48/100 [02:50<03:17,  3.80s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 49%|████▉     | 49/100 [02:53<03:09,  3.72s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 50%|█████     | 50/100 [02:57<03:03,  3.67s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 51%|█████     | 51/100 [03:00<02:57,  3.62s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 52%|█████▏    | 52/100 [03:04<02:54,  3.64s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 53%|█████▎    | 53/100 [03:08<02:49,  3.60s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 54%|█████▍    | 54/100 [03:11<02:44,  3.58s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 55%|█████▌    | 55/100 [03:15<02:39,  3.54s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 56%|█████▌    | 56/100 [03:18<02:36,  3.57s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 57%|█████▋    | 57/100 [03:22<02:30,  3.50s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 58%|█████▊    | 58/100 [03:25<02:27,  3.51s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 59%|█████▉    | 59/100 [03:28<02:21,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 60%|██████    | 60/100 [03:32<02:20,  3.50s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 61%|██████    | 61/100 [03:36<02:18,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 62%|██████▏   | 62/100 [03:39<02:14,  3.53s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 63%|██████▎   | 63/100 [03:43<02:11,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 64%|██████▍   | 64/100 [03:46<02:07,  3.55s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 65%|██████▌   | 65/100 [03:50<02:02,  3.51s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 66%|██████▌   | 66/100 [03:53<01:59,  3.51s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 67%|██████▋   | 67/100 [03:57<01:55,  3.51s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 68%|██████▊   | 68/100 [04:01<01:56,  3.64s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 69%|██████▉   | 69/100 [04:04<01:53,  3.65s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 70%|███████   | 70/100 [04:08<01:49,  3.63s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 71%|███████   | 71/100 [04:11<01:44,  3.59s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 72%|███████▏  | 72/100 [04:15<01:39,  3.56s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 73%|███████▎  | 73/100 [04:19<01:40,  3.71s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 74%|███████▍  | 74/100 [04:22<01:33,  3.59s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 75%|███████▌  | 75/100 [04:26<01:28,  3.53s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 76%|███████▌  | 76/100 [04:29<01:23,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 77%|███████▋  | 77/100 [04:33<01:20,  3.49s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 78%|███████▊  | 78/100 [04:36<01:15,  3.45s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 79%|███████▉  | 79/100 [04:39<01:12,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 80%|████████  | 80/100 [04:43<01:09,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 81%|████████  | 81/100 [04:46<01:05,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 82%|████████▏ | 82/100 [04:50<01:02,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 83%|████████▎ | 83/100 [04:53<00:59,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 84%|████████▍ | 84/100 [04:57<00:55,  3.49s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 85%|████████▌ | 85/100 [05:00<00:52,  3.49s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 86%|████████▌ | 86/100 [05:04<00:48,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 87%|████████▋ | 87/100 [05:07<00:45,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 88%|████████▊ | 88/100 [05:11<00:41,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 89%|████████▉ | 89/100 [05:14<00:38,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 90%|█████████ | 90/100 [05:18<00:34,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 91%|█████████ | 91/100 [05:21<00:31,  3.49s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 92%|█████████▏| 92/100 [05:25<00:27,  3.47s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 93%|█████████▎| 93/100 [05:28<00:24,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 94%|█████████▍| 94/100 [05:32<00:20,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 95%|█████████▌| 95/100 [05:35<00:17,  3.49s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 96%|█████████▌| 96/100 [05:39<00:13,  3.48s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 97%|█████████▋| 97/100 [05:42<00:10,  3.45s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 98%|█████████▊| 98/100 [05:45<00:06,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


 99%|█████████▉| 99/100 [05:49<00:03,  3.46s/it]

Shapes torch.Size([12, 256, 7, 768]) torch.Size([12, 256, 7, 768])


100%|██████████| 100/100 [05:52<00:00,  3.53s/it]


#### Run from here to avoid repeating 100 epoch runs!

Data preparation for training a Logistic Regression model

In [26]:
LAYER = 3
y = np.array([j for i in range(len(all_first_token_residuals[0])) for j in range(NUM_WORDS)])
layer_data = all_last_token_residuals[LAYER]
X = layer_data[:, :].reshape(-1, d_model)

# Split the dataset into train and test sets

x_indices = to_numpy(torch.randperm(len(X))[:10000])
y_indices = to_numpy(torch.randperm(len(y))[:10000])
common_indices = np.intersect1d(x_indices, y_indices)

X_train, X_test, y_train, y_test = train_test_split(X[common_indices], y[common_indices], test_size=0.1)

Create and train a Logistic Regression model

In [27]:
lr_model = LogisticRegression(multi_class='ovr', solver='saga', random_state=42, max_iter=100, C=1.0)
lr_model.fit(X_train, y_train)

Evaluate the model

In [28]:
y_pred = lr_model.predict(X_train)
print(classification_report(y_train, y_pred))
y_pred = lr_model.predict(X_test)
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         7
           1       1.00      1.00      1.00         5
           2       1.00      1.00      1.00         2
           3       1.00      1.00      1.00         5
           4       1.00      1.00      1.00         4
           5       1.00      1.00      1.00         2
           6       1.00      1.00      1.00         5

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       1.00      0.50      0.67         2
           2       1.00      0.50      0.67         2

    accuracy                           0.50         4
   macro avg       0.67      0.33      0.44         4
weighted avg       1.00      0.50      0.67         4



Generate predictions on test batches

In [29]:
test_batches = 10
last_token_predictions_list = []
last_token_abs_indices_list = []

with torch.no_grad():
    for i in tqdm.tqdm(range(test_batches)):
        tokens, words, word_lengths, first_token_indices, last_token_indices = gen_batch(batch_size, test_word_by_length_array)
        _, cache = model.run_with_cache(tokens, names_filter=lambda x: x.endswith("resid_post"))
        residuals = cache.stack_activation("resid_post")
        first_token_residuals = residuals[:, torch.arange(len(first_token_indices)).to(residuals.device)[:, None], first_token_indices, :]
        last_token_residuals = residuals[:, torch.arange(len(last_token_indices)).to(residuals.device)[:, None], last_token_indices, :]
        last_token_resids = to_numpy(einops.rearrange(last_token_residuals[LAYER], "batch word d_model -> (batch word) d_model"))
        last_token_predictions_list.append(lr_model.predict(last_token_resids))
        last_token_abs_indices_list.append(to_numpy(last_token_indices.flatten()))

100%|██████████| 10/10 [00:16<00:00,  1.62s/it]


Prepare and visualize results

In [30]:
last_token_abs_indices = np.concatenate(last_token_abs_indices_list)
last_token_predictions = np.concatenate(last_token_predictions_list)

df = pd.DataFrame({
    "index": [i for _ in range(batch_size * test_batches) for i in range(NUM_WORDS)],
    "abs_pos": last_token_abs_indices,
    "pred": last_token_predictions
})

Plot histogram of the prediction results

In [31]:
px.histogram(df, x="abs_pos", color="pred", facet_row="index", barnorm="fraction").show()