#Setup


In [None]:
use_drive = False #@param

In [None]:
!nvidia-smi
!pip install einops

Fri Oct  7 04:55:25 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    32W /  70W |   1034MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

# from google.colab import drive
from pathlib import Path
import pickle
import os

import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
import plotly.graph_objects as go

from torch.utils.data import DataLoader

from functools import *
import pandas as pd
import gc

# import comet_ml
import itertools

##Defining Transformer

In [None]:
# A helper class to get access to intermediate activations (inspired by Garcon)
# It's a dummy module that is the identity function by default
# I can wrap any intermediate activation in a HookPoint and get a convenient 
# way to add PyTorch hooks
class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []
    
    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name
    
    def add_hook(self, hook, dir='fwd'):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output, 
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, name=self.name)
        if dir=='fwd':
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir=='bwd':
            handle = self.register_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")
    
    def remove_hooks(self, dir='fwd'):
        if (dir=='fwd') or (dir=='both'):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir=='bwd') or (dir=='both'):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ['fwd', 'bwd', 'both']:
            raise ValueError(f"Invalid direction {dir}")
    
    def forward(self, x):
        return x

In [None]:
# Define network architecture
# I defined my own transformer from scratch so I'd fully understand each component 
# - I expect this wasn't necessary or particularly important, and a bunch of this 
# replicates existing PyTorch functionality

# Embed & Unembed
class Embed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))
    
    def forward(self, x):
        return torch.einsum('dbp -> bpd', self.W_E[:, x])

class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))
    
    def forward(self, x):
        return (x @ self.W_U)

# Positional Embeddings
class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))
    
    def forward(self, x):
        return x+self.W_pos[:x.shape[-2]]

# LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon = 1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon
    
    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x

# Attention
class Attention(nn.Module):
    def __init__(self, d_model, num_heads, d_head, n_ctx, model):
        super().__init__()
        self.model = model
        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_O = nn.Parameter(torch.randn(num_heads, d_model, d_head)/np.sqrt(d_model))
        self.register_buffer('mask', torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head
        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()
        self.hook_attn_pre = HookPoint()
        self.hook_result = HookPoint()

    def forward(self, x):
        k = self.hook_k(torch.einsum('ihd,bpd->biph', self.W_K, x))
        q = self.hook_q(torch.einsum('ihd,bpd->biph', self.W_Q, x))
        v = self.hook_v(torch.einsum('ihd,bpd->biph', self.W_V, x))
        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])
        attn_matrix = self.hook_attn(F.softmax(self.hook_attn_pre(attn_scores_masked/np.sqrt(self.d_head)), dim=-1))
        z = self.hook_z(torch.einsum('biph,biqp->biqh', v, attn_matrix))
        result = self.hook_result(torch.einsum('idh,biqh->biqd', self.W_O, z))
        out = einops.reduce(result, 
                             'batch index position model->batch position model', 
                             'sum')
        return out

# MLP Layers
class MLP(nn.Module):
    def __init__(self, d_model, d_mlp, act_type, model):
        super().__init__()
        self.model = model
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model)/np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp)/np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        # self.ln = LayerNorm(d_mlp, model=self.model)
        self.hook_pre = HookPoint()
        self.hook_post = HookPoint()
        assert act_type in ['ReLU', 'GeLU']
        
    def forward(self, x):
        x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x) + self.b_in)
        if self.act_type=='ReLU':
            x = F.relu(x)
        elif self.act_type=='GeLU':
            x = F.gelu(x)
        x = self.hook_post(x)
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, attn_only, model):
        super().__init__()
        self.model = model
        # self.ln1 = LayerNorm(d_model, model=self.model)
        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)
        # self.ln2 = LayerNorm(d_model, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_post = HookPoint()
        self.attn_only = attn_only
        if not self.attn_only:
            self.mlp = MLP(d_model, d_mlp, act_type, model=self.model)
            self.hook_mlp_out = HookPoint()
            self.hook_resid_mid = HookPoint()
    
    def forward(self, x):
        x = (x + self.hook_attn_out(self.attn((self.hook_resid_pre(x)))))
        if not self.attn_only:
            x = (x + self.hook_mlp_out(self.mlp(self.hook_resid_mid(x))))
        return self.hook_resid_post(x)

# Full transformer
class Transformer(nn.Module):
    def __init__(self, 
                 num_layers, 
                 d_vocab, 
                 d_model, 
                 d_mlp, 
                 d_head, 
                 num_heads, 
                 n_ctx, 
                 act_type, 
                 attn_only=False,
                 d_vocab_out=None,
                 use_pos=True):
        super().__init__()
        self.cache = {}
        self.attn_only = attn_only

        self.embed = Embed(d_vocab, d_model)
        self.use_pos = use_pos
        if self.use_pos:
            self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, d_mlp, d_head, num_heads, n_ctx, act_type, attn_only, model=[self]) for i in range(num_layers)])
        # self.ln = LayerNorm(d_model, model=[self])
        if d_vocab_out is None:
            d_vocab_out = d_vocab
        self.unembed = Unembed(d_vocab_out, d_model)

        for name, module in self.named_modules():
            if type(module)==HookPoint:
                module.give_name(name)
    
    def forward(self, x):
        x = self.embed(x)
        if self.use_pos:
            x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)
        # x = self.ln(x)
        x = self.unembed(x)
        return x
    
    def hook_points(self):
        return [module for name, module in self.named_modules() if 'hook' in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks('fwd')
            hp.remove_hooks('bwd')
    
    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()
        def save_hook_back(tensor, name):
            cache[name+'_grad'] = tensor[0].detach()
        for hp in self.hook_points():
            hp.add_hook(save_hook, 'fwd')
            if incl_bwd:
                hp.add_hook(save_hook_back, 'bwd')

##Helper Functions

In [None]:
def cuda_memory():
    print(torch.cuda.memory_allocated()/1e9)

def cos(x, y):
    return (x.dot(y))/x.norm()/y.norm()


In [None]:
#Plotting functions
# This is mostly a bunch of over-engineered mess to hack Plotly into producing 
# the pretty pictures I want, I recommend not reading too closely unless you 
# want Plotly hacking practice
def to_numpy(tensor, flat=False):
    if type(tensor)!=torch.Tensor:
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()
def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):
    tensor = torch.squeeze(tensor)
    px.imshow(to_numpy(tensor, flat=False), 
              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name}, 
              **kwargs).show()
# Set default colour scheme
imshow = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both 
# positive and negative values, where 0 is white)
imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps 
# of activations with x axis being input 1 and y axis being input 2.
inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==torch.Tensor:
        y = to_numpy(y, flat=True)
    if type(x)==torch.Tensor:
        x=to_numpy(x, flat=True)
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()
def scatter(x, y, **kwargs):
    px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()

def histogram(x, nbins=100, **kwargs):
    px.histogram(to_numpy(x), nbins=nbins, **kwargs).show()
def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()
def line_marker(x, **kwargs):
    lines([x], mode='lines+markers', **kwargs)
def animate_lines(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, xaxis='x', yaxis='y', **kwargs):
    if type(lines_list)==list:
        lines_list = torch.stack(lines_list, axis=0)
    lines_list = to_numpy(lines_list, flat=False)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[1]):
            rows.append([lines_list[i][j], snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=[yaxis, snapshot, xaxis])
    px.line(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover,**kwargs).show()

def imshow_fourier(tensor, title='', animation_name='snapshot', facet_labels=[], **kwargs):
    # Set nice defaults for plotting functions in the 2D fourier basis
    # tensor is assumed to already be in the Fourier Basis
    if tensor.shape[0]==p*p:
        tensor = unflatten_first(tensor)
    tensor = torch.squeeze(tensor)
    fig=px.imshow(to_numpy(tensor),
            x=fourier_basis_names, 
            y=fourier_basis_names, 
            labels={'x':'x Component', 
                    'y':'y Component', 
                    'animation_frame':animation_name},
            title=title,
            color_continuous_midpoint=0., 
            color_continuous_scale='RdBu', 
            **kwargs)
    fig.update(data=[{'hovertemplate':"%{x}x * %{y}y<br>Value:%{z:.4f}"}])
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    fig.show()

def animate_multi_lines(lines_list, y_index=None, snapshot_index = None, snapshot='snapshot', hover=None, swap_y_animate=False, **kwargs):
    # Can plot an animation of lines with multiple lines on the plot.
    if type(lines_list)==list:
        lines_list = torch.stack(lines_list, axis=0)
    lines_list = to_numpy(lines_list, flat=False)
    if swap_y_animate:
        lines_list = lines_list.transpose(1, 0, 2)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if y_index is None:
        y_index = [str(i) for i in range(lines_list.shape[1])]
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append(list(lines_list[i, :, j])+[snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=y_index+[snapshot, 'x'])
    px.line(df, x='x', y=y_index, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover, **kwargs).show()

def animate_scatter(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, yaxis='y', xaxis='x', color=None, color_name = 'color', **kwargs):
    # Can plot an animated scatter plot
    # lines_list has shape snapshot x 2 x line
    if type(lines_list)==list:
        lines_list = torch.stack(lines_list, axis=0)
    lines_list = to_numpy(lines_list, flat=False)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    if color is None:
        color = np.ones(lines_list.shape[-1])
    if type(color)==torch.Tensor:
        color = to_numpy(color)
    if len(color.shape)==1:
        color = einops.repeat(color, 'x -> snapshot x', snapshot=lines_list.shape[0])
    print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append([lines_list[i, 0, j].item(), lines_list[i, 1, j].item(), snapshot_index[i], color[i, j]])
    print([lines_list[:, 0].min(), lines_list[:, 0].max()])
    print([lines_list[:, 1].min(), lines_list[:, 1].max()])
    df = pd.DataFrame(rows, columns=[xaxis, yaxis, snapshot, color_name])
    px.scatter(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_x=[lines_list[:, 0].min(), lines_list[:, 0].max()], range_y=[lines_list[:, 1].min(), lines_list[:, 1].max()], hover_name=hover, color=color_name, **kwargs).show()

In [None]:
# def unflatten_first(tensor):
#     if tensor.shape[0]==p*p:
#         return einops.rearrange(tensor, '(x y) ... -> x y ...', x=p, y=p)
#     else: 
#         return tensor
# def mod_div(a, b):
#     return (a*pow(b, p-2, p))%p
# def normalize(tensor, axis=0):
#     return tensor/(tensor).pow(2).sum(keepdim=True, axis=axis).sqrt()
# def extract_freq_2d(tensor, freq):
#     # Takes in a pxpx... or batch x ... tensor, returns a 3x3x... tensor of the 
#     # Linear and quadratic terms of frequency freq
#     tensor = unflatten_first(tensor)
#     # Extracts the linear and quadratic terms corresponding to frequency freq
#     index_1d = [0, 2*freq-1, 2*freq]
#     # Some dumb manipulation to use fancy array indexing rules
#     # Gets the rows and columns in index_1d
#     return tensor[[[i]*3 for i in index_1d], [index_1d]*3]
# def get_cov(tensor, norm=True):
#     # Calculate covariance matrix
#     if norm:
#         tensor = normalize(tensor, axis=1)
#     return tensor @ tensor.T
# def is_close(a, b):
#     return ((a-b).pow(2).sum()/(a.pow(2).sum().sqrt())/(b.pow(2).sum().sqrt())).item()

Fourier Transform functions - see Fourier Transform section for intuition

In [None]:
# def fft1d(tensor):
#     # Converts a tensor with dimension p into the Fourier basis
#     return tensor @ fourier_basis.T

# def fourier_2d_basis_term(x_index, y_index):
#     # Returns the 2D Fourier basis term corresponding to the outer product of 
#     # the x_index th component in the x direction and y_index th component in the 
#     # y direction
#     # Returns a 1D vector of length p^2
#     return (fourier_basis[x_index][:, None] * fourier_basis[y_index][None, :]).flatten()

# def fft2d(mat):
#     # Converts a pxpx... or batch x ... tensor into the 2D Fourier basis.
#     # Output has the same shape as the orig
#     shape = mat.shape
#     mat = einops.rearrange(mat, '(x y) ... -> x y (...)', x=p, y=p)
#     fourier_mat = torch.einsum('xyz,fx,Fy->fFz', mat, fourier_basis, fourier_basis)
#     return fourier_mat.reshape(shape)

# def analyse_fourier_2d(tensor, top_k=10):
#     # Processes a (p,p) or (p*p) tensor in the 2D Fourier Basis, showing the 
#     # top_k terms and how large a fraction of the variance they explain
#     values, indices = tensor.flatten().pow(2).sort(descending=True)
#     rows = []
#     total = values.sum().item()
#     for i in range(top_k):
#         rows.append([tensor.flatten()[indices[i]].item(),
#                      values[i].item()/total, 
#                      values[:i+1].sum().item()/total, 
#                      fourier_basis_names[indices[i].item()//p], 
#                      fourier_basis_names[indices[i]%p]])
#     display(pd.DataFrame(rows, columns=['Coefficient', 'Frac explained', 'Cumulative frac explained', 'x', 'y']))

# def get_2d_fourier_component(tensor, x, y):
#     # Takes in a batch x ... tensor and projects it onto the 2D Fourier Component 
#     # (x, y)
#     vec = fourier_2d_basis_term(x, y).flatten()
#     return vec[:, None] @ (vec[None, :] @ tensor)

# def get_component_cos_xpy(tensor, freq, collapse_dim=False):
#     # Gets the component corresponding to cos(freq*(x+y)) in the 2D Fourier basis
#     # This is equivalent to the matrix cos((x+y)*freq*2pi/p)
#     cosx_cosy_direction = fourier_2d_basis_term(2*freq-1, 2*freq-1).flatten()
#     sinx_siny_direction = fourier_2d_basis_term(2*freq, 2*freq).flatten()
#     # Divide by sqrt(2) to ensure it remains normalised
#     cos_xpy_direction = (cosx_cosy_direction - sinx_siny_direction)/np.sqrt(2)
#     # Collapse_dim says whether to project back into R^(p*p) space or not
#     if collapse_dim:
#         return (cos_xpy_direction @ tensor)
#     else:
#         return cos_xpy_direction[:, None] @ (cos_xpy_direction[None, :] @ tensor)

# def get_component_sin_xpy(tensor, freq, collapse_dim=False):
#     # Gets the component corresponding to sin((x+y)*freq*2pi/p) in the 2D Fourier basis
#     sinx_cosy_direction = fourier_2d_basis_term(2*freq, 2*freq-1).flatten()
#     cosx_siny_direction = fourier_2d_basis_term(2*freq-1, 2*freq).flatten()
#     sin_xpy_direction = (sinx_cosy_direction + cosx_siny_direction)/np.sqrt(2)
#     if collapse_dim:
#         return (sin_xpy_direction @ tensor)
#     else:
#         return sin_xpy_direction[:, None] @ (sin_xpy_direction[None, :] @ tensor)

##Training Details

###Architecture
It's a 1 layer transformer, with no layer norm and learned positional embeddings. $d_{model} = 128$, $n_{heads} = 4$, $d_{head}=32$, $d_{mlp}=512$. Input format is `x|y|=`, $d_{vocab}=114$ (integers from $0$ to $p-1$ and $=$).

It was trained with full batch training, with 0.3 of the total data as training data. It is trained with AdamW, with $lr=10^{-3}$ and very high weight decay ($wd=1$) - I speculate that weight decay and possibly using Adam at all is pretty important for getting grokking to work.

**Aside:** Why a transformer? A 1L MLP with ReLU activations can do modular addition, with no equals sign, and the sum of the embeddings of $x$ and $y$. This would also be a pretty reasonable thing to analyse! I realised this pretty late into this research, and wanted an excuse to use the transformer circuits equations, so stuck with a 1L transformer. Using a 1L MLP for future research into grokking seems pretty reasonable.

**Aside 2:** Even given that it's a transformer, why have an equals sign token? Mostly because I wanted the model to be able to learn commutativity - if the inputs were just `x|y` then the residual stream + QK circuit inherently distinguish between $x$ and $y$. This isn't a super important choice, though I do think it makes the analysis a bit nicer.

###Hyper-Parameters

In [None]:
# lr=1e-3 #@param
# weight_decay = 1.0 #@param
# p=113 #@param
# d_model = 128 #@param
# fn_name = 'add' #@param ['add', 'subtract', 'x2xyy2','rand']
# frac_train = 0.3 #@param
# num_epochs = 50000 #@param
# save_models = False #@param
# save_every = 100 #@param
# # Stop training when test loss is <stopping_thresh
# stopping_thresh = -1 #@param
# seed = 0 #@param

# num_layers = 1
# batch_style = 'full'
# d_vocab = p+1
# n_ctx = 3
# d_mlp = 4*d_model
# num_heads = 4
# assert d_model % num_heads == 0
# d_head = d_model//num_heads
# act_type = 'ReLU' #@param ['ReLU', 'GeLU']
# # batch_size = 512
# use_ln = False
# random_answers = np.random.randint(low=0, high=p, size=(p, p))
# fns_dict = {'add': lambda x,y:(x+y)%p, 'subtract': lambda x,y:(x-y)%p, 'x2xyy2':lambda x,y:(x**2+x*y+y**2)%p, 'rand':lambda x,y:random_answers[x][y]}
# fn = fns_dict[fn_name]

###Model Training Code
Included as a reference, don't run by default

In [None]:
# train_model = False #@param

In [None]:
# def gen_train_test(frac_train, num, seed=0):
#     # Generate train and test split
#     pairs = [(i, j, num) for i in range(num) for j in range(num)]
#     random.seed(seed)
#     random.shuffle(pairs)
#     div = int(frac_train*len(pairs))
#     return pairs[:div], pairs[div:]

# train, test = gen_train_test(frac_train, p, seed)
# print(len(train), len(test))

In [None]:
# # Creates an array of Boolean indices according to whether each data point is in 
# # train or test
# # Used to index into the big batch of all possible data
# is_train = []
# is_test = []
# for x in range(p):
#     for y in range(p):
#         if (x, y, 113) in train:
#             is_train.append(True)
#             is_test.append(False)
#         else:
#             is_train.append(False)
#             is_test.append(True)
# is_train = np.array(is_train)
# is_test = np.array(is_test)

In [None]:
# if train_model:
#     model = Transformer(num_layers=num_layers, d_vocab=d_vocab, d_model=d_model, d_mlp=d_mlp, d_head=d_head, num_heads=num_heads, n_ctx=n_ctx, act_type=act_type, use_cache=False, use_ln=use_ln)
#     model.to('cuda')
#     optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.98))
#     scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))
#     run_name = f"grok_{int(time.time())}"
#     print(f'Run name {run_name}')
#     if save_models:
#         os.mkdir(root/run_name)
#         save_dict = {'model':model.state_dict(), 'train_data':train, 'test_data':test}
#         torch.save(save_dict, root/run_name/'init.pth')
#     train_losses = []
#     test_losses = []
#     for epoch in range(num_epochs):
#         train_loss = full_loss(model, train)
#         test_loss = full_loss(model, test)
#         train_losses.append(train_loss.item())
#         test_losses.append(test_loss.item())
#         if epoch%100 == 0: print(f"{epoch}_{np.log(train_loss.item()):.4f}_{np.log(test_loss.item()):.4f}")#_{train_acc.item():.4f}_{test_acc.item():.4f}")
#         train_loss.backward()
#         optimizer.step()
#         scheduler.step()
#         optimizer.zero_grad()
#         if test_loss.item() < stopping_thresh:
#             break
#         if (save_models) and (epoch%save_every == 0):
#             if test_loss.item() < stopping_thresh:
#                 break
#             save_dict = {
#                 'model': model.state_dict(),
#                 'optimizer': optimizer.state_dict(),
#                 'scheduler': scheduler.state_dict(),
#                 'train_loss': train_loss,
#                 'test_loss': test_loss,
#                 'epoch': epoch,
#             }
#             torch.save(save_dict, root/run_name/f"{epoch}.pth")
#             print(f"Saved model to {root/run_name/f'{epoch}.pth'}")
#     if not save_models:
#         os.mkdir(root/run_name)
#     save_dict = {
#         'model': model.state_dict(),
#         'optimizer': optimizer.state_dict(),
#         'scheduler': scheduler.state_dict(),
#         'train_loss': train_loss,
#         'test_loss': test_loss,
#         'train_losses': train_losses,
#         'test_losses': test_losses,
#         'epoch': epoch,
#     }
#     torch.save(save_dict, root/run_name/f"final.pth")
#     print(f"Saved model to {root/run_name/f'final.pth'}")
#     lines([train_losses, test_losses], labels=['train', 'test'], log_y=True)

#     # save_models = False

#5 Digit Addition

In [None]:
train_model = True #@param


##Setup

In [None]:
#@markdown Model
num_layers = 1
d_vocab = 12
d_vocab_out = 10
d_model = 512 #@param
num_heads = 4
d_head = d_model//num_heads
d_mlp = 4 * d_model
seed = 129000 #@param
#@markdown Data
num_digits =  5#@param
n_ctx = 3*num_digits + 3
act_type = 'ReLU'
batch_size = 64 #@param
is_finite = False #@param
num_data = 750 #@param
#@markdown Optimizer
lr = 1e-4 #@param
weight_decay = 0.1 #@param
num_epochs = 3000 #@param

#@markdown Misc
checkpoint_models = False #@param
checkpoint_every = 50 #@param


PLUS_INDEX = 10
EQUALS_INDEX = 11

In [None]:
def tokens_to_string(tokens):
    tokens = to_numpy(tokens)
    x = "".join([str(i) for i in tokens[:5]])
    y = "".join([str(i) for i in tokens[6:11]])
    z = "".join([str(i) for i in tokens[12:]])
    return f"  {x}\n +{y}\n={z}"
def string_to_tokens(string, batch=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=10
    lookup['=']=11
    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(10))
fourier_basis_names.append(f'const')

for i in range(1, 5):
    fourier_basis.append(torch.cos(2*np.pi/10*i*torch.arange(10)))
    fourier_basis_names.append(f'cos {i}')
    fourier_basis.append(torch.sin(2*np.pi/10*i*torch.arange(10)))
    fourier_basis_names.append(f'sin {i}')
fourier_basis.append(torch.cos(np.pi*torch.arange(10)))
fourier_basis_names.append(f'+-1')

fourier_basis = torch.stack(fourier_basis, axis=0).cuda()
fourier_basis = fourier_basis/einops.reduce(fourier_basis.pow(2),
                                            'vocab fourier -> vocab 1',
                                            'sum').sqrt()
imshow_div(fourier_basis)
imshow_div(fourier_basis @ fourier_basis.T)

In [None]:
def data_generator(batch_size, num_digits, seed):
    torch.manual_seed(seed)
    while True:
        batch = torch.zeros((batch_size, 3*num_digits+3)).to(torch.int64)
        x = torch.randint(0, 10, (batch_size, num_digits))
        y = torch.randint(0, 10, (batch_size, num_digits))
        batch[:, :num_digits] = x
        batch[:, num_digits] = PLUS_INDEX
        batch[:, 1+num_digits:1+num_digits*2] = y
        batch[:, 1+num_digits*2] = EQUALS_INDEX
        carries = [torch.zeros((batch_size,)).to(torch.int64)]
        for i in range(num_digits):
            carry = carries[-1]
            digit_sum = (batch[:, num_digits-1-i]+batch[:, 2*num_digits-i]+carry)
            batch[:, -1-i] = (digit_sum % 10)
            carry = (digit_sum>=10).to(torch.int64)
            carries.append(carry)
        batch[:, -1-num_digits] = carries[-1]
        carries = torch.stack(carries, axis=1)
        yield batch.cuda(), carries.cuda()
if is_finite:
    test_ds = data_generator(batch_size, num_digits, seed)
    train_ds = data_generator(num_data, num_digits, seed)
    train_tokens, train_carries = next(train_ds)
else:
    ds = data_generator(batch_size, num_digits, seed)

In [None]:
torch.manual_seed(seed)
model = Transformer(num_layers=num_layers, 
                    d_vocab=d_vocab, 
                    d_model=d_model, 
                    d_mlp=d_mlp, 
                    d_head=d_head, 
                    num_heads=num_heads, 
                    n_ctx=n_ctx,
                    act_type=act_type,
                    d_vocab_out=d_vocab_out)
model.to('cuda')
optimizer = optim.AdamW(model.parameters(), 
                        lr=lr, 
                        weight_decay=weight_decay, 
                        betas=(0.9, 0.98))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))

In [None]:

def get_pred_log_probs(logits, tokens):
    trunc_logits = logits[:, -(num_digits+2):-1]
    ans_tokens = tokens[:, -(num_digits+1):]
    log_probs = F.log_softmax(trunc_logits.to(torch.float64), dim=-1)
    pred_log_probs = torch.gather(log_probs, -1, ans_tokens[:, :, None])[..., 0]
    return pred_log_probs

def loss_fn(logits, tokens):
    return -get_pred_log_probs(logits, tokens).mean()

##Training

In [None]:
if is_finite:
    train_losses = []
    ptl_train_list = []
    test_losses = []
    ptl_test_list = []
    # per_token_losses_list = []
    # sds=[]
    # epochs = [0]
    # sds.append(model.state_dict())
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_tokens)
        per_token_losses_train = -get_pred_log_probs(train_logits, train_tokens).mean(0)
        ptl_train_list.append(to_numpy(per_token_losses_train))
        train_loss = per_token_losses_train.mean()
        train_loss.backward()
        train_losses.append(train_loss.item())

        test_tokens, _ = next(test_ds)
        test_logits = model(test_tokens)
        per_token_losses_test = -get_pred_log_probs(test_logits, test_tokens).mean(0)
        ptl_test_list.append(to_numpy(per_token_losses_test))
        test_loss = per_token_losses_test.mean()
        test_losses.append(test_loss.item())

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if epoch % 100 == 0:
            print(epoch, train_loss.item(), test_loss.item())
        if epoch%1000 ==0 and epoch>0:
            lines([train_losses, test_losses], labels=['train', 'test'])
            lines([[ptl_train_list[j][i] for j in range(len(ptl_train_list))] for i in range(1+num_digits)]+[train_losses]+[[ptl_test_list[j][i] for j in range(len(ptl_train_list))] for i in range(1+num_digits)]+[test_losses],
            labels = [f'tok train {i}' for i in range(1+num_digits)]+['train_loss']+[f'tok test {i}' for i in range(1+num_digits)]+['test_loss'],
            title='Per-digit Loss Curves for 5 digit addition (Finite Data)',
            xaxis='Step',
            yaxis='Loss')

In [None]:
if not is_finite and train_model:
    train_losses = []
    per_token_losses_list = []
    sds=[]
    epochs = [0]
    sds.append(model.state_dict())
    for epoch in tqdm.tqdm(range(num_epochs)):
        tokens, carry = next(ds)
        logits = model(tokens)
        per_token_losses = -get_pred_log_probs(logits, tokens).mean(0)
        per_token_losses_list.append(to_numpy(per_token_losses))
        loss = per_token_losses.mean()
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        train_losses.append(loss.item())
        if epoch % 100 == 0:
            print(epoch, loss.item())
        if checkpoint_models:
            if (epoch+1) % (checkpoint_every) == 0:
                sds.append(model.state_dict())
                epochs.append(epoch+1)
        # if (epoch+1) % 2000 == 0:
    line(train_losses)

  0%|          | 0/3000 [00:00<?, ?it/s]

0 2.5933517134842354
100 1.9987422520574711
200 1.9970166011907557
300 1.9250046541021848
400 1.8031224091815634
500 1.7581865212820351
600 1.6841518779962343
700 1.630511549688915
800 1.661828079396909
900 1.5942936133183832
1000 1.4743505272202289
1100 1.371478044522961
1200 1.365702081309069
1300 1.2931866080480892
1400 1.285163123156675
1500 1.1130359451275
1600 0.7196021854265029
1700 0.37786698991880685
1800 0.2889489863031449
1900 0.22298177054905016
2000 0.16746928390340665
2100 0.1252498923485676
2200 0.08444272403742815
2300 0.057058135558476566
2400 0.09701944961320949
2500 0.06711638988956252
2600 0.03192600587129353
2700 0.04556934619726692
2800 0.06436082760636397
2900 0.03303631199648917


In [None]:
line(train_losses)
per_token_losses = np.stack(per_token_losses_list, axis=0)
lines([per_token_losses[:, i] for i in range(1+num_digits)]+[train_losses],
      labels = [f'tok {i}' for i in range(1+num_digits)]+['train_loss'],
      title='Per-digit Loss Curves for 5 digit addition (Infinite Data)',
      xaxis='Step',
      yaxis='Loss')

lines([per_token_losses[:, i] for i in range(1+num_digits)]+[train_losses],
      labels = [f'tok {i}' for i in range(1+num_digits)]+['train_loss'], log_y=True)

In [None]:
line(train_losses)
per_token_losses = np.stack(per_token_losses_list, axis=0)
lines([per_token_losses[:, i] for i in range(1+num_digits)]+[train_losses],
      labels = [f'tok {i}' for i in range(1+num_digits)]+['train_loss'])

lines([per_token_losses[:, i] for i in range(1+num_digits)]+[train_losses],
      labels = [f'tok {i}' for i in range(1+num_digits)]+['train_loss'], log_y=True)

In [None]:
if is_finite:
    train_losses = []
    test_losses = []
    # losses_forced = []
    epochs = []
    for epoch in tqdm.tqdm(range(num_epochs)):
        logits = model(train_tokens)
        train_ptl = get_pred_lps(logits, train_mask)
        train_ptls.append(to_numpy(train_ptl))
        train_loss = train_ptl.mean()
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        train_losses.append(train_loss.item())

        test_tokens, test_mask = next(test_ds)
        test_logits = model(test_tokens)
        test_pred_lps = get_pred_lps(test_logits, test_mask)
        test_ptl = [test_pred_lps[test_mask==i].mean().item() for i in range(rand_size)]
        test_ptls.append(test_ptl)
        test_loss = test_pred_lps.mean()
        test_losses.append(test_loss.item())

        if epoch % 100 == 0:
            print(epoch, train_losses[-1], test_losses[-1])
        if (epoch % 1000 == 0) and epoch>0:
            lines([train_losses, test_losses], labels=['train', 'test'], log_y=True)
        if epoch%5 == 0:
            epochs.append(epoch)
            sds.append(copy.deepcopy(model.state_dict()))

#Skip Trigram

##Setup

In [None]:
lr=1e-3 #@param
weight_decay =  0.1#@param 
d_model = 32 #@param
num_epochs = 3000 #@param
save_models = False #@param
save_every = 10 #@param
# Stop training when test loss is <stopping_thresh
stopping_thresh = -1 #@param
seed =  6798#@param

num_layers = 1
batch_size = 256 #@param
rand_size = 10 #@param
d_vocab = 2*rand_size+2 #@param
d_vocab_out = rand_size #@param
seq_len = 40 #@param
n_ctx = seq_len + 2 #@param
d_mlp = 4*d_model
num_heads = 1 #@param
assert d_model % num_heads == 0
d_head = 8 #@param
act_type = 'ReLU' #@param ['ReLU', 'GeLU']
attn_only = True #@param

is_finite = False #@param
num_data =  100#@param

In [None]:
# Make data
def data_generator(batch_size, seq_len, rand_size, seed, force_mask=-1):
    rand_gen = torch.Generator(device='cpu')
    rand_gen.manual_seed(seed)
    while True:
        batch = torch.zeros((batch_size, seq_len+2)).to(torch.int64)
        batch[:, 1:1+seq_len] = torch.randint(2, rand_size+2, (batch_size, seq_len), generator=rand_gen)
        batch[:, -1] = 1
        mask = torch.randint(0, rand_size, (batch_size,), generator=rand_gen)
        if force_mask>-1:
            mask[:] = force_mask
        index = torch.randint(1, seq_len+1, (batch_size,), generator=rand_gen)
        batch[torch.arange(batch_size, dtype=torch.int64), index] = rand_size + 2 + mask
        yield batch.cuda(), mask.cuda()
if not is_finite:
    ds = data_generator(batch_size, seq_len, rand_size, seed)
# forced_ds = [data_generator(batch_size, seq_len, rand_size, seed, force_mask=i) for i in range(10)]
else:
    ds = data_generator(num_data, seq_len, rand_size, seed)
    train_tokens, train_mask = next(ds)
    test_ds = ds = data_generator(batch_size, seq_len, rand_size, seed)

In [None]:
class SkipTrigramOld(nn.Module):
    def __init__(self, d_vocab, d_vocab_out, seed):
        super().__init__()
        torch.manual_seed(seed)
        self.QK = nn.Parameter(torch.randn(d_vocab)/np.sqrt(d_vocab))
        self.OV = nn.Parameter(torch.randn(d_vocab, d_vocab_out)/np.sqrt(d_vocab))
    
    def forward(self, x):
        attn = self.QK[x]
        attn = F.softmax(attn, dim=-1)
        # attn = attn.exp()
        value = self.OV[x]
        return torch.einsum('bp,bpv->bv', attn, value)

In [None]:
# class SkipTrigramSimple(nn.Module):
#     def __init__(self, d_vocab_in, d_vocab_out):
#         super().__init__()
#         self.QK = nn.Parameter(torch.randn((d_vocab_in))/np.sqrt(d_vocab_in))
#         self.OV = nn.Parameter(torch.randn((d_vocab_out, d_vocab_in))/np.sqrt(d_vocab_out))
    
#     def forward(self, x):
#         attn_pre = self.QK[x]
#         attn = F.softmax(attn_pre, dim=-1)
#         v = self.OV[:, x]
#         return torch.einsum('obp,bp->bo', v, attn)
    
# model = SkipTrigramSimple(d_vocab, d_vocab_out).cuda()
# # optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.98))
# optimizer = optim.SGD(model.parameters(), lr=1, weight_decay=0.01)
# # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))

In [None]:
torch.manual_seed(seed)
model = Transformer(num_layers=num_layers, 
                    d_vocab=d_vocab, 
                    d_model=d_model, 
                    d_mlp=d_mlp, 
                    d_head=d_head, 
                    num_heads=num_heads, 
                    n_ctx=n_ctx, 
                    act_type=act_type, 
                    attn_only=attn_only,
                    d_vocab_out=d_vocab_out)
model.to('cuda')

# model = SkipTrigramOld(d_vocab, d_vocab_out, seed)
# model.to('cuda')

optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.98))
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))

In [None]:
def get_pred_lps(logits, mask, already_just_last=False):
    if not already_just_last:
        logits = logits[:, -1, :]
    log_probs = F.log_softmax(logits.to(torch.float64), dim=-1)
    prediction_log_probs = torch.gather(log_probs, dim=-1, index=mask[:, None])
    # return -(prediction_log_probs*(weight(mask))).mean()/weight(mask).float().mean()
    return -(prediction_log_probs)

In [None]:
import copy

##Training

In [None]:
if is_finite:
    train_ptls = []
    test_ptls = []
    train_losses = []
    test_losses = []
    sds=[]
    # losses_forced = []
    epochs = []
    for epoch in tqdm.tqdm(range(num_epochs)):
        logits = model(train_tokens)
        train_ptl = get_pred_lps(logits, train_mask)
        train_ptls.append(to_numpy(train_ptl))
        train_loss = train_ptl.mean()
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        train_losses.append(train_loss.item())

        test_tokens, test_mask = next(test_ds)
        test_logits = model(test_tokens)
        test_pred_lps = get_pred_lps(test_logits, test_mask)
        test_ptl = [test_pred_lps[test_mask==i].mean().item() for i in range(rand_size)]
        test_ptls.append(test_ptl)
        test_loss = test_pred_lps.mean()
        test_losses.append(test_loss.item())

        if epoch % 100 == 0:
            print(epoch, train_losses[-1], test_losses[-1])
        if (epoch % 1000 == 0) and epoch>0:
            lines([train_losses, test_losses], labels=['train', 'test'], log_y=True)
        if epoch%5 == 0:
            epochs.append(epoch)
            sds.append(copy.deepcopy(model.state_dict()))
    # lines(np.array(losses_forced).T, log_y=True)
else:
    train_losses = []
    sds = []
    epochs = []
    train_ptls = []
    # losses_forced = []

    grad_l = []
    for epoch in tqdm.tqdm(range(num_epochs)):
        cache = {}
        model.cache_all(cache, incl_bwd=True)
        tokens, mask = next(ds) 
        logits = model(tokens)
        train_pred_lps = get_pred_lps(logits, mask)
        train_loss = train_pred_lps.mean()
        train_ptl = [train_pred_lps[mask==i].mean().item() for i in range(rand_size)]
        train_ptls.append(train_ptl)
        train_loss.backward()

        grads = cache['blocks.0.attn.hook_attn_grad'][torch.arange(batch_size).cuda(), 0, -1, tokens.argmax(1).cuda()]
        grad_l.append([grads[mask==i].mean().item() for i in range(rand_size)])
        # if epoch % 100 == 0:
        #     print('GRAD', grad_l[-1])
        model.remove_all_hooks()

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        train_losses.append(train_loss.item())
        if epoch % 100 == 0:
            print(epoch, train_losses[-1])
        if (epoch % 1000 == 0) and epoch>0:
            line(train_losses, log_y=True)
        if epoch%100 == 0:
            epochs.append(epoch)
            sds.append(copy.deepcopy(model.state_dict()))
    # lines(np.array(losses_forced).T, log_y=True)
    test_losses = train_losses

  0%|          | 0/3000 [00:00<?, ?it/s]

0 2.3461269241836282
100 2.248627353300792
200 0.2549234119546281
300 0.014259271087967415
400 0.003011262212211936
500 0.0008407764419176036
600 0.00023982133920360531
700 7.907968477477128e-05
800 2.599014213155428e-05
900 8.425155653442519e-06
1000 2.8579765030970806e-06


1100 9.702836711225614e-07
1200 3.266822605600373e-07
1300 1.1261705951643674e-07
1400 4.104932981609225e-08
1500 1.6251449137513372e-08
1600 6.901606784974924e-09
1700 3.985333961635475e-09
1800 2.6025596963417475e-09
1900 2.1558313304833046e-09
2000 1.8941943377082642e-09


2100 1.835954541154622e-09
2200 1.7681960289446334e-09
2300 1.779933357906981e-09
2400 1.7614561191183232e-09
2500 1.7656281160597172e-09
2600 1.7333743583050018e-09
2700 1.7060766812173324e-09
2800 1.7303718477476938e-09
2900 1.7550454346456562e-09


In [None]:
if is_finite:
    lines([train_losses, test_losses], log_y=True, labels=['train', 'test'])
    lines([train_losses, test_losses], log_y=False, labels=['train', 'test'])
else:
    line(train_losses, log_y=True)
    line(train_losses, log_y=False)

In [None]:
if is_finite:
    train_ptl = np.stack(train_ptls, axis=1).squeeze()
    test_ptl = np.array(test_ptls).T
    lines(train_ptl[:, ::10], log_y=False)
    lines(test_ptl[:, ::10], log_y=False)
    lines(train_ptl[:, ::10], log_y=True)
    lines(test_ptl[:, ::10], log_y=True)
else:
    train_ptl = np.array(train_ptls).T
    lines([train_ptl[i] for i in range(10)]+[train_losses], 
          log_y=False, 
          title='Per token phase changes for skip trigrams (Infinite Data)',
          xaxis='Step',
          yaxis='Loss',
          labels = [f"trigram {i}" for i in range(11, 21)]+['overall loss'])
    lines(train_ptl[:], log_y=True)
    lines(np.array(grad_l).T)

In [None]:
if is_finite:
    train_ptl = np.stack(train_ptls, axis=1).squeeze()
    test_ptl = np.array(test_ptls).T
    lines(train_ptl[:, ::10], log_y=False)
    lines(test_ptl[:, ::10], log_y=False)
    lines(train_ptl[:, ::10], log_y=True)
    lines(test_ptl[:, ::10], log_y=True)
else:
    train_ptl = np.array(train_ptls).T
    lines(train_ptl[:], log_y=False)
    lines(train_ptl[:], log_y=True)
    lines(np.array(grad_l).T)

#Induction Heads


Config

In [None]:
# Model config
# data_config = dict(
#     seq_len = 129,
#     subseq_len = 16,
#     num_repeats = 4,
#     vocab_size = 128, # First token is BEGIN
#     batch_size = 256,
#     seed = 0,
# )

data_config = dict(
    seq_len = 85,#@param
    subseq_len = 16,#@param
    num_repeats = 3,#@param
    vocab_size = 128, # First token is BEGIN
    batch_size = 128, #@param
    seed = 123, #@param
    is_finite=False,#@param
    num_batches=4, #@param
)

model_config = dict(
    num_layers = 2,
    d_model = 48, #@param
    # d_model = 128,
    num_heads = 4,
    n_ctx = data_config['seq_len'],
    act_type = 'ReLU',
    d_vocab = data_config['vocab_size'],
    pos_embed_type = 'orig',
    seed = data_config['seed']
)

model_config['d_head'] = model_config['d_model']//model_config['num_heads']
model_config['d_mlp'] = model_config['d_model'] * 4
# Optim config
optim_config = dict(
    lr = 1e-3,
    weight_decay = 1, #@param
    num_epochs = 6000,#@param
)


num_layers = model_config['num_layers']
d_model = model_config['d_model']
num_heads = model_config['num_heads']
n_ctx = model_config['n_ctx']
act_type = model_config['act_type']
d_vocab = model_config['d_vocab']
d_head = model_config['d_head']
d_mlp = model_config['d_mlp']
pos_embed_type = model_config['pos_embed_type']


seq_len = data_config['seq_len']
subseq_len = data_config['subseq_len']
num_repeats = data_config['num_repeats']
vocab_size = data_config['vocab_size']
batch_size = data_config['batch_size']

Data generation

In [None]:
def generate_dataset(seq_len, subseq_len, num_repeats, vocab_size, batch_size, seed, **kwargs):
    np.random.seed(seed)
    def make_batch():
        background_noise = np.random.randint(1, vocab_size, (batch_size, seq_len))
        background_noise[:, 0] = 0
        interval_length = (seq_len - 1)//num_repeats
        indices = [np.random.randint(1+i*interval_length, 1+(i+1)*interval_length - (subseq_len - 1), (batch_size,)) for i in range(num_repeats)]
        batch_index = np.arange(batch_size)[:, None] + np.zeros(subseq_len, dtype=np.int32)
        # background_noise[batch_index, indices[0][:, None]+np.arange(subseq_len)] += (vocab_size-1)//2
        mask = np.zeros((batch_size, seq_len))
        for i in range(1, num_repeats):
            # Fancy indexing needs to be done with Numpy not PyTorch - PyTorch crashes the kernel ???
            background_noise[batch_index, indices[i][:, None]+np.arange(subseq_len)] = background_noise[batch_index, indices[0][:, None]+np.arange(subseq_len)]
            mask[batch_index[:,:-1], indices[i][:, None]+np.arange(subseq_len-1)] = 1.0
        return torch.tensor(background_noise), torch.tensor(mask)#, indices
    while True:
        yield make_batch()


Model definition

In [None]:
# Define network architecture

# Embed & Unembed
class Embed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))
    
    def forward(self, x):
        return torch.einsum('dbp -> bpd', self.W_E[:, x])

class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))
    
    def forward(self, x):
        return (x @ self.W_U)

# Positional Embeddings
class PosEmbedSine(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        self.sinusoidal_pos_embeds = nn.Parameter(torch.zeros((max_ctx, d_model)), requires_grad=False)
        self.sinusoidal_pos_embeds[:, ::2] = torch.cos(torch.arange(max_ctx)[:, None]/10000**(torch.arange(d_model//2)/d_model))
        self.sinusoidal_pos_embeds[:, 1::2] = torch.sin(torch.arange(max_ctx)[:, None]/10000**(torch.arange(d_model//2)/d_model))
        self.sinusoidal_pos_embeds/=np.sqrt(d_model)
    
    def forward(self, x):
        return x+self.sinusoidal_pos_embeds[:x.shape[-2]]

class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))
    
    def forward(self, x):
        return x+self.W_pos[:x.shape[-2]]

# LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon = 1e-4):
        super().__init__()
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon
    
    def forward(self, x):
        x = x - x.mean(axis=-1)[..., None]
        x = x / (x.std(axis=-1)[..., None] + self.epsilon)
        x = x * self.w_ln
        x = x + self.b_ln
        return x

class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []
    
    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name
    
    def add_hook(self, hook, dir='fwd'):
        if dir=='fwd':
            handle = self.register_forward_hook(partial(hook, name=self.name))
            self.fwd_hooks.append(handle)
        elif dir=='bwd':
            handle = self.register_backward_hook(partial(hook, name=self.name))
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")
    
    def remove_hooks(self, dir='fwd'):
        if (dir=='fwd') or (dir=='both'):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir=='bwd') or (dir=='both'):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ['fwd', 'bwd', 'both']:
            raise ValueError(f"Invalid direction {dir}")
    
    def forward(self, x):
        return x


# Attention
class Attention(nn.Module):
    def __init__(self, d_model, num_heads, d_head, n_ctx, model):
        super().__init__()
        self.model = model
        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_head))
        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_head))
        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_head))
        self.W_O = nn.Parameter(torch.randn(d_model, d_head * num_heads)/np.sqrt(d_model))
        self.mask = nn.Parameter(torch.tril(torch.ones((n_ctx, n_ctx))), requires_grad=False)
        self.d_head = d_head
        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()

    def forward(self, x):
        k = self.hook_k(torch.einsum('ihd,bpd->biph', self.W_K, x))
        q = self.hook_q(torch.einsum('ihd,bpd->biph', self.W_Q, x))
        v = self.hook_v(torch.einsum('ihd,bpd->biph', self.W_V, x))
        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])
        attn_matrix = self.hook_attn(F.softmax(attn_scores_masked/np.sqrt(self.d_head), dim=-1))
        z = self.hook_z(torch.einsum('biph,biqp->biqh', v, attn_matrix))
        z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')
        out = torch.einsum('df,bqf->bqd', self.W_O, z_flat)
        return out

# Transformer Block
class AttnBlock(nn.Module):
    def __init__(self, d_model, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        self.model = model
        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_post = HookPoint()
        # self.ln = LayerNorm(d_model)
    
    def forward(self, x):
        # return self.hook_resid_post(x + self.hook_attn_out(self.attn(self.ln(self.hook_resid_pre(x)))))
        return self.hook_resid_post(x + self.hook_attn_out(self.attn((self.hook_resid_pre(x)))))



# Full transformer
class AttnOnlyTransformer(nn.Module):
    def __init__(self, model_config):
        super().__init__()

        num_layers = model_config['num_layers']
        d_vocab = model_config['d_vocab']
        d_model = model_config['d_model']
        d_head = model_config['d_head']
        num_heads = model_config['num_heads']
        n_ctx = model_config['n_ctx']
        act_type = model_config['act_type']
        pos_embed_type = model_config['pos_embed_type']
        
        self.cache = {}
        
        self.embed = Embed(d_vocab, d_model)
        if pos_embed_type == 'sine':
            self.pos_embed = PosEmbedSine(n_ctx, d_model)
        elif pos_embed_type == 'orig':
            self.pos_embed = PosEmbed(n_ctx, d_model)
        
        self.attn_blocks = nn.ModuleList([AttnBlock(d_model, d_head, num_heads, n_ctx, act_type, model=[self]) for i in range(num_layers)])
        # self.ln = LayerNorm(d_model)
        self.unembed = Unembed(d_vocab, d_model)

        for name, module in self.named_modules():
            if type(module)==HookPoint:
                module.give_name(name)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.pos_embed(x)
        for attn_block in self.attn_blocks:
            x = x + attn_block(x)
        x = self.unembed((x))
        # x = self.unembed(self.ln(x))
        return x
    
    def hook_points(self):
        return [module for name, module in self.named_modules() if 'hook' in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks('fwd')
            hp.remove_hooks('bwd')
    
    def cache_all(self, cache, incl_bwd=False):
        def save_hook(mod, inp, outp, name):
            cache[name] = outp.detach()
        def save_hook_back(mod, inp, outp, name):
            cache[name+'_grad'] = outp[0].detach()
        for hp in self.hook_points():
            hp.add_hook(save_hook, 'fwd')
            if incl_bwd:
                hp.add_hook(save_hook_back, 'bwd')

Data creation

In [None]:
if data_config['is_finite']:
    data_config_temp = dict(data_config)
    data_config_temp['batch_size'] = data_config['batch_size']*data_config['num_batches']
    train_data, train_mask = next(generate_dataset(**data_config_temp))
    train_data_loader = DataLoader([(train_data[i], train_mask[i]) for i in range(train_data.shape[0])], batch_size=data_config['batch_size'], shuffle=True)
    data_config_temp['batch_size'] = data_config['batch_size']*5
    data_config_temp['seed']+=1
    test_data, test_mask = next(generate_dataset(**data_config_temp))
    test_data = test_data.to('cuda')
    test_mask = test_mask.to('cuda')

else:
    ds = generate_dataset(**data_config)
# ds = generate_dataset_just_reps(**data_config)

    tokens, mask = (next(ds))
    print(tokens.shape, mask.shape)
    print(torch.stack([mask[0], tokens[0]], axis=1))

torch.Size([128, 85]) torch.Size([128, 85])
tensor([[  0.,   0.],
        [  0., 110.],
        [  0., 127.],
        [  0.,  67.],
        [  0.,  93.],
        [  0.,  99.],
        [  0., 103.],
        [  0.,  18.],
        [  0.,  84.],
        [  0., 107.],
        [  0., 124.],
        [  0.,  58.],
        [  0.,  87.],
        [  0.,  98.],
        [  0.,  97.],
        [  0., 114.],
        [  0., 127.],
        [  0.,  48.],
        [  0.,  74.],
        [  0.,  33.],
        [  0.,  47.],
        [  0.,  97.],
        [  0., 112.],
        [  0.,  26.],
        [  0.,  84.],
        [  0.,  79.],
        [  0., 126.],
        [  0.,  37.],
        [  0.,  97.],
        [  0.,  81.],
        [  0.,  69.],
        [  0.,  50.],
        [  0.,  56.],
        [  0.,  68.],
        [  0.,   3.],
        [  0.,  85.],
        [  0.,  40.],
        [  0.,  67.],
        [  1.,  18.],
        [  1.,  84.],
        [  1., 107.],
        [  1., 124.],
        [  1.,  58.],
        [ 

In [None]:
torch.manual_seed(model_config['seed'])
model = AttnOnlyTransformer(model_config)
model.to('cuda')
optimizer = optim.AdamW(model.parameters(), lr = optim_config['lr'], weight_decay = optim_config['weight_decay'])
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/100, 1))

In [None]:
def cross_entropy(logits, tokens):
    logits = logits[:, :-1]
    tokens = logits[:, 1:]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    return -pred_log_probs()

In [None]:
tokens = torch.randint(low=1, high=vocab_size, size=(5, 5), device='cuda')
tokens
logits = model(tokens)
print(logits)
# print(cross_entropy(logits, tokens))

tensor([[[-1.9238e+00, -7.6283e-01,  5.9798e-01,  ..., -1.5771e+00,
          -3.3662e-02,  3.0776e-01],
         [-8.4826e-01,  9.1209e-01,  6.0311e-01,  ..., -1.6651e+00,
          -6.9614e-01, -5.0488e-01],
         [-9.0552e-01, -1.8423e-01,  1.0147e+00,  ..., -6.6479e-01,
          -5.1940e-01,  2.9453e-01],
         [-6.0241e-01,  5.8092e-01,  3.4893e-01,  ..., -1.2229e+00,
          -9.6616e-02,  6.4790e-01],
         [-9.1785e-01, -5.1596e-01,  1.0497e-02,  ..., -1.9060e+00,
          -9.8223e-01, -4.9816e-02]],

        [[-2.8323e+00,  3.5125e-01,  7.9934e-01,  ..., -3.7355e+00,
          -1.0634e+00,  7.9991e-01],
         [-9.8148e-01,  7.4357e-01, -3.1180e-01,  ..., -9.8548e-01,
          -6.3240e-01,  3.8766e-01],
         [-3.2494e-01, -3.0337e-01, -9.6819e-02,  ..., -2.4075e-01,
          -1.5758e-03,  6.4609e-02],
         [-2.8085e-01,  6.0991e-01, -1.1624e-01,  ..., -2.6191e-01,
          -2.8087e-01,  5.7244e-01],
         [-9.6088e-03, -1.2049e-01,  1.0920e-01,  ...

In [None]:
uniform_loss = np.log(d_vocab - 1)
optimal_loss = np.log(d_vocab - 1) * (seq_len - 1 -(num_repeats-1)*(subseq_len-1))/(seq_len - 1)
print('Uniform Loss:', uniform_loss)
print('Optimal Loss:', optimal_loss)

Uniform Loss: 4.844187086458591
Optimal Loss: 3.1141202698662376


Loss fn

In [None]:
def cross_entropy_masked(logits, tokens, mask, cutoff=0):
    # Shapes batch x ctx x vocab, batch x ctx, batch x ctx
    l = logits[:, cutoff:-1]
    t = tokens[:, 1+cutoff:]
    logprobs = F.log_softmax(l, dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=t[:, :, None], dim=-1)[:, :, 0]
    prediction_logprobs = prediction_logprobs * mask[:, cutoff:-1]
    loss = -torch.mean(prediction_logprobs)
    loss = loss*mask.numel()/mask.sum()
    return loss

Infinite data training

In [None]:
if not data_config['is_finite']:
    run_name = f"grok_induction_infinite_{int(time.time())}"
    loss_ewma = 5.2
    loss_beta = 0.99
    train_losses = []
    for step in tqdm.tqdm(range(optim_config['num_epochs'])):
        tokens, mask = (next(ds))
        tokens = tokens.to('cuda')
        mask = mask.to('cuda')
        logits = model(tokens)
        loss = cross_entropy_masked(logits, tokens, mask)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        loss = loss.item()
        loss_ewma = loss_beta * loss_ewma + (1-loss_beta) * loss
        train_losses.append(loss)
        if step % 50 == 0:
            print(step, loss_ewma, loss)
        if step % 2000 == 0 and step>0:
            px.line(train_losses).show()

  0%|          | 0/6000 [00:00<?, ?it/s]

0 5.198708530398921 5.070853039892095
50 5.1347850132185 5.017400270183791
100 5.072127880628849 4.950358738694044
150 5.014603326635928 4.910722228480384
200 4.9674440635752255 4.878750908728097
250 4.927089332887405 4.850178688287824
300 4.8894531721051795 4.827229696259434
350 4.849251952818847 4.749837403573717
400 4.798642773539804 4.68058216638146
450 4.746607703525436 4.629260581648273
500 4.704810507264168 4.650207082983032
550 4.668284355653492 4.581873239895757
600 4.633850802806748 4.59181951106866
650 4.596435074057646 4.535093872493784
700 4.558401715007702 4.480438318912453
750 4.511788070651714 4.409600285111025
800 4.389210580496837 3.990467095516838
850 4.043444491309172 3.0440824774965103
900 3.457720465730496 2.170818824663617
950 2.8401044949393066 1.6717066325042889
1000 2.2815556710324576 1.2721886304245835
1050 1.8412405932582459 1.1403186385719413
1100 1.5116249100112649 0.8528238515914093
1150 1.2693568390114163 0.8781049689752685
1200 1.089882488023431 0.69709

2050 0.4521082666438918 0.45571281389725626
2100 0.4453627867729178 0.4273876849885371
2150 0.44037980686457784 0.43640999255581836
2200 0.43709159571859646 0.4642181935512849
2250 0.43348162323397293 0.42180217143127374
2300 0.4257909369082881 0.38648746771863646
2350 0.40415310696485734 0.318581777064077
2400 0.3807080283510711 0.30381543781098547
2450 0.3573892486293462 0.3334900455996973
2500 0.33957055797199304 0.33210879103149266
2550 0.3273058609724144 0.3025668509669586
2600 0.321459611661896 0.34914543619058186
2650 0.31532807936746576 0.2929284849605479
2700 0.3115765960125252 0.28245160915487155
2750 0.3088536226334721 0.2662684967218313
2800 0.30575624930789985 0.31101975478461114
2850 0.3039198442363036 0.3120920102396044
2900 0.3013667445984078 0.2956302539018158
2950 0.30114360580643523 0.3023615758564837
3000 0.3002486621060656 0.2898743129653053
3050 0.298008356633825 0.2988659430471086
3100 0.295798193639824 0.27799332167347396
3150 0.29635491945450376 0.3098831176412

4050 0.27936612539364497 0.24959494443109292
4100 0.2798955905743231 0.34436923763332233
4150 0.2761643581085577 0.25257361825077085
4200 0.2754098231457297 0.28313485312777753
4250 0.27435842290465073 0.28004720825951257
4300 0.27425050346941354 0.24017166303674944
4350 0.2753821821777369 0.2723176951204236
4400 0.2747042210493737 0.27108625156299165
4450 0.27289229920359226 0.29680984165826146
4500 0.2739427759979218 0.2771792474940879
4550 0.2726358336318416 0.26868022461421975
4600 0.2762738847757033 0.30379312523071106
4650 0.27416808933101267 0.25236732244812177
4700 0.27208019509168574 0.26189409819554943
4750 0.2721732537987758 0.2574995136954828
4800 0.27158386926394656 0.28579161649419516
4850 0.27081430343983554 0.23801403563851883
4900 0.2703561479844958 0.26982198321104034
4950 0.27245931849149313 0.2833643743423424
5000 0.27064934142902625 0.2570248705738761
5050 0.2713138347227124 0.2717971341417881
5100 0.2702241425999567 0.2431817078241079
5150 0.26922570916281935 0.26

Finite data training

In [None]:
#Finite data
if data_config['is_finite']:
    run_name = f"grok_induction_finite_{int(time.time())}"
    loss_ewma = 5.2
    loss_beta = 0.99
    train_losses = []
    test_losses = []
    for epoch in range(optim_config['num_epochs']):
        ds = iter(train_data_loader)
        for step, (tokens, mask) in enumerate(ds):
            tokens = tokens.to('cuda')
            mask = mask.to('cuda')
            logits = model(tokens)
            loss = cross_entropy_masked(logits, tokens, mask)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            loss = loss.item()
            loss_ewma = loss_beta * loss_ewma + (1-loss_beta) * loss
            train_losses.append(loss)
            # if step %10 == 0:
            #     print(step, loss_ewma, loss)
        test_loss = cross_entropy_masked(model(test_data), test_data, test_mask).item()
        test_losses.append(test_loss)
        if (epoch*data_config['num_batches'])%60 == 0 :
            print('Epoch', epoch, loss_ewma, test_loss)
        if (epoch*data_config['num_batches'] % 1000)==0 and epoch>0:
            lines([train_losses[::data_config['num_batches']], test_losses])
    torch.save({
    'model':model.state_dict(),
    'optimizer':optimizer.state_dict(),
    'scheduler':scheduler.state_dict(),
    'train_losses':train_losses,
    'test_losses':test_losses,
    'data_config':data_config,
    'optim_config':optim_config,
    'model_config':model_config,
    }, root/run_name/'final.pth')
