# Attention Is All You Need — Transformer Implementation

A faithful implementation of the original Transformer from [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762).

**Task:** German → English translation on IWSLT 2014 (a scaled-down version of the paper's WMT 2014 benchmark).

**Architecture:** Identical to the paper — post-norm residual connections, sinusoidal positional encoding, three-way weight tying, label-smoothed cross-entropy, and the paper's exact LR schedule.

In [None]:
# Cell 0: Setup & Installs
!pip install datasets tokenizers sacrebleu -q

In [None]:
# Cell 1: Configuration

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import copy
import time
import random
import numpy as np
from dataclasses import dataclass
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

@dataclass
class TransformerConfig:
    # Model architecture (scaled down from paper's base config)
    n_layers: int = 3           # Paper: 6
    d_model: int = 256          # Paper: 512
    d_ff: int = 512             # Paper: 2048
    n_heads: int = 4            # Paper: 8
    d_k: int = 64               # Paper: 64 (same)
    dropout: float = 0.1        # Paper: 0.1 (same)
    max_seq_len: int = 128
    
    # Training
    label_smoothing: float = 0.1  # Paper: 0.1 (same)
    warmup_steps: int = 4000      # Paper: 4000 (same)
    n_epochs: int = 25
    batch_size: int = 64
    
    # Tokenizer
    vocab_size: int = 10000
    
    # Adam params (Section 5.3)
    adam_beta1: float = 0.9
    adam_beta2: float = 0.98
    adam_eps: float = 1e-9       # Paper: 10^-9 (not PyTorch default 10^-8)
    
    # Special token IDs (set after tokenizer training)
    pad_idx: int = 0
    bos_idx: int = 1
    eos_idx: int = 2
    unk_idx: int = 3
    
    seed: int = 42

config = TransformerConfig()

# Reproducibility
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"\nConfig: {config}")

In [None]:
# Cell 2: Data Pipeline

from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors

# 1. Download IWSLT14 DE-EN
print("Downloading IWSLT14 DE-EN...")
raw_dataset = load_dataset("bbaaaa/iwslt14-de-en")
print(f"Train: {len(raw_dataset['train'])} pairs")
print(f"Val:   {len(raw_dataset['validation'])} pairs")
print(f"Test:  {len(raw_dataset['test'])} pairs")
print(f"\nExample: {raw_dataset['train'][0]}")

# 2. Train a shared BPE tokenizer (paper Section 5.1: shared source-target vocabulary)
print("\nTraining shared BPE tokenizer...")

tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

trainer = trainers.BpeTrainer(
    vocab_size=config.vocab_size,
    special_tokens=["<pad>", "<bos>", "<eos>", "<unk>"],
    min_frequency=2,
    show_progress=True,
)

# Combine DE + EN training text for shared vocabulary
def training_corpus():
    for example in raw_dataset["train"]:
        yield example["translation"]["de"]
        yield example["translation"]["en"]

tokenizer.train_from_iterator(training_corpus(), trainer=trainer)

# Set up post-processing to add <bos> and <eos>
tokenizer.post_processor = processors.TemplateProcessing(
    single="<bos> $A <eos>",
    pair="<bos> $A <eos> <bos> $B:1 <eos>:1",
    special_tokens=[
        ("<bos>", config.bos_idx),
        ("<eos>", config.eos_idx),
    ],
)

# Update config with actual special token IDs
config.pad_idx = tokenizer.token_to_id("<pad>")
config.bos_idx = tokenizer.token_to_id("<bos>")
config.eos_idx = tokenizer.token_to_id("<eos>")
config.unk_idx = tokenizer.token_to_id("<unk>")
config.vocab_size = tokenizer.get_vocab_size()

tokenizer.enable_padding(pad_id=config.pad_idx, pad_token="<pad>")
tokenizer.enable_truncation(max_length=config.max_seq_len)

print(f"Vocab size: {config.vocab_size}")
print(f"Special tokens — PAD: {config.pad_idx}, BOS: {config.bos_idx}, EOS: {config.eos_idx}, UNK: {config.unk_idx}")

# Test tokenizer
test_enc = tokenizer.encode("Hello, this is a test.")
print(f"\nTokenizer test:")
print(f"  Input: 'Hello, this is a test.'")
print(f"  Tokens: {test_enc.tokens}")
print(f"  IDs: {test_enc.ids}")
print(f"  Decoded: '{tokenizer.decode(test_enc.ids)}'")

In [None]:
# Cell 2b: Dataset & DataLoader

class TranslationDataset(Dataset):
    """IWSLT14 DE-EN translation dataset."""
    
    def __init__(self, split_data, tokenizer, max_len):
        self.pairs = []
        for example in tqdm(split_data, desc=f"Tokenizing"):
            src_text = example["translation"]["de"]
            tgt_text = example["translation"]["en"]
            src_enc = tokenizer.encode(src_text)
            tgt_enc = tokenizer.encode(tgt_text)
            self.pairs.append((
                torch.tensor(src_enc.ids, dtype=torch.long),
                torch.tensor(tgt_enc.ids, dtype=torch.long),
            ))
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        return self.pairs[idx]


def collate_fn(batch):
    """Pad sequences to the longest in the batch."""
    src_seqs, tgt_seqs = zip(*batch)
    src_padded = nn.utils.rnn.pad_sequence(src_seqs, batch_first=True, padding_value=config.pad_idx)
    tgt_padded = nn.utils.rnn.pad_sequence(tgt_seqs, batch_first=True, padding_value=config.pad_idx)
    return src_padded, tgt_padded


print("Tokenizing datasets...")
train_dataset = TranslationDataset(raw_dataset["train"], tokenizer, config.max_seq_len)
val_dataset = TranslationDataset(raw_dataset["validation"], tokenizer, config.max_seq_len)
test_dataset = TranslationDataset(raw_dataset["test"], tokenizer, config.max_seq_len)

train_loader = DataLoader(
    train_dataset, batch_size=config.batch_size, shuffle=True,
    collate_fn=collate_fn, num_workers=2, pin_memory=True, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=config.batch_size, shuffle=False,
    collate_fn=collate_fn, num_workers=2, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=config.batch_size, shuffle=False,
    collate_fn=collate_fn, num_workers=2, pin_memory=True
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")

# Quick check
src_batch, tgt_batch = next(iter(train_loader))
print(f"\nSample batch shapes — src: {src_batch.shape}, tgt: {tgt_batch.shape}")

In [None]:
# Cell 3: Positional Encoding (Section 3.5)
#
# PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
# PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )  # (d_model/2,)
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Quick visualization
pe_module = PositionalEncoding(config.d_model)
pe_values = pe_module.pe[0, :128, :].numpy()

fig, ax = plt.subplots(figsize=(10, 4))
cax = ax.imshow(pe_values.T, aspect='auto', cmap='RdBu')
ax.set_xlabel('Position')
ax.set_ylabel('Dimension')
ax.set_title('Sinusoidal Positional Encoding')
fig.colorbar(cax)
plt.tight_layout()
plt.show()

In [None]:
# Cell 4: Multi-Head Attention (Section 3.2)
#
# Scaled Dot-Product Attention:
#   Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
#
# Multi-Head Attention:
#   MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W_O
#   where head_i = Attention(Q W_Qi, K W_Ki, V W_Vi)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections (implemented as single large matrices)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)  # Output projection
        
        self.attn_dropout = nn.Dropout(p=dropout)
        
    def scaled_dot_product_attention(self, q, k, v, mask=None):
        """
        q, k, v: (batch, n_heads, seq_len, d_k)
        mask: broadcastable to (batch, n_heads, seq_len_q, seq_len_k)
        """
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)  # Attention dropout (Section 5.4)
        
        output = torch.matmul(attn_weights, v)
        return output, attn_weights
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 1. Linear projections and reshape to (batch, n_heads, seq_len, d_k)
        q = self.w_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.w_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.w_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. Scaled dot-product attention
        attn_output, attn_weights = self.scaled_dot_product_attention(q, k, v, mask)
        
        # 3. Concatenate heads and apply output projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.w_o(attn_output)
        
        return output, attn_weights

print("MultiHeadAttention ready.")
print(f"  d_model={config.d_model}, n_heads={config.n_heads}, d_k={config.d_model // config.n_heads}")

In [None]:
# Cell 5: Position-wise Feed-Forward Network (Section 3.3)
#
# FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
# Inner dimension d_ff, with ReLU activation.

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

print("PositionwiseFeedForward ready.")
print(f"  d_model={config.d_model}, d_ff={config.d_ff}")

In [None]:
# Cell 6: Encoder Layer & Encoder Stack (Section 3.1)
#
# Each encoder layer has two sub-layers:
#   1. Multi-Head Self-Attention
#   2. Position-wise FFN
# Each sub-layer: LayerNorm(x + Sublayer(x))  [post-norm, as in the paper]

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p=dropout)
        self.dropout2 = nn.Dropout(p=dropout)
    
    def forward(self, x, src_mask):
        # Sub-layer 1: Self-Attention with post-norm
        attn_output, _ = self.self_attn(x, x, x, src_mask)
        x = self.norm1(x + self.dropout1(attn_output))  # LayerNorm(x + Dropout(Sublayer(x)))
        
        # Sub-layer 2: FFN with post-norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x


class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout, max_len, pad_idx):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_encoding = PositionalEncoding(d_model, dropout, max_len)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
    
    def forward(self, src, src_mask):
        # Scale embeddings by sqrt(d_model) (Section 3.4)
        x = self.embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x, src_mask)
        
        return x

print("Encoder ready.")
print(f"  {config.n_layers} layers, d_model={config.d_model}, n_heads={config.n_heads}, d_ff={config.d_ff}")

In [None]:
# Cell 7: Decoder Layer & Decoder Stack (Section 3.1)
#
# Each decoder layer has three sub-layers:
#   1. Masked Multi-Head Self-Attention (causal mask)
#   2. Multi-Head Encoder-Decoder Attention (Q from decoder, K/V from encoder)
#   3. Position-wise FFN

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p=dropout)
        self.dropout2 = nn.Dropout(p=dropout)
        self.dropout3 = nn.Dropout(p=dropout)
    
    def forward(self, x, enc_output, src_mask, tgt_mask):
        # Sub-layer 1: Masked Self-Attention (post-norm)
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Sub-layer 2: Encoder-Decoder Attention (post-norm)
        # Q from decoder, K and V from encoder output
        attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout2(attn_output))
        
        # Sub-layer 3: FFN (post-norm)
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_output))
        
        return x


class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout, max_len, pad_idx):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_encoding = PositionalEncoding(d_model, dropout, max_len)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
    
    def forward(self, tgt, enc_output, src_mask, tgt_mask):
        # Scale embeddings by sqrt(d_model) (Section 3.4)
        x = self.embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        
        return x

print("Decoder ready.")
print(f"  {config.n_layers} layers, d_model={config.d_model}, n_heads={config.n_heads}, d_ff={config.d_ff}")

In [None]:
# Cell 8: Full Transformer (Section 3)
#
# Combines Encoder + Decoder + final linear projection.
# Weight tying (Section 3.4): encoder embedding = decoder embedding = pre-softmax linear.

def create_padding_mask(seq, pad_idx):
    """Create mask to hide padding tokens.
    Returns: (batch, 1, 1, seq_len) — broadcastable over heads and query positions.
    """
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)  # 1 where not padding


def create_causal_mask(size, device):
    """Create causal (look-ahead) mask for decoder self-attention.
    Upper-triangular mask: position i can only attend to positions <= i.
    Returns: (1, 1, size, size)
    """
    mask = torch.tril(torch.ones(size, size, device=device)).unsqueeze(0).unsqueeze(0)
    return mask  # 1 in lower triangle, 0 in upper triangle


def create_tgt_mask(tgt, pad_idx):
    """Combined padding + causal mask for decoder."""
    tgt_pad_mask = create_padding_mask(tgt, pad_idx)  # (batch, 1, 1, tgt_len)
    tgt_causal_mask = create_causal_mask(tgt.size(1), tgt.device)  # (1, 1, tgt_len, tgt_len)
    return tgt_pad_mask & tgt_causal_mask  # broadcast: (batch, 1, tgt_len, tgt_len)


class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.encoder = Encoder(
            config.vocab_size, config.d_model, config.n_layers,
            config.n_heads, config.d_ff, config.dropout,
            config.max_seq_len, config.pad_idx
        )
        self.decoder = Decoder(
            config.vocab_size, config.d_model, config.n_layers,
            config.n_heads, config.d_ff, config.dropout,
            config.max_seq_len, config.pad_idx
        )
        
        # Pre-softmax linear layer
        self.output_projection = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        # Three-way weight tying (Section 3.4):
        # Encoder embedding = Decoder embedding = Output projection
        self.decoder.embedding.weight = self.encoder.embedding.weight
        self.output_projection.weight = self.encoder.embedding.weight
        
        # Initialize parameters with Xavier uniform (as is standard)
        self._init_parameters()
    
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, src, src_mask):
        return self.encoder(src, src_mask)
    
    def decode(self, tgt, enc_output, src_mask, tgt_mask):
        return self.decoder(tgt, enc_output, src_mask, tgt_mask)
    
    def forward(self, src, tgt):
        src_mask = create_padding_mask(src, self.config.pad_idx)
        tgt_mask = create_tgt_mask(tgt, self.config.pad_idx)
        
        enc_output = self.encode(src, src_mask)
        dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
        logits = self.output_projection(dec_output)
        
        return logits


# Instantiate and count parameters
model = Transformer(config).to(device)
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Transformer model created.")
print(f"  Total parameters:     {n_params:,}")
print(f"  Trainable parameters: {n_trainable:,}")
print(f"\nArchitecture:")
print(f"  Encoder: {config.n_layers} layers")
print(f"  Decoder: {config.n_layers} layers")
print(f"  d_model={config.d_model}, n_heads={config.n_heads}, d_ff={config.d_ff}")
print(f"  Weight tying: encoder emb = decoder emb = output projection")

In [None]:
# Cell 9: Label-Smoothed Cross-Entropy Loss (Section 5.4)
#
# True label gets probability (1 - eps), remaining eps distributed
# uniformly across all other classes. Ignores padding index.
# Paper: "hurts perplexity but improves accuracy and BLEU."

class LabelSmoothingLoss(nn.Module):
    def __init__(self, vocab_size, pad_idx, smoothing=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.pad_idx = pad_idx
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing
    
    def forward(self, logits, target):
        """
        logits: (batch * seq_len, vocab_size)
        target: (batch * seq_len,)
        """
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Create smoothed target distribution
        # Fill with uniform smoothing value
        smooth_targets = torch.full_like(log_probs, self.smoothing / (self.vocab_size - 2))  # -2 for pad and true label
        # Set the true label's probability
        smooth_targets.scatter_(1, target.unsqueeze(1), self.confidence)
        # Zero out padding
        smooth_targets[:, self.pad_idx] = 0
        
        # Mask out padding positions entirely
        pad_mask = target == self.pad_idx
        smooth_targets[pad_mask] = 0
        
        # KL divergence: sum over vocab, mean over non-pad tokens
        loss = -(smooth_targets * log_probs).sum(dim=-1)
        loss = loss[~pad_mask].mean()
        
        return loss

criterion = LabelSmoothingLoss(config.vocab_size, config.pad_idx, config.label_smoothing)
print(f"Label Smoothing Loss ready (eps={config.label_smoothing})")

In [None]:
# Cell 10: Optimizer & LR Schedule (Section 5.3)
#
# lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
# Linear warmup for warmup_steps, then decay proportional to 1/sqrt(step).

def get_lr_lambda(d_model, warmup_steps):
    def lr_lambda(step):
        step = max(step, 1)  # avoid division by zero
        return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5)
    return lr_lambda

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1.0,  # actual LR controlled by scheduler
    betas=(config.adam_beta1, config.adam_beta2),
    eps=config.adam_eps
)

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=get_lr_lambda(config.d_model, config.warmup_steps)
)

# Visualize the LR schedule
steps = list(range(1, 60001))
lr_fn = get_lr_lambda(config.d_model, config.warmup_steps)
lrs = [lr_fn(s) for s in steps]

fig, ax = plt.subplots(figsize=(8, 3))
ax.plot(steps, lrs)
ax.set_xlabel('Step')
ax.set_ylabel('Learning Rate')
ax.set_title(f'Transformer LR Schedule (warmup={config.warmup_steps}, d_model={config.d_model})')
ax.axvline(x=config.warmup_steps, color='r', linestyle='--', alpha=0.5, label=f'warmup={config.warmup_steps}')
ax.legend()
plt.tight_layout()
plt.show()

print(f"Peak LR at step {config.warmup_steps}: {lr_fn(config.warmup_steps):.6f}")

In [None]:
# Cell 11: Training Loop

def train_epoch(model, dataloader, optimizer, scheduler, criterion, device, epoch):
    model.train()
    total_loss = 0
    total_tokens = 0
    start_time = time.time()
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
    for i, (src, tgt) in enumerate(pbar):
        src = src.to(device)
        tgt = tgt.to(device)
        
        # Teacher forcing: input is tgt[:-1], target is tgt[1:]
        tgt_input = tgt[:, :-1]   # everything except last token
        tgt_output = tgt[:, 1:]   # everything except first token (<bos>)
        
        # Forward pass
        logits = model(src, tgt_input)
        
        # Reshape for loss: (batch * seq_len, vocab_size)
        logits_flat = logits.contiguous().view(-1, logits.size(-1))
        tgt_flat = tgt_output.contiguous().view(-1)
        
        loss = criterion(logits_flat, tgt_flat)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        scheduler.step()
        
        # Track stats
        n_tokens = (tgt_output != config.pad_idx).sum().item()
        total_loss += loss.item() * n_tokens
        total_tokens += n_tokens
        
        if i % 100 == 0:
            elapsed = time.time() - start_time
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}',
                'tok/s': f'{total_tokens / elapsed:.0f}'
            })
    
    return total_loss / total_tokens


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    for src, tgt in dataloader:
        src = src.to(device)
        tgt = tgt.to(device)
        
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        logits = model(src, tgt_input)
        logits_flat = logits.contiguous().view(-1, logits.size(-1))
        tgt_flat = tgt_output.contiguous().view(-1)
        
        loss = criterion(logits_flat, tgt_flat)
        
        n_tokens = (tgt_output != config.pad_idx).sum().item()
        total_loss += loss.item() * n_tokens
        total_tokens += n_tokens
    
    return total_loss / total_tokens

print("Training functions ready.")

In [None]:
# Cell 12: Greedy Decoding

@torch.no_grad()
def greedy_decode(model, src, src_mask, max_len, bos_idx, eos_idx, device):
    """
    Autoregressively generate target tokens using greedy (argmax) decoding.
    
    Args:
        src: (1, src_len) source token IDs
        src_mask: (1, 1, 1, src_len) padding mask
        max_len: maximum generation length
    Returns:
        List of generated token IDs (excluding <bos>)
    """
    model.eval()
    
    # Encode source once
    enc_output = model.encode(src, src_mask)
    
    # Start with <bos> token
    ys = torch.tensor([[bos_idx]], dtype=torch.long, device=device)
    
    for _ in range(max_len - 1):
        tgt_mask = create_causal_mask(ys.size(1), device)
        dec_output = model.decode(ys, enc_output, src_mask, tgt_mask)
        logits = model.output_projection(dec_output[:, -1, :])  # last position
        next_token = logits.argmax(dim=-1, keepdim=True)  # (1, 1)
        ys = torch.cat([ys, next_token], dim=1)
        
        if next_token.item() == eos_idx:
            break
    
    return ys[0, 1:].tolist()  # exclude <bos>


def translate_sentence(model, sentence, tokenizer, config, device):
    """Translate a single German sentence to English."""
    model.eval()
    
    # Tokenize source
    src_enc = tokenizer.encode(sentence)
    src = torch.tensor([src_enc.ids], dtype=torch.long, device=device)
    src_mask = create_padding_mask(src, config.pad_idx)
    
    # Decode
    output_ids = greedy_decode(
        model, src, src_mask, config.max_seq_len,
        config.bos_idx, config.eos_idx, device
    )
    
    # Remove <eos> if present
    if output_ids and output_ids[-1] == config.eos_idx:
        output_ids = output_ids[:-1]
    
    return tokenizer.decode(output_ids)

print("Greedy decoding ready.")

In [None]:
# Cell 13: Beam Search Decoding
#
# Paper: beam size 4, length penalty alpha=0.6

@torch.no_grad()
def beam_search_decode(model, src, src_mask, max_len, bos_idx, eos_idx, device,
                       beam_size=4, alpha=0.6):
    """
    Beam search decoding with length normalization.
    
    Length normalization: score / (length^alpha)
    """
    model.eval()
    
    enc_output = model.encode(src, src_mask)  # (1, src_len, d_model)
    
    # Each beam: (score, [token_ids])
    beams = [(0.0, [bos_idx])]
    completed = []
    
    for _ in range(max_len - 1):
        all_candidates = []
        
        for score, seq in beams:
            if seq[-1] == eos_idx:
                # Length-normalized score
                norm_score = score / (len(seq) ** alpha)
                completed.append((norm_score, seq))
                continue
            
            ys = torch.tensor([seq], dtype=torch.long, device=device)
            tgt_mask = create_causal_mask(ys.size(1), device)
            dec_output = model.decode(ys, enc_output, src_mask, tgt_mask)
            logits = model.output_projection(dec_output[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)
            
            topk_log_probs, topk_ids = log_probs.topk(beam_size, dim=-1)
            
            for k in range(beam_size):
                new_score = score + topk_log_probs[0, k].item()
                new_seq = seq + [topk_ids[0, k].item()]
                all_candidates.append((new_score, new_seq))
        
        if not all_candidates:
            break
        
        # Keep top beam_size candidates
        all_candidates.sort(key=lambda x: x[0], reverse=True)
        beams = all_candidates[:beam_size]
    
    # Add remaining beams to completed
    for score, seq in beams:
        norm_score = score / (len(seq) ** alpha)
        completed.append((norm_score, seq))
    
    # Return best completed sequence
    completed.sort(key=lambda x: x[0], reverse=True)
    best_seq = completed[0][1]
    
    # Remove <bos> and <eos>
    result = best_seq[1:]  # remove <bos>
    if result and result[-1] == eos_idx:
        result = result[:-1]
    
    return result

print(f"Beam search decoding ready (beam_size=4, alpha=0.6).")

In [None]:
# Cell 14: BLEU Evaluation

import sacrebleu

def compute_bleu(model, test_loader, tokenizer, config, device, use_beam=False):
    """Compute corpus-level BLEU on the test set."""
    model.eval()
    predictions = []
    references = []
    
    for src_batch, tgt_batch in tqdm(test_loader, desc="Translating test set"):
        for i in range(src_batch.size(0)):
            # Get source (remove padding)
            src = src_batch[i].unsqueeze(0).to(device)
            src_mask = create_padding_mask(src, config.pad_idx)
            
            # Decode
            if use_beam:
                output_ids = beam_search_decode(
                    model, src, src_mask, config.max_seq_len,
                    config.bos_idx, config.eos_idx, device
                )
            else:
                output_ids = greedy_decode(
                    model, src, src_mask, config.max_seq_len,
                    config.bos_idx, config.eos_idx, device
                )
                # Remove <eos> if present
                if output_ids and output_ids[-1] == config.eos_idx:
                    output_ids = output_ids[:-1]
            
            # Decode to text
            pred_text = tokenizer.decode(output_ids)
            predictions.append(pred_text)
            
            # Reference: decode target tokens (remove <bos>, <eos>, <pad>)
            ref_ids = tgt_batch[i].tolist()
            ref_ids = [t for t in ref_ids if t not in (config.pad_idx, config.bos_idx, config.eos_idx)]
            ref_text = tokenizer.decode(ref_ids)
            references.append(ref_text)
    
    # Compute BLEU using sacrebleu
    bleu = sacrebleu.corpus_bleu(predictions, [references])
    return bleu, predictions, references

print("BLEU evaluation ready.")

In [None]:
# Cell 15: Run Training

train_losses = []
val_losses = []
best_val_loss = float('inf')

print(f"Training for {config.n_epochs} epochs...")
print(f"{'='*60}")

for epoch in range(1, config.n_epochs + 1):
    epoch_start = time.time()
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, device, epoch)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = evaluate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    
    epoch_time = time.time() - epoch_start
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_transformer.pt')
        marker = ' *'
    else:
        marker = ''
    
    print(f"Epoch {epoch:2d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"Time: {epoch_time:.1f}s | LR: {scheduler.get_last_lr()[0]:.2e}{marker}")
    
    # Show a sample translation every 5 epochs
    if epoch % 5 == 0:
        sample_de = raw_dataset['validation'][0]['translation']['de']
        sample_en = raw_dataset['validation'][0]['translation']['en']
        pred_en = translate_sentence(model, sample_de, tokenizer, config, device)
        print(f"  Sample — DE: {sample_de}")
        print(f"           EN: {sample_en}")
        print(f"         Pred: {pred_en}")

print(f"{'='*60}")
print(f"Best validation loss: {best_val_loss:.4f}")

# Plot loss curves
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss')
ax.plot(range(1, len(val_losses)+1), val_losses, label='Val Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training & Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Cell 16: Run Evaluation & Show Results

# Load best checkpoint
model.load_state_dict(torch.load('best_transformer.pt', map_location=device))
print("Loaded best model checkpoint.\n")

# Compute BLEU on test set (greedy decoding)
print("Evaluating with greedy decoding...")
bleu_greedy, preds_greedy, refs = compute_bleu(model, test_loader, tokenizer, config, device, use_beam=False)
print(f"\nGreedy BLEU: {bleu_greedy.score:.2f}")
print(bleu_greedy)

# Compute BLEU with beam search
print("\nEvaluating with beam search (beam=4, alpha=0.6)...")
bleu_beam, preds_beam, _ = compute_bleu(model, test_loader, tokenizer, config, device, use_beam=True)
print(f"\nBeam Search BLEU: {bleu_beam.score:.2f}")
print(bleu_beam)

# Show sample translations
print(f"\n{'='*80}")
print("Sample Translations (from test set)")
print(f"{'='*80}")

n_samples = 10
indices = random.sample(range(len(refs)), n_samples)

for i, idx in enumerate(indices):
    src_text = raw_dataset['test'][idx]['translation']['de']
    print(f"\n--- Example {i+1} ---")
    print(f"  Source (DE):     {src_text}")
    print(f"  Reference (EN):  {refs[idx]}")
    print(f"  Greedy (EN):     {preds_greedy[idx]}")
    print(f"  Beam (EN):       {preds_beam[idx]}")

print(f"\n{'='*80}")
print(f"Final Results:")
print(f"  Greedy BLEU:      {bleu_greedy.score:.2f}")
print(f"  Beam Search BLEU: {bleu_beam.score:.2f}")
print(f"  Model params:     {sum(p.numel() for p in model.parameters()):,}")
print(f"  Config: N={config.n_layers}, d_model={config.d_model}, d_ff={config.d_ff}, h={config.n_heads}")