In [303]:
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm

In [None]:
# hyper-parameters
batch_size = 32 # independent sequences process in parallel
block_size = 8 # context length
n_embd = 32
n_head = 4 # n_embd % n_head = 0
lr = 1e-3
epochs = 20
iters = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data

In [727]:
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
res = requests.get(url)
text = res.content.decode("utf-8")

In [728]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"{''.join(chars)} \nNumber of token: {vocab_size}")


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 
Number of token: 65


In [729]:
# text tokenization in character level
itos = {i:c for i, c in enumerate(chars)}
stoi = {c:i for i, c in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s] # encoder: string to list of int
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: list of int to string

In [730]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [731]:
n = int(len(data)*0.9)
train = data[:n]
val = data[n:]

In [732]:
print(f'x: {train[:block_size]}')
print(f'y: {train[1:block_size+1]}')

x: tensor([18, 47, 56, 57, 58,  1, 15, 47])
y: tensor([47, 56, 57, 58,  1, 15, 47, 58])


In [733]:
torch.manual_seed(49)

def get_batch(phase):
    data = train if phase =='train' else val
    ix = torch.randint(len(data)-block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch('train')
# print('inputs:', xb.shape, '\n', xb)
# print('targets:', yb.shape, '\n', yb)

# Model

In [734]:
# Head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # B, T, head
        q = self.query(x) # B, T, head
        # attention scores (affinities)
        a = q @ k.transpose(-2, -1) * self.head_size**-0.5
        a = a.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        a = F.softmax(a, dim=-1)
        
        v = self.value(x)
        out = a @ v
        return out

In [735]:
# Multi-head
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd) # projection back to residual pathway
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [736]:
# Feed-Forward Networks
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
        )
        self.proj = nn.Linear(4 * n_embd, n_embd) # projection back to residual pathway
    
    def forward(self, x):
        out = self.net(x)
        out = self.proj(out)
        return out

In [737]:
# Transformer Block
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa_heads = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        x = x + self.sa_heads(self.ln1(x)) # prenorm 
        x = x + self.ffwd(self.ln2(x))
        return x

In [744]:
# Bigram Model
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embd_table = nn.Embedding(vocab_size, n_embd)
        self.position_embd_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            Block(n_embd, n_head),
            Block(n_embd, n_head),
            Block(n_embd, n_head),
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # idx and targets are both B, T
        tok_embd = self.token_embd_table(idx) # B, T, embd
        pos_embd = self.position_embd_table(torch.arange(T, device=device)) # T, embd
        x = tok_embd + pos_embd
        x = self.blocks(x)
        logits = self.lm_head(x) # B, T, vocab_size
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond) # B, T, C
            logits = logits[:, -1, :] # B, C
            probs = F.softmax(logits, dim=-1) # B, T
            idx_next = torch.multinomial(probs, num_samples=1) # B, 1
            idx = torch.cat((idx, idx_next), dim=1) # B, T+1
        return idx

In [745]:
torch.manual_seed(49)
model = BigramLanguageModel()

# Train Model

In [756]:
# criterion = nn.CrossEntropyLoss()
optimizer = optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [747]:
def train_one_epoch():
    model.train()
    running_loss = 0
    iteration = iters # int(len(train)/batch_size)
    for _ in range(iteration):
        xb, yb = get_batch('train')
        
        # forward pass
        logits, loss = model(xb, yb)
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # track stat
        running_loss += loss.item() / iteration
    return running_loss

In [748]:
def evaluate():
    model.eval()
    running_loss = 0
    iteration = iters # int(len(val)/batch_size)
    
    with torch.inference_mode():
        for _ in range(iteration):
            xb, yb = get_batch('val')
            # forward pass
            logits, loss = model(xb, yb)
            running_loss += loss.item() / iteration
    return running_loss

In [759]:
for epoch in tqdm(range(epochs)):
    train_loss = train_one_epoch()
    val_loss = evaluate()
    print(f"Epoch: {epoch+1}/{epochs} | training loss: {train_loss:.4f} | validation loss: {val_loss:.4f}")

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 1/20 | training loss: 1.9344 | validation loss: 2.0427
Epoch: 2/20 | training loss: 1.9410 | validation loss: 2.0314
Epoch: 3/20 | training loss: 1.9374 | validation loss: 2.0407
Epoch: 4/20 | training loss: 1.9345 | validation loss: 2.0354
Epoch: 5/20 | training loss: 1.9312 | validation loss: 2.0319
Epoch: 6/20 | training loss: 1.9259 | validation loss: 2.0335
Epoch: 7/20 | training loss: 1.9242 | validation loss: 2.0307
Epoch: 8/20 | training loss: 1.9272 | validation loss: 2.0258
Epoch: 9/20 | training loss: 1.9247 | validation loss: 2.0325
Epoch: 10/20 | training loss: 1.9219 | validation loss: 2.0301
Epoch: 11/20 | training loss: 1.9189 | validation loss: 2.0322
Epoch: 12/20 | training loss: 1.9153 | validation loss: 2.0305
Epoch: 13/20 | training loss: 1.9141 | validation loss: 2.0311
Epoch: 14/20 | training loss: 1.9092 | validation loss: 2.0262
Epoch: 15/20 | training loss: 1.9120 | validation loss: 2.0227
Epoch: 16/20 | training loss: 1.9138 | validation loss: 2.0246
E

# Generate

In [762]:
idx = torch.zeros((1, 1), dtype=torch.long).to(device)
print(decode(model.generate(idx, 1000)[0].tolist()))


To Anto swome his Self
Best their:
I ret sfyound tearny ther mornane. A thee pureb'ds
Dike
Ullews sit you, chach never stontwer of her!
I grower,
Some four this floght.

CORCADURE
MERSCINIOF VI' chimile,
I them makidicece.

HANGEY:
Ben thin awas, my instrant-arthere good do the woad! 
God cally
That himes a which tall: crema;
And, dedince,
WinSaven;
The ratie, have hears, deaps abothter, are it to burithere rour you dewo sit doth she meclaids,
Tide?

LORD GIARGIE:
Hold, man.

LAUSTEL:
For oul strow not rever sires; your old two rulded the the dispily, son we murdine.

Micks, broys tursil.

KING RICESTER:
Bhis my Warwolds: chart,
Have fromm from me, let ken
that that or we disprance so, and men pustry sive and somel seel.
Nurse did at cill'd man,
Thore ware lord?
War: Work, a plew do of reme caul.

I proan to mood, their.

KING EDWHBY:
Feliry for rist boy, of smyy wet lord. To do preak,
I will to were met, as you.
Bof dughter: thans blose sless. Where Eght.

LUCESTER:
Ah,
Hath no at em

# The mathematical trick in self-attention

In [494]:
torch.manual_seed(49)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [326]:
wei = torch.tril(torch.ones(T, T))
wei /= wei.sum(1, keepdim=True)
xbow = wei @ x # bag of word (B, T, T) @ (B, T, C) --> (B, T, C)
xbow.shape 

torch.Size([4, 8, 2])

In [344]:
torch.manual_seed(49)
tril = torch.tril(torch.ones(T, T))
a = torch.randn((T, T)) # affinity
a = a.masked_fill(tril == 0, float('-inf')) # token from future can not communicate
a = F.softmax(a, dim=-1)
xbow = a @ x
xbow.shape

torch.Size([4, 8, 2])

In [622]:
# self-attention
torch.manual_seed(47)

B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # B, T, head_size
q = query(x) # B, T, head_size
a = q @ k.transpose(-2, -1) # * head_size**-0.5 # B, T, T

tril = torch.tril(torch.ones(T, T))
a = a.masked_fill(tril == 0, float('-inf'))
a = F.softmax(a, dim=-1)

v = value(x)
out = a @ v
out.shape

torch.Size([4, 8, 16])