<a target="_blank" href="https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Modular Subtraction Grokking Notebook

**Please ensure proper set-up as described in the `README.md` file.**



*   You must install the dependencies described in `requirements.txt`; only **some** of the dependencies are hardcoded here.
*   This notebook should ideally be run on Yale's McCleary or Grace HPC, as Plotly plots will not render correctly on Google Colab.
*   Please run this notebook using a GPU for efficiency.
*   This notebook is **not** comptabile with Apple Silicon due to the float64 datatypes used for high-precision.

This code is adapted from the work presented in the paper:
- Nanda, N., Chan, L., Lieberum, T., Smith, J., & Steinhardt, J. (2023). "Progress measures for grokking via mechanistic interpretability." arXiv preprint arXiv:2301.05217. Available at: https://arxiv.org/abs/2301.05217

# Setup

In [None]:
# pre-requisites
#!pip install ipython==8.12.3
!pip install datasets==2.19.1
!pip install einops==0.8.0
!pip install fancy_einsum==0.0.3
!pip install matplotlib==3.7.1
!pip install numpy==1.25.2
!pip install pandas==2.0.3
!pip install plotly==5.15.0
!pip install protobuf==3.20.3
!pip install torch==2.1.2
!pip install tqdm==4.66.4
!pip install transformer_lens==1.17.0
!pip install transformers==4.40.2

!pip install kaleido==0.2.1
!pip install git+https://github.com/neelnanda-io/neel-plotly.git

[0mCollecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-_34ra39x
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-_34ra39x
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25l[?25hdone
[0m

In [None]:
# standard library imports
import copy
import dataclasses
from datetime import datetime
import itertools
import os
import random
import warnings
from functools import partial
from pathlib import Path
from typing import List, Optional, Union

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# third party imports
import datasets
import einops
from neel_plotly.plot import line, imshow, melt
import numpy as np
from pathlib import Path
import plotly.express as px
import plotly.io as pio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm.auto as tqdm
import transformer_lens
import transformer_lens.utils as utils
from fancy_einsum import einsum
from IPython import get_ipython
from IPython.display import HTML
from plotly.offline import init_notebook_mode, iplot
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformer_lens.hook_points import HookPoint, HookedRootModule
from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformer, HookedTransformerConfig

In [None]:
# Define constants
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print(f"Device: {DEVICE}")

p = 113                    # prime
TRAIN_FRAC = 0.3           # fraction of data used for training
LEARNING_RATE = 1e-3       # learning rate for the optimizer
WEIGHT_DECAY = 1.0         # weight decay for regularization
BETAS = (0.9, 0.98)        # betas for optimizer
N_EPOCHS = 25000           # number of training epochs
CHECKPOINT_EVERY = 100     # interval at which to save model checkpoints
SEED = 598                 # seed for data shuffling for reproducibility

pio.renderers.default = "notebook_connected"

try:
    import google.colab
    IN_COLAB = True
    pio.renderers.default = "colab"
    print("Running in a Google Colab environment")

    # required packages for Colab
    packages = ["transformer-lens", "circuitsvis",
                "git+https://github.com/neelnanda-io/PySvelte.git"]
    for package in packages:
        get_ipython().run_line_magic('pip', f'install {package}')

    # Node.js for PySvelte
    get_ipython().system_raw("curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -")
    get_ipython().system_raw("sudo apt-get install -y nodejs")
except ImportError:
    IN_COLAB = False
    print("Running in a Jupyter notebook")

    ipython = get_ipython()
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

init_notebook_mode(connected=True)

# Update Plotly layout defaults
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

Device: cuda
Running in a Google Colab environment
[0mCollecting git+https://github.com/neelnanda-io/PySvelte.git
  Cloning https://github.com/neelnanda-io/PySvelte.git to /tmp/pip-req-build-yqq25upp
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/PySvelte.git /tmp/pip-req-build-yqq25upp
  Resolved https://github.com/neelnanda-io/PySvelte.git to commit 582d85ff708947e72b35cfcca05641332b44f5f5
  Preparing metadata (setup.py) ... [?25l[?25hdone
[0m

## Data Preparation

Input format:
|a|b|=|

In [None]:
# Create a matrix with repeated rows [0, 1, ..., p-1]
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)

# Create a matrix with repeated columns [0, 1, ..., p-1]
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)

# Create a matrix filled with the constant value 113
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i=p, j=p)

# Stack the tensors to form a combined dataset
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(DEVICE)

print("First 5 entries of the dataset:")
print(dataset[:5])

First 5 entries of the dataset:
tensor([[  0,   0, 113],
        [  0,   1, 113],
        [  0,   2, 113],
        [  0,   3, 113],
        [  0,   4, 113]], device='cuda:0')


In [None]:
# Compute labels using modular subtraction
labels = (dataset[:, 0] - dataset[:, 1]) % p

print("First 5 labels:", labels[:5].tolist())
print("Last 5 labels:", labels[-5:].tolist())

First 5 labels: [0, 112, 111, 110, 109]
Last 5 labels: [4, 3, 2, 1, 0]


We convert the data into training and testing sets using a 30-70 split.

In [None]:
# Set a fixed seed to ensure reproducible results
torch.manual_seed(SEED)

# Generate a random permutation of indices for splitting the data
total_size = p * p
indices = torch.randperm(total_size)

# Calculate the cutoff index for training and testing data split
cutoff = int(total_size * TRAIN_FRAC)

# Split indices into training and testing sets
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

# Index into the original dataset to create training and testing subsets
train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]

# Print the first few entries of the training data and labels to verify
print("First 5 training data samples:\n", train_data[:5])
print("\nFirst 5 training labels:\n", train_labels[:5].tolist())
print("\nTraining data shape:\n", train_data.shape)

# Print the first few entries of the testing data and labels to verify
print("\nFirst 5 testing data samples:\n", test_data[:5])
print("\nFirst 5 testing labels:\n", test_labels[:5].tolist())
print("\nTesting data shape:\n", test_data.shape)

First 5 training data samples:
 tensor([[ 21,  31, 113],
        [ 30,  98, 113],
        [ 47,  10, 113],
        [ 86,  21, 113],
        [ 99,  83, 113]], device='cuda:0')

First 5 training labels:
 [103, 45, 37, 65, 16]

Training data shape:
 torch.Size([3830, 3])

First 5 testing data samples:
 tensor([[ 43,  40, 113],
        [ 31,  42, 113],
        [ 39,  63, 113],
        [ 35,  61, 113],
        [112, 102, 113]], device='cuda:0')

First 5 testing labels:
 [3, 102, 89, 87, 10]

Testing data shape:
 torch.Size([8939, 3])


## Model Configuration

In [None]:
cfg = HookedTransformerConfig(
    n_layers=1,               # Number of transformer layers
    n_heads=4,                # Number of attention heads in each transformer layer
    d_model=128,              # Dimension of the model
    d_head=32,                # Dimension of each attention head
    d_mlp=512,                # Dimension of the feedforward network model in transformers
    act_fn="relu",            # Activation function
    normalization_type=None,  # Type of normalization layer used in the transformer, if any
    d_vocab=p + 1,            # Vocabulary size (input)
    d_vocab_out=p,            # Output vocabulary size
    n_ctx=3,                  # Context size or the length of the sequence to be processed
    init_weights=True,        # Whether to initialize weights
    device=DEVICE,            # Specify the computation device
    seed=999                  # Random seed for reproducibility
)

model = HookedTransformer(cfg)

# disable biases in the model for simplification and clearer interpretation
for name, param in model.named_parameters():
    if "b_" in name:  # check if a bias term
        param.requires_grad = False  # disable gradient computation for biases

## Optimizer and Loss Configuration

In [None]:
# Initialize the optimizer with the specified parameters and settings
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=BETAS)

def loss_fn(logits, labels):
    """
    Calculate the cross-entropy loss between logits from a model and provided labels.
    """

    # Reduce logits if 3D (batch_size, sequence_length, vocab_size) to 2D (batch_size, vocab_size)
    if len(logits.shape) == 3:
        logits = logits[:, -1]

    logits = logits.to(torch.float64)  # float64 for numerical stability
    log_probs = logits.log_softmax(dim=-1)  # Apply log_softmax to logits
    correct_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))[:, 0]  # Gather correct log probabilities using labels
    return -correct_log_probs.mean()  # Return negative log likelihood mean as the loss

Baseline losses (before training the model)

In [None]:
# Calculate and print losses for both training and testing datasets
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(f"Training Loss: {train_loss.item()}")

test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(f"Testing Loss: {test_loss.item()}")

# Print theoretical uniform loss for comparison
uniform_loss = np.log(p)  # Calculate uniform loss as the log of vocabulary size
print("Uniform loss:", uniform_loss)

Training Loss: 4.732243928451223
Testing Loss: 4.7345742216589235
Uniform loss: 4.727387818712341


## Model Training

We train the model with full batch training rather than stochastic gradient descent to make training smoother and reduce the number of "slingshots". The **training loop** is in the next cell.

In [None]:
train_losses = []
test_losses = []

model_checkpoints = []
checkpoint_epochs = []

for epoch in tqdm.tqdm(range(N_EPOCHS)):

    # compute logits and loss for training data
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()  # perform backprop
    train_losses.append(train_loss.item())  # store train loss

    optimizer.step()  # update model parameters
    optimizer.zero_grad()  # reset gradients

    # eval model on test data w/o computing gradients
    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)
        test_losses.append(test_loss.item())  # store test loss

    # save checkpoint
    if (epoch + 1) % CHECKPOINT_EVERY == 0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        print(f"Epoch {epoch+1}: Train Loss {train_loss.item()}, Test Loss {test_loss.item()}")

  0%|          | 0/25000 [00:00<?, ?it/s]

Epoch 100: Train Loss 3.022262011050979, Test Loss 7.554872957277325
Epoch 200: Train Loss 0.05346864657062493, Test Loss 20.926745070063976
Epoch 300: Train Loss 0.01317649258177606, Test Loss 22.764766772392907
Epoch 400: Train Loss 0.004204733923955457, Test Loss 24.368866459746865
Epoch 500: Train Loss 0.001393430905065373, Test Loss 26.073257790197772
Epoch 600: Train Loss 0.0004701875809480826, Test Loss 27.844731972795806
Epoch 700: Train Loss 0.0001611024743911674, Test Loss 29.637408455160184
Epoch 800: Train Loss 5.5943680009892304e-05, Test Loss 31.421438840528
Epoch 900: Train Loss 1.9781119656853215e-05, Test Loss 33.188836290868174
Epoch 1000: Train Loss 7.225755003799071e-06, Test Loss 34.90650942383287
Epoch 1100: Train Loss 2.8143193040890922e-06, Test Loss 36.482072760803696
Epoch 1200: Train Loss 1.239562522633948e-06, Test Loss 37.828367319324805
Epoch 1300: Train Loss 6.625529458988232e-07, Test Loss 38.78269176957883
Epoch 1400: Train Loss 4.489746084676601e-07, T

# Model Analysis

In [None]:
line([train_losses[::100],
      test_losses[::100]],
     x=np.arange(0, len(train_losses), 100),
     xaxis="Epoch",
     yaxis="Loss",
     log_y=True,
     title="Training Curve for Modular Subtraction",
     line_labels=['train', 'test'],
     toggle_x=True,
     toggle_y=True)

In [None]:
original_logits, cache = model.run_with_cache(dataset)

# Printing total number of elements in the original_logits tensor
print("Number of elements in logits:", original_logits.numel())

# Extracting embedding weights, excluding the last row
W_E = model.embed.W_E[:-1]
print("W_E shape:", W_E.shape)

# Computing a transformation through the network's first block
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur shape:", W_neur.shape)

# Calculating the transformation from the final MLP output back to logits
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit shape:", W_logit.shape)

Number of elements in logits: 4328691
W_E shape: torch.Size([113, 128])
W_neur shape: torch.Size([4, 113, 512])
W_logit shape: torch.Size([512, 113])


In [None]:
# Calculating and printing the loss for the original logits
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)

Original Loss: 1.2152574445457206e-07


## Attention Heads Analysis

In [None]:
# Extract attention patterns for the last head's first two columns
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]  # Last head, first column
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]  # Last head, second column
print("pattern_a shape:", pattern_a.shape)
print("pattern_b shape:", pattern_b.shape)

# Extract MLP layer's post-activations and pre-activations at the last position
neuron_acts = cache["post", 0, "mlp"][:, -1, :]  # Post-activations
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]  # Pre-activations
print("neuron_acts shape:", neuron_acts.shape)
print("neuron_pre_acts shape:", neuron_pre_acts.shape)

# Print the shapes of all cached items to understand what data is being stored
for param_name, param in cache.items():
    print(param_name, param.shape)

pattern_a shape: torch.Size([12769, 4])
pattern_b shape: torch.Size([12769, 4])
neuron_acts shape: torch.Size([12769, 512])
neuron_pre_acts shape: torch.Size([12769, 512])
hook_embed torch.Size([12769, 3, 128])
hook_pos_embed torch.Size([12769, 3, 128])
blocks.0.hook_resid_pre torch.Size([12769, 3, 128])
blocks.0.attn.hook_q torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_k torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_v torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_attn_scores torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_pattern torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_z torch.Size([12769, 3, 4, 32])
blocks.0.hook_attn_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_mid torch.Size([12769, 3, 128])
blocks.0.mlp.hook_pre torch.Size([12769, 3, 512])
blocks.0.mlp.hook_post torch.Size([12769, 3, 512])
blocks.0.hook_mlp_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_post torch.Size([12769, 3, 128])


In [None]:
imshow(cache["pattern", 0].mean(dim=0)[:, -1, :],
       title="Average Attention Pattern per Head",
       xaxis="Source",
       yaxis="Head",
       x=['a', 'b', '='])

In [None]:
imshow(cache["pattern", 0][5][:, -1, :],
       title="Average Attention Pattern per Head",
       xaxis="Source",
       yaxis="Head",
       x=['a', 'b', '='])

In [None]:
imshow(cache["pattern", 0][:, 0, -1, 0].reshape(p, p),
       title="Attention for Head 0 from a -> =",
       xaxis="b",
       yaxis="a")

In [None]:
imshow(einops.rearrange(cache["pattern", 0][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
       title="Attention for Head 0 from a -> =",
       xaxis="b",
       yaxis="a",
       facet_col=0)

## Singular Value Decomposition

In [None]:
W_E.shape

torch.Size([113, 128])

First, we analyze the SVD of a RANDOM Gaussian matrix to serve as a baseline/control. Notice how the singular values are linearly decrease in importance. The singular values exhibit a wide range of values.

In [None]:
# CONTROL - SVD on RANDOM Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))

line(S, title="Singular Values for Random Gaussian Matrix")
imshow(U, title="Principal Components for Random Gaussian Matrix")

In contrast, SVD on our weight matrix yields only ~6 nontrivial singular values!

In [None]:
# SVD on our weight matrix
U, S, Vh = torch.svd(W_E)

line(S, title="Singular Values for Embedding Matrix")
imshow(U, title="Principal Components on the Input")

We visualize the first few (8) principal components, extracted from U.

In [None]:
# Extract the first 8 principal components from U
principal_components = U[:, :8].T  # transpose to make each row a PC
line(principal_components,
     title="First 8 Principal Components of the Embedding",
     xaxis='''Input "Vocabulary"''')

## Fourier Basis Analysis

In [None]:
fourier_basis = []
fourier_basis_names = []

# append a constant basis vector
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")

# generate sine and cosine basis vectors
for freq in range(1, p//2+1):
    # append sine components
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")

    # append cosine components
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")

fourier_basis = torch.stack(fourier_basis, dim=0).to(DEVICE)

# normalize each basis vector to have unit norm
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)

imshow(fourier_basis,
       title="Fourier Basis Components (2D)",
       xaxis="Input",
       yaxis="Component",
       y=fourier_basis_names)

Below, we plot slices of the standard Fourier basis for p=113.

In [None]:
# plot the first 5 Fourier components
line(fourier_basis[:5],
     xaxis="Input",
     line_labels=fourier_basis_names[:5],
     title="First 5 Fourier Components")

# plot middle range Fourier components
line(fourier_basis[57:61],
     xaxis="Input",
     line_labels=fourier_basis_names[57:61],
     title="Middle Fourier Components")

# plot last 2 Fourier components
line(fourier_basis[111:113],
     xaxis="Input",
     line_labels=fourier_basis_names[111:113],
     title="Last Fourier Components")

Below, we illustrate that all standard Fourier basis vectors are orthogonal.

In [None]:
imshow(fourier_basis @ fourier_basis.T,
       title="All Fourier Vectors are Orthogonal")

## Projecting the Weight Matrix to the Fourier basis (to observe frequency)

After projecting our weight matrix to the Fourier basis, we observe that of the 113 Fourier components, only **6** of the basis vectors are used meaningfully! The embedding matrix and projection are low rank.

In [None]:
# Compute the projection of the embedding weights onto the Fourier basis
projection = fourier_basis @ W_E

# heatmap where each cell represents the interaction strength between a
# Fourier component and a dimension of the embedding space
imshow(projection,
       yaxis="Fourier Component",
       xaxis="Residual Stream",
       y=fourier_basis_names,
       title="Embedding in Fourier Basis")

Next, we calculate the norm of each projection vector, which provides a measure of how much each Fourier component contributes to the embedding space, effectively summarizing the projection strength. We observe that **sin(1), cos(1), sin(5), cos(5), sin(34), and cos(34)** are the most influential.

In [None]:
# calculate the norm of each projection vector
projection_norms = projection.norm(dim=-1)

# Plot the norms of these projections to understand which Fourier components
# have the strongest influence in the embedding space.
line((fourier_basis @ W_E).norm(dim=-1),
     xaxis="Fourier Component",
     x=fourier_basis_names,
     title="Norms of Embedding in Fourier Basis")

Below, we explicitly identify the key frequencies which are crucial for further analysis. The other frequencies are essentially zero, so are discarded in further analysis.

In [None]:
# key frequencies
key_freqs = [1, 5, 34]

# indices for the Fourier basis, since each frequency has
# two corresponding indices in the Fourier basis for its sine and cosine components
key_freq_indices = [1, 2, 9, 10, 67, 68]  # Sine and Cosine indices for each frequency

# extract projections corresponding to the key frequencies
key_fourier_embed = projection[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)  # shape for verification

# dot product of the key Fourier embeddings with themselves
# represents the interactions between these key Fourier components
imshow(key_fourier_embed @ key_fourier_embed.T,
       title="Dot Product of Embedding of Key Fourier Terms")

key_fourier_embed torch.Size([6, 128])


## Key Frequencies Visualization

In [None]:
line(fourier_basis[[2, 10, 68]],
     title="Cos of Key Frequencies",
     line_labels=[2, 10, 68])

line(fourier_basis[[1, 9, 67]],
     title="Sin of Key Frequencies",
     line_labels=[1, 9, 67])

### Neuron Clusters

In [None]:
# project neuron activations to Fourier basis
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T

# center these by removing the mean
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

fourier_neuron_acts torch.Size([512, 113, 113])


Observe that each neuron's variance is substantially explained by activation along a few key frequencies.

In [None]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(DEVICE)
for freq in range(0, p//2):
    # consider both the sine and cosine components for each frequency
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]

# show what fraction of the neuron's variance is explained by each frequency.
imshow(neuron_freq_norm,
       xaxis="Neuron",
       yaxis="Freq",
       y=torch.arange(1, p//2+1),
       title="Neuron Frac Explained by Freq")

In [None]:
line(neuron_freq_norm.max(dim=0).values.sort().values,
     xaxis="Neuron",
     title="Max Neuron Frac Explained over Freqs")

## Neuron-Logit Weights Analysis

In [None]:
# Multiply MLP output weights by the UN-embedding weights to project the transformed outputs
# back to the input vocabulary space. This translates the learned representations into
# interpretable outputs
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

W_logit torch.Size([512, 113])


Projecting the logit weight matrix (W_L) into the frequency domain (via Fourier transform) reveals that the outputs of this model consist of linear combinations of a sines and cosines of the key frequencies.

In [None]:
line((W_logit @ fourier_basis.T).norm(dim=0),
     x=fourier_basis_names,
     title="W_logit in the Fourier Basis")

# Black Box Methods & Progress Measures

## Setup Code

In [None]:
def test_logits(logits, bias_correction=False, original_logits=None, mode="all"):
    """
    Adjusts and evaluates logits according to specified testing conditions.
    """
    # Ensure the logits are in the correct shape [p*p, p].
    if logits.shape[1] == p * p:
        logits = logits.T  # Transpose to shape [p*p, p+1]
    if logits.shape == torch.Size([p * p, p + 1]):
        logits = logits[:, :-1]

    # Reshape logits to ensure each row corresponds to a potential input
    logits = logits.reshape(p * p, p)

    # Apply bias correction if enabled
    if bias_correction:
        if original_logits is None:
            raise ValueError("Original logits must be provided for bias correction.")

        # Calculate the mean difference between original logits and current logits across all batches
        # Then adjust the current logits by adding this mean difference to each
        mean_difference = einops.reduce(original_logits - logits, "batch ... -> ...", "mean")
        logits += mean_difference

    # Compute the loss based on the specified mode.
    if mode == "train":
        return loss_fn(logits[train_indices], labels[train_indices])
    elif mode == "test":
        return loss_fn(logits[test_indices], labels[test_indices])
    elif mode == "all":
        return loss_fn(logits, labels)

In [None]:
metric_cache = {}

def get_metrics(model, metric_cache, metric_fn, name, reset=False):
    """
    Evaluate and cache the metric results for a model at various checkpoints.
    """
    if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
        metric_cache[name] = []

        for c, sd in enumerate(tqdm.tqdm(model_checkpoints)):
            model.reset_hooks()
            model.load_state_dict(sd)
            out = metric_fn(model)

            if isinstance(out, torch.Tensor):
                out = utils.to_numpy(out)

            metric_cache[name].append(out)

        model.load_state_dict(model_checkpoints[-1])

        try:
            metric_cache[name] = torch.tensor(metric_cache[name])
        except TypeError:  # Handle cases where the conversion fails
            metric_cache[name] = torch.tensor(np.array(metric_cache[name]))

## Defining Progress Measures

### Loss Curves

These epoch numbers are estimated from the plot and are useful for visualization.

In [None]:
memorization_end_epoch = 1500
circuit_formation_end_epoch = 22500
cleanup_end_epoch = 24000

In [None]:
def add_lines(figure):
    figure.add_vline(memorization_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(circuit_formation_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(cleanup_end_epoch, line_dash="dash", opacity=0.7)
    return figure

In [None]:
fig = line([train_losses[::100],
            test_losses[::100]],
            x=np.arange(0, len(train_losses), 100),
            xaxis="Epoch",
            yaxis="Loss",
            log_y=True,
            title="Training Curve for Modular Subtraction",
            line_labels=['train', 'test'],
            toggle_x=True,
            toggle_y=True,
            return_fig=True)
add_lines(fig)

### Logit Periodicity

In [None]:
all_logits = original_logits[:, -1, :]
print(all_logits.shape)

all_logits = einops.rearrange(all_logits, "(a b) c -> a b c", a=p, b=p)
print(all_logits.shape)

torch.Size([12769, 113])
torch.Size([113, 113, 113])


Following our hypothesized formula, we begin constructing the logits that our formula would predict.

In [None]:
coses = {}
for freq in key_freqs:
    print("Freq:", freq)
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a - b - c)).to(DEVICE)
    cube_predicted_logits /= cube_predicted_logits.norm()
    coses[freq] = cube_predicted_logits

Freq: 1
Freq: 5
Freq: 34


Observe that the cosine similarities are nontrivial: 0.4242, 0.7101, 0.2435, and for the residual 0.5065.

In [None]:
approximated_logits = torch.zeros_like(all_logits)
for freq in key_freqs:
    print("Freq:", freq)

    # represents how much of `all_logits` is in the direction of the cosine pattern.
    coeff = (all_logits * coses[freq]).sum()
    print("Coeff:", coeff)

    # normalized measure of how much the cosine pattern aligns with `all_logits`.
    cosine_sim = coeff / all_logits.norm()
    print("Cosine Sim:", cosine_sim, "\n\n")

    # builds approx by adding contributions from each significant cosine pattern.
    approximated_logits += coeff * coses[freq]

residual = all_logits - approximated_logits
print("Residual size:", residual.norm())
print("Residual fraction of norm:", residual.norm()/all_logits.norm())

Freq: 1
Coeff: tensor(18928.9023, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.4242, device='cuda:0', grad_fn=<DivBackward0>) 


Freq: 5
Coeff: tensor(31688.7910, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.7101, device='cuda:0', grad_fn=<DivBackward0>) 


Freq: 34
Coeff: tensor(10867.9258, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.2435, device='cuda:0', grad_fn=<DivBackward0>) 


Residual size: tensor(22602.7246, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Residual fraction of norm: tensor(0.5065, device='cuda:0', grad_fn=<DivBackward0>)


For a random vector, the cosine similarity is very small! This corroborates the strength of our approximation.

In [None]:
random_logit_cube = torch.randn_like(all_logits)
print((all_logits * random_logit_cube).sum()/random_logit_cube.norm()/all_logits.norm())

tensor(-0.0005, device='cuda:0', grad_fn=<DivBackward0>)


The loss using the approximated logits remains relatively stable (same magnitude) as the true logits of the model. This supports our hypothesized formula approximation.

In [None]:
test_logits(all_logits)

tensor(1.2153e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)

In [None]:
test_logits(approximated_logits)

tensor(7.6232e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)

We extend this to the training loop.

In [None]:
cos_cube = []
for freq in range(1, p//2 + 1):
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]

    # calculate cosine values across the 3D grid.
    # cosine wave pattern is based on the formula: cos(freq * 2 * pi / p * (a - b - c))
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a - b - c)).to(DEVICE)

    # Normalize the cosine pattern to have unit norm.
    cube_predicted_logits /= cube_predicted_logits.norm()
    cos_cube.append(cube_predicted_logits)
cos_cube = torch.stack(cos_cube, dim=0)
print(cos_cube.shape)

torch.Size([56, 113, 113, 113])


In [None]:
def get_cos_coeffs(model):
    """
    Calculate the coefficients of cosine wave patterns for the model's output logits.
    """
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)

    # projects the logits onto the space defined by each cosine pattern,
    # effectively measuring how much each pattern is represented in the logits.
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals

get_metrics(model, metric_cache, get_cos_coeffs, "cos_coeffs")
print("Cached cosine coefficients shape:", metric_cache["cos_coeffs"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

Cached cosine coefficients shape: torch.Size([250, 56])


In [None]:
def get_cos_sim(model):
    """
    Calculate cosine similarity between the model's output logits and predefined cosine patterns.
    """
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)

    # Calculate the dot product of logits and cosine patterns, then sum over spatial dimensions.
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals / logits.norm()

get_metrics(model, metric_cache, get_cos_sim, "cos_sim")
print(metric_cache["cos_sim"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250, 56])


We confirm that output increasingly relies on the key frequencies as training progresses. The outputs **gradually** coalesce around these key frequencies—despite grokking feeling sudden!

Observe the importance of frequency 5 developing over time (along with the key frequencies 1 and 34). This provides a leading indicator of the eventual grok!

Further, we notice that in the beginning, the residual line shows that almost nothing can be explained by the cosine key frequencies (residuals have a value of 1), but near grokking, the residual value decreases significantly (as the key frequencies are able to explain the logits well).

In [None]:
def get_residual_cos_sim(model):
    """
    Calculate the residual cosine similarity of model outputs relative to logits predicted by the formula AND
    the residuals (the proportion of the logits that CANNOT be explained by the cosine basis).
    """
    logits = model(dataset)[:, -1] # get last logits
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)

    # Project logits onto each cosine pattern & sum results to get a scalar value for each pattern
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])

    residual = logits - (vals[:, None, None, None] * cos_cube).sum(dim=0)
    return residual.norm() / logits.norm()

get_metrics(model, metric_cache, get_residual_cos_sim, "residual_cos_sim")
print(metric_cache["residual_cos_sim"].shape)

fig = line([metric_cache["cos_sim"][:, i] for i in range(p//2)]+[metric_cache["residual_cos_sim"]],
            line_labels=[f"Freq {i}" for i in range(1, p//2+1)]+["residual"],
            title="Cosine Similarity with Predicted Logits & Residual",
            xaxis="Epoch", x=checkpoint_epochs,
            yaxis="Cosine Similarity",
            return_fig=True)
add_lines(fig)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250])


## Restricted Loss

We evaluate the restricted loss during the training loop. The restricted loss when we ablate every non-key frequency (i.e. we remove all frequencies except for 1, 5, and 34), and assess how the model performs. If our hypothesis is correct, then the model will still exhibit acceptable performance (i.e. low loss).

In [None]:
def get_restricted_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]

    approx_neuron_acts = torch.zeros_like(neuron_acts)

    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        # create & normalize a cosine wave matrix
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a - b)).to(DEVICE)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")

        # project neuron activations onto cosine basis and add to approximations
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec

        # create & normalize a sine wave matrix
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a - b)).to(DEVICE)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")

        # project neuron activations onto sine basis and add to approximations
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec

    restricted_logits = approx_neuron_acts @ model.blocks[0].mlp.W_out @ model.unembed.W_U

    # Add bias term
    restricted_logits += logits.mean(dim=0, keepdim=True) - restricted_logits.mean(dim=0, keepdim=True)
    return loss_fn(restricted_logits[test_indices], test_labels).item()

print("Restricted Loss:", get_restricted_loss(model))

Restricted Loss: 0.00015509621008755154


In [None]:
# clear cache memory
del fourier_basis
del neuron_freq_norm
del cube_predicted_logits
del metric_cache["cos_coeffs"]
del metric_cache["residual_cos_sim"]
del metric_cache["cos_sim"]

torch.cuda.empty_cache()

In [None]:
with torch.no_grad():
  get_metrics(model, metric_cache, get_restricted_loss, "restricted_loss", reset=True)
  print(metric_cache["restricted_loss"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250])


As desired, our restricted loss nearly matches the train/test losses near 0 once the grok occurs. We expected it to remain relatively constant beforehand because the model is only able to learn using the key frequencies. This plot further validates our hypothesis.

In [None]:
fig = line([train_losses[::100],
            test_losses[::100],
            metric_cache["restricted_loss"]],
           x=np.arange(0, len(train_losses), 100),
           xaxis="Epoch", yaxis="Loss",
           log_y=True,
           title="Restricted Loss Curve",
           line_labels=['train', 'test', "restricted_loss"],
           toggle_x=True,
           toggle_y=True,
           return_fig=True)
add_lines(fig)

We expect these two losses to be relatively stable (i.e. ratio of 1), until the grok, when the restricted loss becomes larger than the test loss because although the key frequencies can explain the model outputs well, they can't explain it as perfectly as when the model has access to **all** the frequencies.

In [None]:
fig = line([metric_cache["restricted_loss"]/torch.tensor(test_losses[::100])],
           x=np.arange(0, len(train_losses), 100),
           xaxis="Epoch",
           yaxis="Loss",
           log_y=True,
           title="Restricted Loss to Test Loss Ratio",
           toggle_x=True,
           toggle_y=True,
           return_fig=True)
add_lines(fig)

## Excluded Loss

Now, we determine the excluded loss by ablating ONLY the key frequencies (meaning frequencies 1, 5, and 34 are removed from the model). If our hpyothesis is correct, then our model performance will get *worse* after grokking.

In [None]:
def get_excluded_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)

    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        # create & normalize a cosine wave matrix
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a - b)).to(DEVICE)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")

        # project neuron activations onto cosine basis and add to approximations
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec

        # create & normalize a sine wave matrix
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a - b)).to(DEVICE)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")

        # project neuron activations onto sine basis and add to approximations
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec

    excluded_neuron_acts = neuron_acts - approx_neuron_acts
    residual_stream_final = excluded_neuron_acts @ model.blocks[0].mlp.W_out + cache["resid_mid", 0][:, -1, :]
    excluded_logits = residual_stream_final @ model.unembed.W_U

    return loss_fn(excluded_logits[train_indices], train_labels)

print("Excluded Loss:", get_excluded_loss(model).item())

Excluded Loss: 23.777497850680085


In [None]:
get_metrics(model, metric_cache, get_excluded_loss, "excluded_loss", reset=True)
print(metric_cache["excluded_loss"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250])


As the model begins generalizing, it increasingly relies on the key frequencies. So, in the excluded loss (where we have ablated them), we expect and, in fact, observe that the excluded loss gets worse after the grok, since the model's memorization has been "weight decayed away". So, the model is unable to rely on the key frequencies that support the general algorithm.

In [None]:
fig = line([train_losses[::100],
            test_losses[::100],
            metric_cache["excluded_loss"],
            metric_cache["restricted_loss"]],
           x=np.arange(0, len(train_losses), 100),
           xaxis="Epoch",
           yaxis="Loss",
           log_y=True,
           title="Excluded and Restricted Loss Curve",
           line_labels=['train', 'test', "excluded_loss", "restricted_loss"],
           toggle_x=True,
           toggle_y=True,
           return_fig=True)

add_lines(fig)