General Evaluation Function for Multiple Models & Datasets

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report

In [6]:
def evaluate_models(models, model_names, data_loaders, dataset_names, device='cuda'):
    """
    General evaluation function for multiple models across multiple datasets.


    Args:
    models (list): list of PyTorch models.
    model_names (list): list of model names (strings).
    data_loaders (list): list of DataLoader objects for datasets.
    dataset_names (list): list of dataset names (strings).
    device (str): device to run evaluation ('cuda' or 'cpu').


    Returns:
    results_dict (dict): nested dictionary containing metrics for each model/dataset.
    """
    results_dict = {}


    for model, m_name in zip(models, model_names):
        model.to(device)
        model.eval()
        results_dict[m_name] = {}


        for loader, d_name in zip(data_loaders, dataset_names):
            y_true, y_pred, y_prob = [], [], []


            with torch.no_grad():
                for images, labels in loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    probs = torch.softmax(outputs, dim=1)[:, 1]
                    preds = torch.argmax(outputs, dim=1)
                    y_true.extend(labels.cpu().numpy())
                    y_pred.extend(preds.cpu().numpy())
                    y_prob.extend(probs.cpu().numpy())


        # Compute metrics
        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred)
        rec = recall_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        auc = roc_auc_score(y_true, y_prob)
        cm = confusion_matrix(y_true, y_pred)


        # Store results
        results_dict[m_name][d_name] = {
        'accuracy': acc,
        'precision': prec,
        'recall': rec,
        'f1': f1,
        'roc_auc': auc,
        'confusion_matrix': cm
        }


        # --- Plot confusion matrix ---
        plt.figure(figsize=(5,4))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Real','AI'], yticklabels=['Real','AI'])
        plt.title(f'Confusion Matrix: {m_name} on {d_name}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.show()


        # --- ROC Curve ---
        from sklearn.metrics import roc_curve
        fpr, tpr, _ = roc_curve(y_true, y_prob)
        plt.figure()
        plt.plot(fpr, tpr, label=f'{m_name} on {d_name} (AUC={auc:.2f})')
        plt.plot([0,1],[0,1],'--', color='gray')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve')
        plt.legend()
        plt.show()


        # --- Classification report ---
        print(f'Classification Report for {m_name} on {d_name}:')
        print(classification_report(y_true, y_pred, target_names=['Real','AI']))


    return results_dict

In [5]:
# --- Example Usage ---
# models = [model1, model2]
# model_names = ['ResNet50', 'EfficientNetB0']
# data_loaders = [val_loader1, val_loader2]
# dataset_names = ['CIFAKE', 'ExtraDataset']
# results = evaluate_models(models, model_names, data_loaders, dataset_names)