In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [5]:
# ok biolerplate done, now we get to the action:

In [6]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [7]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [8]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [9]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3277, grad_fn=<NegBackward0>)

In [20]:
emb.shape, emb

(torch.Size([32, 3, 10]),
 tensor([[[-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
            2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01],
          [-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
            2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01],
          [-9.6478e-01, -2.3211e-01, -3.4762e-01,  3.3244e-01, -1.3263e+00,
            1.1224e+00,  5.9641e-01,  4.5846e-01,  5.4011e-02, -1.7400e+00]],
 
         [[ 1.2815e+00, -6.3182e-01, -1.2464e+00,  6.8305e-01, -3.9455e-01,
            1.4388e-02,  5.7216e-01,  8.6726e-01,  6.3149e-01, -1.2230e+00],
          [ 4.6827e-01, -6.5650e-01,  6.1662e-01, -6.2197e-01,  5.1007e-01,
            1.3563e+00,  2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01],
          [-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
            2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01]],
 
         [[-5.6533e-01,  5.4281e-01,  1.7549e-01, 

In [21]:
embcat.shape, embcat

(torch.Size([32, 30]),
 tensor([[-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01,
          -4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01,
          -9.6478e-01, -2.3211e-01, -3.4762e-01,  3.3244e-01, -1.3263e+00,
           1.1224e+00,  5.9641e-01,  4.5846e-01,  5.4011e-02, -1.7400e+00],
         [ 1.2815e+00, -6.3182e-01, -1.2464e+00,  6.8305e-01, -3.9455e-01,
           1.4388e-02,  5.7216e-01,  8.6726e-01,  6.3149e-01, -1.2230e+00,
           4.6827e-01, -6.5650e-01,  6.1662e-01, -6.2197e-01,  5.1007e-01,
           1.3563e+00,  2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01,
          -4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01],
         [-5.6533e-01,  5.4281e-01,  1.7549e-01, -2.2901e+00, -7.0928e-01,


In [22]:
hprebn.shape, hprebn

(torch.Size([32, 64]),
 tensor([[-1.2944e+00, -1.6063e+00,  5.8236e-01,  ...,  2.8577e+00,
          -6.9102e-01,  1.4514e+00],
         [-2.4448e+00,  7.3323e-04, -2.5433e+00,  ..., -1.4529e+00,
           4.0477e-01, -1.4181e-01],
         [-1.4482e+00,  9.7922e-01,  1.4362e+00,  ...,  3.0377e-01,
          -1.2690e+00,  1.4736e+00],
         ...,
         [-1.7455e-01, -2.8238e-01, -4.0602e-01,  ..., -1.1733e+00,
           2.5727e-01,  1.6736e+00],
         [-1.9748e+00, -2.0943e+00,  7.6590e-01,  ...,  3.5904e-01,
           2.3257e+00, -1.4833e+00],
         [ 2.5080e+00, -3.0651e-01, -8.6864e-01,  ..., -1.7992e+00,
          -2.5073e+00,  4.5976e-02]], grad_fn=<AddBackward0>))

In [23]:
bnmeani.shape, bnmeani

(torch.Size([1, 64]),
 tensor([[-0.2923, -0.0526, -0.6841,  0.7001, -0.5036,  0.4972, -0.0668, -0.1110,
          -0.4643,  0.0302,  0.6675, -0.3830, -0.4715, -0.3956, -0.5209,  0.3950,
           0.0292, -1.6880,  0.3933,  0.8486, -0.6091, -1.2410, -0.0519, -0.2812,
           0.2267,  1.3896, -0.6374,  0.3277, -0.5348,  1.1896,  0.3176,  0.6083,
           0.7051, -0.5859, -0.2753,  1.9107, -1.1769, -0.7579,  0.1236,  0.4828,
           0.2302,  0.2900,  0.5919, -1.0644, -0.2177,  0.7055,  0.4917, -0.3559,
           0.6701,  1.5317, -0.5005, -0.2265,  1.7797,  0.6933,  1.5666, -1.1398,
          -0.4648, -0.8345,  0.6744, -0.1985, -1.3676, -0.5960,  0.1881,  0.7816]],
        grad_fn=<MulBackward0>))

In [24]:
bndiff.shape, bndiff

(torch.Size([32, 64]),
 tensor([[-1.0021, -1.5537,  1.2665,  ...,  3.4537, -0.8792,  0.6698],
         [-2.1525,  0.0533, -1.8592,  ..., -0.8569,  0.2166, -0.9234],
         [-1.1559,  1.0318,  2.1204,  ...,  0.8998, -1.4571,  0.6920],
         ...,
         [ 0.1178, -0.2298,  0.2781,  ..., -0.5773,  0.0691,  0.8920],
         [-1.6825, -2.0417,  1.4501,  ...,  0.9550,  2.1376, -2.2649],
         [ 2.8003, -0.2539, -0.1845,  ..., -1.2032, -2.6954, -0.7357]],
        grad_fn=<SubBackward0>))

In [25]:
bndiff2.shape, bndiff2

(torch.Size([32, 64]),
 tensor([[1.0042e+00, 2.4139e+00, 1.6040e+00,  ..., 1.1928e+01, 7.7294e-01,
          4.4859e-01],
         [4.6332e+00, 2.8452e-03, 3.4566e+00,  ..., 7.3430e-01, 4.6927e-02,
          8.5274e-01],
         [1.3360e+00, 1.0647e+00, 4.4960e+00,  ..., 8.0959e-01, 2.1232e+00,
          4.7888e-01],
         ...,
         [1.3866e-02, 5.2797e-02, 7.7355e-02,  ..., 3.3327e-01, 4.7776e-03,
          7.9563e-01],
         [2.8309e+00, 4.1687e+00, 2.1026e+00,  ..., 9.1209e-01, 4.5692e+00,
          5.1298e+00],
         [7.8419e+00, 6.4469e-02, 3.4037e-02,  ..., 1.4477e+00, 7.2654e+00,
          5.4118e-01]], grad_fn=<PowBackward0>))

In [26]:
bnvar.shape, bnvar

(torch.Size([1, 64]),
 tensor([[2.9677, 1.2682, 2.1554, 1.4667, 4.6703, 2.9705, 2.0831, 1.4324, 1.3041,
          1.2941, 1.5416, 1.3771, 3.7823, 1.3272, 1.0750, 2.1390, 2.0086, 4.0767,
          0.9473, 2.6632, 5.2006, 1.3070, 2.7925, 2.8935, 1.6161, 2.7019, 1.9697,
          1.4905, 2.7771, 1.8918, 1.4346, 3.3289, 2.6032, 1.7041, 2.5652, 3.1497,
          2.4295, 3.7256, 0.3788, 1.5171, 1.6284, 2.5108, 2.3224, 1.2441, 1.9543,
          2.9438, 1.9290, 2.3941, 1.5941, 2.6831, 2.9968, 1.2193, 3.5117, 2.7399,
          3.2270, 3.6581, 1.3968, 2.2523, 2.1150, 0.7927, 1.5964, 3.3448, 2.2397,
          1.8453]], grad_fn=<MulBackward0>))

In [27]:
bnvar_inv.shape, bnvar_inv

(torch.Size([1, 64]),
 tensor([[0.5805, 0.8880, 0.6811, 0.8257, 0.4627, 0.5802, 0.6928, 0.8355, 0.8757,
          0.8790, 0.8054, 0.8522, 0.5142, 0.8680, 0.9645, 0.6837, 0.7056, 0.4953,
          1.0274, 0.6128, 0.4385, 0.8747, 0.5984, 0.5879, 0.7866, 0.6084, 0.7125,
          0.8191, 0.6001, 0.7270, 0.8349, 0.5481, 0.6198, 0.7661, 0.6244, 0.5635,
          0.6416, 0.5181, 1.6249, 0.8119, 0.7836, 0.6311, 0.6562, 0.8965, 0.7153,
          0.5828, 0.7200, 0.6463, 0.7920, 0.6105, 0.5777, 0.9056, 0.5336, 0.6041,
          0.5567, 0.5228, 0.8461, 0.6663, 0.6876, 1.1232, 0.7915, 0.5468, 0.6682,
          0.7361]], grad_fn=<PowBackward0>))

In [28]:
bnraw.shape, bnraw

(torch.Size([32, 64]),
 tensor([[-0.5817, -1.3797,  0.8627,  ...,  1.8884, -0.5875,  0.4930],
         [-1.2495,  0.0474, -1.2664,  ..., -0.4685,  0.1448, -0.6798],
         [-0.6710,  0.9163,  1.4443,  ...,  0.4920, -0.9736,  0.5094],
         ...,
         [ 0.0684, -0.2040,  0.1894,  ..., -0.3157,  0.0462,  0.6566],
         [-0.9767, -1.8131,  0.9877,  ...,  0.5222,  1.4283, -1.6673],
         [ 1.6255, -0.2255, -0.1257,  ..., -0.6579, -1.8011, -0.5415]],
        grad_fn=<MulBackward0>))

In [29]:
hpreact.shape, hpreact

(torch.Size([32, 64]),
 tensor([[-0.6044, -1.4444,  0.8227,  ...,  1.9963, -0.6541,  0.5122],
         [-1.2776,  0.0673, -1.6394,  ..., -0.6261,  0.0124, -0.8218],
         [-0.6944,  0.9877,  1.4953,  ...,  0.4426, -1.0056,  0.5308],
         ...,
         [ 0.0508, -0.1990,  0.0442,  ..., -0.4560, -0.0773,  0.6982],
         [-1.0026, -1.9035,  0.9673,  ...,  0.4762,  1.1808, -1.9450],
         [ 1.6206, -0.2217, -0.3202,  ..., -0.8368, -1.7588, -0.6646]],
        grad_fn=<AddBackward0>))

In [30]:
h.shape, h

(torch.Size([32, 64]),
 tensor([[-0.5402, -0.8946,  0.6766,  ...,  0.9638, -0.5744,  0.4716],
         [-0.8559,  0.0672, -0.9274,  ..., -0.5554,  0.0124, -0.6760],
         [-0.6008,  0.7564,  0.9043,  ...,  0.4158, -0.7639,  0.4860],
         ...,
         [ 0.0508, -0.1964,  0.0442,  ..., -0.4268, -0.0771,  0.6033],
         [-0.7627, -0.9565,  0.7475,  ...,  0.4432,  0.8277, -0.9599],
         [ 0.9247, -0.2181, -0.3097,  ..., -0.6841, -0.9424, -0.5814]],
        grad_fn=<TanhBackward0>))

In [31]:
logits.shape, logits

(torch.Size([32, 27]),
 tensor([[ 7.7128e-01,  9.3263e-01, -5.1015e-01,  4.1128e-01, -5.0394e-01,
           9.7078e-01, -2.9692e-01,  9.7793e-02, -6.0967e-01,  1.5935e-01,
           1.3660e-01,  1.3891e-01,  1.6699e-01, -1.8374e-01,  9.1531e-02,
          -7.9658e-01, -1.3272e+00, -4.8379e-01, -6.4224e-01,  6.0472e-01,
           3.9351e-01, -3.6342e-01, -3.2196e-01,  7.8311e-01,  7.0067e-01,
          -2.3011e-01, -4.0359e-01],
         [ 3.9979e-01,  2.9018e-01,  8.1399e-01,  2.9322e-01, -1.3265e-01,
          -2.0699e-01, -8.1407e-01,  1.4490e-01, -5.9691e-01, -4.4663e-01,
           2.9095e-01, -1.0618e-02,  1.8841e-01, -3.8759e-01,  2.5519e-01,
           3.8835e-02, -3.8375e-01, -8.2597e-01, -6.2724e-01, -8.9176e-02,
          -7.4729e-01, -5.4558e-01, -1.0391e+00,  4.0443e-01, -4.5633e-01,
          -9.0184e-02, -4.9872e-01],
         [-5.0269e-01, -4.2230e-01, -8.3734e-01, -8.9921e-01, -3.3879e-01,
           2.1711e-01,  6.0008e-01,  6.3464e-01,  5.6448e-01,  1.2675e-02,
   

In [32]:
logit_maxes.shape, logit_maxes

(torch.Size([32, 1]),
 tensor([[0.9708],
         [0.8140],
         [1.2211],
         [0.6820],
         [1.7148],
         [0.9150],
         [0.7778],
         [1.5061],
         [1.0722],
         [1.0249],
         [1.7570],
         [2.0432],
         [1.0086],
         [0.8913],
         [0.5998],
         [0.7661],
         [0.9427],
         [0.7801],
         [1.0086],
         [0.8390],
         [0.8303],
         [1.0503],
         [1.0086],
         [1.1103],
         [1.5602],
         [0.9252],
         [1.1188],
         [0.8919],
         [0.9395],
         [0.7835],
         [1.1258],
         [0.8638]], grad_fn=<MaxBackward0>))

In [33]:
norm_logits.shape, norm_logits

(torch.Size([32, 27]),
 tensor([[-0.1995, -0.0381, -1.4809, -0.5595, -1.4747,  0.0000, -1.2677, -0.8730,
          -1.5805, -0.8114, -0.8342, -0.8319, -0.8038, -1.1545, -0.8792, -1.7674,
          -2.2979, -1.4546, -1.6130, -0.3661, -0.5773, -1.3342, -1.2927, -0.1877,
          -0.2701, -1.2009, -1.3744],
         [-0.4142, -0.5238,  0.0000, -0.5208, -0.9466, -1.0210, -1.6281, -0.6691,
          -1.4109, -1.2606, -0.5230, -0.8246, -0.6256, -1.2016, -0.5588, -0.7752,
          -1.1977, -1.6400, -1.4412, -0.9032, -1.5613, -1.3596, -1.8531, -0.4096,
          -1.2703, -0.9042, -1.3127],
         [-1.7237, -1.6433, -2.0584, -2.1203, -1.5598, -1.0039, -0.6210, -0.5864,
          -0.6566, -1.2084, -1.7411, -1.2391, -1.0512, -0.8983, -1.5063, -1.4009,
          -2.0944, -1.1996, -1.3850,  0.0000, -0.6337, -1.1380, -1.1169, -1.0992,
          -1.1033, -1.7579, -1.4735],
         [-0.7926, -1.0046, -0.4975, -0.2862, -0.2143, -0.9017, -0.4309, -0.5628,
          -0.4070, -1.2227, -0.5791, -1.089

In [34]:
counts.shape, counts

(torch.Size([32, 27]),
 tensor([[0.8191, 0.9626, 0.2274, 0.5715, 0.2288, 1.0000, 0.2815, 0.4177, 0.2059,
          0.4442, 0.4342, 0.4352, 0.4476, 0.3152, 0.4151, 0.1708, 0.1005, 0.2335,
          0.1993, 0.6935, 0.5614, 0.2634, 0.2745, 0.8289, 0.7633, 0.3009, 0.2530],
         [0.6609, 0.5923, 1.0000, 0.5941, 0.3880, 0.3602, 0.1963, 0.5122, 0.2439,
          0.2835, 0.5927, 0.4384, 0.5349, 0.3007, 0.5719, 0.4606, 0.3019, 0.1940,
          0.2366, 0.4053, 0.2099, 0.2568, 0.1567, 0.6639, 0.2807, 0.4049, 0.2691],
         [0.1784, 0.1933, 0.1277, 0.1200, 0.2102, 0.3664, 0.5374, 0.5563, 0.5186,
          0.2987, 0.1753, 0.2897, 0.3495, 0.4073, 0.2217, 0.2464, 0.1231, 0.3013,
          0.2503, 1.0000, 0.5306, 0.3205, 0.3273, 0.3331, 0.3318, 0.1724, 0.2291],
         [0.4527, 0.3662, 0.6080, 0.7511, 0.8071, 0.4059, 0.6499, 0.5696, 0.6657,
          0.2944, 0.5604, 0.3363, 0.4993, 0.5338, 0.8175, 1.0000, 0.3456, 0.3861,
          0.2510, 0.7380, 0.2387, 0.3348, 0.5102, 0.2953, 0.3844, 0.6569

In [35]:
counts_sum.shape, counts_sum

(torch.Size([32, 1]),
 tensor([[11.8491],
         [11.1105],
         [ 8.7165],
         [14.0813],
         [ 5.7018],
         [12.1429],
         [12.4093],
         [ 7.7319],
         [10.7903],
         [10.8823],
         [ 6.5783],
         [ 5.1394],
         [10.9708],
         [11.8559],
         [16.1933],
         [12.3383],
         [11.7193],
         [12.3621],
         [10.9708],
         [11.2925],
         [12.5316],
         [10.9061],
         [10.9708],
         [ 9.6062],
         [ 6.4158],
         [12.5034],
         [ 8.4255],
         [11.3254],
         [12.4496],
         [13.3811],
         [11.6646],
         [12.5479]], grad_fn=<SumBackward1>))

In [36]:
counts_sum_inv.shape, counts_sum_inv

(torch.Size([32, 1]),
 tensor([[0.0844],
         [0.0900],
         [0.1147],
         [0.0710],
         [0.1754],
         [0.0824],
         [0.0806],
         [0.1293],
         [0.0927],
         [0.0919],
         [0.1520],
         [0.1946],
         [0.0912],
         [0.0843],
         [0.0618],
         [0.0810],
         [0.0853],
         [0.0809],
         [0.0912],
         [0.0886],
         [0.0798],
         [0.0917],
         [0.0912],
         [0.1041],
         [0.1559],
         [0.0800],
         [0.1187],
         [0.0883],
         [0.0803],
         [0.0747],
         [0.0857],
         [0.0797]], grad_fn=<PowBackward0>))

In [37]:
probs.shape, probs

(torch.Size([32, 27]),
 tensor([[0.0691, 0.0812, 0.0192, 0.0482, 0.0193, 0.0844, 0.0238, 0.0353, 0.0174,
          0.0375, 0.0366, 0.0367, 0.0378, 0.0266, 0.0350, 0.0144, 0.0085, 0.0197,
          0.0168, 0.0585, 0.0474, 0.0222, 0.0232, 0.0700, 0.0644, 0.0254, 0.0214],
         [0.0595, 0.0533, 0.0900, 0.0535, 0.0349, 0.0324, 0.0177, 0.0461, 0.0220,
          0.0255, 0.0533, 0.0395, 0.0481, 0.0271, 0.0515, 0.0415, 0.0272, 0.0175,
          0.0213, 0.0365, 0.0189, 0.0231, 0.0141, 0.0598, 0.0253, 0.0364, 0.0242],
         [0.0205, 0.0222, 0.0146, 0.0138, 0.0241, 0.0420, 0.0617, 0.0638, 0.0595,
          0.0343, 0.0201, 0.0332, 0.0401, 0.0467, 0.0254, 0.0283, 0.0141, 0.0346,
          0.0287, 0.1147, 0.0609, 0.0368, 0.0375, 0.0382, 0.0381, 0.0198, 0.0263],
         [0.0321, 0.0260, 0.0432, 0.0533, 0.0573, 0.0288, 0.0462, 0.0405, 0.0473,
          0.0209, 0.0398, 0.0239, 0.0355, 0.0379, 0.0581, 0.0710, 0.0245, 0.0274,
          0.0178, 0.0524, 0.0170, 0.0238, 0.0362, 0.0210, 0.0273, 0.0467

In [38]:
logprobs.shape, logprobs

(torch.Size([32, 27]),
 tensor([[-2.6717, -2.5104, -3.9532, -3.0318, -3.9470, -2.4723, -3.7400, -3.3452,
          -4.0527, -3.2837, -3.3064, -3.3041, -3.2760, -3.6268, -3.3515, -4.2396,
          -4.7702, -3.9268, -4.0853, -2.8383, -3.0495, -3.8064, -3.7650, -2.6599,
          -2.7424, -3.6731, -3.8466],
         [-2.8221, -2.9317, -2.4079, -2.9287, -3.3545, -3.4289, -4.0360, -3.0770,
          -3.8188, -3.6685, -2.9309, -3.2325, -3.0335, -3.6095, -2.9667, -3.1830,
          -3.6056, -4.0478, -3.8491, -3.3111, -3.9692, -3.7675, -4.2610, -2.8175,
          -3.6782, -3.3121, -3.7206],
         [-3.8890, -3.8086, -4.2236, -4.2855, -3.7251, -3.1692, -2.7862, -2.7516,
          -2.8218, -3.3736, -3.9063, -3.4043, -3.2164, -3.0635, -3.6715, -3.5661,
          -4.2596, -3.3648, -3.5502, -2.1652, -2.7989, -3.3032, -3.2822, -3.2644,
          -3.2685, -3.9231, -3.6387],
         [-3.4374, -3.6495, -3.1424, -2.9311, -2.8592, -3.5466, -3.0758, -3.2076,
          -3.0518, -3.8676, -3.2239, -3.734

In [39]:
loss.shape, loss

(torch.Size([]), tensor(3.3277, grad_fn=<NegBackward0>))

In [41]:
l = -logprobs[range(n), Yb]

In [42]:
l.shape, l

(torch.Size([32]),
 tensor([4.0527, 2.9667, 3.5661, 3.3178, 4.0940, 3.5239, 3.0496, 4.0385, 3.2198,
         4.2614, 3.1726, 1.6369, 2.8379, 2.9222, 3.0287, 3.1440, 3.9667, 3.0623,
         3.6613, 3.2564, 2.8960, 2.9260, 4.2945, 3.9861, 3.5233, 2.8240, 2.9169,
         3.8597, 2.8218, 3.3063, 3.2330, 3.1199], grad_fn=<NegBackward0>))

In [44]:
Yb

tensor([ 8, 14, 15, 22,  0, 19,  9, 14,  5,  1, 20,  3,  8, 14, 12,  0, 11,  0,
        26,  9, 25,  0,  1,  1,  7, 18,  9,  3,  5,  9,  0, 18])

In [46]:
-logprobs[1, 14]

tensor(2.9667, grad_fn=<NegBackward0>)

In [53]:
logprobs.grad

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -0.0312,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0312,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0312,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000

In [71]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n
dlogprobs

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -0.0312,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0312,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0312,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000

In [None]:
# Exercise 1: backprop through the whole thing manually, 
# backpropagating through exactly all of the variables 
# as they are defined in the forward pass above, one by one


In [83]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n
cmp('dlogprops', dlogprobs, logprobs)

dprobs = 1.0 / probs * dlogprobs
cmp('dprops', dprobs, probs)


dlogprops       | exact: True  | approximate: True  | maxdiff: 0.0
dprops          | exact: True  | approximate: True  | maxdiff: 0.0


In [79]:
dprobs = 1.0 /probs * dlogprobs
dprobs

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -1.7985,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.6071,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1056,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000

In [78]:
probs.grad

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -1.7985,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.6071,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1056,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000

In [None]:
dcounts_sum_inv = 