In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random

In [2]:
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
block_size = 3


def build_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y


random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte,  Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [5]:
def cmp(s: str, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
n_embd = 10
n_hidden = 64
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)

# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * \
    (5/3)/((n_embd * block_size)**0.5)
# using b1 just for fun, it's useless because of BN
b1 = torch.randn(n_hidden, generator=g) * 0.1

# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

# BatchNorm parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

4137


In [7]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

In [8]:
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
# Linear layer 1
hprebn = embcat @ W1 + b1
# BatchNorm layer
bnmeani = 1/n * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
# note: Bessel's correction (dividing by n-1, not n)
bnvar = 1/(n-1) * bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
h = torch.tanh(hpreact)
# Linear layer 2
logits = h @ W2 + b2
# Cross entropy loss
logit_maxes = logits.max(1, keepdim=True).values    # shape: [32, 1]
norm_logits = logits - logit_maxes  # subtract max for numerical stability
counts = norm_logits.exp()  # shape: [32, 27]
counts_sum = counts.sum(1, keepdims=True)   # [32, 1]
counts_sum_inv = counts_sum**-1  # [32, 1]
probs = counts * counts_sum_inv  # [32, 27]
logprobs = probs.log()  # [32, 27]
loss = -logprobs[range(n), Yb].mean()
# logprobs[range(n), Yb].shape is [32] and loss is a scalar
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3401, grad_fn=<NegBackward0>)

In [9]:
# Exercise 1

In [10]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = - 1.0 / n
cmp('dlogprobs', dlogprobs, logprobs)
dprobs = dlogprobs * (1.0 / probs)
cmp('dprobs', dprobs, probs)

dlogprobs       | exact: True  | approximate: True  | maxdiff: 0.0
dprobs          | exact: True  | approximate: True  | maxdiff: 0.0


In [11]:
# probs.shape [32, 27]
# counts.shape [32, 27]
# counts_sum_inv.shape [32, 1]
# probs = counts * counts_sum_inv
# a: 3x3, b: 3x1 -> c: 3x3
# a * b
# a00*b0 a01*b0 a02*b0
# a10*b1 a11*b1 a12*b1
# a20*b2 a21*b2 a22*b2

In [12]:
dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)
cmp('dcounts_sum_inv', dcounts_sum_inv, counts_sum_inv)

dcounts_sum_inv | exact: True  | approximate: True  | maxdiff: 0.0


In [13]:
dcounts_sum = dcounts_sum_inv * (-1 * counts_sum**-2)
cmp('dcounts_sum', dcounts_sum, counts_sum)

dcounts_sum     | exact: True  | approximate: True  | maxdiff: 0.0


In [14]:
dcounts = dcounts_sum * torch.ones((1, 27)) + dprobs * counts_sum_inv
cmp('dcounts', dcounts, counts)

dcounts         | exact: True  | approximate: True  | maxdiff: 0.0


In [15]:
dnorm_logits = dcounts * norm_logits.exp()
cmp('dnorm_logits', dnorm_logits, norm_logits)

dnorm_logits    | exact: True  | approximate: True  | maxdiff: 0.0


In [16]:
dlogit_maxes = (dnorm_logits * -1.0).sum(1, keepdim=True)
cmp('dlogit_maxes', dlogit_maxes, logit_maxes)

dlogit_maxes    | exact: True  | approximate: True  | maxdiff: 0.0


In [17]:
logits_derivative = torch.zeros_like(logits)
logits_derivative[torch.arange(32), logits.argmax(1)] = 1.0

dlogits = dnorm_logits * 1.0 + dlogit_maxes * logits_derivative
cmp('dlogits', dlogits, logits)

dlogits         | exact: True  | approximate: True  | maxdiff: 0.0


In [18]:
dh = dlogits @ W2.T
cmp('dh', dh, h)
dW2 = h.T @ dlogits
cmp('dW2', dW2, W2)
db2 = dlogits.sum(0)
cmp('db2', db2, b2)

dh              | exact: True  | approximate: True  | maxdiff: 0.0
dW2             | exact: True  | approximate: True  | maxdiff: 0.0
db2             | exact: True  | approximate: True  | maxdiff: 0.0


In [19]:
dhpreact = dh * (1 - h**2)
cmp('dhpreact', dhpreact, hpreact)

dhpreact        | exact: True  | approximate: True  | maxdiff: 0.0


In [20]:
dbnraw = dhpreact * bngain
cmp('dbnraw', dbnraw, bnraw)
dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
cmp('dbngain', dbngain, bngain)
dbnbias = dhpreact.sum(0, keepdim=True)
cmp('dbnbias', dbnbias, bnbias)

dbnraw          | exact: True  | approximate: True  | maxdiff: 0.0
dbngain         | exact: True  | approximate: True  | maxdiff: 0.0
dbnbias         | exact: True  | approximate: True  | maxdiff: 0.0


In [21]:
dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim=True)
dbnvar = dbnvar_inv * (-0.5 * (bnvar + 1e-5)**-1.5)
cmp('dbnvar_inv', dbnvar_inv, bnvar_inv)
cmp('dbnvar', dbnvar, bnvar)

dbnvar_inv      | exact: True  | approximate: True  | maxdiff: 0.0
dbnvar          | exact: True  | approximate: True  | maxdiff: 0.0


In [22]:
dbndiff = dbnraw * bnvar_inv
dbndiff2 = dbnvar * (1.0/(n-1)) * torch.ones_like(bndiff2)
dbndiff += dbndiff2 * 2 * bndiff
cmp('dbndiff', dbndiff, bndiff)
cmp('dbndiff2', dbndiff2, bndiff2)

dbndiff         | exact: True  | approximate: True  | maxdiff: 0.0
dbndiff2        | exact: True  | approximate: True  | maxdiff: 0.0


In [23]:
dbnmeani = (dbndiff * -1.0).sum(0, keepdim=True)
dhprebn = dbndiff * 1.0
dhprebn += dbnmeani * (1.0/n) * torch.ones_like(hprebn)
cmp('dbnmeani', dbnmeani, bnmeani)
cmp('dhprebn', dhprebn, hprebn)

dbnmeani        | exact: True  | approximate: True  | maxdiff: 0.0
dhprebn         | exact: True  | approximate: True  | maxdiff: 0.0


In [24]:
dembcat = dhprebn @ W1.t()
cmp('dembcat', dembcat, embcat)
dW1 = embcat.t() @ dhprebn
cmp('dW1', dW1, W1)
db1 = dhprebn.sum(0)
cmp('db1', db1, b1)

dembcat         | exact: True  | approximate: True  | maxdiff: 0.0
dW1             | exact: True  | approximate: True  | maxdiff: 0.0
db1             | exact: True  | approximate: True  | maxdiff: 0.0


In [25]:
demb = dembcat.reshape([batch_size, block_size, n_embd])
cmp('demb', demb, emb)

demb            | exact: True  | approximate: True  | maxdiff: 0.0


In [26]:
#  Xb: 32 * 3
#   C: 27 * 10
# emb: 32 * 3 * 10
# emb is composed of < n0 * C[0] + n1 * C[1] + ... + n26 * C[26] >
# where n0 is the sum of '0' in Xb ...
# And, n0 + n1 + ... + n26 = 32 * 3 = 96
# And, the coefficient is 1.0

In [27]:
dC = torch.zeros_like(C)
for i in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[i, j]
        dC[ix] += 1.0 * demb[i][j]
cmp('dC', dC, C)

dC              | exact: True  | approximate: True  | maxdiff: 0.0


In [28]:
# Exercise 2

In [29]:
# Cross entropy loss
# logit_maxes = logits.max(1, keepdim=True).values    # shape: [32, 1]
# norm_logits = logits - logit_maxes  # subtract max for numerical stability
# counts = norm_logits.exp()  # shape: [32, 27]
# counts_sum = counts.sum(1, keepdims=True)   # [32, 1]
# counts_sum_inv = counts_sum**-1  # [32, 1]
# probs = counts * counts_sum_inv  # [32, 27]
# logprobs = probs.log()  # [32, 27]
# loss = -logprobs[range(n), Yb].mean()
# logprobs[range(n), Yb].shape is [32] and loss is a scalar

In [30]:
dlogprobs_ = -F.softmax(logits, 1)
dlogprobs_[range(n), Yb] += 1
dlogits_fast = (-1.0/n) * dlogprobs_
cmp('dlogits_fast', dlogits_fast, logits)   # < 6e-9

dlogits_fast    | exact: False | approximate: True  | maxdiff: 7.2177499532699585e-09


In [31]:
# Exercise 3

In [32]:
# BatchNorm layer
# bnmeani = 1/n * hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# note: Bessel's correction (dividing by n-1, not n)
# bnvar = 1/(n-1) * bndiff2.sum(0, keepdim=True)
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv
# hpreact = bngain * bnraw + bnbias

In [33]:
coefficient = bngain / n * bnvar_inv
first = -dhpreact.sum(0, keepdim=True)
second = -n/(n-1) * bnraw * (bnraw * dhpreact).sum(0, keepdim=True)
third = n * dhpreact
dhprebn_fast = coefficient * (first + second + third)
cmp('dhprebn_fast', dhprebn_fast, hprebn)  # < 9e-10

dhprebn_fast    | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10


In [34]:
# Exercise 4: putting them all together!

In [35]:
n_embd = 10
n_hidden = 200

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * \
    (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1
# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

12297


In [36]:
max_steps = 200000
batch_size = 32
n = batch_size
lossi = []


for k in range(max_steps):
    # minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    # Linear layer
    hprebn = embcat @ W1 + b1
    # BatchNorm layer
    # -------------------------------------------------------------
    bnmean = hprebn.mean(0, keepdim=True)
    bnvar = hprebn.var(0, keepdim=True, unbiased=True)
    bnvar_inv = (bnvar + 1e-5)**-0.5
    bnraw = (hprebn - bnmean) * bnvar_inv
    hpreact = bngain * bnraw + bnbias
    # -------------------------------------------------------------
    # Non-linearity
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # Manual backprop!
    # -----------------
    dlogprobs_ = -F.softmax(logits, 1)
    dlogprobs_[range(n), Yb] += 1
    dlogits = (-1.0/n) * dlogprobs_
    # 2nd layer backprop
    dh = dlogits @ W2.T
    dW2 = h.T @ dlogits
    db2 = dlogits.sum(0)
    # tanh
    dhpreact = dh * (1 - h**2)
    # batchnorm backprop
    dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
    dbnbias = dhpreact.sum(0, keepdim=True)

    coefficient = bngain / n * bnvar_inv
    first = -dhpreact.sum(0, keepdim=True)
    second = -n/(n-1) * bnraw * (bnraw * dhpreact).sum(0, keepdim=True)
    third = n * dhpreact
    dhprebn = coefficient * (first + second + third)
    # 1st layer
    dembcat = dhprebn @ W1.t()
    dW1 = embcat.t() @ dhprebn
    db1 = dhprebn.sum(0)
    # embedding
    demb = dembcat.view(emb.shape)
    dC = torch.zeros_like(C)
    for i in range(Xb.shape[0]):
        for j in range(Xb.shape[1]):
            ix = Xb[i, j]
            dC[ix] += 1.0 * demb[i][j]
    grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]
    # -----------------

    # update
    lr = 0.1 if i < 100000 else 0.01
    for p, grad in zip(parameters, grads):
        p.data += -lr * p.grad

    # track stats
    if (k == 0) or (k == max_steps - 1) or ((k+1) % 10000 == 0):
        print(f'{k:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())

    if k >= 99:
        break

      0/ 200000: 3.8221


In [37]:
for p, g in zip(parameters, grads):
    cmp(str(tuple(p.shape)), g, p)

(27, 10)        | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
(30, 200)       | exact: False | approximate: True  | maxdiff: 1.1175870895385742e-08
(200,)          | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09
(200, 27)       | exact: False | approximate: True  | maxdiff: 1.1175870895385742e-08
(27,)           | exact: False | approximate: True  | maxdiff: 4.190951585769653e-09
(1, 200)        | exact: False | approximate: True  | maxdiff: 2.7939677238464355e-09
(1, 200)        | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09


In [38]:
n_embd = 10
n_hidden = 200

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * \
    (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1
# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

12297


In [39]:
max_steps = 200000
batch_size = 32
n = batch_size
lossi = []

with torch.no_grad():
    for k in range(max_steps):
        # minibatch construct
        ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
        Xb, Yb = Xtr[ix], Ytr[ix]

        # forward pass
        emb = C[Xb]
        embcat = emb.view(emb.shape[0], -1)
        # Linear layer
        hprebn = embcat @ W1 + b1
        # BatchNorm layer
        # -------------------------------------------------------------
        bnmean = hprebn.mean(0, keepdim=True)
        bnvar = hprebn.var(0, keepdim=True, unbiased=True)
        bnvar_inv = (bnvar + 1e-5)**-0.5
        bnraw = (hprebn - bnmean) * bnvar_inv
        hpreact = bngain * bnraw + bnbias
        # -------------------------------------------------------------
        # Non-linearity
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        loss = F.cross_entropy(logits, Yb)

        # Manual backprop!
        # -----------------
        dlogprobs_ = -F.softmax(logits, 1)
        dlogprobs_[range(n), Yb] += 1
        dlogits = (-1.0/n) * dlogprobs_
        # 2nd layer backprop
        dh = dlogits @ W2.T
        dW2 = h.T @ dlogits
        db2 = dlogits.sum(0)
        # tanh
        dhpreact = dh * (1 - h**2)
        # batchnorm backprop
        dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
        dbnbias = dhpreact.sum(0, keepdim=True)

        coefficient = bngain / n * bnvar_inv
        first = -dhpreact.sum(0, keepdim=True)
        second = -n/(n-1) * bnraw * (bnraw * dhpreact).sum(0, keepdim=True)
        third = n * dhpreact
        dhprebn = coefficient * (first + second + third)
        # 1st layer
        dembcat = dhprebn @ W1.t()
        dW1 = embcat.t() @ dhprebn
        db1 = dhprebn.sum(0)
        # embedding
        demb = dembcat.view(emb.shape)
        dC = torch.zeros_like(C)
        for i in range(Xb.shape[0]):
            for j in range(Xb.shape[1]):
                ix = Xb[i, j]
                dC[ix] += 1.0 * demb[i][j]
        grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]
        # -----------------

        # update
        lr = 0.1 if i < 80000 else 0.01
        for p, grad in zip(parameters, grads):
            p.data += -lr * grad

        # track stats
        if (k == 0) or (k == max_steps - 1) or ((k+1) % 10000 == 0):
            print(f'{k:7d}/{max_steps:7d}: {loss.item():.4f}')
        lossi.append(loss.log10().item())

      0/ 200000: 3.7930
   9999/ 200000: 2.2924
  19999/ 200000: 1.6431
  29999/ 200000: 2.2065
  39999/ 200000: 2.2734
  49999/ 200000: 2.4392
  59999/ 200000: 1.7371
  69999/ 200000: 1.9692
  79999/ 200000: 2.2048
  89999/ 200000: 2.2183
  99999/ 200000: 2.0322
 109999/ 200000: 2.1701
 119999/ 200000: 2.1051
 129999/ 200000: 2.2872
 139999/ 200000: 1.9275
 149999/ 200000: 2.5890
 159999/ 200000: 2.1844
 169999/ 200000: 1.9185
 179999/ 200000: 1.7665
 189999/ 200000: 2.0832
 199999/ 200000: 2.3343


In [40]:
# calibrate the batch norm at the end of training
with torch.no_grad():
    emb = C[Xtr]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    # measure the mean/std over the entire training set
    bnmean = hpreact.mean(0, keepdim=True)
    bnvar = hpreact.var(0, keepdim=True, unbiased=True)

In [41]:
# evaluate train and val loss

@torch.no_grad()
def split_loss(split):
    x, y = {
        'train': (Xtr, Ytr),
        'val': (Xdev, Ydev),
        'test': (Xte, Yte),
    }[split]
    emb = C[x]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())


split_loss('train')
split_loss('val')

train 2.119234323501587
val 2.164980888366699


In [42]:
# sample from the model
g = torch.Generator().manual_seed(2147483647 + 10)

for _ in range(20):

    out = []
    context = [0] * block_size
    while True:
        # forward pass:
        # ------------
        emb = C[torch.tensor([context])]
        embcat = emb.view(emb.shape[0], -1)
        hpreact = embcat @ W1 + b1
        hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        # ------------
        # Sample
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break

    print(''.join(itos[i] for i in out))

mon.
ammyazleee.
mad.
ryla.
renyr.
jendraegradee.
daeli.
jemi.
jen.
eden.
esmanarielle.
malara.
noshubvigahiriel.
kendrethleen.
teron.
ubreyce.
ryyah.
faeh.
yuve.
mys.
