Cell 1: Imports + constants

In [None]:
import os
import math
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torchvision import transforms

# Constants
PATCH_SIZE = 64
STRIDE = 64  # start with no-overlap. later you can try 32 for overlap.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# Inference preprocessing: must match training
preprocess = transforms.Compose([
    transforms.ToTensor(),  # scales [0,255] -> [0,1]
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

Cell 2: Load model

In [None]:
# Cell 2: Load model (matches Notebook 03 setup)

import os, sys
from pathlib import Path
import torch

from src.dataset import load_eurosat_dataset
from src.models import build_model

# Resolve project root + make src importable
PROJECT_ROOT = Path(os.getcwd()).resolve()
if PROJECT_ROOT.name == "notebooks":
    PROJECT_ROOT = PROJECT_ROOT.parent

if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

print("Project root:", PROJECT_ROOT)

# Use the same device variable everywhere
device = DEVICE  # uses DEVICE from Cell 1
print("Device:", device)

# Data dir (same as Notebook 03)
DATA_DIR = PROJECT_ROOT / "data" / "raw" / "EuroSAT_RGB"
print("Data dir:", DATA_DIR)

# Checkpoint 
models_dir = PROJECT_ROOT / "models"
MODEL_PATH = models_dir / "resnet18_pre_ft_best.pth"   # fine-tuned best

if not MODEL_PATH.exists():
    print("Missing checkpoint:", MODEL_PATH)
    print("Available .pth files:")
    for p in sorted(models_dir.glob("*.pth")):
        print(" -", p.name)
    raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")

# Get class_names (only to set num_classes correctly) 
_, _, _, class_names = load_eurosat_dataset(
    data_dir=DATA_DIR,
    img_size=64,
    batch_size=64,
    seed=42,
    aug_level="light",
)

print("Num classes:", len(class_names))
print("Classes:", class_names)

# Build model architecture and load weights 
model = build_model(
    num_classes=len(class_names),
    model_name="resnet18",
    pretrained=True,
    freeze_backbone=False,
).to(device)

state = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state)
model.eval()

print("Loaded checkpoint:", MODEL_PATH.name)

Cell 3: Sanity check

In [None]:
def predict_patch(pil_patch: Image.Image):
    """
    pil_patch: PIL Image (64x64, RGB)
    returns: (pred_class:int, probs:np.ndarray)
    """
    x = preprocess(pil_patch).unsqueeze(0).to(DEVICE)  # (1,3,64,64)

    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
        pred = int(np.argmax(probs))

    return pred, probs

# Quick test
test_patch = Image.open("path/to/some_small_tile.png").convert("RGB").resize((64, 64))
pred, probs = predict_patch(test_patch)

print("Pred:", pred)
print("Top prob:", float(probs[pred]))
print("Probs shape:", probs.shape)

Cell 4: Tiling a large image into a prediction grid

In [None]:
def pad_to_multiple(img: Image.Image, multiple: int):
    w, h = img.size
    new_w = int(math.ceil(w / multiple) * multiple)
    new_h = int(math.ceil(h / multiple) * multiple)

    if new_w == w and new_h == h:
        return img, (0, 0)

    padded = Image.new("RGB", (new_w, new_h))
    padded.paste(img, (0, 0))
    return padded, (new_w - w, new_h - h)


def sliding_window_predict(image_path: str, patch_size=64, stride=64, return_confidence=False):
    img = Image.open(image_path).convert("RGB")
    img, pad = pad_to_multiple(img, patch_size)

    w, h = img.size
    cols = (w - patch_size) // stride + 1
    rows = (h - patch_size) // stride + 1

    pred_grid = np.zeros((rows, cols), dtype=np.int32)
    conf_grid = None

    if return_confidence:
        pass

    if return_confidence:
        first_patch = img.crop((0, 0, patch_size, patch_size))
        _, probs = predict_patch(first_patch)
        num_classes = probs.shape[0]
        conf_grid = np.zeros((rows, cols, num_classes), dtype=np.float32)

    for r in range(rows):
        y = r * stride
        for c in range(cols):
            x = c * stride
            patch = img.crop((x, y, x + patch_size, y + patch_size))
            pred, probs = predict_patch(patch)

            pred_grid[r, c] = pred
            if return_confidence:
                conf_grid[r, c, :] = probs

    meta = {
        "orig_size": Image.open(image_path).size,
        "padded_size": (w, h),
        "pad_added": pad,
        "patch_size": patch_size,
        "stride": stride,
        "rows": rows,
        "cols": cols
    }

    return pred_grid, conf_grid, meta

In [None]:
IMAGE_PATH = "path/to/large_satellite.png"

pred_grid, conf_grid, meta = sliding_window_predict(
    IMAGE_PATH,
    patch_size=PATCH_SIZE,
    stride=STRIDE,
    return_confidence=True
)

print(meta)
print("pred_grid shape:", pred_grid.shape)
print("conf_grid shape:", None if conf_grid is None else conf_grid.shape)

Cell 5: Quick visualization

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.imshow(pred_grid)
plt.title("Prediction Grid (class indices)")
plt.show()