In [1]:
!pip install pyvi

Collecting pyvi
  Downloading pyvi-0.1.1-py2.py3-none-any.whl.metadata (2.5 kB)
Collecting sklearn-crfsuite (from pyvi)
  Downloading sklearn_crfsuite-0.5.0-py2.py3-none-any.whl.metadata (4.9 kB)
Collecting python-crfsuite>=0.9.7 (from sklearn-crfsuite->pyvi)
  Downloading python_crfsuite-0.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Downloading pyvi-0.1.1-py2.py3-none-any.whl (8.5 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m8.5/8.5 MB[0m [31m59.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading sklearn_crfsuite-0.5.0-py2.py3-none-any.whl (10 kB)
Downloading python_crfsuite-0.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.3/1.3 MB[0m [31m47.1 MB/s[0m eta [36m

In [None]:
import math
import random
from collections import Counter, defaultdict
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import autocast, GradScaler
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# -------------------------
# Seed for reproducibility
# -------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

# ==============================
# LOAD DATA
# ==============================
train_en = open("/kaggle/input/en-vi-ds/data/train.en", "r", encoding="utf-8").read().splitlines()
train_vi = open("/kaggle/input/en-vi-ds/data/train.vi", "r", encoding="utf-8").read().splitlines()
test_en  = open("/kaggle/input/en-vi-ds/data/tst2013.en", "r", encoding="utf-8").read().splitlines()
test_vi  = open("/kaggle/input/en-vi-ds/data/tst2013.vi", "r", encoding="utf-8").read().splitlines()

print(f"Train: {len(train_en)}, Test: {len(test_en)}")


# ==============================
# OPTIMIZED BPE TOKENIZER
# ==============================
class BPETokenizer:
    def __init__(self, texts, vocab_size=8000, min_freq=2, max_samples=50000):
        """Optimized BPE with caching and faster merging"""
        print(f"Initializing BPE Tokenizer (vocab_size={vocab_size})...")
        self.vocab_size = vocab_size
        self.word2idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.bpe_codes = {}
        self.cache = {}  # Cache for encoded words
        
        if len(texts) > max_samples:
            print(f"Sampling {max_samples}/{len(texts)} for BPE training")
            texts = random.sample(texts, max_samples)
        
        self.build_bpe(texts, min_freq)
    
    def get_stats(self, vocab):
        """Optimized pair counting"""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs
    
    def merge_vocab(self, pair, vocab):
        """Optimized vocabulary merging"""
        new_vocab = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)
        
        for word, freq in vocab.items():
            new_word = word.replace(bigram, replacement)
            new_vocab[new_word] = freq
        return new_vocab
    
    def build_bpe(self, texts, min_freq):
        """Build BPE vocabulary with progress tracking"""
        print("Step 1: Counting word frequencies...")
        word_freq = Counter()
        
        # Batch processing for efficiency
        batch_size = 10000
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            for line in batch:
                word_freq.update(line.strip().lower().split())
            if (i // batch_size) % 5 == 0:
                print(f"  Processed {min(i + batch_size, len(texts))}/{len(texts)} lines")
        
        print(f"Step 2: Found {len(word_freq)} unique words")
        
        # Filter and prepare vocab
        vocab = {
            ' '.join(list(word)) + ' </w>': freq
            for word, freq in word_freq.items()
            if freq >= min_freq
        }
        
        print(f"Step 3: After filtering (min_freq={min_freq}): {len(vocab)} words")
        
        # Learn BPE merges
        num_merges = min(self.vocab_size - len(self.word2idx), 5000)
        print(f"Step 4: Learning {num_merges} BPE merges...")
        
        for i in range(num_merges):
            if i % 200 == 0:
                print(f"  BPE merge {i}/{num_merges}")
            
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.bpe_codes[best] = i
        
        # Build final vocabulary
        print("Step 5: Building final vocabulary...")
        for word in vocab.keys():
            for token in word.split():
                if token not in self.word2idx:
                    idx = len(self.word2idx)
                    self.word2idx[token] = idx
                    self.idx2word[idx] = token
        
        print(f"‚úì Vocabulary size: {len(self.word2idx)}\n")
    
    def apply_bpe(self, word):
        """Apply BPE with caching"""
        if word in self.cache:
            return self.cache[word]
        
        word_tokens = ' '.join(list(word)) + ' </w>'
        
        while True:
            symbols = word_tokens.split()
            if len(symbols) == 1:
                break
            
            # Find best pair to merge
            pairs = [(symbols[i], symbols[i+1]) for i in range(len(symbols)-1)]
            valid_pairs = [(self.bpe_codes.get(p, float('inf')), i, p) 
                          for i, p in enumerate(pairs) 
                          if p in self.bpe_codes]
            
            if not valid_pairs:
                break
            
            # Merge the earliest learned pair
            _, pos, pair = min(valid_pairs)
            symbols[pos] = ''.join(pair)
            del symbols[pos + 1]
            word_tokens = ' '.join(symbols)
        
        result = word_tokens.split()
        self.cache[word] = result
        return result
    
    def encode(self, text):
        """Encode text to token IDs"""
        tokens = []
        for word in text.lower().split():
            bpe_tokens = self.apply_bpe(word)
            tokens.extend(self.word2idx.get(token, 3) for token in bpe_tokens)
        return tokens
    
    def decode(self, ids):
        """Decode token IDs to text"""
        words = []
        current_word = ""
        
        for idx in ids:
            if idx == 2:  # eos
                break
            if idx > 3:
                token = self.idx2word.get(idx, "<unk>")
                if token.endswith('</w>'):
                    current_word += token[:-4]
                    words.append(current_word)
                    current_word = ""
                else:
                    current_word += token
        
        if current_word:
            words.append(current_word)
        
        return " ".join(words)


print("Building tokenizers...")
tok_src = BPETokenizer(train_en, vocab_size=8000, min_freq=2, max_samples=50000)
tok_trg = BPETokenizer(train_vi, vocab_size=8000, min_freq=2, max_samples=50000)


# ==============================
# DATA AUGMENTATION (OPTIMIZED)
# ==============================
class DataAugmentation:
    @staticmethod
    def random_swap(words: List[str], n=1) -> List[str]:
        """Randomly swap n words"""
        if len(words) < 2:
            return words
        
        words = words.copy()
        for _ in range(min(n, len(words) // 2)):
            idx1, idx2 = random.sample(range(len(words)), 2)
            words[idx1], words[idx2] = words[idx2], words[idx1]
        return words
    
    @staticmethod
    def random_deletion(words: List[str], p=0.1) -> List[str]:
        """Randomly delete words with probability p"""
        if len(words) == 1:
            return words
        
        new_words = [word for word in words if random.random() > p]
        return new_words if new_words else [random.choice(words)]


# ==============================
# DATASET (OPTIMIZED)
# ==============================
class TranslationDataset(Dataset):
    def __init__(self, src, trg, tok_src, tok_trg, augment=False, max_len=100):
        self.tok_src = tok_src
        self.tok_trg = tok_trg
        self.augment = augment
        self.max_len = max_len
        
        # Pre-encode all data for faster training
        print("Pre-encoding dataset...")
        self.data = []
        for i, (s, t) in enumerate(zip(src, trg)):
            if i % 10000 == 0 and i > 0:
                print(f"  Encoded {i}/{len(src)} samples")
            
            s_tokens = [1] + tok_src.encode(s) + [2]
            t_tokens = [1] + tok_trg.encode(t) + [2]
            
            # Filter out very long sequences
            if len(s_tokens) <= max_len and len(t_tokens) <= max_len:
                self.data.append((s_tokens, t_tokens, s.split(), t.split()))
        
        print(f"‚úì Encoded {len(self.data)} samples\n")

    def __len__(self): 
        return len(self.data)

    def __getitem__(self, idx):
        s_tokens, t_tokens, s_words, t_words = self.data[idx]
        
        # Apply augmentation with 20% probability
        if self.augment and random.random() < 0.2:
            aug_s = DataAugmentation.random_swap(s_words, n=1)
            aug_t = DataAugmentation.random_swap(t_words, n=1)
            
            s_tokens = [1] + self.tok_src.encode(' '.join(aug_s)) + [2]
            t_tokens = [1] + self.tok_trg.encode(' '.join(aug_t)) + [2]
        
        return torch.tensor(s_tokens, dtype=torch.long), torch.tensor(t_tokens, dtype=torch.long)


def collate_fn(batch):
    """Optimized collate with proper padding"""
    src, trg = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0)
    trg = nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)
    return src, trg


# Create datasets
dataset = TranslationDataset(train_en, train_vi, tok_src, tok_trg, augment=True, max_len=100)
train_len = int(0.95 * len(dataset))
val_len = len(dataset) - train_len
train_set, val_set = random_split(dataset, [train_len, val_len])

# Disable augmentation for validation
val_set.dataset.augment = False

# Optimized batch sizes
BATCH_SIZE = 64 if torch.cuda.is_available() else 32

train_loader = DataLoader(
    train_set, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)
val_loader = DataLoader(
    val_set, 
    batch_size=BATCH_SIZE * 2, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"Train: {train_len}, Val: {val_len}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}\n")


# ==============================
# OPTIMIZED POSITIONAL ENCODING
# ==============================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


# ==============================
# LABEL SMOOTHING LOSS
# ==============================
class LabelSmoothingLoss(nn.Module):
    def __init__(self, num_classes, smoothing=0.1, ignore_index=0):
        super().__init__()
        self.smoothing = smoothing
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.confidence = 1.0 - smoothing
    
    def forward(self, pred, target):
        """
        pred: (batch_size * seq_len, num_classes)
        target: (batch_size * seq_len)
        """
        pred = pred.log_softmax(dim=-1)
        
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.num_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            
            # Mask padding
            mask = target == self.ignore_index
            true_dist[mask] = 0
        
        return torch.mean(torch.sum(-true_dist * pred, dim=-1))


# ==============================
# IMPROVED TRANSFORMER MODEL
# ==============================
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model=512, nhead=8, 
                 num_layers=6, dim_feedforward=2048, dropout=0.1, pad_idx=0):
        super().__init__()
        self.pad_idx = pad_idx
        self.d_model = d_model

        # Embeddings with scaling
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=pad_idx)
        self.trg_emb = nn.Embedding(trg_vocab, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model, dropout=dropout)
        
        # Scale factor for embeddings
        self.scale = math.sqrt(d_model)

        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model, 
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            norm_first=True  # Pre-LN for better training stability
        )

        # Output projection with weight tying
        self.fc = nn.Linear(d_model, trg_vocab)
        self.fc.weight = self.trg_emb.weight
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Xavier initialization for better convergence"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, trg):
        # Create masks
        src_mask = (src == self.pad_idx)
        trg_mask = (trg == self.pad_idx)
        seq_len = trg.size(1)
        
        # Causal mask for decoder
        causal_mask = nn.Transformer.generate_square_subsequent_mask(
            seq_len, device=trg.device
        )

        # Embeddings with scaling + positional encoding
        src_emb = self.pos(self.src_emb(src) * self.scale)
        trg_emb = self.pos(self.trg_emb(trg) * self.scale)

        # Transformer forward
        out = self.transformer(
            src_emb, trg_emb,
            tgt_mask=causal_mask,
            src_key_padding_mask=src_mask,
            tgt_key_padding_mask=trg_mask,
            memory_key_padding_mask=src_mask
        )
        
        return self.fc(out)


# ==============================
# EARLY STOPPING
# ==============================
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0005, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, val_metric):
        score = -val_metric if self.mode == 'min' else val_metric
        
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            print(f"     ‚ö† EarlyStopping: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0


# ==============================
# OPTIMIZED TRAINING WITH MIXED PRECISION
# ==============================
def train_model(model, train_loader, val_loader, device, epochs=25, lr=5e-4, 
                patience=7, warmup_epochs=2, use_amp=True):
    model.to(device)
    print(f"Device: {device}")
    print(f"Mixed Precision: {use_amp and torch.cuda.is_available()}")
    
    # Optimizer with weight decay
    opt = torch.optim.AdamW(
        model.parameters(), 
        lr=lr, 
        betas=(0.9, 0.98), 
        eps=1e-9,
        weight_decay=0.0001
    )
    
    # Learning rate scheduler with warmup
    warmup_steps = warmup_epochs * len(train_loader)
    total_steps = epochs * len(train_loader)
    
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
    
    # Loss function
    loss_fn = LabelSmoothingLoss(
        num_classes=len(tok_trg.word2idx), 
        smoothing=0.1, 
        ignore_index=0
    )
    
    # Mixed precision scaler
    scaler = GradScaler() if use_amp and torch.cuda.is_available() else None
    
    # Early stopping
    early_stopping = EarlyStopping(patience=patience, min_delta=0.0005)
    
    print(f"\nTraining Config:")
    print(f"  Epochs: {epochs}")
    print(f"  Learning Rate: {lr}")
    print(f"  Warmup Epochs: {warmup_epochs}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Early Stopping Patience: {patience}")
    print()

    best_val = float('inf')
    train_losses, val_losses = [], []

    for ep in range(1, epochs + 1):
        print(f"{'='*60}")
        print(f"Epoch {ep}/{epochs}")
        print(f"{'='*60}")
        
        # Training
        model.train()
        total_loss = 0
        
        for batch_idx, (src, trg) in enumerate(train_loader):
            src, trg = src.to(device), trg.to(device)
            opt.zero_grad()
            
            # Mixed precision training
            if scaler:
                with autocast():
                    out = model(src, trg[:, :-1])
                    loss = loss_fn(
                        out.reshape(-1, out.size(-1)), 
                        trg[:, 1:].reshape(-1)
                    )
                
                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
            else:
                out = model(src, trg[:, :-1])
                loss = loss_fn(
                    out.reshape(-1, out.size(-1)), 
                    trg[:, 1:].reshape(-1)
                )
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
            
            scheduler.step()
            total_loss += loss.item()
            
            # Progress update
            if (batch_idx + 1) % 100 == 0:
                avg_loss = total_loss / (batch_idx + 1)
                lr_current = opt.param_groups[0]['lr']
                print(f"  Batch {batch_idx+1}/{len(train_loader)} | "
                      f"Loss: {avg_loss:.4f} | LR: {lr_current:.6f}")

        # Validation
        print(f"\n  Validating...")
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for src, trg in val_loader:
                src, trg = src.to(device), trg.to(device)
                
                if scaler:
                    with autocast():
                        out = model(src, trg[:, :-1])
                        loss = loss_fn(
                            out.reshape(-1, out.size(-1)), 
                            trg[:, 1:].reshape(-1)
                        )
                else:
                    out = model(src, trg[:, :-1])
                    loss = loss_fn(
                        out.reshape(-1, out.size(-1)), 
                        trg[:, 1:].reshape(-1)
                    )
                
                val_loss += loss.item()

        # Epoch summary
        avg_train = total_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        lr_current = opt.param_groups[0]['lr']
        
        train_losses.append(avg_train)
        val_losses.append(avg_val)

        print(f"\n  üìä Summary:")
        print(f"     Train Loss: {avg_train:.4f}")
        print(f"     Val Loss:   {avg_val:.4f}")
        print(f"     LR:         {lr_current:.6f}")

        # Save best model
        if avg_val < best_val:
            best_val = avg_val
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'val_loss': best_val,
            }, "best_model.pt")
            print(f"     ‚úÖ Best model saved! (Val Loss: {best_val:.4f})")
        
        # Early stopping
        early_stopping(avg_val)
        if early_stopping.early_stop:
            print(f"\nüõë Early stopping at epoch {ep}")
            print(f"   Best Val Loss: {best_val:.4f}")
            break
        
        print()
    
    return train_losses, val_losses


# ==============================
# BEAM SEARCH DECODING
# ==============================
def beam_search_decode(model, src, tok_trg, device, beam_size=5, max_len=80):
    model.eval()
    sos, eos = 1, 2
    
    src = src.to(device)
    
    # Encode source
    with torch.no_grad():
        src_mask = (src == 0)
        src_emb = model.pos(model.src_emb(src) * model.scale)
        memory = model.transformer.encoder(src_emb, src_key_padding_mask=src_mask)
    
    # Initialize beam
    sequences = [[sos]]
    scores = [0.0]
    
    for _ in range(max_len):
        all_candidates = []
        
        for seq, score in zip(sequences, scores):
            if seq[-1] == eos:
                all_candidates.append((seq, score))
                continue
            
            trg = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(
                trg.size(1), device=device
            )
            
            with torch.no_grad():
                trg_emb = model.pos(model.trg_emb(trg) * model.scale)
                out = model.transformer.decoder(
                    trg_emb, memory,
                    tgt_mask=tgt_mask,
                    memory_key_padding_mask=src_mask
                )
                logits = model.fc(out[:, -1])
                log_probs = F.log_softmax(logits, dim=-1)
            
            # Get top-k candidates
            topk_probs, topk_ids = torch.topk(log_probs, beam_size)
            
            for k in range(beam_size):
                candidate_seq = seq + [topk_ids[0, k].item()]
                candidate_score = score + topk_probs[0, k].item()
                all_candidates.append((candidate_seq, candidate_score))
        
        # Select top beam_size candidates
        ordered = sorted(all_candidates, key=lambda x: x[1] / len(x[0]), reverse=True)
        sequences = [seq for seq, _ in ordered[:beam_size]]
        scores = [score for _, score in ordered[:beam_size]]
        
        # Early stopping if all beams end with EOS
        if all(seq[-1] == eos for seq in sequences):
            break
    
    best_seq = sequences[0]
    return tok_trg.decode(best_seq[1:])


# ==============================
# TRANSLATION FUNCTION
# ==============================
def translate(model, text, tok_src, tok_trg, device, beam_size=5):
    src = [1] + tok_src.encode(text) + [2]
    src = torch.tensor(src, dtype=torch.long).unsqueeze(0)
    return beam_search_decode(model, src, tok_trg, device, beam_size=beam_size)


# ==============================
# EVALUATION
# ==============================
def evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=100):
    model.eval()
    smooth = SmoothingFunction().method1
    
    total_bleu = 0
    n = min(n, len(test_en))
    
    print(f"Evaluating {n} test samples...")
    
    for i in range(n):
        pred = translate(model, test_en[i], tok_src, tok_trg, device, beam_size=5)
        bleu = sentence_bleu(
            [test_vi[i].split()], 
            pred.split(), 
            smoothing_function=smooth
        )
        total_bleu += bleu
        
        # Show first 5 examples
        if i < 5:
            print(f"\n--- Example {i+1} ---")
            print(f"EN: {test_en[i]}")
            print(f"GT: {test_vi[i]}")
            print(f"PR: {pred}")
            print(f"BLEU: {bleu:.4f}")
    
    avg_bleu = total_bleu / n
    print(f"\n{'='*60}")
    print(f"AVERAGE BLEU SCORE: {avg_bleu:.4f}")
    print(f"{'='*60}")
    
    return avg_bleu


# ==============================
# MAIN EXECUTION
# ==============================
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n{'='*60}")
    print(f"DEVICE: {device}")
    print(f"{'='*60}\n")
    
    # Initialize model
    model = TransformerModel(
        src_vocab=len(tok_src.word2idx),
        trg_vocab=len(tok_trg.word2idx),
        d_model=512,
        nhead=8,
        num_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        pad_idx=0
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}\n")
    
    # Train
    print("="*60)
    print("TRAINING")
    print("="*60)
    train_losses, val_losses = train_model(
        model, train_loader, val_loader, device,
        epochs=25,
        lr=5e-4,
        patience=7,
        warmup_epochs=2,
        use_amp=True
    )
    
    # Load best model
    print("\n" + "="*60)
    print("LOADING BEST MODEL")
    print("="*60)
    checkpoint = torch.load("best_model.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úì Loaded model from epoch {checkpoint['epoch']}")
    print(f"  Best Val Loss: {checkpoint['val_loss']:.4f}\n")
    
    # Evaluate on test set
    print("="*60)
    print("EVALUATION ON TEST SET")
    print("="*60)
    avg_bleu = evaluate_test_set(model, test_en, test_vi, tok_src, tok_trg, device, n=100)
    
    # Interactive translation
    print("\n" + "="*60)
    print("INTERACTIVE TRANSLATION")
    print("="*60)
    print("Enter English sentences to translate (or 'quit' to exit):\n")
    
    while True:
        try:
            text = input("EN: ").strip()
            if text.lower() in ['quit', 'exit', 'q']:
                print("Goodbye!")
                break
            if not text:
                continue
                
            translation = translate(model, text, tok_src, tok_trg, device, beam_size=5)
            print(f"VI: {translation}\n")
            
        except KeyboardInterrupt:
            print("\nGoodbye!")
            break
        except Exception as e:
            print(f"Error: {e}\n")

Train: 133317, Test: 1268
Building tokenizers...
Initializing BPE Tokenizer (vocab_size=8000)...
Sampling 50000/133317 for BPE training
Step 1: Counting word frequencies...
  Processed 10000/50000 lines
Step 2: Found 31792 unique words
Step 3: After filtering (min_freq=2): 17929 words
Step 4: Learning 5000 BPE merges...
  BPE merge 0/5000
  BPE merge 200/5000
  BPE merge 400/5000
  BPE merge 600/5000
  BPE merge 800/5000
  BPE merge 1000/5000
  BPE merge 1200/5000
  BPE merge 1400/5000
  BPE merge 1600/5000
  BPE merge 1800/5000
  BPE merge 2000/5000
  BPE merge 2200/5000
  BPE merge 2400/5000
  BPE merge 2600/5000
  BPE merge 2800/5000
  BPE merge 3000/5000
  BPE merge 3200/5000
  BPE merge 3400/5000
  BPE merge 3600/5000
  BPE merge 3800/5000
  BPE merge 4000/5000
  BPE merge 4200/5000
  BPE merge 4400/5000
  BPE merge 4600/5000
  BPE merge 4800/5000
Step 5: Building final vocabulary...
‚úì Vocabulary size: 15346

Initializing BPE Tokenizer (vocab_size=8000)...
Sampling 50000/133317 



Model Parameters: 55,608,190
Trainable Parameters: 55,608,190

TRAINING
Device: cuda
Mixed Precision: True

Training Config:
  Epochs: 25
  Learning Rate: 0.0005
  Warmup Epochs: 2
  Batch Size: 64
  Early Stopping Patience: 7

Epoch 1/25


  scaler = GradScaler() if use_amp and torch.cuda.is_available() else None
  with autocast():


  Batch 100/1902 | Loss: 3.1693 | LR: 0.000013
  Batch 200/1902 | Loss: 3.0361 | LR: 0.000026
  Batch 300/1902 | Loss: 2.8676 | LR: 0.000039
  Batch 400/1902 | Loss: 2.7547 | LR: 0.000053
  Batch 500/1902 | Loss: 2.6666 | LR: 0.000066
  Batch 600/1902 | Loss: 2.5985 | LR: 0.000079
  Batch 700/1902 | Loss: 2.5338 | LR: 0.000092
  Batch 800/1902 | Loss: 2.4754 | LR: 0.000105
  Batch 900/1902 | Loss: 2.4252 | LR: 0.000118
  Batch 1000/1902 | Loss: 2.3835 | LR: 0.000131
  Batch 1100/1902 | Loss: 2.3398 | LR: 0.000145
  Batch 1200/1902 | Loss: 2.3018 | LR: 0.000158
  Batch 1300/1902 | Loss: 2.2672 | LR: 0.000171
  Batch 1400/1902 | Loss: 2.2384 | LR: 0.000184
  Batch 1500/1902 | Loss: 2.2083 | LR: 0.000197
  Batch 1600/1902 | Loss: 2.1796 | LR: 0.000210
  Batch 1700/1902 | Loss: 2.1534 | LR: 0.000223
  Batch 1800/1902 | Loss: 2.1303 | LR: 0.000237
  Batch 1900/1902 | Loss: 2.1074 | LR: 0.000250

  Validating...


  with autocast():



  üìä Summary:
     Train Loss: 2.1072
     Val Loss:   1.5934
     LR:         0.000250
     ‚úÖ Best model saved! (Val Loss: 1.5934)

Epoch 2/25
  Batch 100/1902 | Loss: 1.6319 | LR: 0.000263
  Batch 200/1902 | Loss: 1.6299 | LR: 0.000276
  Batch 300/1902 | Loss: 1.6142 | LR: 0.000289
  Batch 400/1902 | Loss: 1.6177 | LR: 0.000303
  Batch 500/1902 | Loss: 1.6133 | LR: 0.000316
  Batch 600/1902 | Loss: 1.6123 | LR: 0.000329
  Batch 700/1902 | Loss: 1.6010 | LR: 0.000342
  Batch 800/1902 | Loss: 1.5991 | LR: 0.000355
  Batch 900/1902 | Loss: 1.5927 | LR: 0.000368
  Batch 1000/1902 | Loss: 1.5907 | LR: 0.000381
  Batch 1100/1902 | Loss: 1.5836 | LR: 0.000395
  Batch 1200/1902 | Loss: 1.5746 | LR: 0.000408
  Batch 1300/1902 | Loss: 1.5675 | LR: 0.000421
  Batch 1400/1902 | Loss: 1.5604 | LR: 0.000434
  Batch 1500/1902 | Loss: 1.5551 | LR: 0.000447
  Batch 1600/1902 | Loss: 1.5494 | LR: 0.000460
  Batch 1700/1902 | Loss: 1.5430 | LR: 0.000473
  Batch 1800/1902 | Loss: 1.5379 | LR: 0.000

EN:  When I was seven years old , I saw my first public execution , but I thought my life in North Korea was normal


In [5]:
# =========================================================
# 0. IMPORTS + SEED
# =========================================================
import math
import random
from collections import Counter, defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F

def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", device)

# =========================================================
# 1. BPE TOKENIZER (GI·ªêNG L√öC TRAIN)
# =========================================================
class BPETokenizer:
    def __init__(self, texts, vocab_size=8000, min_freq=2, max_samples=50000):
        self.vocab_size = vocab_size
        self.word2idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.bpe_codes = {}
        self.cache = {}

        if len(texts) > max_samples:
            texts = random.sample(texts, max_samples)

        self.build_bpe(texts, min_freq)

    def get_stats(self, vocab):
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def merge_vocab(self, pair, vocab):
        new_vocab = {}
        bigram = " ".join(pair)
        replacement = "".join(pair)
        for word, freq in vocab.items():
            new_vocab[word.replace(bigram, replacement)] = freq
        return new_vocab

    def build_bpe(self, texts, min_freq):
        word_freq = Counter()
        for line in texts:
            word_freq.update(line.lower().split())

        vocab = {
            " ".join(list(word)) + " </w>": freq
            for word, freq in word_freq.items()
            if freq >= min_freq
        }

        num_merges = min(self.vocab_size - len(self.word2idx), 5000)

        for i in range(num_merges):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.bpe_codes[best] = i

        for word in vocab.keys():
            for token in word.split():
                if token not in self.word2idx:
                    idx = len(self.word2idx)
                    self.word2idx[token] = idx
                    self.idx2word[idx] = token

    def apply_bpe(self, word):
        if word in self.cache:
            return self.cache[word]

        tokens = " ".join(list(word)) + " </w>"
        while True:
            symbols = tokens.split()
            pairs = [(symbols[i], symbols[i + 1]) for i in range(len(symbols) - 1)]
            valid = [(self.bpe_codes.get(p, 1e9), i, p) for i, p in enumerate(pairs) if p in self.bpe_codes]
            if not valid:
                break
            _, pos, pair = min(valid)
            symbols[pos] = "".join(pair)
            del symbols[pos + 1]
            tokens = " ".join(symbols)

        self.cache[word] = tokens.split()
        return self.cache[word]

    def encode(self, text):
        ids = []
        for word in text.lower().split():
            for t in self.apply_bpe(word):
                ids.append(self.word2idx.get(t, 3))
        return ids

    def decode(self, ids):
        words, cur = [], ""
        for idx in ids:
            if idx == 2:
                break
            if idx > 3:
                tok = self.idx2word.get(idx, "<unk>")
                if tok.endswith("</w>"):
                    cur += tok[:-4]
                    words.append(cur)
                    cur = ""
                else:
                    cur += tok
        if cur:
            words.append(cur)
        return " ".join(words)

# =========================================================
# 2. LOAD DATA ‚Üí BUILD TOKENIZER
# =========================================================
train_en = open("/kaggle/input/en-vi-ds/data/train.en", encoding="utf-8").read().splitlines()
train_vi = open("/kaggle/input/en-vi-ds/data/train.vi", encoding="utf-8").read().splitlines()

print("Building tokenizers...")
tok_src = BPETokenizer(train_en)
tok_trg = BPETokenizer(train_vi)

# =========================================================
# 3. TRANSFORMER MODEL (GI·ªêNG L√öC TRAIN)
# =========================================================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerModel(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model=512):
        super().__init__()
        self.scale = math.sqrt(d_model)
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=0)
        self.trg_emb = nn.Embedding(trg_vocab, d_model, padding_idx=0)
        self.pos = PositionalEncoding(d_model)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=8,
            num_encoder_layers=6,
            num_decoder_layers=6,
            batch_first=True,
            norm_first=True
        )

        self.fc = nn.Linear(d_model, trg_vocab)
        self.fc.weight = self.trg_emb.weight

# =========================================================
# 4. LOAD CHECKPOINT (FIX VOCAB MISMATCH)
# =========================================================
ckpt = torch.load(
    "/kaggle/input/envi-final/pytorch/default/1/best_model (11).pt",
    map_location=device
)

src_vocab_ckpt = ckpt["model_state_dict"]["src_emb.weight"].shape[0]
trg_vocab_ckpt = ckpt["model_state_dict"]["trg_emb.weight"].shape[0]

print("Checkpoint vocab:", src_vocab_ckpt, trg_vocab_ckpt)

model = TransformerModel(
    src_vocab=src_vocab_ckpt,
    trg_vocab=trg_vocab_ckpt
).to(device)

model.load_state_dict(ckpt["model_state_dict"])
model.eval()
print("‚úÖ Model loaded")

# =========================================================
# 5. SAFE ENCODE (KH√ìA TOKEN ID)
# =========================================================
def safe_encode(tok, text, max_id):
    return [i if i < max_id else 3 for i in tok.encode(text)]

# =========================================================
# 6. BEAM SEARCH TRANSLATION
# =========================================================
def translate(text, beam_size=5, max_len=80):
    sos, eos = 1, 2

    src_ids = safe_encode(tok_src, text, src_vocab_ckpt)
    src = torch.tensor([sos] + src_ids + [eos]).unsqueeze(0).to(device)

    with torch.no_grad():
        src_mask = (src == 0)
        memory = model.transformer.encoder(
            model.pos(model.src_emb(src) * model.scale),
            src_key_padding_mask=src_mask
        )

    beams = [([sos], 0.0)]

    for _ in range(max_len):
        new_beams = []
        for seq, score in beams:
            if seq[-1] == eos:
                new_beams.append((seq, score))
                continue

            trg = torch.tensor(seq).unsqueeze(0).to(device)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(trg.size(1)).to(device)

            with torch.no_grad():
                out = model.transformer.decoder(
                    model.pos(model.trg_emb(trg) * model.scale),
                    memory,
                    tgt_mask=tgt_mask,
                    memory_key_padding_mask=src_mask
                )
                logp = F.log_softmax(model.fc(out[:, -1]), dim=-1)

            topk = torch.topk(logp, beam_size)
            for i in range(beam_size):
                new_beams.append(
                    (seq + [topk.indices[0, i].item()],
                     score + topk.values[0, i].item())
                )

        beams = sorted(new_beams, key=lambda x: x[1] / len(x[0]), reverse=True)[:beam_size]
        if all(b[0][-1] == eos for b in beams):
            break

    return tok_trg.decode(beams[0][0][1:])

# =========================================================
# 7. TEST
# =========================================================
print(translate("I love machine learning"))
print(translate("This model was trained using transformer architecture"))
print(translate("How are you today?"))


DEVICE: cuda
Building tokenizers...
Checkpoint vocab: 15346 7038
‚úÖ Model loaded
t√¥i y√™u m√°y h·ªçc .
m√¥ h√¨nh n√†y ƒë∆∞·ª£c hu·∫•n luy·ªán b·∫±ng c√°ch s·ª≠ d·ª•ng kn tr√∫c chuy·ªÉn ho√° .
ng√†y nay b·∫°n l√† ai ?


In [18]:
print(translate("Secondly various projects, researches, assignments and practical scenarios are conducted in universities or colleges from where students get exposure and experience to various problems which they might have to face in their real life while practicing. Like in dentistry the students have to work on tooth for scaling, wiring etc from which they get practical exposure."))
print(translate("Moreover universities have huge libraries carrying thousands of books of different subjects and other study material like fictional, non-fictional, journals, newspapers, reports which are huge sources of information for the students and teachers."))

th·ª© hai l√† nh·ªØng nh√† nghi√™n c·ª©u , nh·ªØng b√†i lu·∫≠n lu·∫≠n v√† th·ª±c t·∫ø ƒë∆∞·ª£c tn h√†nh t·∫°i c√°c tr∆∞·ªùng ƒë·∫°i h·ªçc ho·∫∑c ƒë·∫°i h·ªçc t·ª´ nh·ªØng n∆°i h·ªçc sinh c√≥ ƒë∆∞·ª£c trti√™u v√† kinh nghi·ªám ƒë·ªÉ gi·∫£i quy·∫øt nhi·ªÅu v·∫•n ƒë·ªÅ kh√°c nhau m√† h·ªç c√≥ th·ªÉ gi·∫£i quy·∫øt ƒë∆∞·ª£c b·∫±ng c√°ch √°p d·ª•ng trong cu·ªôc s·ªëng .
nhi·ªÅu tr∆∞·ªùng ƒë·∫°i h·ªçc h∆°n c√≥ nh·ªØng th∆∞ vkh·ªïng l·ªì mang theo h√†ng ng√†n nh·ªØng ƒë·ªÅ t√†i kh√°c nhau v√† nh·ªØng nghi√™n c·ª©u kh√°c nh∆∞ nh·ªØng t√†i li·ªáu h∆∞ c·∫•u , kh√¥ng ph·∫£i t·∫°p ch√≠ , b√°o ch√≠ , b√°o c√°o , b√°o c√°o l√† nh·ªØng m·∫©u tin l·ªõn cho nh·ªØng th√¥ng tin l·ªõn .
