In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
    roc_auc_score
)
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json
import os
from tqdm import tqdm
import random

# Add missing imports
from torchvision import transforms, datasets
import timm

# --- 1. Load Your Configurations ---
DATA_DIR = "dataset"
IMG_SIZE = 240
DEVICE = torch.device("cpu" if torch.cuda.is_available() else "cpu")
MODEL_NAMES = ['crossvit_tiny_240']

# --- 2. Recreate Validation DataLoader (Same as Training) ---
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load class mappings (assuming same as training)
full_dataset = datasets.ImageFolder(root=DATA_DIR)
idx_to_class = {v: k for k, v in full_dataset.class_to_idx.items()}
NUM_CLASSES = len(full_dataset.classes)

# --- 3. Evaluation Function ---
def evaluate_model(model, val_loader):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Evaluating"):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Convert numpy arrays to lists with native Python types
    all_preds = [int(x) for x in all_preds]
    all_labels = [int(x) for x in all_labels]
    all_probs = [list(map(float, x)) for x in all_probs]

    # Calculate metrics
    accuracy = float(accuracy_score(all_labels, all_preds))
    precision = float(precision_score(all_labels, all_preds, average="weighted"))
    recall = float(recall_score(all_labels, all_preds, average="weighted"))
    f1 = float(f1_score(all_labels, all_preds, average="weighted"))
    cm = confusion_matrix(all_labels, all_preds).tolist()
    cls_report = classification_report(all_labels, all_preds, target_names=list(full_dataset.class_to_idx.keys()), output_dict=True)
    
    # ROC-AUC (for multi-class)
    try:
        roc_auc = float(roc_auc_score(all_labels, all_probs, multi_class="ovo", average="weighted"))
    except:
        roc_auc = None  # Skip if too many classes or other issues

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1,
        "roc_auc": roc_auc,
        "confusion_matrix": cm,
        "classification_report": cls_report,
        "predictions": all_preds,
        "true_labels": all_labels,
        "class_probabilities": all_probs
    }

# --- 4. Visualization Functions ---
def plot_confusion_matrix(cm, class_names, model_name):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=class_names, 
                yticklabels=class_names)
    plt.title(f"Confusion Matrix - {model_name}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(f"{model_name}_confusion_matrix.png", bbox_inches='tight')
    plt.close()

# --- 5. Main Evaluation Loop ---
def main():
    # Load validation data (same split as during training)
    val_dataset = datasets.ImageFolder(root=DATA_DIR, transform=val_transforms)
    
    # Recreate the exact same train/val split as training
    random.seed(42)  # Same seed as training
    indices = list(range(len(val_dataset)))
    random.shuffle(indices)
    split_point = int(0.8 * len(indices))  # Same ratio as training
    val_indices = indices[split_point:]  # Only validation indices

    val_sampler = SubsetRandomSampler(val_indices)
    val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler, num_workers=4, pin_memory=True)
    
    print(f"Evaluating on {len(val_indices)} validation samples out of {len(val_dataset)} total samples")

    all_results = {}
    
    for model_name in MODEL_NAMES:
        print(f"\n=== Evaluating {model_name} ===")
        
        # Load model architecture
        model = timm.create_model(model_name, pretrained=False, num_classes=NUM_CLASSES)
        model.name = model_name
        
        # Load trained weights
        model_path = f"trained_models/{model_name}_best.pth"
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=DEVICE))
            model = model.to(DEVICE)
            
            # Evaluate
            results = evaluate_model(model, val_loader)
            all_results[model_name] = results
            
            # Save metrics
            with open(f"{model_name}_metrics.json", "w") as f:
                json.dump(results, f, indent=4)
            
            # Visualizations
            plot_confusion_matrix(
                np.array(results["confusion_matrix"]),
                list(full_dataset.class_to_idx.keys()),
                model_name
            )
            
            # Print summary
            print(f"\nResults for {model_name}:")
            print(f"Accuracy: {results['accuracy']:.4f}")
            print(f"Precision: {results['precision']:.4f}")
            print(f"Recall: {results['recall']:.4f}")
            print(f"F1-Score: {results['f1_score']:.4f}")
            if results['roc_auc']:
                print(f"ROC-AUC: {results['roc_auc']:.4f}")
            print(f"\nClassification Report:")
            for class_name, metrics in results["classification_report"].items():
                if isinstance(metrics, dict) and 'precision' in metrics:
                    print(f"{class_name}: Precision={metrics['precision']:.3f}, Recall={metrics['recall']:.3f}, F1={metrics['f1-score']:.3f}")
        else:
            print(f"Model weights not found at {model_path}. Skipping...")
    
    # Compare all models
    if all_results:
        print("\n=== Model Comparison ===")
        comparison_df = pd.DataFrame.from_dict({
            model: {
                'Accuracy': results['accuracy'],
                'Precision': results['precision'],
                'Recall': results['recall'],
                'F1-Score': results['f1_score'],
                'ROC-AUC': results['roc_auc'] if results['roc_auc'] else None
            }
            for model, results in all_results.items()
        }, orient='index')
        
        print(comparison_df)
        
        # Save comparison to CSV
        comparison_df.to_csv("model_comparison.csv")
        print("\nModel comparison saved to model_comparison.csv")
    else:
        print("No models were successfully evaluated.")

if __name__ == "__main__":
    main()