In [1]:
import torch
import pickle

from utils.utils import create_data_pipeline
from models.bilstm2.bilstm2 import BiLSTM
# from models.blstm.blstm import BLSTM

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
with open('utils/diacritic2id.pickle', 'rb') as f:
    diacritic2idx = pickle.load(f)

with open('utils/letter2idx.pickle', 'rb') as f:
    letter2idx = pickle.load(f)

In [4]:
vocab_size = len(letter2idx)
num_classes = len(diacritic2idx)

In [5]:
def pad_collate_fn(batch):
    x_batch, y_batch, mask_batch = zip(*batch)
    lengths_x = [len(x) for x in x_batch]
    x_padded = torch.nn.utils.rnn.pad_sequence(x_batch, batch_first=True, padding_value=letter2idx['<PAD>'])
    y_padded = torch.nn.utils.rnn.pad_sequence(y_batch, batch_first=True, padding_value=diacritic2idx['<PAD>'])
    mask_spadded = torch.nn.utils.rnn.pad_sequence(mask_batch, batch_first=True, padding_value=0)
    return x_padded, y_padded, mask_spadded, torch.tensor(lengths_x, dtype=torch.long)

In [6]:
test_dataset,  test_loader= create_data_pipeline(
    corpus_path='data/val.txt', 
    letter2idx=letter2idx, 
    diacritic2idx=diacritic2idx, 
    collate_fn=pad_collate_fn,
    train=False, 
    batch_size=32
)

In [7]:
model = BiLSTM(vocab_size=vocab_size, num_classes=num_classes)

In [8]:
checkpoint = torch.load("./models/bilstm2/bilstm2.pth", map_location=device)

model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
print(f"Validation Loss: {checkpoint['val_loss']:.4f}")
print(f"Validation Accuracy: {checkpoint['val_accuracy']:.4f}")

Loaded checkpoint from epoch 6
Validation Loss: 0.0689
Validation Accuracy: 0.9781


  checkpoint = torch.load("./models/bilstm2/bilstm2.pth", map_location=device)


In [9]:
@torch.no_grad()
def evaluate_full_sequence(model, data_loader):
    model.eval()
    total_correct = 0
    total_tokens = 0

    for batch_X, batch_y, _, _ in data_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)

        outputs = model(batch_X)
        preds = outputs.argmax(dim=-1)

        mask = (batch_y != 15)
        correct = (preds[mask] == batch_y[mask]).sum().item()
        total_correct += correct
        total_tokens += mask.sum().item()

    acc = total_correct / total_tokens if total_tokens > 0 else 0
    return acc


@torch.no_grad()
def evaluate_last_char_accuracy(model, data_loader):
    model.eval()
    total_correct = 0
    total_important = 0

    for batch_X, batch_y, batch_mask, _ in data_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)

        outputs = model(batch_X)
        preds = outputs.argmax(dim=-1)

        mask = (batch_mask == 1)
        correct = (preds[mask] == batch_y[mask]).sum().item()
        total_correct += correct
        total_important += mask.sum().item()

    acc = total_correct / total_important if total_important > 0 else 0
    return acc

In [10]:
full_acc = evaluate_full_sequence(model, test_loader)
last_char_acc = evaluate_last_char_accuracy(model, test_loader)

print(f"Full sequence accuracy: {full_acc:.4f}")
print(f"Last char accuracy: {last_char_acc:.4f}")

Full sequence accuracy: 0.9778
Last char accuracy: 0.9568
