In [8]:
# %%
# Cell 0. Batch case selection and paths (case-first layout)
from pathlib import Path

ROOT = Path("..").resolve()

# Use the same cases you generated previously
CASE_IDS = [f"elephant_7_val_{i}" for i in range(1, 16)]  # 11..26 inclusive

ONNX_PATH = ROOT / "data" / "models" / "resnet50_geirhos_tl_with_feats.onnx"
CLASS_NAMES_JSONL = ROOT / "data" / "class_names.jsonl"

assert ONNX_PATH.exists(), f"Missing ONNX model: {ONNX_PATH}"
assert CLASS_NAMES_JSONL.exists(), f"Missing class_names.jsonl: {CLASS_NAMES_JSONL}"

# Output: write into results/certainty_metrics.jsonl (same as you used before)
OUT_DIR = ROOT / "results"
OUT_DIR.mkdir(parents=True, exist_ok=True)
CERT_JSONL = OUT_DIR / "certainty_metrics.jsonl"

# Helper to build per-case paths (matches previous notebook structure)
def case_paths(case_id: str) -> dict:
    case_dir = ROOT / "data" / "cases" / case_id
    gen_dir  = case_dir / "generated"
    paths = {
        "case_id": case_id,
        "case_dir": case_dir,
        "case_jsonl": gen_dir / f"{case_id}.jsonl",
        "occluded_png": case_dir / "occluded.png",
        "gt_png": case_dir / "gt.png",
        "shapes_xy_npz": gen_dir / "shapes_xy.npz",
        "completions_dir": gen_dir / "completions",
    }
    return paths

# Validate all cases exist and have expected files
cases = []
missing_any = False
for cid in CASE_IDS:
    p = case_paths(cid)
    ok = True
    if not p["case_dir"].exists():
        print("Missing CASE_DIR:", p["case_dir"])
        ok = False
    for k in ["case_jsonl", "occluded_png", "gt_png", "shapes_xy_npz", "completions_dir"]:
        if not p[k].exists():
            print(f"Missing {k} for {cid}:", p[k])
            ok = False
    if ok:
        cases.append(p)
    else:
        missing_any = True

print("Requested cases:", len(CASE_IDS))
print("Cases ready     :", len(cases))
print("OUT_DIR         :", OUT_DIR)
print("CERT_JSONL      :", CERT_JSONL)
if missing_any:
    print("Some cases are missing required files. Only 'Cases ready' will be processed.")


Requested cases: 15
Cases ready     : 15
OUT_DIR         : /home/hschatzle/monte-carlo-selection/results
CERT_JSONL      : /home/hschatzle/monte-carlo-selection/results/certainty_metrics.jsonl


In [9]:
# %%
# Cell 1. Load class names (JSONL: one JSON object per line)
import json

with CLASS_NAMES_JSONL.open("r", encoding="utf-8") as f:
    classes = [json.loads(line)["class_name"] for line in f if line.strip()]

print("Loaded class list:", len(classes))
print("First 10 classes:", classes[:10])


Loaded class list: 54
First 10 classes: ['ant', 'bat', 'bear', 'bee', 'beetle', 'bird', 'bug', 'bull', 'butterfly', 'camel']


In [10]:
# %%
# Cell 2. ONNX session + preprocessing (shared)
import onnxruntime as ort
import numpy as np
from PIL import Image

so = ort.SessionOptions()
so.intra_op_num_threads = 8
so.inter_op_num_threads = 1

SESSION = ort.InferenceSession(
    str(ONNX_PATH),
    sess_options=so,
    providers=["CPUExecutionProvider"],
)

IN_NAME = SESSION.get_inputs()[0].name
OUTS = [o.name for o in SESSION.get_outputs()]

# These are the names you used in the other notebook
LOGITS_OUT_NAME = "output"
PENULT_OUT_NAME = "features"

assert LOGITS_OUT_NAME in OUTS, f"Model outputs {OUTS}, missing logits output '{LOGITS_OUT_NAME}'"
assert PENULT_OUT_NAME in OUTS, f"Model outputs {OUTS}, missing penult output '{PENULT_OUT_NAME}'"

IM_SIZE = 224
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def preprocess_png(p: str | Path) -> np.ndarray:
    img = Image.open(p).convert("RGB").resize((IM_SIZE, IM_SIZE), resample=Image.BILINEAR)
    x = np.asarray(img, dtype=np.float32) / 255.0
    x = (x - MEAN[None, None, :]) / STD[None, None, :]
    x = np.transpose(x, (2, 0, 1))[None, ...]  # (1,3,224,224)
    return x.astype(np.float32, copy=False)

def infer_penult_and_logits(p: str | Path):
    x = preprocess_png(p)
    pen, logits = SESSION.run([PENULT_OUT_NAME, LOGITS_OUT_NAME], {IN_NAME: x})
    return np.asarray(pen), np.asarray(logits).reshape(-1)

def softmax(logits: np.ndarray) -> np.ndarray:
    z = np.asarray(logits, dtype=np.float64).reshape(-1)
    z = z - np.max(z)
    e = np.exp(z)
    return e / (np.sum(e) + 1e-12)


In [11]:
# %%
# Cell 3. Load shapes_xy.npz for a case and resolve completion PNG paths
import numpy as np
from pathlib import Path

def load_case_shapes(case: dict):
    z = np.load(case["shapes_xy_npz"], allow_pickle=True)
    out_files_raw = z["out_files"].tolist()
    polygons_xy   = z["polygons"]
    matlab_1_indexed = bool(z["matlab_1_indexed"]) if "matlab_1_indexed" in z else False

    completions_dir = case["completions_dir"]

    def resolve_png_path(p: str) -> Path:
        pth = Path(p)
        if pth.exists():
            return pth
        return (completions_dir / pth.name)

    png_paths = [resolve_png_path(p) for p in out_files_raw]
    missing = sum(1 for p in png_paths if not p.exists())

    return png_paths, polygons_xy, matlab_1_indexed, missing

# quick check first ready case
if len(cases) > 0:
    png_paths0, polys0, m10, miss0 = load_case_shapes(cases[0])
    print("Example case:", cases[0]["case_id"])
    print("N completions:", len(png_paths0))
    print("Missing PNGs :", miss0)
    print("polygons_xy  :", polys0.shape, polys0.dtype)
    print("matlab_1_idx :", m10)


Example case: elephant_7_val_1
N completions: 10000
Missing PNGs : 0
polygons_xy  : (10000,) object
matlab_1_idx : True


In [12]:
# %%
# Cell 4. Compute target class per case from occluded image
# Target = argmax logits on occluded baseline (same as your current logic)

def target_from_occluded(case: dict):
    _, occ_logits = infer_penult_and_logits(case["occluded_png"])
    occ_prob = softmax(occ_logits)
    target_idx = int(np.argmax(occ_logits))
    return target_idx, float(occ_logits[target_idx]), float(occ_prob[target_idx])

# sanity
if len(cases) > 0:
    tidx, tlog0, tpr0 = target_from_occluded(cases[0])
    print("Example case:", cases[0]["case_id"])
    print("TARGET idx:", tidx, "class:", classes[tidx])
    print("baseline logit:", tlog0, "prob:", tpr0)


Example case: elephant_7_val_1
TARGET idx: 21 class: elephant
baseline logit: 7.087069988250732 prob: 0.6987596641956342


In [13]:
# %%
# Cell. Robust parallel scoring (always returns tlog,tpr,tmar,tdel)

from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np
from tqdm.auto import tqdm

N_WORKERS = 8

def _score_chunk_onnx(args):
    """
    args = (chunk_pairs, target_idx, occ_tlog, onnx_path_str, logits_out_name)
    chunk_pairs: list of (i, path_str)
    Returns: (idxs, tlog_vals, tpr_vals, tmar_vals, tdel_vals, n_scored)
    """
    chunk_pairs, target_idx, occ_tlog, onnx_path_str, logits_out_name = args

    import os
    import numpy as np
    import onnxruntime as ort
    from PIL import Image

    so = ort.SessionOptions()
    so.intra_op_num_threads = 1
    so.inter_op_num_threads = 1

    sess = ort.InferenceSession(onnx_path_str, sess_options=so, providers=["CPUExecutionProvider"])
    in_name = sess.get_inputs()[0].name

    IM_SIZE = 224
    MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    def preprocess(p):
        img = Image.open(p).convert("RGB").resize((IM_SIZE, IM_SIZE), resample=Image.BILINEAR)
        x = np.asarray(img, dtype=np.float32) / 255.0
        x = (x - MEAN[None, None, :]) / STD[None, None, :]
        x = np.transpose(x, (2, 0, 1))[None, ...]
        return x.astype(np.float32, copy=False)

    def softmax_local(logits):
        z = np.asarray(logits, dtype=np.float64).reshape(-1)
        z = z - np.max(z)
        e = np.exp(z)
        return e / (np.sum(e) + 1e-12)

    idxs = []
    v_tlog = []
    v_tpr  = []
    v_tmar = []
    v_tdel = []

    for i, pstr in chunk_pairs:
        try:
            if not os.path.exists(pstr):
                continue

            x = preprocess(pstr)
            logits = sess.run([logits_out_name], {in_name: x})[0]
            logits = np.asarray(logits).reshape(-1)

            prob = softmax_local(logits)
            tl = float(logits[int(target_idx)])
            other_max = float(np.max(np.delete(logits, int(target_idx))))

            idxs.append(int(i))
            v_tlog.append(tl)
            v_tpr.append(float(prob[int(target_idx)]))
            v_tdel.append(float(tl - float(occ_tlog)))
            v_tmar.append(float(tl - other_max))
        except Exception:
            continue

    return (
        np.asarray(idxs, dtype=np.int64),
        np.asarray(v_tlog, dtype=np.float64),
        np.asarray(v_tpr, dtype=np.float64),
        np.asarray(v_tmar, dtype=np.float64),
        np.asarray(v_tdel, dtype=np.float64),
        int(len(idxs)),
    )


def score_case_parallel(case: dict, png_paths: list, target_idx: int, occ_tlog: float):
    N = len(png_paths)
    tlog = np.full(N, np.nan, dtype=np.float64)
    tpr  = np.full(N, np.nan, dtype=np.float64)
    tmar = np.full(N, np.nan, dtype=np.float64)
    tdel = np.full(N, np.nan, dtype=np.float64)

    pairs = [(i, str(p)) for i, p in enumerate(png_paths)]
    if len(pairs) == 0:
        print("No png_paths for case:", case.get("case_id", "unknown"))
        return tlog, tpr, tmar, tdel

    chunk_size = max(1, len(pairs) // (N_WORKERS * 8))
    chunks = [pairs[j:j+chunk_size] for j in range(0, len(pairs), chunk_size)]

    args_list = [(ch, target_idx, occ_tlog, str(ONNX_PATH), LOGITS_OUT_NAME) for ch in chunks]

    n_scored_total = 0
    with ProcessPoolExecutor(max_workers=N_WORKERS) as ex:
        futs = [ex.submit(_score_chunk_onnx, a) for a in args_list]
        for fut in tqdm(as_completed(futs), total=len(futs), desc=f"Scoring {case.get('case_id','case')}"):
            idxs, a, b, c, d, n_scored = fut.result()
            if idxs.size:
                tlog[idxs] = a
                tpr[idxs]  = b
                tmar[idxs] = c
                tdel[idxs] = d
            n_scored_total += int(n_scored)

    finite = int(np.isfinite(tlog).sum())
    if finite == 0:
        print("WARNING: scored 0 images for", case.get("case_id", "case"))
        print("  N paths:", N)
        print("  Example path:", str(png_paths[0]) if N else "NA")
        print("  ONNX_PATH:", str(ONNX_PATH))
        print("  LOGITS_OUT_NAME:", LOGITS_OUT_NAME)

    return tlog, tpr, tmar, tdel


In [14]:
# %%
# Cell 6. Choose best completion (by target margin or target logit) and write JSONL row
# Uses the same output schema you showed earlier.

import json
import numpy as np
from datetime import datetime, timezone

def certainty_metrics_from_logits(logits: np.ndarray) -> dict:
    z = np.asarray(logits, dtype=np.float64).reshape(-1)
    p = softmax(z)
    k = int(np.argmax(z))
    z_top = float(z[k])
    z_2nd = float(np.max(np.delete(z, k))) if z.size > 1 else float("-inf")
    margin = z_top - z_2nd
    p_top = float(np.max(p))
    entropy = float(-np.sum(p * np.log(p + 1e-12)))
    # logsumexp
    zmax = float(np.max(z))
    energy = float(np.log(np.sum(np.exp(z - zmax)) + 1e-12) + zmax)
    return {
        "pred_idx": k,
        "pred_logit": z_top,
        "margin_top2": float(margin),
        "p_top1": float(p_top),
        "entropy": float(entropy),
        "energy_logsumexp": float(energy),
    }

def cls_name(idx: int) -> str:
    return classes[int(idx)]

def pick_best_index(tlog: np.ndarray, tmar: np.ndarray, *, prefer: str = "margin") -> int:
    finite = np.isfinite(tlog) & np.isfinite(tmar)
    if int(finite.sum()) == 0:
        raise RuntimeError("No finite scores to pick best completion.")
    idxs = np.where(finite)[0]
    if prefer == "margin":
        j = idxs[np.argmax(tmar[idxs])]
    elif prefer == "logit":
        j = idxs[np.argmax(tlog[idxs])]
    else:
        raise ValueError("prefer must be 'margin' or 'logit'")
    return int(j)

def append_certainty_row(case: dict, best_png: Path):
    # infer logits for best, gt, occ
    _, logits_best = infer_penult_and_logits(best_png)
    _, logits_gt   = infer_penult_and_logits(case["gt_png"])
    _, logits_occ  = infer_penult_and_logits(case["occluded_png"])

    m_best = certainty_metrics_from_logits(logits_best)
    m_gt   = certainty_metrics_from_logits(logits_gt)
    m_occ  = certainty_metrics_from_logits(logits_occ)

    metrics_def = {
        "margin_top2": {"direction": "higher = more certain"},
        "p_top1": {"direction": "higher = more certain"},
        "entropy": {"direction": "lower = more certain"},
        "energy_logsumexp": {"direction": "higher = more certain"},
    }

    row = {
        "timestamp_utc": datetime.now(timezone.utc).isoformat(timespec="seconds"),
        "case_id": case["case_id"],
        "best_png": str(best_png),
        "gt_png": str(case["gt_png"]),
        "occ_png": str(case["occluded_png"]),
        "best_pred_idx": int(m_best["pred_idx"]),
        "gt_pred_idx": int(m_gt["pred_idx"]),
        "occ_pred_idx": int(m_occ["pred_idx"]),
        "best_pred_class": cls_name(m_best["pred_idx"]),
        "gt_pred_class": cls_name(m_gt["pred_idx"]),
        "occ_pred_class": cls_name(m_occ["pred_idx"]),
        "metrics": {
            k: {
                "best": float(m_best[k]),
                "gt": float(m_gt[k]),
                "occ": float(m_occ[k]),
                "direction": metrics_def[k]["direction"],
            }
            for k in metrics_def.keys()
        },
        "deltas": {
            "best_minus_gt": {
                k: float(m_best[k] - m_gt[k]) for k in metrics_def.keys()
            },
            "gt_minus_occ": {
                k: float(m_gt[k] - m_occ[k]) for k in metrics_def.keys()
            },
        },
    }

    with CERT_JSONL.open("a", encoding="utf-8") as f:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

    return row

print("Will append rows to:", CERT_JSONL)


Will append rows to: /home/hschatzle/monte-carlo-selection/results/certainty_metrics.jsonl


In [None]:
# %%
# Cell 7. Run the batch: for each case, score completions, pick best, append row

import numpy as np

PICK_RULE = "margin"  # "margin" or "logit"

rows_written = []
for case in cases:
    png_paths, polygons_xy, matlab_1_indexed, missing = load_case_shapes(case)
    if missing > 0:
        print("Warning:", case["case_id"], "missing PNGs:", missing)

    target_idx, occ_tlog, occ_tpr = target_from_occluded(case)

    # Score all completions in parallel (logits for target, margins, etc.)
    tlog, tpr, tmar, tdel = score_case_parallel(case, png_paths, target_idx, occ_tlog)

    best_i = pick_best_index(tlog, tmar, prefer=PICK_RULE)
    best_png = png_paths[best_i]

    print("\nCASE:", case["case_id"])
    print("  target:", classes[target_idx], "idx:", target_idx, "occ_logit:", occ_tlog)
    print("  best:", best_png.name)
    print("  best target_logit:", float(tlog[best_i]), "best target_margin:", float(tmar[best_i]), "best target_prob:", float(tpr[best_i]))

    row = append_certainty_row(case, best_png)
    rows_written.append(row)

print("\nDone. Rows appended:", len(rows_written))


Scoring elephant_7_val_1:   0%|          | 0/65 [00:00<?, ?it/s]


CASE: elephant_7_val_1
  target: elephant idx: 21 occ_logit: 7.087069988250732
  best: completion_0001_06673.png
  best target_logit: 7.2816033363342285 best target_margin: 1.891324520111084 best target_prob: 0.6668654229684642


Scoring elephant_7_val_2:   0%|          | 0/65 [00:00<?, ?it/s]


CASE: elephant_7_val_2
  target: elephant idx: 21 occ_logit: 7.087069988250732
  best: completion_0001_03881.png
  best target_logit: 7.3613600730896 best target_margin: 1.9310007095336914 best target_prob: 0.6694445294113563


Scoring elephant_7_val_3:   0%|          | 0/65 [00:00<?, ?it/s]


CASE: elephant_7_val_3
  target: elephant idx: 21 occ_logit: 7.087069988250732
  best: completion_0001_01436.png
  best target_logit: 7.332314491271973 best target_margin: 1.9637165069580078 best target_prob: 0.6637961821264552


Scoring elephant_7_val_4:   0%|          | 0/65 [00:00<?, ?it/s]


CASE: elephant_7_val_4
  target: elephant idx: 21 occ_logit: 7.087069988250732
  best: completion_0001_05537.png
  best target_logit: 7.148738861083984 best target_margin: 1.8863897323608398 best target_prob: 0.6392780543373996


Scoring elephant_7_val_5:   0%|          | 0/65 [00:00<?, ?it/s]