In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
chars = sorted(list(set(text)))
chars_size = len(chars)
print('size:', chars_size, ''.join(chars))

size: 65 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [4]:
stoi = { c:i for i,c in enumerate(chars) }
itos = { i:c for i,c in enumerate(chars) }
encode = lambda str: [stoi[c] for c in str]
decode = lambda inds: ''.join(itos[i] for i in inds)

In [5]:
batch_size = 32
block_size = 8
batch_estimate = 300
estimate_interval = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
data = torch.tensor(encode(text), dtype=torch.long)

torch.manual_seed(1337)

# split data on training and valuation
ind = int(0.9*len(data))
data_tr = data[:ind]
data_val = data[ind:]

def get_batch(split):
    data = data_tr if split == 'train' else data_val
    inds = torch.randint(0, len(data) - block_size, (batch_size,))
    
    x = torch.stack([data[i:i+block_size] for i in inds])
    y = torch.stack([data[i+1:i+block_size+1] for i in inds])
    return x.to(device), y.to(device)

In [7]:
class BigramLanguageModel(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(size, size)
        
    def forward(self, x, targets=None):
        logits = self.token_embedding_table(x) # [batch_size, block_size, embedding_table]
        
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C) # ? why not B, T*C
            targets = targets.view(-1) 
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, ind, num_tokens):
        for _ in range(num_tokens):
            logits, loss = self(ind) # forward(x)
            logits = logits[:, -1, :] # get the last character [batch_size, embedding_table]
            probs = F.softmax(logits, dim=1)
            ind_next = torch.multinomial(probs, num_samples=1)
            ind = torch.cat((ind, ind_next), dim=1)
        
        return ind
        

In [8]:
model = BigramLanguageModel(chars_size)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [9]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    
    for split in ['train', 'val']:
        losses = torch.zeros(batch_estimate)
        for i in range(batch_estimate):
            x_b, y_b = get_batch(split)
            logits, loss = model(x_b, y_b)
            losses[i] = loss
        out[split] = losses.mean().item()
        
    model.train()
    return out

In [10]:
steps = 10000
for i in range(steps):
    if i % estimate_interval == 0:
        losses = estimate_loss()
        print('step:', i, 'train loss:', losses['train'], 'val loss:', losses['val'])
    
    x_b, y_b = get_batch('train')
    logits, loss = model.forward(x_b, y_b)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step: 0 train loss: 4.729331016540527 val loss: 4.726686000823975
step: 1000 train loss: 3.7340946197509766 val loss: 3.7374446392059326
step: 2000 train loss: 3.1215972900390625 val loss: 3.1260106563568115
step: 3000 train loss: 2.801257848739624 val loss: 2.8015177249908447
step: 4000 train loss: 2.6318860054016113 val loss: 2.648066520690918
step: 5000 train loss: 2.5684964656829834 val loss: 2.568488597869873
step: 6000 train loss: 2.52632474899292 val loss: 2.537660837173462
step: 7000 train loss: 2.4904325008392334 val loss: 2.5036895275115967
step: 8000 train loss: 2.480072259902954 val loss: 2.49685001373291
step: 9000 train loss: 2.470869541168213 val loss: 2.489243745803833


In [11]:
# generate from the model
num_tokens = 300
context = torch.zeros((1,1), dtype=torch.long, device=device)
result = model.generate(context, num_tokens=num_tokens)
print(decode(result[0].tolist()))




CExfik bridcowindakis s, bth

HAPORThobe d e.
S:
O:3 my d?
LUCous:
Wanthar u qur, vet?
F dXENDoate awice my.

HDEdarom oroup
Yowh$Frtof isth ble mil ndill, ath iree sengmin lat Heriliovets, and Win nghirileranousel lind me l.
HAshe ce hiry:
Supr aisspllw y.
Hurindu n Boopetelaves
MP:

Pl, d mothak


In [12]:
# self attention
B,T,C = 4,8,32
head_size = 16

x = torch.randn((B,T,C))
keys = nn.Linear(C, head_size, bias=False)
queries = nn.Linear(C, head_size, bias=False)
values = nn.Linear(C, head_size, bias=False)

k = keys(x)
q = queries(x)

weights = q @ k.transpose(-2,-1) * C**-0.5
tril = torch.tril(torch.ones((T,T)))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

v = values(x)
out = weights @ v
out.shape

torch.Size([4, 8, 16])