<a href="https://colab.research.google.com/github/JonathanRaines/bluedot-intro-to-mech-interp/blob/main/notebooks/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
# Install Poetry fro dependency management
%%capture
!curl -sSL https://install.python-poetry.org | python3 -


In [2]:
# Install Dependencies
%%capture
!poetry install

In [3]:
import numpy as np
import os
import pathlib
import plotly.io as pio
import plotly.graph_objects as go
import torch
import tqdm

from src import const
from src import loss
from src import model
from src import task

import copy

In [4]:
# For developoment. Enables making changes to modules without restarting the runtime.
%%capture
import importlib
importlib.reload(loss)
importlib.reload(model)
importlib.reload(task)

In [5]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"
# Download to persist/grokking_demo.pth after training

# Create the directory if it does not exist
os.makedirs(pathlib.Path(PTH_LOCATION).parent, exist_ok=True)

# Configuration

In [6]:
DEVICE = "cuda"

P=const.MOD
DATA_SEED = 598
TRAINING_FRACTION = 0.3

# Optimizer config
LR = 1e-3
WD = 1.0 # Very large, makes grokking happen faster, encouarges a simple model
BETAS = (0.9, 0.98)

NUM_EPOCHS = 25_000
CHECKPOINT_EVERY = 100
EARLY_STOPPING_LOSS = 1e-6

PLOTLY_TEMPLATE = "plotly_dark"

## Create the model

In [7]:
hooked_model = model.get_hooked_transformer(p=P, device=DEVICE)
# Disable the biases, as we don't need them for this task and it makes things easier to interpret.
for name, param in hooked_model.named_parameters():
    if "b_" in name:
        param.requires_grad = False

## Define the task

In [8]:
train_data, train_labels, test_data, test_labels = task.make_train_and_test_data(p=P, device=DEVICE, data_seed=DATA_SEED, training_fraction=TRAINING_FRACTION)

## Define Optimizer

In [9]:
optimizer = torch.optim.AdamW(hooked_model.parameters(), lr=LR, weight_decay=WD, betas=BETAS)

# Train

In [None]:
train_losses: list[float] = []
test_losses: list [float] = []
model_checkpoints = []
checkpoint_epochs: list[int] = []

for epoch in (pbar := tqdm.trange(NUM_EPOCHS, unit=" epoch")):
    train_logits = hooked_model(train_data)
    train_loss = loss.mean_log_prob_loss(train_logits, train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())

    optimizer.step()
    optimizer.zero_grad()

    with torch.inference_mode():
      test_logits = hooked_model(test_data)
      test_loss = loss.mean_log_prob_loss(test_logits, test_labels)
      test_losses.append(test_loss.item())

    if ((epoch+1) % CHECKPOINT_EVERY)==0:
      checkpoint_epochs.append(epoch)
      model_checkpoints.append(copy.deepcopy(hooked_model.state_dict()))
      pbar.set_description(f"Train Loss {train_losses[-1]:.4f}, Test Loss {test_losses[-1]:.4f}")

    if test_losses[-1] < EARLY_STOPPING_LOSS:
      print(f"\nEarly stopping after {epoch} epochs.")
      break

Train Loss 0.0000, Test Loss 27.5485:  14%|█▎        | 3392/25000 [01:24<08:17, 43.45 epoch/s]

In [None]:
torch.save(
    {
        "model":hooked_model.state_dict(),
        "config": hooked_model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        # "train_indices": train_indices,
        # "test_indices": test_indices,
    },
    PTH_LOCATION)

# Analysis
## Show Model Training Statistics, Check that it groks!

In [None]:
x = np.arange(0, len(train_losses), 100)
y1 = train_losses[::100]
y2 = test_losses[::100]

fig = go.Figure(
    data = [
        go.Scatter(x=x, y=y1, name="train"),
        go.Scatter(x=x, y=y2, name="test"),
    ],
    layout = {
        "xaxis": {"title": "Epoch"},
        "yaxis": {"title": "Loss", "type": "log"},
        "title": f"Training Curve for Base {P} Modular Addition",
        "template": PLOTLY_TEMPLATE,
    }
)
fig
