ML Dog Identifier

Our research questions are:

Can our custom CNN match ResNet-18 in classification accuracy?


Which dog breeds are most often misclassified, and why?


How does image background—indoor versus outdoor—affect model performance?


Can we prune the model to reduce inference time without significant accuracy loss?


Finally, is real-time inference achievable on mobile hardware?"


In [None]:
!pip install -q tqdm

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, models
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [None]:
!pip install -q kaggle

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
data_dir = "/content/drive/MyDrive/stanford-dogs/images/Images"

In [None]:
import os

# List all breed folders
classes = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
print(f"Found {len(classes)} breeds:\n", classes)

# Count images per breed
counts = {cls: len(os.listdir(os.path.join(data_dir, cls))) for cls in classes}
print("\nSample counts:", dict(list(counts.items())))


In [None]:
import shutil, random
from pathlib import Path

random.seed(42)

# Split ratios
train_ratio, val_ratio = 0.8, 0.1
test_ratio = 1.0 - train_ratio - val_ratio

# Where to write the split folders
base_split = Path("/content/stanford-dogs-splits")
for split in ["train", "val", "test"]:
    (base_split / split).mkdir(parents=True, exist_ok=True)

# Perform split
for cls in classes:
    imgs = list((Path(data_dir) / cls).glob("*.jpg"))
    random.shuffle(imgs)
    n = len(imgs)
    n_train = int(train_ratio * n)
    n_val = int(val_ratio * n)

    splits = {
        "train": imgs[:n_train],
        "val": imgs[n_train:n_train + n_val],
        "test": imgs[n_train + n_val:]
    }

    for split_name, files in splits.items():
        target_dir = base_split / split_name / cls
        target_dir.mkdir(parents=True, exist_ok=True)
        for f in files:
            shutil.copy(f, target_dir / f.name)

print("Done! Check `/content/stanford-dogs-splits/{train,val,test}`")

In [None]:
from torchvision import transforms
from torchvision.datasets import ImageFolder

# We resize images to 224×224 (standard for ResNet), convert to tensor, and normalize w.r.t. ImageNet stats
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

# Point to your split folders
base_split = "/content/stanford-dogs-splits"
train_dir = f"{base_split}/train"
val_dir   = f"{base_split}/val"
test_dir  = f"{base_split}/test"

# Create datasets
train_dataset = ImageFolder(root=train_dir, transform=transform)
val_dataset   = ImageFolder(root=val_dir,   transform=transform)
test_dataset  = ImageFolder(root=test_dir,  transform=transform)

In [None]:
from torch.utils.data import DataLoader

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid

# Get one batch
images, labels = next(iter(train_loader))
class_names = train_dataset.classes

# Unnormalize and plot
img_grid = make_grid(images[:16], nrow=4, padding=2)
np_img = img_grid.numpy().transpose((1, 2, 0))
np_img = np.clip((np_img * [0.229,0.224,0.225] + [0.485,0.456,0.406]), 0, 1)

plt.figure(figsize=(8,8))
plt.imshow(np_img)
plt.title([class_names[i] for i in labels[:16]])
plt.axis('off')
plt.show()

In [None]:
!wget -q http://places2.csail.mit.edu/models_places365/resnet18_places365.pth.tar

In [None]:
import torch
from torchvision import models as tv_models

places_model = tv_models.resnet18(num_classes=365)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load('resnet18_places365.pth.tar', map_location=device)

state_dict = {k.replace('module.',''): v for k,v in checkpoint['state_dict'].items()}
places_model.load_state_dict(state_dict)

places_model.eval().to(device)

print("Places365 model loaded successfully.")

In [None]:
!wget -q https://raw.githubusercontent.com/csailvision/places365/master/categories_places365.txt

!wget -q https://raw.githubusercontent.com/csailvision/places365/master/IO_places365.txt

In [None]:
scene_classes = []
with open('categories_places365.txt') as f:
    for line in f:
        scene_classes.append(line.strip().split(' ')[0])

io_labels = []
with open('IO_places365.txt') as f:
    for line in f:
        io_labels.append(int(line.strip().split()[-1]))

In [None]:
from torchvision import transforms

places_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

In [None]:
import os
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
from pathlib import Path

test_dir = "/content/stanford-dogs-splits/test"

records = []
for breed in classes:
    breed_dir = Path(test_dir) / breed
    for img_path in breed_dir.glob("*.jpg"):
        img = Image.open(img_path).convert("RGB")
        inp = places_transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = places_model(inp)
        idx = torch.argmax(logits, dim=1).item()
        bg = "outdoor" if io_labels[idx] == 1 else "indoor"
        records.append({
            "path": str(img_path),
            "breed": breed,
            "background": bg
        })

df_bg = pd.DataFrame(records)
df_bg.head()

In [None]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),        # random crop and resize
    transforms.RandomHorizontalFlip(),        # flip half the images
    transforms.ColorJitter(brightness=0.2,    # random color adjustments
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_dataset = ImageFolder(root=train_dir, transform=train_transform)
val_dataset   = ImageFolder(root=val_dir,   transform=val_test_transform)
test_dataset  = ImageFolder(root=test_dir,  transform=val_test_transform)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid

images, labels = next(iter(train_loader))
img_grid = make_grid(images[:16], nrow=4)
np_img = img_grid.numpy().transpose((1,2,0))
np_img = np.clip((np_img * [0.229,0.224,0.225] + [0.485,0.456,0.406]), 0,1)

plt.figure(figsize=(8,8))
plt.imshow(np_img)
plt.axis('off')
plt.show()

In [None]:
from torchvision import transforms

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),        # random crop + resize
    transforms.RandomHorizontalFlip(),        # flip half the images
    transforms.ColorJitter(brightness=0.2,    # random color adjustments
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Re-create datasets with new transforms
train_dataset = ImageFolder(root=train_dir, transform=train_transform)
val_dataset   = ImageFolder(root=val_dir,   transform=val_test_transform)
test_dataset  = ImageFolder(root=test_dir,  transform=val_test_transform)

# DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid

# Grab a batch and show 16 augmented images
images, labels = next(iter(train_loader))
img_grid = make_grid(images[:16], nrow=4)
np_img = img_grid.numpy().transpose((1,2,0))
np_img = np.clip((np_img * [0.229,0.224,0.225] + [0.485,0.456,0.406]), 0,1)

plt.figure(figsize=(8,8))
plt.imshow(np_img)
plt.axis('off')
plt.show()

Custom CNN

In [None]:
import torch.nn as nn
import torchvision.models as models

# Number of classes
num_classes = len(train_dataset.classes)

class SmallCNN(nn.Module):
    def __init__(self, num_classes):
        super(SmallCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x

small_cnn = SmallCNN(num_classes).to(device)

In [None]:
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn

weights = ResNet18_Weights.DEFAULT

resnet18 = resnet18(weights=weights)

in_features = resnet18.fc.in_features
resnet18.fc = nn.Linear(in_features, num_classes)

for name, param in resnet18.named_parameters():
    if "fc" not in name:
        param.requires_grad = False

resnet18 = resnet18.to(device)

Begin Training

In [None]:
import time
import torch
import torch.nn.functional as F

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for inputs, labels in tqdm(loader, desc="Train", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data).item()
        total += inputs.size(0)

    epoch_loss = running_loss / total
    epoch_acc  = running_corrects / total
    return epoch_loss, epoch_acc

def validate(model, loader, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Validate", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data).item()
            total += inputs.size(0)

    epoch_loss = running_loss / total
    epoch_acc  = running_corrects / total
    return epoch_loss, epoch_acc

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer_small = optim.AdamW(small_cnn.parameters(),
                              lr=1e-3, weight_decay=1e-4)
scheduler_small = CosineAnnealingLR(optimizer_small, T_max=10)

optimizer_resnet = optim.AdamW(
    filter(lambda p: p.requires_grad, resnet18.parameters()),
    lr=1e-3, weight_decay=1e-4
)
scheduler_resnet = CosineAnnealingLR(optimizer_resnet, T_max=10)

def fit_model(model, train_loader, val_loader, optimizer, scheduler, device, epochs=10):
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    for epoch in range(1, epochs+1):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss,   val_acc   = validate(model,   val_loader,   device)
        scheduler.step()

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"Epoch {epoch}/{epochs} | "
              f"Train: {train_loss:.4f}, {train_acc:.4f} | "
              f"Val:   {val_loss:.4f}, {val_acc:.4f}")
    return history

hist_small = fit_model(small_cnn, train_loader, val_loader,
                       optimizer_small, scheduler_small,
                       device, epochs=10)

hist_resnet = fit_model(resnet18, train_loader, val_loader,
                        optimizer_resnet, scheduler_resnet,
                        device, epochs=10)


In [None]:
import matplotlib.pyplot as plt

def plot_history(hist, title):
    epochs = range(1, len(hist["train_loss"]) + 1)

    plt.figure(figsize=(12,5))

    plt.subplot(1,2,1)
    plt.plot(epochs, hist["train_loss"], label="Train Loss")
    plt.plot(epochs, hist["val_loss"],   label="Val Loss")
    plt.title(f"{title} — Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(epochs, hist["train_acc"], label="Train Acc")
    plt.plot(epochs, hist["val_acc"],   label="Val Acc")
    plt.title(f"{title} — Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()

    plt.show()

plot_history(hist_small, "SmallCNN")

plot_history(hist_resnet, "ResNet-18")

Evaluate

In [None]:
# Evaluate SmallCNN
test_loss_small, test_acc_small = validate(small_cnn, test_loader, device)
print(f"SmallCNN Test — Loss: {test_loss_small:.4f}, Accuracy: {test_acc_small:.4f}")

# Evaluate ResNet-18
test_loss_resnet, test_acc_resnet = validate(resnet18, test_loader, device)
print(f"ResNet-18 Test — Loss: {test_loss_resnet:.4f}, Accuracy: {test_acc_resnet:.4f}")

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

# Gather all predictions & true labels
all_preds, all_labels = [], []

resnet18.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing", leave=False):
        inputs = inputs.to(device)
        outputs = resnet18(inputs)
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)
class_names = test_dataset.classes

# Convert to DataFrame for readability
df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
df_cm.head()

In [None]:
# Zero out the diagonal to ignore correct predictions
cm_no_diag = cm.copy()
np.fill_diagonal(cm_no_diag, 0)

N = 10
flat = cm_no_diag.flatten()
top_idxs = flat.argsort()[::-1][:N]
rows = top_idxs // cm_no_diag.shape[1]
cols = top_idxs % cm_no_diag.shape[1]
counts = flat[top_idxs]

misclassifications = pd.DataFrame({
    "Predicted": [class_names[c] for c in cols],
    "Actual":    [class_names[r] for r in rows],
    "Count":     counts
})

misclassifications

In [None]:
import pandas as pd

paths, true_idxs = zip(*test_dataset.samples)

df_results = pd.DataFrame({
    "path": paths,
    "actual": [test_dataset.classes[i] for i in all_labels],
    "predicted": [test_dataset.classes[i] for i in all_preds]
})

df_results.head()

In [None]:
df_merged = df_results.merge(df_bg[["path", "background"]], on="path")

# Quick peek
df_merged.head()

In [None]:
df_merged["correct"] = (df_merged["predicted"] == df_merged["actual"]).astype(int)

accuracy_by_bg = df_merged.groupby("background")["correct"].mean().reset_index()
accuracy_by_bg.columns = ["background", "accuracy"]

print(accuracy_by_bg)

import matplotlib.pyplot as plt

plt.figure(figsize=(6,4))
plt.bar(accuracy_by_bg["background"], accuracy_by_bg["accuracy"])
plt.ylim(0,1)
plt.title("Test Accuracy by Background Type")
plt.xlabel("Background")
plt.ylabel("Accuracy")
plt.show()


In [None]:
import time

# Ensure model is in eval mode
resnet18.eval()

# Measure inference time over test set
start_time = time.perf_counter()
running_corrects = 0
total_images = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Baseline Inference"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet18(inputs)
        preds = torch.argmax(outputs, dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_images += inputs.size(0)

end_time = time.perf_counter()

baseline_acc  = running_corrects / total_images
baseline_time = (end_time - start_time) / total_images  # seconds per image

print(f"Baseline ResNet-18 — Accuracy: {baseline_acc:.4f}, "
      f"Avg Inference Time: {baseline_time*1000:.2f} ms/image")

Pruning

In [None]:
import torch.nn.utils.prune as prune
import torch.nn as nn

# Prune 30% of weights in each Conv2d and Linear layer
for module in resnet18.modules():
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        prune.l1_unstructured(module, name="weight", amount=0.3)
        # Make pruning permanent
        prune.remove(module, "weight")

print("Applied 30% L1 unstructured pruning to all Conv2d and Linear layers.")

In [None]:
# Ensure pruned model is in eval mode
resnet18.eval()

# Measure inference time over test set again
start_time = time.perf_counter()
running_corrects_pruned = 0
total_images = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Pruned Inference"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet18(inputs)
        preds = torch.argmax(outputs, dim=1)
        running_corrects_pruned += torch.sum(preds == labels).item()
        total_images += inputs.size(0)

end_time = time.perf_counter()

pruned_acc  = running_corrects_pruned / total_images
pruned_time = (end_time - start_time) / total_images  # seconds per image

print(f"Pruned ResNet-18 — Accuracy: {pruned_acc:.4f}, "
      f"Avg Inference Time: {pruned_time*1000:.2f} ms/image")

In [None]:
import pandas as pd

df_prune = pd.DataFrame({
    "Model": ["Baseline", "Pruned (30%)"],
    "Accuracy": [baseline_acc, pruned_acc],
    "Time_ms_per_image": [baseline_time*1000, pruned_time*1000]
})
df_prune

In [None]:
import torch

# Move the pruned ResNet-18 to CPU and set eval mode
resnet18_cpu = resnet18.to('cpu')
resnet18_cpu.eval()

# Create example input matching our model’s expected size
example_input = torch.randn(1, 3, 224, 224)

# Trace the model
traced_model = torch.jit.trace(resnet18_cpu, example_input)

# Save the TorchScript module
traced_model.save("resnet18_pruned_mobile.pt")

print("Model exported to TorchScript: resnet18_pruned_mobile.pt")

In [None]:
import torch
import time
from tqdm.notebook import tqdm

# Load the TorchScript model
mobile_model = torch.jit.load("resnet18_pruned_mobile.pt")
mobile_model.eval()
mobile_model.to('cpu')

latencies = []

# Iterate through test images batch by batch
with torch.no_grad():
    for inputs, _ in tqdm(test_loader, desc="Mobile CPU Inference"):
        # inputs: [batch_size, 3, 224, 224]
        for img in inputs:
            img = img.unsqueeze(0).to('cpu')  # make it [1,3,224,224]
            start = time.perf_counter()
            _ = mobile_model(img)
            end = time.perf_counter()
            latencies.append((end - start) * 1000)  # convert to milliseconds

# Compute average latency
avg_latency = sum(latencies) / len(latencies)
print(f"Average CPU inference latency: {avg_latency:.2f} ms/image")

In [None]:
import pandas as pd

# 1) Test Accuracies
df_accuracy = pd.DataFrame({
    "Model": ["SmallCNN", "ResNet-18"],
    "Test Accuracy": [test_acc_small, test_acc_resnet]
})

df_latency = pd.DataFrame({
    "Model": ["CPU inference (ms/image)"],
    "Value": [avg_latency]
})

# Display all tables
print("=== Test Accuracy Comparison ===")
display(df_accuracy)

print("\n=== Top Misclassified Breed Pairs ===")
display(misclassifications)

print("\n=== Accuracy by Background ===")
display(accuracy_by_bg)

print("\n=== Pruning Trade-Off ===")
display(df_prune)

print("\n=== Mobile CPU Latency ===")
display(df_latency)

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.bar(df_accuracy["Model"], df_accuracy["Test Accuracy"])
plt.ylim(0,1)
plt.title("Test Accuracy: SmallCNN vs ResNet-18")
plt.xlabel("Model")
plt.ylabel("Accuracy")
plt.show()

In [None]:
plt.figure()
plt.bar(accuracy_by_bg["background"], accuracy_by_bg["accuracy"])
plt.ylim(0,1)
plt.title("Test Accuracy by Image Background")
plt.xlabel("Background")
plt.ylabel("Accuracy")
plt.show()

In [None]:
plt.figure()
plt.bar(df_prune["Model"], df_prune["Accuracy"])
plt.ylim(0,1)
plt.title("Accuracy Before vs After Pruning")
plt.xlabel("Model")
plt.ylabel("Accuracy")
plt.show()

In [None]:
plt.figure()
plt.bar(df_prune["Model"], df_prune["Time_ms_per_image"])
plt.title("Inference Time Before vs After Pruning")
plt.xlabel("Model")
plt.ylabel("Time (ms per image)")
plt.show()

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

short_names = [name.split('-', 1)[1] for name in test_dataset.classes]

report = classification_report(all_labels, all_preds, target_names=short_names)
print("Classification Report:\n", report)

In [None]:
import random
import matplotlib.pyplot as plt

# Prepare short names mapping
short_names = [name.split('-', 1)[1] for name in test_dataset.classes]

# Build list of (path, actual_idx, pred_idx)
paths = [p for p,_ in test_dataset.samples]
pairs = list(zip(paths, all_labels, all_preds))

# Split correct and incorrect
correct_samples   = [(p,a,pr) for p,a,pr in pairs if a==pr]
incorrect_samples = [(p,a,pr) for p,a,pr in pairs if a!=pr]

# Randomly pick 3 of each
random.seed(42)
sample_correct   = random.sample(correct_samples,   3)
sample_incorrect = random.sample(incorrect_samples, 3)

def display_grid(samples, title):
    plt.figure(figsize=(12,4))
    for i, (path, actual, pred) in enumerate(samples):
        img = plt.imread(path)
        plt.subplot(1, len(samples), i+1)
        plt.imshow(img)
        plt.title(f"Actual: {short_names[actual]}\nPred: {short_names[pred]}")
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

display_grid(sample_correct,   "3 Correct Predictions")
display_grid(sample_incorrect, "3 Incorrect Predictions")