In [413]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import matplotlib.pyplot as plt # for making figures
import re
import torch.nn as nn
%matplotlib inline

In [414]:
if torch.backends.mps.is_available():
    print("Using mps")
    torch.set_default_device("mps") #will run metal performance shaders on Mac for better performance

Using mps


In [415]:
text = open('game_of_thrones.txt', 'r').read()

In [416]:
print(text[:1000])

A Song of Ice and Fire

A Game of Thrones

PROLOGUE

We should start back, Gared urged as the woods began to grow dark around them.  The wildlings are dead.

Do the dead frighten you? Ser Waymar Royce asked with just the hint of a smile.

Gared did not rise to the bait. He was an old man, past fifty, and he had seen the lordlings come and go.  Dead is dead, he said.  We have no business with the dead.

Are they dead? Royce asked softly.  What proof have we?

Will saw them, Gared said.  If he says they are dead, that's proof enough for me.

Will had known they would drag him into the quarrel sooner or later. He wished it had been later rather than sooner.  My mother told me that dead men sing no songs, he put in.

My wet nurse said the same thing, Will, Royce replied.  Never believe anything you hear at a woman's tit. There are things to be learned even from the dead. His voice echoed, too loud in the twilit forest.

We have a long ride before us, Gared pointed out.  Eight days, maybe n

In [417]:
lines = text.splitlines()

In [418]:
special_chars = ',?;.:/*!+-()[]{}"\'&'
sentences = [re.sub(f'[{re.escape(special_chars)}]', ' \g<0> ', s).split(' ') for s in lines] #tokenize sentences into words
sentences = [[w for w in s if len(w)] for s in sentences] #remove null values

In [419]:
sentences[:8]

[['A', 'Song', 'of', 'Ice', 'and', 'Fire'],
 [],
 ['A', 'Game', 'of', 'Thrones'],
 [],
 ['PROLOGUE'],
 [],
 ['We',
  'should',
  'start',
  'back',
  ',',
  'Gared',
  'urged',
  'as',
  'the',
  'woods',
  'began',
  'to',
  'grow',
  'dark',
  'around',
  'them',
  '.',
  'The',
  'wildlings',
  'are',
  'dead',
  '.'],
 []]

In [420]:
vocab = set()
for s in sentences:
    for word in s:
        vocab.add(word)
vocab.add('<unk>')
vocab.add('<s>')
vocab.add('<pad>')
vocab.add('\n')
vocab_size = len(vocab)

In [421]:
vocab_size

13395

In [422]:
lookup = {word:i for i,word in enumerate(vocab)}
reverse_lookup = {value:key for key,value in lookup.items()}
encode = lambda x: [lookup[word] for word in x]    
decode = lambda x: ' '.join([reverse_lookup[c] for c in x])

In [423]:
encode('\n')

[10877]

In [424]:
data = []
for sentence in sentences:
    if len(sentence) == 0:
        data.append('\n')
        continue
    for word in sentence:
        data.append(word)

In [425]:
data = torch.tensor(encode(data), dtype=torch.long)

In [426]:
print(data.shape, data.dtype)

print(data[:10])
print(decode(data[:100].tolist()))

torch.Size([360857]) torch.int64
tensor([10218,  7756,  4517,  1902,  5017, 10379, 10877, 10218,  4141,  4517],
       device='mps:0')
A Song of Ice and Fire 
 A Game of Thrones 
 PROLOGUE 
 We should start back , Gared urged as the woods began to grow dark around them . The wildlings are dead . 
 Do the dead frighten you ? Ser Waymar Royce asked with just the hint of a smile . 
 Gared did not rise to the bait . He was an old man , past fifty , and he had seen the lordlings come and go . Dead is dead , he said . We have no business with the dead . 
 Are


In [427]:
n_training = int(0.8*len(data))
n_dev = int(0.9*len(data))
train_data = data[:n_training]
val_data = data[n_training:n_dev]
test_data = data[n_dev:]

In [428]:
seq_len = 10
x = train_data[:seq_len]
y = train_data[1:seq_len+1]
for t in range(seq_len):
    context = x[:t+1]
    target = y[t]
    context = decode(context.tolist())
    target = decode([target.item()])
    
    print(f"when input is {context}, target is {target}")

when input is A, target is Song
when input is A Song, target is of
when input is A Song of, target is Ice
when input is A Song of Ice, target is and
when input is A Song of Ice and, target is Fire
when input is A Song of Ice and Fire, target is 

when input is A Song of Ice and Fire 
, target is A
when input is A Song of Ice and Fire 
 A, target is Game
when input is A Song of Ice and Fire 
 A Game, target is of
when input is A Song of Ice and Fire 
 A Game of, target is Thrones


In [None]:
batch_size = 4
seq_len = 10
def get_batch(batch_size, seq_len, split='train'):
    datasets = {'train': train_data, 'val': val_data, 'test': test_data}
    dataset = datasets[split]
    ix = torch.randint(len(dataset) - seq_len, (batch_size,))
    x = torch.stack([dataset[i:i+seq_len] for i in ix])
    y = torch.stack([dataset[i+1:i+seq_len+1] for i in ix])
    return x, y

In [430]:
xb, yb = get_batch(batch_size, seq_len)

In [431]:
print(xb.shape, yb.shape)
print(xb) #input to transformer
print(yb)

torch.Size([4, 10]) torch.Size([4, 10])
tensor([[13191,  4973,  8219,  3768,  4364, 11823,  3183, 12923,  8452,  5086],
        [ 4676,  4500,  5017, 11569, 11746,  5329,  4973, 13346,  9686,  8452],
        [ 7733,  8121,  7384,  2689,  6100,  4349,  5677, 13284,  9375, 13154],
        [ 4973, 10877, 11741, 10053,  3768,  4364, 11174,  8446,  9803,  6100]],
       device='mps:0')
tensor([[ 4973,  8219,  3768,  4364, 11823,  3183, 12923,  8452,  5086, 12923],
        [ 4500,  5017, 11569, 11746,  5329,  4973, 13346,  9686,  8452,  9356],
        [ 8121,  7384,  2689,  6100,  4349,  5677, 13284,  9375, 13154,  8452],
        [10877, 11741, 10053,  3768,  4364, 11174,  8446,  9803,  6100, 11383]],
       device='mps:0')


In [432]:
for b in range(1):
    for t in range(seq_len):
        context = xb[b][:t+1]
        target = yb[b][t]
        context = decode(context.tolist())
        target = decode([target.item()])
        print(f"when input is: {context}, target is: {target}")
        

when input is: me, target is: .
when input is: me ., target is: What
when input is: me . What, target is: '
when input is: me . What ', target is: s
when input is: me . What ' s, target is: wrong
when input is: me . What ' s wrong, target is: with
when input is: me . What ' s wrong with, target is: you
when input is: me . What ' s wrong with you, target is: ,
when input is: me . What ' s wrong with you ,, target is: are
when input is: me . What ' s wrong with you , are, target is: you


In [433]:
emb = nn.Embedding(vocab_size, 24)
print(emb(xb).shape) # (B, T, C)

pe = torch.zeros(seq_len, 24)
print(pe[:seq_len, :24].shape)
print(pe.shape)

torch.Size([4, 10, 24])
torch.Size([10, 24])
torch.Size([10, 24])


In [434]:
class PositionalEncoding(nn.Module):

    def __init__(self, embd_dim, seq_len):
        super().__init__()
        self.pe = torch.zeros(seq_len, embd_dim) # (T, C)
        self.pe.requires_grad = False

        pos = torch.arange(0, seq_len).float()
        pos = pos.unsqueeze(dim=1)

        even_positions = torch.arange(0, embd_dim, 2).float()
        self.pe[:, 0::2] = torch.sin(pos / (10000 ** (even_positions / embd_dim)))
        self.pe[:, 1::2] = torch.cos(pos / (10000 ** (even_positions / embd_dim)))
        # compute sinusoidal positional embeddings as in the attention paper
        
        
    def forward(self, x): # x is of shape (B, T, C)
        B, T = x.shape
        return self.pe[:T, :] # (T, C)
              

In [435]:
class Head(nn.Module): #single head of self attention
    def __init__(self, n_head, embd_dim, seq_len, dropout):
        super().__init__()
        self.query = nn.Linear(embd_dim, n_head, bias=False)
        self.key = nn.Linear(embd_dim, n_head, bias=False)
        self.value = nn.Linear(embd_dim, n_head, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(seq_len, seq_len)))
        self.n_head = n_head
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape # batch, time (seq_len), channel (# of features, or # of embedding dimensions)
        key = self.key(x) # (B, T, n_head)
        query = self.query(x) # (B, T, n_head)

        # compute attention scores
        wei = query @ key.transpose(-2, -1) * self.n_head ** -0.5  # (B, T, n_head) x (B, n_head, T) = (B, T, T)
        # internally, pyTorch is doing:
        # for every b in B (batch_dim) do: matrix multiply of (T, C) x (C, T) = (T, T) for every b = (B, T, T)
        #initialize with std dev of 0 and variance of 1 so that softmax works properly ( does not converge to one hot encodings)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        value = self.value(x) # (B, T, n_head)
        out = wei @ value # (B, T, T) x (B, T, n_head) = (B, T, n_head)
        return out      

In [436]:
embd_dim = 32
n_head = 5
emb = nn.Embedding(vocab_size, embd_dim)
pe = PositionalEncoding(embd_dim, seq_len)
h = Head(n_head, embd_dim, seq_len, 0.2)

In [437]:
x = emb(xb) + pe(xb)
x = h(x)
x.shape

torch.Size([4, 10, 5])

In [438]:
class MultiHeadAttention(nn.Module):

    def __init__(self, n_heads, head_size, embd_dim, seq_len, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, embd_dim, seq_len, dropout) for _ in range(n_heads)])
        self.proj = nn.Linear(embd_dim, embd_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [439]:
class FeedForward(nn.Module):

    def __init__(self, embd_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embd_dim, 4 * embd_dim),
            nn.ReLU(),
            nn.Linear(4 * embd_dim, embd_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [440]:
class Block(nn.Module):
    def __init__(self, embd_dim, n_heads, seq_len, dropout):
        super().__init__()
        head_size = embd_dim // n_heads
        self.sa = MultiHeadAttention(n_heads, head_size, embd_dim, seq_len, dropout)
        self.ffwd = FeedForward(embd_dim, dropout)
        self.ln1 = nn.LayerNorm(embd_dim)
        self.ln2 = nn.LayerNorm(embd_dim)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) #skip connection
        x = x + self.ffwd(self.ln2(x)) #skip connection
        return x

In [441]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embd_dim, seq_len, n_heads, n_layers, dropout):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embd_dim)
        self.pe = PositionalEncoding(embd_dim, seq_len)
        self.blocks = nn.Sequential(*[Block(embd_dim, n_heads, seq_len, dropout) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(embd_dim)
        self.lm_head = nn.Linear(embd_dim, vocab_size)
        self.seq_len = seq_len

    def forward(self, x, targets=None):
        B, T = x.shape
        x = self.emb(x) + self.pe(x)
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss   

    def generate(self, context, max_new_tokens):
        # context is (B, T) shape 
        for _ in range(max_new_tokens):
            context_max = context[:, -self.seq_len:] #crop context to get last seq_len tokens
            logits, loss = self(context_max) #get the predictions
            logits = logits[:, -1, :] # new shape is (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
            context = torch.cat((context, next_token), dim=1) # (B, T+1)
        return context

In [442]:
embd_dim = 16
batch_size = 8
n_heads = 4
n_layers = 4
dropout = 0.1
seq_len = 8

model = LanguageModel(vocab_size, embd_dim, seq_len, n_heads, n_layers, dropout)

In [443]:
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

0.454995 M parameters


In [474]:
learning_rate = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [445]:
def estimate_loss(eval_iters):
    with torch.no_grad():
        out = {}
        model.eval()
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                X, Y = get_batch(batch_size, seq_len, split)
                logits, loss = model(X, Y)
                losses[k] = loss.item()
            out[split] = losses.mean()
        model.train()
        return out

In [475]:
eval_interval = 100
eval_iters = 100
max_iters = 1000
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(eval_iters)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch(batch_size, seq_len)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    

step 0: train loss 4.9999, val loss 5.6076
step 100: train loss 4.9774, val loss 5.5121
step 200: train loss 4.8922, val loss 5.4117
step 300: train loss 4.9755, val loss 5.4341
step 400: train loss 4.9473, val loss 5.4772
step 500: train loss 4.9234, val loss 5.5013
step 600: train loss 4.9295, val loss 5.4918
step 700: train loss 5.0022, val loss 5.3713
step 800: train loss 4.9329, val loss 5.4762
step 900: train loss 4.9393, val loss 5.4992
step 999: train loss 4.9339, val loss 5.5361


In [479]:
encode("He '".split(' '))

[13095, 3768]

In [480]:
# encode('\n') = 10877
context = torch.tensor([13095, 3768], dtype=torch.long).view(1, 2)
# context = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(context, max_new_tokens=100)[0].tolist()))

He ' s ride himself every year . Khal Drogo century sister ' s death was up at first particular friends , Marillion . 
 Give water , thin legs must make them you here out , part home , now . His men doors time you think he said Mormont gave him , the center of the east one and Catelyn ' s floor and shelters . You had told Bronn said Baelish truly see her . shuddered , so skulls around the stone parapets and at him , Queen 
 For a helm with them when I was because Jon
