In [61]:
from typing import Union, Any
import matplotlib.pyplot as plt
%matplotlib inline

In [62]:
with open('names.txt') as f:
    content = f.read()
    words = content.splitlines()
len(words)

32033

In [63]:
chars = sorted(list(set(''.join(words))))
stoi = { s: i+1 for i, s in enumerate(chars) }
stoi['.'] = 0
itos = { i: s for s, i in stoi.items() }
vocab_size = len(itos)

In [64]:
import torch

# Build dataset splits
def build_dataset(words: list[str], block_size: int) -> (torch.tensor, torch.tensor):
    x, y = [], []
    for w in words:
        #print(w)
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            x.append(context)
            y.append(ix)
            #print(''.join(itos[i] for i in context), "--->", ch)
            # Advance the rolling window of context
            context = context[1:] + [ix]
    X = torch.tensor(x)
    Y = torch.tensor(y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

CONTEXT_SZ = 3
Xtr, Ytr = build_dataset(words[:n1], CONTEXT_SZ)
Xdev, Ydev = build_dataset(words[n1:n2], CONTEXT_SZ)
Xte, Yte = build_dataset(words[n2:], CONTEXT_SZ)

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


# Model

In [65]:
def cmp(s: str, dt: torch.tensor, t: torch.tensor) -> None:
    """Compre manual gradient calculation with our own calculation."""
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f"{s:25s} | exact: {str(ex):5s} | approx: {str(app):5s} | maxdiff: {maxdiff} | shape: {t.shape == dt.shape} ({t.shape})")

In [66]:
class Linear:
    """A linear layer."""

    def __init__(
        self,
        fan_in: int,
        fan_out: int,
        generator: Union[torch.Generator, None] = None,
        weight_gain: float | None = None,
        bias_gain: float = 0.1,
    ) -> None:
        """Initialize Linear."""
        if weight_gain is None:
            weight_gain = (5/3) / (fan_in**0.5)
        self.weight = torch.randn((fan_in, fan_out), generator=generator) * weight_gain
        # Note: bias is spurious given we're using batch norm, but calculating anyway
        self.bias = torch.randn(fan_out, generator=g) * bias_gain
        self.out: Union[torch.tensor, None] = None

    def __call__(self, x: torch.tensor) -> torch.tensor:
        """Forward pass"""
        self.out = x @ self.weight + self.bias
        return self.out

    def parameters(self):
        return [self.weight, self.bias]
    
    def tensors(self) -> dict[str, torch.tensor]:
        return {
            "out": self.out,
            "weight": self.weight,
            "bias": self.bias,
        }

    def __str__(self) -> str:
        return self.__class__.__name__


class BatchNorm1d:
    """A batch normalization layer."""

    def __init__(self, dim: int, eps: float = 1e-5, momentum=0.001) -> None:
        """Initialize BatchNorm1d."""
        self.eps = eps
        self.momentum = momentum
        self.training = True
        self.dim = dim
        # Adds small random numbers to unmask gradient errors
        self.bngain = torch.randn((1, dim)) * 0.1 + 1.0
        self.bnbias = torch.randn((1, dim)) * 0.1
        # Buffers, trained with a running momentum update
        self.running_mean = torch.zeros((1, dim))
        self.running_var = torch.ones((1, dim))
        self.out: Union[torch.tensor, None] = None

    def __call__(self, x: torch.tensor) -> torch.tensor:
        """Forward pass."""
        if not self.training:
            xmean = self.running_mean
            xvar = self.running_var
        else:

            self.xmean = (1/self.dim) * x.sum(0, keepdim=True)
            self.xdiff = x - self.xmean
            self.xdiff2 = self.xdiff**2
            self.xvar = 1/(self.dim - 1)*(self.xdiff2).sum(0, keepdim=True) # Bessel's correction
            self.xvar_inv = (self.xvar + self.eps)**-0.5
            self.xraw = self.xdiff * self.xvar_inv

            # Update buffers
            with torch.no_grad():
                self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * self.xmean
                self.running_var = (1-self.momentum) * self.running_var + self.momentum * self.xvar

        self.out = self.bngain * self.xraw + self.bnbias
        return self.out

    def tensors(self) -> dict[str, torch.tensor]:
        return {
            'xmean': self.xmean,
            'xvar': self.xvar,
            'xvar_inv': self.xvar_inv,
            'xdiff': self.xdiff,
            'xdiff2': self.xdiff2,
            'xraw': self.xraw,
            'bnbias': self.bnbias,
            'bngain': self.bngain,
            'out': self.out,
        }

    def parameters(self):
        return [self.bngain, self.bnbias]

    def __str__(self) -> str:
        return self.__class__.__name__


class Nonlinearity:
    """Nonlinearity."""

    def __init__(self) -> None:
        """Initialize Nonlinearity."""
        self.act = torch.nn.Tanh()
        self.out: Union[torch.tensor, None] = None

    def __call__(self, x: torch.tensor) -> torch.tensor:
        """Forward pass"""
        self.out = self.act(x)
        return self.out
    
    def tensors(self) -> dict[str, torch.tensor]:
        return {'out': self.out}

    def __str__(self) -> str:
        return f"{self.__class__.__name__}"

    def parameters(self):
        return []

In [67]:
EMBED_SZ = 10
HIDDEN_LAYER_SZ = 64

class Model:

    def __init__(self, vocab_size: int, generator: torch.Generator = None) -> None:
        self.C = torch.randn((vocab_size, EMBED_SZ), generator=g)
        self.layers = [
            Linear(EMBED_SZ * CONTEXT_SZ, HIDDEN_LAYER_SZ, generator=g),
            BatchNorm1d(HIDDEN_LAYER_SZ),
            Nonlinearity(),
            Linear(HIDDEN_LAYER_SZ, vocab_size, generator=g),
        ]

        for layer in self.layers:
            layer.training = True

        # Reset parameters for training
        for p in self.parameters():
            p.requires_grad = True

        self.logits_maxes = None
        self.norm_logits = None
        self.counts = None
        self.counts_sum = None
        self.counts_sum_inv = None
        self.probs = None
        self.logprobs = None

    def parameters(self):
        return [self.C] + [p for layer in self.layers for p in layer.parameters()]

    @property
    def num_parameters(self) -> int:
        """Total number of parameters."""
        return sum(p.nelement() for p in self.parameters())

    def forward_pass(self, x: torch.tensor) -> torch.tensor:
        """Forward pass."""

        self.emb = self.C[x] # (N, CONTEXT_SZ, 2)
        self.embcat = self.emb.view(self.emb.shape[0], -1)  # concatenate the vectors

        xhat = self.embcat
        for layer in self.layers:
            xhat = layer(xhat)


        self.logits = xhat

        # Explicit implementation of loss function. Computing softmax here
        # to make sure x doesn't overflow.
        self.logits_maxes = self.logits.max(1, keepdim=True).values
        self.norm_logits = self.logits - self.logits_maxes
        self.counts = self.norm_logits.exp()
        self.counts_sum = self.counts.sum(1, keepdims=True)
        self.counts_sum_inv = self.counts_sum**-1
        self.probs = self.counts * self.counts_sum_inv
        self.logprobs = self.probs.log()
        return xhat
    
    def loss(self, x: torch.tensor, y: torch.tensor) -> float:
        return -self.logprobs[range(x.shape[0]), y].mean()
    
    def tensors(self) -> dict[torch.tensor]:
        result = {
            'logprobs': self.logprobs,
            'probs': self.probs,
            'counts_sum_inv': self.counts_sum_inv,
            'counts_sum': self.counts_sum,
            'counts': self.counts,
            'norm_logits': self.norm_logits,
            'logits_maxes': self.logits_maxes,
            'logits': self.logits,
        }
        i = len(self.layers) - 1
        layers = list(self.layers)
        layers.reverse()
        for layer in layers:
            tensors = list(layer.tensors().items())
            tensors.reverse()
            for name, t in tensors:
                result[f"{layer}-L{i}-{name}"] = t
            i = i - 1
        result.update({
            "embcat": self.embcat,
            "emb": self.emb,
            "C": self.C,
        })
        return result


In [77]:
from statistics import mean
from tqdm.notebook import tqdm
from torch import nn
from torch.nn import functional as F

MINI_BATCH_SZ = 32

g = torch.Generator().manual_seed(2147483647)
model = Model(vocab_size, generator=g)
print(model.num_parameters)


# Calculate a single batch
ix = torch.randint(0, Xtr.shape[0], (MINI_BATCH_SZ,), generator=g) # (MINI_BATCH_SZ)
Xb, Yb = Xtr[ix], Ytr[ix]

# Forward pass & loss calculation
logits = model.forward_pass(Xb)
loss = model.loss(Xb, Yb)
print(f"Loss: {loss:0.2f}")

cross_loss = F.cross_entropy(logits, Yb)
print(f"Cross Loss: {cross_loss:0.2f}")

for p in model.parameters():
    p.grad = None

# Manual backprop here
for name, t in model.tensors().items():
    print(f"Tensor: {name}, {t.shape}, {t.grad_fn}")
    t.retain_grad()

# Use for correctness comparisons
loss.backward()

4137
Loss: 3.92
Cross Loss: 3.92
Tensor: logprobs, torch.Size([32, 27]), <LogBackward0 object at 0x16612ace0>
Tensor: probs, torch.Size([32, 27]), <MulBackward0 object at 0x16612ace0>
Tensor: counts_sum_inv, torch.Size([32, 1]), <PowBackward0 object at 0x16612ace0>
Tensor: counts_sum, torch.Size([32, 1]), <SumBackward1 object at 0x16612ace0>
Tensor: counts, torch.Size([32, 27]), <ExpBackward0 object at 0x16612ace0>
Tensor: norm_logits, torch.Size([32, 27]), <SubBackward0 object at 0x16612ace0>
Tensor: logits_maxes, torch.Size([32, 1]), <MaxBackward0 object at 0x16612ace0>
Tensor: logits, torch.Size([32, 27]), <AddBackward0 object at 0x16612ace0>
Tensor: Linear-L3-bias, torch.Size([27]), None
Tensor: Linear-L3-weight, torch.Size([64, 27]), None
Tensor: Linear-L3-out, torch.Size([32, 27]), <AddBackward0 object at 0x16612ace0>
Tensor: Nonlinearity-L2-out, torch.Size([32, 64]), <TanhBackward0 object at 0x16612ace0>
Tensor: BatchNorm1d-L1-out, torch.Size([32, 64]), <AddBackward0 object at 0

# Manual Backprop

Work through the entire back propagation manually.

In [92]:
# dlogprobs:  Derivative of loss
# loss = -self.logprobs[range(x.shape[0]), Yb].mean() # 32, 27
# Yb is array of correct indicies
# Example:
# loss = -(a + b + c +... ) / 32
# loss = -1/32a + -1/32b ...
# dloss/da = -1/n.  The 27 dimension does not interact with the loss
dlogprobs = torch.zeros_like(model.logprobs)
dlogprobs[range(Xb.shape[0]), Yb] = -1.0 / MINI_BATCH_SZ

# dprobs: Derivative of the log
# self.logprobs = self.probs.log()
# logprobs = ln(a), ln(b), ln(c)
# dlogprobs/da = 1 / a * dlogprobs
# Explanation: Examples with a low probability assigned get a boosted gradient
dprobs = (1 / model.probs) * dlogprobs

# Batch norm is taking the logits, exponentiating them, then normalizing
# to create the probabilities (but it's broken down into smaller chunks.)
#
# dcouns_sum_inv: Derivative of the batch norm
#        self.probs = self.counts * self.counts_sum_inv
# c = a * b
# dc/da = b
# a[3x3] * b[3x1] = a11 * b1 + a12 * b1 + a13 * b1
#                   a21 * b2 + a22 * b2 + a23 * b2 
#                   a31 * b3 + a32 * b3 + a33 * b3 
# c[3x3]
# dc/da = b1, b1, b1
#         b2, b2, b2
#         b3, b3, b3
# dc/db = a11, a12, a13
#         a21, a22, a23
#         a31, a32, a33
# dc/db = a11+a21+a31, a12+a22+a32, a13+a23+a33
# Note that counts is 32x27 and counts_sum_inv is 32x1 so we're broadcasting
# the inv across
# Gradient for b1 is the sum of all the input rows so we sum and keep the 32x1
# dimensions.
dcounts_sum_inv = (model.counts * dprobs).sum(1, keepdim=True)
# dcounts can't be calculated yet because count_sum_inv depends on counts
# More up the branch
#        self.counts_sum_inv = self.counts_sum**-1
# c = 1/a
# dc/da = - 1 / a**2
dcounts_sum = (-1 / (model.counts_sum**2)) * dcounts_sum_inv

dcounts = (model.counts_sum_inv * dprobs)

# dcounts
#        self.counts_sum = self.counts.sum(1, keepdims=True)
#        self.probs = self.counts * self.counts_sum_inv
# norm_logits is 32x27
# counts_sum is 32x1
# a[3x3] = a11, a12, a13
#          a21, a22, a23
#          a31, a32, a33
# c = a.sum(1)
# c = a11 + a12 + a13,
#     a21 + a22 + a23,
#     a31 + a32 + a33,
# dc/da = 1
#         1
#         1
# dcounts = (contrib from dprob calc) + (contrib from counts_sum contrib)
# dcounts = (model.counts_sum_inv * dprobs) + (torch.ones_like(model.counts)) * dcounts_sum
dcounts += torch.ones_like(model.counts) * dcounts_sum

# dnorm_logits
#        self.counts = self.norm_logits.exp()
# c = e**a
# dc/da = e**a
# Could also be counts * dcounts
dnorm_logits = torch.exp(model.norm_logits) * dcounts

# dlogits_maxes
#        self.norm_logits = self.logits - self.logits_maxes
# c11 c12 c13     a11 a12 a13     b1
# c21 c22 c23  =  a21 a22 a23  -  b2
# c31 c32 c33     a31 a32 a33     b3
# e.g. c11 = a11 - b1
# dlogit_maxes = -1 * dnorm_logits
# Aside: dlogits_maxes is basically zero and its weird to back propagate through it.
dlogits_maxes = (-1 * dnorm_logits).sum(1, keepdim=True)

# dlogits part 1. Clone since we modify below.
dlogits = dnorm_logits.clone()

# dlogits part 2
#        self.logits_maxes = self.logits.max(1, keepdim=True).value 
# logits: 32x27
# logits_maxes: 32x1
# max:
# A[3x3] = a11, a12, a13
#          a21, a22, a23
#          a31, a32, a33
# C = max(A, 1)
# C = max(a11, a12, a13)
#     max(a21, a22, a23)
#     max(a31, a32, a33)
# The derivate flowing through is 1 for the appropriate entry that was plucked
# out. We need to scatter to the correct indicies
dlogits += F.one_hot(model.logits.max(1).indices, num_classes=model.logits.shape[1]) * dlogits_maxes

# logits = x @ self.weight + self.bias
# 32x27 = 32x64 @ 64x27 + 27x1
# C = A @ B
# C[2x4] = A[2x3] @ B[3x4] + D[1x4]
# D = c1 c2 c3 c4
#     c1 c2 c3 c4
# D11 = a11 * b11 + a12 * b21 + a13 * b31 + d1
# D12 = a11 * b12 + a12 * b22 + a13 * b32 + d2
# ...
# D21 = a21 * b11 + a22 * b21 + a23 * b31 + d1
# 
dh = dlogits @ model.layers[3].weight.T  # 32x64
dw2 = model.layers[2].out.T @ dlogits
db2 = dlogits.sum(0)

# dhpreact
#   self.out = self.act(x)
#   h = tanh(dhpreact)
# dloss/dhpreact = (1 / cosh(x)**2) * dh
dhpreact = (1 - model.layers[2].out**2) * dh

# bngain, xraw, bnbias
#   self.out = self.bngain * self.xraw + self.bnbias
#      32x64 = 1x64 * 32x64 + 1x64
# dloss/dhpreact = xraw * dhpreact
# dbngain is size 32
dbngain = (model.layers[1].xraw * dhpreact).sum(0, keepdim=True)
dbnraw = (model.layers[1].bngain * dhpreact)
dbnbias = dhpreact.sum(0, keepdim=True)

# self.xraw = self.xdiff * self.xvar_inv
#       32x64 = (32x64) * 1x64
# dloss/dxvar_inv = self.xdiff * dxraw
dxvar_inv = (model.layers[1].xdiff * dbnraw).sum(0, keepdim=True)
#    self.xraw = self.xdiff * self.xvar_inv
# dloss/dxdiff = dxvar_inv * dxraw
dxdiff =  model.layers[1].xvar_inv * dbnraw
# NOTE: We'll come back ot dxdiff
#   self.xvar_inv = (self.xvar + self.eps)**-0.5
# dloss/dxvar = -0.5*(xvar + eps)**-1.5 * dxvarinv
dxvar = (-0.5*(model.layers[1].xvar + model.layers[1].eps)**-1.5) * dxvar_inv

# self.xvar = 1/(self.dim - 1)*(self.xdiff2).sum(0, keepdim=True)
# a11 a12
# a21 a22
# -> b1 = 1/(n-1)*(a11 + a12)
# -> b2 = 1/(n-1)*(a21 + a22)
# dloss/dxdiff2 = 2/self.dim * (self.xdiff2) * dxvar
dxdiff2 = 1 / (model.layers[1].dim - 1) * torch.ones_like(model.layers[1].xdiff2) * dxvar

#   self.xdiff2 = self.xdiff**2
# dloss/dxdiff2 = 2 * xdiff * dxdiff2
dxdiff += (2 * model.layers[1].xdiff) * dxdiff2

#   self.xdiff = x - self.xmean
#   32x64 = 32x64 - 1x64
# Broadcasting in forward pass means sum in backwards pass
# dloss/xmean = -1 * dxdiff
dhprebn = dxdiff.clone()
dxmean =  (-1 * dxdiff).sum(0, keepdim=True)

 #   self.xmean = (1/self.dim) * x.sum(0, keepdim=True)
 # Same as dxdiff2
dhprebn += (1 / model.layers[1].dim) * torch.ones_like(model.layers[0].out) * dxmean


# logits = x @ self.weight + self.bias
# hprebn = embcat @ weight + bias
# 32x64 = 32x30 @ 30x64 + 64x1
# Same as above.
#demb = model.layers[0] @ model.layers[0].weight.T  # 32x64
# 32x30
dembcat = dhprebn @ model.layers[0].weight.T
# 30x64
dw1 = model.embcat.T @ dhprebn
# 64x1
db1 = dhprebn.sum(0)

#        self.embcat = self.emb.view(self.emb.shape[0], -1)  # concatenate the vectors
demb = dembcat.view(model.emb.shape)

#        self.emb = self.C[x] # (N, CONTEXT_SZ, 2)
# Route gradient back through assignment. Which row in C did the 10-dim embeddings come from?
# When rows were used multiple times, the gradients that arrived there have to add.
dC = torch.zeros_like(model.C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k, j]
        dC[ix] += demb[k, j]

results = {
    "logprobs": dlogprobs,
    "probs": dprobs,
    "counts_sum_inv": dcounts_sum_inv, 
    "counts_sum": dcounts_sum,
    "counts": dcounts,
    "norm_logits": dnorm_logits,
    "logits_maxes": dlogits_maxes,
    "logits": dlogits,
    "Linear-L3-out": dlogits,
    "Linear-L3-weight": dw2,
    "Linear-L3-bias": db2,
    "Nonlinearity-L2-out": dh,
    "BatchNorm1d-L1-out": dhpreact,
    "BatchNorm1d-L1-bngain": dbngain,
    "BatchNorm1d-L1-bnbias": dbnbias,
    "BatchNorm1d-L1-xraw": dbnraw,
    "BatchNorm1d-L1-xvar_inv": dxvar_inv,
    "BatchNorm1d-L1-xvar": dxvar,
    "BatchNorm1d-L1-xdiff2": dxdiff2,
    "BatchNorm1d-L1-xdiff": dxdiff,
    "BatchNorm1d-L1-xmean": dxmean,
    "Linear-L0-out": dhprebn,
    "Linear-L0-weight": dw1,
    "Linear-L0-bias": db1,
    "embcat": dembcat,
    "emb": demb,
    "C": dC,
}

for name, t1 in model.tensors().items():
    t2 = results.get(name)
    if t2 is not None:
        cmp(name, t2, t1)
    else:
        print(f"{name:25s} | shape: {t1.shape}")



dlogits2 = F.softmax(logits, 1)
dlogits2[range(MINI_BATCH_SZ), Yb] -= 1
dlogits2 /= MINI_BATCH_SZ

cmp('logits', dlogits2, model.tensors()["logits"])

logprobs                  | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 27]))
probs                     | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 27]))
counts_sum_inv            | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 1]))
counts_sum                | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 1]))
counts                    | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 27]))
norm_logits               | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 27]))
logits_maxes              | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 1]))
logits                    | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([32, 27]))
Linear-L3-bias            | exact: True  | approx: True  | maxdiff: 0.0 | shape: True (torch.Size([27]))
Linear-L3-weight          

In [71]:
#self.xraw = (x - self.xvar) * self.xvar_inv
#dbnraw.shape, model.layers[1].xraw.shape, model.layers[0].out.shape, model.layers[1].xraw.shape, model.layers[1].xvar_inv.shape

#model.layers[1].xdiff.shape, model.layers[0].out.shape,  model.layers[1].xmean.shape
#model.layers[0].out.shape, model.embcat.shape, model.layers[0].weight.shape, model.layers[0].bias.shape, 
model.emb.shape, model.embcat.shape

(torch.Size([32, 3, 10]), torch.Size([32, 30]))

In [None]:
u = 1/m * sum(xi)
du/xi = -2/m * sum( xi - u)

-2 * sum( (x-u)**2 - u)
