The material in this notebook follows part 4 of Andrej Karpathy's stellar Makemore tutorial (https://www.youtube.com/watch?v=q8SA3rM6ckI)


In [19]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

%matplotlib inline

!git clone https://github.com/karpathy/makemore

fatal: destination path 'makemore' already exists and is not an empty directory.


In [20]:
with open('makemore/names.txt','r') as file:
  words = file.read().splitlines()

In [21]:
word_lengths = torch.tensor([len(w) for w in words]).float()
print(
 f"""
 This dataset contains {word_lengths.nelement()} names\n
 The minimum name length is {word_lengths.min()} characters.\n 
 The maximum name length is {word_lengths.max()} characters.\n
 The mean name length is  {word_lengths.mean():.2f} characters. \n
 The associated standard deviation is {word_lengths.std():.2f} characters.
 """
 )


 This dataset contains 32033 names

 The minimum name length is 2.0 characters.
 
 The maximum name length is 15.0 characters.

 The mean name length is  6.12 characters. 

 The associated standard deviation is 1.44 characters.
 


In [22]:
#building the character vocabulary and lookup tables to map from characters to integer indices and back

chars = ['.']+sorted(list(set(''.join(words))))  #as before, '.' is used as a start/stop/padding special character
s_to_i = {s:i for i,s in enumerate(chars)}
i_to_s = {i:s for s,i in s_to_i.items()}
block_size = 3 #context length, size of the block that supports the prediction: P(x_n| x_{n-1}, x_{n-2}, x_{n-3} )
vocab_size = len(i_to_s)
print(i_to_s)
print(vocab_size)


{0: '.', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}
27


In [23]:
def build_dataset(words, block_size):
  X,Y = [], []

  for w in words:
    #print(w)
    context = [0] * block_size #init context using indices of chars
    for ch in w+'.':
      ix = s_to_i[ch]
      X.append(context)
      Y.append(ix)
      #print(''.join(i_to_s[i] for i in context), '--->', i_to_s[ix]) #context ---> current, training pattern
      context = context[1:]+[ix]

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape,Y.shape)
  return X,Y

#training split (used to train parameters), dev/validation split (used to train hyperparameters), test split (at end with the final model)
# 80%, 10%, 10%
import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1], block_size)
Xdev, Ydev = build_dataset(words[n1:n2], block_size)
Xte, Yte = build_dataset(words[n2:], block_size)

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [24]:
# reference gradients for comparison to validate our manual gradients
def cmp(s,dt,t):
  ex = torch.all(dt == t.grad).item() #all entries exact match?
  app = torch.allclose(dt, t.grad) #all entries close within some tolerance?
  maxdiff = (dt - t.grad).abs().max().item() #max error?
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')


In [25]:
#re-initialize all parameters using the Kaiming init method for tanh
n_embd = 10 #embedding dimension
n_hidden = 64
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size,n_embd),             generator=g) #embedding matrix
W1 = torch.randn((n_embd * block_size,n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)  #scaling can be important, large weights that occur by chance (high dimensional gaussian) can cause the tanh nonlinearity to saturate, even at initialization. Saturated nonlinearities are flat, meaning the gradient of the loss wrt those parameters is zero. No learning for those parameters.

# note that with batch normalization, the bias b1 is useless. We include it here anyway just to show it will have zero gradient.
b1 = torch.randn(n_hidden,                       generator=g) * 0.1
W2 = torch.randn((n_hidden,vocab_size),          generator=g) * 0.1 #scaling can help with unbalanced initial probabilities output by the softmax layer due to outliers in the random input layer
b2 = torch.randn(vocab_size,                     generator=g) * 0.1

#
bngain = torch.randn((1,n_hidden))*0.1 + 1.0 #batch normalization gain
bnbias = torch.randn((1,n_hidden))*0.1 #batch normalization bias
bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1,n_hidden))

#As stated in the lecture, the above initializations are somewhat nonstandard.
#We are avoiding certain initializations, such as zero, so that improper implementations
#of backprop will be fully exposed (nothing is hidden). 

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
  p.requires_grad = True

4137


# Detailed forward pass
Exploring the guts of a single forward pass through the network

In [26]:
#construct minibatch indices
batch_size = 32
ix = torch.randint(0,Xtr.shape[0],(batch_size,), generator=g) 
Xb, Yb = Xtr[ix], Ytr[ix] #batch

print(Xb.shape,Yb.shape)

torch.Size([32, 3]) torch.Size([32])


In [27]:
# forward pass
# recall: 
# C is (vocab_size,n_embd)
# Xb is (batch_size,block_size) 


emb = C[Xb] #embedding characters (batch_size, block_size, n_embd)
embcat = emb.view(emb.shape[0],-1) #(batch_size, block_size * n_embd)

# Linear layer 1
hprebn = embcat @ W1 + b1 #h(idden)preb(atch)n(ormalization), size = (batch_size, n_hidden)
# each row of hprebn is a vector of preactivations for the corresponding input example.

# BatchNorm layer
bnmeani = 1/batch_size * hprebn.sum(dim=0,keepdim=True) #(1,n_hidden), batch average for each preactivation
bndiff = hprebn - bnmeani #(batch_size, n_hidden)
bndiff2 = bndiff**2
bnvar= 1/(batch_size-1) * bndiff2.sum(dim=0,keepdim=True) #Bessel correction to sample variance
bnstd_inv = (bnvar+1e-5)**-0.5 #1/bnstd_dev
bnraw = bndiff * bnstd_inv # normalization (hprebn - bnmean) / bnstd_dev
hpreact = bngain * bnraw + bnbias #(batch_size, n_hidden)

# Non-linearity
h = torch.tanh(hpreact) #(batch_size, n_hidden)

#Linear layer 2
logits = h @ W2 + b2 #(batch_size, vocab_size)

#cross entropy loss, does the same thing as F.cross_entropy(logits,Yb)
logit_maxes = logits.max(1, keepdim=True).values #calculates max of each output layer of each example, size=(batch_size,1)
norm_logits = logits - logit_maxes #subtracts the max element from each output layer/row, size=(batch_size,vocab_size)
counts = norm_logits.exp() #(batch_size, vocab_size)
counts_sum = counts.sum(1,keepdim=True) #(batch_size,1)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv #(batch_size, vocab_size)
logprobs = probs.log() #k-th row contains the log likelihood distribution over all next tokens for the k-th example in the batch
loss = -logprobs[range(batch_size),Yb].mean() # negative log likelihood loss, averaged over the batch


#PyTorch backward pass
#backward pass
for p in parameters:
  p.grad = None

for t in [logprobs,probs,counts,counts_sum, counts_sum_inv,
         norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnstd_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()

loss.backward()
loss

tensor(3.3420, grad_fn=<NegBackward0>)

In [28]:
## Exercise 1: backprop through the whole thing manually, one-by-one.
# notation dg is shorthand for d(loss)/dg

#d(loss)/d(logprobs)
# For example k, logprobs is a (1,vocab_size) sized log probability distribution over the vocab for the next words.
# Hence, dlogprobs is of dimension (batch_size,vocab_size)
# The negative log likelihood loss only plucks out the -logprob associated with Yb[k]. 
# Hence, this index is the only parameter in logprob that influences the loss.
# This is averaged over the batch, hence each score has a coefficient of 1/batch_size.
# The dimension of dlogprobs should be (batch_size,vocab_size) 

dlogprobs = torch.zeros_like(logprobs) 
dlogprobs[range(batch_size),Yb] = -1/batch_size
cmp('logprobs', dlogprobs, logprobs)

#d(loss)/d(probs) 
# For a fixed example k, logprobs is a (1,vocab_size) sized log probability distribution.
# Each entry of is the natural logarithm of the corresponding entry of probs.
# Hence, the dimension of dprobs is (batch_size, vocab_size)

dprobs = probs**-1.0 * dlogprobs 
cmp('probs', dprobs, probs)

#d(loss)/d(counts_sum_inv)
# For a fixed example k, counts_sum_inv is a (1,1) sized parameter.
# And for the batch of examples, counts_sum_inv is (batch_size,1)
# Hence, d(probs)/d(counts_sum_inv) should be (batch_size,vocab_size).
# This tensor contains all of the linear perturbations to the probabilities due to perturbations in counts_sum_inv.
# From above, we have that d(loss)/d(probs) is of dimension (batch_size, vocab_size).
# For each example, we can sum up all of the linear perturbations to the loss due to the perturbations to counts_sum_inv.
dcounts_sum_inv = (counts * dprobs).sum(dim=1, keepdim=True)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)

#d(loss)/d(counts_sum)
# For a fixed example k, counts_sum is a (1,1) sized parameter.
# And for the batch of examples, counts_sum is (batch_size,1)
# Hence, d(count_sum_inv)/d(counts) should be (batch_size,1).
# From above, we have that d(loss)/d(count_sum_inv)) is of dimension (batch_size, vocab_size).
# For each example, we can sum up all of the linear perturbations to the loss due to the perturbations to counts_sum_inv.
dcounts_sum = (-counts_sum**-2)  * dcounts_sum_inv
cmp('counts_sum', dcounts_sum, counts_sum)

#d(loss)/d(counts)
# First, note that counts appears in two terms: probs and counts_sum .
# Therefore, we must sum up the derivatives of these two terms (product rule).
dcounts = counts_sum_inv * dprobs + torch.ones_like(counts) * dcounts_sum
cmp('counts', dcounts, counts)

#d(loss)/d(norm_logits)
# norm_logits is (batch_size, vocab_size)
dnorm_logits = counts * dcounts
cmp('norm_logits', dnorm_logits, norm_logits)

#d(loss)/d(logit_maxes)
# We should expect this derivative to be effectively zero since the normalization of the cross-entropy loss is included
# only for numerical stability.
dlogit_maxes = (-dnorm_logits).sum(dim=1, keepdim=True)
cmp('logit_maxes', dlogit_maxes, logit_maxes)

#d(loss)/d(logits)
# logits appears in both the expressions for norm_logits and logits_max.
# recall, logits is (batch_size, vocab_size)
arg_idx = torch.zeros_like(logits)
arg_idx[range(batch_size), logits.argmax(dim=1)] = 1.0
dlogits = dnorm_logits.clone() + arg_idx*dlogit_maxes # note, we took care of the negative sign when computing dlogit_maxes
cmp('logits', dlogits, logits)

#d(loss)/dh
# recall, the h is of dimension (batch_size, n_hidden).
# also recall that W2 is of dimension (n_hidden,vocab_size)
dh = dlogits @ W2.T
cmp('h', dh, h)

#d(loss)/d(W2)
# recall, the weight matrix W2 is of dimension (n_hidden, vocab_size)
# h is (batch_size, n_hidden)
# dlogits is (batch_size, vocab_size)
dW2 = h.T @ dlogits
cmp('W2', dW2, W2)

#d(loss)/d(b2)
# recall, the bias vector b2 is of dimension (vocab_size,) but will be broadcast to (1,vocab_size during operations)
# h is (batch_size, n_hidden)
# dlogits is (batch_size, vocab_size)
db2 = dlogits.sum(dim=0)
cmp('b2', db2, b2)

#d(loss)/d(hpreact)
# recall, the the preactivation layer in hpreact is (batch_size, n_hidden).
# After the application of tanh, the resulting activation layer of course has the same dimension.
dhpreact = (1.0 - h**2) * dh
cmp('hpreact', dhpreact, hpreact)

#d(loss)/d(bngain)
# recall, bngain is of dimension (1,n_hidden).
# After the application of tanh, the resulting activation layer of course has the same dimension.
dbngain = (bnraw*dhpreact).sum(dim=0,keepdim=True)
cmp('bngain', dbngain, bngain)

#d(loss)/d(bnraw)
#recall, bnraw is of dimension (batch_size, n_hidden)
dbnraw = bngain * dhpreact
cmp('bnraw', dbnraw, bnraw)

#d(loss)/d(bnbias)
#recall, bnbias is of dimension (1, n_hidden)
dbnbias = dhpreact.sum(dim=0,keepdim=True)
cmp('bnbias', dbnbias, bnbias)

#d(loss)/d(bnstd_inv)
#recall, bnstd_inv is of dimension (1, n_hidden)
dbnstd_inv = (bndiff*dbnraw).sum(dim=0, keepdim=True)
cmp('bnstd_inv', dbnstd_inv, bnstd_inv)

#d(loss)/d(bnvar)
#recall, bnvar is of dimension (1, n_hidden)
dbnvar= -0.5*(bnvar+1e-5)**-1.5 *dbnstd_inv
cmp('bnvar', dbnvar, bnvar)

#d(loss)/d(bndiff2)
#recall, bndiff2 is of dimension (batch_size, n_hidden)
dbndiff2 = 1/(batch_size-1)*torch.ones_like(bndiff2)*dbnvar
cmp('bndiff2', dbndiff2, bndiff2)

#d(loss)/d(bndiff)
#recall, bndiff is of dimension (batch_size, n_hidden)
dbndiff = bnstd_inv * dbnraw + 2.0 * bndiff * dbndiff2
cmp('bndiff', dbndiff, bndiff)

#d(loss)/d(bnmeani)
#recall, bnmeani is of dimension (1,n_hidden)
dbnmeani= -1.0 * dbndiff.sum(dim=0, keepdim=True)
cmp('bnmeani', dbnmeani, bnmeani)

#d(loss)/d(hprebn)
#recall, hprebn is of dimension (batch_size, n_hidden)
dhprebn = dbndiff.clone() + 1.0/batch_size*torch.ones_like(hprebn) *dbnmeani
cmp('hprebn', dhprebn, hprebn)

#d(loss)/d(W1)
#recall, W1 is of dimension (block_size * n_embed, n_hidden)
dW1 = embcat.T @ dhprebn
cmp('W1', dW1, W1)

#d(loss)/d(b1)
#recall, b1 is of dimension (n_hidden,)
db1 = dhprebn.sum(dim=0)
cmp('b1', db1, b1)

#d(loss)/d(embcat)
#recall embcat is of dimension (batch_size, block_size * n_embed)
dembcat = dhprebn@W1.T
cmp('embcat', dembcat, embcat)

#d(loss)/d(emb)
demb = dembcat.view(emb.shape)
cmp('emb', demb, emb)

#d(loss)/d(C)
#recall, C is of dimension (vocab_size, n_embed)
dC = torch.zeros_like(C)
for ii in range(Xb.shape[0]):
  for jj in range(Xb.shape[1]):
    dC[Xb[ii,jj]] += demb[ii,jj,:]
cmp('C', dC, C)

# I seem to get mostly approximate equivalence whereas the lecture has exact equivalence for all derivatives. Perhaps I have a bug somewhere...

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: False | approximate: True  | maxdiff: 4.656612873077393e-10
bngain          | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09
bnraw           | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10
bnbias  

In [29]:
# Exercise 2: backprop through the cross_entropy loss again, but this time all in one go rather than tracking each individual step.

# Basically, the long way -- as investigated in the previous cell -- had the following forward pass:

# logit_maxes = logits.max(1, keepdim=True).values #calculates max of each output layer of each example, size=(batch_size,1)
# norm_logits = logits - logit_maxes #subtracts the max element from each output layer/row, size=(batch_size,vocab_size)
# counts = norm_logits.exp() #(batch_size, vocab_size)
# counts_sum = counts.sum(1,keepdim=True) #(batch_size,1)
# counts_sum_inv = counts_sum**-1
# probs = counts * counts_sum_inv #(batch_size, vocab_size)
# logprobs = probs.log() #k-th row contains the log likelihood distribution over all next tokens for the k-th example in the batch
# loss = -logprobs[range(batch_size),Yb].mean() # negative log likelihood loss, averaged over the batch

# now, just consider:
loss_fast = F.cross_entropy(logits,Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())

3.342024564743042 diff: -2.384185791015625e-07


In [30]:
#backward pass
# here we implement an analytical formula for d(loss)/dlogits

# first, let us recall that logits is of dimension (batch_size, vocab_size).
# now, let us consider the case of a single example rather than a minibatch.
# the model output in terms of logits z is of dimension (vocab_size,).
# this can be converted to probabilities by applying a softmax operator to get:
#               p_i = exp(z_i) / $_j exp(z_j)
# which is stored in the vector p of dimension (vocab_size,).
# the nll loss is therefore -log p_y, where y is the index of target training data.
#
# hence, we are interested in 
# d/dz_k (-log p_y) = -1/p_y dp_y/dz_k
#                   = -1/p_y d/dz_k (exp(z_y) / $_j exp(z_j)) 
#                   = -1/p_y [( ($_j exp(z_j)) * d(exp(z_y))/dz_k - exp(z_y)*d($_j exp(z_j))/dz_k) ) / ($_j exp(z_j))**2 ] 
#                   = -1/p_y [( ($_j exp(z_j)) * d(exp(z_y))/dz_k - exp(z_y)*exp(z_k)) ) / ($_j exp(z_j))**2 ] 
#                   = -1/p_y [ ($_j exp(z_j)) * d(exp(z_y))/dz_k / ($_j exp(z_j))**2 - exp(z_y)*exp(z_k)) ) / ($_j exp(z_j))**2 ] 
#                   = -1/p_y [  d(exp(z_y))/dz_k / ($_j exp(z_j)) - exp(z_y)*exp(z_k)) ) / ($_j exp(z_j))**2 ] 
#                   = -1/p_y [  p_y * 1_{k==y} - p_y p_k ]
#                   = - [ 1_{k==y} - p_k ]
#                   = p_k - 1_{k==y}
# now recall our loss is averaged over a minibatch.

#dlogits_ = 1/batch_size * (probs - F.one_hot(Yb, num_classes=vocab_size) )
dlogits_ = F.softmax(logits, dim=1)
dlogits_[range(batch_size),Yb] -= 1
dlogits_ *= 1/batch_size
cmp('logits', dlogits_, logits)

logits          | exact: False | approximate: True  | maxdiff: 5.122274160385132e-09


An interpretation of the negative log marginal likelihood derivatives with respect to the logits is as follows. Recall, the logits are soft maxed into a probability distribution $p$ over the `vocab_size` classes. Let $k$ index those classes and let $y$ by the index of the observed training data point.

If $k \neq y$, i.e., the $k$-th class does not correspond to the correct data label, then

`d(loss)/d(logits[k]) = probs[k]`.

Hence, further decreasing the loss via gradient descent pulls down the $k$-th logit.

Conversely, if $k = y$, i.e., the $k$-th class is the correct data label, then

`d(loss)/d(logits[k]) = probs[k] - 1 < 0`

Therefore, further decreasing the loss via gradient descent pushes up the $k$-th logit.

Basically, the output probabilities are adjusted to support the training data labels (while opposing incorrect labels). 

In [31]:
# Exercise 3: backprop through the batch normalization again, but this time all in one go.

# recall, the forward pass:

# bnmeani = 1/batch_size * hprebn.sum(dim=0,keepdim=True) #(1,n_hidden), batch average for each preactivation
# bndiff = hprebn - bnmeani #(batch_size, n_hidden)
# bndiff2 = bndiff**2
# bnvar= 1/(batch_size-1) * bndiff2.sum(dim=0,keepdim=True) #Bessel correction to sample variance
# bnstd_inv = (bnvar+1e-5)**-0.5 #1/bnstd_dev
# bnraw = bndiff * bnstd_inv # normalization (hprebn - bnmean) / bnstd_dev
# hpreact = bngain * bnraw + bnbias #(batch_size, n_hidden)

# now, using one line:
hpreact_fast = bngain * (hprebn - hprebn.mean(dim=0,keepdim=True)) / torch.sqrt(hprebn.var(dim=0, keepdim=True, unbiased = True) + 1e-5) + bnbias
print('maxdiff: ', (hpreact_fast - hpreact).abs().max())


maxdiff:  tensor(4.7684e-07, grad_fn=<MaxBackward1>)


In [32]:
# backward pass

# our goal here is to calculate the d(loss)/d(hprebn) given d(loss)/d(hpreact). 
# from the analytical formula, we have: 

dhprebn_ = bngain * bnstd_inv / batch_size * ( batch_size*dhpreact - dhpreact.sum(dim=0) - batch_size / (batch_size-1) * bnraw *(dhpreact * bnraw).sum(dim=0))

cmp('hprebn', dhprebn_, hprebn)

hprebn          | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10


## Putting it all together

In [33]:
#re-initialize all parameters using the Kaiming init method for tanh
n_embd = 10 #embedding dimension
n_hidden = 200
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size,n_embd),             generator=g) #embedding matrix
W1 = torch.randn((n_embd * block_size,n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)  #scaling can be important, large weights that occur by chance (high dimensional gaussian) can cause the tanh nonlinearity to saturate, even at initialization. Saturated nonlinearities are flat, meaning the gradient of the loss wrt those parameters is zero. No learning for those parameters.

# note that with batch normalization, the bias b1 is useless. We include it here anyway just to show it will have zero gradient.
b1 = torch.randn(n_hidden,                       generator=g) * 0.1
W2 = torch.randn((n_hidden,vocab_size),          generator=g) * 0.1 #scaling can help with unbalanced initial probabilities output by the softmax layer due to outliers in the random input layer
b2 = torch.randn(vocab_size,                     generator=g) * 0.1

#
bngain = torch.randn((1,n_hidden))*0.1 + 1.0 #batch normalization gain
bnbias = torch.randn((1,n_hidden))*0.1 #batch normalization bias

#As stated in the lecture, the above initializations are somewhat nonstandard.
#We are avoiding certain initializations, such as zero, so that improper implementations
#of backprop will be fully exposed (nothing is hidden). 

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
  p.requires_grad = True

batch_size = 32
max_iters = 200000
lossi=[]

with torch.no_grad():
  for i in range(max_iters):

    ix = torch.randint(0,Xtr.shape[0],(batch_size,), generator=g) 
    Xb, Yb = Xtr[ix], Ytr[ix] #batch

    #forward pass
    emb = C[Xb] #embedding characters (batch_size, block_size, n_embd)
    embcat = emb.view(emb.shape[0],-1) #(batch_size, block_size * n_embd)

    # Linear layer 1
    hprebn = embcat @ W1 + b1 #h(idden)preb(atch)n(ormalization), size = (batch_size, n_hidden)
    # each row of hprebn is a vector of preactivations for the corresponding input example.

    # BatchNorm layer
    bnmean = hprebn.mean(dim=0, keepdim=True)
    bnvar = hprebn.var(dim=0, keepdim=True, unbiased=True)
    bnstd_inv = (bnvar+1e-5)**-0.5
    bnraw = (hprebn - bnmean) * bnstd_inv
    hpreact = bngain * bnraw + bnbias #(batch_size, n_hidden)

    # Non-linearity
    h = torch.tanh(hpreact) #(batch_size, n_hidden)

    #Linear layer 2
    logits = h @ W2 + b2 #(batch_size, vocab_size)

    #cross entropy loss, does the same thing as F.cross_entropy(logits,Yb)
    loss = F.cross_entropy(logits, Yb) # negative log likelihood loss, averaged over the batch


    #PyTorch backward pass
    #backward pass
    for p in parameters:
      p.grad = None

    #loss.backward() #for debug
    # Manual backprop:

    dlogits = F.softmax(logits, dim=1)
    dlogits[range(batch_size),Yb] -= 1
    dlogits /= batch_size
    dh = dlogits @ W2.T
    dW2 = h.T @ dlogits
    db2 = dlogits.sum(dim=0)
    dhpreact = (1.0 - h**2) * dh
    dbngain = (bnraw*dhpreact).sum(dim=0,keepdim=True)
    dbnbias = dhpreact.sum(dim=0,keepdim=True)
    dhprebn = bngain * bnstd_inv / batch_size * ( batch_size*dhpreact - dhpreact.sum(dim=0) - batch_size / (batch_size-1) * bnraw *(dhpreact * bnraw).sum(dim=0))
    dW1 = embcat.T @ dhprebn
    db1 = dhprebn.sum(dim=0)
    dembcat = dhprebn@W1.T
    demb = dembcat.view(emb.shape)
    dC = torch.zeros_like(C)
    for ii in range(Xb.shape[0]):
      for jj in range(Xb.shape[1]):
        dC[Xb[ii,jj]] += demb[ii,jj,:]

    grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]

    #update
    lr = 0.1 if i < 100000 else 0.01
    for p, grad in zip(parameters,grads):
      p.data += -lr * grad

    # track stats
    if i % 10000 == 0:
      print(f'{i:7d}/{max_iters:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())

    # if i == 1000: #debug
    #   break


12297
      0/ 200000: 3.7805
  10000/ 200000: 2.1499
  20000/ 200000: 2.3771
  30000/ 200000: 2.4338
  40000/ 200000: 1.9939
  50000/ 200000: 2.4139
  60000/ 200000: 2.4395
  70000/ 200000: 1.9687
  80000/ 200000: 2.3636
  90000/ 200000: 2.1345
 100000/ 200000: 1.9419
 110000/ 200000: 2.3656
 120000/ 200000: 2.0369
 130000/ 200000: 2.4437
 140000/ 200000: 2.3271
 150000/ 200000: 2.2033
 160000/ 200000: 1.9513
 170000/ 200000: 1.8305
 180000/ 200000: 2.1005
 190000/ 200000: 1.8946


In [34]:
#debug
# for p,g in zip(parameters, grads):
#   cmp(str(tuple(p.shape)),g,p)

In [35]:
# calibrate the batch norm parameters for evaluation
with torch.no_grad():
  emb=C[Xtr]
  embcat = emb.view(emb.shape[0],-1)
  hpreact = embcat @W1 + b1
  bnmean = hpreact.mean(0,keepdim=True)
  bnvar = hpreact.var(0,keepdim=True)

In [36]:
@torch.no_grad()
def split_loss(split):
  x,y = { 
      'train': (Xtr, Ytr),
      'val': (Xdev,Ydev),
      'test': (Xte,Yte),
  }[split]
  emb = C[x] #(N, block_size, n_embd)
  embcat = emb.view(emb.shape[0],-1)
  hpreact = embcat @ W1 + b1
  hpreact = bngain* (hpreact - bnmean) * (bnvar+1e-5)**-0.5 + bnbias
  h = torch.tanh(hpreact)
  logits = h @ W2 + b2
  loss = F.cross_entropy(logits, y)
  print(split, loss.item())
  
split_loss('train')
split_loss('test')

train 2.0714404582977295
test 2.1091206073760986
