<a href="https://colab.research.google.com/github/MartijnRoozendaal1/vvr-prediction/blob/main/Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. Setup

In [None]:
# Import necessary libraries
import numpy as np
import torch
from sklearn.metrics import (
    precision_score, recall_score, f1_score, matthews_corrcoef,
    precision_recall_curve, auc, ConfusionMatrixDisplay
)
import matplotlib.pyplot as plt

2. Load Data

In [None]:
# Load test data (update paths if necessary)
X_tensor_test = torch.load('/content/drive/MyDrive/X_tensor_test.pt')
y_tensor_test = torch.load('/content/drive/MyDrive/y_tensor_test.pt')

# Create DataLoader for the test set
from torch.utils.data import DataLoader, TensorDataset

test_dataset = TensorDataset(X_tensor_test, y_tensor_test)
test_loader = DataLoader(test_dataset, batch_size=32)  # Adjust batch size if needed


3. Evaluation

In [None]:
# Function to evaluate the model and optimize the threshold
def evaluate_model_with_threshold(model, test_loader):
    """
    Evaluate the model, optimize the threshold, and visualize metrics.

    Args:
        model: Trained PyTorch model
        test_loader: DataLoader for the test set

    Displays:
        - Optimal threshold
        - Metrics: Precision, Recall, F1, PR-AUC
        - Confusion Matrix
        - Precision-Recall Curve
    """
    model.eval()
    all_labels, all_logits = [], []

    # Generate predictions
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            outputs = model(X_batch).squeeze()
            all_logits.extend(outputs.tolist())
            all_labels.extend(y_batch.tolist())

    all_logits = np.array(all_logits)
    all_labels = np.array(all_labels)

    # Find the optimal threshold
    best_mcc, best_threshold = -1, 0.5
    for threshold in np.arange(0.1, 1, 0.01):
        predictions = (torch.sigmoid(torch.tensor(all_logits)) > threshold).float().numpy()
        mcc = matthews_corrcoef(all_labels, predictions)
        if mcc > best_mcc:
            best_mcc = mcc
            best_threshold = threshold

    # Final predictions
    final_predictions = (torch.sigmoid(torch.tensor(all_logits)) > best_threshold).float().numpy()

    # Calculate metrics
    precision = precision_score(all_labels, final_predictions)
    recall = recall_score(all_labels, final_predictions)
    f1 = f1_score(all_labels, final_predictions)

    # Compute precision-recall curve and AUC
    precisions, recalls, _ = precision_recall_curve(all_labels, all_logits)
    pr_auc = auc(recalls, precisions)

    # Print metrics
    print(f"Optimal Threshold: {best_threshold:.2f}, Best MCC: {best_mcc:.2f}")
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}, PR-AUC: {pr_auc:.2f}")

    # Display Confusion Matrix
    ConfusionMatrixDisplay.from_predictions(all_labels, final_predictions)
    plt.title("Confusion Matrix")
    plt.show()

    # Plot Precision-Recall Curve
    plt.plot(recalls, precisions, marker='.')
    plt.title("Precision-Recall Curve")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.show()
