In [1]:
!cp -r "/kaggle/input/vi-tone-no-tone/data" /kaggle/working/

In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
import random
import time
from torch.utils.data import Dataset, DataLoader
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import string
import matplotlib.pyplot as plt

# Tạo thư mục lưu kết quả
os.makedirs("/kaggle/working/results", exist_ok=True)

# Định nghĩa lớp Dataset
class TranslationDatasetFull(Dataset):
    def __init__(self, in_file, out_file, in_vocab, out_vocab, max_len=50):
        self.in_sentences = self._load_sentences(in_file)
        self.out_sentences = self._load_sentences(out_file)
        self.in_vocab = self._load_vocab(in_vocab) if isinstance(in_vocab, str) else in_vocab
        self.out_vocab = self._load_vocab(out_vocab) if isinstance(out_vocab, str) else out_vocab
        self.max_len = max_len

    def _load_sentences(self, file_path):
        df = pd.read_csv(file_path, encoding='utf-8', header=None, names=['ID', 'Sentence'])
        sentences = df['Sentence'].tolist()
        return [str(s).strip() for s in sentences if str(s).strip()]

    def _load_vocab(self, vocab_path):
        vocab = {}
        with open(vocab_path, 'r', encoding='utf-8') as f:
            words = [line.strip() for line in f if line.strip()]
        for idx, word in enumerate(words):
            vocab[word] = idx
        required_tokens = ['<unk>', '<pad>', '<sos>', '<eos>']
        max_idx = max(vocab.values()) if vocab else -1
        for token in required_tokens:
            if token not in vocab:
                max_idx += 1
                vocab[token] = max_idx
        return vocab

    def _encode_sentence(self, sentence, vocab, max_len):
        tokens = sentence.strip().split()
        token_ids = [vocab.get(token, vocab['<unk>']) for token in tokens]
        token_ids = token_ids[:max_len] + [vocab['<pad>']] * (max_len - len(token_ids))
        return token_ids

    def _encode_decoder_sentence(self, sentence, vocab, max_len):
        tokens = sentence.strip().split()
        full_tokens = [vocab['<sos>']] + [vocab.get(token, vocab['<unk>']) for token in tokens] + [vocab['<eos>']]
        if len(full_tokens) < max_len:
            full_tokens += [vocab['<pad>']] * (max_len - len(full_tokens))
        else:
            full_tokens = full_tokens[:max_len]
        return full_tokens

    def __len__(self):
        return len(self.in_sentences)

    def __getitem__(self, idx):
        in_sentence = self.in_sentences[idx]
        out_sentence = self.out_sentences[idx]
        src = self._encode_sentence(in_sentence, self.in_vocab, self.max_len)
        tgt = self._encode_decoder_sentence(out_sentence, self.out_vocab, self.max_len)
        return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long)

# Custom collate function to enforce max_len
def collate_fn(batch):
    max_len = 50  # Enforce MAX_LEN
    src_batch, tgt_batch = [], []
    for src, tgt in batch:
        src = src[:max_len]
        tgt = tgt[:max_len]
        src_batch.append(src)
        tgt_batch.append(tgt)
    src_batch = torch.stack(src_batch)
    tgt_batch = torch.stack(tgt_batch)
    return src_batch, tgt_batch

# Tạo dataset và DataLoader
train_dataset = TranslationDatasetFull(
    "/kaggle/working/data/train/source.csv",
    "/kaggle/working/data/train/target.csv",
    "/kaggle/working/data/vocab/input_vocab.txt",
    "/kaggle/working/data/vocab/output_vocab.txt",
    max_len=50
)
val_dataset = TranslationDatasetFull(
    "/kaggle/working/data/val/source.csv",
    "/kaggle/working/data/val/target.csv",
    "/kaggle/working/data/vocab/input_vocab.txt",
    "/kaggle/working/data/vocab/output_vocab.txt",
    max_len=50
)
test_dataset = TranslationDatasetFull(
    "/kaggle/working/data/test/source.csv",
    "/kaggle/working/data/test/target.csv",
    "/kaggle/working/data/vocab/input_vocab.txt",
    "/kaggle/working/data/vocab/output_vocab.txt",
    max_len=50
)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2, collate_fn=collate_fn)

# Debug dataset và vocab
print(f"Number of sentences in train_dataset: {len(train_dataset)}")
print(f"Number of sentences in val_dataset: {len(val_dataset)}")
print(f"Number of sentences in test_dataset: {len(test_dataset)}")
print(f"Expected train batches: {len(train_loader)}")
print("Input vocab size:", len(train_dataset.in_vocab))
print("Output vocab size:", len(train_dataset.out_vocab))
print("Sample input vocab:", list(train_dataset.in_vocab.items())[:5])
print("Sample output vocab:", list(train_dataset.out_vocab.items())[:5])
print("Value of <pad>:", train_dataset.out_vocab.get('<pad>', "Not found"))
print("Sample source:", train_dataset.in_sentences[:5])
print("Sample target:", train_dataset.out_sentences[:5])

Number of sentences in train_dataset: 4393646
Number of sentences in val_dataset: 549205
Number of sentences in test_dataset: 549207
Expected train batches: 34326
Input vocab size: 1450
Output vocab size: 5805
Sample input vocab: [('a', 0), ('ac', 1), ('ach', 2), ('ai', 3), ('am', 4)]
Sample output vocab: [('a', 0), ('a1', 1), ('a1c', 2), ('a1ch', 3), ('a1i', 4)]
Value of <pad>: 5802
Sample source: ['tenedos barronus uoc ralph vary chamberlin mieu ta nam', 'ngay giao su tran van huong uoc quoc truong phan khac suu bo nhiem lam thu tuong', 'trong noi inh cac vi quy nhan cung cung tan co ngau nhien lam sai ieu gi quach hau cung khong truy cuu con o truoc mat tao phi bao che', 'chung uoc su dung cho che tao cac cam bien tia hong ngoai hoac nhiet ien', 'uchukeiji gyaban uoc bat au tu mot tam hinh minh hoa cua murakami katsushi mot nhan vien thiet ke cua hang bandai nguoi a e lai ten tuoi minh trong lich su nganh o choi voi nhieu san pham oc ao']
Sample target: ['tenedos barronus d9u7o75c r

In [None]:
# Define Transformer Model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
DIM_MODEL = 512
N_HEADS = 8
N_LAYERS = 4
D_FF = 512
DROPOUT = 0.1
MAX_LEN = 50
NUM_EPOCHS = 6

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.depth = d_model // n_heads

        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.n_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def forward(self, query, key, value, mask=None, padding_mask=None):
        batch_size = query.size(0)
        query_len = query.size(1)
        key_len = key.size(1)

        query = self.split_heads(self.query_linear(query), batch_size)  # [batch_size, n_heads, query_len, depth]
        key = self.split_heads(self.key_linear(key), batch_size)        # [batch_size, n_heads, key_len, depth]
        value = self.split_heads(self.value_linear(value), batch_size)  # [batch_size, n_heads, key_len, depth]

        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)  # [batch_size, n_heads, query_len, key_len]

        # Áp dụng padding mask (ngăn attention đến các vị trí pad)
        if padding_mask is not None:
            # padding_mask: [batch_size, key_len]
            padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, key_len]
            attention_scores = attention_scores.masked_fill(padding_mask == 1, float('-inf'))

        # Áp dụng causal mask hoặc source mask (nếu có)
        if mask is not None:
            # mask: [key_len, key_len] hoặc [query_len, key_len]
            if len(mask.shape) == 2:
                mask = mask.unsqueeze(0).unsqueeze(1)  # [1, 1, query_len/key_len, key_len]
                mask = mask.repeat(batch_size, self.n_heads, 1, 1)  # [batch_size, n_heads, query_len/key_len, key_len]
            elif len(mask.shape) == 3:
                mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)  # [batch_size, n_heads, query_len/key_len, key_len]
            # Đảm bảo kích thước của mask phù hợp với attention_scores
            if mask.size(2) != query_len or mask.size(3) != key_len:
                mask = mask[:, :, :query_len, :key_len]
            attention_scores = attention_scores + mask

        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, value)  # [batch_size, n_heads, query_len, depth]
        context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, self.d_model)
        output = self.out_linear(context)
        return output

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, padding_mask=None):
        attn_output = self.attention(x, x, x, mask=mask, padding_mask=padding_mask)
        x = self.layer_norm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(x)
        x = self.layer_norm2(x + self.dropout(ffn_output))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads)
        self.cross_attention = MultiHeadAttention(d_model, n_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        # Self-attention với target (causal mask)
        self_attn_output = self.self_attention(x, x, x, mask=tgt_mask, padding_mask=tgt_padding_mask)
        x = self.layer_norm1(x + self.dropout(self_attn_output))

        # Cross-attention với encoder output
        cross_attn_output = self.cross_attention(x, enc_output, enc_output, mask=src_mask, padding_mask=src_padding_mask)
        x = self.layer_norm2(x + self.dropout(cross_attn_output))

        # Feed-forward
        ffn_output = self.ffn(x)
        x = self.layer_norm3(x + self.dropout(ffn_output))
        return x

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_len=5000, dropout=0.1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = self._generate_positional_encoding(max_len, d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def _generate_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        return pe

    def forward(self, x, mask=None, padding_mask=None):
        x = self.embedding(x) * math.sqrt(self.d_model)
        seq_len = x.size(1)
        x = x + self.positional_encoding[:, :seq_len, :].to(x.device)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, mask=mask, padding_mask=padding_mask)
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_len=5000, dropout=0.1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = self._generate_positional_encoding(max_len, d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.output_linear = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def _generate_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        return pe

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        x = self.embedding(x) * math.sqrt(self.d_model)
        seq_len = x.size(1)
        x = x + self.positional_encoding[:, :seq_len, :].to(x.device)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, enc_output, src_mask=src_mask, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask)
        logits = self.output_linear(x)
        return logits

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads, n_layers, d_ff, max_len=5000, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, n_heads, n_layers, d_ff, max_len, dropout)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_heads, n_layers, d_ff, max_len, dropout)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        enc_output = self.encoder(src, mask=src_mask, padding_mask=src_padding_mask)
        logits = self.decoder(tgt, enc_output, src_mask=src_mask, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask)
        return logits

# Generate target mask for causal attention
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# Greedy decoding for inference
def greedy_decode(model, src, src_mask, src_padding_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)
    src_padding_mask = src_padding_mask.to(DEVICE)
    
    memory = model.encoder(src, mask=src_mask, padding_mask=src_padding_mask)
    
    ys = torch.ones(src.size(0), 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    
    for i in range(max_len - 1):
        tgt_mask = generate_square_subsequent_mask(ys.size(1)).to(DEVICE)
        tgt_padding_mask = (ys == train_dataset.out_vocab['<pad>']).to(DEVICE)
        out = model.decoder(ys, memory, src_mask=src_mask, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask)
        prob = out[:, -1, :]
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys, torch.tensor([[next_word]], device=DEVICE)], dim=1)
        if next_word == train_dataset.out_vocab['<eos>']:
            break
    return ys

# Vocabulary class for token lookup
class SimpleVocab:
    def __init__(self, vocab):
        self.vocab = vocab
        self.inv_vocab = {id: token for token, id in vocab.items()}
    def lookup_tokens(self, ids):
        return [self.inv_vocab.get(i, "<unk>") for i in ids]

out_vocab_transform = SimpleVocab(train_dataset.out_vocab)
in_vocab_transform = SimpleVocab(train_dataset.in_vocab)

# Translation function
def translate(model, src_sentence, max_len=MAX_LEN):
    model.eval()
    src_tensor = simple_text_transform(src_sentence).to(DEVICE)
    seq_len = src_tensor.size(1)
    src_mask = torch.zeros(seq_len, seq_len, device=DEVICE).type(torch.float)
    src_padding_mask = (src_tensor == train_dataset.in_vocab['<pad>']).to(DEVICE)
    
    ys = greedy_decode(model, src_tensor, src_mask, src_padding_mask, max_len=seq_len + 5, start_symbol=train_dataset.out_vocab['<sos>'])
    tgt_tokens = ys.squeeze(0).cpu().numpy().tolist()
    
    tokens = out_vocab_transform.lookup_tokens(tgt_tokens)
    num_words = len(src_sentence.split())
    translation = " ".join(tokens).replace("<sos>", "").replace("<eos>", "").strip()
    if len(translation.split()) > num_words:
        translation = " ".join(translation.split()[:num_words])
    return translation

def simple_text_transform(sentence: str):
    sentence = sentence.strip().lower()
    sentence = sentence.translate(str.maketrans("", "", string.punctuation))
    tokens = sentence.split()
    token_ids = [train_dataset.in_vocab.get(token, train_dataset.in_vocab['<unk>']) for token in tokens]
    token_ids = token_ids[:MAX_LEN] + [train_dataset.in_vocab['<pad>']] * (MAX_LEN - len(token_ids))
    return torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)

# Training and Evaluation Functions
def calculate_metrics(model, iterator, in_vocab, out_vocab, device, max_len, criterion):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_tokens = 0
    bleu_scores = []
    idx2word = {idx: word for word, idx in out_vocab.items()}
    smoothie = SmoothingFunction().method4
    total_samples = len(iterator.dataset)
    sample_size = max(1, total_samples // 100)
    sampled_indices = random.sample(range(total_samples), sample_size)
    sampled_dataset = torch.utils.data.Subset(iterator.dataset, sampled_indices)
    sampled_loader = DataLoader(sampled_dataset, batch_size=iterator.batch_size, shuffle=False, pin_memory=True, num_workers=2, collate_fn=collate_fn)
    
    with torch.no_grad():
        for src, tgt in iterator:
            src, tgt = src.to(device), tgt.to(device)
            src_mask = torch.zeros(src.size(1), src.size(1), device=DEVICE).type(torch.float)
            tgt_mask = generate_square_subsequent_mask(tgt.size(1) - 1).to(device)
            src_padding_mask = (src == in_vocab['<pad>']).to(device)
            tgt_padding_mask = (tgt[:, :-1] == out_vocab['<pad>']).to(device)
            output = model(src, tgt[:, :-1], src_mask=src_mask, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask)
            output_dim = output.shape[-1]
            output = output.reshape(-1, output_dim)
            tgt_flat = tgt[:, 1:].reshape(-1)
            loss = criterion(output, tgt_flat)
            total_loss += loss.item()
            preds = output.argmax(dim=1)
            non_pad_mask = tgt_flat != out_vocab['<pad>']
            correct = (preds == tgt_flat) & non_pad_mask
            total_correct += correct.sum().item()
            total_tokens += non_pad_mask.sum().item()
        for src, tgt in sampled_loader:
            src, tgt = src.to(device), tgt.to(device)
            for i in range(src.shape[0]):
                src_sent = src[i].unsqueeze(0)
                tgt_sent = tgt[i].cpu().numpy()
                src_mask = torch.zeros(src_sent.size(1), src_sent.size(1), device=DEVICE).type(torch.float)
                src_padding_mask = (src_sent == in_vocab['<pad>']).to(device)
                pred_sent = greedy_decode(model, src_sent, src_mask, src_padding_mask, max_len, out_vocab['<sos>'])
                pred_sent = pred_sent.squeeze(0).cpu().numpy().tolist()
                pred_tokens = [idx2word.get(idx, '<unk>') for idx in pred_sent if idx not in [out_vocab['<pad>'], out_vocab['<sos>'], out_vocab['<eos>']]]
                ref_sent = [idx2word.get(idx, '<unk>') for idx in tgt_sent if idx not in [out_vocab['<pad>'], out_vocab['<sos>'], out_vocab['<eos>']]]
                bleu = sentence_bleu([ref_sent], pred_tokens, smoothing_function=smoothie)
                bleu_scores.append(bleu)
    avg_loss = total_loss / len(iterator)
    accuracy = total_correct / total_tokens if total_tokens > 0 else 0
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
    return avg_loss, accuracy, avg_bleu

def train(model, train_loader, val_loader, optimizer, criterion, in_vocab, out_vocab, max_len, device, num_epochs=NUM_EPOCHS, clip=1, patience=3):
    train_losses = []
    val_losses = []
    val_accuracies = []
    val_bleu_scores = []
    best_val_loss = float('inf')
    patience_counter = 0
    log_file = "/kaggle/working/results/training_log.txt"
    with open(log_file, 'w', encoding='utf-8') as f:
        f.write("Epoch,Train Loss,Val Loss,Val Accuracy,Val BLEU,VRAM (MB),Epoch Time (s)\n")
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        batch_count = 0
        start_time = time.time()
        for src, tgt in train_loader:
            src, tgt = src.to(device, non_blocking=True), tgt.to(device, non_blocking=True)
            optimizer.zero_grad()
            src_mask = torch.zeros(src.size(1), src.size(1), device=device).type(torch.float)
            tgt_mask = generate_square_subsequent_mask(tgt.size(1) - 1).to(device)
            src_padding_mask = (src == in_vocab['<pad>']).to(device)
            tgt_padding_mask = (tgt[:, :-1] == out_vocab['<pad>']).to(device)
            output = model(src, tgt[:, :-1], src_mask=src_mask, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask)
            output_dim = output.shape[-1]
            output = output.reshape(-1, output_dim)
            tgt = tgt[:, 1:].reshape(-1)
            loss = criterion(output, tgt)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            epoch_loss += loss.item()
            batch_count += 1
        train_loss = epoch_loss / batch_count
        val_loss, val_accuracy, val_bleu = calculate_metrics(model, val_loader, in_vocab, out_vocab, device, max_len, criterion)
        epoch_time = time.time() - start_time
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        val_bleu_scores.append(val_bleu)
        
        with open(log_file, 'a', encoding='utf-8') as f:
            f.write(f"{epoch+1},{train_loss:.3f},{val_loss:.3f},{val_accuracy:.3f},{val_bleu:.3f},{torch.cuda.memory_allocated()/1024**2:.2f},{epoch_time:.2f}\n")
        print(f'Epoch: {epoch+1:02}')
        print(f'\tTrain Loss: {train_loss:.3f}')
        print(f'\tVal Loss: {val_loss:.3f}')
        print(f'\tVal Accuracy: {val_accuracy:.3f}')
        print(f'\tVal BLEU: {val_bleu:.3f}')
        print(f'\tEpoch Time: {epoch_time:.2f} seconds')
        print(f'\tVRAM allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model, '/kaggle/working/results/transformer_best.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after epoch {epoch+1}')
                break
    
    # Plot and save metrics
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.savefig('/kaggle/working/results/loss_plot.png')
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, val_accuracies, label='Validation Accuracy', color='green')
    plt.title('Validation Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid()
    plt.savefig('/kaggle/working/results/accuracy_plot.png')
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, val_bleu_scores, label='Validation BLEU', color='blue')
    plt.title('Validation BLEU Score over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('BLEU Score')
    plt.legend()
    plt.grid()
    plt.savefig('/kaggle/working/results/bleu_plot.png')
    plt.close()
    
    return train_losses, val_losses, val_accuracies, val_bleu_scores

# Initialize model, criterion, optimizer
model = Transformer(
    src_vocab_size=len(train_dataset.in_vocab),
    tgt_vocab_size=len(train_dataset.out_vocab),
    d_model=DIM_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ff=D_FF,
    max_len=MAX_LEN,
    dropout=DROPOUT
).to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.out_vocab['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

# Train the model
train_losses, val_losses, val_accuracies, val_bleu_scores = train(
    model, train_loader, val_loader, optimizer, criterion, 
    train_dataset.in_vocab, train_dataset.out_vocab, MAX_LEN, DEVICE
)

Epoch: 01
	Train Loss: 1.696
	Val Loss: 0.918
	Val Accuracy: 0.756
	Val BLEU: 0.595
	Epoch Time: 11593.25 seconds
	VRAM allocated: 425.98 MB
Epoch: 02
	Train Loss: 0.969
	Val Loss: 0.543
	Val Accuracy: 0.850
	Val BLEU: 0.713
	Epoch Time: 11534.57 seconds
	VRAM allocated: 425.98 MB
Epoch: 03
	Train Loss: 0.634
	Val Loss: 0.342
	Val Accuracy: 0.904
	Val BLEU: 0.798
	Epoch Time: 11569.35 seconds
	VRAM allocated: 425.98 MB
