# HW2 NMT Evaluation
Load the trained checkpoint, compute automatic metrics, and inspect sample translations.

In [1]:
MODEL_NR = [0, 1, 2][0]
print(f'You chose model {MODEL_NR} - LOL')

You chose model 0 - LOL


In [19]:
from __future__ import annotations
import math
import os
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

def set_seed(seed: int = 42) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    import random
    random.seed(seed)

SPECIAL_TOKENS: Dict[str, str] = {
    'pad': '<pad>',
    'sos': '<sos>',
    'eos': '<eos>',
    'unk': '<unk>'
}

def simple_tokenize(text: str) -> List[str]:
    return text.strip().lower().split()

def read_split(path: Path) -> List[Tuple[List[str], List[str]]]:
    pairs: List[Tuple[List[str], List[str]]] = []
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            parts = line.rstrip('\n').split('\t')
            if len(parts) < 2:
                continue
            pairs.append((simple_tokenize(parts[0]), simple_tokenize(parts[1])))
    return pairs

def encode(tokens: List[str], stoi: Dict[str, int], add_sos_eos: bool = False) -> List[int]:
    ids = [stoi.get(tok, stoi[SPECIAL_TOKENS['unk']]) for tok in tokens]
    if add_sos_eos:
        ids = [stoi[SPECIAL_TOKENS['sos']]] + ids + [stoi[SPECIAL_TOKENS['eos']]]
    return ids

class Example:
    def __init__(self, src_ids: List[int], tgt_in_ids: List[int], tgt_out_ids: List[int]):
        self.src_ids = src_ids
        self.tgt_in_ids = tgt_in_ids
        self.tgt_out_ids = tgt_out_ids

class TranslationDataset(Dataset):
    def __init__(self, pairs, src_stoi, tgt_stoi):
        self.examples: List[Example] = []
        for src, tgt in pairs:
            src_seq = encode(src, src_stoi) + [src_stoi[SPECIAL_TOKENS['eos']]]
            tgt_seq = encode(tgt, tgt_stoi, add_sos_eos=True)
            self.examples.append(Example(src_seq, tgt_seq[:-1], tgt_seq[1:]))

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Example:
        return self.examples[idx]

def collate_pad(batch, pad_id_src: int, pad_id_tgt: int):
    src_max = max(len(x.src_ids) for x in batch)
    tgt_max = max(len(x.tgt_in_ids) for x in batch)

    def pad(seq: List[int], length: int, pad_id: int) -> List[int]:
        return seq + [pad_id] * (length - len(seq))

    src = torch.tensor([pad(x.src_ids, src_max, pad_id_src) for x in batch])
    tgt_in = torch.tensor([pad(x.tgt_in_ids, tgt_max, pad_id_tgt) for x in batch])
    tgt_out = torch.tensor([pad(x.tgt_out_ids, tgt_max, pad_id_tgt) for x in batch])
    src_l = torch.tensor([len(x.src_ids) for x in batch])
    tgt_l = torch.tensor([len(x.tgt_out_ids) for x in batch])
    return src, src_l, tgt_in, tgt_out, tgt_l

set_seed(42)


Using device: cuda


In [20]:
HOME_DIR = Path(os.getcwd())
WORK_DIR = HOME_DIR / 'DLNLP25W'
if os.name == 'posix':
    # then create a folder named DLNLP25W
    # if folder exists then dont create it
    if not os.path.exists(WORK_DIR):
        os.makedirs(WORK_DIR)
    os.chdir(WORK_DIR)
print(os.listdir())

[]


In [None]:
# HELPER FUNCS

@torch.no_grad()
def evaluate_nll(loader: DataLoader, model: nn.Module, pad_id_tgt: int, device: torch.device):
    criterion = nn.CrossEntropyLoss(ignore_index=pad_id_tgt, reduction='sum')
    model.eval()
    total = 0.0
    tokens = 0
    for src, src_l, tgt_in, tgt_out, tgt_l in loader:
        src, src_l = src.to(device), src_l.to(device)
        tgt_in, tgt_out = tgt_in.to(device), tgt_out.to(device)
        logits = model(src, src_l, tgt_in)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        total += float(loss.item())
        tokens += int((tgt_out != pad_id_tgt).sum().item())
    return total, tokens

def compute_perplexity(loss_sum: float, token_count: int) -> float:
    if token_count == 0:
        return float('inf')
    try:
        return float(math.exp(loss_sum / token_count))
    except OverflowError:
        return float('inf')

def corpus_bleu(refs, hyps, max_order: int = 4, smooth: bool = True) -> float:
    from collections import Counter

    def ngrams(sequence, n):
        return Counter(tuple(sequence[i:i + n]) for i in range(len(sequence) - n + 1))

    matches = [0] * max_order
    possible = [0] * max_order
    ref_len = 0
    hyp_len = 0

    for ref, hyp in zip(refs, hyps):
        ref_len += len(ref)
        hyp_len += len(hyp)
        for n in range(1, max_order + 1):
            ref_ngrams = ngrams(ref, n)
            hyp_ngrams = ngrams(hyp, n)
            matches[n - 1] += sum(min(count, hyp_ngrams[gram]) for gram, count in ref_ngrams.items())
            possible[n - 1] += max(len(hyp) - n + 1, 0)

    precisions = []
    for m, p in zip(matches, possible):
        if smooth:
            precisions.append((m + 1) / (p + 1))
        else:
            precisions.append(m / p if p > 0 else 0.0)

    if min(precisions) <= 0:
        geo_mean = 0.0
    else:
        geo_mean = math.exp(sum(math.log(val) for val in precisions) / max_order)

    bp = 1.0 if hyp_len > ref_len else math.exp(1 - ref_len / max(1, hyp_len))
    return float(geo_mean * bp)

def lcs_length(x: List[str], y: List[str]) -> int:
    if not x or not y:
        return 0
    m, n = len(x), len(y)
    dp = [0] * (n + 1)
    for i in range(1, m + 1):
        prev = 0
        for j in range(1, n + 1):
            tmp = dp[j]
            if x[i - 1] == y[j - 1]:
                dp[j] = prev + 1
            else:
                dp[j] = max(dp[j], dp[j - 1])
            prev = tmp
    return dp[-1]

def compute_rouge_l(refs, hyps) -> float:
    if not refs:
        return 0.0
    scores = []
    for ref, hyp in zip(refs, hyps):
        if not ref or not hyp:
            scores.append(0.0)
            continue
        l = lcs_length(ref, hyp)
        prec = l / len(hyp)
        rec = l / len(ref)
        if prec + rec == 0:
            scores.append(0.0)
        else:
            scores.append((2 * prec * rec) / (prec + rec))
    return float(sum(scores) / len(scores))

def ids_to_tokens(ids, itos, pad_id, sos_id=None, eos_id=None):
    tokens = []
    for idx in ids:
        if idx == pad_id:
            continue
        if eos_id is not None and idx == eos_id:
            break
        if sos_id is not None and idx == sos_id:
            continue
        if 0 <= idx < len(itos) and itos[idx] is not None:
            tokens.append(itos[idx])
    return tokens

def build_itos(stoi: Dict[str, int]) -> List[str]:
    size = max(stoi.values()) + 1
    itos = [None] * size
    for token, idx in stoi.items():
        if idx >= len(itos):
            itos.extend([None] * (idx - len(itos) + 1))
        itos[idx] = token
    return itos

@torch.no_grad()
def gather_predictions(loader, model, src_itos, tgt_itos, pad_id_src, pad_id_tgt, sos_id, eos_id, src_eos_id, device, max_len=100, sample_cap=32):
    model.eval()
    refs, hyps, samples = [], [], []
    for src, src_l, tgt_in, tgt_out, tgt_l in loader:
        src, src_l = src.to(device), src_l.to(device)
        tgt_out = tgt_out.to(device)
        preds = model.greedy_decode(src, src_l, max_len=max_len, sos_id=sos_id, eos_id=eos_id)
        preds = preds.cpu()
        src = src.cpu()
        tgt_out = tgt_out.cpu()
        src_l = src_l.cpu()
        for b in range(src.size(0)):
            src_ids = src[b][:src_l[b]].tolist()
            ref_ids = tgt_out[b].tolist()
            hyp_ids = preds[b].tolist()
            src_tokens = ids_to_tokens(src_ids, src_itos, pad_id_src, eos_id=src_eos_id)
            ref_tokens = ids_to_tokens(ref_ids, tgt_itos, pad_id_tgt, sos_id=sos_id, eos_id=eos_id)
            hyp_tokens = ids_to_tokens(hyp_ids, tgt_itos, pad_id_tgt, sos_id=sos_id, eos_id=eos_id)
            refs.append(ref_tokens)
            hyps.append(hyp_tokens)
            if len(samples) < sample_cap:
                samples.append({
                    'src': ' '.join(src_tokens),
                    'ref': ' '.join(ref_tokens),
                    'hyp': ' '.join(hyp_tokens)
                })
    return {
        'refs': refs,
        'hyps': hyps,
        'samples': samples
    }


In [29]:
if MODEL_NR == 0:
    class Encoder(nn.Module):
        def __init__(self, vocab_size, emb_dim, hid_dim, num_layers=1, dropout=0.1):
            super().__init__()
            self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
            self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.0)

        def forward(self, src, src_lens):
            emb = self.emb(src)
            packed = nn.utils.rnn.pack_padded_sequence(emb, src_lens.cpu(), batch_first=True, enforce_sorted=False)
            out, (h, c) = self.rnn(packed)
            out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
            return out, (h, c)

    class Decoder(nn.Module):
        def __init__(self, vocab_size, emb_dim, hid_dim, num_layers=1, dropout=0.1):
            super().__init__()
            self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
            self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.0)
            self.proj = nn.Linear(hid_dim, vocab_size)

        def forward(self, tgt_in, hidden):
            emb = self.emb(tgt_in)
            out, hidden = self.rnn(emb, hidden)
            return self.proj(out), hidden

    class Seq2Seq(nn.Module):
        def __init__(self, enc, dec):
            super().__init__()
            self.encoder = enc
            self.decoder = dec

        def forward(self, src, src_lens, tgt_in):
            _, h = self.encoder(src, src_lens)
            logits, _ = self.decoder(tgt_in, h)
            return logits

        @torch.no_grad()
        def greedy_decode(self, src, src_lens, max_len, sos_id, eos_id):
            B = src.size(0)
            _, h = self.encoder(src, src_lens)
            inputs = torch.full((B, 1), sos_id, dtype=torch.long, device=src.device)
            outs = []
            for _ in range(max_len):
                logits, h = self.decoder(inputs[:, -1:].contiguous(), h)
                nxt = logits[:, -1, :].argmax(-1, keepdim=True)
                outs.append(nxt)
                inputs = torch.cat([inputs, nxt], dim=1)
            
            seqs = torch.cat(outs, dim=1)
            for i in range(B):
                row = seqs[i]
                if (row == eos_id).any():
                    idx = (row == eos_id).nonzero(as_tuple=False)[0].item()
                    row[idx + 1:] = eos_id
            return seqs
        
# MODEL 1: Using GRU, torch.float16 and 
if MODEL_NR == 1:
    class Encoder(nn.Module):
        pass


    class Decoder(nn.Module):
        pass


    class Seq2Seq(nn.Module):
        pass

# MODEL 2: Using GRU, torch.float16, and student-teacher training based on fine-tuned BERT
if MODEL_NR == 2:
    class Encoder(nn.Module):
        pass


    class Decoder(nn.Module):
        pass


    class Seq2Seq(nn.Module):
        pass

In [30]:
def build_model(cfg: Dict, src_vocab_size: int, tgt_vocab_size: int) -> nn.Module:
    emb_dim = cfg.get('emb', 256)
    hid_dim = cfg.get('hid', 512)
    layers = cfg.get('layers', 1)
    dropout = cfg.get('dropout', 0.1)
    encoder = Encoder(src_vocab_size, emb_dim, hid_dim, num_layers=layers, dropout=dropout)
    decoder = Decoder(tgt_vocab_size, emb_dim, hid_dim, num_layers=layers, dropout=dropout)
    return Seq2Seq(encoder, decoder)


In [31]:
checkpoint_candidates = [
    Path(f'checkpoints/checkpoint_last_{MODEL_NR}.pt'),
    Path(f'../checkpoints/checkpoint_last_{MODEL_NR}.pt'),
    Path(f'../../checkpoints/checkpoint_last_{MODEL_NR}.pt')
]
checkpoint_path = next((p for p in checkpoint_candidates if p.exists()), None)
if checkpoint_path is None:
    raise FileNotFoundError(f'Could not locate checkpoint_last_{MODEL_NR}.pt in expected directories.')
print('Loading checkpoint from', checkpoint_path)


Loading checkpoint from ../checkpoints/checkpoint_last_0.pt


In [32]:
checkpoint = torch.load(checkpoint_path, map_location=device)
model_cfg = checkpoint.get('model_cfg', {})
src_stoi = dict(checkpoint['src_stoi'])
tgt_stoi = dict(checkpoint['tgt_stoi'])
src_vocab_size = max(src_stoi.values()) + 1
tgt_vocab_size = max(tgt_stoi.values()) + 1
model = build_model(model_cfg, src_vocab_size, tgt_vocab_size).to(device)
model.load_state_dict(checkpoint['model_state'])
model.eval()

print(f'Built model {MODEL_NR}')

pad_id_src = src_stoi[SPECIAL_TOKENS['pad']]
pad_id_tgt = tgt_stoi[SPECIAL_TOKENS['pad']]
sos_id = tgt_stoi[SPECIAL_TOKENS['sos']]
eos_id = tgt_stoi[SPECIAL_TOKENS['eos']]
src_eos_id = src_stoi[SPECIAL_TOKENS['eos']]

src_itos = build_itos(src_stoi)
tgt_itos = build_itos(tgt_stoi)
max_decode_len = checkpoint.get('max_decode_len', 100)

Built model 0


In [33]:
data_dir_candidates = [
    Path('../data'),
    Path('../../data'),
    Path('../dataset_splits'),
    Path('../../dataset_splits')
]

def resolve_split(filename: str) -> Path:
    for base in data_dir_candidates:
        candidate = (base / filename).resolve()
        if candidate.exists():
            return candidate
    raise FileNotFoundError(f'Could not find {filename} in expected data directories.')

split_files = {
    'validation': 'val.txt',
    'public_test': 'public_test.txt'
}

datasets = {}
for split_name, filename in split_files.items():
    path = resolve_split(filename)
    pairs = read_split(path)
    datasets[split_name] = TranslationDataset(pairs, src_stoi, tgt_stoi)
    print(f'{split_name.title()} examples: {len(datasets[split_name])}')

batch_size = 64
collate_fn = lambda batch: collate_pad(batch, pad_id_src, pad_id_tgt)
loaders = {
    split: DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
    for split, ds in datasets.items()
}


Validation examples: 32428
Public_Test examples: 32428


In [34]:
metrics_summary = {}
for split_name, loader in loaders.items():
    loss, tokens = evaluate_nll(loader, model, pad_id_tgt, device)
    ppl = compute_perplexity(loss, tokens)
    decoded = gather_predictions(
        loader, model, src_itos, tgt_itos,
        pad_id_src, pad_id_tgt, sos_id, eos_id, src_eos_id,
        device, max_len=max_decode_len, sample_cap=64
    )
    bleu = corpus_bleu(decoded['refs'], decoded['hyps']) * 100
    rouge = compute_rouge_l(decoded['refs'], decoded['hyps']) * 100
    metrics_summary[split_name] = {
        'perplexity': ppl,
        'bleu': bleu,
        'rouge_l': rouge,
        'samples': decoded['samples']
    }
    print(f"{split_name.title()} =>\t Perplexity: {ppl:.2f} | BLEU: {bleu:.2f} | ROUGE-L: {rouge:.2f}")


Validation =>	 Perplexity: 6.15 | BLEU: 25.19 | ROUGE-L: 53.60
Public_Test =>	 Perplexity: 6.11 | BLEU: 25.50 | ROUGE-L: 53.84


In [None]:
example_split = 'validation' if 'validation' in metrics_summary else next(iter(metrics_summary))
samples = metrics_summary[example_split]['samples'][:10]

print(
    f'Showing {len(samples)} example translations from the {example_split} split:'
)

for idx, sample in enumerate(samples, 1):
    print(f'[{idx}] SRC: {sample["src"]}')
    print(f'    REF: {sample["ref"]}')
    print(f'    HYP: {sample["hyp"]}')

if len(samples) < 10:
    print('Fewer than 10 examples available in this split.')


Showing 10 example translations from the validation split:
[1] SRC: she turned around when she heard his voice.
    REF: sie drehte sich um, als sie seine stimme hörte.
    HYP: sie drehte sich um die <unk> als sie sich zu küssen.
[2] SRC: do you remember the first time we went to boston together?
    REF: erinnerst du dich an das erste mal, dass wir gemeinsam nach boston <unk>
    HYP: weißt du noch, wie wir das erstemal um uns noch nach boston gegangen ist?
[3] SRC: he is old.
    REF: er ist alt.
    HYP: er ist alt.
[4] SRC: i can still remember a few french words.
    REF: ein paar worte auf französisch kann ich noch.
    HYP: ich kann noch immer noch ein paar französische lieder singen.
[5] SRC: we helped him.
    REF: wir halfen ihm.
    HYP: wir haben ihm geholfen.
[6] SRC: you should've never come here.
    REF: sie hätten niemals hierherkommen sollen.
    HYP: du hättest nie hierherkommen sollen.
[7] SRC: did you buy what you wanted?
    REF: habt ihr das, was ihr <unk> gekau