In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from models.Resnet_model import get_resnet18  # Adjust this import based on your structure

# Test dataset path and device configuration
test_data_path = r"C:\Users\Lenovo\Desktop\ai_project\datasets\currencyDataset"  # Path to test dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define test transformations (same as used during training/validation)
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load test dataset with transformations
test_dataset = ImageFolder(root=test_data_path, transform=test_transforms)
class_names = test_dataset.classes  # Retrieve class names

# Create DataLoader for test dataset
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)


def load_model(model_path, num_classes, device):
    """
    Load the model from a checkpoint.

    Args:
        model_path (str): Path to the saved checkpoint file.
        num_classes (int): Number of classes for the model.
        device (torch.device): Device to load the model on (CPU/GPU).

    Returns:
        torch.nn.Module: The loaded model.
    """
    # Initialize the model
    model = get_resnet18(num_classes=num_classes)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    if "state_dict" in checkpoint:  # Handle case where checkpoint contains extra metadata
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint)  # Directly load if it's a plain state_dict
    
    model.to(device)
    model.eval()
    return model


def test_model(model, test_loader, class_names, device, num_images_to_display=20):
    """
    Test the model and visualize predictions.

    Args:
        model (torch.nn.Module): Trained model.
        test_loader (DataLoader): DataLoader for test dataset.
        class_names (list): List of class names.
        device (torch.device): Device to run the model on.
        num_images_to_display (int): Number of test images to display.
    """
    model.eval()
    images_shown = 0
    images_per_row = 5  # Number of images per row for visualization

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predictions = outputs.argmax(dim=1)

            # Visualize images
            if images_shown < num_images_to_display:
                images_np = images.cpu().numpy()
                labels_np = labels.cpu().numpy()
                predictions_np = predictions.cpu().numpy()

                num_images = min(len(images), num_images_to_display - images_shown)
                rows = (num_images + images_per_row - 1) // images_per_row  # Calculate rows needed
                fig, axes = plt.subplots(rows, images_per_row, figsize=(15, 3 * rows))

                if rows == 1:
                    axes = [axes]  # Ensure axes is iterable if there's only one row

                for i in range(num_images):
                    row, col = divmod(i, images_per_row)
                    ax = axes[row][col] if rows > 1 else axes[col]
                    image = np.transpose(images_np[i], (1, 2, 0))  # Convert (C, H, W) to (H, W, C)

                    # Unnormalize the image (reverse normalization)
                    mean = np.array([0.485, 0.456, 0.406])
                    std = np.array([0.229, 0.224, 0.225])
                    image = (image * std + mean).clip(0, 1)

                    ax.imshow(image)
                    ax.set_title(f"Pred: {class_names[predictions_np[i]]}\nAct: {class_names[labels_np[i]]}")
                    ax.axis('off')

                # Turn off unused subplots
                for j in range(num_images, rows * images_per_row):
                    row, col = divmod(j, images_per_row)
                    ax = axes[row][col] if rows > 1 else axes[col]
                    ax.axis('off')

                plt.tight_layout()
                plt.show()

                images_shown += num_images
                if images_shown >= num_images_to_display:
                    break


# Paths and Configurations
model_path = "./checkpoints/saved_model.pth"  # Path to your saved checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = load_model(model_path, num_classes=len(class_names), device=device)

# Test and visualize predictions
test_model(model, test_loader, class_names, device, num_images_to_display=80)