<a href="https://colab.research.google.com/github/Theosdoor/list-comp/blob/main/list_comp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
DEVELOPMENT_MODE = False
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("NOT running as a Colab notebook")

# Install if in Colab
if IN_COLAB:
    import google.colab
    from google.colab import drive
    drive.mount('/content/drive')
    save_path_prefix = '/content/drive/MyDrive/Colab Notebooks/'

    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    save_path_prefix = ''
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"

NOT running as a Colab notebook


In [None]:
import numpy as np
import torch, random
# from sklearn.linear_model import LogisticRegression
# from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import matplotlib.pyplot as plt
import plotly.express as px
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, HookedTransformerConfig
import os
from sklearn.decomposition import PCA
# from mpl_toolkits.mplot3d import Axes3D
# import matplotlib.pyplot as plt

# added by gemini
from transformer_lens import utils
import plotly.graph_objects as go
import pandas as pd


How does this model work?

super basic two layer transformer with no MLP or even a value matrix

In [None]:
DIGITS = list(range(10))
LIST_LEN = 2
SPECIAL = 10
VOCAB = 11
SEQ_LEN = LIST_LEN * 2 + 1
D_MODEL = 16
N_HEAD = 1
N_LAYER = 2 # 2 layers each with single attn head
FREEZE_WV = True # no value matrix in attn (i.e. attn head can only copy inputs to outputs)
MODEL_PATH = "artifacts/len_2.pt"

class DigitDataset(Dataset):
    def __init__(self, n):
        self.data = [[random.randint(0, 9) for _ in range(LIST_LEN)] for _ in range(n)]
        # come up with 'size' lots of sequences of random digits (each seq len 5)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # [d1, d2, d3, d4, d5, 10, d1, d2, d3, d4, d5], where 10 is a special separator token
        seq = self.data[idx]
        tok = seq + [SPECIAL] + seq
        return torch.tensor(tok, dtype=torch.long)

def build_mask(n: int) -> torch.Tensor:
    # create attention pattern for a sequence of length n
    # rows are queries, columns are keys
    # float(-inf) means "ignore this position" (i.e. becomes 0 in later softmax - see 3b1b video)
    m = torch.triu(torch.ones(n, n) * float("-inf"), 1) # prevents attending to future tokens
    m[LIST_LEN+1:, :LIST_LEN] = float("-inf")
    # m[LIST_LEN+1:] = all tokens after the special token (i.e. queries of output tokens)
    # m[LIST_LEN+1:, :LIST_LEN] = refers to all keys of tokens before special token
    #  ==> this line explicitly forbids the output tokens (when they are the query) from attending to the input tokens (when they are the key).

    # attention mask for [d1, d2, d3, d4, d5, 10, d1, d2, d3, d4, d5] looks like this (query rows are horizontal, key columns are vertical):
    # -   d1    d2    d3    d4    d5    10    d1    d2    d3    d4    d5  (keys)
    # d1  0    -inf   -inf  -inf  -inf  -inf  -inf  -inf  -inf  -inf  -inf
    # d2  0      0    -inf  -inf  -inf  -inf  -inf  -inf  -inf  -inf  -inf
    # d3  0      0     0    -inf  -inf  -inf  -inf  -inf  -inf  -inf  -inf
    # d4  0      0     0      0   -inf  -inf  -inf  -inf  -inf  -inf  -inf
    # d5  0      0     0      0     0   -inf  -inf  -inf  -inf  -inf  -inf
    # 10  0      0     0      0     0     0   -inf  -inf  -inf  -inf  -inf
    # d1  -inf   -inf  -inf  -inf  -inf   0     0   -inf  -inf  -inf  -inf
    # d2  -inf   -inf  -inf  -inf  -inf   0     0     0   -inf  -inf  -inf
    # d3  -inf   -inf  -inf  -inf  -inf   0     0     0     0   -inf  -inf
    # d4  -inf   -inf  -inf  -inf  -inf   0     0     0     0     0   -inf
    # d5  -inf   -inf  -inf  -inf  -inf   0     0     0     0     0     0
    # (queries)

    return m


def make_model(device: str = "cuda") -> "HookedTransformer":
    cfg = HookedTransformerConfig(
        d_model=D_MODEL,
        d_head=D_MODEL // N_HEAD,
        n_layers=N_LAYER,
        n_heads=N_HEAD,
        n_ctx=SEQ_LEN,
        d_vocab=VOCAB,
        d_vocab_out=VOCAB,
        attn_only=True, # no MLP!
    )
    model = HookedTransformer(cfg).to(device)
    if FREEZE_WV:
        set_WV_identity_and_freeze(model)
    return model


def attach_custom_mask(model: "HookedTransformer") -> None:
    def _mask(scores, hook=None):
        t = scores.size(-1)
        scores += build_mask(t).to(scores.device)
        return scores

    for block in model.blocks:
        block.attn.hook_attn_scores.add_perma_hook(_mask)


def set_WV_identity_and_freeze(model: "HookedTransformer") -> None:
    with torch.no_grad():
        eye = torch.eye(D_MODEL).unsqueeze(0)  # add head dim
        for block in model.blocks:
            block.attn.W_V.copy_(eye)
            block.attn.W_V.requires_grad = False


def train(
    epochs: int = 10,
    batch: int = 1024,
    size: int = 50000,
    val: int = 1000,
    device="cuda",
) -> HookedTransformer:
    ds = DigitDataset(size)
    dl = DataLoader(ds, batch, shuffle=True)
    model = make_model(device)
    attach_custom_mask(model)
    opt = torch.optim.AdamW(model.parameters(), 1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    for _ in tqdm(range(epochs)):
        for seq in dl:
            seq = seq.to(device)
            logits = model(seq[:, :-1])
            loss = loss_fn(logits.reshape(-1, VOCAB), seq[:, 1:].reshape(-1))
            loss.backward()
            opt.step()
            opt.zero_grad()
        corr = 0
        for _ in range(val):
            d = [random.randint(0, 9) for _ in range(LIST_LEN)]
            corr += generate(model, d) == d
        print(f"acc {corr / val:.2%}")
    return model


@torch.no_grad()
def generate(model: HookedTransformer, digits: list[int]) -> list[int]:
    seq = digits + [SPECIAL]
    out: list[int] = []
    for _ in range(LIST_LEN):
        x = torch.tensor(seq + out, device=next(model.parameters()).device).unsqueeze(0)
        nxt = model(x)[:, -1].argmax(-1).item()
        out.append(nxt)
    return out


def save_model(model: HookedTransformer, path: str = MODEL_PATH):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)


def load_model(
    path: str = MODEL_PATH, device: str = "cuda"
) -> HookedTransformer:
    model = make_model(device)
    model.load_state_dict(
        torch.load(path, map_location=device)
    )  # map weights to target device
    attach_custom_mask(model)
    model.eval()
    return model

# USAGE

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

path = save_path_prefix + MODEL_PATH
if os.path.exists(path):
    print("Loading model from", path)
    model = load_model(path, device)
else:
    print("Training model")
    model = train(epochs=20, device=device)
    save_model(model, path)

Training model
Moving model to device:  cuda


  0%|          | 0/20 [00:00<?, ?it/s]

acc 17.60%
acc 78.80%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 67.20%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%
acc 100.00%


In [None]:
# embedding and unembedding matrix w/ pca

In [None]:
# attn pattern = after softmax

# Define a sample input list to visualize
sample_list = [8, 3]
full_sequence = sample_list + [SPECIAL] + sample_list
tokens = torch.tensor(full_sequence, device=device).unsqueeze(0) # Add batch dim
token_labels = [f"d{i+1}({d})" for i, d in enumerate(sample_list)] + \
               ["SEP"] + \
               [f"o{i+1}({d})" for i, d in enumerate(sample_list)]

attn_layer = 0
attn_hook_name = "blocks."+str(attn_layer)+".attn.hook_pattern"
logits, attn_cache = model.run_with_cache(tokens, remove_batch_dim=True, stop_at_layer=attn_layer+1, names_filter=[attn_hook_name])
attn = attn_cache[attn_hook_name]

print(type(attn_cache))
print(attn_cache)
attention_pattern = attn_cache["pattern", attn_layer, "attn"]
print(attention_pattern.shape)
print(attention_pattern)

print("Layer "+ str(attn_layer) + " Head Attention Patterns:")
#  Remove the batch and head dimensions to get a 2D matrix for plotting.
attention_pattern_2d = attention_pattern.squeeze(0).squeeze(0).cpu().numpy()

print("Generating attention heatmap...")
fig = px.imshow(
    attention_pattern_2d,
    x=token_labels,
    y=token_labels,
    color_continuous_scale='Viridis',
    labels=dict(x="Key (Memory)", y="Query (Current Token)", color="Attention Weight"),
    title=f"Attention Pattern for Layer {attn_layer}"
)

fig.show()

<class 'transformer_lens.ActivationCache.ActivationCache'>
ActivationCache with keys ['blocks.0.attn.hook_pattern']
torch.Size([1, 5, 5])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7594, 0.2406, 0.0000, 0.0000, 0.0000],
         [0.0976, 0.4498, 0.4526, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.8560, 0.1440, 0.0000],
         [0.0000, 0.0000, 0.6075, 0.2026, 0.1900]]], device='cuda:0')
Layer 0 Head Attention Patterns:
Generating attention heatmap...


In [None]:
# ablation

layer_to_ablate = 0
head_index_to_ablate = 0 # fixed

# We need a dataset to calculate the average over
dataset_for_avg = DigitDataset(2000) # Use a few thousand samples for a stable average
dataloader_for_avg = DataLoader(dataset_for_avg, batch_size=128)

# Get the correct hook name for the attention pattern
attn_pattern_hook_name = "blocks.0.attn.hook_pattern"

def uniform_ablation_hook(
    pattern, # Float[torch.Tensor, "batch head_index query_pos key_pos"]
    hook,
):
    """
    Replaces the attention pattern of a target head with a uniform distribution
    over the allowed (unmasked) keys.
    """
    # Get the sub-tensor for the head we want to ablate
    # Shape: [batch, query_pos, key_pos]
    target_head_pattern = pattern[:, head_index_to_ablate, :, :]

    # A key is valid if its original attention > 0 (i.e., it wasn't masked out by softmax)
    # This creates a boolean mask of the valid connections
    valid_keys_mask = target_head_pattern > 0

    # For each query, count how many keys it can attend to.
    # Add a small epsilon to prevent division by zero for queries that can't attend anywhere.
    # Shape: [batch, query_pos, 1] (keepdim=True for broadcasting)
    num_valid_keys = valid_keys_mask.sum(dim=-1, keepdim=True) + 1e-9
    print(num_valid_keys)

    # Create the uniform pattern by dividing 1 by the number of valid keys.
    # The mask (now float) ensures we only place values in allowed positions.
    uniform_pattern = valid_keys_mask.float() / num_valid_keys
    print("pattern", uniform_pattern)

    # Replace the original pattern for the target head with the new uniform one
    print("Uniformly distributing attention for L0H0...")
#     print("BEFORE", pattern[:,  head_index_to_ablate, :, :])
    pattern[:, head_index_to_ablate, :, :] = uniform_pattern
#     print("AFTER", pattern[:, head_index_to_ablate, :, :])

    return pattern

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader

def head_ablation_hook(
    value, # Float[torch.Tensor, "batch pos head_index d_head"]
    hook,
): # -> Float[torch.Tensor, "batch pos head_index d_head"]
    print(f"Shape of the value tensor: {value.shape}")
    # print("BEFORE", value[:, :, head_index_to_ablate, :])
    value[:, head_index_to_ablate, :, :] = 0.
    # print("AFTER", value[:, :, head_index_to_ablate, :])
    return value

original_loss = model(tokens, return_type="loss")

uniform_ablated_loss = model.run_with_hooks(
    tokens,
    return_type="loss",
    fwd_hooks=[(
        attn_pattern_hook_name,
        uniform_ablation_hook
    )]
)

print("\n--- Ablation Results ---")
print(f"✅ Original Loss: {original_loss.item():.3f}")
print(f"➡️  Ablated Loss (Mean Attention Pattern): {uniform_ablated_loss.item():.3f}")

tensor([[[1.],
         [2.],
         [3.],
         [2.],
         [3.]]], device='cuda:0')
pattern tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.5000, 0.5000, 0.0000],
         [0.0000, 0.0000, 0.3333, 0.3333, 0.3333]]], device='cuda:0')
Uniformly distributing attention for L0H0...

--- Ablation Results ---
✅ Original Loss: 0.598
➡️  Ablated Loss (Uniform Attention Pattern): 1.468


In [None]:
layer_to_ablate = 0
head_index_to_ablate = 0 # fixed

# We need a dataset to calculate the average over
dataset_for_avg = DigitDataset(2000) # Use a few thousand samples for a stable average
dataloader_for_avg = DataLoader(dataset_for_avg, batch_size=128)

# Get the correct hook name for the attention pattern
attn_pattern_hook_name = utils.get_act_name("pattern", layer_to_ablate)

# Initialize a tensor to store the sum of attention patterns
# Shape: [query_pos, key_pos]
accumulated_patterns = torch.zeros((SEQ_LEN, SEQ_LEN), device=device)
num_samples = 0

print("Calculating the mean attention pattern for L0H0...")
# Loop over the dataset without tracking gradients
with torch.no_grad():
    for tokens_batch in tqdm(dataloader_for_avg):
        tokens_batch = tokens_batch.to(device)
        # Run the model and cache the attention patterns
        _, cache = model.run_with_cache(
            tokens_batch,
            names_filter=[attn_pattern_hook_name]
        )
        # Get the patterns for the specific head
        # Shape: [batch, head_index, query_pos, key_pos]
        patterns_batch = cache[attn_pattern_hook_name][:, head_index_to_ablate, :, :]

        # Add the patterns to our accumulator
        accumulated_patterns += patterns_batch.sum(dim=0)
        num_samples += len(tokens_batch)

# Calculate the mean by dividing by the total number of samples
mean_pattern = accumulated_patterns / num_samples

print(f"Mean pattern calculated. Shape: {mean_pattern.shape}\n")

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader

def head_ablation_hook(
    value, # Float[torch.Tensor, "batch pos head_index d_head"]
    hook,
): # -> Float[torch.Tensor, "batch pos head_index d_head"]
    print(f"Shape of the value tensor: {value.shape}")
    # print("BEFORE", value[:, :, head_index_to_ablate, :])
    value[:, :, head_index_to_ablate, :] = 0.
    # print("AFTER", value[:, :, head_index_to_ablate, :])
    return value

# This new hook replaces the current attention pattern with the mean pattern
def mean_ablation_hook(
    pattern, # Float[torch.Tensor, "batch head_index query_pos key_pos"]
    hook,
):
    # Replace the pattern for the target head with our pre-calculated mean pattern
    # The mean_pattern is [query_pos, key_pos], we broadcast it across the batch dimension
    print("BEFORE", pattern[:, :, head_index_to_ablate, :])
    pattern[:, head_index_to_ablate, :, :] = mean_pattern
    print("AFTER", pattern[:, :, head_index_to_ablate, :])
    return pattern

original_loss = model(tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    tokens,
    return_type="loss",
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate),
        head_ablation_hook
        )]
    )
# Run the model with the new mean ablation hook
mean_ablated_loss = model.run_with_hooks(
    tokens,
    return_type="loss",
    fwd_hooks=[(
        attn_pattern_hook_name, # Hook the attention PATTERN this time
        mean_ablation_hook
    )]
)
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")
print(f"Ablated Loss (Using Mean Attention Pattern): {mean_ablated_loss.item():.3f}")

Calculating the mean attention pattern for L0H0...


  0%|          | 0/16 [00:00<?, ?it/s]

Mean pattern calculated. Shape: torch.Size([5, 5])

Shape of the value tensor: torch.Size([1, 5, 1, 16])
BEFORE tensor([[[1., 0., 0., 0., 0.]]], device='cuda:0', grad_fn=<SliceBackward0>)
AFTER tensor([[[1., 0., 0., 0., 0.]]], device='cuda:0', grad_fn=<SliceBackward0>)
Original Loss: 0.598
Ablated Loss: 2.882
Ablated Loss (Using Mean Attention Pattern): 0.598


In [None]:
# try ablating:
# * set attn pattern to average attn pattern
# * remove residual?
# * keys & queries

# neels grokking paper - formula from model weights that represents
# but actual d_i are arbitrary


# TODO
# attn pattern (try and break it - means, rand values), neels paper - formula ideas, embed/unembed matrix

In [None]:
# --- 3. CORE ANALYSIS: EXTRACTING & VISUALIZING THE REPRESENTATION ---

# The position of the separator token, which acts as our compression point
COMPRESSION_POS = LIST_LEN

# The name of the hook point for the residual stream after Layer 1
COMPRESSION_HOOK_NAME = utils.get_act_name("resid_post", 1)

@torch.no_grad()
def get_compressed_representation(model, digits_list):
    """
    Runs the model on a list of digits and returns the activation
    at the compression point.
    """
    # The input to the model should not contain the second list of digits
    # as the compression happens before generation.
    tokens = torch.tensor([digits_list + [SPECIAL]], device=device)
    _, cache = model.run_with_cache(tokens, names_filter=[COMPRESSION_HOOK_NAME])
    # Get the activation: [batch, position, d_model] -> [d_model]
    compressed_vector = cache[COMPRESSION_HOOK_NAME][0, COMPRESSION_POS].cpu()
    return compressed_vector

# Generate a dataset of compressed representations
num_samples = 2000
compressed_vectors = []
all_digit_lists = [] # Store original lists for hover text
labels = [] # Will store the sum of digits
print(f"Generating {num_samples} samples...")
for _ in range(num_samples):
    digit_list = [random.randint(0, 9) for _ in range(LIST_LEN)]
    all_digit_lists.append(digit_list)
    vec = get_compressed_representation(model, digit_list)
    compressed_vectors.append(vec.numpy())
    labels.append(sum(digit_list)) # Label is the sum of the digits

compressed_vectors = np.array(compressed_vectors)

# Use PCA to find the 3 most important dimensions
print("Running PCA...")
pca = PCA(n_components=3)
compressed_pca = pca.fit_transform(compressed_vectors)

# --- 4. VISUALIZATION ---

print("Generating plot...")

# Create a figure for the 3D plot
fig = go.Figure()

# Add the scatter plot of the compressed list vectors
fig.add_trace(go.Scatter3d(
    x=compressed_pca[:, 0],
    y=compressed_pca[:, 1],
    z=compressed_pca[:, 2],
    mode='markers',
    marker=dict(
        size=3,
        color=labels, # Color by the sum of digits
        colorscale='Turbo', # A nice rainbow colorscale
        opacity=0.7,
        colorbar=dict(title='Sum of Digits in List'),
    ),
    name='Compressed Lists',
    # Add informative hover text
    hovertext=[f'List: {dl}<br>Sum: {s}' for dl, s in zip(all_digit_lists, labels)],
    hoverinfo='text'
))

# Update layout for clarity
fig.update_layout(
    title='Structure of the Compressed Representation (Colored by Sum)',
    scene=dict(
        xaxis_title='Principal Component 1',
        yaxis_title='Principal Component 2',
        zaxis_title='Principal Component 3'
    ),
    margin=dict(r=20, b=10, l=10, t=40)
)

fig.show()

Generating 2000 samples...
Running PCA...
Generating plot...
