In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" 
        
        self.d_model = d_model # The dimensionality of all representations
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)"""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Combine heads back to original dimension"""
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    ## Forward pass
    def forward(self, query, key, value, mask=None):
        # Linear projections
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Split into multiple heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask (for causal attention) 
        # This is to zero out all future positions, forcing the predictions to use past and present information only, which is especially important for text generation.   
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)
        
        # Combine heads
        attn_output = self.combine_heads(attn_output)
        
        # Final linear projection
        output = self.W_o(attn_output)
        
        return output

`d_model` --> the dimsensionality of all representations
- if `d_model` = 256
  - each token embedding is a 256-dimensional vector
  - each position embedding is 256-dimensional
  - output of each attention layer is 256-dimensional
  - everything flow through the model is 256-dimensional space
- Think 'width' of a neural network - how much information capacity each token holds.

`d_k` --> dimension per Attention Head
- the dimension of key (as well as query and value) vectors 
  - (`d_k` = `d_model` / `num_heads`)
- divides `d_model` dimension into smaller "heads" that can learn different aspects of the relationships between tokens. More attention heads in parallel.
- each head operates independently and gets concatenated back together to output `d_model`

`Q`, `K`, `V` 
- `Q` -> Asks a question on what information is needed.
- `K` -> Key mapping to the question.
- `V` -> The content of the key. Think key-value pair in python `dict`.
- Think soft database lookup

Attention formula 

`scores = Q @ K^T / sqrt(d_k)`   How relevant is each token?

`attention_weights = softmax(scores)`   Normalize to probabilities

`output = attention_weights @ V`   Weighted sum of values

Quick lesson on softmax 
- Given [0.9, -1.5, 3.2], we first want to exponentiate all values to make all values positive and amplify the differences -> [2.45960311116, 0.22313016014, 24.5325301971].
- We then normalize it to find the probability distribution -> [2.45960311116/27.2152634684, 0.22313016014/27.2152634684, 24.5325301971/27.2152634684] --> [0.090376136, 0.008198731, 0.901425133]

`dropout`
- Controls regularization.
- In training, dropout randomly sets the percentage of neurons to zero.
- If `dropout` = 0.1, 10% of neurons are deactivated and 90% remains activated.
  - remaining neurons are scaled up and 'strengthened' to compensate
- This prevents overfitting so that the training does not rely on a single/few neurons, making the model more robust.
- Only activated during training, auto-disabled during inference/test. 

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

Forward pass



In [None]:
class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x

In [None]:
class TwoLayerTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=4, d_ff=1024, 
                 max_seq_len=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        # Token embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional embeddings
        self.positional_embedding = nn.Embedding(max_seq_len, d_model)
        
        # Two transformer layers
        self.layer1 = TransformerLayer(d_model, num_heads, d_ff, dropout)
        self.layer2 = TransformerLayer(d_model, num_heads, d_ff, dropout)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def create_causal_mask(self, seq_len, device):
        """Create a causal mask to prevent attending to future tokens"""
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        return mask.view(1, 1, seq_len, seq_len)
    
    def forward(self, x):
        batch_size, seq_len = x.size()
        
        # Create position indices
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        
        # Token + positional embeddings
        token_emb = self.token_embedding(x)
        pos_emb = self.positional_embedding(positions)
        x = self.dropout(token_emb + pos_emb)
        
        # Create causal mask
        mask = self.create_causal_mask(seq_len, x.device)
        
        # Pass through two transformer layers
        x = self.layer1(x, mask)
        x = self.layer2(x, mask)
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits