In [64]:
import sys
sys.path.append("..")
import math
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from rich import print
from common import RNG
%load_ext rich

The rich extension is already loaded. To reload it, use:
  %reload_ext rich


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [16]:
# vocab
# "train" the Tokenizer, so we're able to map between characters and tokens
train_text = open("../data/train.txt", "r").read()
assert all(c == "\n" or ("a" <= c <= "z") for c in train_text)
uchars = sorted(list(set(train_text)))  # unique characters we see in the input
vocab_size = len(uchars)
char_to_token = {c: i for i, c in enumerate(uchars)}
token_to_char = {i: c for i, c in enumerate(uchars)}
EOT_TOKEN = char_to_token["\n"]  # designate \n as the delimiting <|endoftext|> token
# pre-tokenize all the splits one time up here
test_tokens = [char_to_token[c] for c in open("../data/test.txt", "r").read()]
val_tokens = [char_to_token[c] for c in open("../data/val.txt", "r").read()]
train_tokens = [char_to_token[c] for c in open("../data/train.txt", "r").read()]


In [24]:
def dataloader(tokens, context_length, batch_size):
    # returns inputs, targets as torch Tensors of shape (B, T), (B, )
    n = len(tokens)
    inputs, targets = [], []
    pos = 0
    while True:
        # simple sliding window over the tokens, of size context_length + 1
        window = tokens[pos : pos + context_length + 1]
        inputs.append(window[:-1])
        targets.append(window[-1])
        # once we've collected a batch, emit it
        if len(inputs) == batch_size:
            yield (torch.tensor(inputs), torch.tensor(targets))
            inputs, targets = [], []
        # advance the position and wrap around if we reach the end
        pos += 1
        if pos + context_length >= n:
            pos = 0

In [17]:
class MLP(nn.Module):
    def __init__(self, vocab_size, context_length, embedding_size, hidden_size, rng):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, embedding_size)  # token embedding table
        self.mlp = nn.Sequential(
            nn.Linear(context_length * embedding_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, vocab_size),
        )
        self.reinit(rng)

    @torch.no_grad()
    def reinit(self, rng):
        # This function is a bit of a hack and would not be present in
        # typical PyTorch code. Basically:
        # - we want to use our own RNG to initialize the weights.
        # - but we don't want to change idiomatic PyTorch code (above).
        # So here in this function we overwrite the weights using our own RNG.
        # This ensures that we have full control over the initialization and
        # can easily compare the results with other implementations.

        def reinit_tensor_randn(w, mu, sigma):
            winit = torch.tensor(rng.randn(w.numel(), mu=mu, sigma=sigma))
            w.copy_(winit.view_as(w))

        def reinit_tensor_rand(w, a, b):
            winit = torch.tensor(rng.rand(w.numel(), a=a, b=b))
            w.copy_(winit.view_as(w))

        # Let's match the PyTorch default initialization:
        # Embedding with N(0,1)
        reinit_tensor_randn(self.wte.weight, mu=0, sigma=1.0)
        # Linear (both W,b) with U(-K, K) where K = 1/sqrt(fan_in)
        scale = (self.mlp[0].in_features) ** -0.5
        reinit_tensor_rand(self.mlp[0].weight, -scale, scale)
        reinit_tensor_rand(self.mlp[0].bias, -scale, scale)
        scale = (self.mlp[2].in_features) ** -0.5
        reinit_tensor_rand(self.mlp[2].weight, -scale, scale)
        reinit_tensor_rand(self.mlp[2].bias, -scale, scale)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        emb = self.wte(idx)  # (B, T, embedding_size)
        emb = emb.view(B, -1)  # (B, T * embedding_size)
        logits = self.mlp(emb)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        return logits, loss


In [21]:
print(vocab_size)

In [46]:
context_length = 3  # if 3 tokens predict the 4th, this is a 4-gram model
embedding_size = 48
hidden_size = 512
batch_size = 16
num_steps = 1000
learning_rate = 7e-4
init_rng = RNG(1337)
model = MLP(vocab_size, context_length, embedding_size, hidden_size, init_rng)
model



[1;35mMLP[0m[1m([0m
  [1m([0mwte[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m27[0m, [1;36m48[0m[1m)[0m
  [1m([0mmlp[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m144[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mTanh[0m[1m([0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m27[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
[1m)[0m

In [53]:
@torch.inference_mode()
def eval_split(model, tokens, max_batches=None):
    # calculate the loss on the given tokens
    total_loss = 0
    num_batches = len(tokens) // batch_size
    if max_batches is not None:
        num_batches = min(num_batches, max_batches)
    data_iter = dataloader(tokens, context_length, batch_size)
    for _ in range(num_batches):
        inputs, targets = next(data_iter)
        logits, loss = model(inputs, targets)
        total_loss += loss.item()
    mean_loss = total_loss / num_batches
    return mean_loss


In [61]:
def softmax(logits):
    # logits here is a (1D) torch.Tensor of shape (V,)
    maxval = torch.max(logits)  # subtract max for numerical stability
    exps = torch.exp(logits - maxval)
    probs = exps / torch.sum(exps)
    return probs


def sample_discrete(probs, coinf):
    # sample from a discrete distribution
    # probs is a torch.Tensor of shape (V,)
    cdf = 0.0
    for i, prob in enumerate(probs):
        cdf += prob
        if coinf < cdf:
            return i
    return len(probs) - 1  # in case of rounding errors

In [47]:
train_data_iter = dataloader(train_tokens, context_length, batch_size)

In [51]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

In [48]:
X, y = next(train_data_iter)

In [57]:
step = 100
# cosine learning rate schedule, from max lr to 0
lr = learning_rate * 0.5 * (1 + math.cos(math.pi * step / num_steps))
for param_group in optimizer.param_groups:
    param_group["lr"] = lr
# every now and then evaluate the validation loss
last_step = step == num_steps - 1
if step % 100 == 0 or last_step:
    train_loss = eval_split(model, train_tokens, max_batches=20)
    val_loss = eval_split(model, val_tokens)
    print(
        f"step {step:6d} | train_loss {train_loss:.6f} | val_loss {val_loss:.6f} | lr {lr:e}"
    )
# # training step

# get the next batch of training data
inputs, targets = next(train_data_iter)
# forward pass (calculate the loss)
logits, loss = model(inputs, targets)
# backpropagate pass (calculate the gradients)
loss.backward()
# step the optimizer (update the parameters)
optimizer.step()
optimizer.zero_grad()


In [74]:
# Visualize logits as probabilities of the next token
logits_vis = logits[12]

# convert logits to probabilities
probs = softmax(logits_vis)


fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=[token_to_char[i] for i in range(vocab_size)],
        y=probs.detach().numpy(),
        mode="lines+markers",
    )
)

fig.update_layout(
    title=f"""Logits as probabilities of the next token <br>
    Input: {''.join([token_to_char[i.item()] for i in inputs[12]])} | Target: {token_to_char[targets[12].item()]}""",
    xaxis_title="Token",
    yaxis_title="Probability",
)

fig.show()
