In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import DebertaForQuestionAnswering, RobertaForQuestionAnswering
from short_answer_dataset import ShortAnswerDataset
import json
import matplotlib as plt
import numpy as np
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def evaluate(model, val_data, batch_size=4, max_length=512, max_answers=3):
    val_dataset = ShortAnswerDataset(val_data, max_length=max_length, max_answers=max_answers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    model.eval()
    epoch_loss = 0.0
    predictions_dict = {}
    true_answers_dict = {}

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            example_ids = batch['example_id']
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)

            outputs = model(
                input_ids,
                attention_mask=attention_mask,
                start_positions=start_positions,
                end_positions=end_positions
            )

            loss = outputs.loss

            if not torch.isfinite(loss):
                print("Non-finite loss encountered during evaluation.")
                continue  

            epoch_loss += loss.item()

            start_logits = outputs.start_logits
            end_logits = outputs.end_logits

            start_preds = torch.argmax(start_logits, dim=-1)
            end_preds = torch.argmax(end_logits, dim=-1)

            for i in range(len(example_ids)):
                example_id = example_ids[i].item() if isinstance(example_ids[i], torch.Tensor) else example_ids[i]
                pred_span = (start_preds[i].item(), end_preds[i].item())
                true_span = (start_positions[i].item(), end_positions[i].item())

                # Collect all predictions per example_id
                if example_id in predictions_dict:
                    predictions_dict[example_id].append(pred_span)
                else:
                    predictions_dict[example_id] = [pred_span]

                # Collect all true spans per example_id
                if example_id in true_answers_dict:
                    true_answers_dict[example_id].append(true_span)
                else:
                    true_answers_dict[example_id] = [true_span]

    all_em = []
    all_f1 = []
    all_precision = []
    all_recall = []

    for example_id in predictions_dict:
        pred_spans = predictions_dict[example_id]
        true_spans = true_answers_dict[example_id]

        max_em = 0
        max_f1 = 0
        max_precision = 0
        max_recall = 0

        for pred_span in pred_spans:
            for true_span in true_spans:
                em = exact_match_score(pred_span, true_span)
                f1 = f1_score(pred_span, true_span)
                precision = precision_score(pred_span, true_span)
                recall = recall_score(pred_span, true_span)

                max_em = max(max_em, em)
                max_f1 = max(max_f1, f1)
                max_precision = max(max_precision, precision)
                max_recall = max(max_recall, recall)

        all_em.append(max_em)
        all_f1.append(max_f1)
        all_precision.append(max_precision)
        all_recall.append(max_recall)

    em = sum(all_em) / len(all_em) if len(all_em) > 0 else 0
    f1 = sum(all_f1) / len(all_f1) if len(all_f1) > 0 else 0
    precision = sum(all_precision) / len(all_precision) if len(all_precision) > 0 else 0
    recall = sum(all_recall) / len(all_recall) if len(all_recall) > 0 else 0

    return epoch_loss / len(val_loader), em, f1, precision, recall

def exact_match_score(pred_span, true_span):
    return int(pred_span == true_span)

def f1_score(pred_span, true_span):
    pred_start, pred_end = pred_span
    true_start, true_end = true_span

    common = min(pred_end, true_end) - max(pred_start, true_start) + 1
    if common <= 0:
        return 0
    precision = common / (pred_end - pred_start + 1)
    recall = common / (true_end - true_start + 1)
    return 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0

def precision_score(pred_span, true_span):
    pred_start, pred_end = pred_span
    true_start, true_end = true_span
    common = min(pred_end, true_end) - max(pred_start, true_start) + 1
    return common / (pred_end - pred_start + 1) if common > 0 else 0

def recall_score(pred_span, true_span):
    pred_start, pred_end = pred_span
    true_start, true_end = true_span
    common = min(pred_end, true_end) - max(pred_start, true_start) + 1
    return common / (true_end - true_start + 1) if common > 0 else 0


#model = DebertaForQuestionAnswering.from_pretrained('microsoft/deberta-base')
model = RobertaForQuestionAnswering.from_pretrained('roberta-base')

model.load_state_dict(torch.load('RoBERTa_64_EM_78_F1.pth'))
model.to(device)

def load_data(file_path, subset_size=None):
        with open(file_path, 'r') as f:
            data = [json.loads(line.strip()) for line in f if line.strip()]
    
        if subset_size is not None and subset_size < len(data):
            data = random.sample(data, subset_size)
        return data

val_data = load_data('short_answers_only-dev.jsonl', 5000)

val_loss, val_em, val_f1, val_precision, val_recall = evaluate(
    model,
    val_data,
    batch_size=8,
    max_length=256,
    max_answers=1
)

print(f"Validation Loss: {val_loss}")
print(f"Exact Match (EM): {val_em}")
print(f"F1 Score: {val_f1}")
print(f"Precision: {val_precision}")
print(f"Recall: {val_recall}")
