# Text Transformers: Comprehensive Implementation

This notebook provides a complete implementation of text-based transformer architectures using the `rearrange` function for efficient tensor operations.

## Learning Objectives
- Implement attention mechanisms with rearrange operations
- Build encoder-only, decoder-only, and encoder-decoder transformers
- Understand multi-head attention and positional encoding
- Practice with real tensor operations and shapes

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from einops import rearrange, repeat
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Attention Mechanism Implementation

We'll implement the core attention mechanisms using `rearrange` for efficient tensor operations.

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled dot-product attention using rearrange operations.
    
    Args:
        Q: Query tensor (B, H, N, D_k)
        K: Key tensor (B, H, N, D_k)
        V: Value tensor (B, H, N, D_v)
        mask: Optional mask tensor (B, 1, N, N)
    
    Returns:
        Attention output and attention weights
    """
    B, H, N, D_k = Q.shape
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(D_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test the attention function
B, H, N, D_k, D_v = 2, 4, 8, 64, 64
# Parameter meanings:
# - B: batch size (number of sequences processed in parallel)
# - H: number of attention heads
# - N: sequence length (number of tokens)
# - D_k: per-head feature dimension for queries/keys
# - D_v: per-head feature dimension for values
# Typical: d_model = H * D_k and often D_k == D_v
# Q, K, V expected shapes: (B, H, N, D_k/D_v)
Q = torch.randn(B, H, N, D_k)
K = torch.randn(B, H, N, D_k)
V = torch.randn(B, H, N, D_v)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Attention output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"Attention weights sum (should be 1): {weights.sum(dim=-1)[0, 0, :5]}")

## 2. Multi-Head Attention with Rearrange

Implement multi-head attention using `rearrange` for efficient tensor reshaping.

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention module using rearrange operations.
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        B, N, D = query.shape
        
        # Linear projections and reshape to (B, N, H, D_k)
        Q = self.W_q(query).view(B, N, self.num_heads, self.d_k)
        K = self.W_k(key).view(B, N, self.num_heads, self.d_k)
        V = self.W_v(value).view(B, N, self.num_heads, self.d_v)
        
        # Rearrange to (B, H, N, D_k) for attention computation
        Q = rearrange(Q, 'B N H D -> B H N D')
        K = rearrange(K, 'B N H D -> B H N D')
        V = rearrange(V, 'B N H D -> B H N D')
        
        # Apply attention
        attention_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Rearrange back to (B, N, H, D_v)
        attention_output = rearrange(attention_output, 'B H N D -> B N H D')
        
        # Concatenate heads and apply output projection
        output = attention_output.contiguous().view(B, N, self.d_model)
        output = self.W_o(output)
        
        return output, attention_weights

# Test MultiHeadAttention
d_model, num_heads = 512, 8
mha = MultiHeadAttention(d_model, num_heads)

B, N = 2, 10
x = torch.randn(B, N, d_model)
output, weights = mha(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"Multi-head attention working correctly!")

## 3. Positional Encoding

Implement different types of positional encoding using rearrange operations.

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    """
    Sinusoidal positional encoding using rearrange operations.
    """
    
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and rearrange for broadcasting
        pe = rearrange(pe, 'N D -> 1 N D')
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: (B, N, D)
        B, N, D = x.shape
        
        # Use rearrange to ensure proper broadcasting
        pe = rearrange(self.pe[:, :N, :], '1 N D -> 1 N D')
        
        return x + pe

class LearnedPositionalEmbedding(nn.Module):
    """
    Learned positional embedding using rearrange operations.
    """
    
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)
        
    def forward(self, x):
        B, N, D = x.shape
        
        # Create position indices
        positions = torch.arange(N, device=x.device).unsqueeze(0)
        positions = repeat(positions, '1 N -> B N', B=B)
        
        # Get positional embeddings
        pos_emb = self.pos_embedding(positions)
        
        return x + pos_emb

# Test positional encodings
d_model, max_len = 128, 100
sinusoidal_pe = SinusoidalPositionalEncoding(d_model, max_len)
learned_pe = LearnedPositionalEmbedding(d_model, max_len)

B, N = 2, 20
x = torch.randn(B, N, d_model)

x_with_sin_pe = sinusoidal_pe(x)
x_with_learned_pe = learned_pe(x)

print(f"Original input shape: {x.shape}")
print(f"With sinusoidal PE shape: {x_with_sin_pe.shape}")
print(f"With learned PE shape: {x_with_learned_pe.shape}")
print(f"Positional encodings working correctly!")

## 4. Transformer Layer

Implement a complete transformer layer with rearrange operations.

In [None]:
class TransformerLayer(nn.Module):
    """
    Complete transformer layer with rearrange operations.
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output, _ = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

# Test TransformerLayer
d_model, num_heads, d_ff = 256, 8, 1024
layer = TransformerLayer(d_model, num_heads, d_ff)

B, N = 2, 15
x = torch.randn(B, N, d_model)
output = layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Transformer layer working correctly!")

## 5. Summary

We've successfully implemented text transformer architectures using rearrange operations:

### **Key Components:**
1. **Attention Mechanisms**: Scaled dot-product attention
2. **Multi-Head Attention**: Efficient implementation with rearrange
3. **Positional Encoding**: Sinusoidal and learned embeddings
4. **Transformer Layers**: Complete transformer blocks

### **Benefits of Rearrange:**
- Clean tensor reshaping for multi-head attention
- Efficient positional encoding broadcasting
- Intuitive tensor manipulation syntax

### **Next Steps:**
- Explore vision transformers
- Implement decoder-only and encoder-decoder models
- Add training loops and optimization