In [None]:
import torch
import re
import numpy as np
import pickle

from torch.utils.data import DataLoader
from models.blstm.blstm import BLSTM
from utils.data_loader import DiacritizationDataset
from utils.utils import preprocess, separate_diacritics
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
with open('data/test.txt', 'r', encoding='utf-8') as f:
    test_data = f.read()

with open('utils/diacritics.pickle', 'rb') as f:
    diacritics = pickle.load(f)

cleaned_test_data = preprocess(test_data, diacritics)

In [None]:
split_punct = {",", ".", "،", ":", "?", "؟", "؛", "«", "»", "،", "\n"}
test_sentences = re.split(f"[{re.escape(''.join(split_punct))}]", cleaned_test_data)
test_sentences = list(filter(lambda s: s.strip(), test_sentences))

print(f"Total test_sentences: {len(test_sentences)}")

In [None]:
X_test = []
y_test = []

with open('utils/diacritic2id.pickle', 'rb') as f:
    diacritic2idx = pickle.load(f)

with open('utils/letter2idx.pickle', 'rb') as f:
    letter2idx = pickle.load(f)

for sentence in test_sentences:
    chars, diacritics = separate_diacritics(sentence.strip(), diacritic2idx)
    X_test.append([letter2idx[char] for char in chars])
    y_test.append([diacritic2idx[diacritic] for diacritic in diacritics])

In [None]:
vocab_size = len(letter2idx)
num_classes = len(diacritic2idx)
max_length = 1236

In [None]:
idx2letter = {v: k for k, v in letter2idx.items()}

In [None]:
X_test = pad_sequences(
    X_test,
    maxlen=max_length,
    padding='post',
    value=letter2idx['<PAD>']
)

y_test = pad_sequences(
    y_test,
    maxlen=max_length,
    padding='post',
    value=diacritic2idx['<PAD>']
)

In [None]:
test_dataset = DiacritizationDataset(X_test, y_test, idx2letter=idx2letter)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
model = BLSTM(vocab_size=vocab_size, num_classes=num_classes)

In [None]:
model.load_state_dict(torch.load("./models/blstm/blstm_model.pth", map_location=device))
model.to(device)
model.eval()

In [None]:
@torch.no_grad()
def evaluate_full_sequence(model, data_loader):
    model.eval()
    total_correct = 0
    total_tokens = 0

    for batch_X, batch_y, batch_mask in data_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)

        outputs = model(batch_X)
        preds = outputs.argmax(dim=-1)

        mask = (batch_y != 15)
        correct = (preds[mask] == batch_y[mask]).sum().item()
        total_correct += correct
        total_tokens += mask.sum().item()

    acc = total_correct / total_tokens if total_tokens > 0 else 0
    return acc


@torch.no_grad()
def evaluate_last_char_accuracy(model, data_loader):
    model.eval()
    total_correct = 0
    total_important = 0

    for batch_X, batch_y, batch_mask in data_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)

        outputs = model(batch_X)
        preds = outputs.argmax(dim=-1)

        mask = (batch_mask == 1)
        correct = (preds[mask] == batch_y[mask]).sum().item()
        total_correct += correct
        total_important += mask.sum().item()

    acc = total_correct / total_important if total_important > 0 else 0
    return acc

In [None]:
full_acc = evaluate_full_sequence(model, test_loader)
last_char_acc = evaluate_last_char_accuracy(model, test_loader)

print(f"Full sequence accuracy: {full_acc:.4f}")
print(f"Last char accuracy: {last_char_acc:.4f}")