In [100]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

### Read in all the words and build vocabulary

In [101]:
words = open("names.txt", "r").read().splitlines()
chars = sorted(list(set("".join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi["."] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(stoi)

### Build the dataset

`X` are the sequences of the characters\
`Y` are the next characters after the sequence

`tr` training set\
`dev` validation set\
`te` test set\

In [None]:
block_size=3

def build_dataset(words):
    block_size = 3 # context length: how many characters do we take to predict the next one? 
    X, Y = [], [] # input, labels
    for w in words:
        context = [0] * block_size
        for ch in w + ".":
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] #crop and append new character

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

In [103]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f"{s:15s} | exact: {str(ex):5s} | approx: {str(app):5s} | maxdiff: {maxdiff}")

### Manual Backpropagation

In [105]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

In [106]:
# forward pass, "chunkated" into smaller steps that are possible to backward on at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # cancatenate the vectors
# Linear Layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre activation
# BatchNorm Layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2.sum(0, keepdim=True)) # Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear Layer 2
logits = h @ W2 + b2 # output layer
# Cross Entropy loss (same as F.corss_entropy(logits, Yb)) 
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum ** -1 # if I use (1.0/counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# Pytorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
print("Current loss: " + str(loss.item()))

Current loss: 3.348198175430298


In [107]:
# Backprop through the whole thing one by one

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n
cmp("logprobs", dlogprobs, logprobs)

dprobs = (1.0/probs) * dlogprobs
cmp("probs", dprobs, probs)

dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
cmp("counts_sum_inv", dcounts_sum_inv, counts_sum_inv)

dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
cmp("counts_sum", dcounts_sum, counts_sum)

dcounts = counts_sum_inv * dprobs
dcounts += torch.ones_like(counts) * dcounts_sum
cmp("counts", dcounts, counts)

dnorm_logits = counts * dcounts
cmp("norm_logits", dnorm_logits, norm_logits)

dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
cmp("logit_maxes", dlogit_maxes, logit_maxes)

dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
cmp("logits", dlogits, logits)

dh = dlogits @ W2.T
cmp("h", dh, h)

dW2 = h.T @ dlogits
cmp("W2", dW2, W2)

db2 = dlogits.sum(0)
cmp("b2", db2, b2)

dhpreact = (1 - h**2) * dh
cmp("hpreact", dhpreact, hpreact)

dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
cmp("bngain", dbngain, bngain)

dbnraw = bngain * dhpreact
cmp("bnraw", dbnraw, bnraw)

dbnbias = dhpreact.sum(0, keepdim=True)
cmp("bnbias", dbnbias, bnbias)

dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
cmp("bnvar_inv", dbnvar_inv, bnvar_inv)

dbnvar = -0.5*(bnvar + 1e-5)**-1.5 * dbnvar_inv
cmp("bnvar", dbnvar, bnvar)

dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar
cmp("bndiff2", dbndiff2, bndiff2)

dbndiff = bnvar_inv * dbnraw
dbndiff += 2*bndiff * dbndiff2
cmp("bndiff", dbndiff, bndiff)

dbnmeani = (-dbndiff).sum(0, keepdim=True)
cmp("bnmeani", dbnmeani, bnmeani)

dhprebn = dbndiff.clone()
dhprebn += (1.0/n) * torch.ones_like(hprebn) * dbnmeani
cmp("hprebn", dhprebn, hprebn)

dembcat = dhprebn @ W1.T
cmp("embcat", dembcat, embcat)

dW1 = embcat.T @ dhprebn
cmp("W1", dW1, W1)

db1 = dhprebn.sum(0) 
cmp("b1", db1, b1)

demb = dembcat.view(emb.shape)
cmp("demb", demb, emb)

dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k, j]
        dC[ix] += demb[k, j]
cmp("C", dC, C)


logprobs        | exact: True  | approx: True  | maxdiff: 0.0
probs           | exact: True  | approx: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approx: True  | maxdiff: 0.0
counts_sum      | exact: True  | approx: True  | maxdiff: 0.0
counts          | exact: True  | approx: True  | maxdiff: 0.0
norm_logits     | exact: True  | approx: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approx: True  | maxdiff: 0.0
logits          | exact: True  | approx: True  | maxdiff: 0.0
h               | exact: True  | approx: True  | maxdiff: 0.0
W2              | exact: True  | approx: True  | maxdiff: 0.0
b2              | exact: True  | approx: True  | maxdiff: 0.0
hpreact         | exact: True  | approx: True  | maxdiff: 0.0
bngain          | exact: True  | approx: True  | maxdiff: 0.0
bnraw           | exact: True  | approx: True  | maxdiff: 0.0
bnbias          | exact: True  | approx: True  | maxdiff: 0.0
bnvar_inv       | exact: True  | approx: True  | maxdiff: 0.0
bnvar   

In [None]:
#forward pass

# before:

# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdim=True)
# counts_sum_inv = counts_sum ** -1 # if I use (1.0/counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()

#now:

loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), "diff:", (loss_fast - loss).item())

In [None]:
# backward pass

# before:

# dlogprobs = torch.zeros_like(logprobs)
# dlogprobs[range(n), Yb] = -1.0 / n
# dprobs = (1.0/probs) * dlogprobs
# dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
# dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
# dcounts = counts_sum_inv * dprobs
# dcounts += torch.ones_like(counts) * dcounts_sum
# dnorm_logits = counts * dcounts
# dlogits = dnorm_logits.clone()
# dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
# dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes

dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n
cmp("logits", dlogits, logits)


In [None]:
# forward pass 

#before:

# bnmeani = 1/n*hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# bnvar = 1/(n-1)*(bndiff2.sum(0, keepdim=True)) # Bessel's correction (dividing by n-1, not n)
# bnvar_inv = (bnvar + 1e-5) ** -0.5
# bnraw = bndiff * bnvar_inv
# hpreact = bngain * bnraw + bnbias

#now:

hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias
print("max diff:", (hpreact_fast - hpreact).abs().max().item())

In [None]:
# backward pass

#before
# dbnraw = bngain * dhpreact
# dbnbias = dhpreact.sum(0, keepdim=True)
# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
# dbnvar = -0.5*(bnvar + 1e-5)**-1.5 * dbnvar_inv
# dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar
# dbndiff = bnvar_inv * dbnraw
# dbndiff += 2*bndiff * dbndiff2
# dbnmeani = (-dbndiff).sum(0, keepdim=True)
# dhprebn = dbndiff.clone()
# dhprebn += (1.0/n) * torch.ones_like(hprebn) * dbnmeani

#now:

dhprebn = bngain * bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
cmp("hprebn", dhprebn, hprebn)

# Summary

In [120]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [154]:
#init
n_emd = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_emd),                generator=g)
#Layer 1
W1 = torch.randn((n_emd*block_size,n_hidden),       generator=g) * (5/3)/((n_emd*block_size)**0.5)
b1 = torch.randn(n_hidden,                          generator=g) * 0.01
#Layer 2
W2 = torch.randn((n_hidden, vocab_size),            generator=g) * 0.1
b2 = torch.randn(vocab_size,                        generator=g) * 0.1

# Batch Norm Parameters
bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.zeros((1, n_hidden)) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
    p.requires_grad = True

#optimalization
max_steps = 200000
batch_size=32
n = batch_size
lossi=[]


with torch.no_grad(): # I am not gonna call backwar => we are doing it manually
    for i in range(max_steps):
        #minibatch construct
        ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
        Xb, Yb = Xtr[ix], Ytr[ix]

        #forward pass
        emb = C[Xb] # embedd characters into vectors
        embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
        #Linear Layer
        hprebn = embcat @ W1 + b1 # hidden layer pre activation

        #Batchnorm Layer
        bnmean = hprebn.mean(0, keepdim=True)
        bnvar = hprebn.var(0, keepdim=True, unbiased=True)
        bnvar_inv = (bnvar + 1e-5) ** -0.5
        bnraw = (hprebn - bnmean) * bnvar_inv
        hpreact = bngain * bnraw + bnbias

        #Non-linearity
        h = torch.tanh(hpreact) # hidden layer
        logits = h @ W2 + b2 # output layer
        loss = F.cross_entropy(logits, Yb) # loss function

        # backward pass
        for p in parameters:
            p.grad = None
        # loss.backward() => not needed anymore because of manually below

        #manual backprop
        dlogits = F.softmax(logits, 1)
        dlogits[range(n), Yb] -= 1
        dlogits /= n
        # 2nd layer backprop
        dh = dlogits @ W2.T
        dW2 = h.T @ dlogits
        db2 = dlogits.sum(0)
        # tahn
        dhpreact = (1 - h**2) * dh
        # batchnorm backprop
        dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
        dbnbias = dhpreact.sum(0, keepdim=True)
        dhprebn = bngain * bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
        # 1st layer
        dembcat = dhprebn @ W1.T
        dW1 = embcat.T @ dhprebn
        db1 = dhprebn.sum(0)
        #embedding
        demb = dembcat.view(emb.shape)
        dC = torch.zeros_like(C)
        for k in range(Xb.shape[0]):
            for j in range(Xb.shape[1]):
                ix = Xb[k, j]
                dC[ix] += demb[k, j]
        grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]

        #update
        lr = 0.1 if i < 100000 else 0.01 # learning rate decay
        for p, grad in zip(parameters, grads):
            # p.data += -lr* p.grad
            p.data += -lr * grad

        # track stats
        if i % 10000 == 0:
            print(f"{i:7d}/{max_steps:7d}: {loss.item():.4f}")
        lossi.append(loss.log10().item())

        # if i <= 100: # using this to check our gradients
        #     break


12297
      0/ 200000: 3.6560
  10000/ 200000: 2.4832
  20000/ 200000: 2.3983
  30000/ 200000: 2.1027
  40000/ 200000: 1.9862
  50000/ 200000: 2.4499
  60000/ 200000: 2.3097
  70000/ 200000: 2.0755
  80000/ 200000: 1.9685
  90000/ 200000: 1.9662
 100000/ 200000: 2.4539
 110000/ 200000: 2.1895
 120000/ 200000: 2.1088
 130000/ 200000: 2.4104
 140000/ 200000: 2.2916
 150000/ 200000: 2.3604
 160000/ 200000: 2.1177
 170000/ 200000: 1.9711
 180000/ 200000: 2.3095
 190000/ 200000: 1.8517


In [None]:
# for p, g in zip(parameters, grads):
#     cmp(str(tuple(p.shape)), g,p)

In [157]:
# calibrate the batch norm at the end

with torch.no_grad():
    emb = C[Xtr] # embedd characters into vectors
    embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
    hpreact = embcat @ W1 + b1 # hidden layer pre activation
    bnmean = hpreact.mean(0, keepdim=True)
    bnvar = hpreact.var(0, keepdim=True, unbiased=True)

In [159]:
@torch.no_grad() # disabled gradient tracking
def split_loss(split):
    x, y = {
        "train": (Xtr, Ytr),
        "val": (Xdev, Ydev),
        "test": (Xte, Yte),
    }[split]
    emb = C[x] # (N, block_size, n_emb)
    embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_emb)
    hpreact = embcat @ W1 + b1 # hidden layer pre-activation
    hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5) ** -0.5 + bnbias
    h = torch.tanh(hpreact) # (N, n_hidden)
    logits = h @ W2 + b2 # (N, vocab_size)
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss("train")
split_loss("val")

train 2.0673887729644775
val 2.1106348037719727


In [160]:
# sample from the model
g = torch.Generator().manual_seed(2147483647 + 10)

for _ in range(20):
    out = []
    context = [0] * block_size
    while True:
        emb = C[torch.tensor([context])]
        embcat = emb.view(emb.shape[0], -1)
        hpreact = embcat @ W1 + b1
        hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5) ** -0.5 + bnbias
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2

        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break
    print("".join(itos[i] for i in out))


carman.
ambrilli.
kimri.
reet.
khalaysie.
mahnee.
den.
rha.
kaeli.
ner.
kiah.
maiivon.
leigh.
ham.
join.
quint.
sulie.
alianni.
watell.
dearisia.
