In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

block_size = 64
emb_dim = 32
batch_size = 64
head_dim = 32

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [88]:
# Create the dataset
with open('shakespeare.txt') as file:
    content = file.read()

vocabulary = list(set(list(content)))
stoi = {s:i for i,s in enumerate(vocabulary)}
itos = {i:s for i,s in enumerate(vocabulary)}
print(len(vocabulary))

tokenized = torch.tensor([stoi[char] for char in list(content)])
n = int(0.9*len(tokenized))
train_data = tokenized[:n]
val_data = tokenized[n:]
print(len(train_data), len(val_data))

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]) # For every letter, the target is the letter to the right
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

x, y = get_batch('train')
x.shape, y.shape # Check how for each data element (64 features) there are 64 targets. Thats is because self-attention generates an output for each position (this can be used to make the prediction of the next token)

65
1003854 111540


(torch.Size([64, 64]), torch.Size([64, 64]))

In [97]:
decode = lambda x: "".join([itos[ix.item()] for ix in x])

# Examples:
for i in range(8):
    print(f'\n\nEXAMPLE {i+1}:')
    for j in range(x.shape[1]):
        print(f'{decode(x[i,:j+1])} -> {decode([y[i,j]])}')



EXAMPLE 1:
o -> i
oi -> s
ois -> e
oise -> 

oise
 -> I
oise
I ->  
oise
I  -> t
oise
I t -> r
oise
I tr -> e
oise
I tre -> m
oise
I trem -> b
oise
I tremb -> l
oise
I trembl -> i
oise
I trembli -> n
oise
I tremblin -> g
oise
I trembling ->  
oise
I trembling  -> w
oise
I trembling w -> a
oise
I trembling wa -> k
oise
I trembling wak -> e
oise
I trembling wake -> d
oise
I trembling waked -> ,
oise
I trembling waked, ->  
oise
I trembling waked,  -> a
oise
I trembling waked, a -> n
oise
I trembling waked, an -> d
oise
I trembling waked, and ->  
oise
I trembling waked, and  -> f
oise
I trembling waked, and f -> o
oise
I trembling waked, and fo -> r
oise
I trembling waked, and for ->  
oise
I trembling waked, and for  -> a
oise
I trembling waked, and for a ->  
oise
I trembling waked, and for a  -> s
oise
I trembling waked, and for a s -> e
oise
I trembling waked, and for a se -> a
oise
I trembling waked, and for a sea -> s
oise
I trembling waked, and for a seas -> o
oise
I trembling w

In [85]:
class Head(nn.Module):
    def __init__(self,inp_dim, h_dim, block_size):
        super().__init__()
        self.Wq = nn.Parameter(torch.randn((inp_dim, h_dim), device=device) * inp_dim**-0.5)
        self.Wk = nn.Parameter(torch.randn((inp_dim, h_dim), device=device) * inp_dim**-0.5)
        self.Wv = nn.Parameter(torch.randn((inp_dim, h_dim), device=device) * inp_dim**-0.5)
        self.tril = torch.tril(torch.ones(block_size, block_size, device=device))

    def forward(self, x):
        B, T, C = x.shape
        Q = x @ self.Wq
        K = x @ self.Wk
        V = x @ self.Wv
        attention = Q @ K.transpose(-2, -1) # [B,T,C]@[B,C,T] = [B,T,T]
        attention *= Q.shape[-1]**-0.5
        masked_attention = attention.masked_fill(self.tril[:T,:T] == 0, -torch.inf) # Mask future tokens. The [:T,:T] is to be able to work with inputs of less than T tokens, this is helpful for generation
        att_weights = F.softmax(masked_attention, dim=2) # Regularize the weights
        return att_weights @ V # [B,T,T]@[B,T,C] = [B,T,C] Make a weighted average of the vectors the possition attends to

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads, block_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(input_dim, head_dim//num_heads, block_size) for _ in range(num_heads)])
        self.linear = nn.Linear(head_dim, input_dim)

    def forward(self, x):
        att = torch.cat([h(x) for h in self.heads], dim=-1)
        att = self.linear(att)
        return att
    
class Block(nn.Module):
    def __init__(self, num_heads, emb_dim, block_size):
        super().__init__()
        self.selfAttention = MultiHeadAttention(emb_dim, num_heads, block_size)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ffd = nn.Sequential(
            nn.Linear(emb_dim, 4*emb_dim), # scale up
            nn.ReLU(),
            nn.Linear(4*emb_dim, emb_dim) # scale down again for compatibility with the residual connections
        )
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        att = self.selfAttention(x)
        att = self.ln1(att + x)
        x = self.ffd(x)
        x = self.ln2(att + x)
        return x


class GPTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(len(vocabulary), emb_dim)
        self.pos_emb = nn.Embedding(block_size, emb_dim)

        self.blocks = nn.Sequential(*[Block(4, emb_dim, block_size) for _ in range(4)]) # Without the * it breaks

        self.final_dropout = nn.Dropout(0.4)

        self.final_proj = nn.Linear(head_dim, len(vocabulary))

    def forward(self, x, targets=None):
        emb1 = self.token_emb(x)
        # print(emb1.shape)
        emb2 = self.pos_emb(torch.arange(emb1.shape[1], device=device)) # Use emb.shape[1] instead of block_size because of inference. To being able to generate a prediction starting from 1 token with a max of block_size
        # print(emb2.shape)
        emb = emb1 + emb2
        # print(emb.shape)
        x = self.blocks(emb)

        x = self.final_dropout(x)
        # print(att.shape)
        logits = self.final_proj(x)
        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(batch_size*block_size, -1), targets.view(batch_size*block_size))
        return logits, loss

In [86]:
@torch.no_grad
def estimate_loss():
    model.eval()
    samples = 10
    # Estimate train loss
    losses = []
    for i in range(samples):
        x, y = get_batch('train')
        _, loss = model(x, y)
        losses.append(loss)
    print(f'Train loss: {sum(losses)/samples:.2f}', end='\t')
    # Estimate Validation loss
    losses = []
    for i in range(samples):
        x, y = get_batch('valid')
        _, loss = model(x, y)
        losses.append(loss)
    print(f'Val loss: {sum(losses)/samples:.2f}')
    model.train()

In [87]:
import math

steps = 100000
learning_rate = 1e-3
lossi = []

model = GPTModel()
model.train()
model = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

# Baseline loss
print('baseline loss', -math.log(1/len(vocabulary)))

for i in range(steps):
    x, y = get_batch('train')
    logits, loss = model(x, y)
    if i % 10000 == 0:
        estimate_loss()
    
    # Update the weights
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Track stats
    lossi.append(loss.item())
loss.item()

0.056705 M parameters
baseline loss 4.174387269895637
Train loss: 4.33	Val loss: 4.32
Train loss: 1.79	Val loss: 1.92
Train loss: 1.69	Val loss: 1.84
Train loss: 1.65	Val loss: 1.84


KeyboardInterrupt: 

In [98]:
# Generation
model.eval()
@torch.no_grad
def generate(max_tokens=500):
    sample = torch.tensor([[stoi['\n']]], device=device)
    while sample.shape[-1] < max_tokens:
        logits, loss = model(sample[:,:block_size])
        logits = logits[:,-1,:]
        # print(logits.shape)
        probs = F.softmax(logits, dim=-1)
        next_tok = torch.multinomial(probs, 1)
        sample = torch.cat((sample, next_tok), dim=-1)
    return decode(sample[0])

print(generate())


Ghe nof man a-dandce attes one for
by sejothd's the escroken toy         l   l  u                         on               s             p omodf  i  tf               o      o    '       r   ,        n                      tl   ,         fo     ol          l   b      s-!    ,   
       m            os   n ,      o       o              . . s             ,            ,    n f       
    o                 t;y          n    m      '                 -            t n g    n     
e   ,   l  it       f 
