In [1]:
# === Jupyter-friendly inference (infer.py) =========================
import json
from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms as T

# ---- config ----
CKPT_DIR = Path("checkpoints_multiclass_strong")   # where best.pt lives
DATA_DIR = Path("disaster-ai/data/xbd/tier1")      # used if you pass relative paths

# ---- model builders ----
def build_model(backbone: str, n_classes: int):
    if backbone == "resnet18":
        m = torchvision.models.resnet18(weights=None)
        m.fc = nn.Linear(m.fc.in_features, n_classes)
    elif backbone == "resnet50":
        m = torchvision.models.resnet50(weights=None)
        m.fc = nn.Linear(m.fc.in_features, n_classes)
    elif backbone == "efficientnet_b0":
        m = torchvision.models.efficientnet_b0(weights=None)
        m.classifier[1] = nn.Linear(m.classifier[1].in_features, n_classes)
    elif backbone == "vit_b_16":
        m = torchvision.models.vit_b_16(weights=None)
        m.heads.head = nn.Linear(m.heads.head.in_features, n_classes)
    else:
        raise ValueError(f"Unknown backbone: {backbone}")
    return m

def load_pipeline(ckpt_dir: Path = CKPT_DIR):
    """Load best.pt and return (model, transform, classes)."""
    ckpt_path = ckpt_dir / "best.pt"
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location="cpu")

    classes  = ckpt["classes"]
    img_size = ckpt["img_size"]
    backbone = ckpt["backbone"]

    model = build_model(backbone, len(classes))
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    tform = T.Compose([
        T.Resize(int(img_size * 1.2)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
    ])
    return model, tform, classes

# ---- prediction helpers ----
def predict_one(model, tform, classes, path: Path):
    img = Image.open(path).convert("RGB")
    x = tform(img).unsqueeze(0)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1).squeeze(0)
        top = probs.argmax().item()
    return classes[top], float(probs[top])

def run_inference(image_paths, return_json=False, quiet=False):
    """
    image_paths: list[str|Path] or a single str/Path
    return_json: if True returns JSON string; else returns list of dicts
    quiet:       if False prints one-line summary per image
    """
    if isinstance(image_paths, (str, Path)):
        image_paths = [image_paths]

    model, tform, classes = load_pipeline()

    outputs = []
    for p in image_paths:
        path = Path(p)
        if not path.is_file():
            # allow relative paths like "images/foo.png"
            path = DATA_DIR / p
        if not path.is_file():
            raise FileNotFoundError(f"Image not found: {p} (also tried {path})")

        label, conf = predict_one(model, tform, classes, path)
        rec = {"image_path": str(path), "pred": label, "conf": round(conf, 4)}
        outputs.append(rec)
        if not quiet:
            print(f"{path.name:45s} -> {label:13s} (conf={conf:.2f})")

    if return_json:
        return json.dumps(outputs, indent=2)
    return outputs
# ==================================================================

In [5]:
# Single image
run_inference("images/socal-fire_00001323_post_disaster.png")

# Multiple images
imgs = [
    "images/palu-tsunami_00000125_post_disaster.png",
    "images/socal-fire_00001333_post_disaster.png",
]
preds = run_inference(imgs)

# If you prefer JSON output (e.g., to save or send somewhere)
json_str = run_inference(imgs, return_json=True, quiet=True)
print(json_str)

socal-fire_00001323_post_disaster.png         -> destroyed     (conf=0.96)
palu-tsunami_00000125_post_disaster.png       -> no-damage     (conf=1.00)
socal-fire_00001333_post_disaster.png         -> no-damage     (conf=1.00)
[
  {
    "image_path": "disaster-ai/data/xbd/tier1/images/palu-tsunami_00000125_post_disaster.png",
    "pred": "no-damage",
    "conf": 0.9997
  },
  {
    "image_path": "disaster-ai/data/xbd/tier1/images/socal-fire_00001333_post_disaster.png",
    "pred": "no-damage",
    "conf": 0.9997
  }
]


In [7]:
# --- Batch prediction -> CSV ---------------------------------------------------
from pathlib import Path
import pandas as pd

SUPPORTED_EXTS = (".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp")

def batch_predict_to_csv(
    image_dir="disaster-ai/data/xbd/tier1/images",
    out_csv="batch_predictions.csv",
    recursive=True,
    quiet=False,
):
    """
    Scan image_dir for images, run inference with the trained checkpoint,
    and save predictions to a CSV with columns: image_path, pred, conf [, error].
    """
    image_dir = Path(image_dir)
    assert image_dir.exists(), f"Directory not found: {image_dir}"

    # Collect paths
    if recursive:
        candidates = image_dir.rglob("*")
    else:
        candidates = image_dir.glob("*")
    paths = [p for p in candidates if p.suffix.lower() in SUPPORTED_EXTS]
    paths = sorted(paths)
    if not paths:
        print(f"No images found in {image_dir} (recursive={recursive}).")
        return []

    # Load your trained pipeline (uses the helpers you already have)
    model, tform, classes = load_pipeline()  # from your earlier cell
    recs = []

    for p in paths:
        try:
            label, conf = predict_one(model, tform, classes, p)  # from your earlier cell
            rec = {"image_path": str(p), "pred": label, "conf": round(conf, 4)}
            if not quiet:
                print(f"{p.name:45s} -> {label:13s} (conf={conf:.2f})")
        except Exception as e:
            # keep going even if one image fails
            rec = {"image_path": str(p), "pred": None, "conf": None, "error": str(e)}
            if not quiet:
                print(f"⚠️  {p.name}: {e}")
        recs.append(rec)

    pd.DataFrame(recs).to_csv(out_csv, index=False)
    ok = sum(1 for r in recs if r["pred"] is not None)
    print(f"✅ Saved {ok}/{len(recs)} predictions to {out_csv}")
    return recs
# ------------------------------------------------------------------------------

In [9]:
# default: your dataset images folder, recursive, prints progress
batch_predict_to_csv()

# custom folder + quiet mode
batch_predict_to_csv("some/other/folder", out_csv="preds_other.csv", recursive=True, quiet=True)

guatemala-volcano_00000000_post_disaster.png  -> no-damage     (conf=0.98)
guatemala-volcano_00000000_pre_disaster.png   -> no-damage     (conf=0.99)
guatemala-volcano_00000001_post_disaster.png  -> minor-damage  (conf=1.00)
guatemala-volcano_00000001_pre_disaster.png   -> minor-damage  (conf=0.91)
guatemala-volcano_00000002_post_disaster.png  -> destroyed     (conf=0.91)
guatemala-volcano_00000002_pre_disaster.png   -> no-damage     (conf=0.91)
guatemala-volcano_00000006_post_disaster.png  -> no-damage     (conf=1.00)
guatemala-volcano_00000006_pre_disaster.png   -> no-damage     (conf=1.00)
guatemala-volcano_00000007_post_disaster.png  -> no-damage     (conf=0.99)
guatemala-volcano_00000007_pre_disaster.png   -> no-damage     (conf=1.00)
guatemala-volcano_00000008_pre_disaster.png   -> no-damage     (conf=0.62)
guatemala-volcano_00000010_post_disaster.png  -> destroyed     (conf=1.00)
guatemala-volcano_00000013_post_disaster.png  -> no-damage     (conf=1.00)
guatemala-volcano_0000001

AssertionError: Directory not found: some/other/folder