In [9]:
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt

words = open("data/names.txt", "r").read().splitlines()
chars = sorted(list(set("".join(words))))

stoi = {char: i+1 for i, char in enumerate(chars)}
stoi["."] = 0

itos = {i: char for char, i in stoi.items()}
vocab_size = len(itos)

block_size = 3
import random
def build_dataset(words):
    
    X, Y = [], []

    for word in words:
        context = [0] * block_size
        for char in word + ".":
            idx = stoi[char]
            X.append(context)
            Y.append(idx)
            context = context[1:] + [idx]
    
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y


random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

X_train, Y_train = build_dataset(words[:n1])
X_val, Y_val = build_dataset(words[n1:n2])
X_test, Y_test = build_dataset(words[n2:])


In [25]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f"{s:15s} / exact: {str(ex):5s} / approximate: {str(app):5s} / maxdiff: {maxdiff}")

In [11]:
# mlp
n_embd = 10
n_hidden = 200

C = torch.randn((vocab_size, n_embd))
W1 = torch.randn((n_embd*block_size), n_hidden) * ((5/3) / ((n_embd * block_size) **0.5)) # kaiming init
b1 = torch.randn(n_hidden) * 0.1 #usually not necessary for batchnorm -> just to check for backprop
W2 = torch.randn((n_hidden, vocab_size)) * 0.1
b2 = torch.randn(vocab_size) * 0.1

bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.zeros((1, n_hidden)) * 0.1
#paramters are initialized in a not standart way to unmask potential errors in bachproapgation implementaion

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
    p.requires_grad = True
print(sum(p.nelement() for p in parameters))

12297


In [12]:
batch_size = 32
n = batch_size
ix = torch.randint(0, X_train.shape[0], (batch_size,))
Xb, Yb = X_train[ix], Y_train[ix]

In [None]:
emb = C[Xb] # 32x3x10
embcat = emb.view(emb.shape[0], -1) #32x30
hprebn = embcat @ W1 + b1 #32x200
#batchnorm
bnmeani = 1/n * hprebn.sum(0, keepdim=True) #mittelwert jedes neurons über batch hinweg
bndiff  = hprebn - bnmeani #bnmeani braodcasten und von hprebn abieziehen -> zentierte werde
bndiff2 = bndiff**2 
bnvar   = 1/(n-1) * (bndiff2).sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw   = bndiff * bnvar_inv #reine normalisierung über batch
hpreact = bngain * bnraw + bnbias #freiheit: skalierung und mittelwert verändern
# Non-linearity
h = torch.tanh(hpreact)              
# Linear layer 2
logits = h @ W2 + b2 # 32 x 27           
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes  #damit exp nicht overflowed von jeder zeile maximalwert abziehebn
counts = norm_logits.exp() 
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1      
probs = counts * counts_sum_inv      
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean() 

#pytorch backward pass
for p in parameters:
    p.grad = None
for i in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits,h,hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani, embcat, emb]:
    i.retain_grad()
loss.backward()
loss

tensor(3.5762, grad_fn=<NegBackward0>)

### manual backprop

In [39]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n
dprobs = (1.0 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim = True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum
dnorm_logits = norm_logits.exp() * dcounts
dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim= True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes #oder wie bei dlobprobs
dh = dlogits @ W2.T 
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = (1-h**2) *dh
dbngain = (bnraw * dhpreact).sum(0, keepdim = True)
dbnraw = bngain * dhpreact
dbnbias = dhpreact.sum(0, keepdim = True)
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim= True)
dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv
dbndiff2 = (1.0 / (n - 1)) * torch.ones_like(bndiff2) * bnvar
dbndiff += (2 * bndiff) * dbndiff2
dhprebn = dbndiff.clone()
dbnmeani = (-dbndiff).sum(0)
dhprebn += 1.0 / n * (torch.ones_like(hprebn) * dbnmeani)
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k,j]

### better option with math and simplifikation

In [None]:
#for softmax
dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n

#for cross entropy
