## Setup

In [3]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

import os
import copy

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
import umap

import einops
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd, itertools
from tqdm.auto import tqdm

from transformer_lens import HookedTransformer, HookedTransformerConfig, utils

# Configure plotly to use static rendering if widgets fail
import plotly.io as pio
pio.renderers.default = "notebook"


In [4]:
float_formatter = "{:.5f}".format
np.set_printoptions(formatter={'float_kind':float_formatter})

## Model

In [5]:
# ---------- constants ----------
LIST_LEN = 2 # [d1, d2]
SEQ_LEN = LIST_LEN * 2 + 1 # [d1, d2, SEP, o1, o2]

N_DIGITS = 100
DIGITS = list(range(N_DIGITS)) # 100 digits from 0 to 99
SEP = DIGITS[-1] + 1 # special seperator token for the model to think about the input (+1 to avoid confusion with the last digit)
SPECIAL2 = SEP + 1 # another special token, e.g. for padding or end of sequence
VOCAB = len(DIGITS) + 1  # +1 for the special token

D_MODEL = 8
N_HEAD = 1 # 1
N_LAYER = 3 # 2
USE_LN = False # use layer norm in model
USE_BIAS = False # use bias in model
FREEZE_WV = True # no value matrix in attn 
FREEZE_WO = True # no output matrix in attn (i.e. attn head can only copy inputs to outputs)
WEIGHT_DECAY = 0.01 # default 0.01

TRAIN_SPLIT = 0.8 # 80% train, 20% test
MAX_TRAIN_STEPS = 300_000 # max training steps

# model name for saving and loading
# MODEL_NAME = f'{N_DIGITS}dig_{D_MODEL}d'
MODEL_NAME = '3layer_100dig_8d'
MODEL_PATH = "models/" + MODEL_NAME + ".pt"

USE_CHECKPOINTING = True # whether to use checkpointing for training

DEV = (
    "cuda"
    if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
device = DEV
torch.manual_seed(0)

# ---------- mask ----------
# attention mask for [d1, d2, SEP, o1, o2] looks like this (query rows are horizontal, key columns are vertical):
# -    d1    d2    SEP    o1    o2   (keys)
# d1  -inf  -inf   -inf  -inf  -inf
# d2   0    -inf   -inf  -inf  -inf
# SEP  0      0    -inf  -inf  -inf
# o1  -inf  -inf    0    -inf   -inf
# o2  -inf  -inf    0      0    -inf
# (queries)

mask_bias = torch.triu(torch.ones(SEQ_LEN, SEQ_LEN) * float("-inf")) # upper triangular bias mask (lead_diag & above = -inf, rest = 0)
mask_bias[0, 0] = 0. # don't want a full row of -inf! otherwise we get nan erros & training breaks
mask_bias[LIST_LEN+1:, :LIST_LEN] = float("-inf") # stop output tokens from attending to input tokens
mask_bias = mask_bias.unsqueeze(0).unsqueeze(0) # (1,1,T,T) broadcastable across batch and heads

print(mask_bias.cpu()[0][0])


tensor([[0., -inf, -inf, -inf, -inf],
        [0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [-inf, -inf, 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf]])


In [6]:
# ---------- data ----------
# Create all possible combinations of digits
all_data = list(itertools.product(DIGITS, repeat=LIST_LEN))
n_data = len(all_data)
all_data = torch.tensor(all_data, dtype=torch.int64)

# Create sequences of the form [d1, d2, SEP, d1, d2]
all_targets = torch.full((n_data, SEQ_LEN), SEP)
all_targets[:, :LIST_LEN] = all_data
all_targets[:, LIST_LEN+1:] = all_data

# Create input sequences of the form [d1, d2, SEP, SEP, SEP]
all_inputs = all_targets.clone()
all_inputs[:, LIST_LEN+1:] = SEP

# Shuffle the dataset (inputs and targets together)
perm = torch.randperm(n_data)
all_inputs = all_inputs[perm]
all_targets = all_targets[perm]

train_ds = TensorDataset(all_inputs[:int(TRAIN_SPLIT*n_data)], all_targets[:int(TRAIN_SPLIT*n_data)])  # 80% for training
val_ds = TensorDataset(all_inputs[int(TRAIN_SPLIT*n_data):], all_targets[int(TRAIN_SPLIT*n_data):])  # 20% for validation
train_batch_size = min(128, len(train_ds))  # Use a batch size of 128 or less if dataset is smaller
val_batch_size = min(256, len(val_ds))  # Use a batch size of 256 or less if dataset is smaller
train_dl = DataLoader(train_ds, train_batch_size, shuffle=True, drop_last=True)
val_dl = DataLoader(val_ds, val_batch_size, drop_last=False)

print("Input:", train_ds[0][0])  # Example input: [d1, d2, SEP, SEP, SEP]
print("Target:", train_ds[0][1]) # Example target: [d1, d2, SEP, d1, d2]
len(train_ds), len(val_ds)  # Should be 80% for train and 20% for validation

Input: tensor([ 60,  44, 100, 100, 100])
Target: tensor([ 60,  44, 100,  60,  44])


(8000, 2000)

In [7]:
# ---------- config helper ----------
def attach_custom_mask(model):
    def _mask(scores, hook=None):
        # scores: (batch, heads, Q, K)
        return scores + mask_bias.to(scores.device)
    
    # register the same mask hook on every layer
    for block in model.blocks:
        block.attn.hook_attn_scores.add_perma_hook(_mask, dir="fwd")


def strip_bias(m):
    for mod in m.modules():
        if hasattr(mod, "bias") and mod.bias is not None:
            mod.bias.requires_grad_(False)
            torch.nn.init.zeros_(mod.bias)
            print(mod)

    # remove biases from attention layers
    attn_biases = ['b_Q', 'b_K', 'b_V', 'b_O']
    for block in m.blocks:
        for b in attn_biases:
            mod = getattr(block.attn, b, None)
            if mod is not None:
                mod.requires_grad_(False)
                torch.nn.init.zeros_(mod)

    # remove unembed bias
    if hasattr(m, "unembed") and m.b_U is not None:
        m.unembed.b_U.requires_grad_(False)
        torch.nn.init.zeros_(m.unembed.b_U)

def set_WV_identity_and_freeze(model, d_model):
    with torch.no_grad():
        # Create a stack of identity-like matrices for W_V
        # Each matrix is of shape (d_model, d_head)
        # We take the first d_head columns of the d_model x d_model identity matrix
        identity_slice = torch.eye(d_model, model.cfg.d_head)
        # Repeat for each head
        W_V_identity = identity_slice.unsqueeze(0).repeat(model.cfg.n_heads, 1, 1)
        
        for block in model.blocks:
            block.attn.W_V.copy_(W_V_identity)
            block.attn.W_V.requires_grad = False

def set_WO_identity_and_freeze(model, d_model):
    with torch.no_grad():
        # Create a stack of identity-like matrices for W_O
        # Each matrix is of shape (d_head, d_model)
        # We take the first d_head rows of the d_model x d_model identity matrix
        identity_slice = torch.eye(model.cfg.d_head, d_model)
        # Repeat for each head
        W_O_identity = identity_slice.unsqueeze(0).repeat(model.cfg.n_heads, 1, 1)

        for block in model.blocks:
            block.attn.W_O.copy_(W_O_identity)
            block.attn.W_O.requires_grad = False


def make_model(n_layers=N_LAYER, n_heads=N_HEAD, d_model=D_MODEL, ln=USE_LN, use_bias=USE_BIAS, freeze_wv=FREEZE_WV, freeze_wo=FREEZE_WO):
    cfg = HookedTransformerConfig(
        n_layers = n_layers,
        n_heads = n_heads,
        d_model = d_model,
        d_head = d_model//n_heads,
        n_ctx=SEQ_LEN,
        d_vocab=VOCAB,
        attn_only=True, # no MLP!
        normalization_type=("LN" if ln else None),
    )
    model = HookedTransformer(cfg).to(DEV)
    if freeze_wv:
        set_WV_identity_and_freeze(model, d_model)
    if freeze_wo:
        set_WO_identity_and_freeze(model, d_model)
    if not use_bias:
        strip_bias(model)
    
    attach_custom_mask(model)
    return model

In [8]:
# ----- Model saving / loading helpers ------
def save_model(model, path = MODEL_PATH):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(path = MODEL_PATH, device = DEV):
    print("Loading model from", path)
    model = make_model()
    model.load_state_dict(
        torch.load(path, map_location=device)
    )  # map weights to target device
    model.eval()
    return model

In [9]:
# ---------- utilities ----------
def accuracy(m):
    m.eval()
    hits = tots = 0
    with torch.no_grad():
        for inputs, targets in val_dl:
            logits = m(inputs.to(DEV))[:, LIST_LEN+1:]  # (batch, 2, vocab)
            preds = logits.argmax(-1)
            hits += (preds == targets[:, LIST_LEN+1:].to(DEV)).sum().item()
            tots += preds.numel()
    return hits / tots


def train(m, max_steps=10_000, early_stop_acc=0.999, checkpoints=False, weight_decay=WEIGHT_DECAY):
    opt = torch.optim.AdamW(m.parameters(), 1e-3, weight_decay=weight_decay)
    ce = torch.nn.CrossEntropyLoss()
    dl = itertools.cycle(train_dl)  # infinite iterator
    for step in tqdm(range(max_steps), desc="Training"):
        inputs, targets = next(dl)
        # get logits/loss for output tokens only
        logits = m(inputs.to(DEV))[:, LIST_LEN+1:].reshape(-1, VOCAB) 
        loss = ce(logits, targets[:, LIST_LEN+1:].reshape(-1).to(DEV))
        loss.backward()
        opt.step()
        opt.zero_grad()
        if (step + 1) % 100 == 0:
            acc = accuracy(m)
            if acc >= early_stop_acc:
                print(f"Early stopping at step {step + 1} with accuracy {acc:.2%} >= {early_stop_acc:.2%}")
                break
            update_every = max(min(10_000, 0.05*max_steps), 1000)
            if (step+1) % update_every == 0:
                print(f"Step {step + 1}, Loss: {loss.item():.4f}, Accuracy: {acc:.2%}")
            if checkpoints and (step+1) % 50_000 == 0:
                save_model(m, MODEL_PATH)
            
    print(f"Final accuracy: {accuracy(m):.2%}")


In [10]:
# Check train set
train_ds[:5]

(tensor([[ 60,  44, 100, 100, 100],
         [ 28,  90, 100, 100, 100],
         [ 93,  99, 100, 100, 100],
         [ 19,  17, 100, 100, 100],
         [ 49,  19, 100, 100, 100]]),
 tensor([[ 60,  44, 100,  60,  44],
         [ 28,  90, 100,  28,  90],
         [ 93,  99, 100,  93,  99],
         [ 19,  17, 100,  19,  17],
         [ 49,  19, 100,  49,  19]]))

In [None]:
# ---------- experiment grid ----------

specs = [
    {'name': 'd128', 'd_model': 128},
    {'name': 'd128_2layer', 'd_model': 128, 'n_layers': 2},
    {'name': 'd128_2layer', 'd_model': 128, 'n_layers': 2, 'ln': True, 'use_bias': True, 'freeze_wo': True, 'freeze_wv': True},
    {'name': 'd128_3layer', 'd_model': 128, 'n_layers': 3},
    

    # {'name': 'd64', 'd_model': 64},
    
    # {'name': 'd32', 'd_model': 32},
    # {'name': 'd32_ln_bias', 'd_model': 32, 'ln': True, 'use_bias': True},
    # {'name': 'd32_noLN', 'd_model': 32, 'ln': False, 'use_bias': True},
    # {'name': 'd32_noBias', 'd_model': 32, 'ln': True, 'use_bias': False},
    # {'name': 'd32_noLNnoBias', 'd_model': 32, 'ln': False, 'use_bias': False},
    # {'name': 'd32_fwo', 'd_model': 32, 'freeze_wo': True},
    # {'name': 'd32_unfwo', 'd_model': 32, 'freeze_wo': False},

    # {'name': 'd16', 'd_model': 16},
    # {'name': 'd16_ln_bias', 'd_model': 16, 'ln': True, 'use_bias': True},
    # {'name': 'd16_noLN', 'd_model': 16, 'ln': False, 'use_bias': True},
    # {'name': 'd16_noBias', 'd_model': 16, 'ln': True, 'use_bias': False},
    # {'name': 'd16_noLNnoBias', 'd_model': 16, 'ln': False, 'use_bias': False},
    # {'name': 'd16_fwo', 'd_model': 16, 'freeze_wo': True},
    # {'name': 'd16_unfwo', 'd_model': 16, 'freeze_wo': False},

    # {'name': 'd8', 'd_model': 8},
    # {'name': 'd8_ln_bias', 'd_model': 8, 'ln': True, 'use_bias': True},
    # {'name': 'd8_noLN', 'd_model': 8, 'ln': False, 'use_bias': True},
    # {'name': 'd8_noBias', 'd_model': 8, 'ln': True, 'use_bias': False},
    # {'name': 'd8_noLNnoBias', 'd_model': 8, 'ln': False, 'use_bias': False},
    # {'name': 'd8_fwo', 'd_model': 8, 'freeze_wo': True},
    # {'name': 'd8_unfwo', 'd_model': 8, 'freeze_wo': False},

    # {'name': 'd4_ln_bias', 'd_model': 4, 'ln': True, 'use_bias': True},
]

rows = []
for spec in specs:
    # Create a full spec by starting with defaults and updating with the current spec
    full_spec = {
        'n_layers': N_LAYER,
        'n_heads': N_HEAD,
        'd_model': D_MODEL,
        'ln': USE_LN,
        'use_bias': USE_BIAS,
        'freeze_wv': FREEZE_WV,
        'freeze_wo': FREEZE_WO,
        'weight_decay': WEIGHT_DECAY,
    }
    full_spec.update(spec) # Overwrite defaults with provided spec values

    print(f"--- Training model: {full_spec['name']} ---")
    model = make_model(
        n_layers=full_spec['n_layers'],
        n_heads=full_spec['n_heads'],
        d_model=full_spec['d_model'], 
        ln=full_spec['ln'],
        use_bias=full_spec['use_bias'],
        freeze_wv=full_spec['freeze_wv'],
        freeze_wo=full_spec['freeze_wo'],
    )

    train(model, max_steps=20_000, weight_decay=full_spec['weight_decay'])
    
    # Add all spec parameters to the results
    result = full_spec.copy()
    result['val_acc'] = round(accuracy(model), 4)
    rows.append(result)

df = pd.DataFrame(rows)

# Move 'name' column to the front for better readability
if 'name' in df.columns:
    cols = ['name'] + [col for col in df.columns if col != 'name']
    df = df[cols]

print(df.to_markdown(index=False))






Bias not needed

In [12]:
# LOAD existing or train and SAVE new model

if os.path.exists(MODEL_PATH):
    model = load_model(MODEL_PATH, device=DEV)
else:
    print("Training model")
    model = make_model()
    train(model, max_steps=MAX_TRAIN_STEPS, early_stop_acc=0.999, checkpoints=USE_CHECKPOINTING)
    save_model(model, MODEL_PATH)

# from torchinfo import summary
# summary(model) 

Loading model from models/3layer_100dig_8d.pt
Moving model to device:  cuda


In [13]:
# --- Model Parameters Overview ---

print("--- Overview of Model Parameters ---")
total_params = 0
trainable_params = 0

# Use a formatted string for better alignment
print(f"{'Parameter Name':<40} | {'Shape':<20} | {'Trainable':<10}")
print("-" * 80)

for name, param in model.named_parameters():
    shape_str = str(tuple(param.shape))
    is_trainable = "Yes" if param.requires_grad else "No"
    total_params += param.numel()

    if not param.requires_grad:
        continue
    # Print only trainable parameters
    print(f"{name:<40} | {shape_str:<20} | {is_trainable:<10}")
    trainable_params += param.numel()

print("-" * 80)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
print("-" * 80)

--- Overview of Model Parameters ---
Parameter Name                           | Shape                | Trainable 
--------------------------------------------------------------------------------
embed.W_E                                | (101, 8)             | Yes       
pos_embed.W_pos                          | (5, 8)               | Yes       
blocks.0.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.0.attn.W_K                        | (1, 8, 8)            | Yes       
blocks.1.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.1.attn.W_K                        | (1, 8, 8)            | Yes       
blocks.2.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.2.attn.W_K                        | (1, 8, 8)            | Yes       
unembed.W_U                              | (8, 101)             | Yes       
--------------------------------------------------------------------------------
Total parameters: 2621
Trainabl

### Model attention

We confirm below that the model does not leak attention onto the first two tokens, which are the inputs to the task. The model should only attend to the first two tokens when predicting the third token, and not attend to them at all when predicting the fourth and fifth tokens.

In [None]:
# --- Using Plotly for visualization ---

def check_attention(m, dataloader, eps=1e-3):
    for inputs, _ in dataloader:
        with torch.no_grad():
            _, cache = m.run_with_cache(inputs.to(DEV))
        for l in range(m.cfg.n_layers):
            pat = cache["pattern", l][:, 0]  # (batch, Q, K)
            leak = pat[:, LIST_LEN+1:, :LIST_LEN].sum(dim=-1)  # mass on forbidden keys
            if (leak > eps).any():
                raise ValueError(f"❌ Layer {l}: output tokens attend to x₁/x₂ by >{eps:.0e}")
    print("✅ no attention leakage onto x₁/x₂")


sample = val_ds[0][0] # Example input sequence
print(f"Sample sequence: {sample.cpu().numpy()}")  # Print the sample sequence for reference
_, cache = model.run_with_cache(sample.unsqueeze(0).to(DEV))

# --- Create Plotly visualization ---
token_labels = [f'd{i+1}' for i in range(LIST_LEN)] + ['SEP'] + [f'o{i+1}' for i in range(LIST_LEN)]
subplot_titles = [f"Layer {l} Attention Pattern" for l in range(model.cfg.n_layers)]

fig = make_subplots(
    rows=1, 
    cols=model.cfg.n_layers, 
    subplot_titles=subplot_titles,
    horizontal_spacing=0.08 # Add spacing between plots
)

for l in range(model.cfg.n_layers):
    pat = cache["pattern", l][0, 0].cpu().numpy()
    
    fig.add_trace(
        go.Heatmap(
            z=pat,
            x=token_labels,
            y=token_labels,
            colorscale="Viridis",
            zmin=0,
            zmax=1,
            showscale=(l == model.cfg.n_layers - 1) # Show colorbar only for the last plot
        ),
        row=1, col=l+1
    )

fig.update_layout(
    title_text="Attention Patterns for a Sample Sequence",
    width=1200,
    height=450,
    template="plotly_white"
)

# Apply settings to all axes
fig.update_xaxes(title_text="Key Position")
fig.update_yaxes(title_text="Query Position", autorange='reversed')

fig.show()

check_attention(model, val_dl)

Sample sequence: [ 80  52 100 100 100]


✅ no attention leakage onto x₁/x₂


In [None]:
# --- Mean Attention Patterns ---

all_pats = [[] for _ in range(model.cfg.n_layers)]
for inputs, _ in val_dl:
    with torch.no_grad():
        _, cache = model.run_with_cache(inputs.to(DEV))
    for l in range(model.cfg.n_layers):
        pat = cache["pattern", l][:, 0]  # (batch, Q, K)
        all_pats[l].append(pat)
all_pats = [torch.cat(pats, dim=0) for pats in all_pats]

for l, pats in enumerate(all_pats):
    identical = torch.allclose(pats, pats[0].expand_as(pats))
    print(f"Layer {l}: all attention patterns identical? {'✅' if identical else '❌'}")

with torch.no_grad():
    avg_pats = [
        torch.zeros(SEQ_LEN, SEQ_LEN, device=DEV) for _ in range(model.cfg.n_layers)
    ]
    n = 0
    for inputs, _ in val_dl:
        _, cache = model.run_with_cache(inputs.to(DEV))
        for l in range(model.cfg.n_layers):
            avg_pats[l] += cache["pattern", l][:, 0].sum(0)
        n += inputs.shape[0]
    avg_pats = [p / n for p in avg_pats]

# --- Visualize Average Attention Patterns ---
token_labels = [f'd{i+1}' for i in range(LIST_LEN)] + ['SEP'] + [f'o{i+1}' for i in range(LIST_LEN)]
subplot_titles = [f"Layer {l} Average Attention" for l in range(model.cfg.n_layers)]

fig = make_subplots(
    rows=1, 
    cols=model.cfg.n_layers, 
    subplot_titles=subplot_titles,
    horizontal_spacing=0.08
)

for l in range(model.cfg.n_layers):
    avg_pat_np = avg_pats[l].cpu().numpy()
    
    fig.add_trace(
        go.Heatmap(
            z=avg_pat_np,
            x=token_labels,
            y=token_labels,
            colorscale="Viridis",
            zmin=0,
            zmax=1,
            showscale=(l == model.cfg.n_layers - 1) # Show colorbar only for the last plot
        ),
        row=1, col=l+1
    )

fig.update_layout(
    title_text="Average Attention Patterns Across Validation Set",
    width=1200,
    height=450,
    template="plotly_white"
)
fig.update_xaxes(title_text="Key Position")
fig.update_yaxes(title_text="Query Position", autorange='reversed')
fig.show()


# Create a deep copy of the model to avoid modifying the original
model_with_avg_attn = copy.deepcopy(model)

def mk_hook(avg):
    logits = (avg + 1e-12).log()  # log-prob so softmax≈avg, ε avoids -∞

    def f(scores, hook):
        return logits.unsqueeze(0).unsqueeze(0).expand_as(scores)

    return f

for l in range(model_with_avg_attn.cfg.n_layers):
    model_with_avg_attn.blocks[l].attn.hook_attn_scores.add_hook(
        mk_hook(avg_pats[l]), dir="fwd"
    )

print("Accuracy with avg-attn:", accuracy(model_with_avg_attn))

Layer 0: all attention patterns identical? ❌
Layer 1: all attention patterns identical? ❌
Layer 2: all attention patterns identical? ❌


Accuracy with avg-attn: 0.9915
