In [None]:
import json
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from yes_no_dataset import YesNoDataset
from transformers import RobertaForSequenceClassification, DebertaForSequenceClassification, BertForSequenceClassification
import numpy as np
import random

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

model_path = 'BaseBERT_76.pth'  
val_data_file = 'simplified-yes-no-dev.jsonl'  

val_data = []
with open(val_data_file, 'r') as f:
    for line in f:
        val_data.append(json.loads(line))  

val_dataset = YesNoDataset(val_data, max_length=128)  
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
#model = DebertaForSequenceClassification.from_pretrained('microsoft/deberta-base', num_labels=2)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

model.load_state_dict(torch.load(model_path))
model.to(device)

def evaluate(model, val_loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    epoch_loss = 0
    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = criterion(logits, labels)
            epoch_loss += loss.item()

            # Get predictions
            preds = torch.argmax(logits, dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    avg_loss = epoch_loss / len(val_loader)
    return avg_loss, all_labels, all_preds

val_loss, val_labels, val_preds = evaluate(model, val_loader, device)

In [None]:
print("Classification Report:")
print(classification_report(val_labels, val_preds))

In [None]:
cm = confusion_matrix(val_labels, val_preds)
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['No', 'Yes'], yticklabels=['No', 'Yes'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.show()

In [None]:
precision = precision_score(val_labels, val_preds, average='binary')
recall = recall_score(val_labels, val_preds, average='binary')
f1 = f1_score(val_labels, val_preds, average='binary')

print(f'Validation Loss: {val_loss:.4f}')
print(f'Validation Accuracy: {(np.array(val_labels) == np.array(val_preds)).mean():.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')