In [8]:
import torch
import torch.nn as nn
from torch.nn import functional as F
# --- Hyperparams ---
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 1000
eval_interval = 300  
learning_rate = 1e-2
device = torch.device('mps') or ('cuda' if torch.cuda.is_available() else 'cpu')
eval_iters = 200
n_embd = 32
print(device)
# -----------------

torch.manual_seed(1337)

# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
  text = f.read()


chars = sorted(list(set(text)))
vocab_size = len(chars)

# "Tokenizer". Simple; pros: small vocab size, cons: long sequences
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: [itos[i] for i in l]
encoded = encode("HALLO")
encoded, decode(encoded)

torch.manual_seed(1337)

data = torch.tensor(encode(text), dtype=torch.long)
n = int(.9*len(data))
train_data = data[:n]
val_data = data[n:]


def get_batch(split="train"):
  data = train_data if split == "train" else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x, y, = x.to(device), y.to(device)
  return x, y

@torch.no_grad() 
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Bigram(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    B, T = idx.shape

    tok_embd = self.token_embedding_table(idx) # (B, T, C)
    pos_embd = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
    x = tok_embd + pos_embd
    logits = self.lm_head(x)
    # print(logits.shape)
    # Reshape for corssentropy

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C )
      targets = targets.view (B*T)
      # print(logits.shape, targets.shape)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      logits, loss = self(idx)
      logits = logits[:, -1, :] # (B, C)
      probs = F.softmax(logits, dim=1)
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

model = Bigram()
model = model.to(device)

optim = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for step in range(max_iters):
    xb, yb = get_batch()

    # Evaluate Loss
    logits, loss = model(xb, yb)
    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()

# "".join([decode(batch) for batch in m.generate(torch.zeros((1, 1), dtype=torch.long), 100).tolist()][0])
context = torch.zeros((2, 1), dtype=torch.long, device=device)
print(''.join(decode(model.generate(context, 500)[0].tolist())))


mps
step 0: train loss 4.4800, val loss 4.4803
step 300: train loss 2.5399, val loss 2.5577
step 600: train loss 2.5167, val loss 2.5340
step 900: train loss 2.4968, val loss 2.5147

Foasth pr!
SKixchaPENGRYIOLOMUKEE&exby:
QUS:
3 COLineg agntheprdrrknteckeyr PHENThe?
TyONGrsothy.
DLOPELABUSimppry PS:
MANIF
Shbur$erwixikns
Fokncaknd-htNRLYoun, KGLIOMOMu gwirexthe.
AUFOMETHAnthcus;
NDUFr
TEKYCONTRDURDUKIEUjerks?
STh!
Tus
LUSA'st BEFoghy whe f,
SThogiznk&ofachang!
&CETDWler'dsuqughold ark'dz

PESThe
TJUCENCTINGBundmzeyby hindyongmy!
HENRRVof,
CI'Thy.
MARBRUCUMy,---GHIOfourknghm goutheppmplkedextherex'digackefithrkn'GENUS:
Y laghindKINGRBedouerdjurd.
JKEMjurKENGHEOKIFe.
IINGS:
Y


In [26]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 32])

In [27]:
xbow = torch.zeros((B, T, C))
for b in range(B):
  for t in range(T):
    prev = x[b, :t+1]
    xbow[b, t] = torch.mean(prev, 0)
    # print(prev, torch.mean(prev, 0))

In [28]:
wbow = torch.zeros((B, T, C))
w = torch.tril(torch.ones(T, T))
w = w / w.sum(1, keepdim=True)

# for b in range(B):
#   wbow[b] = w @ x[b] # (T, T) @ (T, C) --> (T, C)
wbow = w @ x
print(wbow.shape, xbow.shape)
torch.allclose(xbow, wbow)

torch.Size([4, 8, 32]) torch.Size([4, 8, 32])


True

In [29]:
# One more step
tril = torch.tril(torch.ones((T, T)))
weights = torch.zeros((T, T)).masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=1) # Each row sums to one
xbow3 = weights @ x
torch.allclose(xbow, xbow3)

True

In [30]:
# a single head
head_size = 16 # Embedding dimension in head after linear layer
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, head_size) 
q = query(x)
v = value(x)
# Swap T and head_size dim --> (B, T, hs) @ (B, hs, T) --> (B, T, T)
weights = q @ k.transpose(-2, -1) 
tril = torch.tril(torch.ones((T, T)))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=1)
out = weights @ x
out.shape

torch.Size([4, 8, 32])

In [31]:
out

tensor([[[ 4.4919e-03, -1.7391e-03, -8.9361e-03,  ..., -1.9920e-02,
           3.7859e-02,  6.2334e-02],
         [-5.1158e-03, -2.6541e-03,  7.3581e-03,  ...,  9.8471e-03,
           2.2559e-02,  9.2822e-03],
         [-2.0666e-01,  1.4131e-01, -1.7812e-02,  ..., -1.2945e-01,
          -1.4731e-01,  2.3609e-02],
         ...,
         [-2.2176e+00,  1.1948e+00,  1.4382e+00,  ...,  2.3454e-01,
           1.9469e+00, -6.7925e-01],
         [-1.4471e+00, -2.2265e-01,  5.8033e-01,  ...,  2.0316e+00,
           1.2952e+00,  1.4777e+00],
         [-1.5152e+00, -1.0873e+00,  1.3338e+00,  ..., -1.2010e+00,
          -1.1554e+00,  2.0142e-01]],

        [[ 2.0221e-02, -4.8391e-02, -3.6381e-02,  ...,  2.2690e-03,
          -2.9152e-02, -1.1405e-01],
         [ 2.7081e-03,  3.3087e-02, -5.0248e-02,  ...,  2.9653e-02,
          -4.5851e-02,  3.8357e-02],
         [ 5.7537e-01, -9.4794e-02,  6.3360e-02,  ..., -1.4155e-01,
           2.0953e-01, -3.1606e-01],
         ...,
         [ 1.7553e-01,  3