<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Grokking Demo Notebook

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

# Setup
(No need to read)

In [81]:
TRAIN_MODEL = True

In [82]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [83]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: notebook_connected


In [84]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [85]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [86]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


device = "cuda" if torch.cuda.is_available() else "cpu"

Plotting helper functions:

In [87]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [88]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

# Model Training

## Config

In [89]:
p = 113
frac_train = 0.3

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 25000
checkpoint_every = 100

DATA_SEED = 598

## Define Task
* Define modular addition
* Define the dataset & labels

Input format:
|a|b|=|

In [90]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i=p, j=p)


In [91]:
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)
print(dataset[:5])
print(dataset.shape)

tensor([[  0,   0, 113],
        [  0,   1, 113],
        [  0,   2, 113],
        [  0,   3, 113],
        [  0,   4, 113]], device='cuda:0')
torch.Size([12769, 3])


In [92]:
labels = (dataset[:, 0] + dataset[:, 1]) % p
print(labels.shape)
print(labels[:5])

torch.Size([12769])
tensor([0, 1, 2, 3, 4], device='cuda:0')


Convert this to a train + test set - 30% in the training set

In [93]:
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

tensor([[ 21,  31, 113],
        [ 30,  98, 113],
        [ 47,  10, 113],
        [ 86,  21, 113],
        [ 99,  83, 113]], device='cuda:0')
tensor([ 52,  15,  57, 107,  69], device='cuda:0')
torch.Size([3830, 3])
tensor([[ 43,  40, 113],
        [ 31,  42, 113],
        [ 39,  63, 113],
        [ 35,  61, 113],
        [112, 102, 113]], device='cuda:0')
tensor([ 83,  73, 102,  96, 101], device='cuda:0')
torch.Size([8939, 3])


## Define Model

In [94]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device=device,
    seed = 999,
)

In [95]:
model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [96]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [97]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [98]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

tensor(4.7359, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(4.7330, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


In [99]:
print("Uniform loss:")
print(np.log(p))

Uniform loss:
4.727387818712341


## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [100]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

  0%|          | 113/25000 [00:01<05:32, 74.83it/s]

Epoch 99 Train Loss 2.9216481191413375 Test Loss 7.6559031896872165


  1%|          | 212/25000 [00:03<06:22, 64.83it/s]

Epoch 199 Train Loss 0.030839887431076563 Test Loss 19.93917900257167


  1%|          | 309/25000 [00:04<06:07, 67.27it/s]

Epoch 299 Train Loss 0.009591797266535414 Test Loss 20.63647831249946


  2%|▏         | 410/25000 [00:06<05:42, 71.72it/s]

Epoch 399 Train Loss 0.0030786268437847614 Test Loss 21.75797263942504


  2%|▏         | 513/25000 [00:07<06:19, 64.53it/s]

Epoch 499 Train Loss 0.001020158643209917 Test Loss 22.978755738221764


  2%|▏         | 609/25000 [00:09<05:47, 70.14it/s]

Epoch 599 Train Loss 0.00034355147184966903 Test Loss 24.253354576350326


  3%|▎         | 705/25000 [00:10<06:46, 59.72it/s]

Epoch 699 Train Loss 0.0001173767403831185 Test Loss 25.524252550329265


  3%|▎         | 810/25000 [00:12<05:40, 71.05it/s]

Epoch 799 Train Loss 4.073800109283013e-05 Test Loss 26.799500937218017


  4%|▎         | 911/25000 [00:13<05:55, 67.77it/s]

Epoch 899 Train Loss 1.4425790894557507e-05 Test Loss 28.049518731389735


  4%|▍         | 1013/25000 [00:15<05:44, 69.64it/s]

Epoch 999 Train Loss 5.31517889310081e-06 Test Loss 29.23085039873119


  4%|▍         | 1108/25000 [00:16<05:39, 70.37it/s]

Epoch 1099 Train Loss 2.1192837692262375e-06 Test Loss 30.277232977110355


  5%|▍         | 1208/25000 [00:17<05:30, 72.01it/s]

Epoch 1199 Train Loss 9.775232801675413e-07 Test Loss 31.07915965065423


  5%|▌         | 1312/25000 [00:19<05:31, 71.46it/s]

Epoch 1299 Train Loss 5.579148897536628e-07 Test Loss 31.557401760210823


  6%|▌         | 1414/25000 [00:20<05:41, 68.99it/s]

Epoch 1399 Train Loss 4.046038562594594e-07 Test Loss 31.67050997267148


  6%|▌         | 1512/25000 [00:22<05:46, 67.75it/s]

Epoch 1499 Train Loss 3.5393501466312126e-07 Test Loss 31.506347159399365


  6%|▋         | 1614/25000 [00:23<05:25, 71.91it/s]

Epoch 1599 Train Loss 3.40993478873628e-07 Test Loss 31.206373056481755


  7%|▋         | 1710/25000 [00:24<05:15, 73.84it/s]

Epoch 1699 Train Loss 3.3807053184244073e-07 Test Loss 30.87688152048916


  7%|▋         | 1813/25000 [00:26<05:31, 70.04it/s]

Epoch 1799 Train Loss 3.3693194573610234e-07 Test Loss 30.548944499151943


  8%|▊         | 1910/25000 [00:27<05:05, 75.65it/s]

Epoch 1899 Train Loss 3.3588397158135484e-07 Test Loss 30.226380183732257


  8%|▊         | 2009/25000 [00:29<05:30, 69.50it/s]

Epoch 1999 Train Loss 3.348305685494867e-07 Test Loss 29.912679120351907


  8%|▊         | 2110/25000 [00:30<05:40, 67.22it/s]

Epoch 2099 Train Loss 3.3369852268226277e-07 Test Loss 29.602481788099706


  9%|▉         | 2208/25000 [00:32<05:07, 74.01it/s]

Epoch 2199 Train Loss 3.324701950377364e-07 Test Loss 29.304015114850245


  9%|▉         | 2313/25000 [00:33<05:17, 71.46it/s]

Epoch 2299 Train Loss 3.309335430800161e-07 Test Loss 29.02118656322816


 10%|▉         | 2410/25000 [00:34<05:05, 74.02it/s]

Epoch 2399 Train Loss 3.297539943475546e-07 Test Loss 28.74676506111603


 10%|█         | 2514/25000 [00:36<05:02, 74.31it/s]

Epoch 2499 Train Loss 3.286141160488954e-07 Test Loss 28.476940850660558


 10%|█         | 2612/25000 [00:37<05:05, 73.38it/s]

Epoch 2599 Train Loss 3.27324470552515e-07 Test Loss 28.221826005265832


 11%|█         | 2709/25000 [00:38<05:13, 71.11it/s]

Epoch 2699 Train Loss 3.259463317021575e-07 Test Loss 27.972172758027586


 11%|█         | 2805/25000 [00:40<05:18, 69.62it/s]

Epoch 2799 Train Loss 3.2442126765056877e-07 Test Loss 27.7225893594591


 12%|█▏        | 2908/25000 [00:41<05:25, 67.93it/s]

Epoch 2899 Train Loss 3.2334938813769394e-07 Test Loss 27.477487897206206


 12%|█▏        | 3014/25000 [00:43<04:52, 75.17it/s]

Epoch 2999 Train Loss 3.2205449274052665e-07 Test Loss 27.235222342077545


 12%|█▏        | 3110/25000 [00:44<05:05, 71.67it/s]

Epoch 3099 Train Loss 3.208921914155068e-07 Test Loss 26.995013838946083


 13%|█▎        | 3206/25000 [00:45<05:11, 70.03it/s]

Epoch 3199 Train Loss 3.19709972446294e-07 Test Loss 26.760973984133724


 13%|█▎        | 3310/25000 [00:47<05:11, 69.68it/s]

Epoch 3299 Train Loss 3.185576003037801e-07 Test Loss 26.529345281179115


 14%|█▎        | 3412/25000 [00:48<05:06, 70.37it/s]

Epoch 3399 Train Loss 3.175612231904462e-07 Test Loss 26.295942430726345


 14%|█▍        | 3510/25000 [00:50<04:56, 72.51it/s]

Epoch 3499 Train Loss 3.165030448914142e-07 Test Loss 26.065901034625703


 14%|█▍        | 3607/25000 [00:51<04:52, 73.19it/s]

Epoch 3599 Train Loss 3.1541342171542204e-07 Test Loss 25.834517753033193


 15%|█▍        | 3712/25000 [00:52<04:54, 72.38it/s]

Epoch 3699 Train Loss 3.142346482279651e-07 Test Loss 25.609819571056775


 15%|█▌        | 3809/25000 [00:54<04:46, 74.03it/s]

Epoch 3799 Train Loss 3.1322788293280234e-07 Test Loss 25.388500528234818


 16%|█▌        | 3913/25000 [00:55<04:54, 71.67it/s]

Epoch 3899 Train Loss 3.1207953321692924e-07 Test Loss 25.163632588362645


 16%|█▌        | 4009/25000 [00:56<04:47, 73.04it/s]

Epoch 3999 Train Loss 3.111593320928089e-07 Test Loss 24.94348113915603


 16%|█▋        | 4113/25000 [00:58<04:59, 69.81it/s]

Epoch 4099 Train Loss 3.105580707479683e-07 Test Loss 24.721470214224855


 17%|█▋        | 4208/25000 [00:59<04:59, 69.41it/s]

Epoch 4199 Train Loss 3.0923785468939507e-07 Test Loss 24.49074126091588


 17%|█▋        | 4309/25000 [01:01<05:03, 68.28it/s]

Epoch 4299 Train Loss 3.0783799633945493e-07 Test Loss 24.250407523172843


 18%|█▊        | 4410/25000 [01:02<04:35, 74.77it/s]

Epoch 4399 Train Loss 3.072148520047709e-07 Test Loss 24.014842592070597


 18%|█▊        | 4507/25000 [01:04<05:14, 65.20it/s]

Epoch 4499 Train Loss 3.057576104459069e-07 Test Loss 23.77982098200813


 18%|█▊        | 4612/25000 [01:05<05:07, 66.40it/s]

Epoch 4599 Train Loss 3.0486691302833746e-07 Test Loss 23.535799844467874


 19%|█▉        | 4703/25000 [01:07<05:23, 62.65it/s]

Epoch 4699 Train Loss 3.0367063674072703e-07 Test Loss 23.287001876295058


 19%|█▉        | 4808/25000 [01:08<05:29, 61.24it/s]

Epoch 4799 Train Loss 3.0278754903965605e-07 Test Loss 23.038015727740692


 20%|█▉        | 4912/25000 [01:10<05:27, 61.43it/s]

Epoch 4899 Train Loss 3.0178278277315354e-07 Test Loss 22.782773145667008


 20%|██        | 5008/25000 [01:12<05:03, 65.84it/s]

Epoch 4999 Train Loss 3.007512388114639e-07 Test Loss 22.5239677518203


 20%|██        | 5112/25000 [01:13<04:43, 70.21it/s]

Epoch 5099 Train Loss 2.998458709871907e-07 Test Loss 22.266121408888043


 21%|██        | 5211/25000 [01:15<04:18, 76.44it/s]

Epoch 5199 Train Loss 2.9863796360630285e-07 Test Loss 21.993623033703955


 21%|██        | 5310/25000 [01:16<04:06, 80.00it/s]

Epoch 5299 Train Loss 2.9726681686026594e-07 Test Loss 21.714174381566863


 22%|██▏       | 5414/25000 [01:17<04:11, 77.89it/s]

Epoch 5399 Train Loss 2.958705142247771e-07 Test Loss 21.42145840216053


 22%|██▏       | 5514/25000 [01:18<04:03, 79.98it/s]

Epoch 5499 Train Loss 2.948050034903859e-07 Test Loss 21.125924489155462


 22%|██▏       | 5613/25000 [01:20<04:38, 69.51it/s]

Epoch 5599 Train Loss 2.939480446196377e-07 Test Loss 20.83049673548172


 23%|██▎       | 5711/25000 [01:21<04:37, 69.53it/s]

Epoch 5699 Train Loss 2.928115348728939e-07 Test Loss 20.523285348200922


 23%|██▎       | 5810/25000 [01:22<04:16, 74.87it/s]

Epoch 5799 Train Loss 2.9155715564957914e-07 Test Loss 20.20508661594743


 24%|██▎       | 5915/25000 [01:24<04:13, 75.34it/s]

Epoch 5899 Train Loss 2.90326828707299e-07 Test Loss 19.885323227247678


 24%|██▍       | 6010/25000 [01:25<04:15, 74.22it/s]

Epoch 5999 Train Loss 2.8891293394123086e-07 Test Loss 19.554916389796425


 24%|██▍       | 6108/25000 [01:26<04:05, 76.81it/s]

Epoch 6099 Train Loss 2.8773603126044014e-07 Test Loss 19.21504093138153


 25%|██▍       | 6207/25000 [01:28<04:05, 76.58it/s]

Epoch 6199 Train Loss 2.8680250736171074e-07 Test Loss 18.865010810900515


 25%|██▌       | 6309/25000 [01:29<04:02, 76.96it/s]

Epoch 6299 Train Loss 2.8561176977168374e-07 Test Loss 18.509263768809962


 26%|██▌       | 6412/25000 [01:30<04:00, 77.24it/s]

Epoch 6399 Train Loss 2.8430043957909327e-07 Test Loss 18.14266294413403


 26%|██▌       | 6512/25000 [01:32<04:10, 73.77it/s]

Epoch 6499 Train Loss 2.825312123984012e-07 Test Loss 17.752646188787974


 26%|██▋       | 6610/25000 [01:33<04:08, 73.88it/s]

Epoch 6599 Train Loss 2.8120060344810994e-07 Test Loss 17.35692160698781


 27%|██▋       | 6711/25000 [01:34<03:51, 78.94it/s]

Epoch 6699 Train Loss 2.799708003165862e-07 Test Loss 16.938086006308698


 27%|██▋       | 6814/25000 [01:36<03:51, 78.53it/s]

Epoch 6799 Train Loss 2.7834683238978344e-07 Test Loss 16.499976543039022


 28%|██▊       | 6912/25000 [01:37<04:14, 70.96it/s]

Epoch 6899 Train Loss 2.7678695462429774e-07 Test Loss 16.057800532010884


 28%|██▊       | 7012/25000 [01:38<04:14, 70.63it/s]

Epoch 6999 Train Loss 2.75293252929092e-07 Test Loss 15.60040226400072


 28%|██▊       | 7112/25000 [01:40<04:10, 71.38it/s]

Epoch 7099 Train Loss 2.7390094048988866e-07 Test Loss 15.138417619037766


 29%|██▉       | 7208/25000 [01:41<04:12, 70.57it/s]

Epoch 7199 Train Loss 2.7236218480601836e-07 Test Loss 14.669120005081457


 29%|██▉       | 7307/25000 [01:42<04:19, 68.10it/s]

Epoch 7299 Train Loss 2.7040918294495027e-07 Test Loss 14.190413384000125


 30%|██▉       | 7406/25000 [01:44<04:55, 59.50it/s]

Epoch 7399 Train Loss 2.685292818034479e-07 Test Loss 13.691108463311261


 30%|███       | 7514/25000 [01:46<04:07, 70.60it/s]

Epoch 7499 Train Loss 2.665888551585277e-07 Test Loss 13.176750581505775


 30%|███       | 7613/25000 [01:47<03:51, 74.99it/s]

Epoch 7599 Train Loss 2.646548378079809e-07 Test Loss 12.657864013131134


 31%|███       | 7712/25000 [01:48<03:48, 75.82it/s]

Epoch 7699 Train Loss 2.6286237857043647e-07 Test Loss 12.126953585443946


 31%|███       | 7812/25000 [01:49<03:45, 76.08it/s]

Epoch 7799 Train Loss 2.609783719738329e-07 Test Loss 11.605518060952026


 32%|███▏      | 7915/25000 [01:51<03:42, 76.76it/s]

Epoch 7899 Train Loss 2.5913395491366126e-07 Test Loss 11.082619962444994


 32%|███▏      | 8007/25000 [01:52<04:01, 70.50it/s]

Epoch 7999 Train Loss 2.568198552705304e-07 Test Loss 10.562249148211944


 32%|███▏      | 8110/25000 [01:54<04:13, 66.54it/s]

Epoch 8099 Train Loss 2.545343976441687e-07 Test Loss 10.034570144170395


 33%|███▎      | 8211/25000 [01:55<04:20, 64.44it/s]

Epoch 8199 Train Loss 2.5251339217686106e-07 Test Loss 9.510143946990436


 33%|███▎      | 8309/25000 [01:56<03:56, 70.48it/s]

Epoch 8299 Train Loss 2.5022280399971596e-07 Test Loss 8.976907856383106


 34%|███▎      | 8413/25000 [01:58<03:49, 72.38it/s]

Epoch 8399 Train Loss 2.4793993954203637e-07 Test Loss 8.452081783808374


 34%|███▍      | 8510/25000 [01:59<03:50, 71.65it/s]

Epoch 8499 Train Loss 2.453801195198049e-07 Test Loss 7.932935103859077


 34%|███▍      | 8614/25000 [02:01<03:51, 70.88it/s]

Epoch 8599 Train Loss 2.4306658635680963e-07 Test Loss 7.411623784967409


 35%|███▍      | 8710/25000 [02:02<03:47, 71.75it/s]

Epoch 8699 Train Loss 2.3988241702415037e-07 Test Loss 6.86982147199379


 35%|███▌      | 8809/25000 [02:03<03:31, 76.49it/s]

Epoch 8799 Train Loss 2.3688463971622187e-07 Test Loss 6.307859734933592


 36%|███▌      | 8908/25000 [02:05<03:35, 74.69it/s]

Epoch 8899 Train Loss 2.3293301252916343e-07 Test Loss 5.717811423510985


 36%|███▌      | 9015/25000 [02:06<03:31, 75.65it/s]

Epoch 8999 Train Loss 2.2854774948063292e-07 Test Loss 5.0830992871762986


 36%|███▋      | 9108/25000 [02:07<03:27, 76.65it/s]

Epoch 9099 Train Loss 2.2290116299349286e-07 Test Loss 4.393643574512727


 37%|███▋      | 9212/25000 [02:09<03:44, 70.31it/s]

Epoch 9199 Train Loss 2.15933611100567e-07 Test Loss 3.6397791232939825


 37%|███▋      | 9308/25000 [02:10<03:37, 72.26it/s]

Epoch 9299 Train Loss 2.0720214953051507e-07 Test Loss 2.8408244208145703


 38%|███▊      | 9413/25000 [02:11<03:40, 70.68it/s]

Epoch 9399 Train Loss 1.963372442848923e-07 Test Loss 2.0315605629846574


 38%|███▊      | 9507/25000 [02:13<03:31, 73.13it/s]

Epoch 9499 Train Loss 1.8410614282555882e-07 Test Loss 1.2913801103737972


 38%|███▊      | 9611/25000 [02:14<03:32, 72.26it/s]

Epoch 9599 Train Loss 1.7142552020605566e-07 Test Loss 0.7163435207188931


 39%|███▉      | 9710/25000 [02:16<03:12, 79.33it/s]

Epoch 9699 Train Loss 1.5985169550707725e-07 Test Loss 0.3465020142802194


 39%|███▉      | 9809/25000 [02:17<03:16, 77.37it/s]

Epoch 9799 Train Loss 1.5140893898475456e-07 Test Loss 0.151622138129121


 40%|███▉      | 9914/25000 [02:18<03:22, 74.38it/s]

Epoch 9899 Train Loss 1.4577086384008854e-07 Test Loss 0.06544384842766893


 40%|████      | 10006/25000 [02:19<03:17, 76.06it/s]

Epoch 9999 Train Loss 1.4163126203483524e-07 Test Loss 0.029810638128482465


 40%|████      | 10112/25000 [02:21<03:37, 68.48it/s]

Epoch 10099 Train Loss 1.3821808719115138e-07 Test Loss 0.014722881295615642


 41%|████      | 10211/25000 [02:22<03:13, 76.51it/s]

Epoch 10199 Train Loss 1.3556424041284445e-07 Test Loss 0.007870199362558084


 41%|████      | 10312/25000 [02:24<03:13, 75.80it/s]

Epoch 10299 Train Loss 1.3347068616907907e-07 Test Loss 0.004313626518940813


 42%|████▏     | 10409/25000 [02:25<03:32, 68.78it/s]

Epoch 10399 Train Loss 1.3186765540007605e-07 Test Loss 0.002572146830850322


 42%|████▏     | 10509/25000 [02:26<03:02, 79.28it/s]

Epoch 10499 Train Loss 1.3025341314736683e-07 Test Loss 0.0017420771859647077


 42%|████▏     | 10608/25000 [02:28<03:11, 75.18it/s]

Epoch 10599 Train Loss 1.290244153007497e-07 Test Loss 0.0013475500333582338


 43%|████▎     | 10713/25000 [02:29<03:16, 72.66it/s]

Epoch 10699 Train Loss 1.2802337733204247e-07 Test Loss 0.0011715980514412944


 43%|████▎     | 10813/25000 [02:30<03:06, 75.98it/s]

Epoch 10799 Train Loss 1.271837025456966e-07 Test Loss 0.0010730035310827237


 44%|████▎     | 10908/25000 [02:32<04:09, 56.47it/s]

Epoch 10899 Train Loss 1.2647612032670178e-07 Test Loss 0.0009967391583429832


 44%|████▍     | 11008/25000 [02:33<03:45, 61.99it/s]

Epoch 10999 Train Loss 1.2579448446004914e-07 Test Loss 0.000934499771289036


 44%|████▍     | 11116/25000 [02:35<02:51, 80.86it/s]

Epoch 11099 Train Loss 1.251653654100727e-07 Test Loss 0.0008899215494922001


 45%|████▍     | 11213/25000 [02:36<03:09, 72.57it/s]

Epoch 11199 Train Loss 1.2460620801935315e-07 Test Loss 0.0008433434485563978


 45%|████▌     | 11310/25000 [02:37<03:04, 74.07it/s]

Epoch 11299 Train Loss 1.2402489069942644e-07 Test Loss 0.0007877785085405835


 46%|████▌     | 11415/25000 [02:39<03:04, 73.60it/s]

Epoch 11399 Train Loss 1.2342586255334821e-07 Test Loss 0.0007398541953754036


 46%|████▌     | 11509/25000 [02:40<03:10, 70.80it/s]

Epoch 11499 Train Loss 1.2271071620005568e-07 Test Loss 0.0006930672170713878


 46%|████▋     | 11604/25000 [02:42<03:42, 60.23it/s]

Epoch 11599 Train Loss 1.219844349292144e-07 Test Loss 0.0006337242359489126


 47%|████▋     | 11709/25000 [02:44<03:36, 61.46it/s]

Epoch 11699 Train Loss 1.2111334400098643e-07 Test Loss 0.0005426081581425038


 47%|████▋     | 11807/25000 [02:45<03:23, 64.89it/s]

Epoch 11799 Train Loss 1.2010678292403675e-07 Test Loss 0.00044555674819142386


 48%|████▊     | 11908/25000 [02:46<03:27, 63.23it/s]

Epoch 11899 Train Loss 1.1891361582136235e-07 Test Loss 0.0003390286728517284


 48%|████▊     | 12011/25000 [02:48<03:12, 67.44it/s]

Epoch 11999 Train Loss 1.176419650020845e-07 Test Loss 0.00025133643867330895


 48%|████▊     | 12110/25000 [02:50<03:31, 60.93it/s]

Epoch 12099 Train Loss 1.1639730179668203e-07 Test Loss 0.00018178213186768953


 49%|████▉     | 12212/25000 [02:51<03:12, 66.29it/s]

Epoch 12199 Train Loss 1.1529846177193116e-07 Test Loss 0.00013076478631227396


 49%|████▉     | 12308/25000 [02:52<02:50, 74.58it/s]

Epoch 12299 Train Loss 1.1440970609753753e-07 Test Loss 9.47917630579517e-05


 50%|████▉     | 12409/25000 [02:54<02:45, 76.21it/s]

Epoch 12399 Train Loss 1.1363670289485555e-07 Test Loss 6.93431829453616e-05


 50%|█████     | 12516/25000 [02:55<02:40, 77.58it/s]

Epoch 12499 Train Loss 1.1293777723262341e-07 Test Loss 5.1311239007552394e-05


 50%|█████     | 12614/25000 [02:57<02:45, 74.95it/s]

Epoch 12599 Train Loss 1.1229406965869952e-07 Test Loss 3.990848296937766e-05


 51%|█████     | 12709/25000 [02:58<03:01, 67.90it/s]

Epoch 12699 Train Loss 1.1174869001087836e-07 Test Loss 3.207152849273008e-05


 51%|█████     | 12808/25000 [02:59<02:50, 71.38it/s]

Epoch 12799 Train Loss 1.1123944299885109e-07 Test Loss 2.6101392870859685e-05


 52%|█████▏    | 12911/25000 [03:01<02:44, 73.66it/s]

Epoch 12899 Train Loss 1.1077826039470722e-07 Test Loss 2.1444161092029873e-05


 52%|█████▏    | 13007/25000 [03:02<02:47, 71.39it/s]

Epoch 12999 Train Loss 1.103732492425311e-07 Test Loss 1.7144547143166407e-05


 52%|█████▏    | 13113/25000 [03:03<02:43, 72.61it/s]

Epoch 13099 Train Loss 1.0998566059522159e-07 Test Loss 1.382429945397482e-05


 53%|█████▎    | 13209/25000 [03:05<02:39, 73.73it/s]

Epoch 13199 Train Loss 1.0965001336507203e-07 Test Loss 1.1403264707399515e-05


 53%|█████▎    | 13309/25000 [03:06<02:42, 71.89it/s]

Epoch 13299 Train Loss 1.0931440944540612e-07 Test Loss 9.484272197309544e-06


 54%|█████▎    | 13413/25000 [03:08<02:37, 73.76it/s]

Epoch 13399 Train Loss 1.0904695240001256e-07 Test Loss 8.077506255317985e-06


 54%|█████▍    | 13510/25000 [03:09<02:44, 69.72it/s]

Epoch 13499 Train Loss 1.0879868361377541e-07 Test Loss 7.000329577906635e-06


 54%|█████▍    | 13613/25000 [03:11<02:44, 69.17it/s]

Epoch 13599 Train Loss 1.0856225414782012e-07 Test Loss 6.197768861864824e-06


 55%|█████▍    | 13707/25000 [03:12<02:48, 67.19it/s]

Epoch 13699 Train Loss 1.083469721020195e-07 Test Loss 5.505000338078414e-06


 55%|█████▌    | 13808/25000 [03:13<02:35, 71.99it/s]

Epoch 13799 Train Loss 1.0815931749008464e-07 Test Loss 4.9612574454960334e-06


 56%|█████▌    | 13914/25000 [03:15<02:29, 74.15it/s]

Epoch 13899 Train Loss 1.0798272010497754e-07 Test Loss 4.494111098182459e-06


 56%|█████▌    | 14011/25000 [03:16<02:35, 70.65it/s]

Epoch 13999 Train Loss 1.078404939476727e-07 Test Loss 4.063017899620005e-06


 56%|█████▋    | 14107/25000 [03:17<02:30, 72.36it/s]

Epoch 14099 Train Loss 1.0769575741886897e-07 Test Loss 3.6309727198118585e-06


 57%|█████▋    | 14211/25000 [03:19<02:39, 67.63it/s]

Epoch 14199 Train Loss 1.0757138228882828e-07 Test Loss 3.356028863515324e-06


 57%|█████▋    | 14309/25000 [03:20<02:43, 65.29it/s]

Epoch 14299 Train Loss 1.0746070967677824e-07 Test Loss 3.1496969390816283e-06


 58%|█████▊    | 14412/25000 [03:22<02:18, 76.35it/s]

Epoch 14399 Train Loss 1.0734724542505191e-07 Test Loss 2.984655298122725e-06


 58%|█████▊    | 14508/25000 [03:23<02:19, 75.12it/s]

Epoch 14499 Train Loss 1.0723340970760933e-07 Test Loss 2.8525096606835604e-06


 58%|█████▊    | 14612/25000 [03:25<02:26, 70.75it/s]

Epoch 14599 Train Loss 1.07140001683221e-07 Test Loss 2.7360216174108344e-06


 59%|█████▉    | 14708/25000 [03:26<02:21, 72.50it/s]

Epoch 14699 Train Loss 1.0704951796159072e-07 Test Loss 2.6515578494166534e-06


 59%|█████▉    | 14812/25000 [03:27<02:21, 71.98it/s]

Epoch 14799 Train Loss 1.0697083629594372e-07 Test Loss 2.5892009901776987e-06


 60%|█████▉    | 14907/25000 [03:29<02:27, 68.53it/s]

Epoch 14899 Train Loss 1.0688437302558085e-07 Test Loss 2.529160036430591e-06


 60%|██████    | 15009/25000 [03:30<02:21, 70.51it/s]

Epoch 14999 Train Loss 1.0681454196252889e-07 Test Loss 2.4765433876142283e-06


 60%|██████    | 15114/25000 [03:32<02:09, 76.08it/s]

Epoch 15099 Train Loss 1.0674655967355178e-07 Test Loss 2.416038679123548e-06


 61%|██████    | 15211/25000 [03:33<02:15, 72.44it/s]

Epoch 15199 Train Loss 1.0668263928535006e-07 Test Loss 2.366328484864973e-06


 61%|██████    | 15306/25000 [03:35<02:39, 60.59it/s]

Epoch 15299 Train Loss 1.0662485196025245e-07 Test Loss 2.31794173258559e-06


 62%|██████▏   | 15410/25000 [03:36<02:14, 71.38it/s]

Epoch 15399 Train Loss 1.0656451468986441e-07 Test Loss 2.2676312852523807e-06


 62%|██████▏   | 15509/25000 [03:37<02:06, 74.77it/s]

Epoch 15499 Train Loss 1.0650681165070128e-07 Test Loss 2.2270680890324566e-06


 62%|██████▏   | 15614/25000 [03:39<02:04, 75.57it/s]

Epoch 15599 Train Loss 1.0645732790897216e-07 Test Loss 2.185601392902501e-06


 63%|██████▎   | 15710/25000 [03:40<02:14, 69.25it/s]

Epoch 15699 Train Loss 1.0640451909159229e-07 Test Loss 2.1589563612013092e-06


 63%|██████▎   | 15813/25000 [03:42<02:17, 66.79it/s]

Epoch 15799 Train Loss 1.0635857878523259e-07 Test Loss 2.1377725783266917e-06


 64%|██████▎   | 15914/25000 [03:43<02:11, 69.31it/s]

Epoch 15899 Train Loss 1.0630473686449132e-07 Test Loss 2.1091895490353037e-06


 64%|██████▍   | 16009/25000 [03:44<02:05, 71.72it/s]

Epoch 15999 Train Loss 1.0625621766808185e-07 Test Loss 2.0770450128292105e-06


 64%|██████▍   | 16113/25000 [03:46<02:01, 73.00it/s]

Epoch 16099 Train Loss 1.0621106316890847e-07 Test Loss 2.056481513573184e-06


 65%|██████▍   | 16210/25000 [03:47<02:04, 70.78it/s]

Epoch 16199 Train Loss 1.0617205331302617e-07 Test Loss 2.0456611361087537e-06


 65%|██████▌   | 16308/25000 [03:49<01:54, 76.02it/s]

Epoch 16299 Train Loss 1.0613278086117012e-07 Test Loss 2.039038839913958e-06


 66%|██████▌   | 16413/25000 [03:50<02:00, 70.97it/s]

Epoch 16399 Train Loss 1.0609182172829326e-07 Test Loss 2.027197887664064e-06


 66%|██████▌   | 16511/25000 [03:51<01:59, 71.12it/s]

Epoch 16499 Train Loss 1.0605782533595184e-07 Test Loss 2.01538453591105e-06


 66%|██████▋   | 16608/25000 [03:53<01:55, 72.48it/s]

Epoch 16599 Train Loss 1.0601888843283616e-07 Test Loss 2.0051437449259196e-06


 67%|██████▋   | 16711/25000 [03:54<01:59, 69.46it/s]

Epoch 16699 Train Loss 1.059868548589035e-07 Test Loss 1.9944530999540098e-06


 67%|██████▋   | 16809/25000 [03:55<01:51, 73.64it/s]

Epoch 16799 Train Loss 1.0595166157475801e-07 Test Loss 1.9875930950966446e-06


 68%|██████▊   | 16908/25000 [03:57<01:44, 77.16it/s]

Epoch 16899 Train Loss 1.0591695946480174e-07 Test Loss 1.9821031475569755e-06


 68%|██████▊   | 17013/25000 [03:58<01:44, 76.71it/s]

Epoch 16999 Train Loss 1.0588678700471054e-07 Test Loss 1.972027237628707e-06


 68%|██████▊   | 17112/25000 [04:00<01:41, 78.03it/s]

Epoch 17099 Train Loss 1.0585777672882591e-07 Test Loss 1.9598304243609484e-06


 69%|██████▉   | 17212/25000 [04:01<01:43, 75.26it/s]

Epoch 17199 Train Loss 1.0582835231632938e-07 Test Loss 1.9509139401931413e-06


 69%|██████▉   | 17308/25000 [04:02<01:41, 75.92it/s]

Epoch 17299 Train Loss 1.0580002624642466e-07 Test Loss 1.9464279566245363e-06


 70%|██████▉   | 17412/25000 [04:04<01:47, 70.82it/s]

Epoch 17399 Train Loss 1.0576881102722291e-07 Test Loss 1.93602446697336e-06


 70%|███████   | 17506/25000 [04:05<01:48, 68.83it/s]

Epoch 17499 Train Loss 1.0574051990026566e-07 Test Loss 1.9275133867581e-06


 70%|███████   | 17610/25000 [04:06<01:48, 67.93it/s]

Epoch 17599 Train Loss 1.0571375681092089e-07 Test Loss 1.919191783596013e-06


 71%|███████   | 17714/25000 [04:08<01:43, 70.40it/s]

Epoch 17699 Train Loss 1.0568401442803539e-07 Test Loss 1.9078409896307775e-06


 71%|███████   | 17810/25000 [04:09<01:41, 70.56it/s]

Epoch 17799 Train Loss 1.0565919883331851e-07 Test Loss 1.9018271024278831e-06


 72%|███████▏  | 17913/25000 [04:11<01:42, 69.42it/s]

Epoch 17899 Train Loss 1.0563113638526855e-07 Test Loss 1.8923332067574675e-06


 72%|███████▏  | 18014/25000 [04:12<01:34, 73.83it/s]

Epoch 17999 Train Loss 1.0560597490279126e-07 Test Loss 1.8859727898653432e-06


 72%|███████▏  | 18107/25000 [04:14<01:40, 68.90it/s]

Epoch 18099 Train Loss 1.0558499645129079e-07 Test Loss 1.8766785418496968e-06


 73%|███████▎  | 18210/25000 [04:15<01:40, 67.62it/s]

Epoch 18199 Train Loss 1.0555743810924986e-07 Test Loss 1.8706399460485149e-06


 73%|███████▎  | 18311/25000 [04:16<01:33, 71.22it/s]

Epoch 18299 Train Loss 1.0553770116667843e-07 Test Loss 1.8615875633383227e-06


 74%|███████▎  | 18410/25000 [04:18<01:28, 74.43it/s]

Epoch 18399 Train Loss 1.0551448703508982e-07 Test Loss 1.8512259887207573e-06


 74%|███████▍  | 18514/25000 [04:19<01:29, 72.58it/s]

Epoch 18499 Train Loss 1.0549097722358918e-07 Test Loss 1.8438195879865664e-06


 74%|███████▍  | 18608/25000 [04:21<01:24, 75.72it/s]

Epoch 18599 Train Loss 1.0546891492632813e-07 Test Loss 1.836048597979857e-06


 75%|███████▍  | 18713/25000 [04:22<01:26, 72.95it/s]

Epoch 18699 Train Loss 1.0544843644229375e-07 Test Loss 1.8316857931161636e-06


 75%|███████▌  | 18809/25000 [04:23<01:27, 70.39it/s]

Epoch 18799 Train Loss 1.0543002477135311e-07 Test Loss 1.8241865187741702e-06


 76%|███████▌  | 18913/25000 [04:25<01:24, 72.34it/s]

Epoch 18899 Train Loss 1.0540910699377598e-07 Test Loss 1.821478857798584e-06


 76%|███████▌  | 19011/25000 [04:26<01:25, 70.24it/s]

Epoch 18999 Train Loss 1.0538809948258481e-07 Test Loss 1.8150605304893517e-06


 76%|███████▋  | 19108/25000 [04:27<01:18, 75.26it/s]

Epoch 19099 Train Loss 1.0536838884130109e-07 Test Loss 1.8117611466168893e-06


 77%|███████▋  | 19213/25000 [04:29<01:23, 69.14it/s]

Epoch 19199 Train Loss 1.053506313816076e-07 Test Loss 1.807382612247743e-06


 77%|███████▋  | 19309/25000 [04:30<01:20, 70.96it/s]

Epoch 19299 Train Loss 1.0533267815946008e-07 Test Loss 1.8007979356144699e-06


 78%|███████▊  | 19406/25000 [04:32<01:23, 66.81it/s]

Epoch 19399 Train Loss 1.0531459286957832e-07 Test Loss 1.7925350196325715e-06


 78%|███████▊  | 19509/25000 [04:33<01:13, 74.63it/s]

Epoch 19499 Train Loss 1.0529512599368209e-07 Test Loss 1.7872037786889846e-06


 78%|███████▊  | 19612/25000 [04:34<01:16, 70.72it/s]

Epoch 19599 Train Loss 1.0527673471034028e-07 Test Loss 1.7801835178451171e-06


 79%|███████▉  | 19710/25000 [04:36<01:13, 72.22it/s]

Epoch 19699 Train Loss 1.0525979316701544e-07 Test Loss 1.7738483550882528e-06


 79%|███████▉  | 19814/25000 [04:37<01:09, 74.25it/s]

Epoch 19799 Train Loss 1.0523846409186604e-07 Test Loss 1.768822291026872e-06


 80%|███████▉  | 19914/25000 [04:38<01:04, 78.86it/s]

Epoch 19899 Train Loss 1.052219206229391e-07 Test Loss 1.7609573858216758e-06


 80%|████████  | 20011/25000 [04:40<01:06, 74.47it/s]

Epoch 19999 Train Loss 1.0520298829138245e-07 Test Loss 1.7585725481645539e-06


 80%|████████  | 20109/25000 [04:41<01:07, 72.45it/s]

Epoch 20099 Train Loss 1.0518228520981301e-07 Test Loss 1.7554244485021704e-06


 81%|████████  | 20213/25000 [04:42<01:03, 75.03it/s]

Epoch 20199 Train Loss 1.0516641738782593e-07 Test Loss 1.7524084356865356e-06


 81%|████████  | 20312/25000 [04:44<01:02, 75.43it/s]

Epoch 20299 Train Loss 1.0515050479674872e-07 Test Loss 1.7476090918328e-06


 82%|████████▏ | 20409/25000 [04:45<01:04, 71.46it/s]

Epoch 20399 Train Loss 1.0513406980397708e-07 Test Loss 1.7443616047501409e-06


 82%|████████▏ | 20513/25000 [04:47<01:03, 71.09it/s]

Epoch 20499 Train Loss 1.0511650377758526e-07 Test Loss 1.738905310510711e-06


 82%|████████▏ | 20609/25000 [04:48<01:03, 68.97it/s]

Epoch 20599 Train Loss 1.0510032398817974e-07 Test Loss 1.7378211454573428e-06


 83%|████████▎ | 20713/25000 [04:49<01:01, 69.34it/s]

Epoch 20699 Train Loss 1.0508660701192154e-07 Test Loss 1.7351899083135559e-06


 83%|████████▎ | 20807/25000 [04:51<00:58, 71.84it/s]

Epoch 20799 Train Loss 1.0507403126774261e-07 Test Loss 1.7316185800899264e-06


 84%|████████▎ | 20910/25000 [04:52<00:59, 68.45it/s]

Epoch 20899 Train Loss 1.0505867597879213e-07 Test Loss 1.729279665561259e-06


 84%|████████▍ | 21012/25000 [04:54<00:55, 72.02it/s]

Epoch 20999 Train Loss 1.050442850947373e-07 Test Loss 1.7273777186300687e-06


 84%|████████▍ | 21113/25000 [04:55<00:54, 70.75it/s]

Epoch 21099 Train Loss 1.0503149309870066e-07 Test Loss 1.7249937662777862e-06


 85%|████████▍ | 21210/25000 [04:56<00:52, 72.79it/s]

Epoch 21199 Train Loss 1.0501737289607768e-07 Test Loss 1.719826077140742e-06


 85%|████████▌ | 21313/25000 [04:58<00:52, 69.70it/s]

Epoch 21299 Train Loss 1.050020745306358e-07 Test Loss 1.7164978405008476e-06


 86%|████████▌ | 21408/25000 [04:59<00:48, 74.13it/s]

Epoch 21399 Train Loss 1.049864440543785e-07 Test Loss 1.7189782646910824e-06


 86%|████████▌ | 21514/25000 [05:01<00:47, 73.23it/s]

Epoch 21499 Train Loss 1.0497604756884643e-07 Test Loss 1.7191928232821428e-06


 86%|████████▋ | 21613/25000 [05:02<00:44, 75.75it/s]

Epoch 21599 Train Loss 1.0496420526060853e-07 Test Loss 1.7157500524122355e-06


 87%|████████▋ | 21710/25000 [05:03<00:47, 69.57it/s]

Epoch 21699 Train Loss 1.0495161449218599e-07 Test Loss 1.714596622204201e-06


 87%|████████▋ | 21813/25000 [05:05<00:47, 67.78it/s]

Epoch 21799 Train Loss 1.0493820116210327e-07 Test Loss 1.7126521096898039e-06


 88%|████████▊ | 21912/25000 [05:06<00:42, 72.83it/s]

Epoch 21899 Train Loss 1.0492468286229981e-07 Test Loss 1.7066606354250027e-06


 88%|████████▊ | 22010/25000 [05:08<00:38, 77.00it/s]

Epoch 21999 Train Loss 1.0491211278641289e-07 Test Loss 1.6998846692889873e-06


 88%|████████▊ | 22109/25000 [05:09<00:38, 75.53it/s]

Epoch 22099 Train Loss 1.049008071899873e-07 Test Loss 1.6940616679400983e-06


 89%|████████▉ | 22213/25000 [05:10<00:38, 72.92it/s]

Epoch 22199 Train Loss 1.0489010960755097e-07 Test Loss 1.6905630623699455e-06


 89%|████████▉ | 22311/25000 [05:12<00:35, 76.37it/s]

Epoch 22299 Train Loss 1.0487343337303787e-07 Test Loss 1.6863912313989198e-06


 90%|████████▉ | 22408/25000 [05:13<00:34, 75.96it/s]

Epoch 22399 Train Loss 1.0486224427219556e-07 Test Loss 1.677332843055664e-06


 90%|█████████ | 22513/25000 [05:14<00:31, 78.34it/s]

Epoch 22499 Train Loss 1.0484748522593423e-07 Test Loss 1.6706129445822796e-06


 90%|█████████ | 22610/25000 [05:16<00:32, 73.68it/s]

Epoch 22599 Train Loss 1.0483863708874513e-07 Test Loss 1.6650026826255346e-06


 91%|█████████ | 22709/25000 [05:17<00:29, 78.49it/s]

Epoch 22699 Train Loss 1.0482708019374276e-07 Test Loss 1.6599671474088716e-06


 91%|█████████▏| 22813/25000 [05:18<00:29, 74.89it/s]

Epoch 22799 Train Loss 1.0481763741588545e-07 Test Loss 1.657268989339847e-06


 92%|█████████▏| 22910/25000 [05:20<00:30, 68.85it/s]

Epoch 22899 Train Loss 1.0480466305802181e-07 Test Loss 1.654115113682907e-06


 92%|█████████▏| 23010/25000 [05:21<00:27, 71.67it/s]

Epoch 22999 Train Loss 1.0479471826635084e-07 Test Loss 1.6506546950281893e-06


 92%|█████████▏| 23114/25000 [05:23<00:25, 74.63it/s]

Epoch 23099 Train Loss 1.0478349840021507e-07 Test Loss 1.6483501628939001e-06


 93%|█████████▎| 23211/25000 [05:24<00:24, 73.45it/s]

Epoch 23199 Train Loss 1.0477259746308884e-07 Test Loss 1.6458085273140132e-06


 93%|█████████▎| 23314/25000 [05:25<00:23, 71.54it/s]

Epoch 23299 Train Loss 1.047616536216008e-07 Test Loss 1.6435258398370327e-06


 94%|█████████▎| 23410/25000 [05:27<00:20, 76.14it/s]

Epoch 23399 Train Loss 1.0475178727877976e-07 Test Loss 1.641239683291118e-06


 94%|█████████▍| 23510/25000 [05:28<00:25, 59.41it/s]

Epoch 23499 Train Loss 1.0474242947817992e-07 Test Loss 1.6339793701129578e-06


 94%|█████████▍| 23613/25000 [05:30<00:19, 71.31it/s]

Epoch 23599 Train Loss 1.0473183645846374e-07 Test Loss 1.629125834434261e-06


 95%|█████████▍| 23710/25000 [05:31<00:17, 75.32it/s]

Epoch 23699 Train Loss 1.0472095019724206e-07 Test Loss 1.625003700254825e-06


 95%|█████████▌| 23814/25000 [05:32<00:16, 72.18it/s]

Epoch 23799 Train Loss 1.0470882695472108e-07 Test Loss 1.6205982017950928e-06


 96%|█████████▌| 23914/25000 [05:34<00:14, 77.33it/s]

Epoch 23899 Train Loss 1.0469993864399566e-07 Test Loss 1.6174356818534536e-06


 96%|█████████▌| 24012/25000 [05:35<00:13, 74.55it/s]

Epoch 23999 Train Loss 1.0469068031051029e-07 Test Loss 1.6167986590906887e-06


 96%|█████████▋| 24108/25000 [05:36<00:12, 73.80it/s]

Epoch 24099 Train Loss 1.046813279836813e-07 Test Loss 1.6137556339471372e-06


 97%|█████████▋| 24213/25000 [05:38<00:10, 74.62it/s]

Epoch 24199 Train Loss 1.0467255699257133e-07 Test Loss 1.6144069440504933e-06


 97%|█████████▋| 24310/25000 [05:39<00:09, 74.97it/s]

Epoch 24299 Train Loss 1.046607416912662e-07 Test Loss 1.6139416355792181e-06


 98%|█████████▊| 24411/25000 [05:41<00:07, 77.80it/s]

Epoch 24399 Train Loss 1.046539611574896e-07 Test Loss 1.6135012053780714e-06


 98%|█████████▊| 24508/25000 [05:42<00:06, 76.10it/s]

Epoch 24499 Train Loss 1.0464229104832984e-07 Test Loss 1.6150257920769275e-06


 98%|█████████▊| 24616/25000 [05:43<00:04, 78.85it/s]

Epoch 24599 Train Loss 1.0463224325544333e-07 Test Loss 1.6154187475027898e-06


 99%|█████████▉| 24713/25000 [05:44<00:03, 76.74it/s]

Epoch 24699 Train Loss 1.0462527746497883e-07 Test Loss 1.617940291353104e-06


 99%|█████████▉| 24815/25000 [05:46<00:02, 80.50it/s]

Epoch 24799 Train Loss 1.0461762170173321e-07 Test Loss 1.6180717313785382e-06


100%|█████████▉| 24914/25000 [05:47<00:01, 75.69it/s]

Epoch 24899 Train Loss 1.046074040895412e-07 Test Loss 1.6197475606072706e-06


100%|██████████| 25000/25000 [05:48<00:00, 71.68it/s]

Epoch 24999 Train Loss 1.0459810010931098e-07 Test Loss 1.6177260081173516e-06





In [101]:
torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)

In [102]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION, weights_only=False)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    train_indices = cached_data["train_indices"]
    test_indices = cached_data["test_indices"]

## Show Model Training Statistics, Check that it groks!

In [407]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)

/home/burny/projects/ai/mechinterp/.venv/bin/python: No module named pip
Note: you may need to restart the kernel to use updated packages.


# Analysing the Model

## Standard Things to Try

In [408]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())

4328691


Get key weight matrices:

In [409]:
W_E = model.embed.W_E[:-1]
print("W_E", W_E.shape)
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur", W_neur.shape)
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

W_E torch.Size([113, 128])
W_neur torch.Size([4, 113, 512])
W_logit torch.Size([512, 113])


In [410]:
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)

Original Loss: 1.1638705372862433e-06


### Looking at Activations

Helper variable:

In [411]:
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]
neuron_acts = cache["post", 0, "mlp"][:, -1, :]
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]

Get all shapes:

In [412]:
for param_name, param in cache.items():
    print(param_name, param.shape)

hook_embed torch.Size([12769, 3, 128])
hook_pos_embed torch.Size([12769, 3, 128])
blocks.0.hook_resid_pre torch.Size([12769, 3, 128])
blocks.0.attn.hook_q torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_k torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_v torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_attn_scores torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_pattern torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_z torch.Size([12769, 3, 4, 32])
blocks.0.hook_attn_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_mid torch.Size([12769, 3, 128])
blocks.0.mlp.hook_pre torch.Size([12769, 3, 512])
blocks.0.mlp.hook_post torch.Size([12769, 3, 512])
blocks.0.hook_mlp_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_post torch.Size([12769, 3, 128])


In [413]:
imshow(cache["pattern", 0].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [414]:
imshow(cache["pattern", 0][5][:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [415]:
dataset[:4]

tensor([[  0,   0, 113],
        [  0,   1, 113],
        [  0,   2, 113],
        [  0,   3, 113]], device='cuda:0')

In [416]:
imshow(cache["pattern", 0][:, 0, -1, 0].reshape(p, p), title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a")

In [417]:
imshow(
    einops.rearrange(cache["pattern", 0][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
    title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)

Plotting neuron activations

In [418]:
cache["post", 0, "mlp"].shape

torch.Size([12769, 3, 512])

In [419]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

### Singular Value Decomposition

In [420]:
W_E.shape

torch.Size([113, 128])

In [421]:
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")

In [422]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")

## Explaining Algorithm

### Analyse the Embedding - It's a Lookup Table!

In [423]:
U, S, Vh = torch.svd(W_E)
line(U[:, :8].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")

In [424]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+1):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).to(device)
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)

In [425]:
line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")

In [426]:
imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")

### Analyse the Embedding

In [427]:
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

In [428]:
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

In [429]:
key_freqs = [9, 33, 36, 38, 55]
key_freq_indices = [i for f in key_freqs for i in (2*f - 1, 2*f)]
# key_freqs = [17, 25, 32, 47]
# key_freq_indices = [33, 34, 49, 50, 63, 64, 93, 94]
fourier_embed = fourier_basis @ W_E
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")

key_fourier_embed torch.Size([10, 128])


### Key Frequencies

In [430]:
cos_indices = [2*f for f in key_freqs]
line(fourier_basis[cos_indices], title="Cos of key freqs", line_labels=cos_indices)

In [431]:
line(fourier_basis[[2*f for f in key_freqs]].mean(0), title="Constructive Interference")

## Analyse Neurons

In [432]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [433]:
imshow(
    einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p),
    title="First neuron act", xaxis="b", yaxis="a",)

In [434]:
example_freq = key_freqs[0]
cos_idx = 2 * example_freq
imshow(fourier_basis[cos_idx][None, :] * fourier_basis[cos_idx][:, None], title=f"Cos {example_freq}a * cos {example_freq}b")

In [435]:
imshow(fourier_basis[cos_idx][None, :] * fourier_basis[0][:, None], title=f"Cos {example_freq}a * const")

In [436]:
imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [437]:
imshow(fourier_basis @ neuron_acts[:, 5].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 5", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [438]:
imshow(fourier_basis @ torch.randn_like(neuron_acts[:, 0]).reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of RANDOM", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

### Neuron Clusters

In [439]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

fourier_neuron_acts torch.Size([512, 113, 113])


In [440]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(device)
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [441]:
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")

## Read Off the Neuron-Logit Weights to Interpret

In [442]:
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

W_logit torch.Size([512, 113])


In [443]:
line((W_logit @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit in the Fourier Basis")

In [444]:
study_freq = key_freqs[0]
neurons_freq = neuron_freq_norm[study_freq-1]>0.85
neurons_freq.shape

torch.Size([512])

In [445]:
neurons_freq.sum()

tensor(42, device='cuda:0')

In [446]:
line((W_logit[neurons_freq] @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title=f"W_logit for freq {study_freq} neurons in the Fourier Basis")

Study sin of first key freq

In [447]:
W_logit_fourier = W_logit @ fourier_basis
neurons_sin_freq = W_logit_fourier[:, 2*study_freq-1]
line(neurons_sin_freq)

In [448]:
neuron_acts.shape

torch.Size([12769, 512])

In [449]:
inputs_sin_freq = neuron_acts @ neurons_sin_freq
imshow(fourier_basis @ inputs_sin_freq.reshape(p, p) @ fourier_basis.T, title=f"Fourier Heatmap over inputs for sin {study_freq} component", x=fourier_basis_names, y=fourier_basis_names)

# Black Box Methods + Progress Measures

## Setup Code

Code to plot embedding freqs

In [450]:
def embed_to_cos_sin(fourier_embed):
    if len(fourier_embed.shape) == 1:
        return torch.stack([fourier_embed[1::2], fourier_embed[2::2]])
    else:
        return torch.stack([fourier_embed[:, 1::2], fourier_embed[:, 2::2]], dim=1)

from neel_plotly.plot import melt

def plot_embed_bars(
    fourier_embed,
    title="Norm of embedding of each Fourier Component",
    return_fig=False,
    **kwargs
):
    cos_sin_embed = embed_to_cos_sin(fourier_embed)
    df = melt(cos_sin_embed)
    # display(df)
    group_labels = {0: "sin", 1: "cos"}
    df["Trig"] = df["0"].map(lambda x: group_labels[x])
    fig = px.bar(
        df,
        barmode="group",
        color="Trig",
        x="1",
        y="value",
        labels={"1": "$w_k$", "value": "Norm"},
        title=title,
        **kwargs
    )
    fig.update_layout(dict(legend_title=""))

    if return_fig:
        return fig
    else:
        fig.show()

Code to test a tensor of edited logits

In [451]:
def test_logits(logits, bias_correction=False, original_logits=None, mode="all"):
    # Calculates cross entropy loss of logits representing a batch of all p^2
    # possible inputs
    # Batch dimension is assumed to be first
    if logits.shape[1] == p * p:
        logits = logits.T
    if logits.shape == torch.Size([p * p, p + 1]):
        logits = logits[:, :-1]
    logits = logits.reshape(p * p, p)
    if bias_correction:
        # Applies bias correction - we correct for any missing bias terms,
        # independent of the input, by centering the new logits along the batch
        # dimension, and then adding the average original logits across all inputs
        logits = (
            einops.reduce(original_logits - logits, "batch ... -> ...", "mean") + logits
        )
    if mode == "train":
        return loss_fn(logits[train_indices], labels[train_indices])
    elif mode == "test":
        return loss_fn(logits[test_indices], labels[test_indices])
    elif mode == "all":
        return loss_fn(logits, labels)

Code to run a metric over every checkpoint

In [452]:
metric_cache = {}

In [453]:
def get_metrics(model, metric_cache, metric_fn, name, reset=False):
    if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
        metric_cache[name] = []
        for c, sd in enumerate(tqdm.tqdm((model_checkpoints))):
            model.reset_hooks()
            model.load_state_dict(sd)
            out = metric_fn(model)
            if type(out) == torch.Tensor:
                out = utils.to_numpy(out)
            metric_cache[name].append(out)
        model.load_state_dict(model_checkpoints[-1])
        try:
            metric_cache[name] = torch.tensor(metric_cache[name])
        except:
            metric_cache[name] = torch.tensor(np.array(metric_cache[name]))



## Defining Progress Measures

### Loss Curves

In [454]:
memorization_end_epoch = 1500
circuit_formation_end_epoch = 13300
cleanup_end_epoch = 16600

In [455]:
def add_lines(figure):
    figure.add_vline(memorization_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(circuit_formation_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(cleanup_end_epoch, line_dash="dash", opacity=0.7)
    return figure

In [456]:
fig = line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

### Logit Periodicity

In [457]:
all_logits = original_logits[:, -1, :]
print(all_logits.shape)
all_logits = einops.rearrange(all_logits, "(a b) c -> a b c", a=p, b=p)
print(all_logits.shape)

torch.Size([12769, 113])
torch.Size([113, 113, 113])


In [458]:
coses = {}
for freq in key_freqs:
    print("Freq:", freq)
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)
    cube_predicted_logits /= cube_predicted_logits.norm()
    coses[freq] = cube_predicted_logits

Freq: 9
Freq: 33
Freq: 36
Freq: 38
Freq: 55


In [459]:
approximated_logits = torch.zeros_like(all_logits)
for freq in key_freqs:
    print("Freq:", freq)
    coeff = (all_logits * coses[freq]).sum()
    print("Coeff:", coeff)
    cosine_sim = coeff / all_logits.norm()
    print("Cosine Sim:", cosine_sim)
    approximated_logits += coeff * coses[freq]
residual = all_logits - approximated_logits
print("Residual size:", residual.norm())
print("Residual fraction of norm:", residual.norm()/all_logits.norm())

Freq: 9
Coeff: tensor(8933.3936, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.5159, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 33
Coeff: tensor(3771.6543, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.2178, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 36
Coeff: tensor(11305.8154, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.6529, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 38
Coeff: tensor(7328.0176, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.4232, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 55
Coeff: tensor(3708.1748, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.2141, device='cuda:0', grad_fn=<DivBackward0>)
Residual size: tensor(3251.7063, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Residual fraction of norm: tensor(0.1878, device='cuda:0', grad_fn=<DivBackward0>)


In [460]:
random_logit_cube = torch.randn_like(all_logits)
print((all_logits * random_logit_cube).sum()/random_logit_cube.norm()/all_logits.norm())

tensor(-0.0011, device='cuda:0', grad_fn=<DivBackward0>)


In [461]:
test_logits(all_logits)

tensor(1.1639e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)

In [462]:
test_logits(approximated_logits)

tensor(3.2529e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)

#### Look During Training

In [463]:
cos_cube = []
for freq in range(1, p//2 + 1):
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)
    cube_predicted_logits /= cube_predicted_logits.norm()
    cos_cube.append(cube_predicted_logits)
cos_cube = torch.stack(cos_cube, dim=0)
print(cos_cube.shape)

torch.Size([56, 113, 113, 113])


In [464]:
def get_cos_coeffs(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals


get_metrics(model, metric_cache, get_cos_coeffs, "cos_coeffs")
print(metric_cache["cos_coeffs"].shape)

100%|██████████| 250/250 [00:02<00:00, 90.50it/s] 

torch.Size([250, 56])





In [465]:
fig = line(metric_cache["cos_coeffs"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Coefficients with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Coefficient", return_fig=True)
add_lines(fig)

In [466]:
def get_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals / logits.norm()

get_metrics(model, metric_cache, get_cos_sim, "cos_sim") # You may need a big GPU. If you don't have one and can't work around this, raise an issue for help!
print(metric_cache["cos_sim"].shape)

fig = line(metric_cache["cos_sim"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Cosine Sim with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

  0%|          | 0/250 [00:00<?, ?it/s]

100%|██████████| 250/250 [00:02<00:00, 102.25it/s]


torch.Size([250, 56])


In [467]:
def get_residual_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    residual = logits - (vals[:, None, None, None] * cos_cube).sum(dim=0)
    return residual.norm() / logits.norm()

get_metrics(model, metric_cache, get_residual_cos_sim, "residual_cos_sim")
print(metric_cache["residual_cos_sim"].shape)

fig = line([metric_cache["cos_sim"][:, i] for i in range(p//2)]+[metric_cache["residual_cos_sim"]], line_labels=[f"Freq {i}" for i in range(1, p//2+1)]+["residual"], title="Cosine Sim with Predicted Logits + Residual", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

100%|██████████| 250/250 [00:03<00:00, 75.73it/s]


torch.Size([250])


## Restricted Loss

In [468]:
neuron_acts.shape

torch.Size([12769, 512])

In [469]:
neuron_acts_square = einops.rearrange(neuron_acts, "(a b) neur -> a b neur", a=p, b=p).clone()
# Center it
neuron_acts_square -= einops.reduce(neuron_acts_square, "a b neur -> 1 1 neur", "mean")
neuron_acts_square_fourier = einsum("a b neur, fa a, fb b -> fa fb neur", neuron_acts_square, fourier_basis, fourier_basis)
imshow(neuron_acts_square_fourier.norm(dim=-1), xaxis="Fourier Component b", yaxis="Fourier Component a", title="Norms of neuron activations by Fourier Component", x=fourier_basis_names, y=fourier_basis_names)

In [470]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())
neuron_acts = cache["post", 0, "mlp"][:, -1, :]

4328691


In [471]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
restricted_logits = approx_neuron_acts @ W_logit
print(loss_fn(restricted_logits[test_indices], test_labels))

tensor(8.9941e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)


In [472]:
print(loss_fn(all_logits.reshape(-1, all_logits.shape[-1]), labels))

tensor(1.1639e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)


### Look During Training

In [473]:
def get_restricted_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    restricted_logits = approx_neuron_acts @ model.blocks[0].mlp.W_out @ model.unembed.W_U
    # Add bias term
    restricted_logits += logits.mean(dim=0, keepdim=True) - restricted_logits.mean(dim=0, keepdim=True)
    return loss_fn(restricted_logits[test_indices], test_labels)
get_restricted_loss(model)

tensor(7.7158e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)

In [474]:
get_metrics(model, metric_cache, get_restricted_loss, "restricted_loss", reset=True)
print(metric_cache["restricted_loss"].shape)

100%|██████████| 250/250 [00:03<00:00, 65.39it/s]

torch.Size([250])





In [475]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss Curve", line_labels=['train', 'test', "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

In [476]:
fig = line([torch.tensor(test_losses[::100])/metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss to Test Loss Ratio", toggle_x=True, toggle_y=True, return_fig=True)
# WARNING: bugged when cancelling training half way thr ough
add_lines(fig)

## Excluded Loss

In [477]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
# approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
excluded_neuron_acts = neuron_acts - approx_neuron_acts
excluded_logits = excluded_neuron_acts @ W_logit
print(loss_fn(excluded_logits[train_indices], train_labels))

tensor(6.5511, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


In [478]:
def get_excluded_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    # approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    excluded_neuron_acts = neuron_acts - approx_neuron_acts
    residual_stream_final = excluded_neuron_acts @ model.blocks[0].mlp.W_out + cache["resid_mid", 0][:, -1, :]
    excluded_logits = residual_stream_final @ model.unembed.W_U
    return loss_fn(excluded_logits[train_indices], train_labels)
get_excluded_loss(model)

tensor(4.7765, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)

In [479]:
get_metrics(model, metric_cache, get_excluded_loss, "excluded_loss", reset=True)
print(metric_cache["excluded_loss"].shape)

100%|██████████| 250/250 [00:03<00:00, 64.82it/s]

torch.Size([250])





In [480]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["excluded_loss"], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Excluded and Restricted Loss Curve", line_labels=['train', 'test', "excluded_loss", "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)

add_lines(fig)