# NMT Homework (Self-Contained): EN→DE

Train a translation model (English→German), measure perplexity and BLEU, save a checkpoint, and optionally export predictions for ML‑Arena.

Focus: experiment with architectures (LSTM w/ attention, Transformer, decoding strategies) — not boilerplate. Core evaluation functions are provided to ensure consistent scoring across students.

Data: the course staff provides `dataset_splits/` in the repo root. No additional setup is needed for data.

## 0. Setup
Use `install.sh` or `pip install -r requirements.txt` to set up.

In [1]:
# !pip install torch tqdm
import torch
import sys
import os
import math
import random

print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

try:
    sys.stdout.reconfigure(line_buffering=True)
except Exception:
    pass

PyTorch version: 2.8.0+cu128
Using device: cuda


## 1. Shared Utilities (no external imports)
Tokenization, vocabulary, dataset, collate, and fixed evaluation (PPL, NLL, BLEU).

In [None]:
from typing import List, Tuple, Dict, Iterable
from collections import Counter
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


def set_seed(seed: int = 42):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


SPECIAL_TOKENS = {
    'pad': '<pad>',
    'sos': '<sos>',
    'eos': '<eos>',
    'unk': '<unk>'
}


def simple_tokenize(s: str) -> List[str]:
    """Lowercase whitespace tokenizer."""
    return s.strip().lower().split()


def read_split(path: str) -> List[Tuple[List[str], List[str]]]:
    """Read tab-separated translation pairs from file."""
    pairs: List[Tuple[List[str], List[str]]] = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.rstrip('\n').split('\t')
            if len(parts) < 2:
                continue
            pairs.append((
                simple_tokenize(parts[0]),
                simple_tokenize(parts[1])
            ))
    return pairs


def build_vocab(
    seqs: Iterable[List[str]],
    max_size: int | None = None
) -> Dict[str, int]:
    """Build vocabulary from token sequences."""
    counter = Counter()
    for seq in seqs:
        counter.update(seq)
    
    items = counter.most_common(max_size) if max_size else counter.items()
    
    # Initialize with special tokens
    stoi = {
        SPECIAL_TOKENS['pad']: 0,
        SPECIAL_TOKENS['sos']: 1,
        SPECIAL_TOKENS['eos']: 2,
        SPECIAL_TOKENS['unk']: 3
    }
    
    for word, _ in items:
        if word not in stoi:
            stoi[word] = len(stoi)
    
    return stoi


def encode(
    tokens: List[str],
    stoi: Dict[str, int],
    add_sos_eos: bool = False
) -> List[int]:
    """Encode tokens to indices using vocabulary."""
    ids = [stoi.get(token, stoi[SPECIAL_TOKENS['unk']]) for token in tokens]
    
    if add_sos_eos:
        ids = [stoi[SPECIAL_TOKENS['sos']]] + ids + [stoi[SPECIAL_TOKENS['eos']]]
    
    return ids


class Example:
    """Container for a single translation example."""
    def __init__(self, src_ids: List[int], tgt_in_ids: List[int], tgt_out_ids: List[int]):
        self.src_ids = src_ids
        self.tgt_in_ids = tgt_in_ids
        self.tgt_out_ids = tgt_out_ids


class TranslationDataset(Dataset):
    """Dataset for translation pairs."""
    def __init__(self, pairs, src_stoi, tgt_stoi):
        self.examples: List[Example] = []
        
        for src_tokens, tgt_tokens in pairs:
            # Source: tokens + EOS
            src_ids = encode(src_tokens, src_stoi) + [src_stoi[SPECIAL_TOKENS['eos']]]
            
            # Target: SOS + tokens + EOS
            tgt_ids = encode(tgt_tokens, tgt_stoi, add_sos_eos=True)
            
            # Decoder input: SOS + tokens, Decoder output: tokens + EOS
            self.examples.append(Example(src_ids, tgt_ids[:-1], tgt_ids[1:]))
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]


def collate_pad(batch, pad_id_src: int, pad_id_tgt: int):
    """Collate function with padding for batching."""
    src_max = max(len(ex.src_ids) for ex in batch)
    tgt_max = max(len(ex.tgt_in_ids) for ex in batch)
    
    def pad_to(seq, length, pad_value):
        return seq + [pad_value] * (length - len(seq))
    
    src = torch.tensor([
        pad_to(ex.src_ids, src_max, pad_id_src) for ex in batch
    ])
    tgt_in = torch.tensor([
        pad_to(ex.tgt_in_ids, tgt_max, pad_id_tgt) for ex in batch
    ])
    tgt_out = torch.tensor([
        pad_to(ex.tgt_out_ids, tgt_max, pad_id_tgt) for ex in batch
    ])
    src_lens = torch.tensor([len(ex.src_ids) for ex in batch])
    tgt_lens = torch.tensor([len(ex.tgt_out_ids) for ex in batch])
    
    return src, src_lens, tgt_in, tgt_out, tgt_lens


def compute_perplexity(loss_sum: float, token_count: int) -> float:
    """Compute perplexity from total loss and token count."""
    if token_count == 0:
        return float('inf')
    try:
        return float(math.exp(loss_sum / token_count))
    except OverflowError:
        return float('inf')


def corpus_bleu(
    refs: List[List[str]],
    hyps: List[List[str]],
    max_order: int = 4,
    smooth: bool = True
) -> float:
    """Compute corpus-level BLEU score."""
    def get_ngrams(tokens, n):
        return Counter([tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)])
    
    matches = [0] * max_order
    possible = [0] * max_order
    ref_length = 0
    hyp_length = 0
    
    for ref, hyp in zip(refs, hyps):
        ref_length += len(ref)
        hyp_length += len(hyp)
        
        for n in range(1, max_order + 1):
            ref_ngrams = get_ngrams(ref, n)
            hyp_ngrams = get_ngrams(hyp, n)
            
            matches[n-1] += sum(
                min(count, hyp_ngrams[ngram])
                for ngram, count in ref_ngrams.items()
            )
            possible[n-1] += max(len(hyp) - n + 1, 0)
    
    # Compute precision for each n-gram order
    if smooth:
        precisions = [
            (matches[i] + 1) / (possible[i] + 1)
            for i in range(max_order)
        ]
    else:
        precisions = [
            matches[i] / possible[i] if possible[i] > 0 else 0.0
            for i in range(max_order)
        ]
    
    # Geometric mean of precisions
    if min(precisions) > 0:
        geo_mean = math.exp(
            sum((1 / max_order) * math.log(p) for p in precisions)
        )
    else:
        geo_mean = 0.0
    
    # Brevity penalty
    if hyp_length > ref_length:
        bp = 1.0
    else:
        bp = math.exp(1 - ref_length / max(1, hyp_length))
    
    return float(geo_mean * bp)


@torch.no_grad()
def evaluate_nll(
    loader: DataLoader,
    model: nn.Module,
    pad_id_tgt: int,
    device: torch.device
):
    """Evaluate negative log-likelihood on a dataset."""
    criterion = nn.CrossEntropyLoss(ignore_index=pad_id_tgt, reduction='sum')
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    
    for src, src_lens, tgt_in, tgt_out, tgt_lens in loader:
        src = src.to(device)
        src_lens = src_lens.to(device)
        tgt_in = tgt_in.to(device)
        tgt_out = tgt_out.to(device)
        
        logits = model(src, src_lens, tgt_in)
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt_out.reshape(-1)
        )
        
        total_loss += float(loss.item())
        total_tokens += int((tgt_out != pad_id_tgt).sum().item())
    
    return total_loss, total_tokens


@torch.no_grad()
def evaluate_bleu(
    loader: DataLoader,
    model: nn.Module,
    tgt_itos: List[str],
    sos_id: int,
    eos_id: int,
    device: torch.device,
    max_len: int = 100
):
    """Evaluate BLEU score on a dataset."""
    model.eval()
    references = []
    hypotheses = []
    
    for src, src_lens, tgt_in, tgt_out, tgt_lens in loader:
        src = src.to(device)
        src_lens = src_lens.to(device)
        
        predictions = model.greedy_decode(
            src, src_lens,
            max_len=max_len,
            sos_id=sos_id,
            eos_id=eos_id
        )
        
        for b in range(src.size(0)):
            ref_ids = tgt_out[b].tolist()
            hyp_ids = predictions[b].tolist()
            
            # Truncate at EOS
            if eos_id in ref_ids:
                ref_ids = ref_ids[:ref_ids.index(eos_id)]
            if eos_id in hyp_ids:
                hyp_ids = hyp_ids[:hyp_ids.index(eos_id)]
            
            # Convert to tokens, filter padding
            ref_tokens = [tgt_itos[i] for i in ref_ids if i != 0]
            hyp_tokens = [tgt_itos[i] for i in hyp_ids if i != 0 and i != sos_id]
            
            references.append(ref_tokens)
            hypotheses.append(hyp_tokens)
    
    return float(corpus_bleu(references, hypotheses))

## 2. Paths and Hyperparameters

In [None]:
set_seed(42)

# Data paths
train_path = 'dataset_splits/train.txt'
val_path = 'dataset_splits/val.txt'
public_test_path = 'dataset_splits/public_test.txt'

# Fallback for alternative naming
if not os.path.exists(public_test_path):
    alt = 'dataset_splits/test_public.txt'
    public_test_path = alt if os.path.exists(alt) else public_test_path

private_test_path = 'dataset_splits/private_test.txt'

# Vocabulary sizes
src_vocab_size = 30000
tgt_vocab_size = 30000

# Model hyperparameters
emb_dim = 256
hid_dim = 512
layers = 1
dropout = 0.1

# Training hyperparameters
batch_size = 64
epochs = 5
lr = 3e-4
max_decode_len = 100

# Checkpoint directory
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

print('Public test path:', public_test_path)

## 3. Load Data and Build Vocab

In [None]:
print('Loading splits...')
train_pairs = read_split(train_path)
val_pairs = read_split(val_path)
test_pairs = read_split(public_test_path)

print(f'Train: {len(train_pairs):,} | Val: {len(val_pairs):,} | Public test: {len(test_pairs):,}')

# Build vocabularies
src_stoi = build_vocab(
    (src for src, _ in train_pairs),
    max_size=src_vocab_size
)
tgt_stoi = build_vocab(
    (tgt for _, tgt in train_pairs),
    max_size=tgt_vocab_size
)

# Special token IDs
pad_id_src = src_stoi[SPECIAL_TOKENS['pad']]
pad_id_tgt = tgt_stoi[SPECIAL_TOKENS['pad']]
sos_id = tgt_stoi[SPECIAL_TOKENS['sos']]
eos_id = tgt_stoi[SPECIAL_TOKENS['eos']]

# Create datasets
train_ds = TranslationDataset(train_pairs, src_stoi, tgt_stoi)
val_ds = TranslationDataset(val_pairs, src_stoi, tgt_stoi)
test_ds = TranslationDataset(test_pairs, src_stoi, tgt_stoi)

# Create dataloaders
collate = lambda batch: collate_pad(batch, pad_id_src, pad_id_tgt)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate,
    num_workers=0
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate,
    num_workers=0
)
test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate,
    num_workers=0
)

# Create inverse vocabulary for target (index to string)
tgt_itos = [None] * len(tgt_stoi)
for word, idx in tgt_stoi.items():
    if 0 <= idx < len(tgt_itos):
        tgt_itos[idx] = word

print('Vocab sizes — src:', len(src_stoi), 'tgt:', len(tgt_stoi))

## 4. Build Model (Your Playground)
Keep the forward/greedy_decode contract so evaluation works. Try adding attention, GRU, Transformer, etc.

In [None]:
class Encoder(nn.Module):
    """LSTM encoder with packed sequences."""
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers=1, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.rnn = nn.LSTM(
            emb_dim,
            hid_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
    
    def forward(self, src, src_lens):
        # Embed tokens
        embedded = self.embedding(src)
        
        # Pack sequence for efficiency
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded,
            src_lens.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        
        # Process through LSTM
        outputs, (hidden, cell) = self.rnn(packed)
        
        # Unpack sequence
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        
        return outputs, (hidden, cell)


class Decoder(nn.Module):
    """LSTM decoder with output projection."""
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers=1, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.rnn = nn.LSTM(
            emb_dim,
            hid_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        self.projection = nn.Linear(hid_dim, vocab_size)
    
    def forward(self, tgt_in, hidden):
        # Embed tokens
        embedded = self.embedding(tgt_in)
        
        # Process through LSTM
        outputs, hidden = self.rnn(embedded, hidden)
        
        # Project to vocabulary
        logits = self.projection(outputs)
        
        return logits, hidden


class Seq2Seq(nn.Module):
    """Sequence-to-sequence model."""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src, src_lens, tgt_in):
        """Teacher forcing forward pass."""
        # Encode source
        _, hidden = self.encoder(src, src_lens)
        
        # Decode with teacher forcing
        logits, _ = self.decoder(tgt_in, hidden)
        
        return logits
    
    @torch.no_grad()
    def greedy_decode(self, src, src_lens, max_len, sos_id, eos_id):
        """Greedy decoding for inference."""
        batch_size = src.size(0)
        
        # Encode source
        _, hidden = self.encoder(src, src_lens)
        
        # Initialize with SOS token
        inputs = torch.full(
            (batch_size, 1),
            sos_id,
            dtype=torch.long,
            device=src.device
        )
        
        outputs = []
        
        # Generate tokens one by one
        for _ in range(max_len):
            # Decode one step
            logits, hidden = self.decoder(inputs[:, -1:].contiguous(), hidden)
            
            # Greedy selection
            next_token = logits[:, -1, :].argmax(-1, keepdim=True)
            outputs.append(next_token)
            
            # Append to input for next step
            inputs = torch.cat([inputs, next_token], dim=1)
        
        # Concatenate all outputs
        sequences = torch.cat(outputs, dim=1)
        
        # Truncate at EOS for each sequence
        for i in range(batch_size):
            row = sequences[i]
            if (row == eos_id).any():
                eos_idx = (row == eos_id).nonzero(as_tuple=False)[0].item()
                row[eos_idx + 1:] = eos_id
        
        return sequences


# Instantiate model
encoder = Encoder(
    len(src_stoi),
    emb_dim,
    hid_dim,
    num_layers=layers,
    dropout=dropout
)
decoder = Decoder(
    len(tgt_stoi),
    emb_dim,
    hid_dim,
    num_layers=layers,
    dropout=dropout
)
model = Seq2Seq(encoder, decoder).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Model parameters: {num_params:,}')

## 5. Train

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_id_tgt, reduction='sum')

for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0.0
    total_tokens = 0
    
    for src, src_lens, tgt_in, tgt_out, tgt_lens in train_loader:
        # Move to device
        src = src.to(device)
        src_lens = src_lens.to(device)
        tgt_in = tgt_in.to(device)
        tgt_out = tgt_out.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(src, src_lens, tgt_in)
        
        # Compute loss
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt_out.reshape(-1)
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Accumulate statistics
        total_loss += float(loss.item())
        total_tokens += int((tgt_out != pad_id_tgt).sum().item())
    
    # Compute perplexities
    train_ppl = compute_perplexity(total_loss, total_tokens)
    
    val_loss, val_tokens = evaluate_nll(val_loader, model, pad_id_tgt, device)
    val_ppl = compute_perplexity(val_loss, val_tokens)
    
    print(f'Epoch {epoch:02d} | train ppl: {train_ppl:.2f} | val ppl: {val_ppl:.2f}')

# Save checkpoint
checkpoint = {
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'epoch': epochs,
    'src_stoi': src_stoi,
    'tgt_stoi': tgt_stoi,
    'model_cfg': {
        'emb': emb_dim,
        'hid': hid_dim,
        'layers': layers,
        'dropout': dropout
    }
}

checkpoint_path = os.path.join(save_dir, 'checkpoint_last.pt')
torch.save(checkpoint, checkpoint_path)
print('Saved checkpoint:', checkpoint_path)

## 6. Evaluate: Perplexity and BLEU (Public Test)

In [None]:
# Validation set
val_loss, val_tokens = evaluate_nll(val_loader, model, pad_id_tgt, device)
val_ppl = compute_perplexity(val_loss, val_tokens)

# Public test set
test_loss, test_tokens = evaluate_nll(test_loader, model, pad_id_tgt, device)
test_ppl = compute_perplexity(test_loss, test_tokens)

# BLEU score
bleu = evaluate_bleu(
    test_loader,
    model,
    tgt_itos,
    sos_id=sos_id,
    eos_id=eos_id,
    device=device,
    max_len=max_decode_len
)

print(f'Validation perplexity: {val_ppl:.2f}')
print(f'Public test perplexity: {test_ppl:.2f}')
print(f'Public test BLEU:       {bleu*100:.2f}')

## 7. Private Test (Optional)

In [None]:
if os.path.exists(private_test_path):
    prv_pairs = read_split(private_test_path)
    prv_ds = TranslationDataset(prv_pairs, src_stoi, tgt_stoi)
    prv_loader = DataLoader(
        prv_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate,
        num_workers=0
    )
    
    prv_loss, prv_tokens = evaluate_nll(prv_loader, model, pad_id_tgt, device)
    prv_ppl = compute_perplexity(prv_loss, prv_tokens)
    
    prv_bleu = evaluate_bleu(
        prv_loader,
        model,
        tgt_itos,
        sos_id=sos_id,
        eos_id=eos_id,
        device=device,
        max_len=max_decode_len
    )
    
    print(f'Private test perplexity: {prv_ppl:.2f}')
    print(f'Private test BLEU:       {prv_bleu*100:.2f}')
else:
    print('Private test split not found at', private_test_path)

## 8. Export Predictions for ML‑Arena (Optional)

In [None]:
@torch.no_grad()
def decode_to_lines(
    loader: DataLoader,
    model: nn.Module,
    tgt_itos: List[str],
    sos_id: int,
    eos_id: int,
    device: torch.device,
    max_len: int
) -> List[str]:
    """Decode all batches to strings."""
    lines: List[str] = []
    
    for src, src_lens, tgt_in, tgt_out, tgt_lens in loader:
        src = src.to(device)
        src_lens = src_lens.to(device)
        
        predictions = model.greedy_decode(
            src, src_lens,
            max_len=max_len,
            sos_id=sos_id,
            eos_id=eos_id
        )
        
        for b in range(src.size(0)):
            hyp_ids = predictions[b].tolist()
            
            # Truncate at EOS
            if eos_id in hyp_ids:
                hyp_ids = hyp_ids[:hyp_ids.index(eos_id)]
            
            # Convert to tokens, filter padding and SOS
            tokens = [
                tgt_itos[i] for i in hyp_ids
                if i != 0 and i != sos_id
            ]
            
            lines.append(' '.join(tokens))
    
    return lines


# Configuration
export_split = 'private'  # 'public' or 'private'
export_format = 'tsv'     # 'tsv' or 'jsonl'
export_out = 'submissions/private_predictions.tsv'

# Create output directory
os.makedirs(os.path.dirname(export_out) or '.', exist_ok=True)

# Load appropriate split
if export_split == 'public':
    pairs = read_split(public_test_path)
else:
    pairs = read_split(private_test_path)

exp_ds = TranslationDataset(pairs, src_stoi, tgt_stoi)
exp_loader = DataLoader(
    exp_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate,
    num_workers=0
)

# Generate predictions
predictions = decode_to_lines(
    exp_loader,
    model,
    tgt_itos,
    sos_id=sos_id,
    eos_id=eos_id,
    device=device,
    max_len=max_decode_len
)

# Export predictions
if export_format == 'tsv':
    with open(export_out, 'w', encoding='utf-8') as f:
        for i, hypothesis in enumerate(predictions):
            f.write(f'{i}\t{hypothesis}\n')
elif export_format == 'jsonl':
    import json
    with open(export_out, 'w', encoding='utf-8') as f:
        for i, hypothesis in enumerate(predictions):
            json_obj = {'id': i, 'hyp': hypothesis}
            f.write(json.dumps(json_obj, ensure_ascii=False) + '\n')

print(f'Wrote {len(predictions)} predictions to {export_out}')
print('Adjust if ML‑Arena requires a different schema.')