# Tables for paper

## Setup

In [1]:
# imports
from dataclasses import dataclass
from __future__ import annotations
from IPython.display import display
import json
import math
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import pandas as pd
from pathlib import Path
import re
from scipy.ndimage import label as cc_label
import tifffile as tiff
from typing import Dict, List, Optional, Tuple
import warnings


In [2]:
# define constants
PROB_THRESHOLD = 0.5
STRUCT_3D_26 = np.ones((3, 3, 3), dtype=bool)
PATCH_ID_RE = re.compile(r"(patch_\d{3})")
VOL_RE = re.compile(r"(vol\d{3})")
CH_RE = re.compile(r"(ch[01])")

# datatype mapping
DTYPES_GT = ["amyloid_plaque_patches", "cell_nucleus_patches", "vessels_patches"]
DTYPE_CANON = {
    "amyloid_plaque_patches": "Amyloid Plaque",
    "cell_nucleus_patches": "Cell Nucleus",
    "vessels_patches": "Vessels",
}

DTYPE_FEWMANY_ALIASES = {
    "amyloid_plaque_patches": ["amyloid_plaque", "amyloid_plaque_patches"],
    "cell_nucleus_patches": ["cell_nucleus", "cell_nucleus_patches"],
    "vessels_patches": ["vessels", "vessels_patches"],
}

# shot definitions
SHOTS = {
    "Zero-shot": {"mode": "zeroshot"},
    "Few-shot":  {"mode": "fewmany", "ntr": 5},
    "Many-shot": {"mode": "fewmany", "ntr": 15},
}

# folds
FOLDS = [0, 1, 2]



In [3]:
# data paths

# ground truth
GT_ROOT = Path("/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches")

# zeroshot
ZEROSHOT_ROOTS = {
    # Unet zeroshot preds: _zeroshot_unet, Swin: _zeroshot/results
    "unet_zeroshot_root": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_zeroshot_unet"),
    "swin_zeroshot_root": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_zeroshot"),
}

# few/many-shot
FEWMANY_ROOTS = {
    "Unet Image+CLIP": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_super_sweep2"),
    "Unet Image-only": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_bright_sweep_26"),
    "Unet Random-init": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_unet_random2"),

    "SwinUNETR Image+CLIP": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_autumn_sweep_27_v2"),
    "SwinUNETR Image+CLIP (overtrain)": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_autumn_sweep_27_long"),
    "SwinUNETR Image-only": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_expert_sweep_31_v2"),
    "SwinUNETR Random-init": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_rand_v2"),

    "microSAM base": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_cross_val_b2"),
    "microSAM large": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_cross_val_l"),

    "CellSeg3D": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuned_cross_val"),

    "Cellpose 2D": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellpose/cross_val"),
    "Cellpose 3D": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellpose/cross_val"),
}


## Functions

In [4]:
# data locating

def _first_existing(paths: List[Path]) -> Optional[Path]:
    for p in paths:
        if p.exists():
            return p
    return None

# find gt label for given patch (ex: .../{dtype_gt}/{patch_id}_{vol}_ch0_label.nii.gz)
def find_gt_label_for_patch(dtype_gt: str, patch_id: str, vol: str, ch: Optional[str]) -> Optional[Path]:
    gt_dir = GT_ROOT / dtype_gt
    candidates: List[Path] = []
    if ch is not None:
        candidates.append(gt_dir / f"{patch_id}_{vol}_{ch}_label.nii.gz")
    else:
        candidates.extend([
            gt_dir / f"{patch_id}_{vol}_ch0_label.nii.gz",
            gt_dir / f"{patch_id}_{vol}_ch1_label.nii.gz",
        ])

    for p in candidates:
        if p.exists():
            return p
    return None

# parse patch, vol, ch from filename
def parse_patch_tokens(p: Path) -> Tuple[Optional[str], Optional[str], Optional[str]]:
    m_patch = PATCH_ID_RE.search(p.name)
    m_vol = VOL_RE.search(p.name)
    m_ch = CH_RE.search(p.name)
    return (
        m_patch.group(1) if m_patch else None,
        m_vol.group(1) if m_vol else None,
        m_ch.group(1) if m_ch else None,
    )

def pick_run_dir(parent: Path, fold: int, ntr: int) -> Optional[Path]:
    if not parent.exists():
        return None
    pat = re.compile(rf"cvfold{fold}_ntr{ntr}(?:_|$)")
    candidates = [d for d in parent.iterdir() if d.is_dir() and pat.search(d.name)]
    if not candidates:
        return None
    candidates.sort(key=lambda d: d.stat().st_mtime, reverse=True)
    return candidates[0]

def list_pred_files_for_run(run_dir: Path) -> List[Path]:
    """
    Handles different layouts:
    - .../run_dir/preds/*.nii.gz
    - .../run_dir/preds/preds/*.nii.gz
    - .../run_dir/patches/*.nii.gz (microSAM)
    - CellSeg3D: *_instances.tif in run_dir
    """
    # most common layouts
    candidates = [
        run_dir / "preds",
        run_dir / "preds" / "preds",
        run_dir / "patches",
    ]
    for c in candidates:
        if c.exists() and c.is_dir():
            files = sorted([p for p in c.rglob("*") if p.is_file() and (p.suffix in [".gz", ".tif"] or p.name.endswith(".nii.gz"))])
            if files:
                return files

    # fallback: run_dir itself
    files = sorted([p for p in run_dir.rglob("*") if p.is_file() and (p.suffix in [".gz", ".tif"] or p.name.endswith(".nii.gz"))])
    return files



In [5]:
# data loading functions

def load_nifti(path: Path) -> np.ndarray:
    img = nib.load(str(path))
    return img.get_fdata(dtype=np.float32)

def _squeeze_singleton_channel(arr: np.ndarray) -> np.ndarray:
    return arr[0] if (arr.ndim == 4 and arr.shape[0] == 1) else arr

def _coerce_pred_gt_shapes(pred: np.ndarray, gt: np.ndarray):
    return _squeeze_singleton_channel(pred), _squeeze_singleton_channel(gt)


In [6]:
# metrics functions

def safe_mean(xs: List[float]) -> float:
    xs = [x for x in xs if x is not None and not (isinstance(x, float) and (math.isnan(x) or math.isinf(x)))]
    return float(np.mean(xs)) if len(xs) else float("nan")

def binary_dice(pred: np.ndarray, gt: np.ndarray, eps: float = 1e-8) -> float:
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    return (2.0 * inter + eps) / (pred.sum() + gt.sum() + eps)

def background_dice(pred: np.ndarray, gt: np.ndarray, eps: float = 1e-8) -> float:
    pred_bg = np.logical_not(pred.astype(bool))
    gt_bg = np.logical_not(gt.astype(bool))
    inter = np.logical_and(pred_bg, gt_bg).sum()
    return (2.0 * inter + eps) / (pred_bg.sum() + gt_bg.sum() + eps)

def total_dice(pred_bin: np.ndarray, gt_bin: np.ndarray, eps: float = 1e-8) -> float:
    return 0.5 * (binary_dice(pred_bin, gt_bin, eps) + background_dice(pred_bin, gt_bin, eps))

def foreground_dice(pred_bin: np.ndarray, gt_bin: np.ndarray, eps: float = 1e-8) -> float:
    pred_bin = pred_bin.astype(bool)
    gt_bin = gt_bin.astype(bool)
    fg_union = np.logical_or(pred_bin, gt_bin)
    if fg_union.sum() == 0:
        return 1.0
    pred_f = pred_bin[fg_union]
    gt_f = gt_bin[fg_union]
    inter = np.logical_and(pred_f, gt_f).sum()
    return (2.0 * inter + eps) / (pred_f.sum() + gt_f.sum() + eps)

def count_instances(mask: np.ndarray, structure=STRUCT_3D_26) -> int:
    mask = mask.astype(bool)
    if mask.sum() == 0:
        return 0
    _, n = cc_label(mask, structure=structure)
    return int(n)

def compute_metrics_for_pair(pred_path: Path, gt_path: Path) -> dict:
    if pred_path.suffix == ".tif":
        pred_arr = tiff.imread(str(pred_path)).astype(np.float32)
    else:
        pred_arr = load_nifti(pred_path)

    gt_arr = load_nifti(gt_path)
    pred_arr, gt_arr = _coerce_pred_gt_shapes(pred_arr, gt_arr)

    pred_bin = pred_arr >= PROB_THRESHOLD
    gt_bin = gt_arr > 0.5

    if pred_bin.shape != gt_bin.shape:
        raise ValueError(f"Shape mismatch: pred {pred_bin.shape}, gt {gt_bin.shape}")

    td = total_dice(pred_bin, gt_bin)
    fd = foreground_dice(pred_bin, gt_bin)
    n_pred = count_instances(pred_bin)
    n_gt = count_instances(gt_bin)

    inst_dice = 1.0 if (n_pred == 0 and n_gt == 0) else (2.0 * min(n_pred, n_gt)) / float(n_pred + n_gt)

    return {
        "total_dice": float(td),
        "foreground_dice": float(fd),
        "instance_dice": float(inst_dice),
        "instances_pred": int(n_pred),
        "instances_gt": int(n_gt),
    }

# wrapper to return only total_dice and instance_dice using compute_metrics_for_pair() function
def compute_pair_metrics(pred_path: Path, gt_label_path: Path) -> Dict[str, float]:
    d = compute_metrics_for_pair(pred_path, gt_label_path)
    return {
        "total_dice": float(d["total_dice"]),
        "instance_dice": float(d["instance_dice"]),
    }


In [7]:
# zero-shot data collecting
def zeroshot_collect_unet(dtype_gt: str, variant_suffix: str) -> List[Tuple[Path, Path]]:
    """
    Unet zeroshot layout described:
    .../temp_selma_segmentation_preds_zeroshot_unet/results/{dtype_gt}/patch_000_..._pred_{variant}.nii.gz
    variant_suffix in {"image_clip","image_only","random"}
    """
    root = ZEROSHOT_ROOTS["unet_zeroshot_root"] / "results" / dtype_gt
    pairs = []
    for i in range(10):
        patch_id = f"patch_{i:03d}"
        pred_glob = list(root.glob(f"{patch_id}_*pred_{variant_suffix}.nii.gz"))
        if not pred_glob:
            continue
        pred_path = sorted(pred_glob)[0]
        patch_id2, vol, ch = parse_patch_tokens(pred_path)
        if patch_id2 is None or vol is None or ch is None:
            continue
        gt = find_gt_label_for_patch(dtype_gt, patch_id2, vol, ch)
        if gt is None:
            continue
        pairs.append((pred_path, gt))
    return pairs

def zeroshot_collect_swin(dtype_gt: str, variant_suffix: str) -> List[Tuple[Path, Path]]:
    """
    Swin zeroshot layout described:
    .../temp_selma_segmentation_preds_zeroshot/results/{dtype_gt}/patch_000_..._pred_{variant}.nii.gz
    """
    root = ZEROSHOT_ROOTS["swin_zeroshot_root"] / "results" / dtype_gt
    pairs = []
    for i in range(10):
        patch_id = f"patch_{i:03d}"
        pred_glob = list(root.glob(f"{patch_id}_*pred_{variant_suffix}.nii.gz"))
        if not pred_glob:
            continue
        pred_path = sorted(pred_glob)[0]
        patch_id2, vol, ch = parse_patch_tokens(pred_path)
        if patch_id2 is None or vol is None or ch is None:
            continue
        gt = find_gt_label_for_patch(dtype_gt, patch_id2, vol, ch)
        if gt is None:
            continue
        pairs.append((pred_path, gt))
    return pairs


In [8]:
# few/many-shot data collecting

def fewmany_collect_pairs(model_root: Path, dtype_gt: str, fold: int, ntr: int, *, mode: str) -> List[Tuple[Path, Path]]:
    """
    Collect (pred, gt_label) pairs for a given model, datatype, fold, and ntr.

    mode:
      - "standard": roots like .../<exp_root>/preds/{dtype}/<run_dir>/(preds|patches)/*.nii.gz
      - "microsam": similar but uses 'patches' folder
      - "cellseg": tif instances
      - "cellpose2d"/"cellpose3d": select files containing pred2d or pred3d
    """

    if mode in {"cellpose2d", "cellpose3d"}:
        pred_kind = "2d" if mode == "cellpose2d" else "3d"
        return collect_cellpose_pairs(model_root, fold=fold, ntr=ntr, pred_kind=pred_kind)

    # find datatype folder under model_root/preds
    preds_root = model_root / "preds"

    dtype_candidates = [preds_root / name for name in DTYPE_FEWMANY_ALIASES[dtype_gt]]
    dtype_dir = _first_existing(dtype_candidates)
    if dtype_dir is None:
        return []

    run_dir = pick_run_dir(dtype_dir, fold=fold, ntr=ntr)
    if run_dir is None:
        return []

    files = list_pred_files_for_run(run_dir)

    # filter per-mode
    if mode == "cellseg":
        files = [p for p in files if p.suffix == ".tif" and p.name.endswith("_instances.tif")]
    elif mode == "cellpose2d":
        files = [p for p in files if p.name.endswith(".nii.gz") and ("pred2d_" in p.name)]
    elif mode == "cellpose3d":
        files = [p for p in files if p.name.endswith(".nii.gz") and ("pred3d_" in p.name)]
    else:
        # standard/microsam: keep nii.gz preds (and ignore any labels etc)
        files = [p for p in files if p.name.endswith(".nii.gz") and ("_label" not in p.name)]

    pairs: List[Tuple[Path, Path]] = []
    for pred_path in files:
        patch_id, vol, ch = parse_patch_tokens(pred_path)
        if patch_id is None or vol is None or ch is None:
            continue
        gt = find_gt_label_for_patch(dtype_gt, patch_id, vol, ch)
        if gt is None:
            continue
        pairs.append((pred_path, gt))

    return pairs

# few/many-shot evaluation
def eval_fold(model_root: Path, dtype_gt: str, fold: int, ntr: int, mode: str) -> Dict[str, float]:
    pairs = fewmany_collect_pairs(model_root, dtype_gt, fold, ntr, mode=mode)
    if not pairs:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    tds, ids = [], []
    for pred, gt in pairs:
        m = compute_pair_metrics(pred, gt)
        tds.append(m["total_dice"])
        ids.append(m["instance_dice"])

    return {"total_dice": safe_mean(tds), "instance_dice": safe_mean(ids)}


In [9]:
# collect cellpose pairs
def collect_cellpose_pairs(
    model_root: Path,
    fold: int,
    ntr: int,
    pred_kind: str,  # "2d" or "3d"
) -> List[Tuple[Path, Path]]:
    """
    Cellpose layout:
      .../preds/cell_nucleus_patches/<run_dir>/patch_007_vol006_pred2d_...nii.gz
      .../preds/cell_nucleus_patches/<run_dir>/patch_007_vol006_pred3d_...nii.gz

    GT labels:
      .../selma3d_finetune_patches/cell_nucleus_patches/patch_007_vol006_ch0_label.nii.gz (or ch1)
    """
    assert pred_kind in {"2d", "3d"}
    dtype_gt = "cell_nucleus_patches"
    dtype_dir = model_root / "preds" / dtype_gt
    if not dtype_dir.exists():
        return []

    # Find run dir (most recent) for this fold+ntr
    pat = re.compile(rf"cvfold{fold}_ntr{ntr}(?:_|$)")
    run_dirs = [d for d in dtype_dir.iterdir() if d.is_dir() and pat.search(d.name)]
    if not run_dirs:
        return []
    run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
    run_dir = run_dirs[0]

    # Gather preds directly under run_dir (your example has them there)
    all_preds = sorted(run_dir.rglob("*.nii.gz"))
    want_tag = f"_pred{pred_kind}_"
    preds = [p for p in all_preds if want_tag in p.name]

    pairs: List[Tuple[Path, Path]] = []
    gt_dir = GT_ROOT / dtype_gt

    for pred_path in preds:
        m_patch = PATCH_ID_RE.search(pred_path.name)
        m_vol = VOL_RE.search(pred_path.name)
        if not m_patch or not m_vol:
            continue
        patch_id = m_patch.group(1)
        vol = m_vol.group(1)

        # Cellpose pred doesn't encode channel → try ch0 then ch1
        gt0 = gt_dir / f"{patch_id}_{vol}_ch0_label.nii.gz"
        gt1 = gt_dir / f"{patch_id}_{vol}_ch1_label.nii.gz"

        if gt0.exists():
            pairs.append((pred_path, gt0))
        elif gt1.exists():
            pairs.append((pred_path, gt1))
        # else: no GT label found -> skip

    return pairs


In [10]:
# model registry (what to compute for each model)
@dataclass
class ModelSpec:
    name: str
    # zeroshot source: "unet" or "swin" or None
    zeroshot_source: Optional[str] = None
    zeroshot_variant: Optional[str] = None  # "image_clip"|"image_only"|"random"
    fewmany_root: Optional[Path] = None
    fewmany_mode: Optional[str] = None      # "standard"|"microsam"|"cellseg"|"cellpose2d"|"cellpose3d"

MODEL_SPECS: List[ModelSpec] = [
    # --- UNet zeroshot variants (all from same unet zeroshot root) ---
    ModelSpec("Unet Image+CLIP", zeroshot_source="unet", zeroshot_variant="image_clip", fewmany_root=FEWMANY_ROOTS["Unet Image+CLIP"], fewmany_mode="standard"),
    ModelSpec("Unet Image-only", zeroshot_source="unet", zeroshot_variant="image_only", fewmany_root=FEWMANY_ROOTS["Unet Image-only"], fewmany_mode="standard"),
    ModelSpec("Unet Random-init", zeroshot_source="unet", zeroshot_variant="random", fewmany_root=FEWMANY_ROOTS["Unet Random-init"], fewmany_mode="standard"),

    # --- Swin zeroshot variants ---
    ModelSpec("SwinUNETR Image+CLIP", zeroshot_source="swin", zeroshot_variant="image_clip", fewmany_root=FEWMANY_ROOTS["SwinUNETR Image+CLIP"], fewmany_mode="standard"),
    ModelSpec("SwinUNETR Image+CLIP (overtrain)", zeroshot_source=None, zeroshot_variant=None, fewmany_root=FEWMANY_ROOTS["SwinUNETR Image+CLIP (overtrain)"], fewmany_mode="standard"),
    ModelSpec("SwinUNETR Image-only", zeroshot_source="swin", zeroshot_variant="image_only", fewmany_root=FEWMANY_ROOTS["SwinUNETR Image-only"], fewmany_mode="standard"),
    ModelSpec("SwinUNETR Random-init", zeroshot_source="swin", zeroshot_variant="random", fewmany_root=FEWMANY_ROOTS["SwinUNETR Random-init"], fewmany_mode="standard"),

    # --- Other methods (no zeroshot) ---
    ModelSpec("microSAM base", fewmany_root=FEWMANY_ROOTS["microSAM base"], fewmany_mode="microsam"),
    ModelSpec("microSAM large", fewmany_root=FEWMANY_ROOTS["microSAM large"], fewmany_mode="microsam"),

    # CellSeg3D only exists for nucleus (tif)
    ModelSpec("CellSeg3D", fewmany_root=FEWMANY_ROOTS["CellSeg3D"], fewmany_mode="cellseg"),

    # Cellpose only nucleus, two separate rows
    ModelSpec("Cellpose 2D", fewmany_root=FEWMANY_ROOTS["Cellpose 2D"], fewmany_mode="cellpose2d"),
    ModelSpec("Cellpose 3D", fewmany_root=FEWMANY_ROOTS["Cellpose 3D"], fewmany_mode="cellpose3d"),
]

MODEL_ORDER = [
    "Unet Image+CLIP",
    "SwinUNETR Image+CLIP",
    "SwinUNETR Image+CLIP (overtrain)",
    "Unet Image-only",
    "SwinUNETR Image-only",
    "Unet Random-init",
    "SwinUNETR Random-init",
    "microSAM base",
    "microSAM large",
    "CellSeg3D",
    "Cellpose 2D",
    "Cellpose 3D",
]


In [11]:
# evaluation orchestration

def eval_zeroshot(model: ModelSpec, dtype_gt: str) -> Dict[str, float]:
    if model.zeroshot_source is None or model.zeroshot_variant is None:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    if model.zeroshot_source == "unet":
        pairs = zeroshot_collect_unet(dtype_gt, model.zeroshot_variant)
    elif model.zeroshot_source == "swin":
        pairs = zeroshot_collect_swin(dtype_gt, model.zeroshot_variant)
    else:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    if not pairs:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    tds, ids = [], []
    for pred, gt in pairs:
        m = compute_pair_metrics(pred, gt)
        tds.append(m["total_dice"])
        ids.append(m["instance_dice"])
    return {"total_dice": safe_mean(tds), "instance_dice": safe_mean(ids)}

def model_supports_dtype(model: ModelSpec, dtype_gt: str) -> bool:
    # CellSeg3D and Cellpose are nucleus-only
    if model.name in ["CellSeg3D", "Cellpose 2D", "Cellpose 3D"]:
        return dtype_gt == "cell_nucleus_patches"
    return True

def eval_fewmany(model: ModelSpec, dtype_gt: str, ntr: int) -> Dict[str, float]:
    if model.fewmany_root is None or model.fewmany_mode is None:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}
    if not model_supports_dtype(model, dtype_gt):
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    fold_metrics = []
    for fold in FOLDS:
        fm = eval_fold(model.fewmany_root, dtype_gt, fold=fold, ntr=ntr, mode=model.fewmany_mode)
        fold_metrics.append(fm)

    # average across folds
    return {
        "total_dice": safe_mean([m["total_dice"] for m in fold_metrics]),
        "instance_dice": safe_mean([m["instance_dice"] for m in fold_metrics]),
    }


In [12]:
# build results table
def build_results_table() -> pd.DataFrame:
    """
    Returns a wide DataFrame with MultiIndex columns:
      (Datatype, Shot, Metric)
    """
    records = []
    for model in MODEL_SPECS:
        row = {"Model": model.name}
        for dtype_gt in DTYPES_GT:
            dtype_name = DTYPE_CANON[dtype_gt]

            # Zero-shot
            z = eval_zeroshot(model, dtype_gt) if model_supports_dtype(model, dtype_gt) else {"total_dice": np.nan, "instance_dice": np.nan}
            row[(dtype_name, "Zero-shot", "Tot Dice")] = z["total_dice"]
            row[(dtype_name, "Zero-shot", "Inst Dice")] = z["instance_dice"]

            # Few-shot / Many-shot
            f = eval_fewmany(model, dtype_gt, ntr=5)
            m = eval_fewmany(model, dtype_gt, ntr=15)

            row[(dtype_name, "Few-shot", "Tot Dice")] = f["total_dice"]
            row[(dtype_name, "Few-shot", "Inst Dice")] = f["instance_dice"]
            row[(dtype_name, "Many-shot", "Tot Dice")] = m["total_dice"]
            row[(dtype_name, "Many-shot", "Inst Dice")] = m["instance_dice"]

        records.append(row)

    df = pd.DataFrame.from_records(records)

    # put Model as index and make multiindex columns
    df = df.set_index("Model")
    df.columns = pd.MultiIndex.from_tuples(df.columns)
    return df


In [13]:
# style and display final table
def style_table(df: pd.DataFrame) -> "pd.io.formats.style.Styler":
    sty = df.copy()

    def _highlight_max(s):
        # ignore NaNs
        if s.dropna().empty:
            return [""] * len(s)
        m = s.max(skipna=True)
        return ["font-weight:700;" if (pd.notna(v) and v == m) else "" for v in s]

    return (sty.style
            .format(precision=3, na_rep="—")
            .apply(_highlight_max, axis=0))

## Build table

In [15]:
# build results table
df = build_results_table()

# sort
df = df.reindex(MODEL_ORDER)

# display
display(df.round(4))

# display styled table
display(style_table(df))


Unnamed: 0_level_0,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Vessels,Vessels,Vessels,Vessels,Vessels,Vessels
Unnamed: 0_level_1,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot
Unnamed: 0_level_2,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice
Model,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
Unet Image+CLIP,0.4083,0.018,0.6765,0.5799,0.8043,0.693,0.4598,0.0574,0.7838,0.9106,0.8097,0.936,0.564,0.088,0.8603,0.7635,0.921,0.8135
SwinUNETR Image+CLIP,0.4721,0.1166,0.4852,0.0069,0.7303,0.6836,0.7196,0.7324,0.7895,0.9493,0.8112,0.9618,0.6816,0.478,0.8579,0.7253,0.8971,0.8292
SwinUNETR Image+CLIP (overtrain),,,0.6078,0.6055,0.6761,0.6805,,,0.7841,0.9358,0.8036,0.9767,,,0.8619,0.81,0.8858,0.9014
Unet Image-only,0.4323,0.1516,0.5006,0.03,0.4998,0.0,0.5547,0.2686,0.8001,0.8688,0.8211,0.9649,0.5331,0.3203,0.6058,0.5605,0.887,0.8426
SwinUNETR Image-only,0.4771,0.1098,0.5555,0.2354,0.6856,0.5681,0.7252,0.7673,0.7854,0.9299,0.8039,0.9574,0.6827,0.4733,0.8331,0.7382,0.8601,0.7362
Unet Random-init,0.4692,0.0002,0.5,0.1136,0.7216,0.5844,0.4828,0.0033,0.751,0.8596,0.8011,0.8833,0.5359,0.0039,0.8334,0.8578,0.8925,0.685
SwinUNETR Random-init,0.4808,0.1901,0.4993,0.0092,0.6041,0.5161,0.7337,0.7647,0.7836,0.9015,0.8055,0.9573,0.6854,0.4559,0.7995,0.5295,0.8683,0.6845
microSAM base,,,0.6618,0.4143,0.7127,0.4725,,,0.4934,0.0046,0.5181,0.0978,,,0.8633,0.7708,0.901,0.8058
microSAM large,,,0.623,0.372,0.7695,0.6343,,,0.5283,0.1175,0.5634,0.2173,,,0.7804,0.7352,0.844,0.8229
CellSeg3D,,,,,,,,,0.5092,0.7089,0.5082,0.7412,,,,,,


Unnamed: 0_level_0,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Vessels,Vessels,Vessels,Vessels,Vessels,Vessels
Unnamed: 0_level_1,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot
Unnamed: 0_level_2,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice
Model,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
Unet Image+CLIP,0.408,0.018,0.676,0.580,0.804,0.693,0.460,0.057,0.784,0.911,0.81,0.936,0.564,0.088,0.860,0.764,0.921,0.813
SwinUNETR Image+CLIP,0.472,0.117,0.485,0.007,0.730,0.684,0.720,0.732,0.79,0.949,0.811,0.962,0.682,0.478,0.858,0.725,0.897,0.829
SwinUNETR Image+CLIP (overtrain),—,—,0.608,0.605,0.676,0.681,—,—,0.784,0.936,0.804,0.977,—,—,0.862,0.810,0.886,0.901
Unet Image-only,0.432,0.152,0.501,0.030,0.500,0.000,0.555,0.269,0.8,0.869,0.821,0.965,0.533,0.320,0.606,0.560,0.887,0.843
SwinUNETR Image-only,0.477,0.110,0.556,0.235,0.686,0.568,0.725,0.767,0.785,0.93,0.804,0.957,0.683,0.473,0.833,0.738,0.860,0.736
Unet Random-init,0.469,0.000,0.500,0.114,0.722,0.584,0.483,0.003,0.751,0.86,0.801,0.883,0.536,0.004,0.833,0.858,0.893,0.685
SwinUNETR Random-init,0.481,0.190,0.499,0.009,0.604,0.516,0.734,0.765,0.784,0.901,0.806,0.957,0.685,0.456,0.799,0.530,0.868,0.684
microSAM base,—,—,0.662,0.414,0.713,0.472,—,—,0.493,0.005,0.518,0.098,—,—,0.863,0.771,0.901,0.806
microSAM large,—,—,0.623,0.372,0.769,0.634,—,—,0.528,0.118,0.563,0.217,—,—,0.780,0.735,0.844,0.823
CellSeg3D,—,—,—,—,—,—,—,—,0.509,0.709,0.508,0.741,—,—,—,—,—,—


## Format and save for paper

In [18]:
# style helper function

# return string-valued copy of dataframe where each column max is wrapped as \\textbf{...}
def latex_bold_column_max(df_numeric: pd.DataFrame, decimals: int = 3) -> pd.DataFrame:
    
    # format everything as strings first
    fmt = f"{{:.{decimals}f}}"
    out = df_numeric.copy()
    # compute maxima on numeric values
    for col in out.columns:
        s = out[col]
        if s.dropna().empty:
            continue
        m = float(s.max(skipna=True))
        mask = s.notna() & np.isclose(s.astype(float), m, rtol=0, atol=1e-12)
        # convert to strings
        out[col] = s.map(lambda v: "--" if pd.isna(v) else fmt.format(float(v)))
        # bold max (ties too)
        out.loc[mask, col] = out.loc[mask, col].map(lambda v: rf"\textbf{{{v}}}")

    # any remaining NaNs (all-NaN columns) -> '--'
    out = out.fillna("--")
    return out

# wrap model names using makecell so the first column can be narrower
def _latex_wrap_model_name(name: str) -> str:
    mapping = {
        "Unet Image+CLIP": r"\makecell[l]{Unet\\Image+CLIP}",
        "SwinUNETR Image+CLIP": r"\makecell[l]{SwinUNETR\\Image+CLIP}",
        "SwinUNETR Image+CLIP (overtrain)": r"\makecell[l]{SwinUNETR\\Image+CLIP\\(overtrain)}",
        "Unet Image-only": r"\makecell[l]{Unet\\Image-only}",
        "SwinUNETR Image-only": r"\makecell[l]{SwinUNETR\\Image-only}",
        "Unet Random-init": r"\makecell[l]{Unet\\Random-init}",
        "SwinUNETR Random-init": r"\makecell[l]{SwinUNETR\\Random-init}",
        "microSAM base": r"\makecell[l]{microSAM\\base}",
        "microSAM large": r"\makecell[l]{microSAM\\large}",
        "CellSeg3D": r"CellSeg3D",
        "Cellpose 2D": r"\makecell[l]{Cellpose\\2D}",
        "Cellpose 3D": r"\makecell[l]{Cellpose\\3D}",
    }
    return mapping.get(name, name)


In [20]:
# output directory
OUTDIR = Path("/midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables")
OUTDIR.mkdir(parents=True, exist_ok=True)

# save as csv
df.to_csv(OUTDIR / "segmentation_results_wide.csv")

# convert for latex
latex_numeric = df.copy().round(3)
latex_str_df = latex_bold_column_max(latex_numeric)

# wrap model names
latex_str_df.index = [_latex_wrap_model_name(str(i)) for i in latex_str_df.index]

latex_str = latex_str_df.to_latex(
    escape=False,
    multicolumn=True,
    multirow=True,
    caption="Segmentation performance (Tot Dice and Inst Dice) for zero-shot, few-shot (ntr5), and many-shot (ntr15). Values are averaged across 10 patches for zero-shot and across 3 CV folds for few/many-shot.",
    label="tab:segmentation_results",
    bold_rows=False,
    longtable=False,
    index=True,
)

latex_str = latex_str.replace(r"\multicolumn{6}{r}{Amyloid Plaque}", r"\multicolumn{6}{c}{Amyloid Plaque}")
latex_str = latex_str.replace(r"\multicolumn{6}{r}{Cell Nucleus}", r"\multicolumn{6}{c}{Cell Nucleus}")
latex_str = latex_str.replace(r"\multicolumn{6}{r}{Vessels}", r"\multicolumn{6}{c}{Vessels}")

# overwrite file with modified string (since to_latex already wrote it)
out_tex = OUTDIR / "segmentation_results.tex"
out_tex.write_text(latex_str)

print(f"[Saved] {OUTDIR/'segmentation_results_wide.csv'}")
print(f"[Saved] {OUTDIR/'segmentation_results.tex'}")


[Saved] /midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables/segmentation_results_wide.csv
[Saved] /midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables/segmentation_results.tex
