In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
import math
import numpy as np
import torch as t
import torch.utils.data as tdata
import matplotlib.pyplot as plt
from einops import rearrange
from dots.training import *
from dots.trainhooks import *
from dots.models import MLP
from dots.dots import *
from dots.utils import get_device
from dots.plotting import *

In [2]:
import random
import time
from pathlib import Path
import pickle
import os
import copy
from collections import OrderedDict
import functools
import itertools
import gc

import numpy as np
import matplotlib.pyplot as plt
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import pandas as pd



In [3]:

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class Embed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))
    
    def forward(self, x):
        return torch.einsum('dbp -> bpd', self.W_E[:, x])

class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))
    
    def forward(self, x):
        return (x @ self.W_U)


class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))
    
    def forward(self, x):
        return x+self.W_pos[:x.shape[-2]]


class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon = 1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon
    
    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x


class Attention(nn.Module):
    def __init__(self, d_model, num_heads, d_head, n_ctx, model):
        super().__init__()
        self.model = model
        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_O = nn.Parameter(torch.randn(d_model, d_head * num_heads)/np.sqrt(d_model))
        self.register_buffer('mask', torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head
        
    def forward(self, x):
        k = torch.einsum('ihd,bpd->biph', self.W_K, x)
        q = torch.einsum('ihd,bpd->biph', self.W_Q, x)
        v = torch.einsum('ihd,bpd->biph', self.W_V, x)
        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])
        attn_matrix = F.softmax(attn_scores_masked/np.sqrt(self.d_head), dim=-1)
        z = torch.einsum('biph,biqp->biqh', v, attn_matrix)
        z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')
        out = torch.einsum('df,bqf->bqd', self.W_O, z_flat)
        return out


class MLP(nn.Module):
    def __init__(self, d_model, d_mlp, act_type, model):
        super().__init__()
        self.model = model
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model)/np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp)/np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        # self.ln = LayerNorm(d_mlp, model=self.model)
        assert act_type in ['ReLU', 'GeLU']
        
    def forward(self, x):
        x = torch.einsum('md,bpd->bpm', self.W_in, x) + self.b_in
        
        if self.act_type=='ReLU':
            x = F.relu(x)
        elif self.act_type=='GeLU':
            x = F.gelu(x)
        
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out
        return x


class TransformerBlock(nn.Module):
    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        self.model = model
        # self.ln1 = LayerNorm(d_model, model=self.model)
        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)
        # self.ln2 = LayerNorm(d_model, model=self.model)
        self.mlp = MLP(d_model, d_mlp, act_type, model=self.model)
    
    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.mlp((x))
        return x


class Transformer(nn.Module):
    def __init__(self, num_layers, d_vocab, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, use_ln=True):
        super().__init__()

        self.embed = Embed(d_vocab, d_model)
        self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]) for i in range(num_layers)])
        # self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln
    
    def forward(self, x):
        x = self.embed(x)
        x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)
        # x = self.ln(x)
        x = self.unembed(x)
        return x


In [5]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params


In [6]:

lr=1e-3 #@param
weight_decay = 0.1 #@param 
p=113 #@param
d_model = 128 #@param
fn_name = 'add' #@param ['add', 'subtract', 'x2xyy2','rand']
frac_train = 0.3 #@param
NUM_EPOCHS = 100000 #@param
save_models = False #@param
save_every = 1000 #@param
# Stop training when test loss is <stopping_thresh
stopping_thresh = -1 #@param
seed = 0 #@param
num_layers = 1
batch_style = 'full'
d_vocab = p+1 
n_ctx = 3
d_mlp = 4*d_model
num_heads = 4
assert d_model % num_heads == 0
d_head = d_model//num_heads
act_type = 'ReLU' #@param ['ReLU', 'GeLU']
#batch_size = BATCH_SIZE

In [7]:
def gen_train_test(frac_train = 0.3, num = 113, seed=0):
    # Generate train and test split
    pairs = [(i, j, num) for i in range(num) for j in range(num)]
    random.seed(seed)
    random.shuffle(pairs)
    div = int(frac_train*len(pairs))
    return pairs[:div], pairs[div:]

train_x, test_x = gen_train_test(frac_train, 113, seed=0)
print(len(train_x), len(test_x))

3830 8939


In [8]:
test_x[0]

(50, 67, 113)

In [9]:

def collate_batch(batch):
    label_list = [(i+j) % num for (i,j,num) in batch]
    label_list = torch.tensor(label_list, dtype=torch.int64)
    input_list = [torch.tensor(x) for x in batch]
    input_list = torch.cat(input_list).view(len(label_list), -1)
    return input_list.to(DEVICE), label_list.to(DEVICE)

# one batch per loader
train_loader = DataLoader(train_x, batch_size=len(train_x), collate_fn=collate_batch)
test_loader = DataLoader(test_x, batch_size=len(test_x), collate_fn=collate_batch)

print(len(train_loader))
for x, y in train_loader:
    print(x[:10])
    print(y[:10])

1
tensor([[ 18,  34, 113],
        [ 10,  83, 113],
        [ 55,  28, 113],
        [ 25,  51, 113],
        [ 63,  65, 113],
        [ 14, 104, 113],
        [ 27,  56, 113],
        [ 35,  77, 113],
        [ 41,  72, 113],
        [ 17, 111, 113]], device='cuda:0')
tensor([ 52,  93,  83,  76,  15,   5,  83, 112,   0,  15], device='cuda:0')


In [10]:
batch = next(iter(train_loader))
batch_x = batch[0]
batch_y = batch[1]

In [11]:
batch_x.shape

torch.Size([3830, 3])

In [12]:
batch_x[0]

tensor([ 18,  34, 113], device='cuda:0')

In [13]:
gk = load_model("grokker")

In [14]:
gk

Transformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (attn): Attention()
      (mlp): MLP()
    )
  )
  (unembed): Unembed()
)

In [15]:
count_parameters(gk)

226816

In [16]:
gk(batch_x[:1]).shape

torch.Size([1, 3, 114])

# Test transformer init

In [17]:
import dots.models as dm

In [18]:
lr=1e-3 #@param
weight_decay = 0.1 #@param 
p=113 #@param
d_model = 128 #@param
fn_name = 'add' #@param ['add', 'subtract', 'x2xyy2','rand']
frac_train = 0.3 #@param
NUM_EPOCHS = 100000 #@param
save_models = False #@param
save_every = 1000 #@param
# Stop training when test loss is <stopping_thresh
stopping_thresh = -1 #@param
seed = 0 #@param
num_layers = 1
batch_style = 'full'
d_vocab = p+1 
n_ctx = 3
d_mlp = 4*d_model
num_heads = 4
assert d_model % num_heads == 0
d_head = d_model//num_heads
act_type = 'ReLU' #@param ['ReLU', 'GeLU']
#batch_size = BATCH_SIZE

In [19]:
transformer = dm.Transformer(
    num_layers, d_vocab, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, use_ln=True
).to(DEVICE)

In [20]:
transformer.count_params()

226816

In [21]:
transformer.get_param_tensor().shape

torch.Size([226816])

In [22]:
mjac = matrix_jacobian(transformer, batch_x[:100])



In [23]:
mjac.shape

torch.Size([11400, 226816])

In [24]:
mjac.element_size() * mjac.numel()

10342809600

In [26]:
t.__version__

'2.0.1+cu117'

In [29]:
r = transformer.jacobian_matrix_rank(batch_x[:20])

In [30]:
r

tensor(1275, device='cuda:0')

In [None]:
S.shape