# GROKKING: GENERALIZATION BEYOND OVERFITTING ON SMALL ALGORITHMIC DATASETS

### *deep dive*

### Step0: Imports

In [None]:
# basic
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import TensorDataset,DataLoader


# other

from tqdm import tqdm

### Step1 generate data

In [None]:
# binary operation tables
# x + y mod p

p=5
x,y = torch.meshgrid(torch.arange(p),torch.arange(p),indexing="ij")
X=torch.stack([x.reshape(-1),y.reshape(-1)],dim=1)
X # 25,2 ; 1 :x,y ... 25:x,y
Y = (X[:,0]+X[:,1])%p

X[14],Y[14]

In [None]:
f = 0.5
split = int(len(X)*f) 

indices = torch.randperm(len(X))
idx_train ,idx_val = indices[:split],indices[:split]
X_train,X_val = X[idx_train],X[idx_val]
Y_train,Y_val = Y[idx_train],Y[idx_val]

train_dataset = TensorDataset(X_train,Y_train)
val_dataset = TensorDataset(X_val,Y_val)
trainloader = DataLoader(train_dataset,shuffle=True)
valloader = DataLoader(val_dataset,shuffle=False)

### Step2 The model 

In [1]:
class MyNet(nn.Module):
    def __init__(self,p=5):
        super().__init__()
        self.embedder = nn.Embedding(p,128)
        dec_layer = nn.TransformerEncoderLayer(d_model=128,nhead=4,dim_feedforward=128,batch_first=True)
        self.dec = nn.TransformerEncoder(dec_layer,num_layers=2)
        self.to_vocab = nn.Linear(128,p)


    
    
    def forward(self,x):
        embeddings = self.embedder(x)
        h = self.dec(embeddings)
        h = h.mean(axis=1)
        logits = self.to_vocab(h)
        return logits

NameError: name 'nn' is not defined

In [None]:
model = MyNet(p=5)


inp = torch.tensor([[1,2]])
out = model(inp)

out

In [None]:
device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device : " ,device)
model = MyNet(p=5)

optimizer = optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-2,betas=(0.9,0.98))
steps =1e6
loss_fn = nn.CrossEntropyLoss()

step =0
history =[]
while step<steps:
    model.train()
    for x,y in trainloader:
        x,y = x.to(device),y.to(device)
        step +=1

        optimizer.zero_grad()
        logits = model(x)
        loss = loss_fn(logits,y)
        loss.backward()
        optimizer.step()

    train_acc = (logits.argmax(dim=1)==y).float().mean().item()
    if step % 1000==0:
        model.eval()
        with torch.no_grad():
            correct =0 
            total = 0
            for x_val,y_val in valloader:
                x_val,y_val = x_val.to(device),y_val.to(device)

                pred = model(x_val).argmax(dim=1)
                correct += (pred==y_val).sum().item()
                total += y_val.sum().item()
            val_acc = correct/total
            print("Step : ",step, "val_acc : ",val_acc)
            history.append([step,loss.detach(),train_acc,val_acc])


