In [24]:
!ls -l /kaggle/input/translate7

total 60
-rw-r--r-- 1 nobody nogroup 3858 Aug 25 03:13 analisis.py
-rw-r--r-- 1 nobody nogroup 1109 Aug 25 03:13 attention.py
-rw-r--r-- 1 nobody nogroup 1521 Aug 25 03:13 decoder.py
-rw-r--r-- 1 nobody nogroup  908 Aug 25 03:13 encoder.py
-rw-r--r-- 1 nobody nogroup 4429 Aug 25 03:13 eval.py
-rw-r--r-- 1 nobody nogroup 1692 Aug 25 03:13 heatmap.py
-rw-r--r-- 1 nobody nogroup  862 Aug 25 03:13 prepare_data.py
-rw-r--r-- 1 nobody nogroup 3754 Aug 25 03:13 seq2seq.py
-rw-r--r-- 1 nobody nogroup  665 Aug 25 03:13 sp_train.py
-rw-r--r-- 1 nobody nogroup 2335 Aug 25 03:13 top_words.py
-rw-r--r-- 1 nobody nogroup 6537 Aug 25 03:13 transformer.py
-rw-r--r-- 1 nobody nogroup 8006 Aug 25 03:13 util.py


In [25]:
import unicodedata
from collections import Counter
from pathlib import Path
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import math
import sacrebleu
import sys
import os
import csv
import random
import numpy as np
import re
from typing import List, Dict, Tuple, Optional

# Set seeds for reproducibility
def set_seed(seed: int = 42):
    """Sets seeds for reproducibility across different devices."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# --- Global Constants ---
SPECIALS = ["<pad>", "<bos>", "<eos>", "<unk>"]
PAD, BOS, EOS, UNK = range(4)
CLIP = 1.0

# --- Improved Helper functions ---
def normalize(text: str) -> str:
    """Improved text normalization for consistent tokenization."""
    text = unicodedata.normalize("NFKC", text.lower().strip())
    # Remove special characters but keep basic punctuation
    text = re.sub(r'[^\w\s.,!?-]', '', text)
    # Normalize multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def to_ids(tokens: List[str], vocab: Dict[str, int]) -> List[int]:
    """Converts a list of tokens to a list of IDs with BOS and EOS tokens."""
    ids = [BOS]
    for tok in tokens:
        ids.append(vocab.get(tok, UNK))
    ids.append(EOS)
    return ids

def decode_ids(ids: List[int], itos: Dict[int, str]) -> str:
    """Decodes a list of IDs back to a string."""
    tokens = []
    # Ensure input is a list of integers
    if isinstance(ids, torch.Tensor):
        ids = ids.tolist()
    
    for i in ids:
        if i == EOS:
            break
        if i != BOS and i != PAD:
            tokens.append(itos.get(i, '<UNK>'))
    
    return ' '.join(tokens)

def collate_batch(batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
    """Pads a batch of sequences and sorts them by source length."""
    src_list, trg_list = [], []
    for src, trg in batch:
        src_list.append(src)
        trg_list.append(trg)
    
    # Sort by source length for more efficient packing/padding
    sorted_batch = sorted(zip(src_list, trg_list), key=lambda x: len(x[0]), reverse=True)
    src_list, trg_list = zip(*sorted_batch)
    
    src_padded = torch.nn.utils.rnn.pad_sequence(src_list, batch_first=True, padding_value=PAD)
    trg_padded = torch.nn.utils.rnn.pad_sequence(trg_list, batch_first=True, padding_value=PAD)
    return src_padded, trg_padded

def load_pairs(file_path: Path, max_len: int = 25, max_pairs: Optional[int] = None) -> List[Tuple[List[str], List[str]]]:
    """Loads and preprocesses sentence pairs from a file."""
    pairs = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if max_pairs and i >= max_pairs:
                break
            parts = line.strip().split('\t')
            if len(parts) >= 2:
                src, trg = normalize(parts[0]).split(' '), normalize(parts[1]).split(' ')
                # Filter by length and ensure non-empty
                if 3 <= len(src) <= max_len and 3 <= len(trg) <= max_len:
                    pairs.append((src, trg))
    return pairs

def split_pairs(pairs: List[Tuple], train_ratio: float = 0.8, val_ratio: float = 0.1) -> Tuple[List, List, List]:
    """Shuffles and splits data into training, validation, and test sets."""
    random.shuffle(pairs)
    n = len(pairs)
    n_train = int(n * train_ratio)
    n_val = int(n * val_ratio)
    return pairs[:n_train], pairs[n_train:n_train + n_val], pairs[n_train + n_val:]

def build_vocab(token_lists: List[List[str]], min_freq: int = 2, max_size: int = 30000) -> Tuple[Dict, Dict]:
    """Builds a vocabulary with frequency and size filtering."""
    counter = Counter()
    for toks in token_lists:
        counter.update(toks)
    
    # Filter by minimum frequency and size
    filtered = [w for w, c in counter.items() if c >= min_freq]
    filtered.sort(key=lambda w: counter[w], reverse=True)
    
    vocab = {sp: i for i, sp in enumerate(SPECIALS)}
    for w in filtered[:max_size - len(SPECIALS)]:
        if w not in vocab:
            vocab[w] = len(vocab)
            
    itos = {i: w for w, i in vocab.items()}
    return vocab, itos

def evaluate_sacrebleu(model, loader: DataLoader, trg_itos: Dict[int, str], beam_size: int = 1) -> Tuple[float, float]:
    """Evaluates the model using SacreBLEU and chrF scores."""
    model.eval()
    refs, hyps = [], []
    
    with torch.no_grad():
        for src, trg in tqdm(loader, desc="Evaluating"):
            src = src.to(model.device)
            
            try:
                if beam_size > 1:
                    translated_ids = model.beam_search_decode(src, max_len=40, beam_size=beam_size)
                else:
                    translated_ids = model.greedy_decode(src, max_len=40)
                    
                for i in range(src.size(0)):
                    hyp = decode_ids(translated_ids[i], trg_itos)
                    ref = decode_ids(trg[i], trg_itos)
                    
                    if hyp.strip() and ref.strip():
                        hyps.append(hyp)
                        refs.append([ref])
            
            except Exception as e:
                print(f"Error in evaluation batch: {e}")
                continue
    
    if not hyps or not refs:
        return 0.0, 0.0
    
    try:
        bleu_score = sacrebleu.corpus_bleu(hyps, refs).score
        chrf_score = sacrebleu.corpus_chrf(hyps, refs).score
    except Exception as e:
        print(f"Error calculating BLEU/chrF: {e}")
        return 0.0, 0.0
    
    return bleu_score, chrf_score

def print_translations_and_analyze(model, loader: DataLoader, src_itos: Dict[int, str], trg_itos: Dict[int, str], num_examples: int = 5):
    """Prints translation examples with greedy and beam search results."""
    model.eval()
    print("\n--- Translation Examples & Error Analysis ---")
    
    with torch.no_grad():
        for i, (src, trg) in enumerate(loader):
            if i >= num_examples:
                break
                
            src = src.to(model.device)
            
            # Use the first example from the batch
            src_seq = src[0:1]
            trg_seq = trg[0]
            
            src_text = decode_ids(src_seq.squeeze(0), src_itos)
            ref_text = decode_ids(trg_seq, trg_itos)
            
            print("--------------------------------------------------")
            print(f"Source: {src_text}")
            print(f"Reference: {ref_text}")
            
            try:
                ys_greedy = model.greedy_decode(src_seq, max_len=40)
                hyp_greedy = decode_ids(ys_greedy.squeeze(0), trg_itos)
                print(f"Greedy: {hyp_greedy}")
            except Exception as e:
                print(f"Greedy decoding failed: {e}")
            
            try:
                ys_beam = model.beam_search_decode(src_seq, max_len=40, beam_size=3)
                if ys_beam and len(ys_beam) > 0:
                    hyp_beam = decode_ids(ys_beam[0], trg_itos)
                    print(f"Beam Search (k=3): {hyp_beam}")
            except Exception as e:
                print(f"Beam search failed: {e}")

class EarlyStopping:
    """Implements early stopping based on validation loss."""
    def __init__(self, patience: int = 10, min_delta: float = 0.001, restore_best_weights: bool = True):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.restore_best_weights = restore_best_weights
        self.best_weights = None

    def __call__(self, val_loss: float, model: Optional[nn.Module] = None):
        if self.best_loss is None:
            self.best_loss = val_loss
            if model and self.restore_best_weights:
                self.best_weights = {k: v.clone() for k, v in model.state_dict().items()}
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if model and self.restore_best_weights:
                self.best_weights = {k: v.clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if model and self.restore_best_weights and self.best_weights:
                    model.load_state_dict(self.best_weights)

# --- Dataset Class ---
class NMTDataset(Dataset):
    """Custom Dataset for Machine Translation."""
    def __init__(self, pairs: List[Tuple], src_vocab: Dict, trg_vocab: Dict):
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.data = []
        
        for src_tokens, trg_tokens in pairs:
            src_ids = to_ids(src_tokens, src_vocab)
            trg_ids = to_ids(trg_tokens, trg_vocab)
            self.data.append((src_ids, trg_ids))

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        src_ids, trg_ids = self.data[idx]
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(trg_ids, dtype=torch.long)

# --- Improved Positional Encoding ---
class PositionalEncoding(nn.Module):
    """Injects positional information into embeddings."""
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

# --- Improved Transformer Model ---
class TransformerSeq2Seq(nn.Module):
    """A standard Transformer-based sequence-to-sequence model."""
    def __init__(self, src_vocab_size: int, trg_vocab_size: int, d_model: int = 512, nhead: int = 8,
                 num_encoder_layers: int = 6, num_decoder_layers: int = 6, dim_feedforward: int = 2048,
                 dropout: float = 0.1, pad_token_id: int = PAD, device: str = 'cpu'):
        super().__init__()
        self.d_model = d_model
        self.device = device
        self.pad_token_id = pad_token_id
        
        self.src_embedding = nn.Embedding(src_vocab_size, d_model, padding_idx=self.pad_token_id)
        self.trg_embedding = nn.Embedding(trg_vocab_size, d_model, padding_idx=self.pad_token_id)
        
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
        self.pos_decoder = PositionalEncoding(d_model, dropout=dropout)
        
        encoder_layers = TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout,
            batch_first=True, norm_first=True
        )
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_encoder_layers)
        
        decoder_layers = TransformerDecoderLayer(
            d_model, nhead, dim_feedforward, dropout,
            batch_first=True, norm_first=True
        )
        self.transformer_decoder = TransformerDecoder(decoder_layers, num_decoder_layers)
        
        self.generator = nn.Linear(d_model, trg_vocab_size)
        self._initialize_parameters()
    
    def _initialize_parameters(self):
        """Xavier initialization for model parameters."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p, gain=1.0)
    
    def _get_src_mask(self, src: torch.Tensor) -> torch.Tensor:
        """Generates a source padding mask."""
        return (src == self.pad_token_id)

    def _get_trg_mask(self, trg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generates a target subsequence mask and padding mask."""
        trg_len = trg.size(1)
        trg_sub_mask = torch.triu(torch.ones(trg_len, trg_len, device=self.device) * float('-inf'), diagonal=1)
        trg_pad_mask = (trg == self.pad_token_id)
        return trg_sub_mask, trg_pad_mask

    def forward(self, src: torch.Tensor, trg: torch.Tensor) -> Tuple[torch.Tensor, None]:
        src_mask = self._get_src_mask(src)
        trg_sub_mask, trg_pad_mask = self._get_trg_mask(trg)
        
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        trg_emb = self.trg_embedding(trg) * math.sqrt(self.d_model)
        
        src_emb = self.pos_encoder(src_emb)
        trg_emb = self.pos_decoder(trg_emb)
        
        encoder_outputs = self.transformer_encoder(src_emb, src_key_padding_mask=src_mask)
        decoder_outputs = self.transformer_decoder(
            trg_emb, encoder_outputs,
            tgt_mask=trg_sub_mask,
            tgt_key_padding_mask=trg_pad_mask,
            memory_key_padding_mask=src_mask
        )
        
        outputs = self.generator(decoder_outputs)
        return outputs, None

    def greedy_decode(self, src: torch.Tensor, max_len: int = 40) -> torch.Tensor:
        """Performs greedy decoding to generate output sequences."""
        self.eval()
        batch_size = src.size(0)
        
        with torch.no_grad():
            src_mask = self._get_src_mask(src)
            src_emb = self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model))
            encoder_outputs = self.transformer_encoder(src_emb, src_key_padding_mask=src_mask)
            
            ys = torch.ones(batch_size, 1, dtype=torch.long, device=self.device).fill_(BOS)
            
            for _ in range(max_len - 1):
                trg_sub_mask, trg_pad_mask = self._get_trg_mask(ys)
                trg_emb = self.pos_decoder(self.trg_embedding(ys) * math.sqrt(self.d_model))
                
                decoder_outputs = self.transformer_decoder(
                    trg_emb, encoder_outputs,
                    tgt_mask=trg_sub_mask,
                    tgt_key_padding_mask=trg_pad_mask,
                    memory_key_padding_mask=src_mask
                )
                
                outputs = self.generator(decoder_outputs)
                next_word_id = outputs.argmax(dim=-1)[:, -1:]
                ys = torch.cat([ys, next_word_id], dim=1)
                
                if (next_word_id == EOS).all():
                    break
        return ys

    def beam_search_decode(self, src: torch.Tensor, max_len: int = 40, beam_size: int = 3) -> List[List[int]]:
        """Performs beam search decoding for a single batch example."""
        self.eval()
        if src.size(0) != 1:
            # Handle batch_size > 1 by decoding each example sequentially
            return [self.beam_search_decode(s.unsqueeze(0), max_len, beam_size)[0] for s in src]

        with torch.no_grad():
            src_mask = self._get_src_mask(src)
            src_emb = self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model))
            encoder_outputs = self.transformer_encoder(src_emb, src_key_padding_mask=src_mask)

            # Initialize beams with BOS token
            beams = [([BOS], 0.0)]  # (sequence, score)
            completed_beams = []

            for _ in range(max_len):
                candidates = []
                for seq, score in beams:
                    if seq[-1] == EOS:
                        completed_beams.append((seq, score / len(seq)))
                        continue

                    ys = torch.tensor([seq], dtype=torch.long, device=self.device)
                    trg_sub_mask, trg_pad_mask = self._get_trg_mask(ys)
                    trg_emb = self.pos_decoder(self.trg_embedding(ys) * math.sqrt(self.d_model))
                    
                    decoder_outputs = self.transformer_decoder(
                        trg_emb, encoder_outputs,
                        tgt_mask=trg_sub_mask,
                        tgt_key_padding_mask=trg_pad_mask,
                        memory_key_padding_mask=src_mask
                    )

                    outputs = self.generator(decoder_outputs)
                    log_probs = F.log_softmax(outputs[0, -1], dim=-1)
                    
                    top_probs, top_indices = log_probs.topk(beam_size)
                    
                    for prob, idx in zip(top_probs, top_indices):
                        new_seq = seq + [idx.item()]
                        new_score = score + prob.item()
                        candidates.append((new_seq, new_score))
                
                if not candidates:
                    break

                candidates.sort(key=lambda x: x[1] / len(x[0]), reverse=True)
                beams = candidates[:beam_size]
                
                if not beams:
                    break
            
            final_results = sorted(completed_beams + [(seq, score / len(seq)) for seq, score in beams], 
                                    key=lambda x: x[1], reverse=True)

            if not final_results:
                return [[BOS, EOS]]
            
            return [res[0] for res in final_results]


# --- Improved Loss Function ---
class LabelSmoothingLoss(nn.Module):
    """Implements label smoothing to regularize the model."""
    def __init__(self, smoothing: float = 0.1, ignore_index: int = PAD):
        super().__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.criterion = nn.KLDivLoss(reduction='batchmean')
        
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred = pred.log_softmax(dim=-1)
        
        with torch.no_grad():
            # Create a uniform distribution
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (pred.size(-1) - 1))
            
            # Put 1.0 - smoothing at the true target index
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
            
            # Mask out the PAD token contribution
            true_dist[:, self.ignore_index] = 0
            
            # Handle padding in target
            pad_mask = (target == self.ignore_index).unsqueeze(1)
            true_dist.masked_fill_(pad_mask, 0)
        
        return self.criterion(pred, true_dist)


# --- Improved Learning Rate Scheduler ---
class WarmupScheduler:
    """A standard Transformer learning rate scheduler."""
    def __init__(self, optimizer: torch.optim.Optimizer, d_model: int, warmup_steps: int = 4000, factor: float = 1.0):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.factor = factor
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = self.get_learning_rate()
        for p in self.optimizer.param_groups:
            p['lr'] = lr
    
    def get_learning_rate(self) -> float:
        return self.factor * (self.d_model ** (-0.5) *
                              min(self.step_num ** (-0.5), self.step_num * (self.warmup_steps ** (-1.5))))

def epoch_run(model, loader: DataLoader, criterion: nn.Module, optimizer: Optional[torch.optim.Optimizer],
              train: bool = True, model_type: str = 'transformer', scheduler: Optional[WarmupScheduler] = None) -> Tuple[float, float]:
    """Runs a single training or validation epoch."""
    model.train() if train else model.eval()
    total_loss, total_tokens = 0.0, 0
    device = next(model.parameters()).device
    
    progress_bar = tqdm(loader, desc=f"{'Training' if train else 'Validation'}")
    
    with torch.set_grad_enabled(train):
        for src, trg in progress_bar:
            src = src.to(device)
            trg = trg.to(device)
            
            if model_type == 'transformer':
                # Shift target for training
                outputs, _ = model(src, trg[:, :-1])
                # Reshape for loss calculation
                outputs_for_loss = outputs.reshape(-1, outputs.size(-1))
                target = trg[:, 1:].contiguous().reshape(-1)
                
                loss = criterion(outputs_for_loss, target)
            else:
                raise ValueError(f"Unsupported model type: {model_type}")
            
            if train:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
                optimizer.step()
                if scheduler:
                    scheduler.step()
            
            n_tokens = (target != PAD).sum().item()
            total_loss += loss.item() * n_tokens
            total_tokens += n_tokens
            
            progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / total_tokens if total_tokens > 0 else 0.0
    ppl = math.exp(avg_loss) if avg_loss < 100 else float('inf')
    return avg_loss, ppl

def main():
    parser = argparse.ArgumentParser(description='Train a Transformer-based NMT model.')
    parser.add_argument('--data_path', type=str, default='/kaggle/input/translate3/ind-eng/ind.txt', help='Path to txt data')
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--d_model', type=int, default=512, help='Model dimension')
    parser.add_argument('--nhead', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--num_enc_layers', type=int, default=6, help='Number of encoder layers')
    parser.add_argument('--num_dec_layers', type=int, default=6, help='Number of decoder layers')
    parser.add_argument('--dim_feedforward', type=int, default=2048, help='Feedforward network dimension')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
    parser.add_argument('--max_vocab', type=int, default=25000, help='Maximum vocabulary size')
    parser.add_argument('--checkpoint', type=str, default='best_model.pt', help='Path to save model checkpoint')
    parser.add_argument('--beam_size', type=int, default=1, help='Beam size for evaluation')
    
    args, unknown = parser.parse_known_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on: {device}")

    # Load and preprocess data
    data_file = Path(args.data_path)
    if not data_file.exists():
        print(f"Error: Data file not found at {data_file}. Please check the path.")
        return
        
    print("Loading and preprocessing data...")
    pairs = load_pairs(data_file, max_len=50)
    print(f"Loaded {len(pairs)} sentence pairs")
    
    if len(pairs) == 0:
        print("Error: No valid sentence pairs found!")
        return
    
    train_pairs, val_pairs, test_pairs = split_pairs(pairs, 0.8, 0.1)
    print(f"Split: Train={len(train_pairs)}, Val={len(val_pairs)}, Test={len(test_pairs)}")

    # Build vocabularies
    print("Building vocabularies...")
    src_tokens = [src for src, _ in train_pairs]
    trg_tokens = [trg for _, trg in train_pairs]
    
    src_vocab, src_itos = build_vocab(src_tokens, min_freq=2, max_size=args.max_vocab)
    trg_vocab, trg_itos = build_vocab(trg_tokens, min_freq=2, max_size=args.max_vocab)

    print(f"Source vocab size: {len(src_vocab)} | Target vocab size: {len(trg_vocab)}")

    # Create datasets
    train_ds = NMTDataset(train_pairs, src_vocab, trg_vocab)
    val_ds = NMTDataset(val_pairs, src_vocab, trg_vocab)
    test_ds = NMTDataset(test_pairs, src_vocab, trg_vocab)
    
    print(f"Dataset sizes: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}")

    # Create data loaders
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_batch)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_batch)

    # Create model
    print("Creating Transformer model...")
    model = TransformerSeq2Seq(
        len(src_vocab), len(trg_vocab),
        args.d_model, args.nhead,
        args.num_enc_layers, args.num_dec_layers,
        args.dim_feedforward, args.dropout,
        PAD, device
    ).to(device)
    
    # Loss and optimizer
    criterion = LabelSmoothingLoss(smoothing=0.1, ignore_index=PAD)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
    
    # Learning rate scheduler
    scheduler_warmup = WarmupScheduler(optimizer, args.d_model, warmup_steps=4000, factor=0.5)
    
    # Early stopping
    early_stopping = EarlyStopping(patience=10, min_delta=0.001, restore_best_weights=True)

    # Training history
    best_val_bleu = -1.0
    history = {"train_loss": [], "val_loss": [], "train_ppl": [], "val_ppl": [], "val_bleu": [], "val_chrf": []}

    print(f"Starting training for {args.epochs} epochs...")
    
    for epoch in range(1, args.epochs + 1):
        # Training
        train_loss, train_ppl = epoch_run(
            model, train_loader, criterion, optimizer,
            train=True, model_type='transformer', scheduler=scheduler_warmup
        )
        
        # Validation
        val_loss, val_ppl = epoch_run(
            model, val_loader, criterion, None,
            train=False, model_type='transformer'
        )
        
        # Evaluation
        val_bleu, val_chrf = evaluate_sacrebleu(model, val_loader, trg_itos=trg_itos, beam_size=args.beam_size)
        
        # Update history
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_ppl"].append(train_ppl)
        history["val_ppl"].append(val_ppl)
        history["val_bleu"].append(val_bleu)
        history["val_chrf"].append(val_chrf)

        print(f"Epoch {epoch:02d} | Train Loss {train_loss:.4f} PPL {train_ppl:.2f} | "
              f"Val Loss {val_loss:.4f} PPL {val_ppl:.2f} | "
              f"Val BLEU {val_bleu:.2f} | Val chrF {val_chrf:.2f}")

        # Save best model
        if val_bleu > best_val_bleu:
            best_val_bleu = val_bleu
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'val_bleu': val_bleu,
                'val_loss': val_loss,
                'src_vocab': src_vocab,
                'trg_vocab': trg_vocab,
                'src_itos': src_itos,
                'trg_itos': trg_itos,
                'args': args
            }, args.checkpoint)
            print(f"✓ Saved best model with BLEU {val_bleu:.2f}")
        
        # Early stopping
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered!")
            break

    # Final evaluation
    print("\n--- Loading best model for final evaluation ---")
    try:
        checkpoint = torch.load(args.checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {checkpoint.get('epoch', 'unknown')}")
    except Exception as e:
        print(f"Could not load checkpoint: {e}")
        print("Using current model state for evaluation")
    
    test_loss, test_ppl = epoch_run(model, test_loader, criterion, None, train=False, model_type='transformer')
    test_bleu, test_chrf = evaluate_sacrebleu(model, test_loader, trg_itos=trg_itos, beam_size=args.beam_size)
    
    # Save results
    results_file = f"final_results_transformer.csv"
    with open(results_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Model", "Test_Loss", "Test_PPL", "Test_BLEU", "Test_chrF", "Best_Val_BLEU"])
        writer.writerow(["transformer", test_loss, test_ppl, test_bleu, test_chrf, best_val_bleu])
    
    print(f"\nFINAL TEST RESULTS:")
    print(f"Loss: {test_loss:.4f} | PPL: {test_ppl:.2f}")
    print(f"SacreBLEU: {test_bleu:.2f} | chrF: {test_chrf:.2f}")
    print(f"Best Val BLEU: {best_val_bleu:.2f}")
    print(f"Results saved to: {results_file}")
    
    # Show translation examples
    print_translations_and_analyze(model, test_loader, src_itos, trg_itos, num_examples=10)

if __name__ == "__main__":
    main()

Running on: cuda
Loading and preprocessing data...
Loaded 13518 sentence pairs
Split: Train=10814, Val=1351, Test=1353
Building vocabularies...
Source vocab size: 3457 | Target vocab size: 3530
Dataset sizes: Train=10814, Val=1351, Test=1353
Creating Transformer model...
Starting training for 10 epochs...


Training: 100%|██████████| 338/338 [00:16<00:00, 20.49it/s, loss=2.88]
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.68it/s, loss=2.43]
Evaluating: 100%|██████████| 43/43 [00:08<00:00,  4.93it/s]


Epoch 01 | Train Loss 3.2186 PPL 24.99 | Val Loss 2.2017 PPL 9.04 | Val BLEU 48.89 | Val chrF 56.65
✓ Saved best model with BLEU 48.89


Training: 100%|██████████| 338/338 [00:16<00:00, 20.34it/s, loss=1.97] 
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.15it/s, loss=2.09]
Evaluating: 100%|██████████| 43/43 [00:06<00:00,  7.02it/s]


Epoch 02 | Train Loss 2.2449 PPL 9.44 | Val Loss 1.8421 PPL 6.31 | Val BLEU 2.82 | Val chrF 15.87


Training: 100%|██████████| 338/338 [00:16<00:00, 20.53it/s, loss=1.99] 
Validation: 100%|██████████| 43/43 [00:00<00:00, 64.79it/s, loss=1.7] 
Evaluating: 100%|██████████| 43/43 [00:07<00:00,  5.76it/s]


Epoch 03 | Train Loss 1.8868 PPL 6.60 | Val Loss 1.5945 PPL 4.93 | Val BLEU 3.07 | Val chrF 18.04


Training: 100%|██████████| 338/338 [00:16<00:00, 20.51it/s, loss=1.55] 
Validation: 100%|██████████| 43/43 [00:00<00:00, 64.61it/s, loss=1.62]
Evaluating: 100%|██████████| 43/43 [00:08<00:00,  5.06it/s]


Epoch 04 | Train Loss 1.5938 PPL 4.92 | Val Loss 1.4207 PPL 4.14 | Val BLEU 24.45 | Val chrF 47.40


Training: 100%|██████████| 338/338 [00:16<00:00, 20.56it/s, loss=0.698]
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.77it/s, loss=1.29] 
Evaluating: 100%|██████████| 43/43 [00:07<00:00,  5.49it/s]


Epoch 05 | Train Loss 1.3396 PPL 3.82 | Val Loss 1.3427 PPL 3.83 | Val BLEU 56.55 | Val chrF 46.05
✓ Saved best model with BLEU 56.55


Training: 100%|██████████| 338/338 [00:16<00:00, 20.26it/s, loss=1.44] 
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.93it/s, loss=1.23] 
Evaluating: 100%|██████████| 43/43 [00:04<00:00,  8.80it/s]


Epoch 06 | Train Loss 1.1333 PPL 3.11 | Val Loss 1.2549 PPL 3.51 | Val BLEU 37.99 | Val chrF 73.04


Training: 100%|██████████| 338/338 [00:16<00:00, 20.50it/s, loss=1.22] 
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.76it/s, loss=1.18] 
Evaluating: 100%|██████████| 43/43 [00:08<00:00,  5.15it/s]


Epoch 07 | Train Loss 0.9668 PPL 2.63 | Val Loss 1.2058 PPL 3.34 | Val BLEU 29.50 | Val chrF 48.90


Training: 100%|██████████| 338/338 [00:16<00:00, 20.52it/s, loss=0.587]
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.51it/s, loss=1.09] 
Evaluating: 100%|██████████| 43/43 [00:08<00:00,  5.12it/s]


Epoch 08 | Train Loss 0.8420 PPL 2.32 | Val Loss 1.1583 PPL 3.18 | Val BLEU 64.59 | Val chrF 65.33
✓ Saved best model with BLEU 64.59


Training: 100%|██████████| 338/338 [00:16<00:00, 20.41it/s, loss=0.861]
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.07it/s, loss=1.1]  
Evaluating: 100%|██████████| 43/43 [00:07<00:00,  5.71it/s]


Epoch 09 | Train Loss 0.7434 PPL 2.10 | Val Loss 1.1630 PPL 3.20 | Val BLEU 87.62 | Val chrF 69.71
✓ Saved best model with BLEU 87.62


Training: 100%|██████████| 338/338 [00:16<00:00, 20.47it/s, loss=0.915]
Validation: 100%|██████████| 43/43 [00:00<00:00, 65.00it/s, loss=1.21] 
Evaluating: 100%|██████████| 43/43 [00:05<00:00,  7.71it/s]


Epoch 10 | Train Loss 0.6649 PPL 1.94 | Val Loss 1.1610 PPL 3.19 | Val BLEU 33.26 | Val chrF 55.58

--- Loading best model for final evaluation ---
Could not load checkpoint: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL argparse.Namespace was not an allowed global by default. Please use `torch.serialization.add_safe_globals([Namespace])` or the `torch.serialization.safe_globals([Namespace])` context manager 

Validation: 100%|██████████| 43/43 [00:00<00:00, 65.71it/s, loss=1.29] 
Evaluating: 100%|██████████| 43/43 [00:04<00:00,  8.84it/s]



FINAL TEST RESULTS:
Loss: 1.1928 | PPL: 3.30
SacreBLEU: 52.43 | chrF: 49.03
Best Val BLEU: 87.62
Results saved to: final_results_transformer.csv

--- Translation Examples & Error Analysis ---
--------------------------------------------------
Source: if you turn here, you can probably avoid a lot of traffic.
Reference: kalau kamu belok ke sini, mungkin kamu bisa terhindar dari kemacetan.
Greedy: kalau kamu <unk> bisa <unk> mungkin dia masih melakukan ini?
Beam Search (k=3): kalau kau bisa <unk> di luar biasa.
--------------------------------------------------
Source: she was in the hospital for six weeks because of her <unk>
Reference: dia berada di rumah sakit selama enam minggu karena sakitnya
Greedy: dia berada di ruangan itu karena enam belas tahun.
Beam Search (k=3): dia berada di ruangan itu karena enam belas tahun.
--------------------------------------------------
Source: we learned about the <unk> of eating a healthy lunch.
Reference: kami mempelajari tentang <unk> memakan ma