In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import time

In [75]:
# get data
text = open('../data/shakespeare.txt', 'r', encoding='utf-8').read()


# hyperparameters
device = 'cuda' if torch.cuda.is_available else 'cpu'
batch_size = 64
block_size = 128
dropout = 0.5
n_embd = 128
n_head = 16
n_layer = 32
learning_rate = 1e-3
eval_interval = 100
eval_iters = 200


# get a set of chars
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create encodings
itos = {i:s for i, s in enumerate(chars)}
stoi = {s:i for i, s in enumerate(chars)}
encode = lambda s: list(stoi[c] for c in s)
decode = lambda l: ''.join(itos[i] for i in l)


# create train and val data splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(.9 * len(data))
train_data = data[:n]
val_data = data[n:]


# create get batch function minibatch
def get_batch(split):
    
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y
    
    
# estimate the loss
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            x, y = get_batch(split)
            logits, loss = model(x, y)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out



# create head of attention
class Head(nn.Module):
    
    def __init__(self, head_size):
        super().__init__()
        
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        # aggregation of the values
        v = self.value(x)
        out = wei @ v
        return out
        

# create multi head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        x = self.dropout(x)
        return x
    

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        x = self.net(x)
        return x
    

    
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        
        """ communication followed by computation """
        
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(head_size)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    

class NGramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, x, targets=None):
        
        B,T = x.shape
        
        tok_emb = self.token_embedding_table(x)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
            
        else:
            
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, x, max_new_tokens):
        
        for _ in range(max_new_tokens):
            
            x_cond = x[:, -block_size:]
            logits, loss = self(x_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            x_next = torch.multinomial(probs, 1)
            x = torch.cat((x, x_next), dim=1)
            
        return x
    
model = NGramLanguageModel()
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

print(f'{sum(p.numel() for p in model.parameters()) / 1e6 } : M parameters')

6.365761 : M parameters


In [58]:
from tqdm import trange

max_iters = 3000
for iter in trange(max_iters):
    
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f'train loss = {losses["train"]} : val loss = {losses["val"]}')
        
    x, y = get_batch('train')
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    

  0%|          | 0/3000 [00:00<?, ?it/s]

train loss = 4.218393802642822 : val loss = 4.22935676574707


  3%|▎         | 100/3000 [02:24<37:09,  1.30it/s] 

train loss = 2.5226941108703613 : val loss = 2.5221376419067383


  7%|▋         | 200/3000 [04:50<38:09,  1.22it/s]   

train loss = 2.4266977310180664 : val loss = 2.4485127925872803


 10%|█         | 300/3000 [07:19<34:24,  1.31it/s]   

train loss = 2.320152759552002 : val loss = 2.345762014389038


 13%|█▎        | 400/3000 [09:42<33:00,  1.31it/s]   

train loss = 2.1968588829040527 : val loss = 2.226954460144043


 17%|█▋        | 500/3000 [12:05<31:08,  1.34it/s]   

train loss = 2.0987634658813477 : val loss = 2.145958662033081


 20%|██        | 600/3000 [14:27<29:51,  1.34it/s]   

train loss = 1.9908556938171387 : val loss = 2.0642499923706055


 23%|██▎       | 700/3000 [16:50<28:37,  1.34it/s]   

train loss = 1.8995281457901 : val loss = 1.9985389709472656


 27%|██▋       | 800/3000 [19:14<29:05,  1.26it/s]   

train loss = 1.8214160203933716 : val loss = 1.9342601299285889


 30%|███       | 900/3000 [21:39<28:06,  1.25it/s]   

train loss = 1.7487664222717285 : val loss = 1.891795039176941


 33%|███▎      | 1000/3000 [24:15<27:46,  1.20it/s]  

train loss = 1.7043712139129639 : val loss = 1.8505151271820068


 37%|███▋      | 1100/3000 [26:55<26:15,  1.21it/s]   

train loss = 1.6589720249176025 : val loss = 1.8446857929229736


 40%|████      | 1200/3000 [29:35<24:43,  1.21it/s]   

train loss = 1.615634799003601 : val loss = 1.803605556488037


 43%|████▎     | 1300/3000 [32:01<23:11,  1.22it/s]   

train loss = 1.5853596925735474 : val loss = 1.7693291902542114


 47%|████▋     | 1400/3000 [34:35<20:12,  1.32it/s]   

train loss = 1.5576202869415283 : val loss = 1.752343773841858


 50%|█████     | 1500/3000 [37:01<19:08,  1.31it/s]  

train loss = 1.5384726524353027 : val loss = 1.732603907585144


 53%|█████▎    | 1600/3000 [39:35<18:24,  1.27it/s]  

train loss = 1.514164924621582 : val loss = 1.7006127834320068


 57%|█████▋    | 1700/3000 [42:04<17:09,  1.26it/s]  

train loss = 1.4974780082702637 : val loss = 1.6869319677352905


 60%|██████    | 1800/3000 [44:31<15:11,  1.32it/s]  

train loss = 1.473813533782959 : val loss = 1.668248176574707


 63%|██████▎   | 1900/3000 [46:54<14:25,  1.27it/s]  

train loss = 1.4585782289505005 : val loss = 1.6629868745803833


 67%|██████▋   | 2000/3000 [49:18<12:29,  1.34it/s]  

train loss = 1.442018985748291 : val loss = 1.6347322463989258


 70%|███████   | 2100/3000 [51:41<11:17,  1.33it/s]  

train loss = 1.4318571090698242 : val loss = 1.6323715448379517


 73%|███████▎  | 2200/3000 [54:04<09:58,  1.34it/s]  

train loss = 1.418219804763794 : val loss = 1.629976511001587


 77%|███████▋  | 2300/3000 [56:31<09:51,  1.18it/s]  

train loss = 1.4095324277877808 : val loss = 1.6219401359558105


 80%|████████  | 2400/3000 [58:59<07:55,  1.26it/s]  

train loss = 1.4024990797042847 : val loss = 1.610627293586731


 83%|████████▎ | 2500/3000 [1:01:25<06:35,  1.26it/s]  

train loss = 1.3866733312606812 : val loss = 1.6030505895614624


 87%|████████▋ | 2600/3000 [1:03:53<05:12,  1.28it/s]  

train loss = 1.3790249824523926 : val loss = 1.592982530593872


 90%|█████████ | 2700/3000 [1:06:17<03:45,  1.33it/s]  

train loss = 1.3695988655090332 : val loss = 1.5816572904586792


 93%|█████████▎| 2800/3000 [1:08:41<02:34,  1.29it/s]  

train loss = 1.3650094270706177 : val loss = 1.5846518278121948


 97%|█████████▋| 2900/3000 [1:11:08<01:15,  1.32it/s]  

train loss = 1.3554720878601074 : val loss = 1.573350191116333


100%|█████████▉| 2999/3000 [1:13:34<00:00,  1.28it/s]

train loss = 1.3467787504196167 : val loss = 1.5717471837997437


100%|██████████| 3000/3000 [1:14:44<00:00,  1.49s/it]


In [59]:
estimate_loss()

{'train': tensor(1.3449), 'val': tensor(1.5692)}

In [76]:
model.load_state_dict(torch.load('../data/shakespeare2.pt'))

<All keys matched successfully>

In [84]:
# context = torch.zeros(1, 1, dtype=torch.long, device=device)
context = torch.tensor(encode('CLARK:\nCan you help me?\n\nKENT:'), dtype=torch.long, device=device).view(1, -1)
print(decode(model.generate(context, 500)[0].tolist()))

CLARK:
Can you help me?

KENT:
How cad if it thou quaked to her brace.

ESCALUS:
More thanks murderold. That we have thee pot on.

VIRGILIA:
Task yor Jom art stry to rount, by his man, threw--

KING EDWARD IV:
Hast you artiest, and good both arm's
And, on for name to the presctroiced.

QUEEN MARGARET:
So thould I there arM ocse thee home off,
as move here expas, what wardeds; thou shall spers,
The serve of a spurp'd rose belans!
O' have thath doney moore fouse! tharks Frather,
Now that that oppear from ith, if
Wedge to thee 


In [74]:
torch.save(model.state_dict(), '../data/shakespeare2.pt')