In [1]:
import pandas as pd
import random
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

# Dataset Exploration and Modification

In [2]:
dinosaur_csv = pd.read_csv("dinosaur.csv")
dinosaurs = dinosaur_csv['Name'].tolist()
dinosaurs = [d.lower().strip() for d in dinosaurs]
random.shuffle(dinosaurs)
dinosaurs[:5]

['xinjiangovenator',
 'deinocheirus',
 'amargatitanis',
 'dryptosauroides',
 'leaellynasaura']

In [5]:
chars = sorted(list(set(''.join(dinosaurs)))) # sort letter alphabetically with no repeating letters
stoi = {s:i+1 for i, s in enumerate(chars)} # create a mapping of chars to integers
stoi['.'] = 0 # add special start and end character

itos = {i:s for s,i in stoi.items()} # create a integer to char mapping

vocab_size = len(itos)
print(vocab_size)
print("stoi: ",  stoi)
print("itos: ", itos)

27
stoi:  {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}
itos:  {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


# Dataset Initialization and Splitting

In [7]:
# for reproducibility
g = torch.Generator().manual_seed(42)

# dataset splitting into train(80%), dev/validation(10%), test(10%)
# train is used to train parameters
# dev/validation is used to train hyperparameters
# test is used to evaluate the model
train, dev, test = torch.utils.data.random_split(dinosaurs, [0.8, 0.1, 0.1], generator=g)

block_size = 3 # number of characters needed to predict the next one

# function to create inputs and outputs from dataset
def build_dataset(dataset):
    X, Y = [], []
    for d in dataset:
        context = [0] * block_size # create context 'window'
        for char in d + '.':
            idx = stoi[char] # get current character index

            X.append(context) # add current context to inputs
            Y.append(idx) # add current index to label
            context = context[1:] + [idx] # slide context 'window'
            
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X,Y
    

# defining training inputs and labels
Xtr, Ytr = build_dataset(train)
print("TRAIN: ", Xtr.shape, Ytr.shape)

# defining validation inputs and labels
Xdev, Ydev = build_dataset(dev)
print("DEV: ", Xdev.shape, Ydev.shape)

# defining test inputs and labels
Xte, Yte = build_dataset(test)
print("TEST: ", Xte.shape, Yte.shape)

TRAIN:  torch.Size([11888, 3]) torch.Size([11888])
DEV:  torch.Size([1451, 3]) torch.Size([1451])
TEST:  torch.Size([1517, 3]) torch.Size([1517])


# Compare Gradients

In [8]:
# utility function to compare manual gradients to PyTorch gradients

def cmp(s, dt, t): # dt are manual gradients; t.grad are PyTorch gradients
    
    ex = torch.all(dt == t.grad).item() # tensor that evaluates to true if the gradients are equal
    
    app = torch.allclose(dt, t.grad) # tensor that evaluates to true if the gradients are within 1e-05 to 1e-08
    
    maxdiff = (dt - t.grad).abs().max().item() # calculates the maximum difference between the gradients
    
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

# Initialize Model

In [9]:
n_embed = 10 # dimensionality of charcter embedding vectors
n_hidden = 64 # number of neurons in each hidden layer

C = torch.randn((vocab_size, n_embed),             generator=g) # create a lookup table

# layer 1
W1 = torch.randn((n_embed * block_size ,n_hidden), generator=g) * (5/3) / ((n_embed * block_size)**0.5) # weights matrix
b1 = torch.randn(n_hidden,                         generator=g) * 0.01 # b1 does nothing because of batch norm

# layer 2
W2 = torch.randn((n_hidden, vocab_size),           generator=g) * 0.01 # weights matrix
b2 = torch.randn(vocab_size,                       generator=g) * 0.01 # bias

# batch norm params
batchn_gain = torch.randn((1, n_hidden)) * 0.1 + 1.0
batchn_bias = torch.randn((1, n_hidden)) * 0.1


params = [C,W1,b1,W2,b2,batchn_gain, batchn_bias]

print("# of parameters: ", sum(p.nelement() for p in params))
for p in params:
    p.requires_grad = True

# of parameters:  4137


In [10]:
batch_size = 32
n = batch_size # short variable for convenience
idx = torch.randint(0, Xtr.shape[0], (n, ), generator=g)
Xb, Yb = Xtr[idx], Ytr[idx] # batch X and Y

# Expanded Forward Pass 

In [126]:
# forward pass (expanded because of manual backprop)

emb = C[Xb] # get character embeddings
emb_concat = emb.view(emb.shape[0], -1) # concat vectors

# linear layer 1
hprebn = emb_concat @ W1 + b1 # hidden layer pre-activation pre-batchnorm

# batch norm layer (same as gamma * (x - mean) / (std) + beta)
bnmeani = 1 / n * hprebn.sum(0, keepdims=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1 / (n - 1) * bndiff2.sum(0, keepdims=True)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = batchn_gain * bnraw + batchn_bias

# non-linear activation function
h = torch.tanh(hpreact) # hidden layer

# linear layer 2
logits = h @ W2 + b2 # output layer

# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdims=True).values
norm_logits = logits - logit_maxes # subtract for exp stability. ex: [5, -3, 10] - 10 ==> [-5, -13, 0]
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in params:
    p.grad = None
for t in [ logprobs, probs, counts_sum_inv, counts_sum, counts, norm_logits, logit_maxes, logits, h,
           hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, bnmeani, hprebn, emb_concat, emb]:
    t.retain_grad()
loss.backward()
loss.item()

3.2855911254882812

# Yay! Manual Backpropogation

In [133]:
# backprop through every variable in the forward pass manually

# logprobs
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n
cmp('logprobs', dlogprobs, logprobs)

# probs
dprobs = (1.0 / probs) * dlogprobs
cmp('probs', dprobs, probs)

# counts_sum_inv
dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True)
dcounts = counts_sum_inv * dprobs
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)

# counts_sum
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
cmp('counts_sum', dcounts_sum, counts_sum)

# counts
dcounts += torch.ones_like(counts) * dcounts_sum
cmp("counts", dcounts, counts)

# norm_logits
dnorm_logits = counts * dcounts
dlogits = dnorm_logits.clone()
cmp("norm_logits", dnorm_logits, norm_logits)

# logit_maxes
dlogit_maxes = (-dnorm_logits).sum(1, keepdims=True)
cmp('logit_maxes', dlogit_maxes, logit_maxes)

# logits (can also use one hot encoding)
ones = torch.zeros_like(logits)
ones[range(n), logits.max(1).indices] = 1
dlogits +=  ones * dlogit_maxes
cmp('logits', dlogits, logits)

# h, W2, b2
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)

# hpreact
dhpreact = (1.0 - h**2) * dh
cmp('hpreact', dhpreact, hpreact)

# batchn_gain
dbatchn_gain = (bnraw * dhpreact).sum(0, keepdims=True)
cmp('batchn_gain', dbatchn_gain, batchn_gain)

# bnraw
dbnraw = batchn_gain * dhpreact
cmp('bnraw', dbnraw, bnraw)

# batchn_bias
dbatchn_bias = dhpreact.sum(0, keepdims=True)
cmp('batchn_bias', dbatchn_bias, batchn_bias)

# bnvar_inv
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdims=True)
dbndiff = bnvar_inv * dbnraw 
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)

# bnvar
dbnvar = (-0.5 * (bnvar + 1e-5)**-1.5) * dbnvar_inv
cmp('bnvar', dbnvar, bnvar)

# bndiff2
dbndiff2 = (1.0 / (n-1)) * torch.ones_like(bndiff2) * dbnvar
cmp('bndiff2', dbndiff2, bndiff2)

# bndiff
dbndiff += (2 * bndiff) * dbndiff2
cmp('bndiff', dbndiff, bndiff)

# bnmeani
dhprebn = dbndiff.clone()
dbnmeani = (-dbndiff).sum(0, keepdims=True)
cmp('bnmeani', dbnmeani, bnmeani)

# hprebn
dhprebn += (1.0 / n) * torch.ones_like(hprebn) * dbnmeani
cmp('hprebn', dhprebn, hprebn)

# emb_concat, W1, b1
demb_concat = dhprebn @ W1.T
dW1 = emb_concat.T @ dhprebn
db1 = dhprebn.sum(0)
cmp('emb_concat', demb_concat, emb_concat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)

# emb
demb = demb_concat.view(emb.shape)
cmp('emb', demb, emb)

# C
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range (Xb.shape[1]):
        idx = Xb[k,j]
        dC[idx] += demb[k,j]
cmp('C', dC, C)

print("yay!")

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
batchn_gain     | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff: 0.0
batchn_bias     | exact: True  | approximate: True  | maxdiff:

### This was way too much! Rather than backpropogating through every single step, lets optimize and do seperate parts of the forward pass

# Backprop Cross-Entropy