In [1]:
!pip install pyvi

Collecting pyvi
  Downloading pyvi-0.1.1-py2.py3-none-any.whl.metadata (2.5 kB)
Collecting sklearn-crfsuite (from pyvi)
  Downloading sklearn_crfsuite-0.5.0-py2.py3-none-any.whl.metadata (4.9 kB)
Collecting python-crfsuite>=0.9.7 (from sklearn-crfsuite->pyvi)
  Downloading python_crfsuite-0.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Downloading pyvi-0.1.1-py2.py3-none-any.whl (8.5 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m8.5/8.5 MB[0m [31m76.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading sklearn_crfsuite-0.5.0-py2.py3-none-any.whl (10 kB)
Downloading python_crfsuite-0.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.3/1.3 MB[0m [31m60.7 MB/s[0m eta [36m0:00:00

In [2]:
import math
import random
from collections import Counter, defaultdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# -------------------------
# Seed (deterministic for reproducibility)
# -------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# ==============================
# 2Ô∏è‚É£ LOAD D·ªÆ LI·ªÜU
# ==============================
train_en = open("/kaggle/input/en-vi-ds/data/train.en", "r", encoding="utf-8").read().splitlines()
train_vi = open("/kaggle/input/en-vi-ds/data/train.vi", "r", encoding="utf-8").read().splitlines()
test_en  = open("/kaggle/input/en-vi-ds/data/tst2013.en", "r", encoding="utf-8").read().splitlines()
test_vi  = open("/kaggle/input/en-vi-ds/data/tst2013.vi", "r", encoding="utf-8").read().splitlines()

print("Train:", len(train_en), "Test:", len(test_en))


# ==============================
# 3Ô∏è‚É£ BPE TOKENIZER
# ==============================
class BPETokenizer:
    def __init__(self, texts, vocab_size=5000, min_freq=2, max_samples=50000):
        """
        texts: list of sentences
        vocab_size: target vocabulary size
        min_freq: minimum word frequency
        max_samples: limit number of sentences for faster training
        """
        print(f"Initializing BPE Tokenizer (vocab_size={vocab_size})...")
        self.vocab_size = vocab_size
        self.word2idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.bpe_codes = {}
        
        # Limit data for faster BPE training
        if len(texts) > max_samples:
            print(f"Using {max_samples}/{len(texts)} samples for BPE training")
            texts = random.sample(texts, max_samples)
        
        self.build_bpe(texts, min_freq)
    
    def get_stats(self, vocab):
        """Count frequency of adjacent symbol pairs"""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[symbols[i], symbols[i + 1]] += freq
        return pairs
    
    def merge_vocab(self, pair, vocab):
        """Merge the most frequent pair in vocabulary"""
        new_vocab = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)
        
        for word in vocab:
            new_word = word.replace(bigram, replacement)
            new_vocab[new_word] = vocab[word]
        return new_vocab
    
    def build_bpe(self, texts, min_freq):
        """Build BPE vocabulary - optimized version"""
        print("Step 1: Counting word frequencies...")
        # Count word frequencies
        word_freq = Counter()
        for i, line in enumerate(texts):
            if i % 10000 == 0 and i > 0:
                print(f"  Processed {i}/{len(texts)} lines")
            words = line.strip().lower().split()
            word_freq.update(words)
        
        print(f"Step 2: Found {len(word_freq)} unique words")
        
        # Filter by min_freq and prepare vocab
        vocab = {}
        for word, freq in word_freq.items():
            if freq >= min_freq:
                vocab[' '.join(list(word)) + ' </w>'] = freq
        
        print(f"Step 3: After filtering (min_freq={min_freq}): {len(vocab)} words")
        
        # Learn BPE merges
        num_merges = min(self.vocab_size - len(self.word2idx), 3000)  # Limit merges
        print(f"Step 4: Learning {num_merges} BPE merges...")
        
        for i in range(num_merges):
            if i % 100 == 0:
                print(f"  BPE merge {i}/{num_merges}")
            
            pairs = self.get_stats(vocab)
            if not pairs:
                print(f"  No more pairs to merge at iteration {i}")
                break
            
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.bpe_codes[best] = i
        
        # Build final vocabulary
        print("Step 5: Building final vocabulary...")
        for word in vocab.keys():
            for token in word.split():
                if token not in self.word2idx:
                    idx = len(self.word2idx)
                    self.word2idx[token] = idx
                    self.idx2word[idx] = token
        
        print(f"‚úì BPE Tokenizer ready! Vocabulary size: {len(self.word2idx)}")
        print()
    
    def apply_bpe(self, word):
        """Apply BPE codes to a word"""
        word = ' '.join(list(word)) + ' </w>'
        
        while True:
            pairs = [(word[i:i+2], i) for i in range(len(word.split())-1)]
            if not pairs:
                break
            
            # Find the pair with lowest merge order
            bigrams = [(' '.join([word.split()[i], word.split()[i+1]]), i) 
                      for i in range(len(word.split())-1)]
            
            valid_bigrams = [(self.bpe_codes.get(tuple(bg.split())), bg, pos) 
                           for bg, pos in bigrams 
                           if tuple(bg.split()) in self.bpe_codes]
            
            if not valid_bigrams:
                break
            
            # Merge the pair with lowest index (learned earliest)
            _, bigram, pos = min(valid_bigrams)
            word_list = word.split()
            word_list[pos] = ''.join(bigram.split())
            del word_list[pos + 1]
            word = ' '.join(word_list)
        
        return word.split()
    
    def encode(self, text):
        """Encode text to token IDs"""
        tokens = []
        for word in text.lower().split():
            bpe_tokens = self.apply_bpe(word)
            for token in bpe_tokens:
                tokens.append(self.word2idx.get(token, 3))
        return tokens
    
    def decode(self, ids):
        """Decode token IDs to text"""
        words = []
        current_word = ""
        
        for i in ids:
            if i == 2:  # eos
                break
            if i > 3:
                token = self.idx2word.get(i, "<unk>")
                if token.endswith('</w>'):
                    current_word += token[:-4]
                    words.append(current_word)
                    current_word = ""
                else:
                    current_word += token
        
        if current_word:
            words.append(current_word)
        
        return " ".join(words)


print("Building source BPE tokenizer...")
tok_src = BPETokenizer(train_en, vocab_size=5000, min_freq=2, max_samples=50000)

print("Building target BPE tokenizer...")
tok_trg = BPETokenizer(train_vi, vocab_size=5000, min_freq=2, max_samples=50000)


# ==============================
# 4Ô∏è‚É£ DATA AUGMENTATION
# ==============================
class DataAugmentation:
    @staticmethod
    def random_swap(text, n=1):
        """Randomly swap n words in the text"""
        words = text.split()
        if len(words) < 2:
            return text
        
        for _ in range(n):
            idx1, idx2 = random.sample(range(len(words)), 2)
            words[idx1], words[idx2] = words[idx2], words[idx1]
        
        return ' '.join(words)
    
    @staticmethod
    def random_deletion(text, p=0.1):
        """Randomly delete words with probability p"""
        words = text.split()
        if len(words) == 1:
            return text
        
        new_words = [word for word in words if random.random() > p]
        
        if len(new_words) == 0:
            return random.choice(words)
        
        return ' '.join(new_words)
    
    @staticmethod
    def augment(text, method='swap'):
        """Apply augmentation method"""
        if method == 'swap':
            return DataAugmentation.random_swap(text, n=1)
        elif method == 'delete':
            return DataAugmentation.random_deletion(text, p=0.1)
        else:
            return text


# ==============================
# 5Ô∏è‚É£ DATASET + collate (with augmentation)
# ==============================
class TranslationDataset(Dataset):
    def __init__(self, src, trg, tok_src, tok_trg, augment=False):
        self.src = src
        self.trg = trg
        self.tok_src = tok_src
        self.tok_trg = tok_trg
        self.augment = augment
        self.aug = DataAugmentation()

    def __len__(self): 
        return len(self.src)

    def __getitem__(self, idx):
        src_text = self.src[idx]
        trg_text = self.trg[idx]
        
        # Apply augmentation with 30% probability during training
        if self.augment and random.random() < 0.3:
            method = random.choice(['swap', 'delete'])
            src_text = self.aug.augment(src_text, method)
            trg_text = self.aug.augment(trg_text, method)
        
        s = [1] + self.tok_src.encode(src_text) + [2]
        t = [1] + self.tok_trg.encode(trg_text) + [2]
        return torch.tensor(s), torch.tensor(t)

def collate_fn(batch, pad_idx=0):
    src, trg = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0)
    trg = nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)
    return src, trg


dataset = TranslationDataset(train_en, train_vi, tok_src, tok_trg, augment=True)
print(f"Total dataset size: {len(dataset)}")

train_len = int(0.9 * len(dataset))
val_len = len(dataset) - train_len
train_set, val_set = random_split(dataset, [train_len, val_len])
print(f"Train: {train_len}, Val: {val_len}")

# Disable augmentation for validation
val_set.dataset.augment = False

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set, batch_size=32, shuffle=False, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
print()


# ==============================
# 6Ô∏è‚É£ POSITIONAL ENCODING
# ==============================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


# ==============================
# 7Ô∏è‚É£ LABEL SMOOTHING LOSS
# ==============================
class LabelSmoothingLoss(nn.Module):
    def __init__(self, num_classes, smoothing=0.1, ignore_index=-100):
        super().__init__()
        self.smoothing = smoothing
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.confidence = 1.0 - smoothing
    
    def forward(self, pred, target):
        """
        pred: (batch_size * seq_len, num_classes)
        target: (batch_size * seq_len)
        """
        pred = pred.log_softmax(dim=-1)
        
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.num_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            true_dist[:, self.ignore_index] = 0
            
            mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
            if mask.dim() > 0 and mask.size(0) > 0:
                true_dist.index_fill_(0, mask.squeeze(), 0.0)
        
        return torch.mean(torch.sum(-true_dist * pred, dim=-1))


# ==============================
# 8Ô∏è‚É£ TRANSFORMER MODEL
# ==============================
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model=256, nhead=4, num_layers=3, pad_idx=0):
        super().__init__()
        self.pad_idx = pad_idx

        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=pad_idx)
        self.trg_emb = nn.Embedding(trg_vocab, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)

        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True
        )

        self.fc = nn.Linear(d_model, trg_vocab)
        self.fc.weight = self.trg_emb.weight   # weight tying

    def forward(self, src, trg):
        device = src.device
        src_mask = (src == self.pad_idx)
        trg_mask = (trg == self.pad_idx)

        seq_len = trg.size(1)
        subsequent_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device)

        src_emb = self.pos(self.src_emb(src))
        trg_emb = self.pos(self.trg_emb(trg))

        out = self.transformer(
            src_emb, trg_emb,
            tgt_mask=subsequent_mask,
            src_key_padding_mask=src_mask,
            tgt_key_padding_mask=trg_mask,
            memory_key_padding_mask=src_mask
        )
        return self.fc(out)


# ==============================
# 9Ô∏è‚É£ EARLY STOPPING
# ==============================
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, mode='min'):
        """
        patience: number of epochs to wait before stopping
        min_delta: minimum change to qualify as improvement
        mode: 'min' for loss, 'max' for accuracy
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, val_metric):
        score = -val_metric if self.mode == 'min' else val_metric
        
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            print(f"     ‚ö† EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0


# ==============================
# üîü TRAINING (with AdamW + Scheduler + Early Stopping)
# ==============================
def train_model(model, train_loader, val_loader, device, epochs=20, lr=3e-4, pad_idx=0, 
                patience=5, warmup_epochs=2):
    model.to(device)
    print(f"Model moved to {device}")
    
    # AdamW optimizer with weight decay
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.98))
    
    # Multi-step learning rate scheduler with warmup
    warmup_steps = warmup_epochs * len(train_loader)
    total_steps = epochs * len(train_loader)
    
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            # Linear warmup
            return float(current_step) / float(max(1, warmup_steps))
        # Cosine decay after warmup
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
    
    # Label smoothing loss
    loss_fn = LabelSmoothingLoss(
        num_classes=len(tok_trg.word2idx), 
        smoothing=0.1, 
        ignore_index=pad_idx
    )
    
    # Early stopping
    early_stopping = EarlyStopping(patience=patience, min_delta=0.001, mode='min')
    
    print(f"Optimizer: AdamW (lr={lr}, weight_decay=0.01)")
    print(f"Scheduler: Warmup + Cosine Decay (warmup={warmup_epochs} epochs)")
    print(f"Loss: LabelSmoothing(0.1)")
    print(f"Early Stopping: patience={patience}")
    print()

    best_val = float('inf')
    train_losses = []
    val_losses = []

    for ep in range(1, epochs+1):
        print(f"{'='*60}")
        print(f"Epoch {ep}/{epochs}")
        print(f"{'='*60}")
        
        model.train()
        total_loss = 0
        batch_count = 0

        for batch_idx, (src, trg) in enumerate(train_loader):
            src, trg = src.to(device), trg.to(device)
            opt.zero_grad()

            out = model(src, trg[:, :-1])
            loss = loss_fn(out.reshape(-1, out.size(-1)), trg[:, 1:].reshape(-1))
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            scheduler.step()  # Update learning rate every batch
            
            total_loss += loss.item()
            batch_count += 1
            
            # Print progress every 100 batches
            if (batch_idx + 1) % 100 == 0:
                avg_loss = total_loss / batch_count
                current_lr = opt.param_groups[0]['lr']
                print(f"  Batch {batch_idx+1}/{len(train_loader)} | Loss: {avg_loss:.4f} | LR: {current_lr:.6f}")

        # validation
        print(f"\n  Running validation...")
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for src, trg in val_loader:
                src, trg = src.to(device), trg.to(device)
                out = model(src, trg[:, :-1])
                loss = loss_fn(out.reshape(-1, out.size(-1)), trg[:, 1:].reshape(-1))
                val_loss += loss.item()

        avg_train = total_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        current_lr = opt.param_groups[0]['lr']
        
        train_losses.append(avg_train)
        val_losses.append(avg_val)

        print(f"\n  üìä Epoch {ep} Summary:")
        print(f"     Train Loss: {avg_train:.4f}")
        print(f"     Val Loss:   {avg_val:.4f}")
        print(f"     LR:         {current_lr:.6f}")

        if avg_val < best_val:
            best_val = avg_val
            torch.save(model.state_dict(), "best_model.pt")
            print(f"     ‚úî New best model saved! (Val Loss: {best_val:.4f})")
        
        # Early stopping check
        early_stopping(avg_val)
        if early_stopping.early_stop:
            print(f"\nüõë Early stopping triggered at epoch {ep}!")
            print(f"   Best Val Loss: {best_val:.4f}")
            break
        
        print()
    
    # Plot training history
    print("\nüìà Training History:")
    print(f"Best Val Loss: {best_val:.4f} at epoch {val_losses.index(min(val_losses)) + 1}")
    
    return train_losses, val_losses


# ==============================
# üîü BEAM SEARCH DECODING
# ==============================
def beam_search_decode(model, src, tok_trg, device, beam_size=5, max_len=60):
    model.eval()

    sos = 1
    eos = 2

    memory_src = src.to(device)

    # Encode
    with torch.no_grad():
        src_mask = (memory_src == 0)
        src_emb = model.pos(model.src_emb(memory_src))
        memory = model.transformer.encoder(src_emb, src_key_padding_mask=src_mask)

    sequences = [[sos]]
    scores = [0]

    for _ in range(max_len):
        all_candidates = []

        for i in range(len(sequences)):
            seq = sequences[i]
            score = scores[i]

            if seq[-1] == eos:
                all_candidates.append((seq, score))
                continue

            trg = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(trg.size(1)).to(device)

            with torch.no_grad():
                trg_emb = model.pos(model.trg_emb(trg))
                out = model.transformer.decoder(
                    trg_emb, memory,
                    tgt_mask=tgt_mask,
                    memory_key_padding_mask=src_mask
                )
                logits = model.fc(out[:, -1])  # last token
                log_probs = torch.log_softmax(logits, dim=-1)

            topk = torch.topk(log_probs, beam_size)
            next_tokens = topk.indices[0]
            next_scores = topk.values[0]

            for k in range(beam_size):
                candidate = seq + [next_tokens[k].item()]
                candidate_score = score + next_scores[k].item()
                all_candidates.append((candidate, candidate_score))

        ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)

        sequences = []
        scores = []
        for i in range(beam_size):
            sequences.append(ordered[i][0])
            scores.append(ordered[i][1])

    best_seq = sequences[0]
    return tok_trg.decode(best_seq[1:])


# ==============================
# 1Ô∏è‚É£1Ô∏è‚É£ TRANSLATE (using BEAM SEARCH)
# ==============================
def translate(model, text, tok_src, tok_trg, device, beam_size=5):
    src = [1] + tok_src.encode(text) + [2]
    src = torch.tensor(src, dtype=torch.long).unsqueeze(0)
    return beam_search_decode(model, src, tok_trg, device, beam_size=beam_size)


# ==============================
# 1Ô∏è‚É£2Ô∏è‚É£ EVALUATE BLEU
# ==============================
def evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=50):
    model.eval()
    smooth = SmoothingFunction().method1

    total_bleu = 0
    n = min(n, len(test_en))

    for i in range(n):
        pred = translate(model, test_en[i], tok_src, tok_trg, device, beam_size=5)
        bleu = sentence_bleu([test_vi[i].split()], pred.split(), smoothing_function=smooth)
        total_bleu += bleu

        if i < 10:
            print("\nEN:", test_en[i])
            print("GT:", test_vi[i])
            print("PR:", pred)
            print("BLEU:", bleu)

    print("\nAVERAGE BLEU =", total_bleu / n)




Train: 133317 Test: 1268
Building source BPE tokenizer...
Initializing BPE Tokenizer (vocab_size=5000)...
Using 50000/133317 samples for BPE training
Step 1: Counting word frequencies...
  Processed 10000/50000 lines
  Processed 20000/50000 lines
  Processed 30000/50000 lines
  Processed 40000/50000 lines
Step 2: Found 31792 unique words
Step 3: After filtering (min_freq=2): 17929 words
Step 4: Learning 3000 BPE merges...
  BPE merge 0/3000
  BPE merge 100/3000
  BPE merge 200/3000
  BPE merge 300/3000
  BPE merge 400/3000
  BPE merge 500/3000
  BPE merge 600/3000
  BPE merge 700/3000
  BPE merge 800/3000
  BPE merge 900/3000
  BPE merge 1000/3000
  BPE merge 1100/3000
  BPE merge 1200/3000
  BPE merge 1300/3000
  BPE merge 1400/3000
  BPE merge 1500/3000
  BPE merge 1600/3000
  BPE merge 1700/3000
  BPE merge 1800/3000
  BPE merge 1900/3000
  BPE merge 2000/3000
  BPE merge 2100/3000
  BPE merge 2200/3000
  BPE merge 2300/3000
  BPE merge 2400/3000
  BPE merge 2500/3000
  BPE merge 26

In [3]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = TransformerModel(
    src_vocab=len(tok_src.word2idx),
    trg_vocab=len(tok_trg.word2idx),
    d_model=256,
    nhead=4,
    num_layers=3,
    pad_idx=0
)

print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train model with 20 epochs, warmup, and early stopping
print("\n" + "="*50)
print("TRAINING")
print("="*50)
train_losses, val_losses = train_model(
    model, train_loader, val_loader, device, 
    epochs=20,           # Increased to 20 epochs
    lr=3e-4,
    patience=5,          # Early stopping patience
    warmup_epochs=2      # Warmup for first 2 epochs
)

# Load best model
model.load_state_dict(torch.load("best_model.pt"))

# Evaluate
print("\n" + "="*50)
print("EVALUATION")
print("="*50)
evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=50)

Using device: cuda

Model parameters: 13,365,103

TRAINING
Model moved to cuda
Optimizer: AdamW (lr=0.0003, weight_decay=0.01)
Scheduler: Warmup + Cosine Decay (warmup=2 epochs)
Loss: LabelSmoothing(0.1)
Early Stopping: patience=5

Epoch 1/20




  Batch 100/3750 | Loss: 20.0097 | LR: 0.000004
  Batch 200/3750 | Loss: 16.3349 | LR: 0.000008
  Batch 300/3750 | Loss: 14.2925 | LR: 0.000012
  Batch 400/3750 | Loss: 12.8346 | LR: 0.000016
  Batch 500/3750 | Loss: 11.9500 | LR: 0.000020
  Batch 600/3750 | Loss: 11.2478 | LR: 0.000024
  Batch 700/3750 | Loss: 10.7148 | LR: 0.000028
  Batch 800/3750 | Loss: 10.2561 | LR: 0.000032
  Batch 900/3750 | Loss: 9.8490 | LR: 0.000036
  Batch 1000/3750 | Loss: 9.4868 | LR: 0.000040
  Batch 1100/3750 | Loss: 9.1487 | LR: 0.000044
  Batch 1200/3750 | Loss: 8.8646 | LR: 0.000048
  Batch 1300/3750 | Loss: 8.5964 | LR: 0.000052
  Batch 1400/3750 | Loss: 8.3433 | LR: 0.000056
  Batch 1500/3750 | Loss: 8.0958 | LR: 0.000060
  Batch 1600/3750 | Loss: 7.8559 | LR: 0.000064
  Batch 1700/3750 | Loss: 7.6421 | LR: 0.000068
  Batch 1800/3750 | Loss: 7.4383 | LR: 0.000072
  Batch 1900/3750 | Loss: 7.2412 | LR: 0.000076
  Batch 2000/3750 | Loss: 7.0528 | LR: 0.000080
  Batch 2100/3750 | Loss: 6.8766 | LR: 0.

  output = torch._nested_tensor_from_mask(



  üìä Epoch 1 Summary:
     Train Loss: 4.9331
     Val Loss:   1.9641
     LR:         0.000150
     ‚úî New best model saved! (Val Loss: 1.9641)

Epoch 2/20
  Batch 100/3750 | Loss: 2.1444 | LR: 0.000154
  Batch 200/3750 | Loss: 2.1018 | LR: 0.000158
  Batch 300/3750 | Loss: 2.0671 | LR: 0.000162
  Batch 400/3750 | Loss: 2.0653 | LR: 0.000166
  Batch 500/3750 | Loss: 2.0628 | LR: 0.000170
  Batch 600/3750 | Loss: 2.0582 | LR: 0.000174
  Batch 700/3750 | Loss: 2.0498 | LR: 0.000178
  Batch 800/3750 | Loss: 2.0304 | LR: 0.000182
  Batch 900/3750 | Loss: 2.0262 | LR: 0.000186
  Batch 1000/3750 | Loss: 2.0164 | LR: 0.000190
  Batch 1100/3750 | Loss: 2.0096 | LR: 0.000194
  Batch 1200/3750 | Loss: 2.0052 | LR: 0.000198
  Batch 1300/3750 | Loss: 1.9988 | LR: 0.000202
  Batch 1400/3750 | Loss: 1.9885 | LR: 0.000206
  Batch 1500/3750 | Loss: 1.9807 | LR: 0.000210
  Batch 1600/3750 | Loss: 1.9779 | LR: 0.000214
  Batch 1700/3750 | Loss: 1.9705 | LR: 0.000218
  Batch 1800/3750 | Loss: 1.9671