# 1. Sorting Fixed Length Lists with One Head

## Variable hyperparameters

In [1]:
# Fixed length of list to be sorted
LIST_LENGTH = 10

# Size of vocabulary
D_VOCAB = 66

# Should lists have repetitions?
ALLOW_REPETITIONS = False

# Attention only? (False -> model includes MLPs)
ATTN_ONLY = False

# Model dimensions
N_LAYERS = 1
N_HEADS = 1
D_MODEL = 128
D_HEAD = 32
D_MLP = 32

if ATTN_ONLY:
    D_MLP = None

# Default batch size
DEFAULT_BATCH_SIZE = 32

## Prelude

### Install and import

In [2]:
try:
    import transformer_lens
except:
    !pip install git+https://github.com/neelnanda-io/TransformerLens
    !pip install circuitsvis

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime as dt
from itertools import repeat
import os
import pickle
import random
from typing import cast, Generator, Literal

import circuitsvis as cv
from fancy_einsum import einsum
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn, tensor, Tensor, TensorType as TT
from torch.nn import functional as F
from transformer_lens import HookedTransformerConfig, HookedTransformer
from tqdm import tqdm
from typing_extensions import Self

cv.examples.hello("You")

### Invariable hyperparameters

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{DEVICE = }")

# Seeds to generate training, validation, and test data
TRAIN_SEED = 42
VAL_SEED = 66
TEST_SEED = 1729

# Context length: [start, *(unsorted_)list_length, mid, *(sorted_)list_length]
N_CTX = 2 * LIST_LENGTH + 2

# "Real" tokens range from 0 to D_VOCAB - 2 (non-inclusive)
VOCAB_MIN_ID = 0
VOCAB_MAX_ID = D_VOCAB - 2

# START token is D_VOCAB - 2 and MID token is D_VOCAB - 1
START_TOKEN_ID = VOCAB_MAX_ID
MID_TOKEN_ID = D_VOCAB - 1

DEVICE = 'cpu'


### Data generator and datasets

In [5]:
def generate_list(batch_size: int) -> Tensor:
    if ALLOW_REPETITIONS:
        return torch.randint(VOCAB_MIN_ID, VOCAB_MAX_ID, (batch_size, LIST_LENGTH))
    return tensor([
        random.sample(range(VOCAB_MIN_ID, VOCAB_MAX_ID), k=LIST_LENGTH) 
        for _ in range(batch_size)
    ]).to(DEVICE)

# General generator
def make_data_gen(
    *,
    batch_size: int = DEFAULT_BATCH_SIZE,
    dataset: Literal["train", "val", "test"], # probably this arg needs a better name,
) -> Generator[Tensor, None, None]:
    assert dataset in ("train", "val", "test")
    if dataset == "train":
        seed = TRAIN_SEED
    elif dataset == "val":
        seed = VAL_SEED
    else: # test
        seed = TEST_SEED
    torch.manual_seed(seed)
    while True:
        # Generate random numbers
        x = generate_list(batch_size)
        # Sort
        x_sorted = torch.sort(x, dim=1).values
        # START tokens
        x_start = START_TOKEN_ID * torch.ones(batch_size, dtype=torch.int32).reshape(batch_size, -1).to(DEVICE)
        # MID tokens
        x_mid = MID_TOKEN_ID * torch.ones(batch_size, dtype=torch.int32).reshape(batch_size, -1).to(DEVICE)
        yield torch.cat((x_start, x, x_mid, x_sorted), dim=1)


# Training data generator (kinda wrapper)
def make_train_gen() -> Generator[Tensor, None, None]:
    """Make generator of training data"""
    return make_data_gen(batch_size=128, dataset="train")

# Validation and test data

val_data = next(make_data_gen(batch_size=1000, dataset="val"))
test_data = next(make_data_gen(batch_size=1000, dataset="test"))

### Loss function

In [6]:
def loss_fn(
    logits: Tensor, # [batch, pos, d_vocab] 
    tokens: Tensor, # [batch, pos] 
    return_per_token: bool = False
) -> Tensor: # scalar
    """"""
    # 
    sorted_start_pos = LIST_LENGTH + 2
    logits = logits[:, (sorted_start_pos-1):-1]
    tokens = tokens[:, sorted_start_pos : None]
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if return_per_token:
        return -correct_log_probs
    return -correct_log_probs.mean()

### Accuracy and validation

In [7]:
def get_diff_row_inds(
    a: Tensor, # [dim1, dim2]
    b: Tensor  # [dim1, dim2]
) -> Tensor:   # [dim1]
    """Find indices of rows where a and b differ"""
    assert a.shape == b.shape
    return ((a == b).prod(dim=1) == 0).nonzero(as_tuple=True)[0]

def acc_fn(
    logits: Tensor, # [batch, pos, d_vocab]
    tokens: Tensor, # [batch, pos]
    per: Literal["token", "sequence"] = "sequence"
) -> float:
    """Compute accuracy as percentage of correct predictions"""
    sorted_start_pos = LIST_LENGTH + 2
    # Get logits of predictions for position
    logits = logits[:, (sorted_start_pos-1):-1]
    preds = logits.argmax(-1)
    tokens = tokens[:, sorted_start_pos:]
    if per == "sequence":
        return (preds == tokens).prod(dim=1).float().mean().item()
    return (preds == tokens).float().mean().item()

def validate(
    model: HookedTransformer, 
    data: Tensor # [batch, pos]
) -> float:
    """Test this model on `data`"""
    logits = model(data)
    acc = acc_fn(logits, tokens=data)
    return acc

def show_mispreds(
    model: HookedTransformer, 
    data: Tensor # [batch, pos]
) -> None:
    """Test this model on `data` and print mispredictions"""
    logits = model(data)
    sorted_start_pos = LIST_LENGTH + 2
    logits = logits[:, (sorted_start_pos-1):-1]
    tokens = data[:, sorted_start_pos:]
    preds = logits.argmax(-1)
    mispred_inds = get_diff_row_inds(tokens, preds)
    for i in mispred_inds:
        print(f"[{i}] {tokens[i].numpy().tolist()} | {preds[i].numpy().tolist()}")
    print(f"{len(mispred_inds)}/{len(preds)} ({len(mispred_inds) / len(preds) :.2%})")

## Training

### Model

In [8]:
cfg = HookedTransformerConfig(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_head=D_HEAD,
    n_ctx=N_CTX,
    d_vocab=D_VOCAB,
    act_fn="relu",
    seed=42,
    device=DEVICE,
    attn_only=ATTN_ONLY
)
model = HookedTransformer(cfg, move_to_device=True)

### Training setup

In [9]:
@dataclass(frozen=True)
class TrainingHistory:
    losses: list[float]
    train_accuracies: list[float]
    val_accuracies: list[float]

def converged(val_accs: list[float], n_last: int = 2) -> bool:
    if len(val_accs) < n_last:
        return False
    return len(set(tensor(val_accs[-n_last:]).round(decimals=4).tolist())) == 1

# Number of epochs
n_epochs = 20000

# Optimization
lr = 1e-3
betas = (.9, .999)
optim = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, "min", patience=100)

# Training data generator
train_gen = make_train_gen()

def train_model(model: HookedTransformer) -> TrainingHistory:
    losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(n_epochs):
        tokens = next(train_gen).to(device=DEVICE)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
        optim.step()
        optim.zero_grad()
        scheduler.step(loss)
        
        if epoch % 100 == 0:
            losses.append(loss.item())
            train_batch_acc = acc_fn(logits, tokens)
            val_acc = validate(model, val_data)
            val_loss = loss_fn(model(val_data), val_data)

            train_accuracies.append(train_batch_acc)
            val_accuracies.append(val_acc)
            print(
                f"Epoch {epoch}/{n_epochs} ({epoch / n_epochs:.0%}) : "
                f"loss = {loss.item():.4f}; {train_batch_acc=:.3%}; "
                f"{val_acc=:.3%}; lr={scheduler._last_lr[0]}" #type:ignore
            )
            # If last 10 recorded val_accuracies are 100%
            if converged(val_accuracies):
                print(f"\nAchieved consistent perfect validation accuracy after {epoch} epochs")
                break
    return TrainingHistory(losses, train_accuracies, val_accuracies)

def load_model_state(model: HookedTransformer, filename: str) -> None:
    assert os.path.isdir("models"), "Make a directory `models` with model state dicts"
    if not filename.startswith("models/"):
        filename = f"models/{filename}"
    with open(filename, "rb") as f:
        state_dict = pickle.load(f)
    model.load_state_dict(state_dict)

### Train or load model

In [10]:
history = train_model(model)
# load_model_state(model, <filename>)

Epoch 0/20000 (0%) : loss = 4.4698; train_batch_acc=0.000%; val_acc=0.000%; lr=0.001
Epoch 100/20000 (0%) : loss = 0.1415; train_batch_acc=75.781%; val_acc=75.100%; lr=0.001
Epoch 200/20000 (1%) : loss = 0.0188; train_batch_acc=99.219%; val_acc=97.400%; lr=0.001
Epoch 300/20000 (2%) : loss = 0.0131; train_batch_acc=98.438%; val_acc=98.700%; lr=0.001
Epoch 400/20000 (2%) : loss = 0.0081; train_batch_acc=98.438%; val_acc=99.400%; lr=0.001
Epoch 500/20000 (2%) : loss = 0.0082; train_batch_acc=96.875%; val_acc=99.500%; lr=0.001
Epoch 600/20000 (3%) : loss = 0.0023; train_batch_acc=100.000%; val_acc=99.700%; lr=0.001
Epoch 700/20000 (4%) : loss = 0.0034; train_batch_acc=100.000%; val_acc=99.700%; lr=0.001

Achieved consistent perfect validation accuracy after 700 epochs


### Testing post-training

In [11]:
print("Validating on validation data:")
val_acc = validate(model, val_data)
print(f"\t{val_acc=:.3%}\n")
if val_acc < 1:
    show_mispreds(model, val_data)

print("\nValidating on test data:")
test_acc = validate(model, test_data)
print(f"\t{test_acc=:.3%}\n")
if test_acc < 1:
    show_mispreds(model, test_data)

Validating on validation data:
	val_acc=99.700%

[12] [2, 4, 38, 43, 45, 47, 49, 53, 54, 56] | [2, 4, 44, 43, 45, 47, 49, 53, 54, 56]
[97] [4, 5, 32, 33, 37, 43, 48, 50, 54, 60] | [4, 5, 33, 33, 37, 43, 48, 50, 54, 60]
[275] [6, 8, 9, 13, 14, 16, 23, 59, 61, 63] | [6, 8, 9, 13, 14, 16, 23, 23, 61, 63]
3/1000 (0.30%)

Validating on test data:
	test_acc=99.700%

[97] [1, 4, 5, 6, 12, 16, 21, 51, 57, 63] | [1, 4, 5, 6, 12, 16, 21, 57, 57, 63]
[360] [1, 3, 6, 10, 39, 42, 44, 53, 54, 61] | [1, 3, 6, 10, 42, 42, 44, 53, 54, 61]
[711] [0, 12, 13, 47, 51, 53, 54, 55, 57, 62] | [0, 12, 13, 43, 51, 53, 54, 55, 57, 62]
3/1000 (0.30%)


### Saving trained model

In [12]:
def save_model_state_dict(
    model: HookedTransformer, 
    filename: str | None = None
) -> None:
    if not os.path.isdir("models"):
        os.mkdir("models")
    if not filename:
        timestamp = dt.now().isoformat("T", "minutes").replace(":", "-")
        filename = f"model_state_dict_{timestamp}.pkl"
    with open(f"models/{filename}", "wb") as f:
        pickle.dump(model.state_dict(), f)

save_model_state_dict(model)

In [13]:
os.listdir("models")

['model_state_dict_2023-11-10T13-21.pkl']

## Investigate the model

### Attention patterns

In [14]:
# Get one input from test_data
test_input = test_data[3, :]

# Pass through model, get cache and predictions
logits, cache_model = model.run_with_cache(test_input, remove_batch_dim=True) 
preds = logits[:, LIST_LENGTH+1 : -1].argmax(-1)

# Get attention pattern and plot it
attention_pattern = cache_model["pattern", 0, "attn"]
tokens_input = list(map(str, test_input))
print(test_input)
print(preds)

cv.attention.attention_patterns(tokens=tokens_input, attention=attention_pattern)

tensor([64, 40, 33, 41, 29, 11, 43, 55,  3, 50, 16, 65,  3, 11, 16, 29, 33, 40,
        41, 43, 50, 55])
tensor([[ 3, 11, 16, 29, 33, 40, 41, 43, 50, 55]])
