# Setup

In [25]:
import os
from IPython import get_ipython
ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

import plotly.io as pio
pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

import warnings

warnings.filterwarnings("ignore")

device = utils.get_device()
print(f"Using device: {device}")

device="cpu"


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using renderer: notebook_connected
Using device: mps


In [26]:
# plotting
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)


In [27]:
# Save dir for model

PTH_LOCATION = "models/grokking_mod.pth"
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)


# Model training

## config

In [28]:
p = 113
train_frac = 0.30

# optimizer configs
lr = 5e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 35_000
checkpoint_every = 100

DATA_SEED = 598


## Define task
- define mod addition
- define dataset and labels

### input format: |a|b|=|

In [29]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i =p, j = p)

dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)

print(dataset[:5])
print(dataset.shape)

labels = (dataset[:, 0] + dataset[:, 1]) % p
print(labels.shape)
print(labels[:5])


tensor([[  0,   0, 113],
        [  0,   1, 113],
        [  0,   2, 113],
        [  0,   3, 113],
        [  0,   4, 113]])
torch.Size([12769, 3])
torch.Size([12769])
tensor([0, 1, 2, 3, 4])


### convert to train and test

In [30]:
torch.manual_seed(DATA_SEED)

indices = torch.randperm(p*p)
cutoff = int(p*p*train_frac)

train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]

print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)


tensor([[ 21,  31, 113],
        [ 30,  98, 113],
        [ 47,  10, 113],
        [ 86,  21, 113],
        [ 99,  83, 113]])
tensor([ 52,  15,  57, 107,  69])
torch.Size([3830, 3])
tensor([[ 43,  40, 113],
        [ 31,  42, 113],
        [ 39,  63, 113],
        [ 35,  61, 113],
        [112, 102, 113]])
tensor([ 83,  73, 102,  96, 101])
torch.Size([8939, 3])


## Define Model

In [31]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "gelu",
    normalization_type=None,
    d_vocab = p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device="cpu",
    seed=999
)


In [32]:
model = HookedTransformer(cfg)


In [33]:
# disable biases
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define optimizer + loss

In [34]:
opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)


In [36]:
def loss_fn(logits, labels):
    if len(logits.shape) == 3:
        logits = logits[:, -1]
        # print(logits.shape)
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)


tensor(4.7334, dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(4.7312, dtype=torch.float64, grad_fn=<NegBackward0>)


In [37]:
# uniform loss at beginning
print(np.log(p))


4.727387818712341


## Train model

In [None]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
for epoch in tqdm.tqdm(range(num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())
    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)
        test_losses.append(test_loss.item())
    if (epoch+1) % checkpoint_every == 0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        print(f"Epoch {epoch+1} | Train Loss: {train_loss.item()} | Test Loss: {test_loss.item()}")
