Install requirements

In [3]:
%pip install transformers einops plotly_express ivy torch tqdm scikit-learn pandas numpy transformer_lens datasets nbformat

Note: you may need to restart the kernel to use updated packages.


Import and manage libraries and packages


In [5]:
import torch
from transformer_lens import HookedTransformer
import pandas as pd
import numpy as np
import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import einops
from ivy import to_numpy
import plotly_express as px

# Enable gradient calculations (by default)

torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x323aec3d0>

Load GPT-2 small model from transformer_lens

In [6]:
model = HookedTransformer.from_pretrained("gpt2-small")


Loaded pretrained model gpt2-small into HookedTransformer


Initialise model components as variables

In [7]:
n_layers = model.cfg.n_layers  # Number of transformer layers
d_model = model.cfg.d_model    # Dimension of the model
n_heads = model.cfg.n_heads    # Number of attention heads
d_head = model.cfg.d_head      # Dimension of each attention head
d_mlp = model.cfg.d_mlp        # Dimension of the MLP (Feed Forward network) within the transformer
d_vocab = model.cfg.d_vocab    # Size of the vocabulary

Load and display common words

In [8]:
from datasets import load_dataset

common_words = load_dataset("Alamerton/common-words")
common_words = common_words["train"]['text']
# common_words = open("common_words.txt", "r").read().split("\n")
print(common_words[:10])

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


['a', 'aa', 'aaa', 'aaron', 'ab', 'abandoned', 'abc', 'aberdeen', 'abilities', 'ability']


In [9]:
# Word Pool Creation
from collections import defaultdict

word_pools = defaultdict(list)

for word in common_words:
    tokens = model.to_tokens(" " + word, prepend_bos=False).squeeze(0)
    token_length = len(tokens)
    if 1 <= token_length <= 3:
        word_pools[token_length].append(word)

print("Word pool sizes:")
for length, words in word_pools.items():
    print(f"{length} token(s): {len(words)} words")

Word pool sizes:
1 token(s): 8025 words
2 token(s): 1614 words
3 token(s): 335 words


In [10]:
# Sequence Template Design
import random

def generate_prefix(min_length=15, max_length=25):
    prefix_words = [
        "The", "In", "As", "When", "While", "Although", "Despite",
        "After", "Before", "During", "Since", "Until", "Because",
        "If", "Unless", "Though", "Whether", "Once", "Whenever"
    ]
    prefix = random.choice(prefix_words)
    while len(model.to_tokens(prefix, prepend_bos=True)[0]) < min_length:
        prefix += " " + random.choice(common_words)
    return prefix

In [11]:
# Sequence Generation Algorithm
def generate_balanced_sequence(target_position, target_absolute_position):
    prefix = generate_prefix()
    prefix_tokens = model.to_tokens(prefix, prepend_bos=True)[0]
    
    sequence = [prefix]
    current_position = 1
    current_tokens = len(prefix_tokens)
    
    # Calculate how many tokens we need to add before the target position
    tokens_before_target = target_absolute_position - current_tokens - 3  # -3 to ensure space for target word
    
    if tokens_before_target < 0:
        return None  # Not enough space for the sequence
    
    # Add words before the target position
    while current_position < target_position:
        available_tokens = tokens_before_target - (target_position - current_position - 1)
        if available_tokens <= 0:
            return None  # Not enough space
        token_length = min(random.randint(1, 3), available_tokens)
        word = random.choice(word_pools[token_length])
        sequence.append(word)
        current_tokens += token_length
        tokens_before_target -= token_length
        current_position += 1
    
    # Add the target word
    remaining_tokens = target_absolute_position - current_tokens
    if 1 <= remaining_tokens <= 3:
        target_word = random.choice(word_pools[remaining_tokens])
        sequence.append(target_word)
        current_tokens += remaining_tokens
    else:
        return None  # Invalid sequence
    
    # Fill the rest of the sequence
    while len(sequence) < 7:
        token_length = random.randint(1, 3)
        word = random.choice(word_pools[token_length])
        sequence.append(word)
    
    return " ".join(sequence)

from tqdm import tqdm

def generate_balanced_dataset(n_samples_per_position=1000):
    dataset = []
    for target_position in tqdm(range(2, 6), desc="Generating dataset"):
        position_dataset = []
        attempts = 0
        while len(position_dataset) < n_samples_per_position and attempts < n_samples_per_position * 10:
            target_absolute_position = random.randint(20, 30)
            sequence = generate_balanced_sequence(target_position, target_absolute_position)
            if sequence is not None:
                tokens = model.to_tokens(sequence, prepend_bos=True)[0]
                if len(tokens) >= target_absolute_position:
                    position_dataset.append((sequence, target_position, target_absolute_position))
            attempts += 1
        dataset.extend(position_dataset)
    return dataset

balanced_dataset = generate_balanced_dataset()

balanced_dataset = generate_balanced_dataset()

Generating dataset: 100%|██████████| 4/4 [00:58<00:00, 14.53s/it]
Generating dataset: 100%|██████████| 4/4 [00:53<00:00, 13.33s/it]


In [12]:
# Dataset Validation and Analysis
def analyze_dataset(dataset):
    position_counts = defaultdict(int)
    absolute_position_counts = defaultdict(int)
    
    for _, target_position, target_absolute_position in dataset:
        position_counts[target_position] += 1
        absolute_position_counts[target_absolute_position] += 1
    
    print("Target word position distribution:")
    for pos, count in sorted(position_counts.items()):
        print(f"{pos}th word: {count}")
    
    print("\nAbsolute token position distribution:")
    for pos, count in sorted(absolute_position_counts.items()):
        print(f"Token {pos}: {count}")

analyze_dataset(balanced_dataset)

Target word position distribution:
2th word: 1000
3th word: 1000
4th word: 1000
5th word: 1000

Absolute token position distribution:
Token 20: 813
Token 21: 897
Token 22: 724
Token 23: 537
Token 24: 407
Token 25: 266
Token 26: 184
Token 27: 112
Token 28: 43
Token 29: 12
Token 30: 5


Define a function to generate batches of tokenized inputs

In [13]:
def gen_batch(batch_size, dataset):
    indices = np.random.choice(len(dataset), batch_size, replace=True)
    sequences, target_positions, target_absolute_positions = zip(*[dataset[i] for i in indices])
    
    tokens = model.to_tokens(sequences, prepend_bos=True)
    max_length = tokens.shape[1]
    full_tokens = torch.ones((batch_size, max_length), dtype=torch.int64)
    full_tokens[:, :tokens.shape[1]] = tokens
    
    word_lengths = [len(model.to_tokens(seq.split()[pos-1], prepend_bos=False)[0]) 
                    for seq, pos in zip(sequences, target_positions)]
    
    first_token_indices = torch.tensor([pos - 1 for pos in target_absolute_positions])
    last_token_indices = torch.tensor([pos - 1 + length for pos, length in zip(target_absolute_positions, word_lengths)])
    
    return full_tokens, sequences, torch.tensor(target_positions), first_token_indices, last_token_indices

tokens, sequences, word_positions, first_token_indices, last_token_indices = gen_batch(10, balanced_dataset)
print("Sample batch:")
for seq, pos, first, last in zip(sequences[:5], word_positions[:5], first_token_indices[:5], last_token_indices[:5]):
    print(f"Sequence: {seq}")
    print(f"Target word position: {pos}th word")
    print(f"Target word tokens: {first}-{last}")
    print()

Sample batch:
Sequence: Until trains weird rs cage lance screw stranger inputs australia hearts getting playing wendy mice rx mtv
Target word position: 3th word
Target word tokens: 19-21

Sequence: If ruling nav palm quantum trap scsi tie quantum networks symptoms africa verizon ver holland phase shanghai raymond
Target word position: 3th word
Target word tokens: 20-21

Sequence: During shorts settle smoking vt generating locked started content blue claire church exist seeking jessica arnold ky unwrap
Target word position: 3th word
Target word tokens: 19-21

Sequence: Despite evening state responses junk glad wrestling conducting guestbook serbia academy told thongs hdtv iceland bright zambia encouraging
Target word position: 2th word
Target word tokens: 19-21

Sequence: As menus york intervals freeze po champion fucked lat kodak rentcom zdnet xanax clinton charlotte beastality labeled
Target word position: 2th word
Target word tokens: 21-23



Generate a batch of tokens and their related information

In [14]:
# Test the new gen_batch function
tokens, sequences, word_positions, first_token_indices, last_token_indices = gen_batch(10, balanced_dataset)
print("Sample batch:")
for seq, pos, first, last in zip(sequences[:5], word_positions[:5], first_token_indices[:5], last_token_indices[:5]):
    print(f"Sequence: {seq}")
    print(f"Target word position: {pos}th word")
    print(f"Target word tokens: {first}-{last}")
    print()

Sample batch:
Sequence: During emergency consensus row ordering retro gbp gary animal repairs grows gilbert bradley wing lucy babes rugs europe
Target word position: 3th word
Target word tokens: 23-25

Sequence: Before proposal kill bizrate rip powerseller entire mechanical airport representative signs ethiopia gazette jeremy thomas propecia conscious
Target word position: 2th word
Target word tokens: 20-22

Sequence: Whenever tony fraud collar cooperative rebel george deposit elvis combat hq ago throws element minerals hentai serious
Target word position: 5th word
Target word tokens: 22-24

Sequence: The sept adams vbulletin pieces extras india evaluate fees giant imposed bibliographic netscape midwest
Target word position: 4th word
Target word tokens: 20-24

Sequence: When patient commissions reports offerings rca expect roger jo hope brain music kenya viii asia atlantic msgid frog
Target word position: 4th word
Target word tokens: 24-25



Set training parameters

In [15]:
batch_size = 128
epochs = 100

Collect residuals for tokens across multiple epochs

In [17]:
from tqdm import tqdm

torch.set_grad_enabled(False)
epochs = 100
all_first_token_residuals = []
all_last_token_residuals = []

for i in tqdm(range(epochs)):
    tokens, sequences, word_positions, first_token_indices, last_token_indices = gen_batch(batch_size, balanced_dataset)
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens, names_filter=lambda x: x.endswith("resid_post"))
        residuals = cache.stack_activation("resid_post")
        first_token_residuals = residuals[:, torch.arange(len(first_token_indices)).to(residuals.device)[:, None], first_token_indices, :]
        last_token_residuals = residuals[:, torch.arange(len(last_token_indices)).to(residuals.device)[:, None], last_token_indices, :]
        print("Shapes", first_token_residuals.shape, last_token_residuals.shape)
        all_first_token_residuals.append(to_numpy(first_token_residuals))
        all_last_token_residuals.append(to_numpy(last_token_residuals))

  0%|          | 0/100 [00:00<?, ?it/s]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  1%|          | 1/100 [00:51<1:24:20, 51.11s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  2%|▏         | 2/100 [01:41<1:22:36, 50.57s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  3%|▎         | 3/100 [02:44<1:31:22, 56.52s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  4%|▍         | 4/100 [03:47<1:34:34, 59.11s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  5%|▌         | 5/100 [04:52<1:36:38, 61.04s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  6%|▌         | 6/100 [06:14<1:46:54, 68.24s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  7%|▋         | 7/100 [07:48<1:58:35, 76.51s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  8%|▊         | 8/100 [09:11<2:00:40, 78.70s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


  9%|▉         | 9/100 [10:15<1:52:22, 74.09s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 10%|█         | 10/100 [12:12<2:10:53, 87.26s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 11%|█         | 11/100 [13:36<2:08:14, 86.46s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 12%|█▏        | 12/100 [14:56<2:03:56, 84.50s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 13%|█▎        | 13/100 [16:28<2:05:50, 86.78s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 14%|█▍        | 14/100 [17:42<1:58:31, 82.69s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 15%|█▌        | 15/100 [19:16<2:01:53, 86.04s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 16%|█▌        | 16/100 [20:48<2:03:19, 88.09s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 17%|█▋        | 17/100 [23:48<2:40:04, 115.72s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 18%|█▊        | 18/100 [28:35<3:48:10, 166.96s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 19%|█▉        | 19/100 [30:19<3:20:11, 148.29s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


 20%|██        | 20/100 [31:25<2:44:35, 123.44s/it]

Shapes torch.Size([12, 128, 128, 768]) torch.Size([12, 128, 128, 768])


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x1076a6bd0>>
Traceback (most recent call last):
  File "/Users/paulreynolds/Documents/23-24/SAIL/GitHub/New-July/Hidden-Coordinates/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


#### Run from here to avoid repeating 100 epoch runs!

Data preparation for training a Logistic Regression model

In [None]:
LAYER = 0
y = np.array([j for i in range(len(all_first_token_residuals[0])) for j in range(2, 6)])  # 2nd to 5th word
layer_data = all_last_token_residuals[LAYER]
X = layer_data[:, :].reshape(-1, d_model)

# Split the dataset into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y)

Create and train a Logistic Regression model

In [None]:
lr_model = LogisticRegression(multi_class='ovr', solver='saga', random_state=42, max_iter=100, C=1.0)
lr_model.fit(X_train, y_train)

Evaluate the model

In [None]:
y_pred = lr_model.predict(X_train)
print(classification_report(y_train, y_pred))
y_pred = lr_model.predict(X_test)
print(classification_report(y_test, y_pred))

Generate predictions on test batches

In [None]:
test_batches = 10
last_token_predictions_list = []
last_token_abs_indices_list = []

with torch.no_grad():
    for i in tqdm.tqdm(range(test_batches)):
        tokens, sequences, word_positions, first_token_indices, last_token_indices = gen_batch(batch_size, balanced_dataset)
        _, cache = model.run_with_cache(tokens, names_filter=lambda x: x.endswith("resid_post"))
        residuals = cache.stack_activation("resid_post")
        first_token_residuals = residuals[:, torch.arange(len(first_token_indices)).to(residuals.device)[:, None], first_token_indices, :]
        last_token_residuals = residuals[:, torch.arange(len(last_token_indices)).to(residuals.device)[:, None], last_token_indices, :]
        last_token_resids = to_numpy(einops.rearrange(last_token_residuals[LAYER], "batch word d_model -> (batch word) d_model"))
        last_token_predictions_list.append(lr_model.predict(last_token_resids))
        last_token_abs_indices_list.append(to_numpy(last_token_indices.flatten()))

Prepare and visualize results

In [None]:
last_token_abs_indices = np.concatenate(last_token_abs_indices_list)
last_token_predictions = np.concatenate(last_token_predictions_list)

df = pd.DataFrame({
    "index": [i for _ in range(batch_size * test_batches) for i in range(NUM_WORDS)],
    "abs_pos": last_token_abs_indices,
    "pred": last_token_predictions
})

Plot histogram of the prediction results

In [None]:
px.histogram(df, x="abs_pos", color="pred", facet_row="index", barnorm="fraction").show()