# Seq2Seq Machine Translation (PyTorch)

This notebook implements a **Sequence-to-Sequence model with Attention** for machine translation using **PyTorch**.

**Architecture:**
- **Encoder**: Bidirectional LSTM
- **Decoder**: LSTM with Bahdanau-style attention
- **Dataset**: WMT14 (source/target vocab pre-tokenized as integer IDs)

**Original code** was written in MindSpore — this version is fully re-implemented in PyTorch.

In [None]:
# ============================================================
# Cell 1: Imports & Device Setup
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pickle
import codecs
import os
from collections import defaultdict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Hyperparameters & Configuration

In [None]:
# ============================================================
# Cell 2: Hyperparameters
# ============================================================
# ----- Paths -----
DATA_PATH = "./WMT14/"            # Root folder for dataset
SAVE_DIR  = "./model_WMT14/"      # Where to save checkpoints

# ----- Model -----
EMB_DIM   = 128
N_HIDDEN  = 256
N_LAYER   = 1
DROPOUT   = 0.0

# ----- Training -----
BATCH_SIZE  = 32
EPOCHS      = 10
LR          = 1e-3
CLIP        = 5.0
PRINT_FREQ  = 100
CKPT_FREQ   = 200
PATIENCE    = 5
USE_AMP     = True   # Mixed-precision training (set False if no GPU)

os.makedirs(SAVE_DIR, exist_ok=True)

## 2. Load Vocabularies

In [None]:
# ============================================================
# Cell 3: Load Vocabularies
# ============================================================
with open(os.path.join(DATA_PATH, 'raw', 'src_vocab.pkl'), "rb") as f:
    src_vocab = pickle.load(f)
with open(os.path.join(DATA_PATH, 'raw', 'tgt_vocab.pkl'), "rb") as f:
    tgt_vocab = pickle.load(f)

# Ensure <pad> is index 0
src_vocab['<pad>'] = 0
tgt_vocab['<pad>'] = 0

# Build reverse target vocab for decoding
tgt_idx2word = {v: k for k, v in tgt_vocab.items()}

print(f"Source vocab size: {len(src_vocab)}")
print(f"Target vocab size: {len(tgt_vocab)}")

In [None]:
# ============================================================
# Cell 3b: Load Vocabularies from TEXT files  (alternative)
# ============================================================
# Use this instead of Cell 3 if your vocab files are plain text
# with one word per line (not .pkl).

def load_vocab(path):
    """Load a vocab text file (one word per line) → dict[str, int]."""
    word2id = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            word = line.strip()
            if word and word not in word2id:
                word2id[word] = len(word2id)
    # Ensure special tokens exist
    for tok in ['<pad>', '<unk>', '<s>', '</s>']:
        if tok not in word2id:
            word2id[tok] = len(word2id)
    # Force <pad> = 0 (swap if needed)
    if word2id['<pad>'] != 0:
        other_word = [k for k, v in word2id.items() if v == 0][0]
        old_pad_id = word2id['<pad>']
        word2id[other_word] = old_pad_id
        word2id['<pad>'] = 0
    return word2id

src_vocab = load_vocab(os.path.join(DATA_PATH, 'raw', 'src_vocab.txt'))
tgt_vocab = load_vocab(os.path.join(DATA_PATH, 'raw', 'tgt_vocab.txt'))

# Build reverse target vocab for decoding
tgt_idx2word = {v: k for k, v in tgt_vocab.items()}

print(f"Source vocab size: {len(src_vocab)}")
print(f"Target vocab size: {len(tgt_vocab)}")
print(f"Sample src entries: {dict(list(src_vocab.items())[:5])}")
print(f"Sample tgt entries: {dict(list(tgt_vocab.items())[:5])}")

## 3. Dataset & DataLoader

In [None]:
# ============================================================
# Cell 4: Dataset class  (replaces data.py)
# ============================================================

def tensorize(sents_batch, word2id, device=device):
    """Pad a list of integer-id sentences to the same length and return a LongTensor."""
    batch_size = len(sents_batch)
    max_len = max(len(s) for s in sents_batch)
    PAD = word2id['<pad>']
    batch = torch.full((batch_size, max_len), PAD, dtype=torch.long, device=device)
    for i, sent in enumerate(sents_batch):
        for j, tok in enumerate(sent):
            batch[i, j] = tok
    return batch


def prepro_batch(src_word, tgt_word, batch, device=device):
    """Prepare a raw (sources, targets) batch for the model."""
    sources, abstracts = batch
    inp_lengths = torch.tensor([len(s) for s in sources], dtype=torch.long, device=device)

    # Prepend <s> for decoder input, append </s> for decoder target
    tgt_in  = [[tgt_word["<s>"]] + t for t in abstracts]
    tgt_out = [t + [tgt_word["</s>"]] for t in abstracts]

    sources = tensorize(sources, src_word, device)
    tgt_in  = tensorize(tgt_in,  tgt_word, device)
    tgt_out = tensorize(tgt_out, tgt_word, device)

    return (sources, inp_lengths, tgt_in), tgt_out


class GigaDataset:
    """Simple batch-iterator dataset (mirrors the original MindSpore version)."""

    def __init__(self, path, split, batch_size, src_word, tgt_word):
        self.batch_size = batch_size
        self.src_word = src_word
        self.tgt_word = tgt_word
        assert split in ('train', 'val', 'test')

        src_file = f"src_{split}.txt"
        tgt_file = f"tgt_{split}.txt"

        with codecs.open(os.path.join(path, 'raw', src_file), "r", encoding="utf-8") as f:
            source_data = f.readlines()
        with codecs.open(os.path.join(path, 'raw', tgt_file), "r", encoding="utf-8") as f:
            target_data = f.readlines()

        source_data = [line.rstrip('\r\n').split(' ') for line in source_data]
        target_data = [line.rstrip('\r\n').split(' ') for line in target_data]

        # Sort training data by source length for efficient batching
        if split == 'train':
            data = sorted(zip(source_data, target_data), key=lambda x: len(x[0]))
            source_data = [x[0] for x in data]
            target_data = [x[1] for x in data]

        self.src = source_data
        self.tgt = target_data
        assert len(self.src) == len(self.tgt)

        self.cur_ind = 0
        self.tot_batch = len(self.src) // self.batch_size

    def __len__(self):
        return len(self.src)

    def next_batch(self):
        if self.cur_ind + self.batch_size >= len(self.src):
            self.cur_ind = 0

        upper = min(self.cur_ind + self.batch_size, len(self.src))
        src = [[int(x) for x in self.src[i]] for i in range(self.cur_ind, upper)]
        tgt = [[int(x) for x in self.tgt[i]] for i in range(self.cur_ind, upper)]

        self.cur_ind = upper if upper < len(self.src) else 0

        return prepro_batch(self.src_word, self.tgt_word, [src, tgt])

print("Dataset class ready.")

## 4. Model Definition — Seq2Seq with Attention (PyTorch)

The model consists of:
- **Encoder**: Bidirectional LSTM that reads the source sentence
- **Attention Decoder**: LSTM decoder with Bahdanau-style additive attention over encoder outputs

In [None]:
# ============================================================
# Cell 5: Attention Decoder  (replaces AttnDecoder in Seq2Seq.py)
# ============================================================

class AttnDecoder(nn.Module):
    """LSTM decoder with Bahdanau-style attention."""

    def __init__(self, embedding, hidden_size, output_size,
                 enc_out_dim, n_layers=1, dropout=0.1):
        super().__init__()
        self.embedding = embedding
        self.n_layers = n_layers
        emb_size = embedding.embedding_dim

        self.decoder_cell = nn.LSTMCell(emb_size, hidden_size)
        self.attn = nn.Linear(enc_out_dim, hidden_size, bias=False)
        self.concat = nn.Linear(enc_out_dim + hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, target, init_states, enc_outs, attn_mask):
        """
        Args:
            target:      (batch, tgt_len)  — decoder input token ids
            init_states: tuple (h0, c0) each (batch, hidden)
            enc_outs:    (src_len, batch, enc_out_dim)
            attn_mask:   (batch, src_len)  — 1 = valid, 0 = pad
        Returns:
            logits: (batch, tgt_len, vocab)
        """
        max_len = target.size(1)
        h, c = init_states
        logits = []
        for i in range(max_len):
            inp_i = target[:, i:i+1]                       # (batch, 1)
            logit, (h, c) = self._step(inp_i, (h, c), enc_outs, attn_mask)
            logits.append(logit)
        return torch.stack(logits, dim=1)                   # (batch, tgt_len, vocab)

    def _step(self, inp, last_hidden, enc_outs, attn_mask):
        """Single decoding step."""
        h_prev, c_prev = last_hidden
        embed = self.embedding(inp).squeeze(1)              # (batch, emb)
        h_t, c_t = self.decoder_cell(embed, (h_prev, c_prev))

        # --- attention ---
        attn_scores = self._get_attn(h_t, enc_outs, attn_mask)  # (batch, 1, src_len)
        context = attn_scores.bmm(enc_outs.transpose(0, 1))     # (batch, 1, enc_dim)
        context = context.squeeze(1)                             # (batch, enc_dim)

        concat_out = torch.tanh(self.concat(torch.cat([context, h_t], dim=1)))
        logit = F.log_softmax(self.out(concat_out), dim=-1)      # (batch, vocab)
        return logit, (h_t, c_t)

    def _get_attn(self, dec_out, enc_outs, attn_mask):
        """Compute attention weights."""
        # enc_outs: (src_len, batch, enc_dim)
        query = dec_out.unsqueeze(0)                             # (1, batch, hidden)
        keys = self.attn(enc_outs)                               # (src_len, batch, hidden)
        weights = (query * keys).sum(dim=2)                      # (src_len, batch)
        weights = weights.transpose(0, 1)                        # (batch, src_len)
        weights = weights.masked_fill(attn_mask == 0, -1e18)
        weights = F.softmax(weights, dim=1).unsqueeze(1)         # (batch, 1, src_len)
        return weights

print("AttnDecoder ready.")

In [None]:
# ============================================================
# Cell 6: Seq2SeqSum  (replaces Seq2Seq.py main class)
# ============================================================

def len_mask(lens, device=device):
    """Create a (batch, max_len) boolean mask from a list/tensor of lengths."""
    max_len = int(max(lens))
    batch_size = len(lens)
    mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=device)
    for i, l in enumerate(lens):
        mask[i, :int(l)] = True
    return mask


class Seq2SeqSum(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, emb_dim,
                 n_hidden, n_layer=1, bi_enc=True, dropout=0.0):
        super().__init__()
        self.n_layer = n_layer
        self.bi_enc = bi_enc
        self.n_hidden = n_hidden

        # Shared embedding (used by both encoder & decoder)
        self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=0)

        # Encoder
        self.encoder = nn.LSTM(
            emb_dim, n_hidden, n_layer,
            bidirectional=bi_enc,
            dropout=0.0 if n_layer == 1 else dropout,
        )
        num_dirs = 2 if bi_enc else 1
        self.enc_out_dim = n_hidden * num_dirs

        # Learnable initial encoder states
        self.enc_init_h = nn.Parameter(torch.empty(n_layer * num_dirs, n_hidden).uniform_(-1e-2, 1e-2))
        self.enc_init_c = nn.Parameter(torch.empty(n_layer * num_dirs, n_hidden).uniform_(-1e-2, 1e-2))

        # Project encoder final hidden → decoder initial hidden
        self._dec_h = nn.Linear(self.enc_out_dim, n_hidden, bias=False)
        self._dec_c = nn.Linear(self.enc_out_dim, n_hidden, bias=False)

        # Decoder
        self.decoder = AttnDecoder(
            self.embedding, n_hidden, tgt_vocab_size,
            self.enc_out_dim, n_layer, dropout=dropout,
        )

    def forward(self, src, src_lengths, tgt):
        """
        Args:
            src:         (batch, src_len) — source token ids
            src_lengths: (batch,)         — true lengths (before padding)
            tgt:         (batch, tgt_len) — decoder input token ids (with <s>)
        Returns:
            logits: (batch, tgt_len, tgt_vocab)
        """
        enc_outs, init_dec_states = self.encode(src, src_lengths)
        attn_mask = len_mask(src_lengths, src.device)
        logits = self.decoder(tgt, init_dec_states, enc_outs, attn_mask)
        return logits

    def encode(self, src, src_lengths):
        """Encode source sequence and return encoder outputs + initial decoder state."""
        batch_size = src.size(0)
        # Expand learnable init states → (layers*dirs, batch, hidden)
        h0 = self.enc_init_h.unsqueeze(1).expand(-1, batch_size, -1).contiguous()
        c0 = self.enc_init_c.unsqueeze(1).expand(-1, batch_size, -1).contiguous()

        embed = self.embedding(src).transpose(0, 1)  # (src_len, batch, emb)
        enc_out, (h, c) = self.encoder(embed, (h0, c0))  # enc_out: (src_len, batch, enc_dim)

        if self.bi_enc:
            # Merge bidirectional hidden states: (2*layers, batch, hidden) → (layers, batch, 2*hidden)
            h = torch.cat(h.chunk(2, dim=0), dim=2)
            c = torch.cat(c.chunk(2, dim=0), dim=2)

        # Project to decoder hidden size
        dec_h = self._dec_h(h).squeeze(0)  # (batch, hidden)
        dec_c = self._dec_c(c).squeeze(0)
        return enc_out, (dec_h, dec_c)

    # ------------------------------------------------------------------
    # Beam search (used at inference time)
    # ------------------------------------------------------------------
    @torch.no_grad()
    def beam_decode(self, inp, src_vocab, tgt_vocab, beam_size=4):
        """
        Args:
            inp:  list of integer token ids for ONE source sentence
            beam_size: number of beams
        Returns:
            List of (token_id_list, score) sorted by score descending.
        """
        self.eval()
        inp_t = torch.tensor([inp], dtype=torch.long, device=device)
        inp_len = torch.tensor([len(inp)], dtype=torch.long, device=device)
        attn_mask = torch.ones_like(inp_t, dtype=torch.bool, device=device)

        SOS = tgt_vocab["<s>"]
        EOS = tgt_vocab["</s>"]
        tgt_vocab_size = len(tgt_vocab)
        k = 50  # max completed hypotheses

        enc_outs, (h, c) = self.encode(inp_t, inp_len)
        # Expand for beam
        h = h.expand(beam_size, -1).contiguous()       # (beam, hidden)
        c = c.expand(beam_size, -1).contiguous()

        top_k_scores = torch.zeros(beam_size, device=device)
        top_k_words = torch.full((beam_size, 1), SOS, dtype=torch.long, device=device)
        prev_words = top_k_words

        completed_seqs = []
        completed_scores = []

        for step in range(1, 33):  # max 32 decoding steps
            logit, (h, c) = self.decoder._step(
                prev_words, (h, c), enc_outs, attn_mask
            )
            log_probs = F.log_softmax(logit, dim=1)  # (beam, vocab)
            log_probs = top_k_scores.unsqueeze(1) + log_probs

            if step == 1:
                cur_beam = min(log_probs.size(1), beam_size)
                top_k_scores, top_k_ids = log_probs[0].topk(cur_beam)
            else:
                cur_beam = min(log_probs.view(-1).size(0), beam_size)
                top_k_scores, top_k_ids = log_probs.view(-1).topk(cur_beam)

            beam_idx = top_k_ids // tgt_vocab_size
            word_idx = top_k_ids % tgt_vocab_size

            top_k_words = torch.cat([top_k_words[beam_idx], word_idx.unsqueeze(1)], dim=1)

            # Separate complete / incomplete
            incomplete = [i for i, w in enumerate(word_idx) if w.item() != EOS]
            complete   = [i for i in range(len(word_idx)) if i not in incomplete]

            if complete:
                completed_seqs.extend(top_k_words[complete].tolist())
                completed_scores.extend(top_k_scores[complete].tolist())
                k -= len(complete)

            if k <= 0:
                break

            # Keep only incomplete beams
            top_k_words  = top_k_words[incomplete]
            top_k_scores = top_k_scores[incomplete]
            h = h[beam_idx[incomplete]]
            c = c[beam_idx[incomplete]]
            prev_words = top_k_words[:, -1:]

            if top_k_words.size(0) == 0:
                break

        # Sort by score
        results = sorted(zip(completed_seqs, completed_scores),
                         key=lambda x: x[1], reverse=True)
        return results

print("Seq2SeqSum model ready.")

## 5. Trainer  (replaces training.py)

Handles the training loop, validation, gradient clipping, checkpointing, and early stopping — all using PyTorch idioms.

In [None]:
# ============================================================
# Cell 7: Trainer class  (replaces training.py)
# ============================================================

class Trainer:
    def __init__(self, model, optimizer, train_loader, val_loader,
                 save_dir, clip, print_freq, ckpt_freq, patience, epochs,
                 use_amp=True):
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.save_dir = save_dir
        self.clip = clip
        self.print_freq = print_freq
        self.ckpt_freq = ckpt_freq
        self.patience = patience
        self.epochs = epochs
        self.use_amp = use_amp and torch.cuda.is_available()

        self.scaler = GradScaler(enabled=self.use_amp)
        self.step = 0
        self.cur_epoch = 1
        self.current_p = 0
        self.best_val = float('inf')

    # ---- loss -------------------------------------------------------
    @staticmethod
    def compute_loss(logits, targets, pad_idx=0):
        """NLL loss ignoring <pad> positions."""
        # logits: (batch, tgt_len, vocab)   targets: (batch, tgt_len)
        mask = (targets != pad_idx).view(-1)
        logits_flat = logits.view(-1, logits.size(2))[mask]
        targets_flat = targets.view(-1)[mask]
        return F.nll_loss(logits_flat, targets_flat)

    # ---- single training step ---------------------------------------
    def train_step(self, srcs, targets):
        src, src_lens, tgt = srcs
        self.optimizer.zero_grad()

        with autocast(enabled=self.use_amp):
            logits = self.model(src, src_lens, tgt)
            loss = self.compute_loss(logits, targets)

        self.scaler.scale(loss).backward()
        self.scaler.unscale_(self.optimizer)
        nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
        self.scaler.step(self.optimizer)
        self.scaler.update()

        self.step += 1
        return loss.item()

    # ---- validation -------------------------------------------------
    @torch.no_grad()
    def validate(self):
        self.model.eval()
        total_loss = 0.0
        val_batches = min(100, self.val_loader.tot_batch)
        for _ in range(val_batches):
            srcs, targets = self.val_loader.next_batch()
            src, src_lens, tgt = srcs
            with autocast(enabled=self.use_amp):
                logits = self.model(src, src_lens, tgt)
                loss = self.compute_loss(logits, targets)
            total_loss += loss.item()
        avg = total_loss / val_batches
        print(f"  Epoch {self.cur_epoch}  |  Val loss: {avg:.4f}")
        return avg

    # ---- checkpoint -------------------------------------------------
    def checkpoint(self):
        name = f"ckpt-{self.cur_epoch}e-{self.step}s.pt"
        path = os.path.join(self.save_dir, name)
        torch.save({
            'epoch': self.cur_epoch,
            'step': self.step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_val': self.best_val,
        }, path)
        print(f"  Checkpoint saved → {path}")

    # ---- early stopping check ---------------------------------------
    def check_stop(self, val_loss):
        if val_loss < self.best_val:
            self.best_val = val_loss
            self.checkpoint()
            self.current_p = 0
        else:
            self.current_p += 1
        return self.current_p >= self.patience

    # ---- logging ----------------------------------------------------
    def log_info(self, running_loss):
        total_steps = self.train_loader.tot_batch
        pct = 100 * self.step / total_steps
        avg = running_loss / self.print_freq
        print(f"  Epoch {self.cur_epoch} | step {self.step}/{total_steps} "
              f"({pct:.1f}%) | loss {avg:.4f}")

    # ---- main training loop ----------------------------------------
    def train(self):
        for epoch in range(1, self.epochs + 1):
            self.cur_epoch = epoch
            self.model.train()
            self.step = 0
            running_loss = 0.0

            for _ in range(self.train_loader.tot_batch):
                srcs, targets = self.train_loader.next_batch()
                step_loss = self.train_step(srcs, targets)
                running_loss += step_loss

                if self.step % self.print_freq == 0:
                    self.log_info(running_loss)
                    running_loss = 0.0

            val_loss = self.validate()
            self.checkpoint()

            if self.check_stop(val_loss):
                print("Early stopping — finished training!")
                return

        print("Reached max epochs — finished training!")

print("Trainer ready.")

## 6. Instantiate Model, Data & Optimizer

In [None]:
# ============================================================
# Cell 8: Build everything
# ============================================================

# Data loaders
train_loader = GigaDataset(DATA_PATH, 'train', BATCH_SIZE, src_vocab, tgt_vocab)
val_loader   = GigaDataset(DATA_PATH, 'val',   BATCH_SIZE, src_vocab, tgt_vocab)

print(f"Training batches : {train_loader.tot_batch}")
print(f"Validation batches: {val_loader.tot_batch}")

# Model
model = Seq2SeqSum(
    src_vocab_size=len(src_vocab),
    tgt_vocab_size=len(tgt_vocab),
    emb_dim=EMB_DIM,
    n_hidden=N_HIDDEN,
    n_layer=N_LAYER,
    dropout=DROPOUT,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters : {total_params:,}")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)

## 7. Train!

In [None]:
# ============================================================
# Cell 9: Run training
# ============================================================

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    save_dir=SAVE_DIR,
    clip=CLIP,
    print_freq=PRINT_FREQ,
    ckpt_freq=CKPT_FREQ,
    patience=PATIENCE,
    epochs=EPOCHS,
    use_amp=USE_AMP,
)

trainer.train()

## 8. Beam Search Decoding  (replaces decode.py)

Load a trained checkpoint and run beam-search on the test set, writing predictions to a file.

In [None]:
# ============================================================
# Cell 10: Beam-search decoding  (replaces decode.py)
# ============================================================

# ---- Settings (adjust after training) ----
MODEL_CKPT  = os.path.join(SAVE_DIR, "ckpt-6e-0s.pt")   # path to best checkpoint
BEAM_SIZE   = 50
OUTPUT_DIR  = "./output/"
OUT_FILE    = os.path.join(OUTPUT_DIR, "WMT14_output.txt")

os.makedirs(OUTPUT_DIR, exist_ok=True)


def load_model(ckpt_path, src_vocab, tgt_vocab):
    """Load a Seq2SeqSum model from a PyTorch checkpoint."""
    model = Seq2SeqSum(
        len(src_vocab), len(tgt_vocab),
        EMB_DIM, N_HIDDEN, N_LAYER
    ).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    return model


def run_beam_search(model, test_src, test_tgt, tgt_idx2word,
                    tgt_vocab, fout, beam_size=50):
    """Run beam search over the test set, write predictions to file."""
    for idx, (src_ids, tgt_ids) in enumerate(zip(test_src, test_tgt)):
        results = model.beam_decode(src_ids, src_vocab, tgt_vocab, beam_size)

        if len(results) < 1:
            pred = []
        else:
            pred = [tgt_idx2word.get(x, '<unk>') for x in results[0][0]]
            pred = pred[1:-1]  # strip <s> and </s>

        tgt_str = [tgt_idx2word.get(x, '<unk>') for x in tgt_ids]
        fout.write(f"{pred}\n{tgt_str}\n")

        if (idx + 1) % 50 == 0 or idx == 0:
            print(f"[{idx+1}] pred: {pred}")
            print(f"     tgt : {tgt_str}\n")


# ---- Load test data ----
def load_test_file(path):
    lines = []
    with open(path, 'r') as f:
        for line in f:
            line = line.rstrip('\r\n')
            if not line:
                break
            lines.append([int(x) for x in line.split(' ')])
    return lines


# ---- Uncomment below to run decoding after training ----
# test_src = load_test_file(os.path.join(DATA_PATH, 'raw', 'src_test.txt'))
# test_tgt = load_test_file(os.path.join(DATA_PATH, 'raw', 'tgt_test.txt'))
# decode_model = load_model(MODEL_CKPT, src_vocab, tgt_vocab)
# with open(OUT_FILE, 'w', encoding='ISO-8859-1') as fout:
#     run_beam_search(decode_model, test_src, test_tgt, tgt_idx2word, tgt_vocab, fout, BEAM_SIZE)
# print(f"Decoding complete → {OUT_FILE}")

print("Decoding functions ready. Uncomment the block above after training.")

## 9. Evaluation — BLEU & ROUGE  (replaces eval.py)

Compute BLEU-1/2/3/4 and ROUGE-1/2/L scores from the output file.

In [None]:
# ============================================================
# Cell 11: Evaluation  (replaces eval.py)
# ============================================================
# Requires: pip install nltk rouge
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge


def bleu_scores(pred, tgt):
    """Compute BLEU 1-4 for a single sentence pair."""
    pred = [str(x) for x in pred]
    b1 = sentence_bleu([pred], tgt, weights=(1, 0, 0, 0))
    b2 = sentence_bleu([pred], tgt, weights=(0, 1, 0, 0))
    b3 = sentence_bleu([pred], tgt, weights=(0, 0, 1, 0))
    b4 = sentence_bleu([pred], tgt, weights=(0, 0, 0, 1))
    return b1, b2, b3, b4


def rouge_scores(pred, tgt):
    """Compute ROUGE for a single sentence pair."""
    rouge = Rouge()
    hyp = ' '.join(str(x) for x in pred)
    ref = ' '.join(tgt)
    return rouge.get_scores(hyps=hyp, refs=ref)[0]


def evaluate(output_path, result_path):
    """Read the decode output file, compute aggregate metrics, write results."""
    with open(output_path, 'r', encoding='ISO-8859-1') as f:
        data = f.readlines()

    tot = 0
    bleu_agg = [0.0, 0.0, 0.0, 0.0]
    rouge_1, rouge_2, rouge_l = defaultdict(float), defaultdict(float), defaultdict(float)

    idx = 0
    while idx < len(data):
        pred = eval(data[idx].rstrip('\r\n'));  idx += 1
        tgt  = eval(data[idx].rstrip('\r\n'));  idx += 1

        if not pred:
            b = (0, 0, 0, 0)
            r = {'rouge-1': {'f': 0, 'p': 0, 'r': 0},
                 'rouge-2': {'f': 0, 'p': 0, 'r': 0},
                 'rouge-l': {'f': 0, 'p': 0, 'r': 0}}
        else:
            b = bleu_scores(pred, tgt)
            r = rouge_scores(pred, tgt)

        for i in range(4):
            bleu_agg[i] += b[i]
        for k in ('f', 'p', 'r'):
            rouge_1[k] += r['rouge-1'][k]
            rouge_2[k] += r['rouge-2'][k]
            rouge_l[k] += r['rouge-l'][k]
        tot += 1

    # Print & save
    lines = []
    lines.append(f"Total samples: {tot}")
    lines.append(f"BLEU-1: {bleu_agg[0]/tot:.4f}  BLEU-2: {bleu_agg[1]/tot:.4f}  "
                 f"BLEU-3: {bleu_agg[2]/tot:.4f}  BLEU-4: {bleu_agg[3]/tot:.4f}")
    lines.append(f"ROUGE-1  r:{rouge_1['r']/tot:.4f}  p:{rouge_1['p']/tot:.4f}  f:{rouge_1['f']/tot:.4f}")
    lines.append(f"ROUGE-2  r:{rouge_2['r']/tot:.4f}  p:{rouge_2['p']/tot:.4f}  f:{rouge_2['f']/tot:.4f}")
    lines.append(f"ROUGE-L  r:{rouge_l['r']/tot:.4f}  p:{rouge_l['p']/tot:.4f}  f:{rouge_l['f']/tot:.4f}")

    for l in lines:
        print(l)

    with open(result_path, 'w') as fout:
        fout.write('\n'.join(lines))
    print(f"\nResults saved → {result_path}")


# ---- Uncomment after decoding ----
# evaluate(OUT_FILE, os.path.join(OUTPUT_DIR, "WMT14_result.txt"))

print("Evaluation functions ready. Uncomment the line above after decoding.")