05 — Model Evaluation & Inference

This notebook evaluates the ResNet18 model trained in Notebook 04. The goal is to measure how well the model performs on unseen images and understand where it succeeds or struggles.

What’s included:
1. Load the trained model: 
    Rebuild the ResNet18 architecture and load the saved .pth weights.
2. Prepare the test dataset: 
    Apply the same transforms used during validation and load images for evaluation.
3. Compute test accuracy: 
    Measure the model’s performance on unseen data.
4. Confusion matrix: 
    Visualize which classes the model predicts correctly and which ones it mixes up.
5. Sample predictions: 
    Display test images with true and predicted labels.
6. Misclassified samples: 
    Inspect the images the model got wrong to understand common mistakes.
7. Custom image inference: 
    Test the model on your own image to see how it performs outside the dataset.

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader

import matplotlib as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
from PIL import Image
import os

import sys
sys.path.append("..")
from dataset import PlantVillageDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
#transforms

test_transforms = transforms.Compose([
  transforms.Resize((160,160)),
  transforms.ToTensor(),
  transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229,0.224,0.225]
  )
])

test_dataset = PlantVillageDataset(
  "../data/PlantVillage/val",
  transform=test_transforms
)

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

print(f"Test samples: {len(test_dataset)}")
print(f"Classes: {test_dataset.classes}")


In [None]:
model = models.resnet18(weights=None) #no pretrained weights
num_classes = len(test_dataset.classes)
model.fc = nn.Linear(512, num_classes)

#load trained weights
model.load_state_dict(torch.load("../models/resnet18_best.pth", map_location=device))
model = model.to(device)
model.eval() #evalualion mode

In [None]:
#compute accuracy
correct = 0 
total = 0

with torch.no_grad():
  for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)
    _,preds = torch.max(outputs, 1)

    correct += (preds == labels).sum().item()
    total += labels.size(0)
test_acc = correct / total
print(f"Final Test Accuracy: {test_acc:.4f}")

In [None]:
#confusion matrix -> table that shows how model actually performed for each class
#each row reps what leaf actually is
#each col reps what the model guessed

all_preds = []
all_targets = []

#not training, so no gradients are computed
with torch.no_grad():
  for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)

    outputs = model(images) #get model preds

    #torch_max(...,1) returns:
    # preds = idx of highest score -> pred class
    _, preds = torch.ax(outputs, 1)

    #save preds and true labels into py lists
    #just incase it does use gpu, we need to move to cpu since numoy needs it
    all_preds.extend(preds.cpu().numpy())
    all_targets.extend(labels.cpu().numpy())

#build confusion mtx
cm = confusion_matrix(all_targets, all_preds)

#normalize so each row values are in range 0-1
#shows % of preds per class
cm_norm = cm / cm.sum(axis=1, keepdims=True)

#plot
plt.figure(figsize=(12,10))
sns.heatmap(cm_norm, annot=False, cmap="Blues")
plt.title("Confusion Matrix (Normalized)")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.show()


In [None]:
def show_preds(model, dataset, num_imgs=6):
  model.eval()
  plt.figure(figsize=(14,8))

  for i in range(num_imgs):
    img,label = dataset[i]
    img_display = img.permute(1,2,0) # convert from tensor CHW to HWC
    img_batch = img.unsqueeze(0).to(device)

    with torch.no_grad():
      output = model(img_batch)
      _,pred  = torch.max(output, 1)

      true_class = dataset.classes[label]
      pred_class = dataset.classes[pred.item()]

      plt.subplot(2,3,i+1)
      plt.imshow(img_display)
      plt.title(f"True: {true_class}\nPred: {pred_class}")
      plt.axis("off")

  plt.show()

show_preds(model, test_dataset, num_imgs=6)