<a href="https://www.kaggle.com/code/aisuko/multi-head-attention-plus-data-loading?scriptVersionId=164079402" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

In this notebook, we implement `multihead-attention` and `Dataloader`.

In [1]:
%%capture
# BPE(Byte pair encoding) also named diagram coding: is an algorithem,with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words) 
!pip install tiktoken==0.6.0

# DataLoader

In [2]:
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.tokenizer=tokenizer
        self.input_ids=[]
        self.target_ids=[]
        
        # tokenize the entire text
        token_ids=tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
        
        # use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids)-max_length, stride):
            input_chunk=token_ids[i:i+max_length]
            target_chunk=token_ids[i+1:i+max_length+1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
    # initialize the tokenizer
    tokenizer=tiktoken.get_encoding('gpt2')
    
    # create dataset
    dataset=GPTDatasetV1(txt, tokenizer, max_length, stride)
    
    # create dataloader
    dataloader=DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    
    return dataloader


# with open('small-text-sample.txt', 'r', encoding='utf-8') as f:
#     raw_text=f.read()

raw_text='''Once upon a time in a quiet village nestled among rolling hills and whispering forests, there lived a young girl named Elara. Elara was known for her boundless curiosity and her love for the stars. Every night, she would climb to the highest hill near her home to gaze at the glittering sky, dreaming of distant worlds and galaxies.

In the heart of the village, there was an ancient library, tended by an old, wise librarian named Mr. Bramwell. This library was a treasure trove of books on every subject, but most importantly, it housed a collection of old star maps and celestial guides. Elara, fascinated by these books, spent countless hours with Mr. Bramwell, learning about constellations, planets, and the mysteries of the universe.

One evening, while studying an old star map, Elara noticed a small, uncharted star that twinkled differently. She shared this discovery with Mr. Bramwell, who was equally intrigued. They decided to observe this star every night, noting its unique patterns and movements. This small, mysterious star, which they named "Elara's Star," became the center of their nightly adventures.

As days turned into weeks, the villagers began to take notice of Elara's star. The uncharted star brought the community together, with people of all ages joining Elara and Mr. Bramwell on the hill each night to gaze at the sky. The nightly gatherings turned into a festival of stars, where stories were shared, friendships were formed, and the mysteries of the cosmos were contemplated.

The story of Elara and her star spread far and wide, attracting astronomers and dreamers from distant lands. The once quiet village became a beacon of wonder, a place where the sky seemed a little closer and the stars a bit friendlier. Elara's curiosity had not only unveiled a hidden star but had also brought her community together, reminding everyone that sometimes, the most extraordinary discoveries are waiting just above us, in the starlit sky.
'''
    
tokenizer=tiktoken.get_encoding('gpt2')
encoded_text=tokenizer.encode(raw_text)

vocab_size=50257
output_dim=256
max_len=1024
block_size=max_len

token_embedding_layer=nn.Embedding(vocab_size, output_dim)
pos_embedding_layer=torch.nn.Embedding(block_size, output_dim)

max_length=4
dataloader=create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)

In [3]:
for batch in dataloader:
    x, y=batch
    token_embeddings=token_embedding_layer(x)
    pos_embeddings=pos_embedding_layer(torch.arange(max_length))
    
    input_embeddings=token_embeddings+pos_embeddings
    
    break

print(input_embeddings.shape)

torch.Size([8, 4, 256])


# Multi-head Attention

## Variant A: Simple implementation

In [4]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout=nn.Dropout(dropout) # new
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # new
        
    def forward(self, x):
        b, n_tokens, d_in=x.shape # new batch dimension b
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        
        attn_scores=queries@keys.transpose(1,2) # changed transpose
        # new, _ops are in-place
        attn_scores.masked_fill_(
            self.mask.bool()[:n_tokens, :n_tokens], -torch.inf
        )
        
        atten_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=1)
        atten_weights=self.dropout(atten_weights) # new
        
        context_vec=atten_weights@values
        return context_vec
    

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads=nn.ModuleList([CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) for _ in range(num_heads)])
        self.out_proj=nn.Linear(d_out*num_heads, d_out*num_heads)
    
    def forward(self, x):
        context_vec=torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)
    

torch.manual_seed(123)
block_size=max_length
d_in=output_dim

num_heads=2
d_out=d_in//num_heads

mha=MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads)

batch=input_embeddings
context_vecs=mha(batch)

print('context_vecs.shape:', context_vecs.shape)

context_vecs.shape: torch.Size([8, 4, 256])


## Vatiant B: Alternative implementation

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads==0, "d_out must be divisible by n_heads"
        
        self.d_out=d_out
        self.num_heads=num_heads
        self.head_dim=d_out//num_heads # reduce the projection dim to match desired output dim
        
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj=nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in=x.shape
        
        keys=self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries=self.W_query(x)
        values=self.W_value(x)
        
        # We implicity split the matrix by adding a `num_heads` dimension unroll last dimen: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys=keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values=values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries=queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)
        
        # compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores=queries@keys.transpose(2,3) # Dot product for each head
        
        # original mask truncated to the number of tokens and converted to boolean
        mask_bool=self.mask.bool()[:num_tokens, :num_tokens]
        
        #unsqueeze the mask twice to match dimensions
        mask_unsqueezed=mask_bool.unsqueeze(0).unsqueeze(0)
        
        # use the unsqueezed mask to fill attention scores
        attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
        
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights=self.dropout(attn_weights)
        
        #shape: (b, num_tokens, num_heads, head_dim)
        context_vec=(attn_weights@values).transpose(1,2)
        
        #combine heads, where self.d_out=self.num_heads * self.head_dim
        context_vec=context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec=self.out_proj(context_vec) # optional projection
        
        return context_vec

torch.manual_seed(123)
block_size=max_length
d_in=output_dim
d_out=d_in

mha=MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)

batch=input_embeddings
context_vecs=mha(batch)

print('context_vecs.shape:', context_vecs.shape)   

context_vecs.shape: torch.Size([8, 4, 256])
