In [None]:
!pip install -q evaluate
import matplotlib.pyplot as plt
import seaborn as sns
import random
import numpy as np
import pandas as pd

import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch

from datasets import load_dataset

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from torchsummary import summary

from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from tqdm import tqdm
import evaluate

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(42)  

In [None]:
ds = load_dataset("thainq107/iwslt2015-en-vi")
train_data, valid_data, test_data = ds["train"], ds["validation"], ds["test"]
en_lengths = [len(sentence) for sentence in train_data["en"]]
vi_lengths = [len(sentence) for sentence in train_data["vi"]]

combined_lengths = en_lengths + vi_lengths

# Calculate Q1, Q3, and IQR
Q1 = np.percentile(combined_lengths, 25)
Q3 = np.percentile(combined_lengths, 75)
IQR = Q3 - Q1

MAX_LENGTH = int(Q3 + 1.5*IQR)
MAX_LENGTH = 50

In [None]:
# Define special tokens
unk_token = "[UNK]"
pad_token = "[PAD]"
mask_token = "[MASK]" # Try out
sos_token = "<sos>"
eos_token = "<eos>"
special_tokens = [unk_token, pad_token, sos_token, eos_token]

en_tokenizer = Tokenizer(BPE(unk_token=unk_token))
en_tokenizer.pre_tokenizer = Whitespace()
en_trainer = BpeTrainer(special_tokens=special_tokens)
en_tokenizer.train_from_iterator(train_data["en"], trainer=en_trainer)

vi_tokenizer = Tokenizer(BPE(unk_token=unk_token))
vi_tokenizer.pre_tokenizer = Whitespace()
vi_trainer = BpeTrainer(special_tokens=special_tokens)
vi_tokenizer.train_from_iterator(train_data["vi"], trainer=vi_trainer)

# Build reverse lookup maps for later decoding (id -> token)
def id_to_token(tokenizer):
    vocab = tokenizer.get_vocab()
    return {v: k for k, v in vocab.items()}
    
# For decoder
en_id_to_token = id_to_token(en_tokenizer)
vi_id_to_token = id_to_token(vi_tokenizer)

In [None]:
VOCAB_SIZE = len(en_tokenizer.get_vocab())   
OUTPUT_DIM = len(vi_tokenizer.get_vocab())    
EMBEDDING_DIM = 256 
BATCH_SIZE = 32
HIDDEN_SIZE = 512            
N_LAYERS = 2
ENCODER_DROPOUT = 0.2
DECODER_DROPOUT = 0.2
BIDIRECTIONAL = False
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def tokenize(data, tokenizer, max_length=MAX_LENGTH):
    encoding = tokenizer.encode(data)

    ids = [tokenizer.token_to_id(sos_token)] + encoding.ids + [tokenizer.token_to_id(eos_token)]
    return ids

def tokenize_and_numericalize(data, src_tokenizer, trg_tokenizer, max_length=MAX_LENGTH):
    return {"src_ids": tokenize(data["en"], src_tokenizer),
            "trg_ids": tokenize(data["vi"], trg_tokenizer)    
           }

train_data = train_data.map(lambda x: tokenize_and_numericalize(x, en_tokenizer, vi_tokenizer))
valid_data = valid_data.map(lambda x: tokenize_and_numericalize(x, en_tokenizer, vi_tokenizer))
test_data = test_data.map(lambda x: tokenize_and_numericalize(x, en_tokenizer, vi_tokenizer))

In [None]:
def get_collate_fn(src_pad_index, trg_pad_index):
    def collate_fn(batch):
        batch_src_ids = [torch.tensor(example["src_ids"]) for example in batch]
        batch_trg_ids = [torch.tensor(example["trg_ids"]) for example in batch]
        
        # Manually pad the sequences to MAX_LENGTH (ensure truncation if necessary)
        for i in range(len(batch_src_ids)):
            # Truncate source sequences that are longer than MAX_LENGTH
            if len(batch_src_ids[i]) > MAX_LENGTH:
                batch_src_ids[i] = batch_src_ids[i][:MAX_LENGTH]
            # Pad source sequences that are shorter than MAX_LENGTH
            elif len(batch_src_ids[i]) < MAX_LENGTH:
                batch_src_ids[i] = torch.cat([batch_src_ids[i], torch.full((MAX_LENGTH - len(batch_src_ids[i]),), src_pad_index)])
            
            # Truncate target sequences that are longer than MAX_LENGTH
            if len(batch_trg_ids[i]) > MAX_LENGTH:
                batch_trg_ids[i] = batch_trg_ids[i][:MAX_LENGTH]
            # Pad target sequences that are shorter than MAX_LENGTH
            elif len(batch_trg_ids[i]) < MAX_LENGTH:
                batch_trg_ids[i] = torch.cat([batch_trg_ids[i], torch.full((MAX_LENGTH - len(batch_trg_ids[i]),), trg_pad_index)])

        # Stack all the sequences to create the batch
        batch_src_ids = torch.stack(batch_src_ids)
        batch_trg_ids = torch.stack(batch_trg_ids)

        return {"src_ids": batch_src_ids, "trg_ids": batch_trg_ids}
    
    return collate_fn


def get_data_loader(dataset, batch_size, src_pad_index, trg_pad_index, shuffle=False):
    collate_fn = get_collate_fn(src_pad_index, trg_pad_index)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)


src_pad_index = en_tokenizer.token_to_id(pad_token)
trg_pad_index = vi_tokenizer.token_to_id(pad_token)

train_data_loader = get_data_loader(train_data, BATCH_SIZE, src_pad_index, trg_pad_index, shuffle=True)
valid_data_loader = get_data_loader(valid_data, BATCH_SIZE, src_pad_index, trg_pad_index)
test_data_loader  = get_data_loader(test_data,  BATCH_SIZE, src_pad_index, trg_pad_index)

In [None]:
# for batch in train_data_loader:
#     print(batch["trg_ids"])
#     break

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, n_layers=2, dropout_p=0.1, bidirectional=False, device='cuda'):
        super(Encoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers=n_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout_p)
        self.dropout = nn.Dropout(dropout_p)
        self.fc = nn.Linear(hidden_size * 2 if bidirectional else hidden_size, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        
        self.device = device
        self.bidirectional = bidirectional
        

    def forward(self, x):
        # x: [batch_size, seq_len]
        x = x.to(self.device)
        embedding = self.dropout(self.embedding(x))

        out, (hidden, cell) = self.lstm(embedding)
        out = self.layer_norm(out)

        
        if self.bidirectional:
            hidden = hidden[-2,:,:] + hidden[-1,:,:]
            out = self.fc(out)
        else:
            hidden = hidden[-1,:,:]
        # out: [batch_size, seq_len, hidden_size]
        # hidden: [batch_size, hidden_size]
        return out, hidden


In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
       
        self.Wa = nn.Linear(hidden_size, hidden_size)  # For the query (decoder hidden state)
        self.Ua = nn.Linear(hidden_size, hidden_size)  # For the keys (encoder outputs)
        self.Va = nn.Linear(hidden_size, 1)            # Final layer for the attention score

    def forward(self, keys, query):

        query_transformed = self.Wa(query).unsqueeze(1)  # Shape: [batch_size, 1, hidden_size]
        keys_transformed = self.Ua(keys)  # Shape: [batch_size, seq_len, hidden_size]

       
        scores = self.Va(torch.tanh(query_transformed + keys_transformed))  # Shape: [batch_size, seq_len, 1]

        scores = scores.squeeze(2)  # Shape: [batch_size, seq_len]

        attention_weights = F.softmax(scores, dim=-1)  # Shape: [batch_size, seq_len]

      
        context_vector = torch.bmm(attention_weights.unsqueeze(1), keys)  # Shape: [batch_size, 1, hidden_size]

        context_vector = context_vector.squeeze(1)  # Shape: [batch_size, hidden_size]

        return context_vector, attention_weights

In [None]:
class LuongAttention(nn.Module):
    def __init__(self, hidden_size):
        super(LuongAttention, self).__init__()
        self.attn = nn.Linear(hidden_size, hidden_size)

    def forward(self, encoder_outputs, hidden):
        # encoder_outputs: [batch_size, seq_len, hidden_size]
        # hidden: [batch_size, hidden_size]

        # Reshape hidden to be [batch_size, hidden_size, 1]
        hidden = hidden.unsqueeze(2)  # [batch_size, hidden_size, 1]

        # Compute attention scores (dot product)
        attn_energies = torch.bmm(encoder_outputs, hidden) / (self.attn.in_features ** 0.5)  # [batch_size, seq_len, 1], Scaled dot-product attention
        attn_energies = attn_energies.squeeze(2)  # [batch_size, seq_len]

        # Compute the attention weights (probabilities)
        attention_weights = torch.softmax(attn_energies, dim=1)  # [batch_size, seq_len]

        # Compute the weighted sum of encoder outputs (context vector)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # [batch_size, 1, hidden_size]
        context = context.squeeze(1)  # [batch_size, hidden_size]

        return context, attention_weights


In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, vocab_size, embedding_dim, attention_type, dropout_p=0.1, n_layers=1, bidirectional=False):
        super(Decoder, self).__init__()

        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        self.n_directions = 2 if bidirectional else 1
        
        if attention_type == 'Luong':
            self.attention = LuongAttention(hidden_size)
        elif attention_type == 'Bahdanau':
            self.attention = BahdanauAttention(hidden_size)
        else:
            raise ValueError("Unknown attention type")

        # Embedding layer
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)

        # Dropout layer
        self.dropout = nn.Dropout(dropout_p)

        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,  
            hidden_size=self.hidden_size,  
            num_layers=n_layers,  
            batch_first=True,  
            bidirectional=bidirectional,  
            dropout=dropout_p
        )

        self.fc = nn.Linear(hidden_size * self.n_directions + hidden_size , hidden_size)

        # Output layer
        self.out = nn.Linear(self.hidden_size, self.vocab_size)

        # Batch normalization layers
        self.bn_lstm = nn.BatchNorm1d(hidden_size * self.n_directions)
        self.bn_fc = nn.BatchNorm1d(hidden_size + hidden_size)

    def forward(self, input, hidden, encoder_outputs):
        """
        input: [batch_size]
        hidden: [batch_size, hidden_size]
        encoder_outputs: [batch_size, seq_len, hidden_size]
        """

        input = input.unsqueeze(1).to(DEVICE)
        embedded = self.dropout(self.embedding(input))  # [batch_size, 1, embedding_dim]

        if len(hidden.shape) == 2:
            hidden = hidden.unsqueeze(0).repeat(self.n_layers * self.n_directions, 1, 1) # [, batch_size, hidden_size]

        lstm_output, (hidden, cell) = self.lstm(embedded, hidden)  # [batch_size, 1, hidden_size]

        lstm_output = self.bn_lstm(lstm_output.squeeze(1))  # [batch_size, hidden_size]

        if self.lstm.bidirectional:
            hidden = hidden[-2, :, :] + hidden[-1, :, :]  
            lstm_output = self.fc(lstm_output)
        else:
            hidden = hidden[-1, :, :]  

        context, attention_weights = self.attention(encoder_outputs, hidden) # 

        output = torch.cat([lstm_output, context], dim=1)  
        output = self.bn_fc(output)
        
        output = self.fc(output)  # Use linear layer to combine them

        prediction = self.out(output)  # [batch_size, vocab_size]

        return prediction, hidden, attention_weights

In [None]:
# for batch in train_data_loader:
#     t = batch
#     break

In [None]:
# encoder = Encoder(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_SIZE, N_LAYERS, ENCODER_DROPOUT, BIDIRECTIONAL, DEVICE).to(DEVICE)
# out, hidden = encoder(t["src_ids"])
# out.shape, hidden.shape

In [None]:
# decoder = Decoder(HIDDEN_SIZE, VOCAB_SIZE, EMBEDDING_DIM, "Luong", DECODER_DROPOUT, N_LAYERS, BIDIRECTIONAL).to(DEVICE)
# prediction, hidden, att = decoder(t["trg_ids"][:, 0], hidden, out)
# prediction.shape, hidden.shape, att.shape

In [None]:
# seq2seq = Seq2Seq(encoder, decoder)
# output = seq2seq(t["src_ids"], t["trg_ids"])
# output.shape

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)  # Get batch size from input
        trg_len = target.size(1)
        outputs = torch.zeros(batch_size, trg_len, OUTPUT_DIM).to(DEVICE)

        encoder_outputs, hidden = self.encoder(source)

        input = target[:, 0]  # SOS

        for t in range(1, trg_len):
            output, hidden, _ = self.decoder(
                input, hidden, encoder_outputs
            )

            outputs[:, t, :] = output
            teacher_force = torch.rand(1, device=DEVICE).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1

        return outputs

In [None]:
ATTENTION_TYPE = "Bahdanau"

encoder =  Encoder(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_SIZE, N_LAYERS, ENCODER_DROPOUT, BIDIRECTIONAL)
decoder = Decoder(HIDDEN_SIZE, VOCAB_SIZE, EMBEDDING_DIM, ATTENTION_TYPE, DECODER_DROPOUT, N_LAYERS, BIDIRECTIONAL)
seq2seq = Seq2Seq(encoder, decoder)
encoder.to(DEVICE), decoder.to(DEVICE), seq2seq.to(DEVICE)

In [None]:
def plot_metrics(train_losses, val_losses, bleu_scores):
    epochs = len(train_losses)

    # Plot Train Loss vs Validation Loss
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(epochs), train_losses, label='Train Loss', color='blue')
    plt.plot(range(epochs), val_losses, label='Validation Loss', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Plot BLEU score over epochs
    plt.subplot(1, 2, 2)
    plt.plot(range(epochs), bleu_scores, label='Validation BLEU Score', color='green')
    plt.xlabel('Epochs')
    plt.ylabel('BLEU Score')
    plt.title('Validation BLEU Score over Epochs')
    plt.legend()

    # Show the plots
    plt.tight_layout()
    plt.show()


def compute_bleu(predictions, targets):
    """
    Compute the BLEU score for the predictions against the targets.
    
    Args:
    predictions (list of list of int): Generated sequence of token IDs.
    targets (list of list of int): Reference target sequence of token IDs.
    
    Returns:
    float: BLEU score
    """
    # Define a smoothing function
    smoothing_function = SmoothingFunction().method4  # You can choose the method you prefer (e.g., method1, method2, etc.)

    # BLEU expects lists of n-grams. Each target is a list of token IDs,
    # and the predictions should be a list of token IDs too.
    return corpus_bleu([[target] for target in targets], predictions, smoothing_function=smoothing_function)



In [None]:
def train_fn(model, train_loader, optimizer, criterion, clip, teacher_forcing_ratio=0.5, device='cuda'):
    model.train()  # Set model to training mode
    epoch_train_loss = 0

    for batch in train_loader:
        source = batch['src_ids'].to(device)
        target = batch['trg_ids'].to(device)

        optimizer.zero_grad()  # Clear previous gradients
        
        # Forward pass
        output = model(source, target, teacher_forcing_ratio)  # Get the model's output
        
        # Calculate loss (using CrossEntropy loss between the predicted and true target)
        loss = criterion(output[1:].view(-1, output.size(-1)), target[1:].view(-1))  # Flatten for CE loss
        
        loss.backward()  # Backpropagation
        
        # **Apply gradient clipping** to prevent exploding gradients
        if clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()  # Update the model parameters
        
        epoch_train_loss += loss.item()

    # Calculate average training loss for the epoch
    avg_train_loss = epoch_train_loss / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}")
    return avg_train_loss

def evaluate_fn(model, val_loader, criterion, device='cuda'):
    model.eval()  # Set model to evaluation mode
    epoch_val_loss = 0
    val_predictions = []
    val_targets = []

    with torch.no_grad():  # No need to compute gradients during evaluation
        for batch in val_loader:
            # Move data to the specified device (CUDA or CPU)
            source = batch['src_ids'].to(device)
            target = batch['trg_ids'].to(device)

            # Forward pass (no teacher forcing during evaluation)
            output = model(source, target, teacher_forcing_ratio=0)  # teacher_forcing_ratio=0 during evaluation
            
            # Calculate loss
            loss = criterion(output[1:].view(-1, output.size(-1)), target[1:].view(-1))  # Flatten for CE loss

            epoch_val_loss += loss.item()

            # Store predictions and targets for BLEU score calculation (or other metrics)
            pred = output.argmax(dim=-1)
            val_predictions.extend(pred.cpu().numpy())
            val_targets.extend(target.cpu().numpy())

    # Calculate average validation loss for the epoch
    avg_val_loss = epoch_val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

    # Calculate BLEU score (you can replace this with other evaluation metrics if needed)
    val_bleu_score = compute_bleu(val_predictions, val_targets)  # Assuming you have a BLEU function
    print(f"Validation BLEU Score: {val_bleu_score:.4f}")
    
    return avg_val_loss, val_bleu_score

In [None]:
import torch
checkpoint_path = ""

In [None]:
def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, scheduler,
                       n_epochs=1, teacher_forcing_ratio=0.5, device='cuda',
                       start_epoch=0, train_losses=None, val_losses=None, bleu_scores=None,
                       best_valid_loss=float("inf")):

    train_losses = train_losses or []
    val_losses = val_losses or []
    bleu_scores = bleu_scores or []

    for epoch in tqdm(range(start_epoch, start_epoch+n_epochs), desc="Training Epochs"):
        print(f"Epoch {epoch+1}/{start_epoch+n_epochs}")

        # Train the model
        train_loss = train_fn(model, train_loader, optimizer, criterion,
                              clip=1.0, teacher_forcing_ratio=teacher_forcing_ratio, device=device)
        train_losses.append(train_loss)

        # Evaluate the model
        val_loss, val_bleu_score = evaluate_fn(model, val_loader, criterion, device=device)
        val_losses.append(val_loss)
        bleu_scores.append(val_bleu_score)

        # Save best model
        if val_loss < best_valid_loss:
            best_valid_loss = val_loss
            torch.save({'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'train_loss': train_loss,
                        'val_loss': val_loss,
                        'bleu_score': val_bleu_score
                        }, 'checkpoint.pth')


        # Print stats
        print(f"Epoch {epoch+1} | Train Loss: {train_loss:.3f} | Train PPL: {np.exp(train_loss):.3f}")
        print(f"Epoch {epoch+1} | Valid Loss: {val_loss:.3f} | Valid PPL: {np.exp(val_loss):.3f}")
        print(f"Epoch {epoch+1} | Valid BLEU Score: {val_bleu_score:.3f}")

        scheduler.step(val_loss)

    # Plot
    plot_metrics(train_losses, val_losses, bleu_scores)


In [None]:
model = Seq2Seq(encoder, decoder).to(DEVICE)  # or however you instantiated it

optimizer = optim.Adam(seq2seq.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=src_pad_index)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)



if checkpoint_path: 
    checkpoint = torch.load(checkpoint_path)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    start_epoch = checkpoint['epoch'] + 1
    train_losses = [checkpoint['train_loss']]
    val_losses = [checkpoint['val_loss']]
    bleu_scores = [checkpoint['bleu_score']]
    best_valid_loss = checkpoint['val_loss']  # assuming best = last val_loss
    
    print(f"Resuming from epoch {start_epoch}")
else:
    start_epoch = 0
    train_losses = []
    val_losses = []
    bleu_scores = []
    best_valid_loss = float("inf")

In [None]:
# Resume training
train_and_evaluate(
    model,
    train_data_loader,
    valid_data_loader,
    optimizer,
    criterion,
    scheduler,
    n_epochs=10,
    teacher_forcing_ratio=0.5,
    device='cuda',
    start_epoch=start_epoch,
    train_losses=train_losses,  # start with previous
    val_losses=val_losses,
    bleu_scores=bleu_scores,
    best_valid_loss=best_valid_loss # resume best val loss
)


In [None]:
# model = Seq2Seq(encoder, decoder).to(DEVICE)  # or however you instantiated it

# # Load the weights
# checkpoint = torch.load(checkpoint_path)
    
# model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# def translate_sentence(sentence, model, en_tokenizer, vi_tokenizer, vi_id_to_token, device, max_output_length=MAX_LENGTH):
#     model.eval()  # Set the model to evaluation mode
#     with torch.no_grad():  # Disable gradient computation
#         # Tokenize the input sentence (convert to token IDs)
#         encoding = en_tokenizer.encode(sentence)
        
#         # Add <SOS> and <EOS> tokens to the input
#         src_ids = [en_tokenizer.token_to_id(sos_token)] + encoding.ids + [en_tokenizer.token_to_id(eos_token)]
        
#         # Convert the source IDs to a tensor with shape [1, seq_len]
#         tensor = torch.LongTensor(src_ids).unsqueeze(0).to(device)  # [1, seq_len]

#         # Pass the tensor through the encoder
#         encoder_outputs, hidden = model.encoder(tensor)  # hidden, cell from LSTM encoder

#         # Initialize input for the decoder with <SOS> token
#         input_token = torch.LongTensor([vi_tokenizer.token_to_id(sos_token)]).to(device)  # [1]
        
#         outputs = [input_token.item()]  # Store the generated tokens
#         for _ in range(max_output_length):
            
#             # Pass the current input token along with hidden and cell states to the decoder
#             output, hidden, _ = model.decoder(input_token, hidden, encoder_outputs)

#             # Get the predicted token (index of the highest probability)
#             predicted_token = output.argmax(-1).item()
#             # print(f"Predicted Token ID: {predicted_token}")
#             # Append the predicted token to the output sequence
#             outputs.append(predicted_token)

#             # Update the input token for the next step (teacher-forcing is not used here)
#             input_token = torch.LongTensor([predicted_token]).to(device)

#             # If the decoder outputs <EOS>, stop the generation
#             if predicted_token == vi_tokenizer.token_to_id(eos_token):
#                 break
        
#         # Convert the predicted token IDs back to words
#         translated_tokens = [vi_id_to_token[idx] for idx in outputs]
        
#     return translated_tokens

# # Test translation on a single example

# def en_vi(i):
#     return train_data[i]["en"], train_data[i]["vi"]

# sentence ,expected_translation = en_vi(9)
# print("Source (English):", sentence)
# print("Expected Translation (Vietnamese):", expected_translation)

# # Run the translation function
# translation = translate_sentence(sentence, model, en_tokenizer, vi_tokenizer, vi_id_to_token, DEVICE)
# print("Model Translation:", translation)