In [1]:
import pathlib
import random
from typing import Literal
import math


import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
import matplotlib.pyplot as plt


%matplotlib inline

In [2]:
USE_MPS = False
TORCH_SEED = 2147483647
SEED = 42
NAMES_FILE = "names.txt"
TERM_TOK = "."
MODEL_INITIAL_PATH = "model_initial.pth"

TRAIN_SPLIT = 0.8
VAL_SPLIT = 0.1
TEST_SPLIT = 0.1
assert TRAIN_SPLIT + VAL_SPLIT + TEST_SPLIT == 1.0

# hyperparams
USE_BATCHNORM = True
CONTEXT_SIZE = 3
EMBEDDING_SIZE = 10
HIDDEN_SIZE = 100
NUM_LAYERS = 5
BATCH_SIZE = 64
LEARNING_RATE = 1e-1
NUM_EPOCHS = 5

random.seed(SEED)
torch.manual_seed(TORCH_SEED)

device = (
    torch.accelerator.current_accelerator() if torch.accelerator.is_available() and USE_MPS else torch.device("cpu")
)
device

device(type='cpu')

In [3]:
# build the dataset
class TrigramDataset(data.Dataset):
    def __init__(self, words: list[str], context_size: int, stoi: dict[str, int], terminal_token: str):
        self.context_size = context_size
        self.stoi = stoi
        self.data: list[tuple[list[str], int]] = []  # (context, label) tuples

        for word in words:
            context: list[int] = [0] * context_size
            for ch in word + terminal_token:
                ix = stoi[ch]
                self.data.append((context.copy(), ix))
                context = context[1:] + [ix]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        context, label = self.data[idx]
        return torch.tensor(context), torch.tensor(label)


words = [line.strip() for line in pathlib.Path(NAMES_FILE).open("r").readlines()]
random.shuffle(words)
chars = [TERM_TOK] + sorted(list(set("".join(words))))
vocab_size = len(chars)
stoi = {s: i for i, s in enumerate(chars)}
itos = chars


n1 = int(TRAIN_SPLIT * len(words))
n2 = int((TRAIN_SPLIT + VAL_SPLIT) * len(words))

train_dataset = TrigramDataset(words[:n1], CONTEXT_SIZE, stoi, TERM_TOK)
dev_dataset = TrigramDataset(words[n1:n2], CONTEXT_SIZE, stoi, TERM_TOK)
test_dataset = TrigramDataset(words[n2:], CONTEXT_SIZE, stoi, TERM_TOK)

train_dataloader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_dataloader = data.DataLoader(dev_dataset, batch_size=BATCH_SIZE)
test_dataloader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

for X, y in test_dataloader:
    print(f"Shape of X [B, C]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
print(f"Vocab size: {vocab_size}")
print(f"train dataset size: {len(train_dataset)}")
print(f"dev dataset size: {len(dev_dataset)}")
print(f"test dataset size: {len(test_dataset)}")

Shape of X [B, C]: torch.Size([64, 3])
Shape of y: torch.Size([64]) torch.int64
Vocab size: 27
train dataset size: 182625
dev dataset size: 22655
test dataset size: 22866


In [4]:
# mosly a copy of torch Linear impl
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        # stored as transpose because forward pass is x @ W.T
        self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # pytorch's weird default init, prolly from torch7 days
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.weight, a=-bound, b=bound)
        if self.bias is not None:
            nn.init.uniform_(self.bias, a=-bound, b=bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return "in_features={}, out_features={}, bias={}".format(
            self.in_features, self.out_features, self.bias is not None
        )


class BatchNorm1d(nn.Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-05,
        momentum: float | None = 0.1,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # for simplicity
        self.affine = True
        self.track_running_stats = True
        # parameters
        self.weight = nn.Parameter(torch.empty(num_features, **factory_kwargs))
        self.bias = nn.Parameter(torch.empty(num_features, **factory_kwargs))
        # buffers
        self.running_mean: torch.Tensor | None
        self.running_var: torch.Tensor | None
        self.register_buffer("running_mean", torch.zeros(num_features, **factory_kwargs))
        self.register_buffer("running_var", torch.ones(num_features, **factory_kwargs))
        self.reset_parameters()

    def reset_running_stats(self) -> None:
        self.running_mean.zero_()
        self.running_var.fill_(1)

    def reset_parameters(self) -> None:
        self.reset_running_stats()
        torch.nn.init.ones_(self.weight)
        torch.nn.init.zeros_(self.bias)

    def extra_repr(self) -> str:
        return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
            **self.__dict__
        )

    def _check_input_dim(self, input: torch.Tensor) -> None:
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        self._check_input_dim(input)
        # training mode, use batch stats and accumulate
        # eval mode, use running stats frozen when training finshed
        if self.training:
            xstd, xmean = torch.std_mean(input, dim=0, keepdim=True, unbiased=False)
            xvar = xstd**2
            # no need torch no_grad because buffers are not parameters
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        else:
            xmean = self.running_mean
            xvar = self.running_var
        input_hat = (input - xmean) / torch.sqrt(xvar + self.eps)  # normalize to unit variance
        return self.weight * input_hat + self.bias


class Tanh(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return torch.tanh(input)

In [5]:
def fc_layer(in_size: int, out_size: int, norm_layer: nn.Identity | BatchNorm1d, bias: bool = True) -> nn.Sequential:
    """Returns a stack of linear->norm->tanh layers."""
    return nn.Sequential(
        Linear(in_size, out_size, bias=bias),
        norm_layer(out_size),
        Tanh(),
    )


class NgramLM(nn.Module):
    def __init__(
        self,
        context_size: int,
        embedding_dim: int,
        vocab_size: int,
        *,
        batchnorm: bool = False,
        hidden_size: int = HIDDEN_SIZE,
        num_layers: int = NUM_LAYERS,
    ) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        if batchnorm is False:
            bias = True
            norm_layer = nn.Identity
        else:
            bias = False
            norm_layer = BatchNorm1d

        layers = []
        layers.append(fc_layer(context_size * embedding_dim, hidden_size, norm_layer=norm_layer, bias=bias))
        for i in range(num_layers - 1):
            layers.append(fc_layer(hidden_size, hidden_size, norm_layer=norm_layer, bias=bias))
        layers.append(Linear(hidden_size, vocab_size, bias=bias))
        with torch.no_grad():
            for layer in layers:
                if isinstance(layer, Linear):
                    # make last layer less confident and all other scaled for tanh
                    if layer is not layers[-1]:
                        layer.weight.mul_(nn.init.calculate_gain("tanh"))
                    else:
                        layer.weight.mul_(0.1)

        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)  # (batch, context, embedding)
        x = x.view(x.shape[0], -1)  # (batch, context*embedding)
        logits = self.layers(x)  # (batch, vocab)
        return logits


if USE_BATCHNORM:
    model = NgramLM(CONTEXT_SIZE, EMBEDDING_SIZE, vocab_size, batchnorm=True).to(device)
else:
    model = NgramLM(CONTEXT_SIZE, EMBEDDING_SIZE, vocab_size).to(device)

torch.save(model.state_dict(), MODEL_INITIAL_PATH)
print("Model initialized and saved to ", MODEL_INITIAL_PATH)
print(model)
print(sum(p.numel() for p in model.parameters()))
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

Model initialized and saved to  model_initial.pth
NgramLM(
  (embedding): Embedding(27, 10)
  (layers): Sequential(
    (0): Sequential(
      (0): Linear(in_features=30, out_features=100, bias=False)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
    )
    (1): Sequential(
      (0): Linear(in_features=100, out_features=100, bias=False)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
    )
    (2): Sequential(
      (0): Linear(in_features=100, out_features=100, bias=False)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
    )
    (3): Sequential(
      (0): Linear(in_features=100, out_features=100, bias=False)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
    )
    (4): Sequential(
      (0): Linear(in_features=100, out_features=100, bia

In [6]:
def train(
    dataloader: data.DataLoader,
    model: nn.Module,
    loss_fn: torch.nn.modules.loss._Loss,
    optimizer: torch.optim.Optimizer,
) -> None:
    size = len(dataloader.dataset)
    model.train()  # set train mode
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # forward
        logits = model(X)
        loss = loss_fn(logits, y)

        # backprop
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(
    dataloader: data.DataLoader,
    model: nn.Module,
    loss_fn: torch.nn.modules.loss._Loss,
) -> None:
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            logits = model(X)
            test_loss += loss_fn(logits, y).item()
    test_loss /= num_batches
    print(f"Avg loss: {test_loss:>8f} \n")


test(test_dataloader, model, loss_fn)

Avg loss: 3.295619 



In [7]:
for t in range(NUM_EPOCHS):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

Epoch 1
-------------------------------
loss: 3.293431  [   64/182625]
loss: 2.834607  [ 6464/182625]
loss: 2.974813  [12864/182625]
loss: 2.605795  [19264/182625]
loss: 2.428613  [25664/182625]
loss: 2.409704  [32064/182625]
loss: 2.384225  [38464/182625]
loss: 2.511750  [44864/182625]
loss: 2.410230  [51264/182625]
loss: 2.538194  [57664/182625]
loss: 2.328156  [64064/182625]
loss: 2.357265  [70464/182625]
loss: 2.388290  [76864/182625]
loss: 1.997476  [83264/182625]
loss: 2.219913  [89664/182625]
loss: 2.332808  [96064/182625]
loss: 2.278374  [102464/182625]
loss: 2.457909  [108864/182625]
loss: 2.372762  [115264/182625]
loss: 2.467430  [121664/182625]
loss: 2.345519  [128064/182625]
loss: 2.230841  [134464/182625]
loss: 2.200317  [140864/182625]
loss: 2.389362  [147264/182625]
loss: 2.031267  [153664/182625]
loss: 2.285945  [160064/182625]
loss: 2.365198  [166464/182625]
loss: 2.168643  [172864/182625]
loss: 2.227643  [179264/182625]
Avg loss: 2.267794 

Epoch 2
-------------------

In [8]:
# sample
model.eval()
with torch.no_grad():
    for _ in range(20):
        context = [0] * CONTEXT_SIZE
        out = ""
        while True:
            logits = model(torch.tensor([context]))
            probs = logits.softmax(dim=1)
            ix = torch.multinomial(probs, num_samples=1).item()
            context = context[1:] + [ix]
            out += itos[ix]
            if ix == 0:
                break
        print(out)

fakeels.
lia.
zena.
niky.
lilo.
mivrehus.
briel.
cpiy.
aryanna.
javurie.
plie.
srion.
kendevaciuwarto.
jalea.
bryz.
nauln.
marelin.
ston.
bree.
yoz.


In [9]:
test(dev_dataloader, model, loss_fn)
test(test_dataloader, model, loss_fn)

Avg loss: 2.159508 

Avg loss: 2.160488 



No BatchNorm:
- Avg loss: 2.212028 
- Avg loss: 2.207326 

With BatchNorm:

- Avg loss: 2.159508 
- Avg loss: 2.160488 

## Part 2: Visualizing activations and gradients

In [10]:
if USE_BATCHNORM:
    model = NgramLM(CONTEXT_SIZE, EMBEDDING_SIZE, vocab_size, batchnorm=True).to(device)
else:
    model = NgramLM(CONTEXT_SIZE, EMBEDDING_SIZE, vocab_size).to(device)

model.load_state_dict(torch.load(MODEL_INITIAL_PATH, weights_only=True))
test(test_dataloader, model, loss_fn)

Avg loss: 3.295619 



In [11]:
# pathlib.Path(MODEL_INITIAL_PATH).unlink(missing_ok=True)