In [None]:
# If you need to install packages, uncomment these lines:
# %pip install torch torchvision pillow matplotlib

import os
from pathlib import Path
import math

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

# ---- User-configurable paths ----
CKPT_PATH = Path("./src/outputs/models/best_resnet50.pth")
IMAGES_DIR = Path("./data/inference")
OUTPUT_DIR = Path("./data/inference_annotated")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Inference image size
TARGET_SIZE = (64, 64)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [3]:
def letterbox_pad_to_size(img: Image.Image, target_size=(64, 64), fill=(0, 0, 0)):
    """
    Resize keeping aspect ratio and pad to target_size (w, h).
    Returns a PIL.Image of exactly target_size.
    """
    tw, th = target_size
    w, h = img.size

    # Compute scale while preserving aspect ratio
    scale = min(tw / w, th / h)
    new_w = max(1, int(round(w * scale)))
    new_h = max(1, int(round(h * scale)))

    # Resize
    img_resized = img.resize((new_w, new_h), Image.BICUBIC)

    # Create background and paste centered
    background = Image.new("RGB", (tw, th), fill)
    paste_x = (tw - new_w) // 2
    paste_y = (th - new_h) // 2
    background.paste(img_resized, (paste_x, paste_y))
    return background

# Normalization (adjust if your training used different stats)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

preprocess = transforms.Compose([
    transforms.Lambda(lambda im: im.convert("RGB")),  # ensure 3 channels
    transforms.Lambda(lambda im: letterbox_pad_to_size(im, TARGET_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])


In [4]:
def load_model_resnet50_from_ckpt(ckpt_path: Path, device: torch.device):
    """
    Load a ResNet-50 model from a checkpoint. Handles common patterns:
      - torch.save(model.state_dict())
      - torch.save({'state_dict': ..., 'class_to_idx': ...})
      - torch.save({'model_state_dict': ..., ...})
      - torch.save(model)  # full model (pickled)
    Returns (model, idx_to_class) where idx_to_class may be None if unknown.
    """
    ckpt = torch.load(ckpt_path, map_location=device)

    state_dict = None
    idx_to_class = None

    if isinstance(ckpt, dict):
        # Common keys for state dicts
        if 'state_dict' in ckpt and isinstance(ckpt['state_dict'], dict):
            state_dict = ckpt['state_dict']
        elif 'model_state_dict' in ckpt and isinstance(ckpt['model_state_dict'], dict):
            state_dict = ckpt['model_state_dict']

        # Class mapping if present
        if 'idx_to_class' in ckpt and isinstance(ckpt['idx_to_class'], (list, dict)):
            if isinstance(ckpt['idx_to_class'], list):
                idx_to_class = {i: name for i, name in enumerate(ckpt['idx_to_class'])}
            else:
                idx_to_class = {int(k): v for k, v in ckpt['idx_to_class'].items()}
        elif 'class_to_idx' in ckpt and isinstance(ckpt['class_to_idx'], dict):
            class_to_idx = ckpt['class_to_idx']
            idx_to_class = {int(v): k for k, v in class_to_idx.items()}

    # If checkpoint is a full model object (pickled)
    if state_dict is None and hasattr(ckpt, 'state_dict') and callable(ckpt.state_dict):
        try:
            model = ckpt.to(device)
            model.eval()
            # Optional mapping on the model
            if hasattr(model, 'idx_to_class'):
                ic = getattr(model, 'idx_to_class')
                if isinstance(ic, list):
                    idx_to_class = {i: name for i, name in enumerate(ic)}
                elif isinstance(ic, dict):
                    idx_to_class = {int(k): v for k, v in ic.items()}
            return model, idx_to_class
        except Exception as e:
            print(f"Failed to use checkpoint as a full model: {e}")

    # Build a base resnet50
    model = models.resnet50(weights=None)
    in_features = model.fc.in_features

    # Try to infer num_classes from state_dict
    num_classes = None
    if state_dict is not None:
        for k in ['fc.weight', 'module.fc.weight']:
            if k in state_dict:
                num_classes = state_dict[k].shape[0]
                break
    if num_classes is None:
        num_classes = 1000  # fallback (ImageNet)
        print("[Warning] Could not infer num_classes from checkpoint. Defaulting to 1000. Adjust if needed.")

    # Replace head
    model.fc = nn.Linear(in_features, num_classes, bias=True)

    # Remove 'module.' prefix if from DataParallel
    def strip_module(sd):
        if not any(k.startswith('module.') for k in sd.keys()):
            return sd
        return {k.replace('module.', '', 1): v for k, v in sd.items()}

    if state_dict is not None:
        state_dict = strip_module(state_dict)
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        if missing:
            print("Missing keys:", missing)
        if unexpected:
            print("Unexpected keys:", unexpected)

    model.to(device).eval()
    return model, idx_to_class


In [5]:
def load_image_paths(root: Path):
    exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
    return sorted([p for p in root.rglob('*') if p.suffix.lower() in exts])

@torch.no_grad()
def predict_logits(model, batch):
    return model(batch)

def softmax_probs(logits):
    return torch.softmax(logits, dim=1)

def get_label(idx, idx_to_class):
    if idx_to_class and idx in idx_to_class:
        return str(idx_to_class[idx])
    return str(idx)

def draw_label_on_image(img: Image.Image, text: str, margin=6):
    """
    Draw a solid rectangle and white text at top-left of the image.
    """
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.load_default()
    except Exception:
        font = None

    # Robust text size
    if hasattr(draw, 'textbbox'):
        left, top, right, bottom = draw.textbbox((0,0), text, font=font)
        text_w, text_h = right - left, bottom - top
    else:
        text_w, text_h = draw.textsize(text, font=font)

    pad = 4
    rect_w = text_w + 2 * pad
    rect_h = text_h + 2 * pad

    draw.rectangle([margin, margin, margin + rect_w, margin + rect_h], fill=(0, 0, 0))
    draw.text((margin + pad, margin + pad), text, fill=(255, 255, 255), font=font)
    return img


In [6]:
# Load model
assert CKPT_PATH.exists(), f"Checkpoint not found at {CKPT_PATH}"
model, idx_to_class = load_model_resnet50_from_ckpt(CKPT_PATH, device)

# Collect images
image_paths = load_image_paths(IMAGES_DIR)
print(f"Found {len(image_paths)} image(s) in {IMAGES_DIR}")

# Inference loop
for img_path in image_paths:
    orig = Image.open(img_path).convert('RGB')

    # Preprocess to 64x64 with aspect-ratio padding
    x = preprocess(orig)                # 3x64x64
    x = x.unsqueeze(0).to(device)       # 1x3x64x64

    # Forward pass
    logits = predict_logits(model, x)
    probs = softmax_probs(logits).squeeze(0).cpu()

    # Top-1 prediction
    top1_prob, top1_idx = torch.max(probs, dim=0)
    label = get_label(int(top1_idx.item()), idx_to_class)
    conf = float(top1_prob.item())

    # Annotate original image for display/saving
    annotated = orig.copy()
    draw_label_on_image(annotated, f"{label} ({conf:.2f})")

    # Save annotated copy
    out_path = OUTPUT_DIR / img_path.name
    annotated.save(out_path)

    # Display inline
    plt.figure()
    plt.imshow(annotated)
    plt.axis('off')
    plt.title(f"{img_path.name} → {label} ({conf:.2f})")
    plt.show()


AssertionError: Checkpoint not found at src\outputs\best_resnet50.pth

In [None]:

from math import ceil
saved = sorted([p for p in OUTPUT_DIR.glob('*') if p.suffix.lower() in {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}])
n = len(saved)
if n:
    cols = min(4, n)
    rows = math.ceil(n / cols)
    plt.figure(figsize=(4*cols, 4*rows))
    for i, p in enumerate(saved, 1):
        img = Image.open(p).convert('RGB')
        plt.subplot(rows, cols, i)
        plt.imshow(img)
        plt.axis('off')
        plt.title(p.name, fontsize=9)
    plt.tight_layout()
    plt.show()
