In [None]:
def evaluate_best_model(
    model, 
    test_loader, 
    load_path="best_model.pth",
):
    """
    Evaluate the best saved model on test data and display results.
    """
    # Load the best model weights
    model.load_state_dict(torch.load(load_path))
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0
    all_labels = []
    all_predictions = []

    print("\nEvaluating the best model on test data...")
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            # Metrics calculation
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Collect predictions and labels for confusion matrix
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

            # GPU memory management
            del images, labels, outputs

    # Calculate accuracy
    accuracy = 100 * correct / total
    print(f"\nTest Accuracy: {accuracy:.2f}%")

    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_predictions, labels=np.arange(num_classes))
    print("\nConfusion Matrix:")
    print(cm)

    # Classification Report
    if classes:
        print("\nClassification Report:")
        print(classification_report(all_labels, all_predictions, target_names=classes))

    # Display Confusion Matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes if classes else np.arange(num_classes))
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix: Test Data")
    plt.show()

    return accuracy, cm


In [None]:
test_accuracy, test_cm = evaluate_best_model(
    model=model, 
    test_loader=test_loader, 
    load_path="best_model.pth",
)