# 01 - Transformer from Scratch

Deep learning paper implementation from scratch using PyTorch.
1. **Sinusoidal Positional Encoding** - Fixed position embeddings using sine/cosine functions
- Positional encoding ON vs OFF
- Number of attention heads: 4 vs 8

In [None]:
import math
import random
import time
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")

In [None]:
def set_seed(seed: int = 42):
    # Set all random seeds for reproducibility.
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed(42)


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 4. Configuration

In [None]:
@dataclass
class TransformerConfig:

    # Model architecture
    d_model: int = 256          # Embedding dimension
    n_heads: int = 8            # Number of attention heads
    n_layers: int = 4           # Number of encoder/decoder layers
    d_ff: int = 512             # Feed-forward hidden dimension
    max_seq_len: int = 128      # Maximum sequence length
    vocab_size: int = 10000     # Vocabulary size (will be updated)
    dropout: float = 0.1        # Dropout rate
    
    # Training
    batch_size: int = 32
    learning_rate: float = 1e-4
    n_epochs: int = 10
    
    # Ablation flags
    use_positional_encoding: bool = True
    
config = TransformerConfig()


In [None]:
# In practice, you would load WikiText-2 or another corpus
SAMPLE_CORPUS = """
The quick brown fox jumps over the lazy dog.
A journey of a thousand miles begins with a single step.
To be or not to be that is the question.
All that glitters is not gold.
The only thing we have to fear is fear itself.
In the beginning was the word and the word was with god.
It was the best of times it was the worst of times.
Call me ishmael some years ago never mind how long precisely.
It is a truth universally acknowledged that a single man in possession of a good fortune must be in want of a wife.
Happy families are all alike every unhappy family is unhappy in its own way.
The sun rose slowly over the mountains casting long shadows across the valley below.
She walked through the forest listening to the birds singing their morning songs.
The old man sat by the fire remembering the days of his youth.
The city streets were busy with people hurrying to their destinations.
A gentle breeze blew through the open window bringing the scent of flowers.
The children played in the garden while their parents watched from the porch.
He picked up the book and began to read losing himself in the story.
The stars twinkled in the night sky like diamonds scattered across velvet.
She smiled at the memory of their first meeting so many years ago.
The waves crashed against the shore creating a soothing rhythm.
Deep in the forest there lived a wise old owl who knew many secrets.
The train departed from the station carrying passengers to distant lands.
Music filled the air as the orchestra began their evening performance.
The scientist worked late into the night trying to solve the puzzle.
Rain began to fall gently at first then harder until it became a downpour.
"""

# Repeat to create more data
CORPUS = (SAMPLE_CORPUS * 50).lower()
print(f"Corpus length: {len(CORPUS)} characters")

In [None]:
class SimpleTokenizer:
    
    def __init__(self, min_freq: int = 1):
        self.min_freq = min_freq
        self.word2idx: Dict[str, int] = {}
        self.idx2word: Dict[int, str] = {}
        
        # Special tokens
        self.pad_token = "<PAD>"
        self.unk_token = "<UNK>"
        self.bos_token = "<BOS>"
        self.eos_token = "<EOS>"
        self.mask_token = "<MASK>"
        
        self.special_tokens = [
            self.pad_token, self.unk_token, 
            self.bos_token, self.eos_token, 
            self.mask_token
        ]
        
    def fit(self, text: str) -> None:
        # Count word frequencies
        word_freq: Dict[str, int] = {}
        words = text.split()
        for word in words:
            word_freq[word] = word_freq.get(word, 0) + 1
        
        # Build vocabulary
        for i, token in enumerate(self.special_tokens):
            self.word2idx[token] = i
            self.idx2word[i] = token
        
        idx = len(self.special_tokens)
        for word, freq in sorted(word_freq.items()):
            if freq >= self.min_freq:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1
                
        print(f"Vocabulary size: {len(self.word2idx)}")
    
    def encode(self, text: str) -> List[int]:

        unk_idx = self.word2idx[self.unk_token]
        return [self.word2idx.get(word, unk_idx) for word in text.split()]
    
    def decode(self, indices: List[int]) -> str:

        return " ".join([self.idx2word.get(idx, self.unk_token) for idx in indices])
    
    @property
    def vocab_size(self) -> int:
        return len(self.word2idx)
    
    @property
    def pad_idx(self) -> int:
        return self.word2idx[self.pad_token]

# Build tokenizer
tokenizer = SimpleTokenizer(min_freq=2)
tokenizer.fit(CORPUS)

# Update config
config.vocab_size = tokenizer.vocab_size
print(f"Updated vocab size: {config.vocab_size}")

In [None]:
class LanguageModelingDataset(Dataset):
    
    def __init__(self, text: str, tokenizer: SimpleTokenizer, seq_len: int):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        
        # Tokenize entire corpus
        self.tokens = tokenizer.encode(text)
        print(f"Total tokens: {len(self.tokens)}")
        
    def __len__(self) -> int:
        # Number of sequences we can create
        return max(0, len(self.tokens) - self.seq_len)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Input: tokens[idx:idx+seq_len]
        # Target: tokens[idx+1:idx+seq_len+1] (shifted by 1)
        x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
        y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
        return x, y

# Create dataset and dataloader
dataset = LanguageModelingDataset(CORPUS, tokenizer, config.max_seq_len)
dataloader = DataLoader(
    dataset, 
    batch_size=config.batch_size, 
    shuffle=True,
    drop_last=True
)

print(f"Number of batches: {len(dataloader)}")

## 6. Transformer Components

Now we build each component of the Transformer step by step.

### 6.1 Sinusoidal Positional Encoding

The positional encoding adds position information to embeddings using sine and cosine functions:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    
    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # Compute the div_term: 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # Shape: (1, max_seq_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

# Test positional encoding
pe = SinusoidalPositionalEncoding(d_model=config.d_model, max_seq_len=config.max_seq_len)
test_input = torch.zeros(2, 10, config.d_model)
test_output = pe(test_input)
assert test_output.shape == test_input.shape, "Shape mismatch!"
print(f"Positional encoding test passed. Output shape: {test_output.shape}")

In [None]:
# Visualize positional encoding
pe_viz = SinusoidalPositionalEncoding(d_model=64, max_seq_len=100)
pe_matrix = pe_viz.pe[0, :50, :].numpy()

plt.figure(figsize=(12, 4))
plt.imshow(pe_matrix.T, aspect='auto', cmap='RdBu')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Sinusoidal Positional Encoding')
plt.colorbar()
plt.tight_layout()
plt.show()

### 6.2 Scaled Dot-Product Attention

The attention mechanism computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

The scaling factor $\sqrt{d_k}$ prevents the dot products from becoming too large.

In [None]:
def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    dropout: Optional[nn.Dropout] = None
) -> Tuple[torch.Tensor, torch.Tensor]:

    d_k = query.size(-1)
    
    # Compute attention scores: Q @ K^T / sqrt(d_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply dropout if provided
    if dropout is not None:
        attention_weights = dropout(attention_weights)
    
    # Compute output: attention_weights @ V
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

# Test scaled dot-product attention
batch, n_heads, seq_len, d_k = 2, 4, 10, 32
q = torch.randn(batch, n_heads, seq_len, d_k)
k = torch.randn(batch, n_heads, seq_len, d_k)
v = torch.randn(batch, n_heads, seq_len, d_k)

out, attn_weights = scaled_dot_product_attention(q, k, v)
assert out.shape == (batch, n_heads, seq_len, d_k), f"Output shape mismatch: {out.shape}"
assert attn_weights.shape == (batch, n_heads, seq_len, seq_len), f"Attention weights shape mismatch: {attn_weights.shape}"
assert torch.allclose(attn_weights.sum(dim=-1), torch.ones(batch, n_heads, seq_len)), "Attention weights don't sum to 1!"
print(f"Scaled dot-product attention test passed.")
print(f"Output shape: {out.shape}, Attention weights shape: {attn_weights.shape}")

### 6.3 Multi-Head Attention

Multi-head attention allows the model to attend to information from different representation subspaces:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(head_1, ..., head_h)W^O$$

where $head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

In [None]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # Store attention weights for visualization
        self.attention_weights = None
        
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        batch_size = query.size(0)
        
        # Linear projections
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Reshape to (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply scaled dot-product attention
        attn_output, self.attention_weights = scaled_dot_product_attention(
            Q, K, V, mask=mask, dropout=self.dropout
        )
        
        # Reshape back: (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        # Final projection
        output = self.W_o(attn_output)
        
        return output

# Test Multi-Head Attention
mha = MultiHeadAttention(d_model=config.d_model, n_heads=config.n_heads)
x = torch.randn(2, 10, config.d_model)
out = mha(x, x, x)  # Self-attention
assert out.shape == x.shape, f"MHA output shape mismatch: {out.shape}"
print(f"Multi-Head Attention test passed. Output shape: {out.shape}")

### 6.4 Position-wise Feed-Forward Network

$$\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2$$

In [None]:
class PositionwiseFeedForward(nn.Module):

    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# Test FFN
ffn = PositionwiseFeedForward(config.d_model, config.d_ff)
x = torch.randn(2, 10, config.d_model)
out = ffn(x)
assert out.shape == x.shape, f"FFN output shape mismatch: {out.shape}"
print(f"FFN test passed. Output shape: {out.shape}")

### 6.5 Encoder Layer

Each encoder layer consists of:
1. Multi-head self-attention
2. Add & Norm (residual connection + layer normalization)
3. Feed-forward network
4. Add & Norm

In [None]:
class EncoderLayer(nn.Module):
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
   
        # Self-attention with residual connection and layer norm
        attn_output = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

# Test encoder layer
enc_layer = EncoderLayer(config.d_model, config.n_heads, config.d_ff)
x = torch.randn(2, 10, config.d_model)
out = enc_layer(x)
assert out.shape == x.shape
print(f"Encoder layer test passed. Output shape: {out.shape}")

### 6.6 Decoder Layer

Each decoder layer has:
1. Masked multi-head self-attention (causal)
2. Add & Norm
3. Multi-head cross-attention (attending to encoder output)
4. Add & Norm
5. Feed-forward network
6. Add & Norm

In [None]:
class DecoderLayer(nn.Module):

    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        self_mask: Optional[torch.Tensor] = None,
        cross_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
   
        # Masked self-attention
        attn_output = self.self_attention(x, x, x, self_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Cross-attention (attend to encoder output)
        cross_output = self.cross_attention(x, encoder_output, encoder_output, cross_mask)
        x = self.norm2(x + self.dropout(cross_output))
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

# Test decoder layer
dec_layer = DecoderLayer(config.d_model, config.n_heads, config.d_ff)
x = torch.randn(2, 10, config.d_model)
enc_out = torch.randn(2, 15, config.d_model)
out = dec_layer(x, enc_out)
assert out.shape == x.shape
print(f"Decoder layer test passed. Output shape: {out.shape}")

### 6.7 Mask Generation

We need two types of masks:
1. **Padding mask**: Prevents attention to padding tokens
2. **Causal mask**: Prevents attention to future tokens (for decoder)

In [None]:
def create_padding_mask(seq: torch.Tensor, pad_idx: int) -> torch.Tensor:
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)


def create_causal_mask(size: int, device: torch.device) -> torch.Tensor:
    mask = torch.tril(torch.ones(size, size, device=device)).bool()
    return mask.unsqueeze(0).unsqueeze(0)


# Test masks
test_seq = torch.tensor([[1, 2, 3, 0, 0], [1, 2, 0, 0, 0]])
pad_mask = create_padding_mask(test_seq, pad_idx=0)
print(f"Padding mask shape: {pad_mask.shape}")
print(f"Padding mask example:\n{pad_mask[0, 0, 0]}")

causal_mask = create_causal_mask(5, device='cpu')
print(f"\nCausal mask shape: {causal_mask.shape}")
print(f"Causal mask:\n{causal_mask[0, 0]}")

In [None]:
# Sanity check: masking correctness
# Verify that causal mask prevents attending to future positions

seq_len = 5
causal = create_causal_mask(seq_len, 'cpu')

for i in range(seq_len):
    allowed_positions = causal[0, 0, i, :].sum().item()
    assert allowed_positions == i + 1, f"Position {i} should attend to {i+1} positions, got {allowed_positions}"

print("Causal mask correctness verified!")

## 7. Full Transformer Model (Encoder-Decoder)

In [None]:
class TransformerEncoderDecoder(nn.Module):
    
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        
        # Embeddings
        self.src_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.tgt_embedding = nn.Embedding(config.vocab_size, config.d_model)
        
        # Positional encoding
        self.positional_encoding = SinusoidalPositionalEncoding(
            config.d_model, config.max_seq_len, config.dropout
        )
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(config.d_model, config.n_heads, config.d_ff, config.dropout)
            for _ in range(config.n_layers)
        ])
        
        # Decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(config.d_model, config.n_heads, config.d_ff, config.dropout)
            for _ in range(config.n_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(config.d_model, config.vocab_size)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def encode(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.src_embedding(src) * math.sqrt(self.config.d_model)
        if self.config.use_positional_encoding:
            x = self.positional_encoding(x)
        
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x
    
    def decode(
        self,
        tgt: torch.Tensor,
        encoder_output: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        src_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = self.tgt_embedding(tgt) * math.sqrt(self.config.d_model)
        if self.config.use_positional_encoding:
            x = self.positional_encoding(x)
        
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, tgt_mask, src_mask)
        return x
    
    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: Optional[torch.Tensor] = None,
        tgt_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, tgt_mask, src_mask)
        logits = self.output_projection(decoder_output)
        return logits

# Test encoder-decoder
enc_dec_model = TransformerEncoderDecoder(config)
src = torch.randint(0, config.vocab_size, (2, 10))
tgt = torch.randint(0, config.vocab_size, (2, 8))
output = enc_dec_model(src, tgt)
assert output.shape == (2, 8, config.vocab_size)
print(f"Encoder-Decoder test passed. Output shape: {output.shape}")

## 8. Decoder-Only Transformer (for Language Modeling)

For causal language modeling, we only need the decoder part with causal masking.

In [None]:
class DecoderOnlyTransformer(nn.Module):
    
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        
        # Token embedding
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        
        # Positional encoding
        self.positional_encoding = SinusoidalPositionalEncoding(
            config.d_model, config.max_seq_len, config.dropout
        )
        
        # Transformer layers (decoder blocks without cross-attention)
        self.layers = nn.ModuleList([
            EncoderLayer(config.d_model, config.n_heads, config.d_ff, config.dropout)
            for _ in range(config.n_layers)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(config.d_model)
        
        # Output projection (tied with embeddings for efficiency)
        self.output_projection = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        # Weight tying
        self.output_projection.weight = self.token_embedding.weight
        
        # Initialize weights
        self._init_weights()
        
        # Store the causal mask
        self.register_buffer(
            'causal_mask',
            create_causal_mask(config.max_seq_len, 'cpu')[0, 0]
        )
        
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                
    def forward(
        self,
        x: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size, seq_len = x.shape
        
        # Create causal mask
        causal_mask = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
        causal_mask = causal_mask.expand(batch_size, 1, seq_len, seq_len)
        
        # Combine with padding mask if provided
        if pad_mask is not None:
            mask = causal_mask & pad_mask
        else:
            mask = causal_mask
            
        # Embed tokens
        h = self.token_embedding(x) * math.sqrt(self.config.d_model)
        
        # Add positional encoding
        if self.config.use_positional_encoding:
            h = self.positional_encoding(h)
        
        # Apply transformer layers
        for layer in self.layers:
            h = layer(h, mask)
            
        # Final normalization and projection
        h = self.norm(h)
        logits = self.output_projection(h)
        
        return logits
    
    def generate(
        self,
        prompt: torch.Tensor,
        max_new_tokens: int = 50,
        temperature: float = 1.0
    ) -> torch.Tensor:
        self.eval()
        x = prompt.clone()
        
        for _ in range(max_new_tokens):
            # Truncate if too long
            x_cond = x if x.size(1) <= self.config.max_seq_len else x[:, -self.config.max_seq_len:]
            
            # Get predictions
            with torch.no_grad():
                logits = self(x_cond)
                logits = logits[:, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                x = torch.cat([x, next_token], dim=1)
                
        return x

# Test decoder-only model
model = DecoderOnlyTransformer(config)
x = torch.randint(0, config.vocab_size, (2, 20))
output = model(x)
assert output.shape == (2, 20, config.vocab_size)
print(f"Decoder-only test passed. Output shape: {output.shape}")

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

## 9. Training Loop

In [None]:
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    pad_idx: int
) -> Tuple[float, int]:
    model.train()
    total_loss = 0.0
    total_tokens = 0
    
    for batch_x, batch_y in tqdm(dataloader, desc="Training", leave=False):
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(batch_x)
        
        # Compute loss (ignore padding)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            batch_y.view(-1),
            ignore_index=pad_idx
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Track stats
        num_tokens = (batch_y != pad_idx).sum().item()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens
        
    return total_loss / total_tokens, total_tokens


def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    pad_idx: int
) -> float:
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    
    with torch.no_grad():
        for batch_x, batch_y in dataloader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            
            logits = model(batch_x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                batch_y.view(-1),
                ignore_index=pad_idx,
                reduction='sum'
            )
            
            num_tokens = (batch_y != pad_idx).sum().item()
            total_loss += loss.item()
            total_tokens += num_tokens
            
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity

In [None]:
def train_model(
    config: TransformerConfig,
    dataloader: DataLoader,
    device: torch.device,
    tokenizer: SimpleTokenizer,
    model_name: str = "Transformer"
) -> Tuple[DecoderOnlyTransformer, List[float], List[float]]:
    model = DecoderOnlyTransformer(config).to(device)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        betas=(0.9, 0.98),
        eps=1e-9
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.n_epochs
    )
    
    train_losses = []
    perplexities = []
    
    print(f"\n{'='*50}")
    print(f"Training {model_name}")
    print(f"Config: heads={config.n_heads}, pos_enc={config.use_positional_encoding}")
    print(f"{'='*50}")
    
    start_time = time.time()
    total_tokens_processed = 0
    
    for epoch in range(config.n_epochs):
        epoch_start = time.time()
        
        train_loss, tokens = train_epoch(
            model, dataloader, optimizer, device, tokenizer.pad_idx
        )
        total_tokens_processed += tokens
        
        perplexity = evaluate(model, dataloader, device, tokenizer.pad_idx)
        
        # Step scheduler
        scheduler.step()
        
        # Record
        train_losses.append(train_loss)
        perplexities.append(perplexity)
        
        epoch_time = time.time() - epoch_start
        tokens_per_sec = tokens / epoch_time
        
        print(f"Epoch {epoch+1}/{config.n_epochs} | "
              f"Loss: {train_loss:.4f} | "
              f"PPL: {perplexity:.2f} | "
              f"Time: {epoch_time:.1f}s | "
              f"Tokens/sec: {tokens_per_sec:.0f}")
    
    total_time = time.time() - start_time
    print(f"\nTotal training time: {total_time:.1f}s")
    print(f"Average tokens/sec: {total_tokens_processed / total_time:.0f}")
    
    return model, train_losses, perplexities

In [None]:
# Sanity check: Gradient flow
print("Checking gradient flow...")

test_model = DecoderOnlyTransformer(config).to(device)
test_x = torch.randint(0, config.vocab_size, (2, 20)).to(device)
test_y = torch.randint(0, config.vocab_size, (2, 20)).to(device)

test_model.zero_grad()
logits = test_model(test_x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), test_y.view(-1))
loss.backward()

# Check that all parameters have gradients
for name, param in test_model.named_parameters():
    if param.requires_grad:
        assert param.grad is not None, f"No gradient for {name}"
        assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}"
        assert not torch.isinf(param.grad).any(), f"Inf gradient for {name}"
        
print("All gradients are valid!")

In [None]:
# Sanity check: Attention weights shape
print("\nChecking attention weights...")

mha = MultiHeadAttention(d_model=config.d_model, n_heads=config.n_heads)
x = torch.randn(2, 15, config.d_model)  # batch=2, seq_len=15
_ = mha(x, x, x)

attn_weights = mha.attention_weights
expected_shape = (2, config.n_heads, 15, 15)
assert attn_weights.shape == expected_shape, f"Expected {expected_shape}, got {attn_weights.shape}"
print(f"Attention weights shape: {attn_weights.shape} ✓")

# Check attention weights sum to 1
attn_sum = attn_weights.sum(dim=-1)
assert torch.allclose(attn_sum, torch.ones_like(attn_sum), atol=1e-5), "Attention weights don't sum to 1"
print("Attention weights sum to 1 ✓")

## 11. Training with Default Configuration

In [None]:
# Train with default config (8 heads, with positional encoding)
set_seed(42)
model_default, losses_default, ppl_default = train_model(
    config, dataloader, device, tokenizer, "Default (8 heads, with PE)"
)

## 12. Ablation 1: Positional Encoding ON vs OFF

In [None]:
# Ablation: Without positional encoding
config_no_pe = TransformerConfig(
    d_model=config.d_model,
    n_heads=config.n_heads,
    n_layers=config.n_layers,
    d_ff=config.d_ff,
    max_seq_len=config.max_seq_len,
    vocab_size=config.vocab_size,
    dropout=config.dropout,
    batch_size=config.batch_size,
    learning_rate=config.learning_rate,
    n_epochs=config.n_epochs,
    use_positional_encoding=False  # Disable PE
)

set_seed(42)
model_no_pe, losses_no_pe, ppl_no_pe = train_model(
    config_no_pe, dataloader, device, tokenizer, "Without Positional Encoding"
)

## 13. Ablation 2: Number of Heads (4 vs 8)

In [None]:
# Ablation: 4 attention heads
config_4heads = TransformerConfig(
    d_model=config.d_model,
    n_heads=4,  # Changed from 8 to 4
    n_layers=config.n_layers,
    d_ff=config.d_ff,
    max_seq_len=config.max_seq_len,
    vocab_size=config.vocab_size,
    dropout=config.dropout,
    batch_size=config.batch_size,
    learning_rate=config.learning_rate,
    n_epochs=config.n_epochs,
    use_positional_encoding=True
)

set_seed(42)
model_4heads, losses_4heads, ppl_4heads = train_model(
    config_4heads, dataloader, device, tokenizer, "4 Attention Heads"
)

## 14. Results Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(losses_default, label='8 heads, with PE', marker='o')
axes[0].plot(losses_no_pe, label='8 heads, no PE', marker='s')
axes[0].plot(losses_4heads, label='4 heads, with PE', marker='^')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Perplexity curves
axes[1].plot(ppl_default, label='8 heads, with PE', marker='o')
axes[1].plot(ppl_no_pe, label='8 heads, no PE', marker='s')
axes[1].plot(ppl_4heads, label='4 heads, with PE', marker='^')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Perplexity')
axes[1].set_title('Perplexity Comparison')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Results table
print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)
print(f"{'Configuration':<30} {'Final Loss':<15} {'Final PPL':<15}")
print("-"*70)
print(f"{'8 heads, with PE (default)':<30} {losses_default[-1]:<15.4f} {ppl_default[-1]:<15.2f}")
print(f"{'8 heads, without PE':<30} {losses_no_pe[-1]:<15.4f} {ppl_no_pe[-1]:<15.2f}")
print(f"{'4 heads, with PE':<30} {losses_4heads[-1]:<15.4f} {ppl_4heads[-1]:<15.2f}")
print("="*70)

## 15. Text Generation Demo

In [None]:
# Generate some text
model_default.eval()

# Create a prompt
prompt_text = "the quick brown"
prompt_tokens = tokenizer.encode(prompt_text)
prompt_tensor = torch.tensor([prompt_tokens]).to(device)

print(f"Prompt: '{prompt_text}'")
print("\nGenerated continuations:")
print("-" * 50)

for temp in [0.7, 1.0, 1.3]:
    generated = model_default.generate(
        prompt_tensor, 
        max_new_tokens=20,
        temperature=temp
    )
    generated_text = tokenizer.decode(generated[0].tolist())
    print(f"Temperature {temp}: {generated_text}")

## 16. Speed Analysis

In [None]:
# Measure tokens per second
def measure_throughput(model, batch_size=32, seq_len=128, n_iters=50):
    model.eval()
    x = torch.randint(0, config.vocab_size, (batch_size, seq_len)).to(device)
    
    # Warmup
    for _ in range(5):
        with torch.no_grad():
            _ = model(x)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    start = time.time()
    for _ in range(n_iters):
        with torch.no_grad():
            _ = model(x)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    elapsed = time.time() - start
    total_tokens = batch_size * seq_len * n_iters
    return total_tokens / elapsed

throughput = measure_throughput(model_default)
print(f"Inference throughput: {throughput:,.0f} tokens/second")
print(f"Device: {device}")