# GPT

- The GPT2 paper: https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
- The GPT3 paper: https://arxiv.org/pdf/2005.14165
- The Transformer paper: https://arxiv.org/pdf/1706.03762
- Based on Karpathy's tutorial: https://www.youtube.com/watch?v=kCc8FmEb1nY&t=3925s 
- Karpathy's nanoGPT https://github.com/karpathy/nanoGPT

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from tqdm import trange

#### The dataset

We use the  tinyshakespeare dataset 

In [3]:
text = open("./data/quijote.txt", 'r', encoding='utf-8').read()
print(text[:100])

El ingenioso hidalgo don Quijote de la Mancha


TASA

Yo, Juan Gallo de Andrada, escribano de Cámara


#### The tokenizer

For now we will consider a character level tokenizer, that is one character encodes to one token, and the GPT has to predict the next caracter. The pros are that the vocabulary size is very small, but the input sequences, the context length, get very large. 

Later we can use another kind of tokenizer, for instance, per word based. We can use the Hugging Face library for this. Other tokenizers are for examlpe, the tiktoker (used by GPT-2)

In [4]:
## Tokenizer
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("vocabulary: ", chars)
print(f"vocab size: {vocab_size}")

# tokenizer: we encode/decode strings into integer arrays (keep it simple). GPT uses tiktoker
char_to_idx = {ch:i for i,ch in enumerate(chars)}
idx_to_char = {i:ch for i,ch in enumerate(chars)}

encode = lambda string : [char_to_idx[ch] for ch in string]
decode = lambda array : ''.join([idx_to_char[i] for i in array])



vocabulary:  ['\n', ' ', '!', '"', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'x', 'y', 'z', '¡', '«', '»', '¿', 'Á', 'É', 'Í', 'Ñ', 'Ó', 'Ú', 'à', 'á', 'é', 'í', 'ï', 'ñ', 'ó', 'ù', 'ú', 'ü']
vocab size: 91


In [5]:
print(encode("La casa roja"))
print(decode(encode("La casa roja")))

[31, 47, 1, 49, 47, 64, 47, 1, 63, 60, 56, 47]
La casa roja


In [6]:
# Encode all data
data = torch.tensor(encode(text), dtype=torch.int64)
print(f"{data.shape=}, {data.dtype=}")
print(data[:100])

data.shape=torch.Size([2097953]), data.dtype=torch.int64
tensor([25, 57,  1, 55, 59, 53, 51, 59, 55, 60, 64, 60,  1, 54, 55, 50, 47, 57,
        53, 60,  1, 50, 60, 59,  1, 36, 66, 55, 56, 60, 65, 51,  1, 50, 51,  1,
        57, 47,  1, 32, 47, 59, 49, 54, 47,  0,  0,  0, 39, 21, 38, 21,  0,  0,
        44, 60,  7,  1, 30, 66, 47, 59,  1, 27, 47, 57, 57, 60,  1, 50, 51,  1,
        21, 59, 50, 63, 47, 50, 47,  7,  1, 51, 64, 49, 63, 55, 48, 47, 59, 60,
         1, 50, 51,  1, 23, 82, 58, 47, 63, 47])


##### Create the train/validation splits

The data is just a huge one dimensional tensor of integers. Let us split it in train and validation

In [7]:
# create the training/validation splits
n = int(0.9 * len(data))
data_train = data[:n]
data_val = data[n:]
print(f"{data_train.shape=}, {data_train.dtype=}")
print(f"{data_val.shape=}, {data_val.dtype=}")

data_train.shape=torch.Size([1888157]), data_train.dtype=torch.int64
data_val.shape=torch.Size([209796]), data_val.dtype=torch.int64


##### What is the input and what is the target?

- We consider sequences of data, for instance sentences. Usually it is called context length. The maximum of this size is fixed before hand.
- Also, we pass the data in batches

The target will be: given one sequence of characters, the next character will be the target

In [8]:
# get a batch of data. 
torch.manual_seed(1234)

def get_batch(split : str, batch_size : str, block_size : str): # the input will be then a 4x8 tensor
    dat = data_train if split == "train" else data_val
    ix = torch.randint(0, dat.shape[0] - block_size, (batch_size,)) # create 4 random indeces
    x =  torch.stack([dat[i:i+block_size] for i in ix])
    y =  torch.stack([dat[i+1:i+block_size+1] for i in ix]) # It is shifted to the right
    return x, y


batch_size = 4 # how many sequencies (randomly generated) will fill the transformer at once
block_size = 8 # the maximum context lenght for predictions
xb, yb = get_batch('train', batch_size=batch_size, block_size=block_size)
print(xb.shape, yb.shape)
print(xb, "\n", yb)

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When the input is {context.tolist()} the target is {target}")
    break

# This effectevely contains a set of 8*4 independent examples.

torch.Size([4, 8]) torch.Size([4, 8])
tensor([[ 1, 51, 59,  1, 57, 47, 64,  1],
        [60,  9,  0, 35, 60, 63, 62, 66],
        [51, 64, 65, 55, 58, 47, 50, 47],
        [55, 59, 65, 51, 59, 49, 55, 87]]) 
 tensor([[51, 59,  1, 57, 47, 64,  1, 65],
        [ 9,  0, 35, 60, 63, 62, 66, 51],
        [64, 65, 55, 58, 47, 50, 47,  1],
        [59, 65, 51, 59, 49, 55, 87, 59]])
When the input is [1] the target is 51
When the input is [1, 51] the target is 59
When the input is [1, 51, 59] the target is 1
When the input is [1, 51, 59, 1] the target is 57
When the input is [1, 51, 59, 1, 57] the target is 47
When the input is [1, 51, 59, 1, 57, 47] the target is 64
When the input is [1, 51, 59, 1, 57, 47, 64] the target is 1
When the input is [1, 51, 59, 1, 57, 47, 64, 1] the target is 65


#### The Model

In [118]:
# Hyper params
vocab_size = len(chars)
block_size = 256
n_embed = 384
num_heads = 6
num_layers = 6

batch_size = 64
learning_rate = 3e-4

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
max_iters = 5000
eval_iters = 20
eval_interval = 200
dropout = 0.2

##########

torch.manual_seed(1234)
xb, yb = get_batch('train', batch_size=batch_size, block_size=block_size)
print(f"{xb.shape=}")

xb.shape=torch.Size([64, 256])


In [119]:
# ------------------------------
# 
# Token Embedding
#
# ------------------------------

# For each token in the sequnce (a number within the range of vocab_size) will be assign a vector of size n_embed fom the embedding table.
# This encodes the identity of each token

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size : int, n_embed : int):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_embed = n_embed
        self.embedding = nn.Embedding(vocab_size, n_embed)

    def forward(self, x):
        # multiplied by sqrt(d_model) in the paper
        return self.embedding(x) * np.sqrt(self.n_embed)  # (batch, block_size) -> (batch, block_size, n_embed)

# Test
token_embedding = TokenEmbedding(vocab_size=vocab_size, n_embed=n_embed)
print(f"{token_embedding(xb).shape=}")


token_embedding(xb).shape=torch.Size([64, 256, 384])


In [120]:
# ------------------------------
# 
# Positional Embedding
#
# ------------------------------

# We also encode the position

class PositionalEmbedding(nn.Module):
    def __init__(self, block_size : int, n_embed : int):
        super().__init__()
        self.context_length = block_size
        self.n_embed = n_embed
        self.embedding = nn.Embedding(block_size, n_embed)

    def forward(self, x):
        # Here the input will be a tensor of the form (0, 1, 2, ... block_size)
        return self.embedding(x) * np.sqrt(self.n_embed)  # (block_size) -> (block_size, n_embed)

# Test
positional_embedding = PositionalEmbedding(block_size=block_size, n_embed=n_embed)
print(positional_embedding(torch.arange(block_size)).shape)

torch.Size([256, 384])


In [121]:
# ------------------------------
# 
# Projection Layer
#
# ------------------------------

#  This is just a linear layer that projects from embedding space to vocab space. It is applied at the end of the transformer

class ProjectionLayer(nn.Module):
    def __init__(self, vocab_size : int, n_embed : int):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_embed = n_embed
        self.projection_layer = nn.Linear(n_embed, vocab_size)

    def forward(self, x):
        return self.projection_layer(x)   # (batch, block_size, n_embed) -> (batch, block_size, vocab_size)

# Test
projection_layer = ProjectionLayer(vocab_size=vocab_size, n_embed=n_embed)
xt = torch.randn(batch_size, block_size, n_embed)
print(projection_layer(xt).shape)

torch.Size([64, 256, 91])


In [122]:
# ------------------------------
# 
# Muli-Head Attention
#
# ------------------------------

# 1. The multihead attention dividies the embedding dimension into multiple smaller attentions. 
# 2. The number of heads must divide the input dimension (the sequence one)

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, n_embed : int, num_heads : int, block_size : int , dropout : float):
        super().__init__()
        assert n_embed % num_heads == 0, "num_heads must divide n_embed" # check that n_embed can be divided by the num of heads
        self.n_embed = n_embed
        self.num_heads = num_heads
        self.head_size = n_embed // num_heads

        self.query = nn.Linear(n_embed, n_embed, bias=False) # Wq
        self.key = nn.Linear(n_embed, n_embed, bias=False) # Wk
        self.value = nn.Linear(n_embed, n_embed, bias=False) # Wv
        self.proj = nn.Linear(n_embed, n_embed, bias=False) # Wo

        self.register_buffer('causal_mask', torch.tril(torch.ones(block_size, block_size))) # (block_size, block_size) lower diagonal matrix
   
        self.dropout = None if dropout == None else nn.Dropout(dropout)
        self.attention_scores = None


    def forward(self, x):
        batch, x_block_size, n_emb = x.shape
        
        q = self.query(x) # (batch, x_block_size, n_emb) @ (batch, n_emb, n_emb) -> (batch, x_block_size, n_emb)
        k = self.key(x) 
        v = self.value(x) 

        # Split across the embedding dimension and rearange.
        # (batch, x_block_size, n_emb) -> (batch, x_block_size, num_head, head_size) ->  (batch, num_head, x_block_size, head_size)
        q = q.view(q.shape[0],  q.shape[1], self.num_heads, self.head_size).transpose(1,2) # (batch, num_head, x_block_size, head_size)
        k = k.view(k.shape[0],  k.shape[1], self.num_heads, self.head_size).transpose(1,2)
        v = v.view(v.shape[0],  v.shape[1], self.num_heads, self.head_size).transpose(1,2)

        # Get the weights
        weights = q @ k.transpose(-1, -2) / np.sqrt(self.head_size) # (batch, num_head, x_block_size, head_size) @  (batch, num_head, head_size, x_block_size) -> (batch, num_head, x_block_size, x_block_size)

        # Causal mask 
        weights = weights.masked_fill(self.causal_mask[:x_block_size, :x_block_size] == 0, float('-inf')) # (batch, num_head, x_block_size, x_block_size)

        # Apply softmax (the -inf will be 0)
        weights = torch.softmax(weights, dim=-1) # sofmax each row (batch, num_heads, x_block_size, x_block_size)
        
        # Apply dropout
        # if self.dropout is not None: weights = self.dropout(weights)
        
        # Finally multiply by the values
        out =  weights @ v # (batch, num_head, x_block_size, x_block_size) @ (batch, num_head, x_block_size, head_size) -> (batch, num_head, x_block_size, head_size)
        
        # save the weights
        self.attention_scores = weights
        
        # Concat the heads
        # (batch, num_heads, x_block_size, head_size) -> (batch, x_block_size, num_heads, head_size) -> (batch, x_block_size, n_embed)
        out = out.transpose(1,2).contiguous().view(out.shape[0], -1, self.num_heads * self.head_size) # (batch, x_block_size, n_embed). 
        
        out = self.proj(out) # (batch, x_block_size, n_embed) @ (n_embed, n_embed) -> (batch, x_block_size, n_embed)

        # Apply dropout
        if self.dropout is not None: out = self.dropout(out)
        
        return out

In [123]:
# ------------------------------
# 
# Feed Forward
#
# ------------------------------

class FeedForward(nn.Module):
    def __init__(self, n_embed : int,  dropout : float):
        super().__init__()
        self.n_embed = n_embed
        self.ff1 = nn.Linear(n_embed, 4*n_embed) # hard code the hiden dim
        self.ff2 = nn.Linear(4*n_embed, n_embed)
        
        self.dropout = nn.Dropout(dropout) 
    
    def forward(self, x):
        x = F.relu(self.ff1(x)) # (batch, x_block_size, n_embed) -> (batch, x_block_size, 4 * n_embed)
        x = self.ff2(x) # (batch, x_block_size, 4*n_embed) -> (batch, x_block_size, n_embed)
        x = self.dropout(x)
        return x

In [124]:
# ------------------------------
# 
# Layer Norm
#
# ------------------------------

# Check the notes
# TODO: check the mean(dim=-1). The sequence dimension is the dim = 1, no?
# Normalized the rows instead of the columns (batch norm). This is taken from Karpathy, I have to check exaclty

class LayerNorm(nn.Module):

  def __init__(self, dim, eps=1e-6):
    super().__init__()
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)

  def forward(self, x):
    # calculate the forward pass
    xmean = x.mean(-1, keepdim=True) # batch mean
    xvar = x.var(-1, keepdim=True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out


# Test
layer_norm = LayerNorm(2)
out = layer_norm(torch.randn(16, 8, 2))
print(out.shape)
print(out[0, :].std()) # normalized across the context dimension
print(out[0, :].mean()) # normalized across the context dimension

torch.Size([16, 8, 2])
tensor(0.7303)
tensor(-3.7253e-09)


In [125]:
# ------------------------------
# 
# Transformer Layer
#
# ------------------------------


class Layer(nn.Module):
    def __init__(self, n_embed: int, num_heads: int, block_size : int,  dropout: float):
        super().__init__()
        self.self_attention = MultiHeadAttentionBlock(n_embed=n_embed, num_heads=num_heads, block_size=block_size, dropout=dropout)
        self.feed_forward = FeedForward(n_embed=n_embed, dropout=dropout)
        self.layer_norm1 = LayerNorm(n_embed)
        self.layer_norm2 = LayerNorm(n_embed)


    def forward(self, x):
        # Just like this, it does not work, we get loss of 2.33, with 4 layers. Gradients might vanish
        # x = self.self_attention(x)
        # x = self.feed_forward(x) 

        # Add the residual connections + layer norms
        x = x + self.self_attention(self.layer_norm1(x)) # (batch, x_block_size, n_embed) ->  (batch, x_block_size, n_embed)
        x = x + self.feed_forward(self.layer_norm2(x)) # (batch, x_block_size, n_embed) ->  (batch, x_block_size, n_embed)
        return x

In [126]:
class Transformer(nn.Module):
    def __init__(self, vocab_size : int, n_embed : int, num_heads: int, num_layers: int,  block_size : int):
        super().__init__()
        self.vocab_size = vocab_size 
        self.n_embed = n_embed
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.block_size = block_size

        self.token_embedding = TokenEmbedding(vocab_size=vocab_size, n_embed=n_embed)
        self.positional_embedding = PositionalEmbedding(block_size=block_size, n_embed=n_embed)

        # attention
        # self.sa_block = MultiHeadAttentionBlock(n_embed=n_embed, num_heads=num_heads, block_size=block_size, dropout=None)
        # self.ff_block = FeedForward(n_embed=n_embed, hid_dim=4*n_embed, dropout=0.0)

        self.layers = nn.Sequential( 
            *[ Layer(n_embed=n_embed, num_heads=num_heads, dropout=0.0) for _ in range(num_layers)] )


        self.projection_layer = ProjectionLayer(vocab_size=vocab_size, n_embed=n_embed)

        # Initialize the parameters
        self._init_parameters()

    def forward(self, idx, targets=None):
        batch_size, idx_block_size = idx.shape
        # print(f"{block_size=}")

        device = idx.device 

        tok_emb = self.token_embedding(idx) # (batch, block_size) -> (batch, block_size, n_embed)
        pos_emb = self.positional_embedding(torch.arange(idx_block_size, device=device)) # (block_size, n_embed)
        x = tok_emb + pos_emb # broadcasting works. (batch, block_size, n_embed)

        # attention
        # x = self.sa_block(x) # (batch, idx_block_size, n_embed) -> (batch, idx_block_size, n_embed)
        # x = self.ff_block(x) # (batch, idx_block_size, n_embed)
        x = self.layers(x) # (batch, idx_block_size, n_embed) -> (batch, idx_block_size, n_embed)


        # the last layer
        logits = self.projection_layer(x) # (batch, block_size, n_embed) -> (batch, block_size, vocab_size)

        loss = None
        if targets != None: 
            # cross entropy expects the form (batch, classes). Essentially, each token in a sequence acts as an independent input. So we can group it with the batch 
            logits = logits.view(logits.shape[0] * idx_block_size, vocab_size) # (batch, block_size, vocab_size) -> (batch * block_size, vocab_size) 
            # The targets are of the form (batch, block_size)
            targets = targets.view(targets.shape[0] * idx_block_size) # no need to create a one hot encoding. In this form, cross entropy does it for you (CHECK)
            loss = F.cross_entropy(logits, targets) 
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop the idx so that only take the maximum block size.
            idx_crop = idx[:, -self.block_size:]
            # forward pass the model
            logits, _ = self.forward(idx_crop) # (batch = 1 for 1 prediciton, block_size) -> (batch, block_size, vocab_size)
            # get the last token, we want to predict the next token from the last (given all the context, etc...)
            logits = logits[:, -1, :] # (batch, block_size, vocab_size) -> (batch, vocab_size)
            # apply softmax to get a probability distribution across the vocabulary
            probs = F.softmax(logits, dim=-1) 

            # Then we could get the token with most associated probability for the next token. But we can also sample from this
            idx_next = torch.multinomial(probs, num_samples=1) # (batch, 1)
            
            # Append the token to the previous one, that serves as a new context
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    

    # Initiliazion of the parameters, as in the paper. I think PyTorch does it automatically, because all are linear layers. 
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1: nn.init.xavier_uniform_(p)

# Generate text
def write_text(model, max_tokens, decode, filename=None):
    initial_context = torch.zeros((1, 1), dtype=torch.int64) # (batch_size = 1, block_size = 1)
    if filename != None: f = open(filename, 'w')
    print(decode( model.generate(initial_context, max_new_tokens=max_tokens)[0].tolist() ), file = None if filename == None else f)

In [127]:
# Test batch
model = Transformer(vocab_size=vocab_size, n_embed=n_embed, num_heads=4, num_layers=4, block_size=block_size)
model.to(device)
xt, yt = get_batch('train', batch_size=batch_size, block_size=block_size)
logitst, losst = model(xt, yt)
print(logitst.shape, losst.item())

torch.Size([16384, 91]) 8.133414268493652


#### Training

In [128]:
# ------------------------
#
# Training loop
#
# ------------------------


model = Transformer(vocab_size=vocab_size, n_embed=n_embed, num_heads=num_heads, num_layers=num_layers, block_size=block_size)
model.to(device)


# The optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Estimate loss
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size=batch_size, block_size=block_size)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


# Test before training
write_text(model, max_tokens=1000, decode=decode, filename='./output/random.txt')

In [129]:

for iter in range(max_iters):
 
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


    # get batch
    xb, yb = get_batch('train', batch_size=batch_size, block_size=block_size)
    #forward
    logits, loss = model(xb, yb)
    #backward
    optimizer.zero_grad(set_to_none=True) # first we set to zero the gradients
    loss.backward() # compute the gradients
    optimizer.step() # update the parameters
    

step 0: train loss 8.6909, val loss 8.6736
step 200: train loss 2.3514, val loss 2.3619
step 400: train loss 2.3079, val loss 2.3128
step 600: train loss 2.2234, val loss 2.2232
step 800: train loss 2.0193, val loss 2.0335
step 1000: train loss 1.8432, val loss 1.8593
step 1200: train loss 1.7506, val loss 1.7587
step 1400: train loss 1.6447, val loss 1.6652
step 1600: train loss 1.5727, val loss 1.6028
step 1800: train loss 1.5128, val loss 1.5332
step 2000: train loss 1.4621, val loss 1.4933
step 2200: train loss 1.4192, val loss 1.4588
step 2400: train loss 1.3781, val loss 1.4159
step 2600: train loss 1.3531, val loss 1.3924


KeyboardInterrupt: 

In [130]:
# Save the model
checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict() }
torch.save(checkpoint, "checkpoint.pt")

In [134]:
write_text(model, max_tokens=5000, decode=decode, filename='./output/trained.txt')

## Notes and tests on the attention mechanishm

In [115]:
B, T, D = 4, 8, 2 # batch, block_size, embeding_dimension (channels)
x_in = torch.randn(B, T, D) # the values. It is like a private information of the token

tril = torch.tril(torch.ones((T,T)))
wei = torch.zeros((T,T)) # This weights will be the dot product between the keys (what I have) and the queries (what I look for)
wei = wei.masked_fill(tril == 0, float('-inf')) # make the tokens only interact with the past tokens 
wei = F.softmax(wei, dim=-1) # normalize it
wei # this will be an interaction matrix between the tokens
print(f"{wei=}, {wei.shape=}")
x_out = wei @ x_in # x_in wil be the values
print(f"{x_out[0]=}, {x_out.shape=}")

wei=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]), wei.shape=torch.Size([8, 8])
x_out[0]=tensor([[ 5.4973e-02,  9.7783e-01],
        [ 7.6479e-01,  1.2942e-01],
        [ 6.9221e-01,  2.4644e-01],
        [ 4.8400e-01,  2.6666e-02],
        [ 2.3302e-01, -4.3245e-01],
        [ 6.8370e-02,  1.3068e-01],
        [ 2.4524e-01,  1.7378e-04],
        [ 1.0304e-01, -8.5936e-02]]), x_out.shape=torch.Size([4, 8, 2])


In [None]:

# ------------------------------
# 
# Head of Self-Attention
#
# ------------------------------

# Receives a input with dimensions (batch, block_size, head_size) and outputs the same

class Head(nn.Module):
    def __init__(self, head_size : int, block_size : int):
        super().__init__()
        self.head_size = head_size

        self.query = nn.Linear(head_size, head_size, bias=False) # The Wq matrix, just a layer of interaction
        self.key = nn.Linear(head_size, head_size, bias=False) # The Wk matrix
        self.value = nn.Linear(head_size, head_size, bias=False) # THe Wv matrix
        
        # This is for decoder attention. This could be moved in the forward method
        self.register_buffer('mask', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        batch, x_block_size, head_size = x.shape

        q = self.query(x) # (batch, x_block_size, head_size) -> (batch, x_block_size, head_size) 
        k = self.key(x)
        v = self.value(x)

        # compute the weights
        wei = q @ k.transpose(-2,-1) / np.sqrt(head_size) # (batch, x_block_size, head_size) @ (batch, head_size, block_size) -> (batch, x_block_size, block_size)

        # mask the affinities, only interaction with past tokens
        wei = wei.masked_fill(self.mask[:x_block_size, :x_block_size] == 0, float('-inf'))
        
        # normalize
        wei = torch.softmax(wei, dim=-1) # softmax across the rows

        out = wei @ v # (batch, x_block_size, x_block_size) @ (batch, x_block_size, head_size) -> (batch, x_block_size, head_size)

        return out

# Test
head = Head(head_size=16, block_size=block_size)
xt = torch.randn(batch_size, block_size, 16)
print(head(xt).shape)

In [None]:
# ------------------------------
# 
# Muti Head of Self-Attention
#
# ------------------------------

# Receives a input with dimensions (batch, block_size, head_size) and outputs the same

class MuliHeadAttention(nn.Module):
    def __init__(self, num_heads : int , head_size : int, block_size : int):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.block_size = block_size

        self.heads = nn.ModuleList([Head(head_size=head_size, block_size=block_size) for _ in range(num_heads)])      

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)  # (batch, x_block_size, head_size) -> (batch, x_block_size, head_size * num_heads)

# Test
multihead = MuliHeadAttention(num_heads=4, head_size=16//4, block_size=block_size)
xt = torch.randn(batch_size, block_size, 16//4)
print(multihead(xt).shape)

##### Layer Norm
From the paper https://arxiv.org/pdf/1607.06450

If we have a batch of size N, then for item in the batch, let's say sentences of seq_len size, then we compute the mean and the std of each one and normalize the element as

$$ x_i' = \gamma \dfrac{x_i - \mu_i}{\sqrt{\sigma_j^2 + \epsilon}} + \beta$$
