In [None]:
!pip install transformers datasets laonlp underthesea

In [None]:
import os
import torch
import math
import time
from torch.utils.data import Dataset, DataLoader
from torch import nn
from nltk.translate.bleu_score import corpus_bleu
from laonlp import word_tokenize as lao_tokenize
from underthesea import word_tokenize as vi_tokenize

# Paths and parameters
OUTPUT_DIR = "/kaggle/input/machine-translation-models"
BEST_CKPT_PATH = os.path.join(OUTPUT_DIR, "best_model_best.pt")
VOCAB_PATH = os.path.join(OUTPUT_DIR, "vocabularies_best.pt")
VI_TEST_PATH = "/kaggle/input/vi-lo-dataset/test_vi.txt"
LAO_TEST_PATH = "/kaggle/input/vi-lo-dataset/test_lo.txt"
MAX_LENGTH = 128
BATCH_SIZE = 24

# Vocabulary class (must match saved instances)
class Vocabulary:
    def __init__(self, pad_token="<pad>", unk_token="<unk>", sos_token="<sos>", eos_token="<eos>"):
        self.word2idx = {}
        self.idx2word = {}
        self.freq = {}
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.sos_token = sos_token
        self.eos_token = eos_token
        # Initialize special tokens
        for token in [pad_token, unk_token, sos_token, eos_token]:
            self.add_word(token)

    def add_word(self, word):
        if word not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word
            self.freq[word] = 1
        else:
            self.freq[word] += 1

    def encode(self, text, tokenizer_func=None, max_length=None):
        tokens = tokenizer_func(text) if tokenizer_func else text.split()
        tokens = [self.sos_token] + tokens + [self.eos_token]
        if max_length and len(tokens) > max_length:
            tokens = tokens[:max_length-1] + [self.eos_token]
        indices = [self.word2idx.get(tok, self.word2idx[self.unk_token]) for tok in tokens]
        if max_length and len(indices) < max_length:
            indices += [self.word2idx[self.pad_token]] * (max_length - len(indices))
        return indices

    def decode(self, indices):
        tokens = [self.idx2word.get(idx, self.unk_token) for idx in indices]
        result = []
        for tok in tokens:
            if tok == self.eos_token:
                break
            if tok not in [self.pad_token, self.sos_token]:
                result.append(tok)
        return ' '.join(result)

# Positional encoding for transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Transformer model definition
class CustomTransformer(nn.Module):
    def __init__(self, source_vocab_size, target_vocab_size, d_model=768, nhead=12,
                 num_encoder_layers=8, num_decoder_layers=8, dim_feedforward=3072,
                 dropout=0.05, pad_idx=0):
        super().__init__()
        self.d_model = d_model
        self.pad_idx = pad_idx
        self.source_embedding = nn.Embedding(source_vocab_size, d_model, padding_idx=pad_idx)
        self.target_embedding = nn.Embedding(target_vocab_size, d_model, padding_idx=pad_idx)
        self.positional_encoding = PositionalEncoding(d_model, dropout=dropout)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                           num_encoder_layers=num_encoder_layers,
                                           num_decoder_layers=num_decoder_layers,
                                           dim_feedforward=dim_feedforward,
                                           dropout=dropout,
                                           batch_first=False)
        self.fc_out = nn.Linear(d_model, target_vocab_size)

    def forward(self, src, tgt):
        src = src.transpose(0,1)
        tgt = tgt.transpose(0,1)
        src_mask = None
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
        src_key_padding_mask = (src == self.pad_idx).transpose(0,1)
        tgt_key_padding_mask = (tgt == self.pad_idx).transpose(0,1)
        src_emb = self.positional_encoding(self.source_embedding(src) * math.sqrt(self.d_model))
        tgt_emb = self.positional_encoding(self.target_embedding(tgt) * math.sqrt(self.d_model))
        output = self.transformer(src_emb, tgt_emb,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=src_key_padding_mask)
        output = self.fc_out(output)
        return output.transpose(0,1)

# Dataset for translation
class TranslationDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, src_vocab, tgt_vocab,
                 src_tokenizer, tgt_tokenizer, max_length=MAX_LENGTH):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_length = max_length

    def __len__(self): return len(self.src_texts)

    def __getitem__(self, idx):
        src = self.src_vocab.encode(self.src_texts[idx], self.src_tokenizer, self.max_length)
        tgt = self.tgt_vocab.encode(self.tgt_texts[idx], self.tgt_tokenizer, self.max_length)
        return {"source": torch.tensor(src, dtype=torch.long),
                "target": torch.tensor(tgt, dtype=torch.long)}

# Sentence translation function
def translate_sentence(model, sentence, src_vocab, tgt_vocab, device,
                       max_length=MAX_LENGTH, source_tokenizer=None):
    model.eval()
    tokens = source_tokenizer(sentence) if source_tokenizer else sentence.split()
    tokens = [src_vocab.sos_token] + tokens + [src_vocab.eos_token]
    src_indices = [src_vocab.word2idx.get(t, src_vocab.word2idx[src_vocab.unk_token]) for t in tokens]
    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    # Encoder
    src_emb = model.positional_encoding(model.source_embedding(src_tensor.transpose(0,1)) * math.sqrt(model.d_model))
    src_padding_mask = (src_tensor == src_vocab.word2idx[src_vocab.pad_token])
    memory = model.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask)
    # Decoder
    tgt_indices = [tgt_vocab.word2idx[tgt_vocab.sos_token]]
    for _ in range(max_length):
        tgt_tensor = torch.LongTensor(tgt_indices).unsqueeze(0).to(device)
        tgt_emb = model.positional_encoding(model.target_embedding(tgt_tensor.transpose(0,1)) * math.sqrt(model.d_model))
        tgt_mask = model.transformer.generate_square_subsequent_mask(len(tgt_indices)).to(device)
        out = model.transformer.decoder(tgt_emb, memory,
                                        tgt_mask=tgt_mask,
                                        memory_key_padding_mask=src_padding_mask)
        pred = model.fc_out(out[-1, :, :])
        next_idx = pred.argmax(1).item()
        tgt_indices.append(next_idx)
        if next_idx == tgt_vocab.word2idx[tgt_vocab.eos_token]: break
    decoded = [tgt_vocab.idx2word.get(i, tgt_vocab.unk_token) for i in tgt_indices]
    # strip special tokens
    return ' '.join([t for t in decoded if t not in [tgt_vocab.sos_token, tgt_vocab.eos_token, tgt_vocab.pad_token]])

# Utility to load test data

def load_texts(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f if line.strip()]

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    # Load vocabularies and model
    data = torch.load(VOCAB_PATH, map_location=device, weights_only=False)
    vi_vocab = data['vi_vocab']
    lao_vocab = data['lao_vocab']
    ckpt = torch.load(BEST_CKPT_PATH, map_location=device, weights_only=False)
    params = ckpt.get('model_params', {})
    model = CustomTransformer(
        source_vocab_size=params['source_vocab_size'],
        target_vocab_size=params['target_vocab_size'],
        d_model=params.get('d_model', 768),
        nhead=params.get('nhead', 12),
        num_encoder_layers=params.get('num_encoder_layers', 8),
        num_decoder_layers=params.get('num_decoder_layers', 8),
        dim_feedforward=params.get('dim_feedforward', 3072),
        dropout=params.get('dropout', 0.05),
        pad_idx=params.get('pad_idx', vi_vocab.word2idx[vi_vocab.pad_token])
    ).to(device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    # Load test data
    vi_test = load_texts(VI_TEST_PATH)
    lo_test = load_texts(LAO_TEST_PATH)
    test_dataset = TranslationDataset(vi_test, lo_test, vi_vocab, lao_vocab, vi_tokenize, lao_tokenize)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    # Compute test loss
    criterion = nn.CrossEntropyLoss(ignore_index=lao_vocab.word2idx[lao_vocab.pad_token])
    total_loss = 0.0
    batches = 0
    with torch.no_grad():
        for batch in test_loader:
            src = batch['source'].to(device)
            tgt = batch['target'].to(device)
            tgt_in = tgt[:, :-1]
            tgt_out = tgt[:, 1:]
            out = model(src, tgt_in)
            out_dim = out.shape[-1]
            loss = criterion(out.contiguous().view(-1, out_dim), tgt_out.contiguous().view(-1))
            total_loss += loss.item()
            batches += 1
    print(f"Test Loss: {total_loss/batches:.4f}")
    # Compute BLEU
    refs = [[r.split()] for r in lo_test]
    hyps = []
    start = time.time()
    for sent in vi_test:
        hyps.append(translate_sentence(model, sent, vi_vocab, lao_vocab, device, MAX_LENGTH, vi_tokenize).split())
    bleu = corpus_bleu(refs, hyps) * 100
    print(f"Corpus BLEU: {bleu:.2f}")
    print(f"Test time: {time.time() - start:.2f}s")