In [126]:
import torch
from mlp import make_dataset, CharacterLevelMLP, Linear, BatchNormalization1D, Tanh
import matplotlib.pyplot as plt
from torch.optim import Adam
from pathlib import Path
from torch.nn.functional import cross_entropy
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.functional import F
%matplotlib inline

In [2]:
names_file = Path("./names.txt")
context_window = 3
X , Y, itos = make_dataset(names_file, context_window=context_window)

In [3]:
train_X, eval_X, train_Y, eval_Y = train_test_split(X,Y,test_size=0.2)
val_X, test_X, val_Y, test_Y = train_test_split(eval_X,eval_Y,test_size=0.5)

In [85]:
def cmp(s: str, dt, t:Tensor):
    exact_match = torch.all(dt == t.grad).item()
    approximate_match = torch.allclose(dt, t.grad)
    max_difference = (dt - t.grad).abs().max().item()
    print(f"{s:15s} | exact {str(exact_match):5s} | approximate {str(approximate_match):5s} | max_diff {str(max_difference)}")
    

In [86]:
n_embedding = 10
n_hidden = 64
vocab_size = 27
block_size = context_window

In [87]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embedding), generator=g)
W1 = torch.randn((n_embedding*block_size, n_hidden), generator=g)*(5/3)/((n_embedding*block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g)*0.1
W2 = torch.randn((n_hidden, vocab_size), generator=g)*0.1
b2 = torch.randn(vocab_size, generator=g)*0.1

bngain = torch.randn((1,n_hidden))*0.1 + 1.0
bnbias = torch.randn((1,n_hidden))*0.1

parameters = [C,W1,b1,W2,b2,bngain,bnbias]
print("Total parameters",sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

Total parameters 4137


In [88]:
batch_size = 32
n = batch_size
ix = torch.randint(0, train_X.shape[0], (batch_size,), generator=g)
Xb, Yb = train_X[ix], train_Y[ix]

##### Forward pass

In [89]:
# Embed the characters into vectors
emb = C[Xb]

# Concatenate the vectors
embcat = emb.view(emb.shape[0],-1)

# Linear layer 1
hprebn = embcat @ W1 + b1

# Batch norm 1 stage
bnmean1 = (1/n)*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmean1
bndiff2 = bndiff**2
bnvar = (1/(n-1))*(bndiff2).sum(0, keepdim=True)
bnvar_inv = (bnvar+1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias

# Non linearity activation
h = torch.tanh(hpreact)

# Linear layer 2
logits = h @ W2 + b2

# Cross Entropy Loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()

# Calculate loss
loss = -logprobs[range(n), Yb].mean()

for p in parameters:
    p.grad = None
    
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmean1, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.5357, grad_fn=<NegBackward0>)

In [187]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1/n
cmp("dlogprobs",dlogprobs,logprobs)
dprobs = (1/probs)*dlogprobs
cmp("dprobs",dprobs,probs)
dcounts_sum_inv = (counts*dprobs).sum(1, keepdim=True)
cmp("dcounts_sum_inv",dcounts_sum_inv,counts_sum_inv)
dcounts_sum = (-(1/counts_sum**2) * dcounts_sum_inv)
cmp("dcounts_sum",dcounts_sum,counts_sum)
dcounts = counts_sum_inv * dprobs
dcounts += torch.ones_like(counts) * dcounts_sum
cmp("dcounts",dcounts,counts)
dnorms_logits = norm_logits.exp() * dcounts
cmp("dnorms_logits",dnorms_logits,norm_logits)
dlogit_maxes = -dnorms_logits.clone().sum(1, keepdim=True)
cmp("dlogit_maxes",dlogit_maxes,logit_maxes)
dlogits = dnorms_logits.clone()
# Karpathys optimized solution
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
# # My simple dump solution
# dlogits_extra = torch.zeros_like(logits)
# for i, max_ix in enumerate(logits.max(1).indices):
#     dlogits_extra[i, max_ix] = dlogit_maxes[i]
# or
# dlogits_extra[torch.arange(logits.size(0)), max_indices] = dlogit_maxes.squeeze()
# dlogits += dlogits_extra
cmp("dlogits",dlogits,logits)
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.clone().sum(0)
cmp("dh",dh,h)
cmp("dW2",dW2,W2)
cmp("db2",db2,b2)
dhpreact = (1 - h**2) * dh
cmp("dhpreact",dhpreact,hpreact)
dbngain = (bnraw * dhpreact).sum(0).unsqueeze(0)
cmp("dbngain",dbngain,bngain)

dlogprobs       | exact True  | approximate True  | max_diff 0.0
dprobs          | exact True  | approximate True  | max_diff 0.0
dcounts_sum_inv | exact True  | approximate True  | max_diff 0.0
dcounts_sum     | exact True  | approximate True  | max_diff 0.0
dcounts         | exact True  | approximate True  | max_diff 0.0
dnorms_logits   | exact True  | approximate True  | max_diff 0.0
dlogit_maxes    | exact True  | approximate True  | max_diff 0.0
dlogits         | exact True  | approximate True  | max_diff 0.0
dh              | exact True  | approximate True  | max_diff 0.0
dW2             | exact True  | approximate True  | max_diff 0.0
db2             | exact True  | approximate True  | max_diff 0.0
dhpreact        | exact True  | approximate True  | max_diff 0.0
dbngain         | exact True  | approximate True  | max_diff 0.0
