## Setup

In [None]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

import os

import pandas as pd, itertools
from tqdm.auto import tqdm

from transformer_lens import HookedTransformer, HookedTransformerConfig, utils

# Configure plotly to use static rendering if widgets fail
import plotly.io as pio
pio.renderers.default = "notebook"

float_formatter = "{:.5f}".format
np.set_printoptions(formatter={'float_kind':float_formatter})


## Model

In [None]:
# ---------- constants ----------
LIST_LEN = 2 # [d1, d2]
SEQ_LEN = LIST_LEN * 2 + 1 # [d1, d2, SEP, o1, o2]

N_DIGITS = 100
DIGITS = list(range(N_DIGITS)) # 100 digits from 0 to 99
SEP = DIGITS[-1] + 1 # special seperator token for the model to think about the input (+1 to avoid confusion with the last digit)
VOCAB = len(DIGITS) + 1  # +1 for the special token

D_MODEL = 8
N_HEAD = 1 # 1
N_LAYER = 3 # 2
USE_LN = False # use layer norm in model
USE_BIAS = False # use bias in model
FREEZE_WV = True # no value matrix in attn 
FREEZE_WO = True # no output matrix in attn (i.e. attn head can only copy inputs to outputs)
WEIGHT_DECAY = 0.01 # default 0.01

TRAIN_SPLIT = 0.8 # 80% train, 20% test
MAX_TRAIN_STEPS = 300_000 # max training steps

# model name for saving and loading
# MODEL_NAME = f'{N_DIGITS}dig_{D_MODEL}d'
MODEL_NAME = '3layer_100dig_8d'
MODEL_PATH = "models/" + MODEL_NAME + ".pt"

USE_CHECKPOINTING = True # whether to use checkpointing for training

DEV = (
    "cuda"
    if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
device = DEV
torch.manual_seed(0)

# ---------- mask ----------
# attention mask for [d1, d2, SEP, o1, o2] looks like this (query rows are horizontal, key columns are vertical):
# -    d1    d2    SEP    o1    o2   (keys)
# d1  -inf  -inf   -inf  -inf  -inf
# d2   0    -inf   -inf  -inf  -inf
# SEP  0      0    -inf  -inf  -inf
# o1  -inf  -inf    0    -inf   -inf
# o2  -inf  -inf    0      0    -inf
# (queries)

mask_bias = torch.triu(torch.ones(SEQ_LEN, SEQ_LEN) * float("-inf")) # upper triangular bias mask (lead_diag & above = -inf, rest = 0)
mask_bias[0, 0] = 0. # don't want a full row of -inf! otherwise we get nan erros & training breaks
mask_bias[LIST_LEN+1:, :LIST_LEN] = float("-inf") # stop output tokens from attending to input tokens
mask_bias = mask_bias.unsqueeze(0).unsqueeze(0) # (1,1,T,T) broadcastable across batch and heads

print(mask_bias.cpu()[0][0])


tensor([[0., -inf, -inf, -inf, -inf],
        [0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [-inf, -inf, 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf]])


In [32]:
# ---------- data ----------
# Create all possible combinations of digits
all_data = list(itertools.product(DIGITS, repeat=LIST_LEN))
n_data = len(all_data)
all_data = torch.tensor(all_data, dtype=torch.int64)

# Create sequences of the form [d1, d2, SEP, d1, d2]
all_targets = torch.full((n_data, SEQ_LEN), SEP)
all_targets[:, :LIST_LEN] = all_data
all_targets[:, LIST_LEN+1:] = all_data

# Create input sequences of the form [d1, d2, SEP, SEP, SEP]
all_inputs = all_targets.clone()
all_inputs[:, LIST_LEN+1:] = SEP

# Shuffle the dataset (inputs and targets together)
perm = torch.randperm(n_data)
all_inputs = all_inputs[perm]
all_targets = all_targets[perm]

train_ds = TensorDataset(all_inputs[:int(TRAIN_SPLIT*n_data)], all_targets[:int(TRAIN_SPLIT*n_data)])  # 80% for training
val_ds = TensorDataset(all_inputs[int(TRAIN_SPLIT*n_data):], all_targets[int(TRAIN_SPLIT*n_data):])  # 20% for validation
train_batch_size = min(128, len(train_ds))  # Use a batch size of 128 or less if dataset is smaller
val_batch_size = min(256, len(val_ds))  # Use a batch size of 256 or less if dataset is smaller
train_dl = DataLoader(train_ds, train_batch_size, shuffle=True, drop_last=True)
val_dl = DataLoader(val_ds, val_batch_size, drop_last=False)

print("Input:", train_ds[0][0])  # Example input: [d1, d2, SEP, SEP, SEP]
print("Target:", train_ds[0][1]) # Example target: [d1, d2, SEP, d1, d2]
len(train_ds), len(val_ds)  # Should be 80% for train and 20% for validation

Input: tensor([ 60,  44, 100, 100, 100])
Target: tensor([ 60,  44, 100,  60,  44])


(8000, 2000)

In [33]:
# ---------- config helper ----------
def attach_custom_mask(model):
    def _mask(scores, hook=None):
        # scores: (batch, heads, Q, K)
        return scores + mask_bias.to(scores.device)
    
    # register the same mask hook on every layer
    for block in model.blocks:
        block.attn.hook_attn_scores.add_perma_hook(_mask, dir="fwd")


def strip_bias(m):
    for mod in m.modules():
        if hasattr(mod, "bias") and mod.bias is not None:
            mod.bias.requires_grad_(False)
            torch.nn.init.zeros_(mod.bias)
            print(mod)

    # remove biases from attention layers
    attn_biases = ['b_Q', 'b_K', 'b_V', 'b_O']
    for block in m.blocks:
        for b in attn_biases:
            mod = getattr(block.attn, b, None)
            if mod is not None:
                mod.requires_grad_(False)
                torch.nn.init.zeros_(mod)

    # remove unembed bias
    if hasattr(m, "unembed") and m.b_U is not None:
        m.unembed.b_U.requires_grad_(False)
        torch.nn.init.zeros_(m.unembed.b_U)

def set_WV_identity_and_freeze(model, d_model):
    with torch.no_grad():
        # Create a stack of identity-like matrices for W_V
        # Each matrix is of shape (d_model, d_head)
        # We take the first d_head columns of the d_model x d_model identity matrix
        identity_slice = torch.eye(d_model, model.cfg.d_head)
        # Repeat for each head
        W_V_identity = identity_slice.unsqueeze(0).repeat(model.cfg.n_heads, 1, 1)
        
        for block in model.blocks:
            block.attn.W_V.copy_(W_V_identity)
            block.attn.W_V.requires_grad = False

def set_WO_identity_and_freeze(model, d_model):
    with torch.no_grad():
        # Create a stack of identity-like matrices for W_O
        # Each matrix is of shape (d_head, d_model)
        # We take the first d_head rows of the d_model x d_model identity matrix
        identity_slice = torch.eye(model.cfg.d_head, d_model)
        # Repeat for each head
        W_O_identity = identity_slice.unsqueeze(0).repeat(model.cfg.n_heads, 1, 1)

        for block in model.blocks:
            block.attn.W_O.copy_(W_O_identity)
            block.attn.W_O.requires_grad = False


def make_model(n_layers=N_LAYER, n_heads=N_HEAD, d_model=D_MODEL, ln=USE_LN, use_bias=USE_BIAS, freeze_wv=FREEZE_WV, freeze_wo=FREEZE_WO):
    cfg = HookedTransformerConfig(
        n_layers = n_layers,
        n_heads = n_heads,
        d_model = d_model,
        d_head = d_model//n_heads,
        n_ctx=SEQ_LEN,
        d_vocab=VOCAB,
        attn_only=True, # no MLP!
        normalization_type=("LN" if ln else None),
    )
    model = HookedTransformer(cfg).to(DEV)
    if freeze_wv:
        set_WV_identity_and_freeze(model, d_model)
    if freeze_wo:
        set_WO_identity_and_freeze(model, d_model)
    if not use_bias:
        strip_bias(model)
    
    attach_custom_mask(model)
    return model

In [34]:
# ----- Model saving / loading helpers ------
def save_model(model, path = MODEL_PATH):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(path = MODEL_PATH, device = DEV):
    print("Loading model from", path)
    model = make_model()
    model.load_state_dict(
        torch.load(path, map_location=device)
    )  # map weights to target device
    model.eval()
    return model

In [35]:
# ---------- utilities ----------
def accuracy(m):
    m.eval()
    hits = tots = 0
    with torch.no_grad():
        for inputs, targets in val_dl:
            logits = m(inputs.to(DEV))[:, LIST_LEN+1:]  # (batch, 2, vocab)
            preds = logits.argmax(-1)
            hits += (preds == targets[:, LIST_LEN+1:].to(DEV)).sum().item()
            tots += preds.numel()
    return hits / tots


def train(m, max_steps=10_000, early_stop_acc=0.999, checkpoints=False, weight_decay=WEIGHT_DECAY, verbose=True):
    opt = torch.optim.AdamW(m.parameters(), 1e-3, weight_decay=weight_decay)
    ce = torch.nn.CrossEntropyLoss()
    dl = itertools.cycle(train_dl)  # infinite iterator
    for step in tqdm(range(max_steps), desc="Training"):
        inputs, targets = next(dl)
        # get logits/loss for output tokens only
        logits = m(inputs.to(DEV))[:, LIST_LEN+1:].reshape(-1, VOCAB) 
        loss = ce(logits, targets[:, LIST_LEN+1:].reshape(-1).to(DEV))
        loss.backward()
        opt.step()
        opt.zero_grad()
        if (step + 1) % 100 == 0:
            acc = accuracy(m)
            if acc >= early_stop_acc:
                print(f"Early stopping at step {step + 1} with accuracy {acc:.2%} >= {early_stop_acc:.2%}")
                break
            update_every = max(min(10_000, 0.05*max_steps), 1000)
            if verbose and (step+1) % update_every == 0:
                print(f"Step {step + 1}, Loss: {loss.item():.4f}, Accuracy: {acc:.2%}")
            if checkpoints and (step+1) % 50_000 == 0:
                save_model(m, MODEL_PATH)
            
    print(f"Final accuracy: {accuracy(m):.2%}")


In [36]:
# Check train set
train_ds[:5]

(tensor([[ 60,  44, 100, 100, 100],
         [ 28,  90, 100, 100, 100],
         [ 93,  99, 100, 100, 100],
         [ 19,  17, 100, 100, 100],
         [ 49,  19, 100, 100, 100]]),
 tensor([[ 60,  44, 100,  60,  44],
         [ 28,  90, 100,  28,  90],
         [ 93,  99, 100,  93,  99],
         [ 19,  17, 100,  19,  17],
         [ 49,  19, 100,  49,  19]]))

In [None]:
# ---------- experiment grid ----------
from itertools import product

def make_name(d_model, n_layers, ln, use_bias, freeze_wv, freeze_wo):
    parts = [
        f"d{d_model}",
        f"{n_layers}L",
        ("LN" if ln else "noLN"),
        ("Bias" if use_bias else "noBias"),
        ("fWV" if freeze_wv else "uWV"), # freeze / unfreeze
        ("fWO" if freeze_wo else "uWO"),
    ]
    return "_".join(parts)

specs = [
    # {'name': 'd128', 'd_model': 128},
    # {'name': 'd64', 'd_model': 64},
    
    # {'name': 'd32', 'd_model': 32},
    # {'name': 'd32_ln_bias', 'd_model': 32, 'ln': True, 'use_bias': True},
    # {'name': 'd32_noLN', 'd_model': 32, 'ln': False, 'use_bias': True},
    # {'name': 'd32_noBias', 'd_model': 32, 'ln': True, 'use_bias': False},
    # {'name': 'd32_noLNnoBias', 'd_model': 32, 'ln': False, 'use_bias': False},
    # {'name': 'd32_fwo', 'd_model': 32, 'freeze_wo': True},
    # {'name': 'd32_unfwo', 'd_model': 32, 'freeze_wo': False},

    # {'name': 'd16', 'd_model': 16},
    # {'name': 'd16_ln_bias', 'd_model': 16, 'ln': True, 'use_bias': True},
    # {'name': 'd16_noLN', 'd_model': 16, 'ln': False, 'use_bias': True},
    # {'name': 'd16_noBias', 'd_model': 16, 'ln': True, 'use_bias': False},
    # {'name': 'd16_noLNnoBias', 'd_model': 16, 'ln': False, 'use_bias': False},
    # {'name': 'd16_fwo', 'd_model': 16, 'freeze_wo': True},
    # {'name': 'd16_unfwo', 'd_model': 16, 'freeze_wo': False},

    # {'name': 'd8', 'd_model': 8},
    # {'name': 'd8_ln_bias', 'd_model': 8, 'ln': True, 'use_bias': True},
    # {'name': 'd8_noLN', 'd_model': 8, 'ln': False, 'use_bias': True},
    # {'name': 'd8_noBias', 'd_model': 8, 'ln': True, 'use_bias': False},
    # {'name': 'd8_noLNnoBias', 'd_model': 8, 'ln': False, 'use_bias': False},
    # {'name': 'd8_fwo', 'd_model': 8, 'freeze_wo': True},
    # {'name': 'd8_unfwo', 'd_model': 8, 'freeze_wo': False},

    # {'name': 'd4_ln_bias', 'd_model': 4, 'ln': True, 'use_bias': True},
]

# specs = []
# d_model = 128
# for n_layers, ln, use_bias, freeze_wv, freeze_wo in product(
#     [2, 3],            # layers
#     [False, True],     # ln
#     [False, True],     # use_bias
#     [False, True],     # freeze_wv
#     [False, True],     # freeze_wo
# ):
#     specs.append({
#         "name": make_name(d_model, n_layers, ln, use_bias, freeze_wv, freeze_wo),
#         "d_model": d_model,
#         "n_layers": n_layers,
#         "ln": ln,
#         "use_bias": use_bias,
#         "freeze_wv": freeze_wv,
#         "freeze_wo": freeze_wo,
#     })

# -----------------------
rows = []
for spec in specs:
    # Create a full spec by starting with defaults and updating with the current spec
    full_spec = {
        'n_layers': N_LAYER,
        'n_heads': N_HEAD,
        'd_model': D_MODEL,
        'ln': USE_LN,
        'use_bias': USE_BIAS,
        'freeze_wv': FREEZE_WV,
        'freeze_wo': FREEZE_WO,
        'weight_decay': WEIGHT_DECAY,
    }
    full_spec.update(spec) # Overwrite defaults with provided spec values

    print(f"--- Training model: {full_spec['name']} ---")
    model = make_model(
        n_layers=full_spec['n_layers'],
        n_heads=full_spec['n_heads'],
        d_model=full_spec['d_model'], 
        ln=full_spec['ln'],
        use_bias=full_spec['use_bias'],
        freeze_wv=full_spec['freeze_wv'],
        freeze_wo=full_spec['freeze_wo'],
    )

    train(model, max_steps=50_000, weight_decay=full_spec['weight_decay'], verbose=True)
    
    # Add all spec parameters to the results
    result = full_spec.copy()
    result['val_acc'] = round(accuracy(model), 4)
    rows.append(result)

df = pd.DataFrame(rows)

# Move 'name' column to the front for better readability
if 'name' in df.columns:
    cols = ['name'] + [col for col in df.columns if col != 'name']
    df = df[cols]

print(df.to_markdown(index=False))

--- Training model: d128_2L_noLN_noBias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7330, Accuracy: 48.40%
Step 5000, Loss: 0.7325, Accuracy: 47.77%
Step 7500, Loss: 0.7095, Accuracy: 48.33%
Step 10000, Loss: 0.7193, Accuracy: 48.27%
Step 12500, Loss: 0.7161, Accuracy: 47.93%
Step 15000, Loss: 0.7122, Accuracy: 47.70%
Step 17500, Loss: 0.7111, Accuracy: 47.83%
Step 20000, Loss: 0.6909, Accuracy: 47.33%
Step 22500, Loss: 0.6990, Accuracy: 47.35%
Step 25000, Loss: 0.6969, Accuracy: 48.45%
Step 27500, Loss: 0.7198, Accuracy: 47.27%
Step 30000, Loss: 0.7127, Accuracy: 47.17%
Step 32500, Loss: 0.7063, Accuracy: 47.08%
Step 35000, Loss: 0.6831, Accuracy: 46.60%
Step 37500, Loss: 0.7150, Accuracy: 47.38%
Step 40000, Loss: 0.7102, Accuracy: 47.02%
Step 42500, Loss: 0.6986, Accuracy: 47.12%
Step 45000, Loss: 0.6993, Accuracy: 46.48%
Step 47500, Loss: 0.7272, Accuracy: 46.77%
Step 50000, Loss: 0.7146, Accuracy: 46.25%
Final accuracy: 46.25%
--- Training model: d128_2L_noLN_noBias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7057, Accuracy: 47.73%
Step 5000, Loss: 0.7009, Accuracy: 47.62%
Step 7500, Loss: 0.6937, Accuracy: 47.10%
Step 10000, Loss: 0.6927, Accuracy: 46.77%
Step 12500, Loss: 0.6991, Accuracy: 47.00%
Step 15000, Loss: 0.6852, Accuracy: 46.23%
Step 17500, Loss: 0.7053, Accuracy: 46.58%
Step 20000, Loss: 0.6888, Accuracy: 46.88%
Step 22500, Loss: 0.6815, Accuracy: 46.65%
Step 25000, Loss: 0.6974, Accuracy: 46.88%
Step 27500, Loss: 0.6881, Accuracy: 47.40%
Step 30000, Loss: 0.6892, Accuracy: 46.65%
Step 32500, Loss: 0.6811, Accuracy: 46.52%
Step 35000, Loss: 0.7078, Accuracy: 46.23%
Step 37500, Loss: 0.6839, Accuracy: 46.25%
Step 40000, Loss: 0.6875, Accuracy: 46.27%
Step 42500, Loss: 0.6727, Accuracy: 46.42%
Step 45000, Loss: 0.6999, Accuracy: 47.38%
Step 47500, Loss: 0.6684, Accuracy: 45.67%
Step 50000, Loss: 0.7015, Accuracy: 48.95%
Final accuracy: 48.95%
--- Training model: d128_2L_noLN_noBias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.6987, Accuracy: 47.77%
Step 5000, Loss: 0.6949, Accuracy: 47.30%
Step 7500, Loss: 0.6973, Accuracy: 47.42%
Step 10000, Loss: 0.6991, Accuracy: 47.10%
Step 12500, Loss: 0.6919, Accuracy: 46.83%
Step 15000, Loss: 0.6870, Accuracy: 46.45%
Step 17500, Loss: 0.6872, Accuracy: 46.52%
Step 20000, Loss: 0.7107, Accuracy: 47.40%
Step 22500, Loss: 0.6978, Accuracy: 46.05%
Step 25000, Loss: 0.6946, Accuracy: 47.38%
Step 27500, Loss: 0.6928, Accuracy: 47.05%
Step 30000, Loss: 0.6888, Accuracy: 47.48%
Step 32500, Loss: 0.6975, Accuracy: 47.12%
Step 35000, Loss: 0.6965, Accuracy: 47.02%
Step 37500, Loss: 0.6908, Accuracy: 47.23%
Step 40000, Loss: 0.6989, Accuracy: 46.35%
Step 42500, Loss: 0.6743, Accuracy: 47.02%
Step 45000, Loss: 0.6945, Accuracy: 46.77%
Step 47500, Loss: 0.6898, Accuracy: 46.02%
Step 50000, Loss: 0.6979, Accuracy: 46.30%
Final accuracy: 46.30%
--- Training model: d128_2L_noLN_noBias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.6906, Accuracy: 43.88%
Step 5000, Loss: 0.6850, Accuracy: 46.05%
Step 7500, Loss: 0.6943, Accuracy: 44.12%
Step 10000, Loss: 0.6844, Accuracy: 44.35%
Step 12500, Loss: 0.6875, Accuracy: 43.68%
Step 15000, Loss: 0.6860, Accuracy: 44.77%
Step 17500, Loss: 0.6369, Accuracy: 45.90%
Step 20000, Loss: 0.6764, Accuracy: 44.25%
Step 22500, Loss: 0.6808, Accuracy: 44.25%
Step 25000, Loss: 0.6866, Accuracy: 44.27%
Step 27500, Loss: 0.6631, Accuracy: 44.50%
Step 30000, Loss: 0.3340, Accuracy: 84.45%
Step 32500, Loss: 0.2394, Accuracy: 87.30%
Step 35000, Loss: 0.1735, Accuracy: 89.62%
Step 37500, Loss: 0.1456, Accuracy: 90.18%
Step 40000, Loss: 0.1193, Accuracy: 90.95%
Step 42500, Loss: 0.1048, Accuracy: 90.95%
Step 45000, Loss: 0.0786, Accuracy: 91.00%
Step 47500, Loss: 0.0796, Accuracy: 91.25%
Step 50000, Loss: 0.0970, Accuracy: 91.72%
Final accuracy: 91.72%
--- Training model: d128_2L_noLN_Bias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7224, Accuracy: 48.60%
Step 5000, Loss: 0.7144, Accuracy: 47.95%
Step 7500, Loss: 0.7345, Accuracy: 47.67%
Step 10000, Loss: 0.7161, Accuracy: 47.48%
Step 12500, Loss: 0.7216, Accuracy: 48.05%
Step 15000, Loss: 0.7224, Accuracy: 47.20%
Step 17500, Loss: 0.6926, Accuracy: 46.77%
Step 20000, Loss: 0.7084, Accuracy: 47.42%
Step 22500, Loss: 0.1840, Accuracy: 84.70%
Step 25000, Loss: 0.3053, Accuracy: 80.20%
Step 27500, Loss: 0.1903, Accuracy: 85.17%
Step 30000, Loss: 0.1579, Accuracy: 87.17%
Step 32500, Loss: 0.1527, Accuracy: 86.72%
Step 35000, Loss: 0.1315, Accuracy: 87.10%
Step 37500, Loss: 0.1643, Accuracy: 86.33%
Step 40000, Loss: 0.7269, Accuracy: 48.50%
Step 42500, Loss: 0.1380, Accuracy: 84.52%
Step 45000, Loss: 0.0950, Accuracy: 87.15%
Step 47500, Loss: 0.1788, Accuracy: 86.80%
Step 50000, Loss: 0.1278, Accuracy: 86.80%
Final accuracy: 86.80%
--- Training model: d128_2L_noLN_Bias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7032, Accuracy: 47.73%
Step 5000, Loss: 0.7018, Accuracy: 47.40%
Step 7500, Loss: 0.6975, Accuracy: 46.45%
Step 10000, Loss: 0.6997, Accuracy: 46.45%
Step 12500, Loss: 0.6827, Accuracy: 45.73%
Step 15000, Loss: 0.7120, Accuracy: 45.90%
Step 17500, Loss: 0.6843, Accuracy: 46.58%
Step 20000, Loss: 0.6848, Accuracy: 46.08%
Step 22500, Loss: 0.6915, Accuracy: 45.52%
Step 25000, Loss: 0.6997, Accuracy: 46.52%
Step 27500, Loss: 0.6941, Accuracy: 46.17%
Step 30000, Loss: 0.7091, Accuracy: 45.70%
Step 32500, Loss: 0.6701, Accuracy: 45.57%
Step 35000, Loss: 0.1887, Accuracy: 86.85%
Step 37500, Loss: 0.0891, Accuracy: 89.65%
Step 40000, Loss: 0.7625, Accuracy: 47.90%
Step 42500, Loss: 0.1736, Accuracy: 86.25%
Step 45000, Loss: 0.0916, Accuracy: 88.10%
Step 47500, Loss: 0.1224, Accuracy: 88.33%
Step 50000, Loss: 0.1188, Accuracy: 89.45%
Final accuracy: 89.45%
--- Training model: d128_2L_noLN_Bias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.6890, Accuracy: 47.33%
Step 5000, Loss: 0.6998, Accuracy: 48.30%
Step 7500, Loss: 0.7076, Accuracy: 47.52%
Step 10000, Loss: 0.6933, Accuracy: 47.33%
Step 12500, Loss: 0.6735, Accuracy: 47.08%
Step 15000, Loss: 0.7062, Accuracy: 47.40%
Step 17500, Loss: 0.7009, Accuracy: 47.15%
Step 20000, Loss: 0.6932, Accuracy: 47.83%
Step 22500, Loss: 0.7002, Accuracy: 46.95%
Step 25000, Loss: 0.6899, Accuracy: 47.10%
Step 27500, Loss: 0.7115, Accuracy: 47.48%
Step 30000, Loss: 0.6883, Accuracy: 47.60%
Step 32500, Loss: 0.6938, Accuracy: 47.42%
Step 35000, Loss: 0.6963, Accuracy: 46.73%
Step 37500, Loss: 0.7013, Accuracy: 46.38%
Step 40000, Loss: 0.6984, Accuracy: 47.40%
Step 42500, Loss: 0.6829, Accuracy: 47.12%
Step 45000, Loss: 0.6942, Accuracy: 46.77%
Step 47500, Loss: 0.7040, Accuracy: 47.08%
Step 50000, Loss: 0.6864, Accuracy: 46.45%
Final accuracy: 46.45%
--- Training model: d128_2L_noLN_Bias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.6870, Accuracy: 44.20%
Step 5000, Loss: 0.6983, Accuracy: 45.60%
Step 7500, Loss: 0.6864, Accuracy: 45.32%
Step 10000, Loss: 0.7014, Accuracy: 44.62%
Step 12500, Loss: 0.6864, Accuracy: 45.32%
Step 15000, Loss: 0.6788, Accuracy: 45.00%
Step 17500, Loss: 0.6799, Accuracy: 43.80%
Step 20000, Loss: 0.6637, Accuracy: 44.98%
Step 22500, Loss: 0.2642, Accuracy: 84.67%
Step 25000, Loss: 0.1716, Accuracy: 88.28%
Step 27500, Loss: 0.1228, Accuracy: 89.45%
Step 30000, Loss: 0.1384, Accuracy: 90.42%
Step 32500, Loss: 0.0877, Accuracy: 91.17%
Step 35000, Loss: 0.1052, Accuracy: 90.90%
Step 37500, Loss: 0.1284, Accuracy: 90.92%
Step 40000, Loss: 0.1152, Accuracy: 90.38%
Step 42500, Loss: 0.0888, Accuracy: 91.00%
Step 45000, Loss: 0.0796, Accuracy: 90.75%
Step 47500, Loss: 0.0787, Accuracy: 91.95%
Step 50000, Loss: 0.0739, Accuracy: 91.83%
Final accuracy: 91.83%
--- Training model: d128_2L_LN_noBias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7132, Accuracy: 49.10%
Step 5000, Loss: 0.7055, Accuracy: 48.93%
Step 7500, Loss: 0.7110, Accuracy: 48.30%
Step 10000, Loss: 0.7004, Accuracy: 49.05%
Step 12500, Loss: 0.7186, Accuracy: 48.65%
Step 15000, Loss: 0.7026, Accuracy: 48.88%
Step 17500, Loss: 0.7061, Accuracy: 48.33%
Step 20000, Loss: 0.6962, Accuracy: 48.43%
Step 22500, Loss: 0.7084, Accuracy: 48.43%
Step 25000, Loss: 0.7025, Accuracy: 48.33%
Step 27500, Loss: 0.7091, Accuracy: 48.45%
Step 30000, Loss: 0.7145, Accuracy: 47.88%
Step 32500, Loss: 0.7074, Accuracy: 47.58%
Step 35000, Loss: 0.6918, Accuracy: 48.12%
Step 37500, Loss: 0.7014, Accuracy: 48.18%
Step 40000, Loss: 0.7020, Accuracy: 47.50%
Step 42500, Loss: 0.6960, Accuracy: 47.55%
Step 45000, Loss: 0.6892, Accuracy: 47.20%
Step 47500, Loss: 0.7000, Accuracy: 47.85%
Step 50000, Loss: 0.6985, Accuracy: 47.42%
Final accuracy: 47.42%
--- Training model: d128_2L_LN_noBias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7124, Accuracy: 48.23%
Step 5000, Loss: 0.7081, Accuracy: 49.02%
Step 7500, Loss: 0.6985, Accuracy: 47.83%
Step 10000, Loss: 0.6995, Accuracy: 47.88%
Step 12500, Loss: 0.6801, Accuracy: 48.95%
Step 15000, Loss: 0.7006, Accuracy: 47.85%
Step 17500, Loss: 0.6809, Accuracy: 48.15%
Step 20000, Loss: 0.6960, Accuracy: 47.93%
Step 22500, Loss: 0.6966, Accuracy: 47.98%
Step 25000, Loss: 0.7002, Accuracy: 47.60%
Step 27500, Loss: 0.7001, Accuracy: 47.55%
Step 30000, Loss: 0.7000, Accuracy: 47.83%
Step 32500, Loss: 0.6937, Accuracy: 47.20%
Step 35000, Loss: 0.6912, Accuracy: 47.42%
Step 37500, Loss: 0.6951, Accuracy: 46.90%
Step 40000, Loss: 0.6803, Accuracy: 47.25%
Step 42500, Loss: 0.6934, Accuracy: 46.70%
Step 45000, Loss: 0.6908, Accuracy: 45.73%
Step 47500, Loss: 0.6836, Accuracy: 46.10%
Step 50000, Loss: 0.6979, Accuracy: 46.08%
Final accuracy: 46.08%
--- Training model: d128_2L_LN_noBias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7041, Accuracy: 49.08%
Step 5000, Loss: 0.6938, Accuracy: 48.40%
Step 7500, Loss: 0.7077, Accuracy: 47.80%
Step 10000, Loss: 0.7173, Accuracy: 47.75%
Step 12500, Loss: 0.7030, Accuracy: 47.88%
Step 15000, Loss: 0.7042, Accuracy: 48.38%
Step 17500, Loss: 0.6946, Accuracy: 47.95%
Step 20000, Loss: 0.6936, Accuracy: 47.25%
Step 22500, Loss: 0.6954, Accuracy: 48.62%
Step 25000, Loss: 0.6909, Accuracy: 47.83%
Step 27500, Loss: 0.7022, Accuracy: 47.50%
Step 30000, Loss: 0.6982, Accuracy: 46.45%
Step 32500, Loss: 0.6958, Accuracy: 47.05%
Step 35000, Loss: 0.6928, Accuracy: 47.48%
Step 37500, Loss: 0.6909, Accuracy: 47.30%
Step 40000, Loss: 0.6767, Accuracy: 46.42%
Step 42500, Loss: 0.6890, Accuracy: 46.30%
Step 45000, Loss: 0.6945, Accuracy: 46.55%
Step 47500, Loss: 0.7045, Accuracy: 45.88%
Step 50000, Loss: 0.6837, Accuracy: 46.33%
Final accuracy: 46.33%
--- Training model: d128_2L_LN_noBias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.6979, Accuracy: 47.77%
Step 5000, Loss: 0.6892, Accuracy: 47.15%
Step 7500, Loss: 0.6913, Accuracy: 47.70%
Step 10000, Loss: 0.6872, Accuracy: 47.27%
Step 12500, Loss: 0.6999, Accuracy: 46.23%
Step 15000, Loss: 0.6946, Accuracy: 46.35%
Step 17500, Loss: 0.6988, Accuracy: 46.12%
Step 20000, Loss: 0.6780, Accuracy: 46.38%
Step 22500, Loss: 0.6934, Accuracy: 45.30%
Step 25000, Loss: 0.6947, Accuracy: 45.52%
Step 27500, Loss: 0.6942, Accuracy: 45.73%
Step 30000, Loss: 0.6872, Accuracy: 44.62%
Step 32500, Loss: 0.6823, Accuracy: 45.75%
Step 35000, Loss: 0.6925, Accuracy: 46.65%
Step 37500, Loss: 0.6880, Accuracy: 46.38%
Step 40000, Loss: 0.6930, Accuracy: 44.25%
Step 42500, Loss: 0.6984, Accuracy: 45.23%
Step 45000, Loss: 0.6936, Accuracy: 45.25%
Step 47500, Loss: 0.6969, Accuracy: 43.97%
Step 50000, Loss: 0.6804, Accuracy: 44.85%
Final accuracy: 44.85%
--- Training model: d128_2L_LN_Bias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7271, Accuracy: 48.93%
Step 5000, Loss: 0.7258, Accuracy: 48.58%
Step 7500, Loss: 0.7052, Accuracy: 49.00%
Step 10000, Loss: 0.7032, Accuracy: 48.10%
Step 12500, Loss: 0.7116, Accuracy: 49.02%
Step 15000, Loss: 0.6980, Accuracy: 48.85%
Step 17500, Loss: 0.6961, Accuracy: 48.35%
Step 20000, Loss: 0.7033, Accuracy: 47.85%
Step 22500, Loss: 0.7064, Accuracy: 48.45%
Step 25000, Loss: 0.7103, Accuracy: 47.88%
Step 27500, Loss: 0.7014, Accuracy: 48.20%
Step 30000, Loss: 0.6998, Accuracy: 48.25%
Step 32500, Loss: 0.6927, Accuracy: 47.58%
Step 35000, Loss: 0.6952, Accuracy: 47.45%
Step 37500, Loss: 0.7088, Accuracy: 48.43%
Step 40000, Loss: 0.7021, Accuracy: 47.15%
Step 42500, Loss: 0.6880, Accuracy: 47.62%
Step 45000, Loss: 0.6996, Accuracy: 47.42%
Step 47500, Loss: 0.6801, Accuracy: 47.27%
Step 50000, Loss: 0.6926, Accuracy: 47.33%
Final accuracy: 47.33%
--- Training model: d128_2L_LN_Bias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7131, Accuracy: 48.20%
Step 5000, Loss: 0.7094, Accuracy: 48.60%
Step 7500, Loss: 0.6894, Accuracy: 48.02%
Step 10000, Loss: 0.6902, Accuracy: 47.85%
Step 12500, Loss: 0.7085, Accuracy: 47.93%
Step 15000, Loss: 0.6969, Accuracy: 48.20%
Step 17500, Loss: 0.6922, Accuracy: 47.90%
Step 20000, Loss: 0.7090, Accuracy: 47.23%
Step 22500, Loss: 0.6972, Accuracy: 47.12%
Step 25000, Loss: 0.6794, Accuracy: 47.25%
Step 27500, Loss: 0.6982, Accuracy: 46.77%
Step 30000, Loss: 0.6873, Accuracy: 47.25%
Step 32500, Loss: 0.6996, Accuracy: 46.58%
Step 35000, Loss: 0.6890, Accuracy: 46.88%
Step 37500, Loss: 0.6814, Accuracy: 46.88%
Step 40000, Loss: 0.6896, Accuracy: 47.23%
Step 42500, Loss: 0.6688, Accuracy: 46.75%
Step 45000, Loss: 0.6927, Accuracy: 46.55%
Step 47500, Loss: 0.6808, Accuracy: 46.52%
Step 50000, Loss: 0.6876, Accuracy: 46.48%
Final accuracy: 46.48%
--- Training model: d128_2L_LN_Bias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7166, Accuracy: 49.10%
Step 5000, Loss: 0.7174, Accuracy: 49.18%
Step 7500, Loss: 0.6985, Accuracy: 47.25%
Step 10000, Loss: 0.7136, Accuracy: 48.25%
Step 12500, Loss: 0.6925, Accuracy: 47.98%
Step 15000, Loss: 0.7050, Accuracy: 48.45%
Step 17500, Loss: 0.7079, Accuracy: 47.40%
Step 20000, Loss: 0.7000, Accuracy: 48.10%
Step 22500, Loss: 0.6979, Accuracy: 47.67%
Step 25000, Loss: 0.7036, Accuracy: 47.25%
Step 27500, Loss: 0.7014, Accuracy: 47.88%
Step 30000, Loss: 0.6789, Accuracy: 47.45%
Step 32500, Loss: 0.6913, Accuracy: 47.83%
Step 35000, Loss: 0.6880, Accuracy: 47.75%
Step 37500, Loss: 0.6872, Accuracy: 47.45%
Step 40000, Loss: 0.6790, Accuracy: 46.60%
Step 42500, Loss: 0.7044, Accuracy: 47.48%
Step 45000, Loss: 0.7009, Accuracy: 46.80%
Step 47500, Loss: 0.6951, Accuracy: 46.17%
Step 50000, Loss: 0.6882, Accuracy: 47.55%
Final accuracy: 47.55%
--- Training model: d128_2L_LN_Bias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.7094, Accuracy: 50.02%
Step 5000, Loss: 0.6913, Accuracy: 46.40%
Step 7500, Loss: 0.6942, Accuracy: 47.02%
Step 10000, Loss: 0.6891, Accuracy: 46.50%
Step 12500, Loss: 0.6982, Accuracy: 46.67%
Step 15000, Loss: 0.7003, Accuracy: 46.40%
Step 17500, Loss: 0.6972, Accuracy: 46.90%
Step 20000, Loss: 0.6811, Accuracy: 46.50%
Step 22500, Loss: 0.6914, Accuracy: 46.27%
Step 25000, Loss: 0.6666, Accuracy: 46.10%
Step 27500, Loss: 0.6747, Accuracy: 46.83%
Step 30000, Loss: 0.6730, Accuracy: 46.15%
Step 32500, Loss: 0.6968, Accuracy: 46.48%
Step 35000, Loss: 0.6911, Accuracy: 46.08%
Step 37500, Loss: 0.6845, Accuracy: 46.27%
Step 40000, Loss: 0.6966, Accuracy: 44.73%
Step 42500, Loss: 0.6931, Accuracy: 46.30%
Step 45000, Loss: 0.6731, Accuracy: 46.95%
Step 47500, Loss: 0.6909, Accuracy: 45.88%
Step 50000, Loss: 0.6917, Accuracy: 46.02%
Final accuracy: 46.02%
--- Training model: d128_3L_noLN_noBias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_noLN_noBias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_noLN_noBias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_noLN_noBias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.6923, Accuracy: 43.45%
Early stopping at step 3300 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_noLN_Bias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_noLN_Bias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_noLN_Bias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_noLN_Bias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 2500, Loss: 0.6869, Accuracy: 46.80%
Early stopping at step 2800 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_LN_noBias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_LN_noBias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_LN_noBias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_LN_noBias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 99.92% >= 99.90%
Final accuracy: 99.92%
--- Training model: d128_3L_LN_Bias_uWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_LN_Bias_uWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_LN_Bias_fWV_uWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 100 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
--- Training model: d128_3L_LN_Bias_fWV_fWO ---
Moving model to device:  cuda


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Early stopping at step 200 with accuracy 100.00% >= 99.90%
Final accuracy: 100.00%
| name                        |   n_layers |   n_heads |   d_model | ln    | use_bias   | freeze_wv   | freeze_wo   |   weight_decay |   val_acc |
|:----------------------------|-----------:|----------:|----------:|:------|:-----------|:------------|:------------|---------------:|----------:|
| d128_2L_noLN_noBias_uWV_uWO |          2 |         1 |       128 | False | False      | False       | False       |           0.01 |    0.4625 |
| d128_2L_noLN_noBias_uWV_fWO |          2 |         1 |       128 | False | False      | False       | True        |           0.01 |    0.4895 |
| d128_2L_noLN_noBias_fWV_uWO |          2 |         1 |       128 | False | False      | True        | False       |           0.01 |    0.463  |
| d128_2L_noLN_noBias_fWV_fWO |          2 |         1 |       128 | False | False      | True        | True        |           0.01 |    0.9173 |
| d128_2L_noLN_Bias_uWV_uWO   |    

| name                        |   n_layers |   n_heads |   d_model | ln    | use_bias   | freeze_wv   | freeze_wo   |   weight_decay |   val_acc |
|:----------------------------|-----------:|----------:|----------:|:------|:-----------|:------------|:------------|---------------:|----------:|
| d128_2L_noLN_noBias_uWV_uWO |          2 |         1 |       128 | False | False      | False       | False       |           0.01 |    0.4625 |
| d128_2L_noLN_noBias_uWV_fWO |          2 |         1 |       128 | False | False      | False       | True        |           0.01 |    0.4895 |
| d128_2L_noLN_noBias_fWV_uWO |          2 |         1 |       128 | False | False      | True        | False       |           0.01 |    0.463  |
| d128_2L_noLN_noBias_fWV_fWO |          2 |         1 |       128 | False | False      | True        | True        |           0.01 |    0.9173 |
| d128_2L_noLN_Bias_uWV_uWO   |          2 |         1 |       128 | False | True       | False       | False       |           0.01 |    0.868  |
| d128_2L_noLN_Bias_uWV_fWO   |          2 |         1 |       128 | False | True       | False       | True        |           0.01 |    0.8945 |
| d128_2L_noLN_Bias_fWV_uWO   |          2 |         1 |       128 | False | True       | True        | False       |           0.01 |    0.4645 |
| d128_2L_noLN_Bias_fWV_fWO   |          2 |         1 |       128 | False | True       | True        | True        |           0.01 |    0.9183 |
| d128_2L_LN_noBias_uWV_uWO   |          2 |         1 |       128 | True  | False      | False       | False       |           0.01 |    0.4743 |
| d128_2L_LN_noBias_uWV_fWO   |          2 |         1 |       128 | True  | False      | False       | True        |           0.01 |    0.4607 |
| d128_2L_LN_noBias_fWV_uWO   |          2 |         1 |       128 | True  | False      | True        | False       |           0.01 |    0.4632 |
| d128_2L_LN_noBias_fWV_fWO   |          2 |         1 |       128 | True  | False      | True        | True        |           0.01 |    0.4485 |
| d128_2L_LN_Bias_uWV_uWO     |          2 |         1 |       128 | True  | True       | False       | False       |           0.01 |    0.4733 |
| d128_2L_LN_Bias_uWV_fWO     |          2 |         1 |       128 | True  | True       | False       | True        |           0.01 |    0.4647 |
| d128_2L_LN_Bias_fWV_uWO     |          2 |         1 |       128 | True  | True       | True        | False       |           0.01 |    0.4755 |
| d128_2L_LN_Bias_fWV_fWO     |          2 |         1 |       128 | True  | True       | True        | True        |           0.01 |    0.4602 |
| d128_3L_noLN_noBias_uWV_uWO |          3 |         1 |       128 | False | False      | False       | False       |           0.01 |    1      |
| d128_3L_noLN_noBias_uWV_fWO |          3 |         1 |       128 | False | False      | False       | True        |           0.01 |    1      |
| d128_3L_noLN_noBias_fWV_uWO |          3 |         1 |       128 | False | False      | True        | False       |           0.01 |    1      |
| d128_3L_noLN_noBias_fWV_fWO |          3 |         1 |       128 | False | False      | True        | True        |           0.01 |    1      |
| d128_3L_noLN_Bias_uWV_uWO   |          3 |         1 |       128 | False | True       | False       | False       |           0.01 |    1      |
| d128_3L_noLN_Bias_uWV_fWO   |          3 |         1 |       128 | False | True       | False       | True        |           0.01 |    1      |
| d128_3L_noLN_Bias_fWV_uWO   |          3 |         1 |       128 | False | True       | True        | False       |           0.01 |    1      |
| d128_3L_noLN_Bias_fWV_fWO   |          3 |         1 |       128 | False | True       | True        | True        |           0.01 |    1      |
| d128_3L_LN_noBias_uWV_uWO   |          3 |         1 |       128 | True  | False      | False       | False       |           0.01 |    1      |
| d128_3L_LN_noBias_uWV_fWO   |          3 |         1 |       128 | True  | False      | False       | True        |           0.01 |    1      |
| d128_3L_LN_noBias_fWV_uWO   |          3 |         1 |       128 | True  | False      | True        | False       |           0.01 |    1      |
| d128_3L_LN_noBias_fWV_fWO   |          3 |         1 |       128 | True  | False      | True        | True        |           0.01 |    0.9992 |
| d128_3L_LN_Bias_uWV_uWO     |          3 |         1 |       128 | True  | True       | False       | False       |           0.01 |    1      |
| d128_3L_LN_Bias_uWV_fWO     |          3 |         1 |       128 | True  | True       | False       | True        |           0.01 |    1      |
| d128_3L_LN_Bias_fWV_uWO     |          3 |         1 |       128 | True  | True       | True        | False       |           0.01 |    1      |
| d128_3L_LN_Bias_fWV_fWO     |          3 |         1 |       128 | True  | True       | True        | True        |           0.01 |    1      |

In [12]:
# LOAD existing or train and SAVE new model

if os.path.exists(MODEL_PATH):
    model = load_model(MODEL_PATH, device=DEV)
else:
    print("Training model")
    model = make_model()
    train(model, max_steps=MAX_TRAIN_STEPS, early_stop_acc=0.999, checkpoints=USE_CHECKPOINTING)
    save_model(model, MODEL_PATH)

# from torchinfo import summary
# summary(model) 

Loading model from models/3layer_100dig_8d.pt
Moving model to device:  cuda


In [13]:
# --- Model Parameters Overview ---

print("--- Overview of Model Parameters ---")
total_params = 0
trainable_params = 0

# Use a formatted string for better alignment
print(f"{'Parameter Name':<40} | {'Shape':<20} | {'Trainable':<10}")
print("-" * 80)

for name, param in model.named_parameters():
    shape_str = str(tuple(param.shape))
    is_trainable = "Yes" if param.requires_grad else "No"
    total_params += param.numel()

    if not param.requires_grad:
        continue
    # Print only trainable parameters
    print(f"{name:<40} | {shape_str:<20} | {is_trainable:<10}")
    trainable_params += param.numel()

print("-" * 80)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
print("-" * 80)

--- Overview of Model Parameters ---
Parameter Name                           | Shape                | Trainable 
--------------------------------------------------------------------------------
embed.W_E                                | (101, 8)             | Yes       
pos_embed.W_pos                          | (5, 8)               | Yes       
blocks.0.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.0.attn.W_K                        | (1, 8, 8)            | Yes       
blocks.1.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.1.attn.W_K                        | (1, 8, 8)            | Yes       
blocks.2.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.2.attn.W_K                        | (1, 8, 8)            | Yes       
unembed.W_U                              | (8, 101)             | Yes       
--------------------------------------------------------------------------------
Total parameters: 2621
Trainabl

### Model attention

We confirm below that the model does not leak attention onto the first two tokens, which are the inputs to the task. The model should only attend to the first two tokens when predicting the third token, and not attend to them at all when predicting the fourth and fifth tokens.

In [None]:
# --- Using Plotly for visualization ---

def check_attention(m, dataloader, eps=1e-3):
    for inputs, _ in dataloader:
        with torch.no_grad():
            _, cache = m.run_with_cache(inputs.to(DEV))
        for l in range(m.cfg.n_layers):
            pat = cache["pattern", l][:, 0]  # (batch, Q, K)
            leak = pat[:, LIST_LEN+1:, :LIST_LEN].sum(dim=-1)  # mass on forbidden keys
            if (leak > eps).any():
                raise ValueError(f"❌ Layer {l}: output tokens attend to x₁/x₂ by >{eps:.0e}")
    print("✅ no attention leakage onto x₁/x₂")


sample = val_ds[0][0] # Example input sequence
print(f"Sample sequence: {sample.cpu().numpy()}")  # Print the sample sequence for reference
_, cache = model.run_with_cache(sample.unsqueeze(0).to(DEV))

# --- Create Plotly visualization ---
token_labels = [f'd{i+1}' for i in range(LIST_LEN)] + ['SEP'] + [f'o{i+1}' for i in range(LIST_LEN)]
subplot_titles = [f"Layer {l} Attention Pattern" for l in range(model.cfg.n_layers)]

fig = make_subplots(
    rows=1, 
    cols=model.cfg.n_layers, 
    subplot_titles=subplot_titles,
    horizontal_spacing=0.08 # Add spacing between plots
)

for l in range(model.cfg.n_layers):
    pat = cache["pattern", l][0, 0].cpu().numpy()
    
    fig.add_trace(
        go.Heatmap(
            z=pat,
            x=token_labels,
            y=token_labels,
            colorscale="Viridis",
            zmin=0,
            zmax=1,
            showscale=(l == model.cfg.n_layers - 1) # Show colorbar only for the last plot
        ),
        row=1, col=l+1
    )

fig.update_layout(
    title_text="Attention Patterns for a Sample Sequence",
    width=1200,
    height=450,
    template="plotly_white"
)

# Apply settings to all axes
fig.update_xaxes(title_text="Key Position")
fig.update_yaxes(title_text="Query Position", autorange='reversed')

fig.show()

check_attention(model, val_dl)

Sample sequence: [ 80  52 100 100 100]


✅ no attention leakage onto x₁/x₂


In [None]:
# --- Mean Attention Patterns ---

all_pats = [[] for _ in range(model.cfg.n_layers)]
for inputs, _ in val_dl:
    with torch.no_grad():
        _, cache = model.run_with_cache(inputs.to(DEV))
    for l in range(model.cfg.n_layers):
        pat = cache["pattern", l][:, 0]  # (batch, Q, K)
        all_pats[l].append(pat)
all_pats = [torch.cat(pats, dim=0) for pats in all_pats]

for l, pats in enumerate(all_pats):
    identical = torch.allclose(pats, pats[0].expand_as(pats))
    print(f"Layer {l}: all attention patterns identical? {'✅' if identical else '❌'}")

with torch.no_grad():
    avg_pats = [
        torch.zeros(SEQ_LEN, SEQ_LEN, device=DEV) for _ in range(model.cfg.n_layers)
    ]
    n = 0
    for inputs, _ in val_dl:
        _, cache = model.run_with_cache(inputs.to(DEV))
        for l in range(model.cfg.n_layers):
            avg_pats[l] += cache["pattern", l][:, 0].sum(0)
        n += inputs.shape[0]
    avg_pats = [p / n for p in avg_pats]

# --- Visualize Average Attention Patterns ---
token_labels = [f'd{i+1}' for i in range(LIST_LEN)] + ['SEP'] + [f'o{i+1}' for i in range(LIST_LEN)]
subplot_titles = [f"Layer {l} Average Attention" for l in range(model.cfg.n_layers)]

fig = make_subplots(
    rows=1, 
    cols=model.cfg.n_layers, 
    subplot_titles=subplot_titles,
    horizontal_spacing=0.08
)

for l in range(model.cfg.n_layers):
    avg_pat_np = avg_pats[l].cpu().numpy()
    
    fig.add_trace(
        go.Heatmap(
            z=avg_pat_np,
            x=token_labels,
            y=token_labels,
            colorscale="Viridis",
            zmin=0,
            zmax=1,
            showscale=(l == model.cfg.n_layers - 1) # Show colorbar only for the last plot
        ),
        row=1, col=l+1
    )

fig.update_layout(
    title_text="Average Attention Patterns Across Validation Set",
    width=1200,
    height=450,
    template="plotly_white"
)
fig.update_xaxes(title_text="Key Position")
fig.update_yaxes(title_text="Query Position", autorange='reversed')
fig.show()


# Create a deep copy of the model to avoid modifying the original
model_with_avg_attn = copy.deepcopy(model)

def mk_hook(avg):
    logits = (avg + 1e-12).log()  # log-prob so softmax≈avg, ε avoids -∞

    def f(scores, hook):
        return logits.unsqueeze(0).unsqueeze(0).expand_as(scores)

    return f

for l in range(model_with_avg_attn.cfg.n_layers):
    model_with_avg_attn.blocks[l].attn.hook_attn_scores.add_hook(
        mk_hook(avg_pats[l]), dir="fwd"
    )

print("Accuracy with avg-attn:", accuracy(model_with_avg_attn))

Layer 0: all attention patterns identical? ❌
Layer 1: all attention patterns identical? ❌
Layer 2: all attention patterns identical? ❌


Accuracy with avg-attn: 0.9915
