In [1]:
import pathlib
import random
from typing import Literal
import math
import enum


import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
import torch.nn.functional as F
import matplotlib.pyplot as plt


%matplotlib inline

In [2]:
class ModelType(enum.StrEnum):
    CHAR_BIGRAM = "char_bigram"
    WORD_MLP = "word_mlp"


USE_ACCELERATOR = True
TORCH_SEED = 2147483647
SEED = 42
INPUT_FILE = "tinyshakespeare.txt"
MODEL_TYPE: ModelType = ModelType.CHAR_BIGRAM


TRAIN_SPLIT = 0.9
VAL_SPLIT = 0.1
assert TRAIN_SPLIT + VAL_SPLIT == 1.0

# hyperparams
CONTEXT_SIZE = 8  # the maximum length of predictions
EMBEDDING_SIZE = 24
HIDDEN_SIZE = 128
NUM_LAYERS = 5
BATCH_SIZE = 64  # the number of independent sequences to process at once
LEARNING_RATE = 1e-3
NUM_EPOCHS = 2

random.seed(SEED)
torch.manual_seed(TORCH_SEED)

device = (
    torch.accelerator.current_accelerator()
    if torch.accelerator.is_available() and USE_ACCELERATOR
    else torch.device("cpu")
)
device

device(type='cuda')

In [3]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
print(x[0])  # (T, C)

wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
xbow = torch.zeros_like(x)
for b in range(B):
    for t in range(T):
        xprev = x[b, : t + 1]
        xbow[b, t] = xprev.mean(dim=0)

# equal-weighted aggregation
xbow2 = wei @ x  # (T, T) @ (B, T, C) -(broadcast)-> (B, T, C)
assert torch.allclose(xbow, xbow2)

# variable-weight aggregation
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0.0, float("-inf"))
wei = wei.softmax(dim=1)
xbow3 = wei @ x
assert torch.allclose(xbow, xbow3)

tensor([[ 1.5674, -0.2373],
        [-0.0274, -1.1008],
        [ 0.2859, -0.0296],
        [-1.5471,  0.6049],
        [ 0.0791,  0.9046],
        [-0.4713,  0.7868],
        [-0.3284, -0.4330],
        [ 1.3729,  2.9334]])


In [4]:
a = torch.tril(torch.ones(3, 3))
a = a / a.sum(dim=1, keepdim=True)
print(a)
b = torch.tensor([[2.0, 9.0], [2.0, 5.0], [8.0, 1.0]])
result = torch.zeros_like(b)
for i in range(b.shape[0]):
    result[i] = b[: i + 1].mean(dim=0)
print(result)

torch.allclose(a @ b, result)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 9.],
        [2., 7.],
        [4., 5.]])


True

In [5]:
B, T, C = 4, 8, 32
x = torch.randn((B, T, C))
tril = torch.tril(torch.ones(T, T))
print(tril)
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0.0, float("-inf"))
wei = wei.softmax(dim=1)
xbow3 = wei @ x

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])


In [8]:
B, T, C = 2, 4, 16
D_head = 8
embeddings = torch.randint(0, 5, (B, T, C)).float()
W_q = torch.randint(0, 5, (C, D_head)).float()
W_k = torch.randint(0, 5, (C, D_head)).float()
W_v = torch.randint(0, 5, (C, D_head)).float()
queries = embeddings @ W_q  # (B, T, D_head)
keys = embeddings @ W_k  # (B, T, D_head)
values = embeddings @ W_v  # (B, T, D_head)

attn_scores = queries @ keys.transpose(-2, -1)  # (B, T, D_head) @ (B, D_head, T) -> (B, T, T)
attn_scores = attn_scores / torch.sqrt(torch.tensor(D_head, dtype=torch.float32))  # normalize
tril = torch.tril(torch.ones(T, T))  # mask out future tokens
attn_scores = attn_scores.masked_fill(tril == 0.0, float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)  # (B, T, T)
residuals = attn_weights @ values  # (B, T, T) @ (B, T, D_head) -> (B, T, D_head)
print(embeddings[0])
print(attn_weights[0])
print(values[0])
print(residuals[0])

tensor([[3., 4., 1., 2., 3., 0., 1., 4., 0., 1., 4., 1., 3., 1., 1., 4.],
        [0., 1., 0., 3., 4., 3., 1., 4., 1., 3., 4., 4., 1., 2., 4., 3.],
        [3., 4., 0., 0., 4., 4., 1., 3., 1., 3., 4., 4., 4., 0., 0., 1.],
        [2., 0., 0., 0., 1., 0., 4., 2., 4., 3., 2., 2., 0., 2., 3., 3.]])
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.]])
tensor([[23., 60., 44., 55., 67., 81., 52., 67.],
        [30., 69., 53., 66., 78., 72., 69., 65.],
        [28., 68., 48., 62., 73., 75., 64., 66.],
        [25., 39., 49., 58., 49., 48., 45., 52.]])
tensor([[23., 60., 44., 55., 67., 81., 52., 67.],
        [30., 69., 53., 66., 78., 72., 69., 65.],
        [30., 69., 53., 66., 78., 72., 69., 65.],
        [30., 69., 53., 66., 78., 72., 69., 65.]])


In [42]:
B, T, C = 4, 8, 32
head_size = 16
x = torch.randn((B, T, C))
W_q = nn.Linear(C, head_size, bias=False)
W_k = nn.Linear(C, head_size, bias=False)
W_v = nn.Linear(C, head_size, bias=False)

x_q = W_q(x)  # (B, T, C) @ (C, head_size) -> (B, T, head_size)
x_k = W_k(x)  # same
x_v = W_v(x)  # same

attn = x_q @ x_k.transpose(-2, -1) / head_size**0.5  # (B, T, T)
mask = torch.tril(torch.ones(T, T))
attn = attn.masked_fill(mask == 0.0, float("-inf"))
attn = attn.softmax(dim=-1)
out = attn @ x_v  # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)

print(out[0])

tensor([[ 8.1925e-03, -6.0443e-01, -6.5695e-01, -4.3440e-02,  5.8751e-01,
         -7.3220e-01, -4.6700e-01,  2.0095e-01, -5.4697e-01, -4.9193e-01,
          2.4311e-01,  5.2821e-01, -3.1823e-01, -4.9768e-01, -5.8761e-01,
         -3.6939e-01],
        [-7.8378e-02, -3.4855e-01, -3.1900e-01, -5.0265e-01,  3.3814e-01,
         -4.7783e-01, -6.1466e-01,  9.3785e-02, -1.0540e-01, -4.1946e-01,
          4.3377e-01,  6.1648e-01, -4.4797e-01, -6.4522e-01, -4.9924e-01,
          8.0813e-02],
        [-3.4285e-01, -4.2789e-01, -1.5630e-01, -3.9100e-01,  1.4745e-01,
         -4.0785e-01, -4.9640e-01,  2.5882e-02, -8.4196e-02, -1.9822e-01,
          4.4566e-01,  2.7950e-01, -3.9456e-01, -6.4231e-01, -4.5462e-01,
         -1.0927e-01],
        [-4.4342e-01, -3.8915e-01, -1.1667e-01, -4.8522e-01,  3.4610e-01,
         -2.8422e-01, -3.3227e-01,  6.2355e-02, -5.4839e-02, -3.4362e-01,
          3.9218e-01,  8.1700e-02, -4.9806e-01, -5.2871e-01, -2.5724e-01,
         -2.0421e-01],
        [-4.0587e-01