In [353]:
from typing import Union, Any
import matplotlib.pyplot as plt
%matplotlib inline

In [354]:
with open('names.txt') as f:
    content = f.read()
    words = content.splitlines()
len(words)

32033

In [355]:
chars = sorted(list(set(''.join(words))))
stoi = { s: i+1 for i, s in enumerate(chars) }
stoi['.'] = 0
itos = { i: s for s, i in stoi.items() }
vocab_size = len(itos)

In [356]:
import torch

# Build dataset splits
def build_dataset(words: list[str], block_size: int) -> (torch.tensor, torch.tensor):
    x, y = [], []
    for w in words:
        #print(w)
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            x.append(context)
            y.append(ix)
            #print(''.join(itos[i] for i in context), "--->", ch)
            # Advance the rolling window of context
            context = context[1:] + [ix]
    X = torch.tensor(x)
    Y = torch.tensor(y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

CONTEXT_SZ = 3
Xtr, Ytr = build_dataset(words[:n1], CONTEXT_SZ)
Xdev, Ydev = build_dataset(words[n1:n2], CONTEXT_SZ)
Xte, Yte = build_dataset(words[n2:], CONTEXT_SZ)

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


# Model

In [357]:
def cmp(s: str, dt: Any, t: torch.tensor) -> None:
    """Compre manual gradient calculation with our own calculation."""
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    printf(f"{s:15s} | exact: {str(ex):5s} | approx: {str(app):5s} | maxdiff: {maxdiff}")

In [358]:
class Linear:
    """A linear layer."""

    def __init__(
        self,
        fan_in: int,
        fan_out: int,
        generator: Union[torch.Generator, None] = None,
        weight_gain: float | None = None,
        bias_gain: float = 0.1,
    ) -> None:
        """Initialize Linear."""
        if weight_gain is None:
            weight_gain = (5/3) / (fan_in**0.5)
        self.weight = torch.randn((fan_in, fan_out), generator=generator) * weight_gain
        # Note: bias is spurious given we're using batch norm, but calculating anyway
        self.bias = torch.randn(fan_out, generator=g) * bias_gain
        self.out: Union[torch.tensor, None] = None

    def __call__(self, x: torch.tensor) -> torch.tensor:
        """Forward pass"""
        self.out = x @ self.weight + self.bias
        return self.out

    def parameters(self):
        return [self.weight, self.bias]
    
    def tensors(self) -> dict[str, torch.tensor]:
        return {'out': self.out}

    def __str__(self) -> str:
        return self.__class__.__name__


class BatchNorm1d:
    """A batch normalization layer."""

    def __init__(self, dim: int, eps: float = 1e-5, momentum=0.001) -> None:
        """Initialize BatchNorm1d."""
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # Adds small random numbers to unmask gradient errors
        self.bngain = torch.randn((1, dim)) * 0.1 + 1.0
        self.bnbias = torch.randn((1, dim)) * 0.1
        # Buffers, trained with a running momentum update
        self.running_mean = torch.zeros((1, dim))
        self.running_var = torch.ones((1, dim))
        self.out: Union[torch.tensor, None] = None

    def __call__(self, x: torch.tensor) -> torch.tensor:
        """Forward pass."""
        if not self.training:
            xmean = self.running_mean
            xvar = self.running_var
        else:

            self.xmean = x.mean(0, keepdim=True)
            self.xvar = x.var(0, keepdim=True, unbiased=True)
            self.xvar_inv = (self.xvar + self.eps)**-0.5
            self.xraw = (x - self.xvar) * self.xvar_inv

            # Update buffers
            with torch.no_grad():
                self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * self.xmean
                self.running_var = (1-self.momentum) * self.running_var + self.momentum * self.xvar

        self.out = self.bngain * self.xraw + self.bnbias
        return self.out

    def tensors(self) -> dict[str, torch.tensor]:
        return {
            'xmean': self.xmean,
            'xvar': self.xvar,
            'xvar_inv': self.xvar_inv,
            'xraw': self.xraw
        }

    def parameters(self):
        return [self.bngain, self.bnbias]

    def __str__(self) -> str:
        return self.__class__.__name__


class Nonlinearity:
    """Nonlinearity."""

    def __init__(self) -> None:
        """Initialize Nonlinearity."""
        self.act = torch.nn.Tanh()
        self.out: Union[torch.tensor, None] = None

    def __call__(self, x: torch.tensor) -> torch.tensor:
        """Forward pass"""
        self.out = self.act(x)
        return self.out
    
    def tensors(self) -> dict[str, torch.tensor]:
        return {'out': self.out}

    def __str__(self) -> str:
        return f"{self.__class__.__name__}"

    def parameters(self):
        return []

In [359]:
EMBED_SZ = 10
HIDDEN_LAYER_SZ = 64

class Model:

    def __init__(self, vocab_size: int, generator: torch.Generator = None) -> None:
        self.C = torch.randn((vocab_size, EMBED_SZ), generator=g)
        self.layers = [
            Linear(EMBED_SZ * CONTEXT_SZ, HIDDEN_LAYER_SZ, generator=g),
            BatchNorm1d(HIDDEN_LAYER_SZ),
            Nonlinearity(),
            Linear(HIDDEN_LAYER_SZ, vocab_size, generator=g),
        ]

        for layer in self.layers:
            layer.training = True

        # Reset parameters for training
        for p in self.parameters():
            p.requires_grad = True

        self.logits_maxes = None
        self.norm_logits = None
        self.counts = None
        self.counts_sum = None
        self.counts_sum_inv = None
        self.probs = None
        self.logprobs = None
        self.loss = None

    def parameters(self):
        return [self.C] + [p for layer in self.layers for p in layer.parameters()]

    @property
    def num_parameters(self) -> int:
        """Total number of parameters."""
        return sum(p.nelement() for p in self.parameters())

    def forward_pass(self, x: torch.tensor) -> torch.tensor:
        """Forward pass."""

        self.emb = self.C[x] # (N, CONTEXT_SZ, 2)
        self.embcat = self.emb.view(self.emb.shape[0], -1)  # concatenate the vectors

        xhat = self.embcat
        for layer in self.layers:
            xhat = layer(xhat)


        self.logits = xhat

        # Explicit implementation of loss function
        self.logits_maxes = self.logits.max(1, keepdim=True).values
        self.norm_logits = self.logits - self.logits_maxes
        self.counts = self.norm_logits.exp()
        self.counts_sum = self.counts.sum(1, keepdims=True)
        self.counts_sum_inv = self.counts_sum**-1
        self.probs = self.counts * self.counts_sum_inv
        self.logprobs = self.probs.log()
        self.loss = -self.logprobs[range(x.shape[0]), Yb].mean()
        return xhat
    
    def tensors(self) -> dict[torch.tensor]:
        result = {
            'logprobs': self.logprobs,
            'probs': self.probs,
            'counts_sum_inv': self.counts_sum_inv,
            'counts_sum': self.counts_sum,
            'counts': self.counts,
            'norm_logits': self.norm_logits,
            'logits_maxes': self.logits_maxes
        }
        i = 0
        for layer in self.layers:
            for name, t in layer.tensors().items():
                result[f"{layer}{i}-{name}"] = t
            i = i + 1
        return result


In [360]:
from statistics import mean
from tqdm.notebook import tqdm
from torch import nn
from torch.nn import functional as F

MINI_BATCH_SZ = 32

g = torch.Generator().manual_seed(2147483647)
model = Model(vocab_size, generator=g)
print(model.num_parameters)


# Calculate a single batch
ix = torch.randint(0, Xtr.shape[0], (MINI_BATCH_SZ,), generator=g) # (MINI_BATCH_SZ)
Xb, Yb = Xtr[ix], Ytr[ix]

# Forward pass & loss calculation
logits = model.forward_pass(Xb)

loss = F.cross_entropy(logits, Yb)

for p in model.parameters():
    p.grad = None

# Manual backprop here
for name, t in model.tensors().items():
    print(f"Tensor: {name}, {t.grad_fn}")
    t.retain_grad()

# Use for correctness comparisons
loss.backward()




4137
Tensor: logprobs, <LogBackward0 object at 0x16afa6aa0>
Tensor: probs, <MulBackward0 object at 0x16afa6aa0>
Tensor: counts_sum_inv, <PowBackward0 object at 0x16afa6aa0>
Tensor: counts_sum, <SumBackward1 object at 0x16afa6aa0>
Tensor: counts, <ExpBackward0 object at 0x16afa6aa0>
Tensor: norm_logits, <SubBackward0 object at 0x16afa6aa0>
Tensor: logits_maxes, <MaxBackward0 object at 0x16afa6aa0>
Tensor: Linear0-out, <AddBackward0 object at 0x16afa6aa0>
Tensor: BatchNorm1d1-xmean, <MeanBackward1 object at 0x16afa6aa0>
Tensor: BatchNorm1d1-xvar, <VarBackward0 object at 0x16afa6aa0>
Tensor: BatchNorm1d1-xvar_inv, <PowBackward0 object at 0x16afa6aa0>
Tensor: BatchNorm1d1-xraw, <MulBackward0 object at 0x16afa6aa0>
Tensor: Nonlinearity2-out, <TanhBackward0 object at 0x16afa6aa0>
Tensor: Linear3-out, <AddBackward0 object at 0x16afa6aa0>
