In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
with open('tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(len(text))
print(text[:1000])

1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for re

In [3]:
chars = sorted(list(set(text)))
ctoi = {c:i for i, c in enumerate(chars)}
itoc = {i:c for c, i in ctoi.items()}
encode = lambda x: [ctoi[c] for c in x]
decode = lambda x: ''.join([itoc[i] for i in x])

In [4]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [5]:
n = int(0.9*len(data)) # first 90% will be train, rest val
train = data[:n]
val = data[n:]

In [16]:
torch.manual_seed(42442)
# params
vocab_size = len(chars)
batch_size = 32 
block_size = 64 # context length
max_iters = 100000
learning_rate = 4e-4
emb_dim = 32
num_head = 8

def get_batch(split):
    data = train if split == 'train' else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

## Attention
Each character should have:
1. a key - info about itself
2. a query - something to match with previous characters
3. a value - something thats output if theres matching

In [7]:
torch.manual_seed(42447)

# self-attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(emb_dim, head_size, bias=False)
        self.query = nn.Linear(emb_dim, head_size, bias=False)
        self.value = nn.Linear(emb_dim, head_size, bias=False)

        # lower triangular mask keeps attention to past tokens only
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        wei = q@k.transpose(-2, -1) * self.head_size**-0.5 # scaling factor to keep variance constant
        # make all weights to future tokens to be -inf so softmax will give them 0.
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # [:T, :T] allows handling of variable sized context lengths
        wei = F.softmax(wei, dim=-1)
        out = wei @ v
        
        return out


In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(emb_dim, emb_dim)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # cat all head outputs
        out = self.proj(out) # linear projection out
        return out

In [9]:
class FF(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.layer = nn.Sequential(nn.Linear(features, 4*features),
                                    nn.ReLU(),
                                    nn.Linear(4*features, features))
    
    def forward(self, x):
        return self.layer(x)

In [10]:
class TransformerBlock(nn.Module):

    def __init__(self, emb_dim, n_head):
        super().__init__()
        head_size = emb_dim // n_head
        self.heads = MultiHeadAttention(n_head, head_size)
        self.ffwd = FF(emb_dim)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.heads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [11]:
torch.manual_seed(42442)

class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, emb_dim)
        self.positional_embedding = nn.Embedding(block_size, emb_dim)
        self.blocks = nn.Sequential(
            TransformerBlock(emb_dim, num_head),
            TransformerBlock(emb_dim, num_head),
            TransformerBlock(emb_dim, num_head),
        )
        self.norm = nn.LayerNorm(emb_dim)
        self.linear_out = nn.Linear(emb_dim, vocab_size)
    
    def forward(self, x, y=None):
        B, T = x.shape

        tok_embeddings = self.token_embedding(x) #  (Batch, Time, Channels=emb_dim) needs to be (Batch, Channels, Time) for F.cross_entropy
        pos_embeddings = self.positional_embedding(torch.arange(T)) # (Time, Channels=emb_dim)
        
        x = tok_embeddings + pos_embeddings
        x = self.blocks(x)
        x= self.norm(x)
        logits = self.linear_out(x) # (Batch, Time, Channels=vocab_size)

        if y is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # (Batch, Channels, Time)
            y = y.view(B*T)
            loss = F.cross_entropy(logits, y)
        return logits, loss

    def generate(self, idx, num_tokens=100):
        for _ in range(num_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # last time step only - prediction
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [12]:
m = LanguageModel()
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

for steps in range(max_iters):
    xb, yb = get_batch('train')

    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 2500 == 0:
        print(loss.item())

4.2884745597839355
2.2887392044067383
2.153799295425415
1.9790315628051758
1.9709405899047852
1.921905517578125
1.8027561902999878
1.7355912923812866
1.6589696407318115
1.7431864738464355
1.6915522813796997
1.667686104774475
1.6288710832595825
1.5517337322235107
1.7228153944015503
1.6182876825332642
1.7119801044464111
1.615836501121521
1.6016979217529297
1.6036126613616943
1.618397831916809
1.6096761226654053
1.5946933031082153
1.6630442142486572
1.6356638669967651
1.6423982381820679
1.590080976486206
1.6452021598815918
1.642789602279663
1.5645372867584229
1.6224675178527832
1.5617073774337769
1.6046854257583618
1.5925236940383911
1.4771521091461182
1.623162865638733
1.5698789358139038
1.5553843975067139
1.506607174873352
1.5822010040283203


In [15]:
@torch.no_grad()
def estimate_val_loss():
    eval_iters = 200
    out = {}
    m.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch(val)
        logits, loss = m(X, Y)
        losses[k] = loss.item()
    out[val] = losses.mean()
    m.train()
    return out

estimate_val_loss()

{tensor([12,  0,  0,  ..., 45,  8,  0]): tensor(1.7774)}

In [14]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(context, 3000)[0].tolist()))


I come torlow, good part in field but are grace,
That both one, lord of that each:
Aft I lod. Crivise and to to-jewle farewel:
My came your marks up to toward, let's ration
That estagers, nor a offent is and switer
With rast the son promp; the to your noble his
Brament, I'll years no make you his friend
Octremites wishs s erring in made faps quint
Should have made with And him nexessed home too
Upound the plaugant Potas desir.
Is that the far brial play this is bot?--
Percitue father crand.

Proth purness;
Pooldom, my lord!
Lord to that opes her shall the blood
With him ne'e-wrepsitie streride becans,
Or and my from and my will what we stones
if I never God make within his find on.
Ressure me addestagual, thousand lest we Clarence
Trut yield of his my this back anway.

LEONTES:
Prick it young unto conce of the bestiancions.

KING RICHARD ICI, find that in hath is my speak:
let your cun Lord; thy speak wome the pute bed,
The, and by unfring inciedinance!

MENENIUS:
Verman, by me to the