In [1]:
import torch
from mlp import make_dataset, CharacterLevelMLP, Linear, BatchNormalization1D, Tanh
import matplotlib.pyplot as plt
from torch.optim import Adam
from pathlib import Path
from torch.nn.functional import cross_entropy
from sklearn.model_selection import train_test_split
from torch import Tensor
%matplotlib inline

In [2]:
names_file = Path("./names.txt")
context_window = 3
X , Y, itos = make_dataset(names_file, context_window=context_window)

In [3]:
train_X, eval_X, train_Y, eval_Y = train_test_split(X,Y,test_size=0.2)
val_X, test_X, val_Y, test_Y = train_test_split(eval_X,eval_Y,test_size=0.5)

In [4]:
def cmp(s, dt, t:Tensor):
    exact_match = torch.all(dt == t.grad).item()
    approximate_match = torch.allclose(dt, t.grad)
    max_difference = (dt - t.grad).abs().max().item()
    print(f"{s:15s} | exact {str(exact_match):5s} | approximate {str(approximate_match):5s} | max_diff {str(max_difference)}")
    

In [5]:
n_embedding = 10
n_hidden = 64
vocab_size = 27
block_size = context_window

In [6]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embedding), generator=g)
W1 = torch.randn((n_embedding*block_size, n_hidden), generator=g)*(5/3)/((n_embedding*block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g)*0.1
W2 = torch.randn((n_hidden, vocab_size), generator=g)*0.1
b2 = torch.randn(vocab_size, generator=g)*0.1

bngain = torch.randn((1,n_hidden))*0.1 + 1.0
bnbias = torch.randn((1,n_hidden))*0.1

parameters = [C,W1,b1,W2,b2,bngain,bnbias]
print("Total parameters",sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

Total parameters 4137


In [7]:
batch_size = 32
n = batch_size
ix = torch.randint(0, train_X.shape[0], (batch_size,), generator=g)
Xb, Yb = train_X[ix], train_Y[ix]

##### Forward pass

In [13]:
# Embed the characters into vectors
emb = C[Xb]

# Concatenate the vectors
embcat = emb.view(emb.shape[0],-1)

# Linear layer 1
hprebn = embcat @ W1 + b1

# Batch norm 1 stage
bnmean1 = (1/n)*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmean1
bndiff2 = bndiff**2
bnvar = (1/(n-1))*(bndiff2).sum(0, keepdim=True)
bnvar_inv = (bnvar+1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias

# Non linearity activation
h = torch.tanh(hpreact)

# Linear layer 2
logits = h @ W2 + b2

# Cross Entropy Loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()

# Calculate loss
loss = -logprobs[range(n), Yb].mean()

for p in parameters:
    p.grad = None
    
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmean1, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.2687, grad_fn=<NegBackward0>)

In [12]:
loss

tensor(3.2687, grad_fn=<NegBackward0>)