# Config tuning

In [None]:
BATCH_SIZE = 16
NUM_EPOCHS = 30
WARMUP_STEPS = 256

MODEL_DIM = 384
NUM_HEADS = 6
NUM_ENC_LAYERS = 6
NUM_DEC_LAYERS = 6
DROPOUT = 0.15
MAX_LEN_EN = 150
MAX_LEN_VI = 180

USE_ROPE = False
USE_SUBWORD = True
VOCAB_SIZE_EN = 12000
VOCAB_SIZE_VI = 7000
VOCAB_MODEL_TYPE = 'unigram'

SRC = 'en'
TRG = 'vi'

NUM_WORKERS = 20

# Define function

In [None]:
!pip install -q \
    torch \
    torchvision \
    torchaudio \
    datasets \
    sentencepiece \
    sacrebleu \
    rouge-score \
    tqdm \
    numpy \
    matplotlib \
    seaborn

import os
import gc
import torch
import torch.nn as nn
import math
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
import re
import sentencepiece as spm
import torch
from tqdm import tqdm
from sacrebleu.metrics import BLEU
from rouge_score import rouge_scorer
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path
import numpy as np
import torch
from datasets import load_dataset
from typing import Literal
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import html
import unicodedata

class RoPE(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.register_buffer(
            'inv_freq',
            1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        )

    def forward(self, x):
        B, L, D = x.shape
        device = x.device
        positions = torch.arange(L, device=device).float()
        theta = torch.einsum('n,d->nd', positions, self.inv_freq)
        cos = theta.cos()[None, :, :]
        sin = theta.sin()[None, :, :]
        x_reshaped = x.view(B, L, D//2, 2)
        x_even = x_reshaped[...,0]
        x_odd  = x_reshaped[...,1]
        cos = cos.expand(B, -1, -1)
        sin = sin.expand(B, -1, -1)
        x_rot_even = x_even * cos - x_odd * sin
        x_rot_odd  = x_even * sin + x_odd * cos
        x_rot = torch.stack([x_rot_even, x_rot_odd], dim=-1).flatten(-2)
        return x_rot

class PositionalEncoding(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        max_len: int
    ):
        super().__init__()
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class FeedForward(nn.Module):
    def __init__(
            self, 
            model_dim: int, 
            hidden_dim: int = 2048, 
            dropout: float = 0.1
        ):
        super().__init__()
        self.linear1 = nn.Linear(model_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        hidden = self.activation(self.linear1(x))
        hidden = self.dropout(hidden)
        output = self.linear2(hidden)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(
        self, 
        model_dim: int, 
        num_heads: int, 
        dropout: float = 0.1,
        use_rope: bool = False
    ):
        super().__init__()
        assert model_dim % num_heads == 0, "embedding_dim ph·∫£i chia h·∫øt cho num_heads"
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        self.use_rope = use_rope

        if use_rope:
            self.rope = RoPE(self.head_dim)

        self.Q_linear = nn.Linear(model_dim, model_dim)
        self.K_linear = nn.Linear(model_dim, model_dim)
        self.V_linear = nn.Linear(model_dim, model_dim)
        self.out_proj = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        """
        Args:
            q: [B, L_q, D]
            k: [B, L_k, D]
            v: [B, L_v, D]
        Returns:
            context: [B, L_q, D]
        """
        B = q.size(0)
        L_q = q.size(1)
        L_k = k.size(1)
        L_v = v.size(1)

        # Linear projections: [B, L, D]
        Q = self.Q_linear(q)
        K = self.K_linear(k)
        V = self.V_linear(v)

        # Split heads: [B, L, D] -> [B, num_heads, L, head_dim]
        Q = Q.view(B, L_q, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, L_k, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, L_v, self.num_heads, self.head_dim).transpose(1, 2)

        if self.use_rope:
            B, H, L, D = Q.shape
            Q = self.rope(Q.reshape(B * H, L, D)).reshape(B, H, L, D)
            K = self.rope(K.reshape(B * H, L, D)).reshape(B, H ,L, D)

        # [B, num_heads, L, L]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # [B, H, L_q, head_dim]
        context = torch.matmul(attn_weights, V)

        # [B, L_q, D]
        context = context.transpose(1, 2).contiguous().view(B, L_q, self.model_dim)
        out = self.out_proj(context)
        return out

class Encoder(nn.Module):
    def __init__(
        self, 
        model_dim: int, 
        num_heads: int, 
        ff_hidden_dim=2048, 
        dropout=0.1,
        use_rope: bool = False
    ):
        super().__init__()
        self.mha = MultiHeadAttention(model_dim, num_heads, dropout, use_rope)
        self.ffn = FeedForward(model_dim, ff_hidden_dim, dropout)
        self.norm1 = nn.LayerNorm(model_dim)
        self.norm2 = nn.LayerNorm(model_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # [B, L, D]
        attn_out = self.mha(x, x, x, mask)
        attn_out = self.norm1(x + self.dropout(attn_out))

        # [B, L, D]
        ffn_out = self.ffn(attn_out)
        ffn_out = self.norm2(attn_out + self.dropout(ffn_out))

        return ffn_out

class Decoder(nn.Module):
    def __init__(
        self, 
        model_dim, 
        num_heads, 
        ff_hidden_dim=2048, 
        dropout=0.1,
        use_rope: bool = False
    ):
        super().__init__()
        self.self_attn = MultiHeadAttention(model_dim, num_heads, dropout, use_rope)
        self.cross_attn = MultiHeadAttention(model_dim, num_heads, dropout, use_rope)
        self.ffn = FeedForward(model_dim, ff_hidden_dim, dropout)

        self.norm1 = nn.LayerNorm(model_dim)
        self.norm2 = nn.LayerNorm(model_dim)
        self.norm3 = nn.LayerNorm(model_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: [B, T, D]
            enc_output: [B, L, D]
            src_mask: [B, 1, 1, L]
            tgt_mask: [B, 1, T, T]
        Returns:
            [B, T, D]
        """
        self_attn_out = self.self_attn(x, x, x, tgt_mask)
        x2 = self.norm1(x + self.dropout(self_attn_out))

        cross_attn_out = self.cross_attn(x2, enc_output, enc_output, src_mask)
        x3 = self.norm2(x2 + self.dropout(cross_attn_out))

        ffn_out = self.ffn(x3)
        x4 = self.norm3(x3 + self.dropout(ffn_out))

        return x4

class Transformer(nn.Module):
    def __init__(
        self, 
        src_vocab_size, 
        tgt_vocab_size, 
        model_dim=512, 
        num_heads=8, 
        num_enc_layers=6, 
        num_dec_layers=6, 
        ff_hidden_dim=2048, 
        max_len_src=150,
        max_len_trg=180,
        dropout=0.1,
        pos_type: Literal['pos', 'rope'] = 'pos'
    ):
        super().__init__()
        self.model_dim = model_dim
        self.pos_type = pos_type
        
        self.src_embedding = nn.Embedding(src_vocab_size, model_dim)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, model_dim)
        if self.pos_type == 'pos':
            self.encoder_pe = PositionalEncoding(model_dim, max_len_src)
            self.decoder_pe = PositionalEncoding(model_dim, max_len_trg)
        self.dropout = nn.Dropout(dropout)
        
        self.encoder_layers = nn.ModuleList([
            Encoder(model_dim, num_heads, ff_hidden_dim, dropout, use_rope=(pos_type == 'rope'))
            for _ in range(num_enc_layers)
        ])
        
        self.decoder_layers = nn.ModuleList([
            Decoder(model_dim, num_heads, ff_hidden_dim, dropout, use_rope=(pos_type=='rope'))
            for _ in range(num_dec_layers)
        ])
        
        self.output_proj = nn.Linear(model_dim, tgt_vocab_size)
    
    def encode(self, src, src_mask=None):
        """
        Args:
            src: [B, S]
            src_mask: [B, 1, 1, S]
        Returns: [B, S, D]
        """
        x = self.src_embedding(src) * math.sqrt(self.model_dim)
        if self.pos_type == 'pos':
            x = self.encoder_pe(x)
        x = self.dropout(x)
        
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x
    
    def decode(self, tgt, enc_output, src_mask=None, tgt_mask=None):
        """
        Args:
            tgt: [B, T]
            enc_output: [B, L, D]
            src_mask: [B, 1, 1, S]
            tgt_mask: [B, 1, T, T]
        Returns: [B, T, D]
        """
        x = self.tgt_embedding(tgt) * math.sqrt(self.model_dim)
        if self.pos_type == 'pos':
            x = self.decoder_pe(x)
        x = self.dropout(x)
        
        for layer in self.decoder_layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return x
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Args:
            src: [B, L]
            tgt: [B, T]
        Returns: [B, T, tgt_vocab_size]
        """

        # [B, S, D]
        enc_output = self.encode(src, src_mask) 

        # [B, T, D]
        dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)

        # [B, T, V]
        logits = self.output_proj(dec_output) 
        return logits

def preprocess_text(text: str) -> str:
    text = html.unescape(text)
    text = text.lower()
    text = unicodedata.normalize('NFKC', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def preprocess_dataset(dataset: Dataset, ignore: Literal['vi', 'en', None] = None):
    def _run(example):
        example['raw_en'] = example['en']
        example['raw_vi'] = example['vi']
        if ignore != 'vi':
            example['vi'] = preprocess_text(example['vi'])
        if ignore != 'en':
            example['en'] = preprocess_text(example['en'])
        return example
    
    return dataset.map(_run)

def invert_html_sign(text: str) -> str:
    html_map = {
        '&': '&amp;',
        '<': '&lt;',
        '>': '&gt;',
        '"': '&quot;',
        "'": '&apos;',
    }
    for sign, value in html_map.items():
        text = text.replace(sign, value)
    return text

def postprocess_text(raw_input: str, pred: str) -> str:
    pred = pred.strip()
    pred = invert_html_sign(pred)
    entities = re.findall(r"&[a-z]+;", pred)
    for i, ent in enumerate(entities):
        pred = pred.replace(ent, f"@@{i}@@")
    pred = re.sub(r"([,.\?!'\":;])", r' \1 ', pred)
    for i, ent in enumerate(entities):
        pred = pred.replace(f"@@{i}@@", ent)
    
    raw_input_words = raw_input.split()
    start_id = 0
    while start_id < len(raw_input_words):
        if not re.fullmatch(r"[A-Za-z√Ä-·ªπ0-9]+", raw_input_words[start_id]):
            start_id += 1
        else:
            break
    upper_words = []
    for i, word in enumerate(raw_input_words):
        if i <= start_id:
            continue
        if word[0].isupper() and raw_input_words[i - 1][0] not in ['.', '?', '!']:
            upper_words.append(word)
    for word in upper_words:
        pred = pred.replace(word.lower(), word)

    pred_words = pred.split()
    all_words = []
    start_id = 0
    while start_id < len(pred_words):
        if not re.fullmatch(r"[A-Za-z√Ä-·ªπ0-9]+", pred_words[start_id]):
            start_id += 1
        else:
            break
    for i, word in enumerate(pred_words):
        if i < start_id:
            all_words.append(word)
        elif i == start_id or pred_words[i - 1] in ['.', '?', '!']:
            all_words.append(word[0].upper() + word[1:])
        else:
            all_words.append(word)
    return ' '.join(all_words)

UNK_TOKEN = '<unk>'
PAD_TOKEN = '<pad>'
SOS_TOKEN = '<sos>' # Start of Sentence
EOS_TOKEN = '<eos>' # End of Sentence

class Vocabulary:
    def __init__(self, freq_threshold=2):
        self.itos = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.stoi = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer(text):
        return [tok.lower() for tok in re.findall(r"\w+|[^\w\s]", text, re.UNICODE)]

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4 

        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi[UNK_TOKEN]
            for token in tokenized_text
        ]

class SubwordVocabulary:
    def __init__(self, spm_model_path):
        self.sp = spm.SentencePieceProcessor()
        self.sp.load(spm_model_path)

        self.pad_idx = self.sp.pad_id()
        self.sos_idx = self.sp.bos_id()
        self.eos_idx = self.sp.eos_id()
        self.unk_idx = self.sp.unk_id()
    
    def __len__(self):
        return self.sp.get_piece_size()

    def numericalize(self, text):
        return self.sp.encode(text, out_type=int)


class BilingualDataset(Dataset):
    def __init__(self, dataset, src_vocab, trg_vocab, src_lang='en', trg_lang='vi'):
        self.dataset = dataset
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.src_lang = src_lang
        self.trg_lang = trg_lang

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        raw_src_text = self.dataset[index][f'raw_{self.src_lang}']
        src_text = self.dataset[index][self.src_lang]
        raw_trg_text = self.dataset[index][f'raw_{self.trg_lang}']
        trg_text = self.dataset[index][self.trg_lang]

        src_numericalized = [self.src_vocab.stoi[SOS_TOKEN]]
        src_numericalized += self.src_vocab.numericalize(src_text)
        src_numericalized.append(self.src_vocab.stoi[EOS_TOKEN])

        trg_numericalized = [self.trg_vocab.stoi[SOS_TOKEN]]
        trg_numericalized += self.trg_vocab.numericalize(trg_text)
        trg_numericalized.append(self.trg_vocab.stoi[EOS_TOKEN])

        return raw_src_text, torch.tensor(src_numericalized), raw_trg_text, torch.tensor(trg_numericalized)

class SpmBilingualDataset(Dataset):
    def __init__(self, dataset, src_vocab, trg_vocab, src_lang='en', trg_lang='vi'):
        self.dataset = dataset
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.src_lang = src_lang
        self.trg_lang = trg_lang

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        raw_src_text = self.dataset[idx][f'raw_{self.src_lang}']
        src_ids = (
            [self.src_vocab.sos_idx]
            + self.src_vocab.numericalize(self.dataset[idx][self.src_lang])
            + [self.src_vocab.eos_idx]
        )

        raw_trg_text = self.dataset[idx][f'raw_{self.trg_lang}']
        trg_ids = (
            [self.trg_vocab.sos_idx]
            + self.trg_vocab.numericalize(self.dataset[idx][self.trg_lang])
            + [self.trg_vocab.eos_idx]
        )

        return raw_src_text, torch.tensor(src_ids), raw_trg_text, torch.tensor(trg_ids)

class Collate:
    def __init__(self, pad_idx, max_src_len=None, max_trg_len=None):
        self.pad_idx = pad_idx
        self.max_src_len = max_src_len
        self.max_trg_len = max_trg_len

    def __call__(self, batch):
        raw_src = [item[0] for item in batch]
        src = [item[1] for item in batch]
        raw_trg = [item[2] for item in batch]
        trg = [item[3] for item in batch]
        if self.max_src_len is not None:
            src = [s[:self.max_src_len] for s in src]
        if self.max_trg_len is not None:
            trg = [t[:self.max_trg_len] for t in trg]

        src = pad_sequence(src, batch_first=True, padding_value=self.pad_idx)
        trg = pad_sequence(trg, batch_first=True, padding_value=self.pad_idx)

        return raw_src, src, raw_trg, trg


class Evaluator:
    """
    Evaluator ƒë·ªÉ ƒë√°nh gi√° model translation quality v·ªõi BLEU v√† ROUGE-L scores.
    """
    
    def __init__(self, model, test_loader, src_vocab, tgt_vocab, device, use_subword: bool = False):
        """
        Args:
            model: Trained Transformer model
            test_loader: DataLoader cho test set
            src_vocab: Source vocabulary
            tgt_vocab: Target vocabulary
            device: torch.device
        """
        self.model = model
        self.test_loader = test_loader
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.device = device
        self.use_subword = use_subword
        
        self.bleu_metric = BLEU(tokenize='none' if use_subword else '13a')
        self.rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
    
    def indices_to_sentence(self, indices, vocab, remove_special=True):
        """
        Convert token indices sang sentence string.
        
        Args:
            indices: list or tensor - Token indices
            vocab: Vocabulary object
            remove_special: bool - Remove special tokens (<sos>, <eos>, <pad>)
            
        Returns:
            sentence: str
        """
        if torch.is_tensor(indices):
            indices = indices.tolist()
        if self.use_subword:
            return vocab.sp.decode_ids(indices)
        
        words = [vocab.itos[idx] for idx in indices]
        words = [w for w in words if w not in ['<sos>', '<eos>', '<pad>']]
        return ' '.join(words)
    
    def evaluate_with_decoder(self, decoder, desc="Evaluation"):
        """
        Evaluate model s·ª≠ d·ª•ng decoder c·ª• th·ªÉ (Greedy ho·∫∑c Beam Search).
        
        Args:
            decoder: GreedySearchDecoder ho·∫∑c BeamSearchDecoder instance
            desc: str - Description cho progress bar
            
        Returns:
            results: dict - Contains BLEU, ROUGE-L scores v√† examples
        """
        self.model.eval()
        
        references = []  # Ground truth translations
        hypotheses = []  # Model predictions
        rouge_scores = []
        
        # Sample translations ƒë·ªÉ show
        examples = []
        num_examples = 5
        
        print(f"\n{'='*60}")
        print(f"üìä {desc}")
        print(f"{'='*60}\n")
        
        with torch.no_grad():
            for batch_idx, (raw_src_batch, src_batch, raw_tgt_batch, tgt_batch) in enumerate(tqdm(self.test_loader, desc=desc)):
                batch_size = src_batch.size(0)
                
                for i in range(batch_size):
                    raw_src = raw_src_batch[i]
                    src = src_batch[i:i+1]  # [1, S]
                    raw_tgt = raw_tgt_batch[i]
                    tgt = tgt_batch[i]  # [T]
                    
                    # Decode
                    pred_indices = decoder.decode(src, self.src_vocab, self.tgt_vocab, self.device)
                    
                    # Convert to sentences
                    pred_sentence = self.indices_to_sentence(pred_indices, self.tgt_vocab)
                    pred_sentence = postprocess_text(raw_input=raw_src, pred=pred_sentence)
                    # ref_sentence = self.indices_to_sentence(tgt, self.tgt_vocab)
                    ref_sentence = raw_tgt
                    src_sentence = self.indices_to_sentence(src[0], self.src_vocab)
                    
                    # Collect for metrics
                    hypotheses.append(pred_sentence)
                    references.append(ref_sentence)
                    
                    # Calculate ROUGE-L for this pair
                    rouge_result = self.rouge_scorer.score(ref_sentence, pred_sentence)
                    rouge_scores.append(rouge_result['rougeL'].fmeasure)
                    
                    # Save examples
                    if len(examples) < num_examples:
                        examples.append({
                            'source': src_sentence,
                            'reference': ref_sentence,
                            'prediction': pred_sentence
                        })
        
        # Calculate BLEU
        # sacrebleu expects list of references for each hypothesis
        bleu_score = self.bleu_metric.corpus_score(hypotheses, [references])
        
        # Average ROUGE-L
        avg_rouge = np.mean(rouge_scores)
        
        results = {
            'bleu': bleu_score.score,
            'rouge_l': avg_rouge,
            'num_samples': len(hypotheses),
            'examples': examples
        }
        
        # Print results
        print(f"\n{'='*60}")
        print(f"üìà Results:")
        print(f"   BLEU Score:   {bleu_score.score:.2f}")
        print(f"   ROUGE-L F1:   {avg_rouge:.4f}")
        print(f"   Samples:      {len(hypotheses)}")
        print(f"{'='*60}\n")
        
        # Print examples
        print(f"{'='*60}")
        print(f"üìù Translation Examples:")
        print(f"{'='*60}")
        for idx, ex in enumerate(examples, 1):
            print(f"\nExample {idx}:")
            print(f"  Source:     {ex['source']}")
            print(f"  Reference:  {ex['reference']}")
            print(f"  Prediction: {ex['prediction']}")
        print(f"\n{'='*60}\n")
        
        return results
    
    def compare_decoders(self, greedy_decoder, beam_decoder):
        """
        So s√°nh Greedy Search vs Beam Search.
        
        Args:
            greedy_decoder: GreedySearchDecoder instance
            beam_decoder: BeamSearchDecoder instance
            
        Returns:
            comparison: dict - Results t·ª´ c·∫£ 2 decoders
        """
        print("\n" + "="*60)
        print("üîç COMPARING DECODING STRATEGIES")
        print("="*60)
        
        greedy_results = self.evaluate_with_decoder(greedy_decoder, "Greedy Search")
        beam_results = self.evaluate_with_decoder(beam_decoder, f"Beam Search (k={beam_decoder.beam_size})")
        
        # Summary comparison
        print("\n" + "="*60)
        print("üìä COMPARISON SUMMARY")
        print("="*60)
        print(f"\n{'Method':<20} {'BLEU':<10} {'ROUGE-L':<10}")
        print("-" * 40)
        print(f"{'Greedy Search':<20} {greedy_results['bleu']:<10.2f} {greedy_results['rouge_l']:<10.4f}")
        print(f"{'Beam Search':<20} {beam_results['bleu']:<10.2f} {beam_results['rouge_l']:<10.4f}")
        print("-" * 40)
        
        improvement_bleu = beam_results['bleu'] - greedy_results['bleu']
        improvement_rouge = beam_results['rouge_l'] - greedy_results['rouge_l']
        
        print(f"{'Improvement':<20} {improvement_bleu:<10.2f} {improvement_rouge:<10.4f}")
        print("="*60 + "\n")
        
        return {
            'greedy': greedy_results,
            'beam': beam_results,
            'improvement': {
                'bleu': improvement_bleu,
                'rouge_l': improvement_rouge
            }
        }


def calculate_bleu_score(references, hypotheses):
    """
    Helper function ƒë·ªÉ t√≠nh BLEU score.
    
    Args:
        references: list of str - Ground truth translations
        hypotheses: list of str - Model predictions
        
    Returns:
        bleu_score: float
    """
    bleu = BLEU()
    score = bleu.corpus_score(hypotheses, [references])
    return score.score


def calculate_rouge_score(references, hypotheses):
    """
    Helper function ƒë·ªÉ t√≠nh ROUGE-L score.
    
    Args:
        references: list of str - Ground truth translations
        hypotheses: list of str - Model predictions
        
    Returns:
        avg_rouge_l: float - Average ROUGE-L F1 score
    """
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
    scores = []
    
    for ref, hyp in zip(references, hypotheses):
        result = scorer.score(ref, hyp)
        scores.append(result['rougeL'].fmeasure)
    
    return np.mean(scores)

class GreedySearchDecoder:
    """
    Greedy Search: Ch·ªçn token c√≥ x√°c su·∫•t cao nh·∫•t ·ªü m·ªói b∆∞·ªõc.
    Nhanh nh∆∞ng quality th·∫•p h∆°n Beam Search.
    """
    
    def __init__(self, model, max_len=100, use_subword=False):
        """
        Args:
            model: Trained Transformer model
            max_len: Maximum length c·ªßa generated sequence
        """
        self.model = model
        self.max_len = max_len
        self.use_subword = use_subword
    
    @torch.no_grad()
    def decode(self, src, src_vocab, tgt_vocab, device):
        """
        Greedy decode m·ªôt c√¢u source.
        
        Args:
            src: [1, S] - Source tensor (batch_size=1)
            src_vocab: Vocabulary object cho source
            tgt_vocab: Vocabulary object cho target
            device: torch.device
            
        Returns:
            decoded_tokens: list of int - Token IDs
        """
        self.model.eval()
        
        src = src.to(device)
        if self.use_subword:
            pad_idx = src_vocab.pad_idx
            sos_idx = src_vocab.sos_idx
            eos_idx = src_vocab.eos_idx
        else:
            pad_idx = src_vocab.stoi['<pad>']
            sos_idx = tgt_vocab.stoi['<sos>']
            eos_idx = tgt_vocab.stoi['<eos>']
        
        # Encode source
        src_mask = create_padding_mask(src, pad_idx).to(device)
        enc_output = self.model.encode(src, src_mask)  # [1, S, D]
        
        # Start v·ªõi <sos> token
        decoded_tokens = [sos_idx]
        
        for _ in range(self.max_len):
            # T·∫°o target tensor t·ª´ tokens ƒë√£ decode
            tgt = torch.LongTensor([decoded_tokens]).to(device)  # [1, T]
            
            # Create target mask
            tgt_mask = create_causal_mask(len(decoded_tokens), device)
            
            # Decode
            dec_output = self.model.decode(tgt, enc_output, src_mask, tgt_mask)  # [1, T, D]
            
            # Get logits cho token cu·ªëi c√πng
            logits = self.model.output_proj(dec_output[:, -1, :])  # [1, V]
            
            # Greedy: ch·ªçn token c√≥ prob cao nh·∫•t
            next_token = logits.argmax(dim=-1).item()
            
            decoded_tokens.append(next_token)
            
            # Stop n·∫øu g·∫∑p <eos>
            if next_token == eos_idx:
                break
        
        return decoded_tokens
    
    def translate(self, src_sentence, src_vocab, tgt_vocab, device):
        """
        Translate m·ªôt c√¢u t·ª´ source sang target language.
        
        Args:
            src_sentence: str - C√¢u source (ƒë√£ tokenized)
            src_vocab: Vocabulary cho source
            tgt_vocab: Vocabulary cho target
            device: torch.device
            
        Returns:
            translation: str - C√¢u ƒë√£ d·ªãch
        """
        # Tokenize v√† convert sang tensor
        if self.use_subword:
            src_ids = (
                [src_vocab.sos_idx]
                + src_vocab.encode(src_sentence)
                + [src_vocab.eos_idx]
            )
        else:
            tokens = src_sentence.split()
            src_ids = [
                src_vocab.stoi.get(tok, src_vocab.stoi['<unk>'])
                for tok in tokens
            ]
        src_tensor = torch.LongTensor([src_ids]).to(device)  # [1, S]
        
        # Decode
        decoded_ids = self.decode(src_tensor, src_vocab, tgt_vocab, device)
        
        if self.use_subword:
            return tgt_vocab.decode(decoded_ids)
        else:
            words = [
                tgt_vocab.itos[i]
                for i in decoded_ids
                if tgt_vocab.itos[i] not in ['<sos>', '<eos>', '<pad>']
            ]
            return ' '.join(words)


class BeamSearchDecoder:
    """
    Beam Search: Maintain top-k hypotheses ƒë·ªÉ t√¨m translation t·ªët h∆°n.
    Ch·∫≠m h∆°n nh∆∞ng quality cao h∆°n Greedy Search.
    """
    
    def __init__(self, model, beam_size=5, max_len=100, length_penalty=0.6, use_subword=False):
        """
        Args:
            model: Trained Transformer model
            beam_size: S·ªë l∆∞·ª£ng beams (hypotheses) ƒë·ªÉ maintain
            max_len: Maximum length c·ªßa generated sequence
            length_penalty: Alpha parameter cho length normalization
                           (0.0 = no penalty, 1.0 = full penalty)
        """
        self.model = model
        self.beam_size = beam_size
        self.max_len = max_len
        self.length_penalty = length_penalty
        self.use_subword = use_subword
    
    @torch.no_grad()
    def decode(self, src, src_vocab, tgt_vocab, device):
        """
        Beam search decode.
        
        Args:
            src: [1, S] - Source tensor
            src_vocab: Vocabulary cho source
            tgt_vocab: Vocabulary cho target
            device: torch.device
            
        Returns:
            best_sequence: list of int - Best decoded token IDs
        """
        self.model.eval()
        
        src = src.to(device)
        if self.use_subword:
            pad_idx = src_vocab.pad_idx
            sos_idx = tgt_vocab.sos_idx
            eos_idx = tgt_vocab.eos_idx
        else:
            pad_idx = src_vocab.stoi['<pad>']
            sos_idx = tgt_vocab.stoi['<sos>']
            eos_idx = tgt_vocab.stoi['<eos>']
        
        # Encode source
        src_mask = create_padding_mask(src, pad_idx).to(device)
        enc_output = self.model.encode(src, src_mask)  # [1, S, D]
        
        # Initialize beams
        # Each beam: (score, tokens)
        beams = [(0.0, [sos_idx])]
        completed_beams = []
        
        for step in range(self.max_len):
            candidates = []
            
            for score, tokens in beams:
                # N·∫øu beam ƒë√£ k·∫øt th√∫c, add v√†o completed
                if tokens[-1] == eos_idx:
                    completed_beams.append((score, tokens))
                    continue
                
                # T·∫°o target tensor
                tgt = torch.LongTensor([tokens]).to(device)  # [1, T]
                tgt_mask = create_causal_mask(len(tokens), device)
                
                # Decode
                dec_output = self.model.decode(tgt, enc_output, src_mask, tgt_mask)
                logits = self.model.output_proj(dec_output[:, -1, :])  # [1, V]
                
                # Get log probabilities
                log_probs = F.log_softmax(logits, dim=-1)  # [1, V]
                
                # Get top-k tokens
                topk_log_probs, topk_indices = log_probs.topk(self.beam_size, dim=-1)
                
                # Create new candidates
                for i in range(self.beam_size):
                    token_id = topk_indices[0, i].item()
                    token_score = topk_log_probs[0, i].item()
                    
                    new_score = score + token_score
                    new_tokens = tokens + [token_id]
                    
                    candidates.append((new_score, new_tokens))
            
            # Kh√¥ng c√≤n candidates n√†o
            if not candidates:
                break
            
            # Ch·ªçn top-k beams theo score
            # Apply length normalization: score / (len ** alpha)
            candidates_normalized = [
                (score / (len(tokens) ** self.length_penalty), score, tokens)
                for score, tokens in candidates
            ]
            candidates_normalized.sort(reverse=True, key=lambda x: x[0])
            
            # Keep top beam_size
            beams = [(score, tokens) for _, score, tokens in candidates_normalized[:self.beam_size]]
            
            # Early stopping: n·∫øu ƒë√£ c√≥ ƒë·ªß completed beams
            if len(completed_beams) >= self.beam_size:
                break
        
        # Add remaining beams to completed
        completed_beams.extend(beams)
        
        # Ch·ªçn best beam (normalize by length)
        if completed_beams:
            best = max(completed_beams, key=lambda x: x[0] / (len(x[1]) ** self.length_penalty))
            return best[1]
        else:
            # Fallback: return beam ƒë·∫ßu ti√™n
            return beams[0][1] if beams else [sos_idx, eos_idx]
    
    def translate(self, src_sentence, src_vocab, tgt_vocab, device):
        """
        Translate m·ªôt c√¢u s·ª≠ d·ª•ng beam search.
        
        Args:
            src_sentence: str - C√¢u source
            src_vocab: Vocabulary cho source
            tgt_vocab: Vocabulary cho target
            device: torch.device
            
        Returns:
            translation: str - C√¢u ƒë√£ d·ªãch
        """
        # Tokenize
        if self.use_subword:
            src_ids = (
                [src_vocab.sos_idx]
                + src_vocab.encode(src_sentence)
                + [src_vocab.eos_idx]
            )
        else:
            tokens = src_sentence.split()
            src_ids = [
                src_vocab.stoi.get(tok, src_vocab.stoi['<unk>'])
                for tok in tokens
            ]
        src_tensor = torch.LongTensor([src_ids]).to(device)
        
        # Decode
        decoded_ids = self.decode(src_tensor, src_vocab, tgt_vocab, device)
        
        if self.use_subword:
            return tgt_vocab.decode(decoded_ids)
        else:
            words = [
                tgt_vocab.itos[i]
                for i in decoded_ids
                if tgt_vocab.itos[i] not in ['<sos>', '<eos>', '<pad>']
            ]
            return ' '.join(words)

def create_padding_mask(seq, pad_idx):
    """
    T·∫°o padding mask cho attention.

    Args:
        seq: [B, L] - Input sequence (token ids)
        pad_idx: int - Index c·ªßa padding token

    Returns:
        mask: [B, 1, 1, L]
              Gi√° tr·ªã:
                - 1.0 : token h·ª£p l·ªá (ƒë∆∞·ª£c attend)
                - 0.0 : padding token (b·ªã mask / ignore)
    """
    # [B, L] -> [B, 1, 1, L]
    mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
    return mask.float()


def create_causal_mask(seq_len, device='cpu'):
    """
    T·∫°o causal (look-ahead) mask cho decoder self-attention.

    Args:
        seq_len: int - ƒê·ªô d√†i sequence
        device: torch.device ho·∫∑c str

    Returns:
        mask: [1, 1, seq_len, seq_len]
              Gi√° tr·ªã:
                - 1.0 : ƒë∆∞·ª£c attend
                - 0.0 : b·ªã ch·∫∑n (future tokens)
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).bool()
    return mask.unsqueeze(0).unsqueeze(0).float()


def create_masks(src, tgt, pad_idx, device='cpu'):
    """
    T·∫°o c√°c mask c·∫ßn thi·∫øt cho Transformer encoder-decoder.

    Args:
        src: [B, S] - Source token ids
        tgt: [B, T] - Target token ids
        pad_idx: int - Padding token index
        device: torch.device ho·∫∑c str

    Returns:
        src_mask: [B, 1, 1, S]
                  (1 = attend, 0 = padding)

        tgt_mask: [B, 1, T, T]
                  K·∫øt h·ª£p:
                    - padding mask
                    - causal mask
                  (1 = attend, 0 = masked)
    """
    # Source padding mask
    src_mask = create_padding_mask(src, pad_idx)
    
    # Target padding mask
    tgt_padding_mask = create_padding_mask(tgt, pad_idx)  # [B, 1, 1, T]
    
    # Target causal mask
    tgt_len = tgt.size(1)
    tgt_causal_mask = create_causal_mask(tgt_len, device)  # [1, 1, T, T]
    
    # Combine: padding mask OR causal mask
    tgt_padding_mask = tgt_padding_mask.expand(-1, -1, tgt_len, -1)
    tgt_mask = tgt_padding_mask * tgt_causal_mask  # [B, 1, T, T]
    
    return src_mask, tgt_mask


def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, filepath):
    """
    L∆∞u model checkpoint.
    
    Args:
        model: nn.Module - Model c·∫ßn l∆∞u
        optimizer: Optimizer
        epoch: int - Current epoch
        train_loss: float - Training loss
        val_loss: float - Validation loss
        filepath: str - ƒê∆∞·ªùng d·∫´n l∆∞u file
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss
    }
    torch.save(checkpoint, filepath)
    print(f"‚úÖ Checkpoint saved: {filepath}")


def load_checkpoint(filepath, model, optimizer=None, device='cpu'):
    """
    Load model checkpoint.
    
    Args:
        filepath: str - ƒê∆∞·ªùng d·∫´n file checkpoint
        model: nn.Module - Model ƒë·ªÉ load weights
        optimizer: Optimizer (optional) - Optimizer ƒë·ªÉ load state
        device: str - Device
        
    Returns:
        epoch: int - Epoch ƒë√£ train
        train_loss: float
        val_loss: float
    """
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f"‚úÖ Checkpoint loaded: {filepath}")
    print(f"   Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.4f}")
    
    return checkpoint['epoch'], checkpoint['train_loss'], checkpoint['val_loss']


def count_parameters(model):
    """
    ƒê·∫øm s·ªë l∆∞·ª£ng parameters c·ªßa model.
    
    Args:
        model: nn.Module
        
    Returns:
        total: int - T·ªïng s·ªë parameters
        trainable: int - S·ªë parameters c√≥ th·ªÉ train
    """
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

class Trainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        optimizer,
        criterion,
        device,
        pad_idx,
        checkpoint_dir='checkpoints',
        log_dir='logs'
    ):
        """
        Trainer class cho Transformer model.
        
        Args:
            model: Transformer model
            train_loader: DataLoader cho training
            val_loader: DataLoader cho validation
            optimizer: Optimizer (Adam)
            criterion: Loss function (CrossEntropyLoss)
            device: torch.device
            pad_idx: Padding token index
            checkpoint_dir: Th∆∞ m·ª•c l∆∞u checkpoints
            log_dir: Th∆∞ m·ª•c l∆∞u logs
        """
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.pad_idx = pad_idx
        
        # Create directories
        self.checkpoint_dir = Path(checkpoint_dir)
        self.log_dir = Path(log_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.log_dir.mkdir(exist_ok=True)
        
        # Training history
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'learning_rates': []
        }
        
        self.best_val_loss = float('inf')
        
    def train_epoch(self, epoch, warmup_scheduler=None):
        """
        Train m·ªôt epoch.
        
        Args:
            warmup_scheduler: WarmupScheduler - Warmup scheduler (optional)
        
        Returns:
            avg_loss: float - Average training loss
        """
        self.model.train()
        total_loss = 0
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch} [Train]')
        for batch_idx, (raw_src, src, raw_tgt, tgt) in enumerate(pbar):
            src = src.to(self.device)  # [B, S]
            tgt = tgt.to(self.device)  # [B, T]
            
            # Decoder input: b·ªè token cu·ªëi (<eos>)
            tgt_input = tgt[:, :-1]  # [B, T-1]
            
            # Target output: b·ªè token ƒë·∫ßu (<sos>)
            tgt_output = tgt[:, 1:]  # [B, T-1]
            
            # Create masks
            src_mask, tgt_mask = create_masks(src, tgt_input, self.pad_idx, self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            logits = self.model(src, tgt_input, src_mask, tgt_mask)  # [B, T-1, V]
            
            # Calculate loss
            # Reshape: [B, T-1, V] -> [B * T-1, V]
            #          [B, T-1] -> [B * T-1]
            logits = logits.reshape(-1, logits.size(-1))
            tgt_output = tgt_output.reshape(-1)
            
            loss = self.criterion(logits, tgt_output)
            
            # Check for NaN loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\n‚ö†Ô∏è Warning: NaN/Inf loss detected! Skipping batch.")
                continue
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping ƒë·ªÉ tr√°nh exploding gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # Update warmup scheduler (per batch)
            if warmup_scheduler is not None:
                warmup_scheduler.step()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Update progress bar with current LR
            current_lr = self.optimizer.param_groups[0]['lr']
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{current_lr:.2e}'})
        
        avg_loss = total_loss / len(self.train_loader)
        return avg_loss
    
    @torch.no_grad()
    def validate(self, epoch):
        """
        Validate model.
        
        Returns:
            avg_loss: float - Average validation loss
        """
        self.model.eval()
        total_loss = 0
        
        pbar = tqdm(self.val_loader, desc=f'Epoch {epoch} [Val]  ')
        for raw_src, src, raw_tgt, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            src_mask, tgt_mask = create_masks(src, tgt_input, self.pad_idx, self.device)
            
            logits = self.model(src, tgt_input, src_mask, tgt_mask)
            
            logits = logits.reshape(-1, logits.size(-1))
            tgt_output = tgt_output.reshape(-1)
            
            loss = self.criterion(logits, tgt_output)
            total_loss += loss.item()
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_loss = total_loss / len(self.val_loader)
        return avg_loss
    
    def train(self, num_epochs, warmup_scheduler=None, plateau_scheduler=None, patience=5):
        """
        Training loop ch√≠nh.
        
        Args:
            num_epochs: int - S·ªë epochs
            warmup_scheduler: WarmupScheduler - Warmup scheduler (optional)
            plateau_scheduler: ReduceLROnPlateau - Plateau scheduler (optional)
            patience: int - Early stopping patience
        """
        print("\n" + "="*60)
        print("üöÄ B·∫ÆT ƒê·∫¶U TRAINING")
        print("="*60)
        
        total, trainable = count_parameters(self.model)
        print(f"üìä Model Parameters:")
        print(f"   Total: {total:,}")
        print(f"   Trainable: {trainable:,}")
        print("="*60 + "\n")
        
        epochs_no_improve = 0
        
        for epoch in range(1, num_epochs + 1):
            # Train
            train_loss = self.train_epoch(epoch, warmup_scheduler)
            
            # Validate
            val_loss = self.validate(epoch)
            
            # Learning rate
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['learning_rates'].append(current_lr)
            
            # Print summary
            print(f"\nüìà Epoch {epoch}/{num_epochs} Summary:")
            print(f"   Train Loss: {train_loss:.4f}")
            print(f"   Val Loss:   {val_loss:.4f}")
            print(f"   LR:         {current_lr:.6f}")
            
            # Save best model
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                epochs_no_improve = 0
                
                checkpoint_path = self.checkpoint_dir / 'best_model.pt'
                save_checkpoint(
                    self.model, 
                    self.optimizer, 
                    epoch, 
                    train_loss, 
                    val_loss, 
                    checkpoint_path
                )
                print(f"   ‚ú® New best model! (Val Loss: {val_loss:.4f})")
            else:
                epochs_no_improve += 1
                print(f"   ‚è≥ No improvement for {epochs_no_improve} epoch(s)")
            
            # Learning rate scheduling (plateau scheduler after warmup)
            if plateau_scheduler is not None:
                plateau_scheduler.step(val_loss)
            
            # Early stopping
            if epochs_no_improve >= patience:
                print(f"\n‚ö†Ô∏è  Early stopping triggered! No improvement for {patience} epochs.")
                break
            
            # Save checkpoint m·ªói 5 epochs
            if epoch % 5 == 0:
                checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
                save_checkpoint(
                    self.model,
                    self.optimizer,
                    epoch,
                    train_loss,
                    val_loss,
                    checkpoint_path
                )
            
            print("-" * 60 + "\n")
        
        # Save training history
        self.save_history()
        
        print("\n" + "="*60)
        print("‚úÖ TRAINING COMPLETED!")
        print(f"üìä Best Validation Loss: {self.best_val_loss:.4f}")
        print("="*60 + "\n")
    
    def save_history(self):
        """L∆∞u training history v√†o JSON file."""
        history_path = self.log_dir / 'training_history.json'
        with open(history_path, 'w', encoding='utf-8') as f:
            json.dump(self.history, f, indent=2)
        print(f"üìù Training history saved: {history_path}")


def create_optimizer(model, learning_rate=1e-4, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4):
    """
    T·∫°o Adam optimizer v·ªõi hyperparameters chu·∫©n cho Transformer.
    
    Args:
        model: nn.Module
        learning_rate: float
        betas: tuple - Adam beta parameters
        eps: float - Epsilon for numerical stability
        weight_decay: float - L2 regularization
        
    Returns:
        optimizer: torch.optim.AdamW
    """
    return torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        betas=betas,
        eps=eps,
        weight_decay=weight_decay
    )


def create_scheduler(optimizer, mode='plateau', factor=0.5, patience=3, min_lr=1e-6):
    """
    T·∫°o learning rate scheduler.
    
    Args:
        optimizer: Optimizer
        mode: str - 'plateau' ho·∫∑c 'step'
        factor: float - Factor gi·∫£m learning rate
        patience: int - S·ªë epochs ch·ªù tr∆∞·ªõc khi gi·∫£m LR
        min_lr: float - Minimum learning rate
        
    Returns:
        scheduler: Learning rate scheduler
    """
    if mode == 'plateau':
        return torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=factor,
            patience=patience,
            min_lr=min_lr
        )
    elif mode == 'step':
        return torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=5,
            gamma=factor
        )
    else:
        raise ValueError(f"Unknown scheduler mode: {mode}")


class WarmupScheduler:
    """
    Learning rate scheduler with warmup.
    Implements the schedule from "Attention is All You Need" paper.
    
    lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
    """
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        """
        Args:
            optimizer: Optimizer
            d_model: int - Model dimension
            warmup_steps: int - Number of warmup steps
        """
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.current_step = 0
        self._update_lr()
    
    def step(self):
        """Update learning rate."""
        self.current_step += 1
        self._update_lr()
    
    def _update_lr(self):
        """Calculate and update learning rate."""
        step = max(self.current_step, 1)  # Avoid division by zero
        lr = (self.d_model ** -0.5) * min(step ** -0.5, step * self.warmup_steps ** -1.5)
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def get_last_lr(self):
        """Get current learning rate."""
        return [param_group['lr'] for param_group in self.optimizer.param_groups]




sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10


def plot_training_curves(history_path, save_dir='figures'):
    """
    V·∫Ω training v√† validation loss curves.
    
    Args:
        history_path: str - Path to training_history.json
        save_dir: str - Directory ƒë·ªÉ l∆∞u figures
    """
    # Load history
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Plot loss curves
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss plot
    ax1.plot(epochs, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=4)
    ax1.plot(epochs, history['val_loss'], 'r-s', label='Val Loss', linewidth=2, markersize=4)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Learning rate plot
    ax2.plot(epochs, history['learning_rates'], 'g-^', linewidth=2, markersize=4)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Learning Rate', fontsize=12)
    ax2.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    ax2.set_yscale('log')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    save_path = save_dir / 'training_curves.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"‚úÖ Saved: {save_path}")
    
    plt.show()


def plot_metrics_comparison(comparison_results, save_dir='figures'):
    """
    V·∫Ω bi·ªÉu ƒë·ªì so s√°nh BLEU v√† ROUGE-L gi·ªØa Greedy v√† Beam Search.
    
    Args:
        comparison_results: dict - Results t·ª´ Evaluator.compare_decoders()
        save_dir: str - Directory ƒë·ªÉ l∆∞u figures
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    
    # Extract data
    methods = ['Greedy Search', 'Beam Search']
    bleu_scores = [
        comparison_results['greedy']['bleu'],
        comparison_results['beam']['bleu']
    ]
    rouge_scores = [
        comparison_results['greedy']['rouge_l'],
        comparison_results['beam']['rouge_l']
    ]
    
    # Create subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # BLEU comparison
    colors = ['#3498db', '#e74c3c']
    bars1 = ax1.bar(methods, bleu_scores, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax1.set_ylabel('BLEU Score', fontsize=12)
    ax1.set_title('BLEU Score Comparison', fontsize=14, fontweight='bold')
    ax1.set_ylim(0, max(bleu_scores) * 1.2)
    
    # Add value labels on bars
    for bar, score in zip(bars1, bleu_scores):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{score:.2f}',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    # ROUGE-L comparison
    bars2 = ax2.bar(methods, rouge_scores, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax2.set_ylabel('ROUGE-L F1 Score', fontsize=12)
    ax2.set_title('ROUGE-L Score Comparison', fontsize=14, fontweight='bold')
    ax2.set_ylim(0, max(rouge_scores) * 1.2)
    
    for bar, score in zip(bars2, rouge_scores):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{score:.4f}',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    
    save_path = save_dir / 'metrics_comparison.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"‚úÖ Saved: {save_path}")
    
    plt.show()


def plot_loss_histogram(history_path, save_dir='figures'):
    """
    V·∫Ω histogram distribution c·ªßa train/val loss.
    
    Args:
        history_path: str - Path to training_history.json
        save_dir: str - Directory ƒë·ªÉ l∆∞u figures
    """
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.hist(history['train_loss'], bins=20, alpha=0.5, label='Train Loss', color='blue', edgecolor='black')
    ax.hist(history['val_loss'], bins=20, alpha=0.5, label='Val Loss', color='red', edgecolor='black')
    
    ax.set_xlabel('Loss Value', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title('Loss Distribution', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    save_path = save_dir / 'loss_histogram.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"‚úÖ Saved: {save_path}")
    
    plt.show()


def create_summary_table(comparison_results, history_path, save_dir='figures'):
    """
    T·∫°o summary table v·ªõi t·∫•t c·∫£ metrics.
    
    Args:
        comparison_results: dict - Results t·ª´ comparison
        history_path: str - Path to training history
        save_dir: str - Directory ƒë·ªÉ l∆∞u
    """
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    
    # Prepare data
    summary = {
        'Training Summary': {
            'Total Epochs': len(history['train_loss']),
            'Final Train Loss': f"{history['train_loss'][-1]:.4f}",
            'Final Val Loss': f"{history['val_loss'][-1]:.4f}",
            'Best Val Loss': f"{min(history['val_loss']):.4f}",
            'Final LR': f"{history['learning_rates'][-1]:.6f}"
        },
        'Greedy Search': {
            'BLEU Score': f"{comparison_results['greedy']['bleu']:.2f}",
            'ROUGE-L F1': f"{comparison_results['greedy']['rouge_l']:.4f}",
            'Samples': comparison_results['greedy']['num_samples']
        },
        'Beam Search': {
            'BLEU Score': f"{comparison_results['beam']['bleu']:.2f}",
            'ROUGE-L F1': f"{comparison_results['beam']['rouge_l']:.4f}",
            'Samples': comparison_results['beam']['num_samples']
        },
        'Improvement (Beam vs Greedy)': {
            'BLEU': f"{comparison_results['improvement']['bleu']:+.2f}",
            'ROUGE-L': f"{comparison_results['improvement']['rouge_l']:+.4f}"
        }
    }
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.axis('tight')
    ax.axis('off')
    
    # Prepare table data
    table_data = []
    for section, metrics in summary.items():
        table_data.append([section, '', ''])
        table_data.append(['‚îÄ' * 30, '‚îÄ' * 20, '‚îÄ' * 10])
        for key, value in metrics.items():
            table_data.append(['  ' + key, str(value), ''])
        table_data.append(['', '', ''])
    
    # Create table
    table = ax.table(cellText=table_data, cellLoc='left', loc='center',
                    colWidths=[0.5, 0.3, 0.2])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)
    
    # Style header rows
    for i, row in enumerate(table_data):
        if row[1] == '':
            for j in range(3):
                cell = table[(i, j)]
                cell.set_facecolor('#3498db')
                cell.set_text_props(weight='bold', color='white')
    
    plt.title('Training & Evaluation Summary', fontsize=16, fontweight='bold', pad=20)
    
    save_path = save_dir / 'summary_table.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"‚úÖ Saved: {save_path}")
    
    plt.show()


def generate_all_plots(history_path, comparison_results, save_dir='figures'):
    """
    Generate t·∫•t c·∫£ plots c√πng l√∫c.
    
    Args:
        history_path: str - Path to training_history.json
        comparison_results: dict - Results t·ª´ evaluation
        save_dir: str - Directory ƒë·ªÉ l∆∞u figures
    """
    print("\n" + "="*60)
    print("üìä GENERATING VISUALIZATIONS")
    print("="*60 + "\n")
    
    plot_training_curves(history_path, save_dir)
    plot_metrics_comparison(comparison_results, save_dir)
    plot_loss_histogram(history_path, save_dir)
    create_summary_table(comparison_results, history_path, save_dir)
    
    print("\n" + "="*60)
    print("‚úÖ ALL VISUALIZATIONS GENERATED!")
    print(f"üìÅ Saved to: {save_dir}/")
    print("="*60 + "\n")

def main():
    # ==================== CONFIGURATION ====================
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nüñ•Ô∏è  Device: {device}")
    
    # Hyperparameters
    CONFIG = {
        'src': SRC,
        'trg': TRG,
        'use_subword': USE_SUBWORD,
        'use_rope': USE_ROPE,
        'vocab_size_en': VOCAB_SIZE_EN,
        'vocab_size_vi': VOCAB_SIZE_VI,
        'vocab_model_type': VOCAB_MODEL_TYPE,
        'num_workers': NUM_WORKERS,
        # Model
        'model_dim': MODEL_DIM,
        'num_heads': NUM_HEADS,
        'num_enc_layers': NUM_ENC_LAYERS,
        'num_dec_layers': NUM_DEC_LAYERS,
        'ff_hidden_dim': MODEL_DIM * 4,
        'dropout': DROPOUT,
        'max_len_en': MAX_LEN_EN,
        'max_len_vi': MAX_LEN_VI,
        
        # Training
        'batch_size': BATCH_SIZE,
        'num_epochs': NUM_EPOCHS,
        # 'learning_rate': 5e-5,  # Reduced to prevent NaN loss
        'weight_decay': 1e-5,
        'warmup_steps': WARMUP_STEPS,  # Warmup for stable training
        'patience': 5,  # Early stopping
        
        # Data
        'freq_threshold': 2,  # Minimum word frequency
        # 'train_split': 'train[:80%]',
        # 'val_split': 'train[80%:90%]',
        # 'test_split': 'train[90%:]',
        
        # Inference
        'beam_size': 5,
        'length_penalty': 0.6,
    }
    
    print("\n" + "="*60)
    print("‚öôÔ∏è  CONFIGURATION")
    print("="*60)
    for key, value in CONFIG.items():
        print(f"  {key:<20}: {value}")
    print("="*60 + "\n")
    
    # ==================== 1. LOAD DATA ====================
    
    print("\n" + "="*60)
    print("üì• LOADING IWSLT2015 DATASET")
    print("="*60)
    dataset = load_dataset('thainq107/iwslt2015-en-vi')
    train_dataset = dataset['train']
    val_dataset = dataset['validation']
    test_dataset = dataset['test']
    
    train_dataset = preprocess_dataset(train_dataset)
    val_dataset = preprocess_dataset(val_dataset)
    test_dataset = preprocess_dataset(test_dataset, ignore=CONFIG['trg'])
    
    print(f"  Train samples: {len(train_dataset):,}")
    print(f"  Val samples:   {len(val_dataset):,}")
    print(f"  Test samples:  {len(test_dataset):,}")
    print("="*60 + "\n")
    
    # ==================== 2. BUILD VOCABULARY ====================
    
    print("\n" + "="*60)
    print("üìö BUILDING VOCABULARY")
    print("="*60)

    with open("temp_train.en", "w", encoding="utf-8") as f_en, \
     open("temp_train.vi", "w", encoding="utf-8") as f_vi:
        for x in train_dataset:
            f_en.write(x["en"].strip() + "\n")
            f_vi.write(x["vi"].strip() + "\n")
    
    spm.SentencePieceTrainer.train(
        input="temp_train.en",
        model_prefix="spm_en",
        vocab_size=CONFIG['vocab_size_en'],
        model_type=CONFIG['vocab_model_type'],
        character_coverage=1.0,
        pad_id=0, bos_id=1, eos_id=2, unk_id=3
    )

    spm.SentencePieceTrainer.train(
        input="temp_train.vi",
        model_prefix="spm_vi",
        vocab_size=CONFIG['vocab_size_vi'],
        model_type=CONFIG['vocab_model_type'],
        character_coverage=0.9995,
        pad_id=0, bos_id=1, eos_id=2, unk_id=3
    )
    os.remove("temp_train.en")
    os.remove("temp_train.vi")
    
    src_sentences = [x[CONFIG['src']] for x in train_dataset]
    trg_sentences = [x[CONFIG['trg']] for x in train_dataset]
    
    if CONFIG['use_subword']:
        src_vocab = SubwordVocabulary(f"spm_{CONFIG['src']}.model")
        trg_vocab = SubwordVocabulary(f"spm_{CONFIG['trg']}.model")
    else:
        src_vocab = Vocabulary(freq_threshold=CONFIG['freq_threshold'])
        src_vocab.build_vocabulary(src_sentences)
        
        trg_vocab = Vocabulary(freq_threshold=CONFIG['freq_threshold'])
        trg_vocab.build_vocabulary(trg_sentences)
    
    
    # ==================== 3. CREATE DATALOADERS ====================
    
    print("\n" + "="*60)
    print("üîÑ CREATING DATALOADERS")
    print("="*60)
    
    if CONFIG['use_subword']:
        pad_idx = src_vocab.pad_idx
        train_data = SpmBilingualDataset(train_dataset, src_vocab, trg_vocab, src_lang=CONFIG['src'], trg_lang=CONFIG['trg'])
        val_data = SpmBilingualDataset(val_dataset, src_vocab, trg_vocab, src_lang=CONFIG['src'], trg_lang=CONFIG['trg'])
        test_data = SpmBilingualDataset(test_dataset, src_vocab, trg_vocab, src_lang=CONFIG['src'], trg_lang=CONFIG['trg'])
    else:
        pad_idx = src_vocab.stoi[PAD_TOKEN]
        
        train_data = BilingualDataset(train_dataset, src_vocab, trg_vocab, src_lang=CONFIG['src'], trg_lang=CONFIG['trg'])
        val_data = BilingualDataset(val_dataset, src_vocab, trg_vocab, src_lang=CONFIG['src'], trg_lang=CONFIG['trg'])
        test_data = BilingualDataset(test_dataset, src_vocab, trg_vocab, src_lang=CONFIG['src'], trg_lang=CONFIG['trg'])
    
    train_loader = DataLoader(
        train_data,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        collate_fn=Collate(
            pad_idx=pad_idx, 
            max_src_len=CONFIG[f"max_len_{CONFIG['src']}"],
            max_trg_len=CONFIG[f"max_len_{CONFIG['trg']}"]
        ),
        num_workers=CONFIG['num_workers']
    )
    
    val_loader = DataLoader(
        val_data,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        collate_fn=Collate(
            pad_idx=pad_idx, 
            max_src_len=CONFIG[f"max_len_{CONFIG['src']}"],
            max_trg_len=CONFIG[f"max_len_{CONFIG['trg']}"]
        ),
        num_workers=CONFIG['num_workers']
    )
    
    test_loader = DataLoader(
        test_data,
        batch_size=1,
        shuffle=False,
        collate_fn=Collate(
            pad_idx=pad_idx, 
            max_src_len=CONFIG[f"max_len_{CONFIG['src']}"],
            max_trg_len=CONFIG[f"max_len_{CONFIG['trg']}"]
        ),
        num_workers=CONFIG['num_workers']
    )
    
    print(f"  Train batches: {len(train_loader)}")
    print(f"  Val batches:   {len(val_loader)}")
    print(f"  Test batches:  {len(test_loader)}")
    print("="*60 + "\n")
    
    # ==================== 4. CREATE MODEL ====================
    
    print("\n" + "="*60)
    print("üèóÔ∏è  CREATING TRANSFORMER MODEL")
    print("="*60)
    
    model = Transformer(
        src_vocab_size=len(src_vocab),
        tgt_vocab_size=len(trg_vocab),
        model_dim=CONFIG['model_dim'],
        num_heads=CONFIG['num_heads'],
        num_enc_layers=CONFIG['num_enc_layers'],
        num_dec_layers=CONFIG['num_dec_layers'],
        ff_hidden_dim=CONFIG['ff_hidden_dim'],
        max_len_src=CONFIG[f"max_len_{CONFIG['src']}"],
        max_len_trg=CONFIG[f"max_len_{CONFIG['trg']}"],
        dropout=CONFIG['dropout']
    ).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"  Total parameters:     {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print("="*60 + "\n")
    
    # ==================== 5. SETUP TRAINING ====================
    
    print("\n" + "="*60)
    print("üéØ SETUP TRAINING")
    print("="*60)
    
    # Loss function (ignore padding)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=0.05)
    
    # Optimizer (learning rate s·∫Ω ƒë∆∞·ª£c ƒëi·ªÅu ch·ªânh b·ªüi warmup scheduler)
    optimizer = create_optimizer(
        model,
        learning_rate=1.0,  # Base LR, s·∫Ω ƒë∆∞·ª£c warmup scheduler ƒëi·ªÅu ch·ªânh
        weight_decay=CONFIG['weight_decay']
    )
    
    # Warmup scheduler (theo paper "Attention is All You Need")
    warmup_scheduler = WarmupScheduler(
        optimizer,
        d_model=CONFIG['model_dim'],
        warmup_steps=CONFIG['warmup_steps']
    )
    
    # Learning rate scheduler (sau warmup)
    plateau_scheduler = create_scheduler(
        optimizer,
        mode='plateau',
        factor=0.5,
        patience=3
    )
    
    print(f"  Loss function: CrossEntropyLoss (ignore_index={pad_idx}, label_smoothing=0.05)")
    print(f"  Optimizer: Adam (base_lr=1.0, weight_decay={CONFIG['weight_decay']})")
    print(f"  Warmup Scheduler: {CONFIG['warmup_steps']} steps")
    print(f"  Plateau Scheduler: ReduceLROnPlateau (factor=0.5, patience=3)")
    print("="*60 + "\n")
    
    # ==================== 6. TRAINING ====================
    
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        pad_idx=pad_idx,
        checkpoint_dir='checkpoints',
        log_dir='logs'
    )
    
    trainer.train(
        num_epochs=CONFIG['num_epochs'],
        warmup_scheduler=warmup_scheduler,
        plateau_scheduler=plateau_scheduler,
        patience=CONFIG['patience']
    )
    
    # ==================== 7. LOAD BEST MODEL ====================
    
    print("\n" + "="*60)
    print("üì¶ LOADING BEST MODEL")
    print("="*60)
    
    checkpoint = torch.load('checkpoints/best_model.pt', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"  Best epoch: {checkpoint['epoch']}")
    print(f"  Best val loss: {checkpoint['val_loss']:.4f}")
    print("="*60 + "\n")
    
    # ==================== 8. INFERENCE & EVALUATION ====================
    
    print("\n" + "="*60)
    print("üîç INFERENCE & EVALUATION")
    print("="*60 + "\n")
    
    # Create decoders
    greedy_decoder = GreedySearchDecoder(model, max_len=100, use_subword=CONFIG['use_subword'])
    beam_decoder = BeamSearchDecoder(
        model,
        beam_size=CONFIG['beam_size'],
        max_len=CONFIG[f"max_len_{CONFIG['trg']}"],
        length_penalty=CONFIG['length_penalty'],
        use_subword=CONFIG['use_subword']
    )
    
    # Evaluate
    evaluator = Evaluator(model, test_loader, src_vocab, trg_vocab, device, use_subword=CONFIG['use_subword'])
    comparison_results = evaluator.compare_decoders(greedy_decoder, beam_decoder)
    
    # ==================== 9. VISUALIZATION ====================
    
    generate_all_plots(
        history_path='logs/training_history.json',
        comparison_results=comparison_results,
        save_dir='figures'
    )
    
    # ==================== DONE ====================
    
    print("\n" + "="*60)
    print("‚úÖ ALL TASKS COMPLETED!")
    print("="*60)
    print("\nüìÅ Output files:")
    print("  - checkpoints/best_model.pt")
    print("  - logs/training_history.json")
    print("  - figures/training_curves.png")
    print("  - figures/metrics_comparison.png")
    print("  - figures/loss_histogram.png")
    print("  - figures/summary_table.png")
    print("\n" + "="*60 + "\n")

# Start training pipeline

In [None]:
main()