In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns


sns.set()

In [2]:
words = open('names.txt','r').read().splitlines()

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

### Slitting the dataset into training, test and validation 

In [33]:
def build_dataset(words:list):
    '''Seperates the data into training, test and validation'''
    block_size = 3 # Context Length: How many characters we take as input to predict the output 
    X,Y = [],[]
    
    for w in words:
        #print(w)
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            #print(''.join(itos[i] for i in context), '---->', itos[ix])
            context = context[1:] + [ix]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X,Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
Xtr,Ytr = build_dataset(words[:n1])
Xdev,Ydev = build_dataset(words[n1:n2])
Xte,Yte = build_dataset(words[n2:])

In [34]:
print(f'Training examples shapes: {Xtr.shape},{Ytr.shape}')
print(f'Validation examples shapes: {Xdev.shape},{Ydev.shape}')
print(f'Testing examples shapes: {Xte.shape},{Yte.shape}')

Training examples shapes: torch.Size([182580, 3]),torch.Size([182580])
Validation examples shapes: torch.Size([22767, 3]),torch.Size([22767])
Testing examples shapes: torch.Size([22799, 3]),torch.Size([22799])


In [37]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [39]:
class Linear:
    def __init__(self,fan_in:int,fan_out:int,bias:bool = True):
        self.weight = torch.randn((fan_in,fan_out),generator = g) / fan_in**0.5
        self.bias = torch.zeros(fan_out) if bias is True else None

    def __call__(self,x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class BatchNormalization:
    def __init__(self,dim,eps = 1e-5, momentum = 0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True

        # Parameters: Trained during Back-Propogation
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

        # Parameters: Updated using running mean and std
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)

    def __call__(self,x):
        # Calculate the forward pass
        if self.training:
            xmean = x.mean(0,keepdim = True)
            xvar = x.var(0, keepdim = True, unbiased = True)
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # Normalization to unit variance
        self.out = self.gamma * xhat + self.beta

        # Updating the running mean and variance
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out

    def parameters(self):
        return [self.gamma , self.beta]

class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []

In [41]:
def cmp(s,dt,t):
    '''Utility function to compare the gradients'''
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt,t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'| Exact : {ex} | Absolute : {app} | Max Difference : {maxdiff}')

In [43]:
# Generating the Neural Network
n_embed = 10
n_hidden = 100
g = torch.Generator().manual_seed(21644343223)
block_size = 3
vocab_size = 27
C = torch.randn((vocab_size,n_embed), generator = g)
layers = [
    Linear(n_embed*block_size,n_hidden),
    BatchNormalization(n_hidden),
    Tanh(),
    Linear(n_hidden,n_hidden),
    BatchNormalization(n_hidden),
    Tanh(),
    Linear(n_hidden,n_hidden),
    BatchNormalization(n_hidden),
    Tanh(),
    Linear(n_hidden,n_hidden),
    BatchNormalization(n_hidden),
    Tanh(),
    Linear(n_hidden,n_hidden),
    BatchNormalization(n_hidden),
    Tanh(),
    Linear(n_hidden,vocab_size),
    BatchNormalization(vocab_size),
]

with torch.no_grad():
    layers[-1].gamma *=  0.1
    for layer in layers[:-1]:
        if isinstance(layer,Linear):
            layer.weight *= 5/3

parameters = [C] + [p for layer in layers for p in layer.parameters()]
print(sum(p.nelement() for p in  parameters))

for p in parameters:
    p.requires_grad = True

47551


In [45]:
batch_size = 32
max_steps = 200000
losses = []

for i in range(max_steps):
    # Mini-batch construct
    ix = torch.randint(0,Xtr.shape[0],(batch_size, ), generator = g)
    Xb,yb = Xtr[ix],Ytr[ix]

    # ----- FORWARD PASS ----- #
    emb = C[Xb] # Embed the characters into vectors
    x = emb.view(emb.shape[0],-1) # Reshape the embedding vector into 2-dimensional matrix
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x,yb) # Categorical Cross-Entropy as the loss function

    # ----- BACKWARD PASS ----- #
    for layer in layers:
        layer.out.retain_grad()
    for p in parameters:
        p.grad = None
    loss.backward()

    # ----- UPDATE ----- #
    lr = 0.1 if i < (max_steps // 2) else 0.01
    for p in parameters:
        p.data -= lr * p.grad

    # ----- PRINTING LOSS ----- #
    if (i % 10000 == 0):
        print(f'{i}/{max_steps} -> {loss.item()}')
    losses.append(loss.log10().item())
    break

0/200000 -> 3.3125298023223877


In [47]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP
vocab_size = 27
block_size = 3
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [49]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [51]:
#forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.4921, grad_fn=<NegBackward0>)

# Excercise-1: Below code has hard-coded derivatives of all variables

1. `dlogprobs`

Now `logprobs` is just calculating the loss using the following formula $$loss = \frac{-1}{n} \sum Y_{true} \times log(Y_{pred})$$ where $Y_{true}$ just provides the index of the true value

Effectively what the loss is actually $$loss = \frac{-1}{n} log(Y_{pred})$$ where $Y_{pred}$ has a shape of $batch \times vocab$

In [55]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[torch.arange(n),Yb] = -1.0/n

cmp('logprobs',dlogprobs,logprobs)


| Exact : True | Absolute : True | Max Difference : 0.0


2. `probs`

Now derivative of probs w.r.t loss is given by 
$$\frac{ \delta L}{\delta probs} = \frac{\delta L}{\delta logprobs} \times \frac{\delta logprobs}{\delta probs}$$

We have already calculated $\frac{\delta L}{\delta logprobs}$. And $\frac{\delta logprobs}{\delta probs}$ is given by
$$
\frac{\delta logprobs}{\delta probs} = \frac{1}{probs}
$$


In [58]:
dprobs = torch.zeros_like(probs)
dprobs = 1.0/probs * dlogprobs

cmp('probs',dprobs,probs)

| Exact : True | Absolute : True | Max Difference : 0.0


3. `dcounts_sum_inv`

Now derivative of `dcounts_sum_inv` w.r.t loss is given by
$$
\frac{ \delta L}{\delta csi} = \frac{\delta L}{\delta probs} \times \frac{\delta probs}{\delta csi}
$$

Therefore $\frac{\delta probs}{\delta csi}$ is given by
$$
\frac{\delta probs}{\delta csi} = counts
$$



In [61]:
dcounts_sum_inv = torch.zeros_like(counts_sum_inv)
dcounts_sum_inv = (counts * dprobs).sum(1,keepdims = True)

cmp('d_counts_inv',dcounts_sum_inv,counts_sum_inv)

| Exact : True | Absolute : True | Max Difference : 0.0


4 `dcounts`

Same as `dcounts_sum_inv` we can calculate `dcounts` as follows
$$
\frac{ \delta L}{\delta counts} = \frac{\delta L}{\delta probs} \times \frac{\delta probs}{\delta counts}
$$

Therefore $\frac{\delta probs}{\delta counts}$ is given by
$$
\frac{\delta probs}{\delta csi} = csi
$$


In [64]:
dcounts = torch.zeros_like(counts)

dcounts  = counts_sum_inv * dprobs

cmp('dcounts',dcounts,counts)

| Exact : False | Absolute : False | Max Difference : 0.005815508309751749


5 `dcounts_sum`

In [67]:
dcounts_sum = torch.zeros_like(dcounts)

dcounts_sum = -1/(counts_sum**2) * dcounts_sum_inv

cmp('dcounts_sum', dcounts_sum,counts_sum)

| Exact : True | Absolute : True | Max Difference : 0.0


6 `dcounts_sum`

In [70]:
dcounts += torch.ones_like(counts) * dcounts_sum

cmp('dcounts',dcounts,counts)

| Exact : True | Absolute : True | Max Difference : 0.0


7 `dnorm_logits`

In [73]:
dnorm_logits = torch.zeros_like(norm_logits)

dnorm_logits = counts * dcounts

cmp('dnorm_logits',dnorm_logits,norm_logits)

| Exact : True | Absolute : True | Max Difference : 0.0


8 `dlogits`

In [78]:
dlogits = torch.zeros_like(logits)
dlogits = dnorm_logits + F.one_hot(logits.max(1).indices,num_classes = logits.shape[1]) * dlogits_maxes


cmp('dlogits',dlogits,logits)

| Exact : True | Absolute : True | Max Difference : 0.0


9 `dlogits_maxes`

In [76]:
dlogits_maxes = torch.zeros_like(dlogits)
dlogits_maxes = (-dnorm_logits).sum(1,keepdims = True)

cmp('dlogit_maxes',dlogits_maxes,logit_maxes)

| Exact : True | Absolute : True | Max Difference : 0.0


10 `dh`

In [None]:
print(h.shape,W2.shape,b2.shape,dlogits.shape)

In [80]:
dh = torch.zeros_like(h)
dh = dlogits @ torch.transpose(W2,0,1)

cmp('dh',dh,h)

| Exact : True | Absolute : True | Max Difference : 0.0


11 `W2`

In [83]:
dW2 = torch.zeros_like(W2)

dW2 = torch.transpose(h,0,1) @ dlogits

cmp('dW2',dW2,W2)

| Exact : True | Absolute : True | Max Difference : 0.0


12 `db2`

In [86]:
db2 = dlogits.sum(0)

cmp('db2',db2,b2)

| Exact : True | Absolute : True | Max Difference : 0.0


13 `dhpreact`

In [89]:
dhpreact = (torch.ones_like(hpreact) - h**2)*dh
cmp('dhpreact',dhpreact,hpreact)

| Exact : True | Absolute : True | Max Difference : 0.0


14 `dbngain`

In [92]:
dbngain = torch.zeros_like(bngain)
dbngain = (bnraw * dhpreact).sum(0,keepdims = True)
cmp('dbngain',dbngain,bngain)

| Exact : True | Absolute : True | Max Difference : 0.0


15 `dbnraw`

In [95]:
dbnraw = torch.zeros_like(bnraw)

dbnraw = bngain * dhpreact

cmp('dbnraw',dbnraw,bnraw)
# print(dbnraw.shape,dhpreact.shape, bngain.shape,(bngain*dhpreact).shape)

| Exact : True | Absolute : True | Max Difference : 0.0


16 `dbnbias`

In [98]:
dbnbias = torch.zeros_like(bnbias)

dbnbias = dhpreact.sum(0,keepdims = True)

cmp('dbnbias',dbnbias,bnbias)
# print(dhpreact.sum(0,keepdims = True).shape,dbnbias.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


17 `dbndiff`

In [112]:
dbndiff = torch.zeros_like(bndiff)

dbndiff = bnvar_inv * dbnraw 
dbndiff += (2*bndiff) * dbndiff2

cmp('dbndiff',dbndiff,bndiff)
# print(dbndiff.shape,dbnraw.shape,bnvar_inv.shape,bnvar.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


18 `dbnvar_inv`

In [104]:
dbnvar_inv = (bndiff * dbnraw).sum(0,keepdims = True)

cmp('dbnvar_inv',dbnvar_inv, bnvar_inv)

| Exact : True | Absolute : True | Max Difference : 0.0


19 `dbnvar`

In [107]:
dbnvar = torch.zeros_like(bnvar)
dbnvar = (-0.5 * (bnvar + 1e-5)**(-1.5)) * dbnvar_inv
# print(dbnvar.shape,bnvar_inv.shape)

cmp('dbnvar',dbnvar,bnvar)

| Exact : True | Absolute : True | Max Difference : 0.0


20 `dbndiff2`

In [110]:
dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar

cmp('dbndiff2',dbndiff2,bndiff2)

| Exact : True | Absolute : True | Max Difference : 0.0


21 `dbmneai`

In [114]:
dbnmeani = torch.zeros_like(bnmeani)

dbnmeani = -1.0 * dbndiff.sum(0,keepdims = True)

cmp('dbnmeani',dbnmeani,bnmeani)
# print(dbnmeani.shape,dbndiff.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


22 `dbhprebn`

In [117]:
dbhprebn = torch.zeros_like(hprebn)

dbhprebn = dbndiff + (1/n)*dbnmeani

cmp('dbhprebn',dbhprebn,hprebn)
# print(dbhprebn.shape,bnmeani.shape,bndiff.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


23 `dbembcat`

In [120]:
dbembcat = torch.zeros_like(embcat)

dbembcat = dbhprebn @ torch.transpose(W1,0,1)


cmp('dbembcat',dbembcat,embcat)
# print(dbembcat.shape, dbhprebn.shape, W1.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


24 `dW1`

In [123]:
dW1 = torch.zeros_like(W1)

dW1 = torch.transpose(embcat,0,1) @ dbhprebn

cmp('dW1',dW1,W1)
# print(dbembcat.shape, dbhprebn.shape, W1.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


25 `db1`

In [126]:
db1 = torch.zeros_like(b1)

db1 = dbhprebn.sum(0,keepdims = True)


cmp('db1',db1,b1)
# print(db1.shape,dbhprebn.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


26 `demb`

In [135]:
dbemb = torch.zeros_like(emb)

dbemb = dbembcat.view(dbembcat.shape[0],block_size,-1)


cmp('dbemb',dbemb,emb)
# print(emb.shape,dbembcat.shape)

| Exact : True | Absolute : True | Max Difference : 0.0


27 `dC`

In [140]:
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += dbemb[k,j]

cmp('dC',dC,C)

| Exact : True | Absolute : True | Max Difference : 0.0


# Excecise-2: Using the `cross entropy` loss for the loss

In [143]:
loss_fast = F.cross_entropy(logits,Yb)

print('fast_loss : ',loss_fast.item(),'diff : ',(loss_fast - loss).abs().item())

fast_loss :  3.4920926094055176 diff :  0.0


We need to calculate $\frac {\delta \mathbf{L}}{\delta logits}$. The forumla for loss is given by
$$loss = \frac{-1}{n} \sum Y_{true} \times log(Y_{pred})$$

In [146]:
dlogits = F.softmax(logits,1)
dlogits[range(n),Yb] -= 1
dlogits /= n

cmp('dlogits',dlogits,logits)

| Exact : False | Absolute : True | Max Difference : 5.3551048040390015e-09


# Excercise-3: Faster Batch Normalization

In [148]:
dhprebn = bngain*bnvar_inv/n * (n * dhpreact - dhpreact.sum(0) - n/(n-1) * bnraw*(dhpreact*bnraw).sum(0))

cmp('dhprebn',dhprebn,hprebn)

| Exact : False | Absolute : True | Max Difference : 4.656612873077393e-10
