# Zero-shot Quantification

In [1]:
# ==========================================================
# Compute Dice metrics for ZERO-SHOT predictions only
# ==========================================================

from pathlib import Path
import numpy as np
import pandas as pd
import nibabel as nib
from scipy.ndimage import label as cc_label

# ----------------------------
# Paths
# ----------------------------

ZEROSHOT_ROOT = Path(
    "/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_zeroshot"
)

GT_ROOT = Path(
    "/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches"
)

ZEROSHOT_MODELS = ["clip_text", "clip_notext", "imgonly", "random"]

DATA_TYPE_TO_GT_SUBDIR = {
    "amyloid_plaque": "amyloid_plaque_patches",
    "c_fos_positive": "c_fos_positive_patches",
    "cell_nucleus": "cell_nucleus_patches",
    "vessels": "vessels_patches",
}

STRUCT_3D_26 = np.ones((3, 3, 3), dtype=bool)
PROB_THRESHOLD = 0.5


# ----------------------------
# Metric definitions (UNCHANGED)
# ----------------------------

def binary_dice(pred, gt, eps=1e-8):
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    return (2 * inter + eps) / (pred.sum() + gt.sum() + eps)

def background_dice(pred, gt, eps=1e-8):
    pred_bg = ~pred
    gt_bg = ~gt
    inter = np.logical_and(pred_bg, gt_bg).sum()
    return (2 * inter + eps) / (pred_bg.sum() + gt_bg.sum() + eps)

def total_dice(pred, gt):
    return 0.5 * (binary_dice(pred, gt) + background_dice(pred, gt))

def foreground_dice(pred, gt, eps=1e-8):
    fg_union = np.logical_or(pred, gt)
    if fg_union.sum() == 0:
        return 1.0
    p = pred[fg_union]
    g = gt[fg_union]
    inter = np.logical_and(p, g).sum()
    return (2 * inter + eps) / (p.sum() + g.sum() + eps)

def count_instances(mask):
    if mask.sum() == 0:
        return 0
    _, n = cc_label(mask.astype(bool), structure=STRUCT_3D_26)
    return n


# ----------------------------
# Compute metrics
# ----------------------------

rows = []

for model in ZEROSHOT_MODELS:

    for data_type, gt_sub in DATA_TYPE_TO_GT_SUBDIR.items():

        pred_dir = ZEROSHOT_ROOT / f'{data_type}_patches' / model / "preds"
        gt_dir = GT_ROOT / gt_sub

        if not pred_dir.exists():
            continue

        for prob_path in pred_dir.glob("*_prob.nii.gz"):
            stem = prob_path.name.replace("_prob.nii.gz", "")
            gt_path = gt_dir / f"{stem}_label.nii.gz"

            if not gt_path.exists():
                continue

            pred_prob = nib.load(prob_path).get_fdata().astype(np.float32)
            gt_bin = nib.load(gt_path).get_fdata() > 0.5

            pred_bin = pred_prob >= PROB_THRESHOLD

            td = total_dice(pred_bin, gt_bin)
            fd = foreground_dice(pred_bin, gt_bin)

            n_pred = count_instances(pred_bin)
            n_gt = count_instances(gt_bin)
            inst_d = 1.0 if (n_pred == 0 and n_gt == 0) else (
                2 * min(n_pred, n_gt) / (n_pred + n_gt)
            )

            rows.append({
                "model": model,
                "data_type": data_type,
                "total_dice": td,
                "foreground_dice": fd,
                "instance_dice": inst_d,
            })

metrics_df_zeroshot = pd.DataFrame(rows)
print("[INFO] Zero-shot metric rows:", len(metrics_df_zeroshot))


[INFO] Zero-shot metric rows: 352


In [2]:
display(metrics_df_zeroshot)

Unnamed: 0,model,data_type,total_dice,foreground_dice,instance_dice
0,clip_text,amyloid_plaque,0.364780,0.002520,0.007989
1,clip_text,amyloid_plaque,0.398683,0.001690,0.009677
2,clip_text,amyloid_plaque,0.370114,0.003098,0.006143
3,clip_text,amyloid_plaque,0.393572,0.001185,0.005839
4,clip_text,amyloid_plaque,0.415510,0.000141,0.001505
...,...,...,...,...,...
347,random,vessels,0.508492,0.047377,0.651163
348,random,vessels,0.506109,0.070662,0.719298
349,random,vessels,0.665294,0.377072,0.243902
350,random,vessels,0.494823,0.018872,0.164948


In [3]:
# Identify numeric metric columns explicitly
metric_cols = ["total_dice", "foreground_dice", "instance_dice"]

summary_table = (
    metrics_df_zeroshot
    .groupby(["model", "data_type"])[metric_cols]
    .mean()
    .reset_index()
    .sort_values(["data_type", "model"])
)

# Display nicely (only format numeric columns)
try:
    display(summary_table.style.format({c: "{:.4f}" for c in metric_cols}))
except NameError:
    print(summary_table)


Unnamed: 0,model,data_type,total_dice,foreground_dice,instance_dice
0,clip_notext,amyloid_plaque,0.4973,0.0124,0.1189
4,clip_text,amyloid_plaque,0.3927,0.0044,0.0105
8,imgonly,amyloid_plaque,0.5001,0.0062,0.0698
12,random,amyloid_plaque,0.5162,0.0389,0.1965
1,clip_notext,c_fos_positive,0.452,0.0009,0.3634
5,clip_text,c_fos_positive,0.3425,0.0327,0.8735
9,imgonly,c_fos_positive,0.509,0.0405,0.8097
13,random,c_fos_positive,0.5395,0.0914,0.5675
2,clip_notext,cell_nucleus,0.5003,0.0194,0.3123
6,clip_text,cell_nucleus,0.4114,0.0602,0.3194
