In [None]:
# Import necessary libraries
import os
import time
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from jiwer import wer, cer
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import pickle
import re
import math

In [None]:
# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
# Check if GPU is available
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

In [None]:
# Create output directory
OUTPUT_DIR = './output_1/'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [None]:
# Paths for saving/loading
DATASET_PATH = './exportStatements.xlsx'
VOCAB_PATH = os.path.join(OUTPUT_DIR, 'word_vocab.pkl')
PREPROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, 'preprocessed_data_word.pkl')
BEST_MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_transformer_model_word.pt')
BEST_CER_MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_transformer_model_cer.pt')
LOSS_PLOT_PATH = os.path.join(OUTPUT_DIR, 'transformer_loss_plot_char.png')
WER_PLOT_PATH = os.path.join(OUTPUT_DIR, 'wer_plot.png')
CER_PLOT_PATH = os.path.join(OUTPUT_DIR, 'cer_plot.png')

In [None]:
# Load the dataset
df = pd.read_excel(DATASET_PATH)

In [None]:
# Check for missing values in 'inFormalForm' and 'FormalForm'
print("Missing values in 'inFormalForm':", df['inFormalForm'].isnull().sum())
print("Missing values in 'FormalForm':", df['FormalForm'].isnull().sum())

# Drop rows with missing values in 'inFormalForm' and 'FormalForm'
initial_length = len(df)
df = df.dropna(subset=['inFormalForm', 'FormalForm']).reset_index(drop=True)
final_length = len(df)

df['inFormalForm'] = df['inFormalForm'].astype(str)
df['FormalForm'] = df['FormalForm'].astype(str)

print(f"Dropped {initial_length - final_length} rows due to missing values.")

In [None]:
# Split data into training, validation, and test sets (80%, 10%, 10%)
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=SEED)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=SEED)

In [None]:
# Build word vocabulary from training data
if not os.path.exists(VOCAB_PATH):
    print('Building word vocabulary...')
    from collections import Counter

    # Simple tokenizer function
    def tokenize(text):
        # Define the regular expression for tokenization
        # - \w+: Matches sequences of word characters (letters, numbers, underscores)
        # - [^\s\w]+: Matches sequences of characters that are not whitespace or word characters (e.g., punctuation)
        token_pattern = r'\w+|[^\s\w]+'

        # Use re.findall to extract all matches of the token pattern from the text
        tokens = re.findall(token_pattern, text)

        return tokens

    # Collect all words from the training data
    all_words = []
    for text in train_df['inFormalForm'].tolist() + train_df['FormalForm'].tolist():
        tokens = tokenize(text)
        all_words.extend(tokens)

    # Build vocabulary
    word_counts = Counter(all_words)
    words = sorted(word_counts.keys())

    # Add special tokens
    special_tokens = ['<pad>', '<unk>', '<s>', '</s>']
    word2idx = {word: idx + len(special_tokens) for idx, word in enumerate(words)}
    for idx, token in enumerate(special_tokens):
        word2idx[token] = idx
    idx2word = {idx: word for word, idx in word2idx.items()}

    # Save vocabulary
    with open(VOCAB_PATH, 'wb') as f:
        pickle.dump({'word2idx': word2idx, 'idx2word': idx2word}, f)
    print('Word vocabulary built and saved.')
else:
    print('Loading existing word vocabulary...')
    with open(VOCAB_PATH, 'rb') as f:
        vocab = pickle.load(f)
        word2idx = vocab['word2idx']
        idx2word = vocab['idx2word']

In [None]:
# Special token IDs
PAD_IDX = word2idx['<pad>']
UNK_IDX = word2idx['<unk>']
BOS_IDX = word2idx['<s>']
EOS_IDX = word2idx['</s>']

PAD_IDX, UNK_IDX, BOS_IDX, EOS_IDX

In [None]:
# Maximum sequence length (based on dataset)
def get_max_len(df_list):
    max_len = 0
    for df in df_list:
        lengths_src = df['inFormalForm'].apply(lambda x: len(tokenize(x)) + 2)  # +2 for BOS and EOS
        lengths_trg = df['FormalForm'].apply(lambda x: len(tokenize(x)) + 2)
        max_len = max(max_len, lengths_src.max(), lengths_trg.max())
    return max_len

MAX_LEN = get_max_len([train_df, val_df, test_df])

MAX_LEN

In [None]:
# Check if preprocessed data exists
if not os.path.exists(PREPROCESSED_DATA_PATH):
    print('Preprocessing data...')
    # Preprocess and tokenize all sentences
    def preprocess_data(df, word2idx, max_len=MAX_LEN):
        src_texts = df['inFormalForm'].tolist()
        trg_texts = df['FormalForm'].tolist()
        src_sequences = []
        trg_sequences = []
        for src, trg in zip(src_texts, trg_texts):
            src_tokens = tokenize(src)
            trg_tokens = tokenize(trg)
            src_ids = [BOS_IDX] + [word2idx.get(w, UNK_IDX) for w in src_tokens] + [EOS_IDX]
            trg_ids = [BOS_IDX] + [word2idx.get(w, UNK_IDX) for w in trg_tokens] + [EOS_IDX]
            # Pad or truncate sequences
            src_ids = src_ids[:max_len] + [PAD_IDX] * max(0, max_len - len(src_ids))
            trg_ids = trg_ids[:max_len] + [PAD_IDX] * max(0, max_len - len(trg_ids))
            src_sequences.append(src_ids)
            trg_sequences.append(trg_ids)
        return src_sequences, trg_sequences
    
    # Tokenize and preprocess data
    train_src, train_trg = preprocess_data(train_df, word2idx)
    val_src, val_trg = preprocess_data(val_df, word2idx)
    test_src, test_trg = preprocess_data(test_df, word2idx)

    # Save preprocessed data
    with open(PREPROCESSED_DATA_PATH, 'wb') as f:
        pickle.dump({
            'train_src': train_src,
            'train_trg': train_trg,
            'val_src': val_src,
            'val_trg': val_trg,
            'test_src': test_src,
            'test_trg': test_trg,
            'MAX_LEN': MAX_LEN
        }, f)
    print('Preprocessed data saved.')
else:
    print('Loading preprocessed data...')
    # Load preprocessed data
    with open(PREPROCESSED_DATA_PATH, 'rb') as f:
        data = pickle.load(f)
        train_src = data['train_src']
        train_trg = data['train_trg']
        val_src = data['val_src']
        val_trg = data['val_trg']
        test_src = data['test_src']
        test_trg = data['test_trg']
        MAX_LEN = data['MAX_LEN']

In [None]:
# Prepare datasets
class TranslationDataset(Dataset):
    def __init__(self, src_sequences, trg_sequences):
        self.src_sequences = src_sequences
        self.trg_sequences = trg_sequences

    def __len__(self):
        return len(self.src_sequences)

    def __getitem__(self, idx):
        src_ids = torch.tensor(self.src_sequences[idx], dtype=torch.long)
        trg_ids = torch.tensor(self.trg_sequences[idx], dtype=torch.long)
        return src_ids, trg_ids

In [None]:
# Collate function to create masks and pad sequences
def collate_fn(batch, pad_idx):
    src_batch, trg_batch = zip(*batch)
    src_batch = torch.nn.utils.rnn.pad_sequence(src_batch, padding_value=pad_idx, batch_first=True)
    trg_batch = torch.nn.utils.rnn.pad_sequence(trg_batch, padding_value=pad_idx, batch_first=True)
    return src_batch, trg_batch

In [None]:
# Create datasets and dataloaders
batch_size = 32  # Adjust as needed

train_dataset = TranslationDataset(train_src, train_trg)
val_dataset = TranslationDataset(val_src, val_trg)
test_dataset = TranslationDataset(test_src, test_trg)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, PAD_IDX))
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=lambda x: collate_fn(x, PAD_IDX))
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=lambda x: collate_fn(x, PAD_IDX))

In [None]:
# Function to generate subsequent masks for target
def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones((sz, sz), device=DEVICE) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [None]:
# Function to create padding masks
def create_mask(src, tgt, pad_idx):
    src_seq_len = src.size(1)
    tgt_seq_len = tgt.size(1)

    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)

    src_padding_mask = (src == pad_idx)
    tgt_padding_mask = (tgt == pad_idx)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [None]:
# Transformer Model Definition (same as before)
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size,
                 nhead, src_vocab_size, tgt_vocab_size, dim_feedforward=512,
                 dropout=0.1, max_len=MAX_LEN, pad_idx=PAD_IDX):
        super(Seq2SeqTransformer, self).__init__()
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.emb_size = emb_size
        self.pad_idx = pad_idx
        self.max_len = max_len

        # Token embedding layers
        self.src_embedding = nn.Embedding(src_vocab_size, emb_size, padding_idx=pad_idx)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, emb_size, padding_idx=pad_idx)

        # Learned positional embeddings
        self.src_pos_embedding = nn.Embedding(max_len, emb_size)
        self.tgt_pos_embedding = nn.Embedding(max_len, emb_size)

        # Transformer
        self.transformer = nn.Transformer(d_model=emb_size, nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout)

        # Output layer
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        # src and tgt shape: [batch_size, seq_len]
        src_seq_len = src.size(1)
        tgt_seq_len = tgt.size(1)
        src_positions = torch.arange(0, src_seq_len, device=src.device).unsqueeze(0).expand(src.size(0), -1)
        tgt_positions = torch.arange(0, tgt_seq_len, device=tgt.device).unsqueeze(0).expand(tgt.size(0), -1)

        # Embed and encode source
        src_emb = self.src_embedding(src) + self.src_pos_embedding(src_positions)
        src_emb = src_emb * math.sqrt(self.emb_size)
        # Embed and encode target
        tgt_emb = self.tgt_embedding(tgt) + self.tgt_pos_embedding(tgt_positions)
        tgt_emb = tgt_emb * math.sqrt(self.emb_size)

        # Transformer
        output = self.transformer(src_emb.transpose(0,1), tgt_emb.transpose(0,1),
                                  src_mask, tgt_mask, None,
                                  src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        output = output.transpose(0,1)
        logits = self.generator(output)
        return logits

    def encode(self, src, src_mask):
        src_seq_len = src.size(1)
        src_positions = torch.arange(0, src_seq_len, device=src.device).unsqueeze(0).expand(src.size(0), -1)
        src_emb = self.src_embedding(src) + self.src_pos_embedding(src_positions)
        src_emb = src_emb * math.sqrt(self.emb_size)
        return self.transformer.encoder(src_emb.transpose(0,1), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        tgt_seq_len = tgt.size(1)
        tgt_positions = torch.arange(0, tgt_seq_len, device=tgt.device).unsqueeze(0).expand(tgt.size(0), -1)
        tgt_emb = self.tgt_embedding(tgt) + self.tgt_pos_embedding(tgt_positions)
        tgt_emb = tgt_emb * math.sqrt(self.emb_size)
        return self.transformer.decoder(tgt_emb.transpose(0,1), memory, tgt_mask)

In [None]:
# Initialize model parameters
VOCAB_SIZE = len(word2idx)
SRC_VOCAB_SIZE = VOCAB_SIZE
TGT_VOCAB_SIZE = VOCAB_SIZE

# Hyperparameters (same as before)
num_encoder_layers = 3
num_decoder_layers = 3
emb_size = 256
nhead = 8
dim_feedforward = 256
dropout = 0.1  # Adjust dropout rate as needed


model = Seq2SeqTransformer(num_encoder_layers, num_decoder_layers, emb_size,
                           nhead, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                           dim_feedforward, dropout, MAX_LEN, PAD_IDX).to(DEVICE)

In [None]:
# Initialize optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# # Initialize optimizer and loss function
# optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Adjusted learning rate for Transformer
# criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [None]:
def evaluate_wer(model, dataloader, idx2word, max_batches=None):
    model.eval()
    cer_scores = []
    wer_scores = []
    batches_processed = 0

    with torch.no_grad():
        for src, trg in dataloader:
            src = src.to(DEVICE)
            trg = trg.to(DEVICE)

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, trg[:, :-1], PAD_IDX)

            output = model(src, trg[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
            output = output.argmax(dim=-1)

            trg_out = trg[:, 1:]  # Remove first token (<s>) for target
            output = output.cpu().tolist()
            trg_out = trg_out.cpu().tolist()

            for pred_ids, trg_ids in zip(output, trg_out):
                # Remove PAD and special tokens
                pred_ids = [idx for idx in pred_ids if idx not in [PAD_IDX, EOS_IDX, UNK_IDX]]
                trg_ids = [idx for idx in trg_ids if idx not in [PAD_IDX, EOS_IDX, UNK_IDX]]

                pred_sentence = ' '.join([idx2word.get(idx, '') for idx in pred_ids])
                trg_sentence = ' '.join([idx2word.get(idx, '') for idx in trg_ids])

                cer_score = cer(trg_sentence, pred_sentence)
                wer_score = wer(trg_sentence, pred_sentence)

                cer_scores.append(cer_score)
                wer_scores.append(wer_score)

            batches_processed += 1
            if max_batches and batches_processed >= max_batches:
                break

    avg_cer = np.mean(cer_scores)
    avg_wer = np.mean(wer_scores)
    return avg_cer, avg_wer

In [None]:
# Training loop with WER calculation (same as before)
N_EPOCHS = 100
CLIP = 1  # Enable gradient clipping
best_valid_loss = float('inf')
best_valid_cer = float('inf')
patience = 5
counter = 0

train_losses = []
valid_losses = []

train_wers = []
valid_wers = []
train_cers = []
valid_cers = []

for epoch in range(1, N_EPOCHS + 1):
    start_time = time.time()

    # Training
    model.train()
    epoch_train_loss = 0
    for src, trg in tqdm(train_loader, desc=f'Training Epoch {epoch}/{N_EPOCHS}'):
        src = src.to(DEVICE)
        trg = trg.to(DEVICE)

        # Create masks
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, trg[:, :-1], PAD_IDX)

        optimizer.zero_grad()
        output = model(src, trg[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        # output: [batch_size, tgt_len - 1, vocab_size]
        # trg_out: [batch_size, tgt_len - 1]
        trg_out = trg[:, 1:]

        output = output.reshape(-1, TGT_VOCAB_SIZE)
        trg_out = trg_out.reshape(-1)

        loss = criterion(output, trg_out)
        loss.backward()

        # Gradient clipping
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        optimizer.step()

        epoch_train_loss += loss.item()

    epoch_train_loss /= len(train_loader)
    train_losses.append(epoch_train_loss)
    

    # Validation
    model.eval()
    epoch_valid_loss = 0
    with torch.no_grad():
        for src, trg in tqdm(val_loader, desc=f'Validation Epoch {epoch}/{N_EPOCHS}'):
            src = src.to(DEVICE)
            trg = trg.to(DEVICE)

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, trg[:, :-1], PAD_IDX)

            output = model(src, trg[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

            trg_out = trg[:, 1:]

            output = output.reshape(-1, TGT_VOCAB_SIZE)
            trg_out = trg_out.reshape(-1)

            loss = criterion(output, trg_out)
            epoch_valid_loss += loss.item()

    epoch_valid_loss /= len(val_loader)
    valid_losses.append(epoch_valid_loss)

    # Evaluate WER
    valid_cer, valid_wer = evaluate_wer(model, val_loader, idx2word)
    valid_wers.append(valid_wer)
    valid_cers.append(valid_cer)
    

    train_subset_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, PAD_IDX))
    train_cer, train_wer = evaluate_wer(model, train_subset_loader, idx2word, max_batches=5)
    train_wers.append(train_wer)
    train_cers.append(train_cer)
    
    print(f'\tTrain Loss: {epoch_train_loss:.3f}')
    print(f'\tValid Loss: {epoch_valid_loss:.3f}')
    print(f'\tTrain WER: {train_wer:.4f}')
    print(f'\tValid WER: {valid_wer:.4f}')
    print(f'\tTrain CER: {train_cer:.4f}')
    print(f'\tValid CER: {valid_cer:.4f}')
    
    # Early stopping check
    if epoch_valid_loss < best_valid_loss:
        best_valid_loss = epoch_valid_loss
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f'Validation loss improved. Model saved to {BEST_MODEL_PATH}.')
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print('Early stopping triggered.')
            break

    if valid_cer < best_valid_cer:
        best_valid_cer = valid_cer
        torch.save(model.state_dict(), BEST_CER_MODEL_PATH)
        print(f'Validation CER improved. Model saved to {BEST_CER_MODEL_PATH}.')
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(int(end_time - start_time), 60)

    print(f'Epoch: {epoch:02} | Time: {epoch_mins}m {epoch_secs}s')

In [None]:
# Plot training and validation loss (same as before)
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(valid_losses) + 1), valid_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.savefig(LOSS_PLOT_PATH)
plt.show()
print(f'Loss plot saved to {LOSS_PLOT_PATH}.')

In [None]:
# Plot WER over epochs
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_wers) + 1), train_wers, label='Train WER')
plt.plot(range(1, len(valid_wers) + 1), valid_wers, label='Validation WER')
plt.xlabel('Epoch')
plt.ylabel('WER')
plt.legend()
plt.title('Training and Validation WER Over Epochs')
plt.savefig(WER_PLOT_PATH)
plt.show()
print(f'WER plot saved to {WER_PLOT_PATH}.')

In [None]:
# Plot CER over epochs
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_cers) + 1), train_cers, label='Train CER')
plt.plot(range(1, len(valid_cers) + 1), valid_cers, label='Validation CER')
plt.xlabel('Epoch')
plt.ylabel('CER')
plt.legend()
plt.title('Training and Validation CER Over Epochs')
plt.savefig(CER_PLOT_PATH)
plt.show()
print(f'CER plot saved to {CER_PLOT_PATH}.')

In [None]:
# Function for inference
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        tgt_mask = generate_square_subsequent_mask(ys.size(1)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0,1)
        prob = model.generator(out[:, -1])
        next_word = torch.argmax(prob, dim=1).item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == EOS_IDX:
            break
    return ys

In [None]:
def translate_sentence(sentence, model, word2idx, idx2word, device, max_len=MAX_LEN, decoding_strategy='greedy'):
    model.eval()
    tokens = tokenize(sentence)
    tokens = [BOS_IDX] + [word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
    tokens = tokens[:max_len]
    src = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
    src_mask = (torch.zeros(src.shape[1], src.shape[1])).type(torch.bool).to(device)
    if decoding_strategy == 'greedy':
        tgt_tokens = greedy_decode(model, src, src_mask, max_len, BOS_IDX).flatten()
    elif decoding_strategy == 'beam':
        # Implement beam search decoding as needed
        raise NotImplementedError("Beam search decoding is not implemented yet.")
    else:
        raise ValueError("Invalid decoding strategy")
    tgt_tokens = tgt_tokens.cpu().numpy()
    # Remove BOS token
    tgt_tokens = tgt_tokens[1:]
    # Stop at EOS token
    if EOS_IDX in tgt_tokens:
        eos_index = np.where(tgt_tokens == EOS_IDX)[0][0]
        tgt_tokens = tgt_tokens[:eos_index]
    translation = ' '.join([idx2word.get(idx, '') for idx in tgt_tokens if idx not in [PAD_IDX, BOS_IDX, EOS_IDX, UNK_IDX]])
    return translation

In [None]:
# Function to calculate CER and WER
def calculate_metrics(references, hypotheses):
    cer_scores = []
    wer_scores = []
    for ref, hyp in zip(references, hypotheses):
        cer_score = cer(ref, hyp)
        wer_score = wer(ref, hyp)
        cer_scores.append(cer_score)
        wer_scores.append(wer_score)
    avg_cer = np.mean(cer_scores)
    avg_wer = np.mean(wer_scores)
    return avg_cer, avg_wer

In [None]:
# Evaluate and save results
def evaluate_and_save(model, df, src_sequences, trg_sequences, word2idx, idx2word, file_name):
    model.eval()
    predictions = []
    cer_scores = []
    wer_scores = []

    for src_ids, trg_ids in tqdm(zip(src_sequences, trg_sequences), total=len(src_sequences), desc=f'Evaluating {file_name}'):
        src_sentence = ' '.join([idx2word.get(idx, '') for idx in src_ids if idx not in [BOS_IDX, EOS_IDX, PAD_IDX]])
        trg_sentence = ' '.join([idx2word.get(idx, '') for idx in trg_ids if idx not in [BOS_IDX, EOS_IDX, PAD_IDX]])

        pred_sentence = translate_sentence(src_sentence, model, word2idx, idx2word, DEVICE)
        predictions.append(pred_sentence)
        cer_score = cer(trg_sentence, pred_sentence)
        wer_score = wer(trg_sentence, pred_sentence)
        cer_scores.append(cer_score)
        wer_scores.append(wer_score)

    results_df = pd.DataFrame({
        'Source': df['inFormalForm'],
        'Target': df['FormalForm'],
        'Prediction': predictions,
        'CER': cer_scores,
        'WER': wer_scores
    })
    
    results_df = results_df.sort_values(by=['CER', 'WER'], ascending=[True, True])

    results_path = os.path.join(OUTPUT_DIR, file_name)
    results_df.to_csv(results_path, index=False)
    avg_cer = np.mean(cer_scores)
    avg_wer = np.mean(wer_scores)
    print(f'Results saved to {results_path}')
    print(f'Average CER: {avg_cer:.4f}')
    print(f'Average WER: {avg_wer:.4f}')
    return results_df

In [None]:
# Load the best model
model.load_state_dict(torch.load(BEST_MODEL_PATH))
print('Best model loaded.')

# Evaluate on training data
print('Evaluating on training data...')
train_results = evaluate_and_save(model, train_df, train_src, train_trg, word2idx, idx2word, 'train_results_word_transformer.csv')

# Evaluate on validation data
print('Evaluating on validation data...')
val_results = evaluate_and_save(model, val_df, val_src, val_trg, word2idx, idx2word, 'val_results_word_transformer.csv')

# Evaluate on test data
print('Evaluating on test data...')
test_results = evaluate_and_save(model, test_df, test_src, test_trg, word2idx, idx2word, 'test_results_word_transformer.csv')

In [None]:
# Load the best CER model
model.load_state_dict(torch.load(BEST_CER_MODEL_PATH))
print('Best CER model loaded.')

# Evaluate on training data
print('Evaluating on training data using best CER model...')
train_results = evaluate_and_save(model, train_df, train_src, train_trg, word2idx, idx2word, 'train_results_best_cer.csv')

# Evaluate on validation data
print('Evaluating on validation data using best CER model...')
val_results = evaluate_and_save(model, val_df, val_src, val_trg, word2idx, idx2word, 'val_results_best_cer.csv')

# Evaluate on test data using the best CER model
print('Evaluating on test data using best CER model...')
test_results = evaluate_and_save(model, test_df, test_src, test_trg, word2idx, idx2word, 'test_results_best_cer.csv')