In [None]:
"""
Â© 2025, Savanna-AI and Nicholus Mboga, nicholus.mboga@savanna-ai.be
All rights reserved: Software for Teaching purposes only. For any commercial use contact
nicholus.mboga@savanna-ai.be
"""


In [None]:
import sys
import os
sys.path.append(os.path.abspath('/Users/nicholus/Documents/GitHub/CI_CD_Tutorial/src/')) 
import torch
import pandas as pd
import random
from main import CustomVGG, EuroSATDataset  # Adjust import if needed
from torchvision import transforms
from PIL import Image
import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objects as go
import numpy as np

In [None]:

def plot_images_with_predictions(sample_df, num_rows=2, num_cols=5):
    """
    Visualizes a sample of images with their actual (ground truth) and predicted class labels using Plotly subplots.

    Parameters
    ----------
    sample_df : pandas.DataFrame
        DataFrame containing columns 'filename', 'ground_truth_class', and 'predicted_class'.
        Each row corresponds to an image and its associated labels.
    num_rows : int, optional
        Number of rows in the subplot grid. Default is 2.
    num_cols : int, optional
        Number of columns in the subplot grid. Default is 5.

    Returns
    -------
    None
        Displays an interactive Plotly figure directly in the notebook or script.

    """
    fig = sp.make_subplots(rows=num_rows, cols=num_cols, subplot_titles=[
        f"GT: {row.ground_truth_class}<br>Pred: {row.predicted_class}"
        for _, row in sample_df.iterrows()
    ])
    for i, (_, row) in enumerate(sample_df.iterrows()):
        img = Image.open(row['filename']).convert('RGB').resize((64, 64))
        fig.add_trace(
            go.Image(z=np.array(img)),
            row=1 + i // num_cols, col=1 + i % num_cols
        )
    fig.update_layout(height=400, width=1000, title_text="EuroSAT: Actual vs Predicted Classes")
    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    fig.show()

def run_inference_and_plot(
    data_dir, 
    checkpoint_path, 
    num_images_to_plot=10, 
    batch_size=8
):
    """
    Runs inference on a folder of images using a trained model, collects results, and visualizes a sample
    of images with their actual and predicted classes.

    Parameters
    ----------
    data_dir : str
        Path to the directory containing images for inference.
        Expected structure: data_dir/class_name/image.jpg.
    checkpoint_path : str
        Path to the saved model checkpoint file.
    num_images_to_plot : int, optional
        Number of images to sample and visualize. Default is 10.
    batch_size : int, optional
        Batch size for loading images during inference. Default is 8.

    Returns
    -------
    pandas.DataFrame
        DataFrame with columns 'filename', 'predicted_class', and 'ground_truth_class'
        for all images processed.
    """
    # 1. Define transform (same as in your dataset/module)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 2. Load dataset and dataloader
    dataset = EuroSATDataset(data_dir, transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # 3. Load model
    model = CustomVGG.load_from_checkpoint(checkpoint_path)
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    # 4. Run predictions
    class_map = CustomVGG.get_class_map()
    results = []
    with torch.no_grad():
        idx = 0
        for images, labels in loader:
            images=images.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            for i in range(images.size(0)):
                img_path = dataset.images[idx]
                results.append({
                    "filename": img_path,
                    "predicted_class": class_map[int(preds[i])],
                    "ground_truth_class": class_map[int(labels[i])]
                })
                idx += 1

    df = pd.DataFrame(results)
    
    # 5. Visualize 10 random images with Plotly
    sample_df = df.sample(n=min(num_images_to_plot, len(df)), random_state=42)
    plot_images_with_predictions(sample_df)
    
    return df



In [4]:
# run inference and visualise plots

df = run_inference_and_plot(
    data_dir="/Users/nicholus/Documents/GitHub/CI_CD_Tutorial/data/Sample_Test_Data",
    checkpoint_path="/Users/nicholus/Documents/GitHub/CI_CD_Tutorial/lightning_logs/version_8/checkpoints/best-epoch=04-val_acc=0.76.ckpt",
    num_images_to_plot=10
)