In [7]:
# Hyperparameters
n_embd = 32
batch_size = 16
block_size = 32

In [8]:
# Let's modify our simple bigram model to a simple attention model 
# The model uses an attention head, which was derived in part 2

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy
 
# Implementation of a single attention head   
    
class Head(nn.Module):
    
    def __init__(self, n_embd):
        super().__init__()
        
        self.key_repr = nn.Linear(n_embd, block_size, bias=False)
        self.query_repr = nn.Linear(n_embd, block_size, bias=False)
        self.value_repr = nn.Linear(n_embd, block_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, x: torch.Tensor):
        B, T, C = x.shape
        
        # create learned representations
        k = self.key_repr(x)
        q = self.query_repr(x)
        v = self.value_repr(x)
        
        # compute attention scores ('affinities between tokens')
        W = q @ k.transpose(-2,-1)  # (B,T,C) @ (B,C,T) -> (B,T,T)
        W *= C ** -0.5  # scaling of the dot product to keep softmax from saturating too much
        W = W.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B,T,T) # TODO: explain [:T, :T]
        W = F.softmax(W, dim=-1) # (B,T,T)
        
        # compute weighted aggregation of values
        out = W @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out        

In [9]:
# Implementation of the new model that modifies the Bigram model to use an attention head.
# For this we now also modify the embedding of the input:
# We use an additional positional encoding to add information about the temporal dimension of each token.
# This is useful because attention in itself is a position agnostic communication mechanism, where the position/time is not taken into account by nature.

# In addition to the self-attention head, we also add a final linear layer, which is called lm_head. 
# Here we project our output back to the dimension of the vocabulary so we can compute a probability distribution of all characters in the vocabulary.

class SimpleAttentionModel(nn.Module):
    
    def __init__(self, n_embd: int, vocab_size: int):
        super().__init__()
        self.token_embeddings_table = nn.Embedding(vocab_size, n_embd)
        self.position_embeddings_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd=n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, inputs: torch.tensor, targets: torch.tensor = None):
        B, T = inputs.shape
        
        # inputs and targets are both (B,T) tensors of integers
        tok_emb = self.token_embeddings_table(inputs)
        pos_emb = self.position_embeddings_table(torch.arange(T)) # (T,C)
        
        x = tok_emb + pos_emb # (B,T,C)
        x = self.sa_head(x)  # apply one head of self-attention (B,T,C)
        logits = self.lm_head(x)  # (B,T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view (B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, inputs, max_new_tokens):
        # Generate new tokens one at a time, 
        # using now the full sequence of n = block_size tokens, weighted via self-attention,
        # to decide on the next generated token
        
        for _ in range(max_new_tokens):
            
            # make sure the inputs are at max block_size number of tokens
            # necessary because our position embedding can only encode position information for up to block_size tokens
            model_inputs = inputs[:, -block_size:]
            #print(f"Generate: inputs.shape = {inputs.shape}")
            
            logits, _ = self(model_inputs)  # shape: (B,T,C)
            # For generation, we only need the predictions from the last position
            probs = F.softmax(logits[:, -1, :], dim=-1)  # shape: (B,C)
            
            # Sample from the probability distribution to get the next token
            inputs_next = torch.multinomial(probs, num_samples=1)  # shape: (B,1)
            
            # Append the new token to our sequence
            inputs = torch.cat((inputs, inputs_next), dim=1)  # shape: (B,T+1)
        return inputs

In [10]:
# This cell adds context form part 1 so that we can perform a similar training run as for the bigram model.

# Read dataset
with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    
# How we encode text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Dataset encoding
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

# Generating train and test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Generate random batch from the training split
def get_batch(split):
    data = train_data if split == 'train' else val_data
    idxs = torch.randint(len(data) - block_size, size=(batch_size,))
    x = torch.stack([data[i:i + block_size] for i in idxs])
    y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
    return x,y  


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65
torch.Size([1115394]) torch.int64


In [11]:
model = SimpleAttentionModel(n_embd=n_embd, vocab_size=vocab_size)

inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(decoded_output)


.gRIuavUDT E:W KyHiHS
hY cEKmNeg3daaSGA.nQWSfQXuww$y;lQZUI
h:E-NSzUDZIjlFy$vIPC&?cwO&S,gv
 kh

VwiwMt!QWAZKn&L:J'.u;MOBrLm-:C3mAKOngcvn
WoVVQ!iDuKt3YK$rOGcGTKAIJupIO
GtOfGlEpT,;CxER$lEVJ&ewZ?BknZ-
hf:$ 
hFSzDhAc3'?DHWsV,L.XrXD.IuY3.3aWF,3VEG
&z!NorVa'mXMXSXZLFaq'ZImxYAMahxZGbKTin',Id3doCI;F
vuDcFipg,C,:g;3FminQxS'anNcZwVy;fx&;NXJ-
f3R;i
,aaUY
 dF'j$lK-3Wbs;&jhWDccwCci&FFZ
jticGBJPNAQciekv3dwfD&UE$zHAEMvq3;YZDYJ.!R
bqDRcGsose'g,Y'Qf.vYJ;pbkXMNI,cqbgEhBnrnHyN$gC'XHs
ElGJ,.z
pYE?rmyXQ:LpkzUIuLsBhlB


In [12]:

# Now we train our simple attention head model!
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for steps in range(10000):

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

# We can see that, compared to our bigram model loss with rougly 2.57,
# we can achieve a smaller loss of 2.39 with the same amount of training steps!
# This is a nice improvement, however, if we generate now text with our trained model, 
# it is still not the yellow from the egg (german joke sry O_O)

2.442214250564575


In [13]:
inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(len(decoded_output))
print(decoded_output)

501

Why hary, omolurs itheande farr
'CHHien los hfid.
QUE:
Lostlo meaghorours
Fo sheras thasepies tava pu, wire dy I to, ge whe wred
T:
Hot!
SI thre sirsomeard, ofurd frak vexser eet thifr.
BRGAndld bere Rinoto's faly stem meakeet bre theexoceracmat Ork athe
JLETS:
O this.

Win blitilelyouslow hy ce
thangad ary yourf houin, kis farrt thit chas stsothite sefr f heargeame:
Pat un limit mones tit oter anst, lsives, shor aur?

Cormifir's t thiver adne shos fig myou peresworb,
NGrd woome.
NORWIIIS:
Horag
