In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
# Set the device to MPS if available, otherwise use CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

# Prepare the data with appropriate augmentations for the test set
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset = ImageFolder(root='./database/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Define the evaluation function
def evaluate_model(model, test_loader):
    model.eval()
    test_targets = []
    test_outputs = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            test_targets.extend(labels.cpu().numpy())
            test_outputs.extend(preds.cpu().numpy())

    # Calculate metrics
    test_targets = np.array(test_targets)
    test_outputs = np.array(test_outputs)
    cm = confusion_matrix(test_targets, test_outputs)
    accuracy = accuracy_score(test_targets, test_outputs)
    precision_micro = precision_score(test_targets, test_outputs, average='micro')
    recall_micro = recall_score(test_targets, test_outputs, average='micro')
    f1_micro = f1_score(test_targets, test_outputs, average='micro')
    precision_macro = precision_score(test_targets, test_outputs, average='macro')
    recall_macro = recall_score(test_targets, test_outputs, average='macro')
    f1_macro = f1_score(test_targets, test_outputs, average='macro')

    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision (Micro): {precision_micro:.4f}')
    print(f'Recall (Micro): {recall_micro:.4f}')
    print(f'F1 Score (Micro): {f1_micro:.4f}')
    print(f'Precision (Macro): {precision_macro:.4f}')
    print(f'Recall (Macro): {recall_macro:.4f}')
    print(f'F1 Score (Macro): {f1_macro:.4f}')

    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
# Load the trained model
model = CustomCNN().to(device)
model.load_state_dict(torch.load('models/best_model.pth'))

# Evaluate the model
evaluate_model(model, test_loader)