## Masked multihead attention

In [18]:
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []
        
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
        
        if len(token_ids) == max_length:
            input_chunk = token_ids
            target_chunk = token_ids[1:] + [token_ids[-1]]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
        elif len(token_ids) < max_length:
            padded_tokens = token_ids + [0] * (max_length - len(token_ids))
            input_chunk = padded_tokens
            target_chunk = padded_tokens[1:] + [0]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
        else:
            for i in range(0, len(token_ids) - max_length, stride):
                input_chunk = token_ids[i : i + max_length]
                target_chunk = token_ids[i + 1 : i + max_length + 1]
                self.input_ids.append(torch.tensor(input_chunk))
                self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader_v1(txt, batch_size=2, max_length=256,
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):
    
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )
    
    return dataloader    



raw_text = "Yours journey start with one steps"
tokenizer = tiktoken.get_encoding("gpt2")
token_ids = tokenizer.encode(raw_text, allowed_special={"<|endoftext|>"})

max_length = len(token_ids)
vocab_size = 50257
output_dim = 3

token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(max_length, output_dim)

dataloader = create_dataloader_v1(
    raw_text, batch_size=1, max_length=max_length,
    stride=max_length, shuffle=False
)

data_iter = iter(dataloader)
inputs, targets = next(data_iter)

token_embeddings = token_embedding_layer(inputs)
pos_embeddings = pos_embedding_layer(torch.arange(max_length))

input_embeddings = token_embeddings + pos_embeddings
input_embeddings[0]


tensor([[ 1.9363, -0.6354,  1.5966],
        [ 0.1375,  1.2886,  2.6458],
        [-1.8003,  1.8194, -0.6992],
        [-2.8168, -0.4561, -0.4425],
        [ 0.1895, -0.6746,  1.6575],
        [-0.4451, -0.2274, -0.9869],
        [ 0.1162, -1.4226,  1.1148]], grad_fn=<SelectBackward0>)

In [19]:
import torch.nn as nn
class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout):
        super().__init__()
        self.w_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_value = torch.nn.Linear(d_in, d_out, bias=False)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, inputs):
        queries = self.w_query(inputs)
        keys = self.w_key(inputs)
        values = self.w_value(inputs)
        
        attn_scores = queries @ keys.transpose(-2, -1) 
        attn_scores = attn_scores.masked_fill(self.mask == 1, float('-inf'))
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        
        return context_vec

In [20]:
import torch
import torch.nn as nn

class MultiHeadCausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        # Create multiple attention heads
        self.heads = nn.ModuleList([
            CausalAttention(d_in, self.head_dim, context_length, dropout) 
            for _ in range(num_heads)
        ])
        
        # Output projection layer (as per the paper)
        self.out_proj = nn.Linear(d_out, d_out)
        
    def forward(self, inputs):
        # Apply each attention head
        head_outputs = [head(inputs) for head in self.heads]
        
        # Concatenate all head outputs along the last dimension
        concatenated = torch.cat(head_outputs, dim=-1)
        
        # Apply final linear projection
        output = self.out_proj(concatenated)
        
        return output

In [23]:
torch.manual_seed(789)
cols = input_embeddings[0].shape[1]
ca = MultiHeadCausalAttention(cols, output_dim, context_length=max_length, dropout=0.0,num_heads=3)

result = ca(input_embeddings)
result

tensor([[[ 0.4205,  0.5018, -0.4273],
         [ 0.2285,  0.0079, -0.4412],
         [ 0.3052, -0.6301, -0.3327],
         [ 0.3230, -0.6968, -0.2863],
         [ 0.3216, -0.7436, -0.2866],
         [ 0.3670, -0.5896, -0.2913],
         [ 0.3527, -0.5647, -0.2942]]], grad_fn=<ViewBackward0>)