In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix,
    precision_recall_fscore_support,
    roc_curve,
    auc
)
import itertools
from torchsummary import summary
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import timm

In [None]:
# ---------------------- USER CONFIGURATION ----------------------
# 1. Path to your saved model weights (e.g., best_model_efficientnet_b0.pth)
MODEL_PATH = r"\best_model_efficientnet_b0.pth"

# 2. Test data directory structured for torchvision.datasets.ImageFolder:
#    TEST_DIR/
#       class0_name/
#         img1.png
#         img2.png
#         ...
#       class1_name/
#         ...
#    (Ensure the subfolder names exactly match your class labels, e.g., "wP", "wR", ..., "empty")
TEST_DIR   = r""

# 3. Batch size for DataLoader
BATCH_SIZE = 32

# 4. Device (GPU if available, otherwise CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------------------------------------------

In [None]:
def main():
    # ------------------------------------------
    # (A) Load Test Dataset with ImageFolder
    # ------------------------------------------
    test_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    test_dataset = datasets.ImageFolder(root=TEST_DIR, transform=test_transforms)
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2
    )

    idx_to_class = test_dataset.classes
    num_classes   = len(idx_to_class)
    print(f"[INFO] Found {num_classes} classes: {idx_to_class}")

    # ---------------------------------------------
    # (B) Rebuild EfficientNet-B0 and Load Weights
    # ---------------------------------------------
    print("[INFO] Creating EfficientNet-B0 model")
    model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes)
    print(f"[INFO] Loading weights from: {MODEL_PATH}")
    state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(state_dict)
    model.to(DEVICE)
    model.eval()

    # ----------------------------------------------
    # (C) Inference: Collect True Labels & Predictions
    # ----------------------------------------------
    print("[INFO] Running inference on test set...")
    all_labels = []
    all_preds  = []
    all_probs  = []

    softmax = torch.nn.Softmax(dim=1)

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = model(inputs)
            probs   = softmax(outputs)
            _, preds = torch.max(probs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds  = np.array(all_preds)
    all_probs  = np.array(all_probs)
    print("[INFO] Inference complete.")

    # ------------------------------------------------
    # (D) Confusion Matrix Heatmap (Normalized)
    # ------------------------------------------------
    print("[INFO] Generating confusion matrix...")
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(10, 8))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Normalized Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, idx_to_class, rotation=90)
    plt.yticks(tick_marks, idx_to_class)

    thresh = cm_norm.max() / 2.0
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        val = cm_norm[i, j]
        plt.text(j, i, f"{val:.2f}",
                 horizontalalignment="center",
                 color="white" if val > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig("confusion_matrix.png", dpi=300)
    plt.close()
    print("[INFO] Saved: confusion_matrix.png")

    # ---------------------------------------------------
    # (E) Precision, Recall, F1-score Bar Charts per Class
    # ---------------------------------------------------
    print("[INFO] Computing precision, recall, F1-score per class...")
    precision, recall, f1, support = precision_recall_fscore_support(
        all_labels, all_preds, labels=list(range(num_classes)), zero_division=0
    )

    x = np.arange(num_classes)
    width = 0.25

    plt.figure(figsize=(12, 6))
    plt.bar(x - width, precision, width, label='Precision')
    plt.bar(x,         recall,    width, label='Recall')
    plt.bar(x + width, f1,        width, label='F1-Score')
    plt.xticks(x, idx_to_class, rotation=90)
    plt.ylabel('Score')
    plt.title('Precision, Recall, F1-Score per Class')
    plt.legend()
    plt.tight_layout()
    plt.savefig("prf_bar_chart.png", dpi=300)
    plt.close()
    print("[INFO] Saved: prf_bar_chart.png")

    # ---------------------------------------------------
    # (F) ROC Curves per Class (using probability outputs)
    # ---------------------------------------------------
    print("[INFO] Generating ROC curves per class...")
    # Binarize true labels for multi-class ROC
    all_labels_bin = np.zeros((all_labels.size, num_classes))
    for i, lbl in enumerate(all_labels):
        all_labels_bin[i, lbl] = 1

    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        fpr, tpr, _ = roc_curve(all_labels_bin[:, i], all_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{idx_to_class[i]} (AUC = {roc_auc:.2f})")

    plt.plot([0, 1], [0, 1], 'k--', lw=1)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Each Class')
    plt.legend(loc='lower right', fontsize='small')
    plt.tight_layout()
    plt.savefig("roc_curves.png", dpi=300)
    plt.close()
    print("[INFO] Saved: roc_curves.png")

    # ---------------------------------------
    # (G) Model Architecture Summary (Text)
    # ---------------------------------------
    print("\n=== Model Architecture ===")
    print(model)
    print("\n=== Detailed Layer-by-Layer Summary ===")
    summary(model, (3, 224, 224))  # Adjust input size if different

if __name__ == "__main__":
    main()