In [62]:
import time
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm_notebook

In [145]:
# parameters config
context_length = 256
batch_size = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd = 384
num_heads = 6
lr = 3e-4
max_new_token = 500
epochs = 1

# Tinyshakespeare dataset

In [126]:
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
res = requests.get(url)
res.raise_for_status()
text = res.content.decode("utf-8")

In [127]:
print(f'length of dataset in characters: {len(text)}')

length of dataset in characters: 1115394


In [128]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [129]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"{''.join(chars)} \nNumber of token: {vocab_size}")


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 
Number of token: 65


In [130]:
# text tokenization in character level
itos = {i:c for i, c in enumerate(chars)}
stoi = {c:i for i, c in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s] # encoder: string to list of int
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: list of int to string

In [131]:
print(encode("Optimum prideee!!!"))
print(decode(encode("Optimum prideee!!!")))

[27, 54, 58, 47, 51, 59, 51, 1, 54, 56, 47, 42, 43, 43, 43, 2, 2, 2]
Optimum prideee!!!


In [132]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


# Train and Val split data

In [133]:
# Train Test split
n = int(0.9*len(data))
train = data[:n]
test = data[n:]

In [134]:
train[:context_length+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

In [149]:
torch.manual_seed(1337)

def get_batch(split):
    data = train if split == 'train' else test
    ix = torch.randint(len(data) - context_length, (batch_size, ))
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch('train')

# for b in range(batch_size):
#     for t in range(context_length):
#         context = xb[b, :t+1]
#         target = yb[b, t]
#         print(f"Input: {context} | target: {target}")
#     print("---"*20)

# Single head self-attention

In [137]:
# single head self-attention
class Head(nn.Module):
    def __init__(self, n_embd, head_size, context_length):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))
        
        self.head_size = head_size
        
    def forward(self, x):
        B, T, C = x.shape # (B, T, C) C=n_embd
        k = self.key(x) # (B, T, C) @ (C, H) = (B, T, H)
        q = self.query(x) # (B, T, C) @ (C, H) = (B, T, H)
        # attention scores (affinities)
        att = q @ k.transpose(-2, -1) / self.head_size**0.5
        att = att.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        att = F.softmax(att, dim=-1)
        # weighted sum of the values
        v = self.value(x) # (B, T, C) @ (C, H) = (B, T, H)
        out = att @ v
        return out

# Multi-head self-attention

In [138]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, n_embd, head_size, context_length):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embd, head_size, context_length) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, num_heads * head_size)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

# Feed-Forward Networks

In [139]:
class FeedForward(nn.Module):
    def __init__(self, fan_in, fan_out):
        super().__init__()
        self.net = nn.Sequential(
                nn.Linear(fan_in, 4 * fan_out),
                nn.ReLU(),
                nn.Linear(4 * fan_out, fan_out),
        )
        
    def forward(self, x):
        return self.net(x)

# Decoder Transformer block

In [140]:
class Block(nn.Module):
    def __init__(self, num_heads, n_embd, head_size, context_length):
        super().__init__()
        self.sd_heads = MultiHeadAttention(num_heads, n_embd, head_size, context_length)
        self.ffwd = FeedForward(n_embd, n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        x = x + self.sd_heads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# Bigram language model

In [141]:
class BigramLanguageModel(nn.Module):
    def __init__(self, 
                     vocab_size, 
                     n_embd, 
                     num_heads, # heads params
                     head_size, # heads params
                     context_length 
                ):
        super().__init__()
        self.token_embbeding_table = nn.Embedding(vocab_size, n_embd) # (B, T) -> (B, T, C)
        self.position_embbeding_table = nn.Embedding(context_length, n_embd)
        self.blocks = nn.Sequential(
                Block(num_heads, n_embd, head_size, context_length),
                Block(num_heads, n_embd, head_size, context_length),
                Block(num_heads, n_embd, head_size, context_length),
                Block(num_heads, n_embd, head_size, context_length),
                Block(num_heads, n_embd, head_size, context_length),
                Block(num_heads, n_embd, head_size, context_length),
                nn.LayerNorm(n_embd)
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
        self.context_length = context_length
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embbeding_table(idx) # (B, T, C)  C=n_embd
        pos_emb = self.position_embbeding_table(torch.arange(T, device=device)) # (T, C) C=n_embd
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, C) C=vocab_size
        
        if targets is None:
            loss = None
        else:
            # change logits shape for cross entropy loss (B, C)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
        
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.context_length:]
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.concat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [146]:
model = BigramLanguageModel(
                                        vocab_size=vocab_size,
                                        n_embd=n_embd,
                                        num_heads=num_heads,
                                        head_size=n_embd // num_heads,
                                        context_length=context_length
            ).to(device)

In [147]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
sum(p.numel() for p in model.parameters())

10788929

In [148]:
for epoch in tqdm_notebook(range(epochs)):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for epoch in tqdm_notebook(range(epochs)):


  0%|          | 0/1 [00:00<?, ?it/s]

4.287578105926514


In [124]:
out = decode(model.generate(torch.zeros((1,1), dtype=torch.long), 1000)[0].tolist())
print(out)


and wo ticherge naive desens, alld papay my ar dive, ffore wen try him fordsiat or, and wase sowr with ait so hus me are treingear may,
An sconhe and fich sonith
I os nese'd epper pleesse,
Whan witte cee soneow we't me arwagh.

LOUS: wia we you hish We thes muse oun with deilse ry lity the for stry me do ad ley yeepel of a'd Foke
Corke deares,
I'll olst that you ace he caway;
Toslaize! ar! ELorjow he

garcuntlearwith the mest habilan wis. hey my I bow got at fir
For mu-akre baan, I we?

Ase wo?

MOV will is or leivowparaw with will'.

KI bathis noticee hiss kis: ad, his is woor wo when any my ast his coud ther ebs.

Butixswle, and mowrrinfewere sear andlarby fear and che shed
What brat. ce my home ader and cany bly yourowspener, of pounk, g; Porce swore's
Wootill noteades.
Thirase,
As of o, met buse lioour capunciruges:er,
Rouhke are! as themavee.

Then of qruce.

Kh an delot ary wis arsage my sbrounes, igheran en my knoake fore lied the she hit marnys enceand ecan uh, thou.

NRUEEN E

In [145]:
# before feed-forward 2.160
# after feed-forward 2.549
# block 2.354
# add residual connect 1.85