In [None]:
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
from imageio.v3 import imread
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score, \
                            f1_score, precision_score, recall_score, roc_auc_score

import torch

from __init__ import init_paths
init_paths()

from model import ModClassifier
from utils import get_image_transform

## Setup

In [None]:
# Paths
CWD = Path.cwd().parent # Or replace with the base directory
DATA_DIR = CWD / "data" / "test"  # adjust if needed
CHECKPOINT = CWD / "weights" / "default.ckpt"
OUTPUT_CSV = CWD / "notebooks" / "results.csv"  # set to None to skip saving

In [None]:
# Visualize Settings
BATCH_SIZE = 16
NUM_WORKERS = 2
THRESHOLD = 0.5
IMAGE_SIZE = 224
MAX_SHOW = 16  # max images to display

device = "cuda" if torch.cuda.is_available() else "cpu"
device

## Data

In [None]:
image_transform = get_image_transform(image_size=IMAGE_SIZE)

def load_img_as_tensor(path: Path):
    img = Image.open(path).convert("RGB")  # same as ImageFolder
    return image_transform(img)

In [None]:
DATA_DIR

In [None]:
EXTS = {".jpg", ".jpeg", ".png"}
images = [p for p in Path(DATA_DIR).glob("*/*.*") if p.is_file() and p.suffix.lower() in EXTS]
len(images), images[:5]

## Model

In [None]:
# Note: this models predicts if content is safe (true, > THRESHOLD) or not (false, <= THRESHOLD)
if CHECKPOINT and Path(CHECKPOINT).exists():
    model = ModClassifier.load_from_checkpoint(str(CHECKPOINT)).to(device).eval()
else:
    model = ModClassifier().to(device).eval()
CHECKPOINT

## Prediction

In [None]:
# optional: normalize folder names -> canonical labels
def _canon_label(name: str) -> str:
    n = name.strip().lower()
    if n in {"safe", "s"}:
        return "Safe"
    if n in {"not", "not safe", "n"}:
        return "Dangerous"
    return name  # fall back to raw folder name if it's something else

@torch.no_grad()
def predict_paths(paths, batch_size=BATCH_SIZE, threshold=THRESHOLD):
    rows = []
    for i in range(0, len(paths), batch_size):
        batch_paths = paths[i:i + batch_size]
        # GT labels from parent dir
        gt_labels = [_canon_label(p.parent.name) for p in batch_paths]

        batch_tensors = torch.stack([load_img_as_tensor(p) for p in batch_paths]).to(device)
        logits = model(batch_tensors).squeeze(1).float()
        probs = torch.sigmoid(logits).cpu().numpy()  # probability of "Safe" (positive class)

        pred_classes = np.where(probs > threshold, "Safe", "Dangerous")

        for p, pr, pc, gt in zip(batch_paths, probs, pred_classes, gt_labels):
            rows.append({
                "filename": str(p),
                "label": gt,                 # ground-truth from folder
                "probability_safe": float(pr),    # P(Safe)
                "predicted_class": pc,
                "correct": (pc == gt),
            })
    return pd.DataFrame(rows)

In [None]:
df = predict_paths(images)
df.head(10)

In [None]:
if OUTPUT_CSV != None:
    OUTPUT_CSV.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(OUTPUT_CSV, index=False)
    OUTPUT_CSV

## Visualize

In [None]:
def show_grid(df, max_n=MAX_SHOW, ncols=4):
    sel = df.sample(n=min(max_n, len(df)), random_state=None)
    n = len(sel)
    nrows = (n + ncols - 1) // ncols

    plt.figure(figsize=(4*ncols, 3*nrows))
    for i, row in enumerate(sel.itertuples(index=False)):
        img = imread(row.filename)
        ax = plt.subplot(nrows, ncols, i+1)
        ax.imshow(img if img.ndim == 3 else img, cmap=None if img.ndim == 3 else "gray")
        ax.axis("off")

        correct = getattr(row, "correct", row.label == row.predicted_class)
        title_color = "tab:green" if correct else "tab:red"
        ax.set_title(
            #f"{Path(row.filename).name}\n"
            #f"GT: {row.label} | Pred: {row.predicted_class} (p={row.probability_safe:.2f})",
            f"GT: {row.label} | Pred: {row.predicted_class}",
            fontsize=9,
            color=title_color,
        )
    plt.tight_layout()
    plt.show()

In [None]:
show_grid(df)

## Example Image for README.md

In [None]:
keepers = {
    "/app/data/test/safe/n02102040_3916.JPEG": True,
    "/app/data/test/not/n03000684_16549.JPEG": True,
    "/app/data/test/safe/n02102040_8304.JPEG": True,
    "/app/data/test/not/n03000684_25484.JPEG": True
}

plt.figure(figsize=(16, 3))
plt_cnt = 1
for i, row in enumerate(df.itertuples(index=False)):
    if row.filename not in keepers:
        continue
    img = imread(row.filename)
    ax = plt.subplot(1, 4, plt_cnt)
    plt_cnt += 1
    ax.imshow(img if img.ndim == 3 else img, cmap=None if img.ndim == 3 else "gray")
    ax.axis("off")

    correct = getattr(row, "correct", row.label == row.predicted_class)
    title_color = "tab:green" if correct else "tab:red"
    ax.set_title(
        f"GT: {row.label} | Pred: {row.predicted_class}",
        fontsize=15,
        color=title_color,
    )
plt.tight_layout()
plt.show()