In [None]:
!pip uninstall -y torch torchtext torchaudio torchvision torchdata

!pip install torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/cu121

import os

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import warnings
from tqdm import tqdm
import os
import math
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from sklearn.model_selection import train_test_split
import torchtext.data.metrics

warnings.filterwarnings("ignore")

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

class LayerNormalization(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(d_model))
        self.bias = nn.Parameter(torch.zeros(d_model))
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not divisible by h"
        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            fill_value = torch.finfo(attention_scores.dtype).min
            attention_scores.masked_fill_(mask == 0, fill_value)
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        x, self.attention_scores = self.attention(query, key, value, mask, self.dropout)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        return self.w_o(x)

class ResidualConnection(nn.Module):
    def __init__(self, d_model: int, dropout: float):
        super().__init__()
        self.norm = LayerNormalization(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, sublayer):
        return self.norm(x + self.dropout(sublayer(x)))

class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttention, feed_forward_block: PositionwiseFeedForward, d_model: int, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(d_model, dropout) for _ in range(2)])
    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(layers[0].feed_forward_block.linear_1.in_features)
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, feed_forward_block: PositionwiseFeedForward, d_model: int, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(d_model, dropout) for _ in range(3)])
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(layers[0].feed_forward_block.linear_1.in_features)
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
    def forward(self, x):
        return self.proj(x)

class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
    def encode(self, src, src_mask):
        return self.encoder(self.src_pos(self.src_embed(src)), src_mask)
    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_pos(self.tgt_embed(tgt)), encoder_output, src_mask, tgt_mask)
    def project(self, x):
        return self.projection_layer(x)

def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention = MultiHeadAttention(d_model, h, dropout)
        feed_forward_block = PositionwiseFeedForward(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention, feed_forward_block, d_model, dropout)
        encoder_blocks.append(encoder_block)
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention = MultiHeadAttention(d_model, h, dropout)
        cross_attention = MultiHeadAttention(d_model, h, dropout)
        feed_forward_block = PositionwiseFeedForward(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_self_attention, cross_attention, feed_forward_block, d_model, dropout)
        decoder_blocks.append(decoder_block)
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return transformer

In [None]:
DATA_DIR_KAGGLE = "/kaggle/input/medical-en-vi-corpus/corpus"
TOKENIZER_FILE_TPL = "/kaggle/working/tokenizer_{lang}.json"
OUTPUT_DIR = "/kaggle/working/"

UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"
SOS_TOKEN = "[SOS]"
EOS_TOKEN = "[EOS]"
SPECIAL_TOKENS = [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]

def train_tokenizer_if_not_exists(lang, vocab_size, data_iterator):
    tokenizer_path = TOKENIZER_FILE_TPL.format(lang=lang)
    if not os.path.exists(tokenizer_path):
        print(f"Starting tokenizer for language '{lang}'...")
        tokenizer = Tokenizer(BPE(unk_token=UNK_TOKEN))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)
        tokenizer.train_from_iterator(data_iterator, trainer)
        tokenizer.save(tokenizer_path)
    else:
        print(f"Tokenizer for language '{lang}' existed.")
    return Tokenizer.from_file(tokenizer_path)

class TranslationDataset(Dataset):
    def __init__(self, data_pairs, tokenizer_en, tokenizer_vi):
        self.data_pairs = data_pairs
        self.tokenizer_en = tokenizer_en
        self.tokenizer_vi = tokenizer_vi
        self.sos_token_id_en = self.tokenizer_en.token_to_id(SOS_TOKEN)
        self.eos_token_id_en = self.tokenizer_en.token_to_id(EOS_TOKEN)
        self.sos_token_id_vi = self.tokenizer_vi.token_to_id(SOS_TOKEN)
        self.eos_token_id_vi = self.tokenizer_vi.token_to_id(EOS_TOKEN)
    def __len__(self):
        return len(self.data_pairs)
    def __getitem__(self, idx):
        en_text, vi_text = self.data_pairs[idx]
        en_ids = self.tokenizer_en.encode(en_text).ids
        vi_ids = self.tokenizer_vi.encode(vi_text).ids
        src_ids = [self.sos_token_id_en] + en_ids + [self.eos_token_id_en]
        tgt_ids = [self.sos_token_id_vi] + vi_ids + [self.eos_token_id_vi]
        return {"src_ids": torch.tensor(src_ids, dtype=torch.long), "tgt_ids": torch.tensor(tgt_ids, dtype=torch.long)}

class Collate:
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id
    def __call__(self, batch):
        src_batch, tgt_batch = [], []
        for item in batch:
            src_batch.append(item["src_ids"])
            tgt_batch.append(item["tgt_ids"])
        src_padded = torch.nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=self.pad_token_id)
        tgt_padded = torch.nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=self.pad_token_id)
        return {"src_ids": src_padded, "tgt_ids": tgt_padded}

In [None]:
# Cell 4
def get_config():
    return {
        "batch_size": 4,
        "gradient_accumulation_steps": 8,
        "num_epochs": 30,
        "lr": 1e-4,
        "seq_len": 900, 
        "d_model": 256,
        "num_blocks": 6,
        "num_heads": 8,
        "dropout": 0.1,
        "d_ff": 1024,
        "model_folder": os.path.join(OUTPUT_DIR, "weights"),
        "tokenizer_folder": OUTPUT_DIR,
        "model_basename": "transformer_medical_v_kaggle_",
        "preload": "03",
        "patience": 3,
        "vocab_size": 30000
    }

def get_dataloaders(config):
    with open(os.path.join(DATA_DIR_KAGGLE, "train.en.txt"), 'r', encoding='utf-8') as f:
        train_en_lines = [line.strip() for line in f.readlines()]
    with open(os.path.join(DATA_DIR_KAGGLE, "train.vi.txt"), 'r', encoding='utf-8') as f:
        train_vi_lines = [line.strip() for line in f.readlines()]
    full_train_pairs = list(zip(train_en_lines, train_vi_lines))
    train_pairs, val_pairs = train_test_split(full_train_pairs, test_size=0.1, random_state=42)

    tokenizer_en = train_tokenizer_if_not_exists('en', config['vocab_size'], (pair[0] for pair in train_pairs))
    tokenizer_vi = train_tokenizer_if_not_exists('vi', config['vocab_size'], (pair[1] for pair in train_pairs))

    train_dataset = TranslationDataset(train_pairs, tokenizer_en, tokenizer_vi)
    val_dataset = TranslationDataset(val_pairs, tokenizer_en, tokenizer_vi)
    pad_token_id = tokenizer_en.token_to_id(PAD_TOKEN)
    collate_fn = Collate(pad_token_id)

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    return train_loader, val_loader, tokenizer_en, tokenizer_vi

config = get_config()
train_loader, val_loader, tokenizer_en, tokenizer_vi = get_dataloaders(config)

In [None]:
# Cell 5
def train_model(config, model, train_loader, val_loader, tokenizer_vi):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    os.makedirs(config['model_folder'], exist_ok=True)
    
    model.to(device)

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params / 1e6:.2f}M")

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
    pad_id = tokenizer_vi.token_to_id(PAD_TOKEN)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id, label_smoothing=0.1).to(device)
    scaler = torch.cuda.amp.GradScaler()

    initial_epoch = 0
    best_val_loss = float('inf')
    
    if config['preload'] is not None:
        resume_folder = "/kaggle/input/nlpmidterm-checkpoints-run1/weights"
    
        old_model_basename = "transformer_medical_v_kaggle_"
        
        model_filename = os.path.join(resume_folder, f"{old_model_basename}{config['preload']}.pt")
        if os.path.exists(model_filename):
            print(f"Loading model from: {model_filename}")
            state = torch.load(model_filename)
            model.load_state_dict(state['model_state_dict'])
            initial_epoch = state['epoch'] + 1
            optimizer.load_state_dict(state['optimizer_state_dict'])
            if 'validation_loss' in state:
                best_val_loss = state['validation_loss']
        else:
            print("Checkpoint not found, starting training from scratch.")

    epochs_no_improve = 0
    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        train_loss_acc = 0
        train_batch_count = 0
        batch_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch:02d}", leave=False)
        for i, batch in enumerate(batch_iterator):
            encoder_input = batch['src_ids'].to(device)
            decoder_input = batch['tgt_ids'][:, :-1].to(device)
            label = batch['tgt_ids'][:, 1:].to(device)
            encoder_mask = (encoder_input != pad_id).unsqueeze(1).unsqueeze(2)
            decoder_pad_mask = (decoder_input != pad_id).unsqueeze(1).unsqueeze(2)
            decoder_causal_mask = torch.tril(torch.ones((decoder_input.size(1), decoder_input.size(1)), device=device)).bool()
            decoder_mask = decoder_pad_mask & decoder_causal_mask
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                encoder_output = model.encode(encoder_input, encoder_mask)
                decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
                proj_output = model.project(decoder_output)
                loss = loss_fn(proj_output.reshape(-1, tokenizer_vi.get_vocab_size()), label.reshape(-1))
            train_loss_acc += loss.item()
            train_batch_count +=1
            scaler.scale(loss / config['gradient_accumulation_steps']).backward()
            if (i + 1) % config['gradient_accumulation_steps'] == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            batch_iterator.set_postfix({"loss": f"{loss.item():6.4f}"})
        avg_train_loss = train_loss_acc / train_batch_count

        model.eval()
        val_loss = 0
        with torch.no_grad():
            val_iterator = tqdm(val_loader, desc=f"Validating Epoch {epoch:02d}", leave=False)
            for batch in val_iterator:
                encoder_input = batch['src_ids'].to(device)
                decoder_input = batch['tgt_ids'][:, :-1].to(device)
                label = batch['tgt_ids'][:, 1:].to(device)
                encoder_mask = (encoder_input != pad_id).unsqueeze(1).unsqueeze(2)
                decoder_pad_mask = (decoder_input != pad_id).unsqueeze(1).unsqueeze(2)
                decoder_causal_mask = torch.tril(torch.ones((decoder_input.size(1), decoder_input.size(1)), device=device)).bool()
                decoder_mask = decoder_pad_mask & decoder_causal_mask
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    encoder_output = model.encode(encoder_input, encoder_mask)
                    decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
                    proj_output = model.project(decoder_output)
                    loss = loss_fn(proj_output.reshape(-1, tokenizer_vi.get_vocab_size()), label.reshape(-1))
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)
        print(f"--- Epoch {epoch:02d} Summary ---")
        print(f"Average Training Loss: {avg_train_loss:.4f} | Average Validation Loss: {avg_val_loss:.4f}")

        epoch_model_filename = os.path.join(config['model_folder'], f"{config['model_basename']}{epoch:02d}.pt")
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'validation_loss': avg_val_loss}, epoch_model_filename)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_filename = os.path.join(config['model_folder'], "best_model.pt")
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'validation_loss': best_val_loss}, best_model_filename)
            print(f"Validation loss improved! Saved best model to {best_model_filename}")
        else:
            epochs_no_improve += 1
            print(f"No improvement in validation loss. {epochs_no_improve}/{config['patience']} epochs.")
        if epochs_no_improve >= config['patience']:
            print(f"Early stopping at epoch {epoch} due to no improvement in validation loss.")
            break


model = build_transformer(
    tokenizer_en.get_vocab_size(), tokenizer_vi.get_vocab_size(), config['seq_len'], config['seq_len'],
    d_model=config['d_model'], N=config['num_blocks'], h=config['num_heads'],
    dropout=config['dropout'], d_ff=config['d_ff']
)
train_model(config, model, train_loader, val_loader, tokenizer_vi)

In [None]:
def translate(model: nn.Module, sentence: str, tokenizer_en: Tokenizer, tokenizer_vi: Tokenizer, device: torch.device, config):
    model.eval()
    sos_token = torch.tensor([tokenizer_en.token_to_id("[SOS]")], device=device)
    eos_token = torch.tensor([tokenizer_en.token_to_id("[EOS]")], device=device)
    src_tokens = tokenizer_en.encode(sentence).ids
    src = torch.cat([sos_token, torch.tensor(src_tokens, device=device), eos_token], dim=0).unsqueeze(0)
    src_mask = (src != tokenizer_en.token_to_id("[PAD]")).unsqueeze(1).unsqueeze(2)
    with torch.no_grad():
        encoder_output = model.encode(src, src_mask)
    decoder_input = torch.empty(1, 1).fill_(tokenizer_vi.token_to_id("[SOS]")).type_as(src).to(device)
    while True:
        if decoder_input.size(1) >= config['seq_len']:
            break
        decoder_mask = torch.tril(torch.ones((decoder_input.size(1), decoder_input.size(1)), device=device)).bool()
        with torch.no_grad():
            out = model.decode(encoder_output, src_mask, decoder_input, decoder_mask)
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(src).fill_(next_word.item()).to(device)], dim=1)
        if next_word == tokenizer_vi.token_to_id("[EOS]"):
            break
    output_ids = decoder_input.squeeze(0).tolist()[1:]
    if tokenizer_vi.token_to_id("[EOS]") in output_ids:
        eos_index = output_ids.index(tokenizer_vi.token_to_id("[EOS]"))
        output_ids = output_ids[:eos_index]
    text_output = tokenizer_vi.decode(output_ids)
    return text_output

def calculate_bleu(candidates, references):
    print("Calculating BLEU score...")
    tokenized_candidates = [c.split() for c in candidates]
    tokenized_references = [[r.split()] for r in references]
    bleu = torchtext.data.metrics.bleu_score(tokenized_candidates, tokenized_references)
    print("==============================================")
    print(f"  FINAL BLEU SCORE: {bleu * 100:.2f}")
    print("==============================================")
    return bleu

def evaluate_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer_en = Tokenizer.from_file(TOKENIZER_FILE_TPL.format(lang='en'))
    tokenizer_vi = Tokenizer.from_file(TOKENIZER_FILE_TPL.format(lang='vi'))

    model = build_transformer(
        tokenizer_en.get_vocab_size(), tokenizer_vi.get_vocab_size(), config['seq_len'], config['seq_len'],
        d_model=config['d_model'], N=config['num_blocks'], h=config['num_heads'],
        dropout=config['dropout'], d_ff=config['d_ff']
    ).to(device)
    model_filename = os.path.join(config['model_folder'], "best_model.pt")
    print(f"Loading best model from: {model_filename}")
    state = torch.load(model_filename)
    model.load_state_dict(state['model_state_dict'])

    with open(os.path.join(DATA_DIR_KAGGLE, "test.en.txt"), 'r', encoding='utf-8') as f:
        source_sentences = [line.strip() for line in f.readlines()]
    with open(os.path.join(DATA_DIR_KAGGLE, "test.vi.txt"), 'r', encoding='utf-8') as f:
        reference_sentences = [line.strip() for line in f.readlines()]
    
    translated_sentences = []
    print(f"Starting translation of {len(source_sentences)} sentences in the test set...")
    for sentence in tqdm(source_sentences, desc="Translating"):
        translation = translate(model, sentence, tokenizer_en, tokenizer_vi, device, config)
        translated_sentences.append(translation)

    output_filename = os.path.join(OUTPUT_DIR, "test_outputs.txt")
    with open(output_filename, 'w', encoding='utf-8') as f:
        for sentence in translated_sentences:
            f.write(sentence + '\n')
    print(f"Saved translation results to file: {output_filename}")

    calculate_bleu(translated_sentences, reference_sentences)

evaluate_model(config)