In [5]:
import torch
from torch.nn import functional as F
from src.datasets.utils import *
from src.models.bigram import *

In [8]:
torch.manual_seed(435123)
batch_size = 32
context_length = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_size = 37

In [9]:
train, test = train_test_split()

In [10]:
xb, yb = get_batch(train)

In [11]:
m = BigramLM(vocab_size)

In [12]:
logits, loss = m(xb, yb)
print(loss)

tensor(3.9769, grad_fn=<NllLossBackward0>)


In [13]:
idx = torch.zeros((1, 1), dtype=torch.long)
generated_idx = m.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(generated_idx))

0x53ymyb2feewcy6px3b6cqo 4q0zdovqx5ec4jhyf7gjq0sqb2hvm1qi9ng1vjh7az4buwy15032lrpcy8ic36lkn9qq0v6n521v


In [17]:
optim = torch.optim.Adam(m.parameters(), lr=1e-3)
for steps in range(100000):
    xb, yb = get_batch(train)
    logits, loss = m(xb, yb)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()
print(loss.item())

2.4556188583374023


In [18]:
idx = torch.zeros((1, 1), dtype=torch.long)
generated_idx = m.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(generated_idx))

050 f b gu   d qulusplau  g as disinfovanvis  d athractithe  0p  p is  theaisedive biuealioultamg tit


In [19]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [20]:
# We want x[b, t] = meazn{i<=t} x[b, i]
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0) # average over time -> ((c))

In [25]:
# vectorized version
wei = torch.tril(torch.ones(T, T))
wei = wei/wei.sum(1, keepdim=True)
xbow2 = wei @ x # (T,T) @ (B, T, C) ---> (B, T, T) @ (B, T, C) --> (B, T, C)

In [26]:
torch.allclose(xbow, xbow2)

True

In [27]:
# version 3: use softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x

In [28]:
torch.allclose(xbow, xbow3)

True

In [29]:
# version 4: self-attention!
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])