In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
words = open("names.txt", "r").read().splitlines()
words[:8]

In [None]:
# build the vocab of chars and mapping to/from int
chars = sorted(list(set(''.join(words))))
stoi = {ch: i + 1 for i, ch in enumerate(chars)}
stoi['.'] = 0
itos = {i: ch for ch, i in stoi.items()}
print(itos, len(chars))

In [None]:
block_size = 3 
def build_dataset(words):

    X_input, Y_label = [], []
    for word in words:
        context = [0] * block_size
        for ch in word + '.':
            ix = stoi[ch]
            X_input.append(context)
            Y_label.append(ix)
            context = context[1:] + [ix] # crop and append

    X_input = torch.tensor(X_input, dtype=torch.int64)
    Y_label = torch.tensor(Y_label, dtype=torch.int64)
    print(X_input.shape, Y_label.shape)
    return X_input, Y_label

import random
random.seed(42)
random.shuffle(words)
n1 = int(len(words) * 0.8)
n2 = int(len(words) * 0.9)

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

In [None]:
# utility function we will use later when comparing manual and auto gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)  # approximate comparison
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | max diff: {maxdiff}')

In [None]:
# MLP revised 
n_emb = 10
hidden_layer = 200
vocab_size = len(chars) + 1

g = torch.Generator().manual_seed(214483647)
C = torch.randn((vocab_size, n_emb), generator=g)
W1 = torch.randn((n_emb * block_size, hidden_layer), generator=g) * 5/3 / (n_emb * block_size ** 0.5) # * 0.2
b1 = torch.randn((hidden_layer,), generator=g) * 0.1
W2 = torch.randn((hidden_layer, vocab_size), generator=g) * 0.1 # smaller weights, avoid 0
b2 = torch.randn((vocab_size,), generator=g) * 0.1

# batch norm parameters
bngain = torch.randn((1, hidden_layer)) * 0.1 + 1.0
bnbias = torch.randn((1, hidden_layer)) * 0.1

# all zero parameters could mask an incorrect implementation of the backward pass

parameters = [C, W1, b1, W2, b2, bngain, bnbias]

total_param_size = sum(p.nelement() for p in parameters) # the number of parameters
print(total_param_size)
for p in parameters:
    p.requires_grad = True

In [None]:
batch_size = 32
n = batch_size # a shorter variable name for convenience
# conscruct a minibatch
ix = torch.randint(low=0, high=Xtr.shape[0], size=(batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [None]:
# forward pass, with smaller steps so that it is possible to backward one step at a time

emb = C[Xb] # n x block_size x n_emb
embcat = emb.view(emb.shape[0], -1) # n x (block_size * n_emb)
# linear layer
hprebn = embcat @ W1 + b1 # hidden_layer pre-activation
# batch norm layer
# paper link: https://arxiv.org/pdf/1502.03167.pdf
bnmeani = 1/n * hprebn.mean(dim=0, keepdim=True) # 1 x hidden_layer
bndiff = hprebn - bnmeani # n x hidden_layer
bndiff2 = bndiff ** 2 # n x hidden_layer
bnvar = 1/(n -1) * (bndiff2).sum(dim=0, keepdim=True) # 1 x hidden_layer, Bessel's correction n-1 instead of n. unbiased estimate instead of biased estimate. 論文ではnを訓練時にn,テスト時にはn-1を使っている
# andrej氏はどちらも固定しておくのがベストと唱える。batch sizeが小さいときにはbessels correctionを使うと良く、大きい時にはそうでないほうが良い。
bnvar_inv = (bnvar + 1e-5) ** -0.5 # 1 x hidden_layer
bnraw = bndiff * bnvar_inv # n x hidden_layer
hpreact = bngain * bnraw + bnbias # n x hidden_layer

# non linearity
h = torch.tanh(hpreact) # n x hidden_layer
# linear layer2
logits = h @ W2 + b2 # n x vocab_size
# cross entropy loss 
logits_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logits_maxes # n x vocab_size
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True) # n x 1
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv # n x vocab_size
logprobs = probs.log() # n x vocab_size
loss = -logprobs[torch.arange(n), Yb].mean() # scalar

# pytorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits,
        logits_maxes, logits, h, hpreact, bnraw, bnvar_inv, 
        bndiff2, bndiff, bnvar, bnvar_inv, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
loss


In [None]:
emb.shape, C.shape, Xb.shape
print(C[:5])

In [None]:
# exercise here: backprop through through the whole thing manually
# backproping through exactyl all of the variables 
# as they are defined in the forward pass

# batch単位で更新しているので、ixで指定した場所以外の勾配は0であり、作用しない
# logprobsと同じ形で勾配はすべて0, ただし、ixで指定した場所のみ1/n (平均値)
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n

dprobs = dlogprobs / probs # local_derivatives * upstream_derivatives(chain rule0), d/dx logx = 1/x
dcounts = counts_sum_inv * dprobs # d/dx x = 1, element wise broadcasting so that dcounts is the same shape as counts
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)# 32 x 1
dcounts_sum = - (counts_sum ** -2) * dcounts_sum_inv # 32 x 1
dcounts +=  torch.ones_like(counts) * dcounts_sum # 32 x 86 # 足し算の場合、要素ごとの勾配は1で形は変わらない
dnorm_logits = norm_logits.exp() * dcounts # 32 x 86 #要素ごとにexpをとるので、 expが抜けると勾配の計算が起きない
dlogits = torch.ones_like(logits) * dnorm_logits # 32 x 86
# dlogits = dnorm_logits.clone() # これでも同じ
dlogits_maxes = (-torch.ones_like(logits_maxes) * dnorm_logits).sum(1, keepdim=True) # 32 x 1
# dlogits_maxes = - logits_maxes.sum(1, keepdim=True) * dlogits # 32 x 1
dlogits += (F.one_hot(logits.max(dim=1).indices, num_classes=logits.shape[1])) * dlogits_maxes # 32 x 86
dh = dlogits @ W2.T # 32 x 86
dW2 = h.T @ dlogits # 86 x 86
db2 = dlogits.sum(0, keepdim=True) # 1 x 86
dhpreact = (1 - h**2) * dh
dbngian = (bnraw * dhpreact).sum(0, keepdim=True) # 86 x 1
dbnraw = bngain * dhpreact
dbnbias = dhpreact.sum(0, keepdim=True)
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar = -0.5 * (bnvar + 1e-5) ** -1.5 * dbnvar_inv
dbndiff2 = torch.ones_like(bndiff2) * 1/(n-1) * dbnvar
dbndiff += 2 * bndiff * dbndiff2
dbnmeani = -1 * dbndiff.sum(0, keepdim=True)
dhprebn = torch.clone(dbndiff)
dhprebn += torch.ones_like(dhprebn) * dbnmeani * 1/n * 1/batch_size
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0, keepdim=False)
# demb = dembcat.view(batch_size, block_size, n_emb) #これは下と同じ
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for i in range(Xb.shape[1]):
        ix = Xb[k, i]
        dC[ix] += demb[k, i]

cmp("logprobs", dlogprobs, logprobs)
cmp("probs", dprobs, probs)
cmp("counts_sum_inv", dcounts_sum_inv, counts_sum_inv)
cmp("counts_sum", dcounts_sum, counts_sum)
cmp("counts", dcounts, counts)
cmp("norm_logits", dnorm_logits, norm_logits)
cmp("logits_maxes", dlogits_maxes, logits_maxes)
cmp("logits", dlogits, logits)
cmp("h", dh, h)
cmp("W2", dW2, W2)
cmp("d2", db2, b2)
cmp("dhpreact", dhpreact, hpreact)
cmp("dbngian", dbngian, bngain)
cmp("dbnraw", dbnraw, bnraw)
cmp("dbnbias", dbnbias, bnbias)
cmp("dbnvardiff", dbnvar_inv, bnvar_inv)
cmp("dbnvar", dbnvar, bnvar)
cmp("dbndiff2", dbndiff2, bndiff2)
cmp("dbndiff", dbndiff, bndiff)
cmp("dbnmeani", dbnmeani, bnmeani)
cmp("dhprebn", dhprebn, hprebn)
cmp("dembcat", dembcat, embcat)
cmp("dW1", dW1, W1)
cmp("db1", db1, b1)
cmp("demb", demb, emb)
cmp("dC", dC, C)

# 
# loss_fast = F.cross_entropy(logits, Yb)
# print(loss_fast.item(), "diff:", (loss_fast - loss).item())

In [None]:
dbnvar_inv.shape, bnvar_inv.shape

In [None]:
# exercise here

# 
hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias


In [None]:
plt.imshow(F.one_hot(logits.max(dim=1).indices, num_classes=logits.shape[1])) # 一行1つだけ1, 他は0
counts[0].sum(), counts[0], counts.sum(dim=1, keepdim=True)
logprobs == torch.log(probs) # element-wise comparison, should be all True

In [None]:
# exsercise 2

# before:
# cross entropy loss 
# logits_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logits_maxes # n x vocab_size
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdim=True) # n x 1
# counts_sum_inv = counts_sum**-1
# probs = counts * counts_sum_inv
# logprobs = probs.log()q
# loss = -logprobs[torch.arange(n), Yb].mean() 

# now:
# forward pass
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), "diff:", (loss_fast - loss).item())

In [None]:
# backward pass
dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n

cmp("logits", dlogits, logits)

In [None]:
logits.shape, Yb.shape, loss_fast.shape

In [None]:
F.softmax(logits, 1)[0]

In [None]:
dlogits[0] * n, n

In [None]:
dlogits[0].sum()

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(dlogits.detach(), cmap="gray")

In [None]:
# exercise 3

# calcurate dhprebn given dhpact

dhprebn = bngain*bnvar_inv/n * (n * dhpreact - dhpreact.sum(0) - n/(n-1) * bnraw * (dhpreact*bnraw).sum(0))
cmp("hrepbn", dhprebn, hprebn)

In [None]:
# init

n_embd = 10
n_hidden = 200

g = torch.Generator().manual_seed(214748)
C = torch.randn((vocab_size, n_embd), generator=g)

# layer1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/ ((n_embd * block_size ) ** 0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1

# layer2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

# batch norm
bngain = torch.randn((1, n_hidden), generator=g) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden), generator=g) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

# same optimaziation as before
max_steps = 200000
batch_size = 32
n = batch_size
lossi = []

# with torch.no_grad():

for i in range(max_steps):

    # mini batch construction
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], - 1)

    hprebn = embcat @ W1 + b1

    # batch norm layer
    bnmean = hprebn.mean(0, keepdim=True)
    bnvar = hprebn.var(0, keepdim=True, unbiased=True)
    bnvar_inv = (bnvar + 1e-5)**-0.5
    bnraw = (hprebn - bnmean) * bnvar_inv
    hpreact = bngain * bnraw + bnbias

    # non linear activation
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # manual backward pass
    # dlogits = F.softmax(logits, 1)
    # dlogits[range(n), Yb] -= 1
    # dlogits /= n

    # # layer2
    # dh = dlogits @ W2.T # 32 x 86
    # dW2 = h.T @ dlogits # 86 x 86
    # db2 = dlogits.sum(0) # 1 x 86

    # # tanh
    # dhpreact = (1.0 - h**2) * dh

    # # batchnorm backprop
    # dbngian = (bnraw * dhpreact).sum(0, keepdim=True) # 86 x 1
    # dbnbias = dhpreact.sum(0, keepdim=True)
    # dhprebn = bngain*bnvar_inv/n * (n * dhpreact - dhpreact.sum(0) - n/(n-1) * bnraw * (dhpreact*bnraw).sum(0))

    # # layer1

    # dembcat = dhprebn @ W1.T
    # dW1 = embcat.T @ dhprebn
    # db1 = dhprebn.sum(0, keepdim=False)

    # demb = dembcat.view(emb.shape)
    # dC = torch.zeros_like(C)
    # for k in range(Xb.shape[0]):
    #     for i in range(Xb.shape[1]):
    #         ix = Xb[k, i]
    #         dC[ix] += demb[k, i]
    
    # dC, dW1, db1, dW2, db2, dbngain, dbnbias = None, None, None, None, None, None, None
    # grads = [dC, dW1, db1, dW2, db2, dbngian, dbnbias]

    # update
    lr = 0.1 if i < 10000 else 0.01
    # for p, grad in zip(parameters, grads):
    for p in parameters:
        p.data += -lr * p.grad # old way, not using grad
        # p.data += -lr * grad # new way, using grad

    # logging
    if i % 10000 == 0:
        print(f"{i:7d}: {max_steps:7d} {loss.item():.4f}")
    lossi.append(loss.log10().item())

    # if i >= 100:
    #     break


In [None]:
for g, p in zip(parameters, grads):
    cmp(str(tuple((p.shape))), p, g)

In [None]:
# calibrate the batch norm at the end of training

with torch.no_grad():
    # pass the training set through the network
    emb = C[Xtr] # embedding the characters into vectors
    emb_cat = emb.view(emb.shape[0], -1) # concatenate the vectors
    h_preactive = emb_cat @ W1 + b1 # pre-activation of the hidden layer

    # measure hte mean and std over the entire training set
    bnmean = h_preactive.mean(0, keepdim=True)
    bnvar = h_preactive.var(0, keepdim=True, unbiased=True)

In [None]:
@torch.no_grad() # no need to track gradients, evaluating alone makes it faster
def split_loss(split):
    x, y = {
        'train': (Xtr, Ytr),
        'val': (Xdev, Ydev),
        'test': (Xte, Yte),
    }[split]

    emb = C[x]
    emb_cat = emb.view(emb.shape[0], -1) # concatenate into (batch_size, n_emb * block_size)
    h_preact = emb_cat @ W1 + b1 # N x hidden_layer
    h_preact = ((h_preact - bnmean) / (bnvar + 1e-5))**-0/5 * bngain + bnbias
    h = torch.tanh(h_preact) # N x hidden_layer
    logits = h @ W2 + b2 # N, vocab_size
    loss = F.cross_entropy(logits, y)
    print(f"{split} loss: {loss.item():.4f}")

split_loss('train')
split_loss('val')

In [None]:
# sample from the model
g = torch.Generator().manual_seed(214483647 + 10)

for _ in range(20):

    out = []
    context = [0] * block_size # start with all ... (zeros)
    while True:
        emb = C[torch.tensor([context])] # (1, block_size, d)
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        # sample from the distribution
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        # shift the context and append the new char
        context = context[1:] + [ix]
        out.append(ix)
        # if we hit the end of the word(special token), stop
        if ix == 0:
            break

    print("".join(itos[i] for i in out)) # decord and print the word