<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Grokking Demo Notebook

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

# Setup
(No need to read)

In [None]:
TRAIN_MODEL = True

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

# Model Training

## Config

In [None]:
frac_train = 0.72

# Optimizer config
lr = 1e-4
wd = 0.1
betas = (0.9, 0.98)

num_epochs = 10000
checkpoint_every = 100

data_seed = 598

## Define Task
We want the model to learn to compute $sin(k x)$ given $k$. Hence, the input is $k$ and the desired output is a discretized version of $sin(k x)$ with $x$ ranging from $[0, 2\pi]$ with 100 data points.

We use 1000 samples for $k$, drawn uniformly at random from the $[0.1, 10.]$ interval.

In [None]:
num_samples = 1000
num_x_points = 100
k_min = 0.1
k_max = 10.0

# Create the k values that will be the input to the model
k_vector = (k_max - k_min) * torch.rand(num_samples, generator=torch.manual_seed(data_seed)) + k_min
dataset = k_vector.unsqueeze(1).to(device) # Shape: [num_samples, 1]

# Create the corresponding sin(kx) curves which are the labels
x_points = torch.linspace(0, 2 * torch.pi, num_x_points, device=device) # Shape: [num_x_points]
# Use broadcasting to compute all sine curves at once
# (num_samples, 1) * (num_x_points,) -> (num_samples, num_x_points)
labels = torch.sin(dataset * x_points)

# --- Split into Training and Test Sets ---
torch.manual_seed(data_seed)
indices = torch.randperm(num_samples)
cutoff = int(num_samples * frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]

## Define Model

The standard hooked transformer config from TransformerLens is set up for discrete tokens. We are however considering continuous inputs and outputs in this regression task. Hence, in order to keep all the useful features of HookedTransformer, we will make a new class that inherits from it and adjust the input and output layers to allow for continuous inputs and outputs.

In [None]:
class SineTransformer(HookedTransformer):
    def __init__(self, cfg):
        super().__init__(cfg)

        # The parent __init__ created self.embed and self.unembed, which are for
        # discrete tokens. We don't need them. It's good practice to delete them
        # to avoid confusion.
        del self.embed
        del self.unembed

        # --- Create New Layers for our Regression Task ---

        # 1. An input layer to project our continuous k value into d_model
        self.k_input_projection = nn.Linear(1, self.cfg.d_model)

        # 2. An output layer to project from d_model back to a single float value
        # at each position in the output sequence.
        self.output_projection = nn.Linear(self.cfg.d_model, 1)

    def forward(self, k_values):
        # k_values shape: [batch_size, 1]

        # 1. Project k into the model's dimension using our new layer
        k_embedding = self.k_input_projection(k_values) # [batch_size, d_model]

        # 2. Get the positional embeddings for our 100 x-points
        # The 'n_ctx' in our config now represents the number of output points.
        # W_pos is the learned positional embedding matrix of shape [n_ctx, d_model]
        positional_embeddings = self.pos_embed.W_pos[:self.cfg.n_ctx, :] # [n_ctx, d_model]

        residual_stream = k_embedding.unsqueeze(1) + positional_embeddings

        # --- THE FIX IS HERE ---
        # Instead of calling self.blocks directly, we loop through the modules inside it.
        for block in self.blocks:
            residual_stream = block(residual_stream)

        # Now we use the final state of the residual stream
        output_values = self.output_projection(residual_stream)

        return output_values.squeeze(-1)

# This is the updated configuration for the Sine task
cfg = HookedTransformerConfig(
    n_layers=1,
    n_heads=4,
    d_model=128,
    d_head=32,
    d_mlp=512,
    act_fn="relu",
    normalization_type="LN", # LayerNorm is generally helpful. You can set to None.
    n_ctx=num_x_points,
    # d_vocab and d_vocab_out are no longer used by our new model,
    # but the config object requires them. We can set them to a dummy value.
    d_vocab=10,
    d_vocab_out=10,
    init_weights=True,
    device=device,
    seed=999,
)

# Instantiate our NEW model class with the updated config
model = SineTransformer(cfg).to(device)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [None]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
loss_fn = nn.MSELoss()

train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(f"Initial Training Loss: {train_loss.item()}")

test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(f"Initial Test Loss: {test_loss.item()}")

## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

# --- New, more powerful optimizer and scheduler config ---
lr = 1e-3           # Start high again
wd = 0.01           # A much more reasonable weight decay
betas = (0.9, 0.98)

# Re-initialize the optimizer with the new parameters
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

# Initialize the scheduler
# T_max is the total number of epochs you plan to run this new training session for.
# It will decay the LR from 1e-3 down to 0 over this period.
new_num_epochs = 20000
scheduler = CosineAnnealingLR(optimizer, T_max=new_num_epochs)

# --- New Training Loop (in a new cell) ---
# We can start new lists for this run to see the effect clearly
new_train_losses = []
new_test_losses = []

for epoch in tqdm.tqdm(range(new_num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)

    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    # Crucially, step the scheduler after each epoch
    scheduler.step()

    new_train_losses.append(train_loss.item())

    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)
        new_test_losses.append(test_loss.item())

    if ((epoch+1)%checkpoint_every)==0:
        # You can get the current learning rate from the scheduler to see it decrease
        current_lr = scheduler.get_last_lr()[0]
        print(f"Epoch {epoch} | Test Loss {test_loss.item():.6f} | LR {current_lr:.6f}")

In [None]:
torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)

if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    train_indices = cached_data["train_indices"]
    test_indices = cached_data["test_indices"]

## Show Model Training Statistics, Check that it groks!

In [None]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line

In [None]:
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Fibonacci", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)

# Analysing the Model

Get key weight matrices:

In [None]:
# Assume 'model' is your trained instance
# and 'device' is set correctly

# Choose some k values to inspect
k_to_inspect = torch.tensor([[1.0], [5.0], [10.0]], device=device)

# 1. Compute the k-embedding
k_embedding = model.k_input_projection(k_to_inspect)
print("Shape of k_embedding for 3 sample k's:", k_embedding.shape)

# 2. Get the positional embedding
positional_embedding = model.pos_embed.W_pos[:num_x_points, :]
print("\nShape of positional_embedding:", positional_embedding.shape) # Expected: [100, 128]

# 3. Compute the full initial residual stream
k_broadcast = k_embedding.unsqueeze(1).expand(-1, num_x_points, -1)

# Now the addition will work because PyTorch can broadcast [100, 128] to match [3, 100, 128]
initial_residual_stream = k_broadcast + positional_embedding

print("\nShape of the full initial embedding:", initial_residual_stream.shape)
print("This is the 'effective embedding' that the first transformer block sees.")


In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


# --- 1. Generate a set of k values for analysis ---
k_for_pca = torch.linspace(k_min, k_max, 1000, device=device).unsqueeze(-1)

# --- 2. Get the embeddings for these k values ---
model.eval() # Put the model in evaluation mode
with torch.no_grad(): # We don't need to compute gradients
    k_embeddings = model.k_input_projection(k_for_pca)

# Move embeddings to CPU and convert to numpy for sklearn
k_embeddings_np = k_embeddings.cpu().numpy()

# --- 3. Perform PCA ---
pca = PCA() # By default, finds all components
pca.fit(k_embeddings_np)

# --- 4. Plot the explained variance ---
explained_variance = pca.explained_variance_ratio_
cumulative_variance = np.cumsum(explained_variance)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.bar(range(1, len(explained_variance) + 1), explained_variance)
plt.xlabel("Principal Component")
plt.ylabel("Explained Variance Ratio")
plt.title("Scree Plot")

plt.subplot(1, 2, 2)
plt.plot(range(1, len(cumulative_variance) + 1), cumulative_variance, marker='o', linestyle='--')
plt.xlabel("Number of Components")
plt.ylabel("Cumulative Explained Variance")
plt.title("Cumulative Variance Plot")
plt.grid(True)
plt.ylim(0, 1.1)
plt.axhline(y=0.99, color='r', linestyle=':', label='99% Variance')
plt.legend()

plt.tight_layout()
plt.show()

# --- Interpretation ---
num_components_for_99_variance = np.argmax(cumulative_variance >= 0.99) + 1
print(f"Number of components to explain 99% of variance: {num_components_for_99_variance}")


All of the variance is explained by a single component of the 128-dimensional vector. We now investigate what this component looks like.

In [None]:
# --- 1. Project the embeddings onto the first principal component ---
# The pca.transform() method projects the data onto the components.
# We only want the first column, which corresponds to PC1.
projected_data_pc1 = pca.transform(k_embeddings_np)[:, 0]

# --- 2. Prepare the original k values for plotting ---
# We need them as a flat numpy array on the CPU.
k_values_for_plot = k_for_pca.cpu().numpy().flatten()

# --- 3. Create the plot ---
plt.figure(figsize=(8, 6))
plt.plot(k_values_for_plot, projected_data_pc1)
plt.xlabel("Input k value")
plt.ylabel("Projection onto Principal Component 1")
plt.title("Learned Representation of k (Projected onto PC1)")
plt.grid(True)
plt.show()


Now, we want to look at the positional embedding: this gives us an idea of what the model thinks 'space' looks like. We do PCA again:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch.nn.functional as F

# --- 1. Get the positional embedding matrix ---
# W_pos is the parameter we want to analyze.
# It has shape [num_x_points, d_model], e.g., [100, 128].
w_pos = model.pos_embed.W_pos.detach().cpu().numpy()

# --- 2. Perform and Plot PCA ---
pca = PCA()
pca.fit(w_pos)

explained_variance = pca.explained_variance_ratio_

plt.figure(figsize=(6, 4))
plt.bar(range(1, 11), explained_variance[:10]) # Plot first 10 components
plt.xlabel("Principal Component")
plt.ylabel("Explained Variance Ratio")
plt.title("PCA of Positional Embeddings (W_pos)")
plt.show()

# --- 3. Compute and Plot the Cosine Similarity Matrix ---
# This shows the similarity between the vector for position i and position j
w_pos_tensor = torch.from_numpy(w_pos)
cosine_sim = F.cosine_similarity(w_pos_tensor.unsqueeze(1), w_pos_tensor.unsqueeze(0), dim=-1)

plt.figure(figsize=(7, 6))
plt.imshow(cosine_sim.numpy(), cmap='viridis')
plt.colorbar(label="Cosine Similarity")
plt.xlabel("Position j")
plt.ylabel("Position i")
plt.title("Cosine Similarity of Positional Embeddings")
plt.show()