# Tables for paper

## Setup

In [1]:
# imports
from dataclasses import dataclass
from __future__ import annotations
from IPython.display import display
import json
import math
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import pandas as pd
from pathlib import Path
import re
from scipy.ndimage import label as cc_label
import tifffile as tiff
from typing import Dict, List, Optional, Tuple
import warnings


In [2]:
# define constants
PROB_THRESHOLD = 0.5
STRUCT_3D_26 = np.ones((3, 3, 3), dtype=bool)
PATCH_ID_RE = re.compile(r"(patch_\d{3})")
VOL_RE = re.compile(r"(vol\d{3})")
CH_RE = re.compile(r"(ch[01])")

# datatype mapping
DTYPES_GT = ["amyloid_plaque_patches", "cell_nucleus_patches", "vessels_patches"]
DTYPE_CANON = {
    "amyloid_plaque_patches": "Amyloid Plaque",
    "cell_nucleus_patches": "Cell Nucleus",
    "vessels_patches": "Vessels",
}

DTYPE_FEWMANY_ALIASES = {
    "amyloid_plaque_patches": ["amyloid_plaque", "amyloid_plaque_patches"],
    "cell_nucleus_patches": ["cell_nucleus", "cell_nucleus_patches"],
    "vessels_patches": ["vessels", "vessels_patches"],
}

# shot definitions
SHOTS = {
    "Zero-shot": {"mode": "zeroshot"},
    "Few-shot":  {"mode": "fewmany", "ntr": 5},
    "Many-shot": {"mode": "fewmany", "ntr": 15},
}

# folds
FOLDS = [0, 1, 2]



In [3]:
# data paths

# ground truth
GT_ROOT = Path("/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches")

# zeroshot
ZEROSHOT_ROOTS = {
    # Unet zeroshot preds: _zeroshot_unet, Swin: _zeroshot/results
    "unet_zeroshot_root": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_zeroshot_unet"),
    "swin_zeroshot_root": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_zeroshot"),
}

# few/many-shot
FEWMANY_ROOTS = {
    "Unet Image+CLIP": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_super_sweep2"),
    "Unet Image-only": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_bright_sweep_26"),
    "Unet Random-init": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_unet_random2"),

    "SwinUNETR Image+CLIP": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_autumn_sweep_27_v2"),
    "SwinUNETR Image+CLIP (overtrain)": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_autumn_sweep_27_long"),
    "SwinUNETR Image-only": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_expert_sweep_31_v2"),
    "SwinUNETR Random-init": Path("/midtier/paetzollab/scratch/ads4015/temp_selma_segmentation_preds_rand_v2"),

    "microSAM base": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_cross_val_b2"),
    "microSAM large": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_cross_val_l"),

    "CellSeg3D": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuned_cross_val"),

    "Cellpose 2D": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellpose/cross_val"),
    "Cellpose 3D": Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellpose/cross_val"),
}


## Functions

In [4]:
# data locating

def _first_existing(paths: List[Path]) -> Optional[Path]:
    for p in paths:
        if p.exists():
            return p
    return None

# find gt label for given patch (ex: .../{dtype_gt}/{patch_id}_{vol}_ch0_label.nii.gz)
def find_gt_label_for_patch(dtype_gt: str, patch_id: str, vol: str, ch: Optional[str]) -> Optional[Path]:
    gt_dir = GT_ROOT / dtype_gt
    candidates: List[Path] = []
    if ch is not None:
        candidates.append(gt_dir / f"{patch_id}_{vol}_{ch}_label.nii.gz")
    else:
        candidates.extend([
            gt_dir / f"{patch_id}_{vol}_ch0_label.nii.gz",
            gt_dir / f"{patch_id}_{vol}_ch1_label.nii.gz",
        ])

    for p in candidates:
        if p.exists():
            return p
    return None

# parse patch, vol, ch from filename
def parse_patch_tokens(p: Path) -> Tuple[Optional[str], Optional[str], Optional[str]]:
    m_patch = PATCH_ID_RE.search(p.name)
    m_vol = VOL_RE.search(p.name)
    m_ch = CH_RE.search(p.name)
    return (
        m_patch.group(1) if m_patch else None,
        m_vol.group(1) if m_vol else None,
        m_ch.group(1) if m_ch else None,
    )

def pick_run_dir(parent: Path, fold: int, ntr: int) -> Optional[Path]:
    if not parent.exists():
        return None
    pat = re.compile(rf"cvfold{fold}_ntr{ntr}(?:_|$)")
    candidates = [d for d in parent.iterdir() if d.is_dir() and pat.search(d.name)]
    if not candidates:
        return None
    candidates.sort(key=lambda d: d.stat().st_mtime, reverse=True)
    return candidates[0]

def list_pred_files_for_run(run_dir: Path) -> List[Path]:
    """
    Handles different layouts:
    - .../run_dir/preds/*.nii.gz
    - .../run_dir/preds/preds/*.nii.gz
    - .../run_dir/patches/*.nii.gz (microSAM)
    - CellSeg3D: *_instances.tif in run_dir
    """
    # most common layouts
    candidates = [
        run_dir / "preds",
        run_dir / "preds" / "preds",
        run_dir / "patches",
    ]
    for c in candidates:
        if c.exists() and c.is_dir():
            files = sorted([p for p in c.rglob("*") if p.is_file() and (p.suffix in [".gz", ".tif"] or p.name.endswith(".nii.gz"))])
            if files:
                return files

    # fallback: run_dir itself
    files = sorted([p for p in run_dir.rglob("*") if p.is_file() and (p.suffix in [".gz", ".tif"] or p.name.endswith(".nii.gz"))])
    return files



In [5]:
# data loading functions

def load_nifti(path: Path) -> np.ndarray:
    img = nib.load(str(path))
    return img.get_fdata(dtype=np.float32)

def _squeeze_singleton_channel(arr: np.ndarray) -> np.ndarray:
    return arr[0] if (arr.ndim == 4 and arr.shape[0] == 1) else arr

def _coerce_pred_gt_shapes(pred: np.ndarray, gt: np.ndarray):
    return _squeeze_singleton_channel(pred), _squeeze_singleton_channel(gt)


In [6]:
# metrics functions

def safe_mean(xs: List[float]) -> float:
    xs = [x for x in xs if x is not None and not (isinstance(x, float) and (math.isnan(x) or math.isinf(x)))]
    return float(np.mean(xs)) if len(xs) else float("nan")

def binary_dice(pred: np.ndarray, gt: np.ndarray, eps: float = 1e-8) -> float:
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    return (2.0 * inter + eps) / (pred.sum() + gt.sum() + eps)

def background_dice(pred: np.ndarray, gt: np.ndarray, eps: float = 1e-8) -> float:
    pred_bg = np.logical_not(pred.astype(bool))
    gt_bg = np.logical_not(gt.astype(bool))
    inter = np.logical_and(pred_bg, gt_bg).sum()
    return (2.0 * inter + eps) / (pred_bg.sum() + gt_bg.sum() + eps)

def total_dice(pred_bin: np.ndarray, gt_bin: np.ndarray, eps: float = 1e-8) -> float:
    return 0.5 * (binary_dice(pred_bin, gt_bin, eps) + background_dice(pred_bin, gt_bin, eps))

def foreground_dice(pred_bin: np.ndarray, gt_bin: np.ndarray, eps: float = 1e-8) -> float:
    pred_bin = pred_bin.astype(bool)
    gt_bin = gt_bin.astype(bool)
    fg_union = np.logical_or(pred_bin, gt_bin)
    if fg_union.sum() == 0:
        return 1.0
    pred_f = pred_bin[fg_union]
    gt_f = gt_bin[fg_union]
    inter = np.logical_and(pred_f, gt_f).sum()
    return (2.0 * inter + eps) / (pred_f.sum() + gt_f.sum() + eps)

def count_instances(mask: np.ndarray, structure=STRUCT_3D_26) -> int:
    mask = mask.astype(bool)
    if mask.sum() == 0:
        return 0
    _, n = cc_label(mask, structure=structure)
    return int(n)

def compute_metrics_for_pair(pred_path: Path, gt_path: Path) -> dict:
    if pred_path.suffix == ".tif":
        pred_arr = tiff.imread(str(pred_path)).astype(np.float32)
    else:
        pred_arr = load_nifti(pred_path)

    gt_arr = load_nifti(gt_path)
    pred_arr, gt_arr = _coerce_pred_gt_shapes(pred_arr, gt_arr)

    pred_bin = pred_arr >= PROB_THRESHOLD
    gt_bin = gt_arr > 0.5

    if pred_bin.shape != gt_bin.shape:
        raise ValueError(f"Shape mismatch: pred {pred_bin.shape}, gt {gt_bin.shape}")

    td = total_dice(pred_bin, gt_bin)
    fd = foreground_dice(pred_bin, gt_bin)
    n_pred = count_instances(pred_bin)
    n_gt = count_instances(gt_bin)

    inst_dice = 1.0 if (n_pred == 0 and n_gt == 0) else (2.0 * min(n_pred, n_gt)) / float(n_pred + n_gt)

    return {
        "total_dice": float(td),
        "foreground_dice": float(fd),
        "instance_dice": float(inst_dice),
        "instances_pred": int(n_pred),
        "instances_gt": int(n_gt),
    }

# wrapper to return only total_dice and instance_dice using compute_metrics_for_pair() function
def compute_pair_metrics(pred_path: Path, gt_label_path: Path) -> Dict[str, float]:
    d = compute_metrics_for_pair(pred_path, gt_label_path)
    return {
        "total_dice": float(d["total_dice"]),
        "instance_dice": float(d["instance_dice"]),
    }


In [7]:
# zero-shot data collecting
def zeroshot_collect_unet(dtype_gt: str, variant_suffix: str) -> List[Tuple[Path, Path]]:
    """
    Unet zeroshot layout described:
    .../temp_selma_segmentation_preds_zeroshot_unet/results/{dtype_gt}/patch_000_..._pred_{variant}.nii.gz
    variant_suffix in {"image_clip","image_only","random"}
    """
    root = ZEROSHOT_ROOTS["unet_zeroshot_root"] / "results" / dtype_gt
    pairs = []
    for i in range(10):
        patch_id = f"patch_{i:03d}"
        pred_glob = list(root.glob(f"{patch_id}_*pred_{variant_suffix}.nii.gz"))
        if not pred_glob:
            continue
        pred_path = sorted(pred_glob)[0]
        patch_id2, vol, ch = parse_patch_tokens(pred_path)
        if patch_id2 is None or vol is None or ch is None:
            continue
        gt = find_gt_label_for_patch(dtype_gt, patch_id2, vol, ch)
        if gt is None:
            continue
        pairs.append((pred_path, gt))
    return pairs

def zeroshot_collect_swin(dtype_gt: str, variant_suffix: str) -> List[Tuple[Path, Path]]:
    """
    Swin zeroshot layout described:
    .../temp_selma_segmentation_preds_zeroshot/results/{dtype_gt}/patch_000_..._pred_{variant}.nii.gz
    """
    root = ZEROSHOT_ROOTS["swin_zeroshot_root"] / "results" / dtype_gt
    pairs = []
    for i in range(10):
        patch_id = f"patch_{i:03d}"
        pred_glob = list(root.glob(f"{patch_id}_*pred_{variant_suffix}.nii.gz"))
        if not pred_glob:
            continue
        pred_path = sorted(pred_glob)[0]
        patch_id2, vol, ch = parse_patch_tokens(pred_path)
        if patch_id2 is None or vol is None or ch is None:
            continue
        gt = find_gt_label_for_patch(dtype_gt, patch_id2, vol, ch)
        if gt is None:
            continue
        pairs.append((pred_path, gt))
    return pairs


In [8]:
# few/many-shot data collecting

def fewmany_collect_pairs(model_root: Path, dtype_gt: str, fold: int, ntr: int, *, mode: str) -> List[Tuple[Path, Path]]:
    """
    Collect (pred, gt_label) pairs for a given model, datatype, fold, and ntr.

    mode:
      - "standard": roots like .../<exp_root>/preds/{dtype}/<run_dir>/(preds|patches)/*.nii.gz
      - "microsam": similar but uses 'patches' folder
      - "cellseg": tif instances
      - "cellpose2d"/"cellpose3d": select files containing pred2d or pred3d
    """

    if mode in {"cellpose2d", "cellpose3d"}:
        pred_kind = "2d" if mode == "cellpose2d" else "3d"
        return collect_cellpose_pairs(model_root, fold=fold, ntr=ntr, pred_kind=pred_kind)

    # find datatype folder under model_root/preds
    preds_root = model_root / "preds"

    dtype_candidates = [preds_root / name for name in DTYPE_FEWMANY_ALIASES[dtype_gt]]
    dtype_dir = _first_existing(dtype_candidates)
    if dtype_dir is None:
        return []

    run_dir = pick_run_dir(dtype_dir, fold=fold, ntr=ntr)
    if run_dir is None:
        return []

    files = list_pred_files_for_run(run_dir)

    # filter per-mode
    if mode == "cellseg":
        files = [p for p in files if p.suffix == ".tif" and p.name.endswith("_instances.tif")]
    elif mode == "cellpose2d":
        files = [p for p in files if p.name.endswith(".nii.gz") and ("pred2d_" in p.name)]
    elif mode == "cellpose3d":
        files = [p for p in files if p.name.endswith(".nii.gz") and ("pred3d_" in p.name)]
    else:
        # standard/microsam: keep nii.gz preds (and ignore any labels etc)
        files = [p for p in files if p.name.endswith(".nii.gz") and ("_label" not in p.name)]

    pairs: List[Tuple[Path, Path]] = []
    for pred_path in files:
        patch_id, vol, ch = parse_patch_tokens(pred_path)
        if patch_id is None or vol is None or ch is None:
            continue
        gt = find_gt_label_for_patch(dtype_gt, patch_id, vol, ch)
        if gt is None:
            continue
        pairs.append((pred_path, gt))

    return pairs

# few/many-shot evaluation
def eval_fold(model_root: Path, dtype_gt: str, fold: int, ntr: int, mode: str) -> Dict[str, float]:
    pairs = fewmany_collect_pairs(model_root, dtype_gt, fold, ntr, mode=mode)
    if not pairs:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    tds, ids = [], []
    for pred, gt in pairs:
        m = compute_pair_metrics(pred, gt)
        tds.append(m["total_dice"])
        ids.append(m["instance_dice"])

    return {"total_dice": safe_mean(tds), "instance_dice": safe_mean(ids)}


In [9]:
# collect cellpose pairs
def collect_cellpose_pairs(
    model_root: Path,
    fold: int,
    ntr: int,
    pred_kind: str,  # "2d" or "3d"
) -> List[Tuple[Path, Path]]:
    """
    Cellpose layout:
      .../preds/cell_nucleus_patches/<run_dir>/patch_007_vol006_pred2d_...nii.gz
      .../preds/cell_nucleus_patches/<run_dir>/patch_007_vol006_pred3d_...nii.gz

    GT labels:
      .../selma3d_finetune_patches/cell_nucleus_patches/patch_007_vol006_ch0_label.nii.gz (or ch1)
    """
    assert pred_kind in {"2d", "3d"}
    dtype_gt = "cell_nucleus_patches"
    dtype_dir = model_root / "preds" / dtype_gt
    if not dtype_dir.exists():
        return []

    # Find run dir (most recent) for this fold+ntr
    pat = re.compile(rf"cvfold{fold}_ntr{ntr}(?:_|$)")
    run_dirs = [d for d in dtype_dir.iterdir() if d.is_dir() and pat.search(d.name)]
    if not run_dirs:
        return []
    run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
    run_dir = run_dirs[0]

    # Gather preds directly under run_dir (your example has them there)
    all_preds = sorted(run_dir.rglob("*.nii.gz"))
    want_tag = f"_pred{pred_kind}_"
    preds = [p for p in all_preds if want_tag in p.name]

    pairs: List[Tuple[Path, Path]] = []
    gt_dir = GT_ROOT / dtype_gt

    for pred_path in preds:
        m_patch = PATCH_ID_RE.search(pred_path.name)
        m_vol = VOL_RE.search(pred_path.name)
        if not m_patch or not m_vol:
            continue
        patch_id = m_patch.group(1)
        vol = m_vol.group(1)

        # Cellpose pred doesn't encode channel → try ch0 then ch1
        gt0 = gt_dir / f"{patch_id}_{vol}_ch0_label.nii.gz"
        gt1 = gt_dir / f"{patch_id}_{vol}_ch1_label.nii.gz"

        if gt0.exists():
            pairs.append((pred_path, gt0))
        elif gt1.exists():
            pairs.append((pred_path, gt1))
        # else: no GT label found -> skip

    return pairs


In [10]:
# model registry (what to compute for each model)
@dataclass
class ModelSpec:
    name: str
    # zeroshot source: "unet" or "swin" or None
    zeroshot_source: Optional[str] = None
    zeroshot_variant: Optional[str] = None  # "image_clip"|"image_only"|"random"
    fewmany_root: Optional[Path] = None
    fewmany_mode: Optional[str] = None      # "standard"|"microsam"|"cellseg"|"cellpose2d"|"cellpose3d"

MODEL_SPECS: List[ModelSpec] = [
    # --- UNet zeroshot variants (all from same unet zeroshot root) ---
    ModelSpec("Unet Image+CLIP", zeroshot_source="unet", zeroshot_variant="image_clip", fewmany_root=FEWMANY_ROOTS["Unet Image+CLIP"], fewmany_mode="standard"),
    ModelSpec("Unet Image-only", zeroshot_source="unet", zeroshot_variant="image_only", fewmany_root=FEWMANY_ROOTS["Unet Image-only"], fewmany_mode="standard"),
    ModelSpec("Unet Random-init", zeroshot_source="unet", zeroshot_variant="random", fewmany_root=FEWMANY_ROOTS["Unet Random-init"], fewmany_mode="standard"),

    # --- Swin zeroshot variants ---
    ModelSpec("SwinUNETR Image+CLIP", zeroshot_source="swin", zeroshot_variant="image_clip", fewmany_root=FEWMANY_ROOTS["SwinUNETR Image+CLIP"], fewmany_mode="standard"),
    ModelSpec("SwinUNETR Image+CLIP (overtrain)", zeroshot_source=None, zeroshot_variant=None, fewmany_root=FEWMANY_ROOTS["SwinUNETR Image+CLIP (overtrain)"], fewmany_mode="standard"),
    ModelSpec("SwinUNETR Image-only", zeroshot_source="swin", zeroshot_variant="image_only", fewmany_root=FEWMANY_ROOTS["SwinUNETR Image-only"], fewmany_mode="standard"),
    ModelSpec("SwinUNETR Random-init", zeroshot_source="swin", zeroshot_variant="random", fewmany_root=FEWMANY_ROOTS["SwinUNETR Random-init"], fewmany_mode="standard"),

    # --- Other methods (no zeroshot) ---
    ModelSpec("microSAM base", fewmany_root=FEWMANY_ROOTS["microSAM base"], fewmany_mode="microsam"),
    ModelSpec("microSAM large", fewmany_root=FEWMANY_ROOTS["microSAM large"], fewmany_mode="microsam"),

    # CellSeg3D only exists for nucleus (tif)
    ModelSpec("CellSeg3D", fewmany_root=FEWMANY_ROOTS["CellSeg3D"], fewmany_mode="cellseg"),

    # Cellpose only nucleus, two separate rows
    ModelSpec("Cellpose 2D", fewmany_root=FEWMANY_ROOTS["Cellpose 2D"], fewmany_mode="cellpose2d"),
    ModelSpec("Cellpose 3D", fewmany_root=FEWMANY_ROOTS["Cellpose 3D"], fewmany_mode="cellpose3d"),
]

MODEL_ORDER = [
    "Unet Image+CLIP",
    "SwinUNETR Image+CLIP",
    "SwinUNETR Image+CLIP (overtrain)",
    "Unet Image-only",
    "SwinUNETR Image-only",
    "Unet Random-init",
    "SwinUNETR Random-init",
    "microSAM base",
    "microSAM large",
    "CellSeg3D",
    "Cellpose 2D",
    "Cellpose 3D",
]


In [11]:
# evaluation orchestration

def eval_zeroshot(model: ModelSpec, dtype_gt: str) -> Dict[str, float]:
    if model.zeroshot_source is None or model.zeroshot_variant is None:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    if model.zeroshot_source == "unet":
        pairs = zeroshot_collect_unet(dtype_gt, model.zeroshot_variant)
    elif model.zeroshot_source == "swin":
        pairs = zeroshot_collect_swin(dtype_gt, model.zeroshot_variant)
    else:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    if not pairs:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    tds, ids = [], []
    for pred, gt in pairs:
        m = compute_pair_metrics(pred, gt)
        tds.append(m["total_dice"])
        ids.append(m["instance_dice"])
    return {"total_dice": safe_mean(tds), "instance_dice": safe_mean(ids)}

def model_supports_dtype(model: ModelSpec, dtype_gt: str) -> bool:
    # CellSeg3D and Cellpose are nucleus-only
    if model.name in ["CellSeg3D", "Cellpose 2D", "Cellpose 3D"]:
        return dtype_gt == "cell_nucleus_patches"
    return True

def eval_fewmany(model: ModelSpec, dtype_gt: str, ntr: int) -> Dict[str, float]:
    if model.fewmany_root is None or model.fewmany_mode is None:
        return {"total_dice": float("nan"), "instance_dice": float("nan")}
    if not model_supports_dtype(model, dtype_gt):
        return {"total_dice": float("nan"), "instance_dice": float("nan")}

    fold_metrics = []
    for fold in FOLDS:
        fm = eval_fold(model.fewmany_root, dtype_gt, fold=fold, ntr=ntr, mode=model.fewmany_mode)
        fold_metrics.append(fm)

    # average across folds
    return {
        "total_dice": safe_mean([m["total_dice"] for m in fold_metrics]),
        "instance_dice": safe_mean([m["instance_dice"] for m in fold_metrics]),
    }


In [12]:
# build results table
def build_results_table() -> pd.DataFrame:
    """
    Returns a wide DataFrame with MultiIndex columns:
      (Datatype, Shot, Metric)
    """
    records = []
    for model in MODEL_SPECS:
        row = {"Model": model.name}
        for dtype_gt in DTYPES_GT:
            dtype_name = DTYPE_CANON[dtype_gt]

            # Zero-shot
            z = eval_zeroshot(model, dtype_gt) if model_supports_dtype(model, dtype_gt) else {"total_dice": np.nan, "instance_dice": np.nan}
            row[(dtype_name, "Zero-shot", "Tot Dice")] = z["total_dice"]
            row[(dtype_name, "Zero-shot", "Inst Dice")] = z["instance_dice"]

            # Few-shot / Many-shot
            f = eval_fewmany(model, dtype_gt, ntr=5)
            m = eval_fewmany(model, dtype_gt, ntr=15)

            row[(dtype_name, "Few-shot", "Tot Dice")] = f["total_dice"]
            row[(dtype_name, "Few-shot", "Inst Dice")] = f["instance_dice"]
            row[(dtype_name, "Many-shot", "Tot Dice")] = m["total_dice"]
            row[(dtype_name, "Many-shot", "Inst Dice")] = m["instance_dice"]

        records.append(row)

    df = pd.DataFrame.from_records(records)

    # put Model as index and make multiindex columns
    df = df.set_index("Model")
    df.columns = pd.MultiIndex.from_tuples(df.columns)
    return df


In [13]:
# style and display final table
def style_table(df: pd.DataFrame) -> "pd.io.formats.style.Styler":
    sty = df.copy()

    def _highlight_max(s):
        # ignore NaNs
        if s.dropna().empty:
            return [""] * len(s)
        m = s.max(skipna=True)
        return ["font-weight:700;" if (pd.notna(v) and v == m) else "" for v in s]

    return (sty.style
            .format(precision=3, na_rep="—")
            .apply(_highlight_max, axis=0))

## Build table

In [14]:
# build results table
df = build_results_table()

# sort
df = df.reindex(MODEL_ORDER)

# display
display(df.round(4))

# display styled table
display(style_table(df))


Unnamed: 0_level_0,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Vessels,Vessels,Vessels,Vessels,Vessels,Vessels
Unnamed: 0_level_1,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot
Unnamed: 0_level_2,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice
Model,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
Unet Image+CLIP,0.4083,0.018,0.6765,0.5799,0.8043,0.693,0.4598,0.0574,0.7838,0.9106,0.8097,0.936,0.564,0.088,0.8603,0.7635,0.921,0.8135
SwinUNETR Image+CLIP,0.4721,0.1166,0.4852,0.0069,0.7303,0.6836,0.7196,0.7324,0.7895,0.9493,0.8112,0.9618,0.6816,0.478,0.8579,0.7253,0.8971,0.8292
SwinUNETR Image+CLIP (overtrain),,,0.6078,0.6055,0.6761,0.6805,,,0.7841,0.9358,0.8036,0.9767,,,0.8619,0.81,0.8858,0.9014
Unet Image-only,0.4323,0.1516,0.5006,0.03,0.4998,0.0,0.5547,0.2686,0.8001,0.8688,0.8211,0.9649,0.5331,0.3203,0.6058,0.5605,0.887,0.8426
SwinUNETR Image-only,0.4771,0.1098,0.5555,0.2354,0.6856,0.5681,0.7252,0.7673,0.7854,0.9299,0.8039,0.9574,0.6827,0.4733,0.8331,0.7382,0.8601,0.7362
Unet Random-init,0.4692,0.0002,0.5,0.1136,0.7216,0.5844,0.4828,0.0033,0.751,0.8596,0.8011,0.8833,0.5359,0.0039,0.8334,0.8578,0.8925,0.685
SwinUNETR Random-init,0.4808,0.1901,0.4993,0.0092,0.6041,0.5161,0.7337,0.7647,0.7836,0.9015,0.8055,0.9573,0.6854,0.4559,0.7995,0.5295,0.8683,0.6845
microSAM base,,,0.6618,0.4143,0.7127,0.4725,,,0.4934,0.0046,0.5181,0.0978,,,0.8633,0.7708,0.901,0.8058
microSAM large,,,0.623,0.372,0.7695,0.6343,,,0.5283,0.1175,0.5634,0.2173,,,0.7804,0.7352,0.844,0.8229
CellSeg3D,,,,,,,,,0.5092,0.7089,0.5082,0.7412,,,,,,


Unnamed: 0_level_0,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Amyloid Plaque,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Cell Nucleus,Vessels,Vessels,Vessels,Vessels,Vessels,Vessels
Unnamed: 0_level_1,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot,Zero-shot,Zero-shot,Few-shot,Few-shot,Many-shot,Many-shot
Unnamed: 0_level_2,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice,Tot Dice,Inst Dice
Model,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
Unet Image+CLIP,0.408,0.018,0.676,0.580,0.804,0.693,0.460,0.057,0.784,0.911,0.81,0.936,0.564,0.088,0.860,0.764,0.921,0.813
SwinUNETR Image+CLIP,0.472,0.117,0.485,0.007,0.730,0.684,0.720,0.732,0.79,0.949,0.811,0.962,0.682,0.478,0.858,0.725,0.897,0.829
SwinUNETR Image+CLIP (overtrain),—,—,0.608,0.605,0.676,0.681,—,—,0.784,0.936,0.804,0.977,—,—,0.862,0.810,0.886,0.901
Unet Image-only,0.432,0.152,0.501,0.030,0.500,0.000,0.555,0.269,0.8,0.869,0.821,0.965,0.533,0.320,0.606,0.560,0.887,0.843
SwinUNETR Image-only,0.477,0.110,0.556,0.235,0.686,0.568,0.725,0.767,0.785,0.93,0.804,0.957,0.683,0.473,0.833,0.738,0.860,0.736
Unet Random-init,0.469,0.000,0.500,0.114,0.722,0.584,0.483,0.003,0.751,0.86,0.801,0.883,0.536,0.004,0.833,0.858,0.893,0.685
SwinUNETR Random-init,0.481,0.190,0.499,0.009,0.604,0.516,0.734,0.765,0.784,0.901,0.806,0.957,0.685,0.456,0.799,0.530,0.868,0.684
microSAM base,—,—,0.662,0.414,0.713,0.472,—,—,0.493,0.005,0.518,0.098,—,—,0.863,0.771,0.901,0.806
microSAM large,—,—,0.623,0.372,0.769,0.634,—,—,0.528,0.118,0.563,0.217,—,—,0.780,0.735,0.844,0.823
CellSeg3D,—,—,—,—,—,—,—,—,0.509,0.709,0.508,0.741,—,—,—,—,—,—


## Format and save for paper

In [26]:
# style helper function

# return string-valued copy of dataframe where each value max (Tot/Inst) is wrapped as \\textbf{...} within each column
def latex_format_bold_pairs(df_pairs: pd.DataFrame, decimals: int = 3) -> pd.DataFrame:

    fmt = f"{{:.{decimals}f}}"
    out = pd.DataFrame(index=df_pairs.index, columns=df_pairs.columns, dtype=object)

    for col in df_pairs.columns:
        s = df_pairs[col]

        # Extract numeric arrays for max computation
        tot_vals = []
        inst_vals = []
        for v in s:
            if isinstance(v, tuple) and len(v) == 2 and v[0] is not None and v[1] is not None:
                tot_vals.append(v[0])
                inst_vals.append(v[1])
            else:
                tot_vals.append(np.nan)
                inst_vals.append(np.nan)

        tot_arr = np.array(tot_vals, dtype=float)
        inst_arr = np.array(inst_vals, dtype=float)

        # Determine maxima (ignore NaN)
        tot_max = np.nanmax(tot_arr) if np.isfinite(tot_arr).any() else np.nan
        inst_max = np.nanmax(inst_arr) if np.isfinite(inst_arr).any() else np.nan

        # Build formatted strings with bolding
        col_out = []
        for (tot, inst), t, u in zip(s, tot_arr, inst_arr):
            if not np.isfinite(t) or not np.isfinite(u):
                col_out.append("--")
                continue

            tot_str = fmt.format(float(t))
            inst_str = fmt.format(float(u))

            if np.isfinite(tot_max) and np.isclose(t, tot_max, rtol=0, atol=1e-12):
                tot_str = rf"\textbf{{{tot_str}}}"
            if np.isfinite(inst_max) and np.isclose(u, inst_max, rtol=0, atol=1e-12):
                inst_str = rf"\textbf{{{inst_str}}}"

            col_out.append(f"{tot_str}/{inst_str}")

        out[col] = col_out

    return out

# wrap model names using makecell so the first column can be narrower
def _latex_wrap_model_name(name: str) -> str:
    mapping = {
        "Unet Image+CLIP": r"\makecell[l]{Unet I+T}",
        "SwinUNETR Image+CLIP": r"\makecell[l]{Swin I+T}",
        "SwinUNETR Image+CLIP (overtrain)": r"\makecell[l]{Swin I+T\\(over)}",
        "Unet Image-only": r"\makecell[l]{Unet I}",
        "SwinUNETR Image-only": r"\makecell[l]{Swin I}",
        "Unet Random-init": r"\makecell[l]{Unet R}",
        "SwinUNETR Random-init": r"\makecell[l]{Swin R}",
        "microSAM base": r"\makecell[l]{uSAM (b)}",
        "microSAM large": r"\makecell[l]{uSAM (l)}",
        "CellSeg3D": r"CellSeg3D",
        "Cellpose 2D": r"\makecell[l]{Cellpose2D}",
        "Cellpose 3D": r"\makecell[l]{Cellpose3D}",
    }
    return mapping.get(name, name)

# inject LaTeX formatting into table environment (immediately after \begin{table} or \begin{sidewaystable})
def inject_table_formatting(
    latex_str: str,
    add_centering: bool = True, # whether to add \centering
    fontsize_cmd: str = r"\fontsize{8}{9}\selectfont", # font size command to inject
    tabcolsep_pt: int = 2, # tabcolsep in points
    arraystretch: float = 1.05, # arraystretch value
) -> str:

    lines = latex_str.splitlines()
    out = []
    injected = False

    for line in lines:
        out.append(line)
        s = line.strip()
        if (not injected) and (s.startswith(r"\begin{table}") or s.startswith(r"\begin{sidewaystable}")):
            if add_centering:
                out.append(r"\centering")
            out.append(fontsize_cmd)
            out.append(rf"\setlength{{\tabcolsep}}{{{tabcolsep_pt}pt}}")
            out.append(rf"\renewcommand{{\arraystretch}}{{{arraystretch}}}")
            injected = True

    if not injected:
        warnings.warn("Could not inject formatting: no table environment found.")
    return "\n".join(out)

# function to collapse (Tot Dice, Inst Dice) into one column per shot
def collapse_metrics(df_in: pd.DataFrame) -> pd.DataFrame:

    # build new columns in a stable order
    dtypes = df_in.columns.get_level_values(0).unique().tolist()
    shots  = df_in.columns.get_level_values(1).unique().tolist()

    out = pd.DataFrame(index=df_in.index)
    for dt in dtypes:
        for sh in shots:
            tot = df_in[(dt, sh, "Tot Dice")]
            inst = df_in[(dt, sh, "Inst Dice")]
            out[(dt, sh, "T/I")] = list(zip(tot, inst))

    out.columns = pd.MultiIndex.from_tuples(out.columns)
    return out

# abbreviate zero-shot/few-shot/many-shot in LaTeX table headers to ZS/FS/MS
def abbreviate_shot_headers(latex_str: str) -> str:
    latex_str = re.sub(r"\bZero-shot\b", "ZS", latex_str)
    latex_str = re.sub(r"\bFew-shot\b",  "FS", latex_str)
    latex_str = re.sub(r"\bMany-shot\b", "MS", latex_str)
    return latex_str

# merge shot and metric header rows into one ('ZS T/I', 'FS T/I', 'MS T/I')
def merge_shot_and_metric_header_rows(latex_str: str) -> str:

    lines = latex_str.splitlines()

    # Find consecutive header rows:
    #  - shot row: contains ZS/FS/MS AND '&'
    #  - metric row: next line contains 'T/I' AND '&'
    shot_i = None
    metric_i = None

    for i in range(len(lines) - 1):
        if ("ZS" in lines[i] or "FS" in lines[i] or "MS" in lines[i]) and "&" in lines[i]:
            if "T/I" in lines[i + 1] and "&" in lines[i + 1]:
                shot_i = i
                metric_i = i + 1
                break

    if shot_i is None or metric_i is None:
        warnings.warn("Could not find shot+metric header rows to merge.")
        return latex_str

    shot_line = lines[shot_i].rstrip()
    metric_line = lines[metric_i].rstrip()

    # Strip trailing '\\' from both lines (we will add it back once, at the end)
    shot_has_end = shot_line.endswith(r"\\")
    metric_has_end = metric_line.endswith(r"\\")
    if shot_has_end:
        shot_line = shot_line[:-2].rstrip()
    if metric_has_end:
        metric_line = metric_line[:-2].rstrip()

    shot_cells = [c.strip() for c in shot_line.split("&")]
    metric_cells = [c.strip() for c in metric_line.split("&")]

    # Merge: keep first cell (row label header), append " T/I" to shot cells where metric cell is T/I
    merged_cells = []
    n = min(len(shot_cells), len(metric_cells))
    for j in range(n):
        a = shot_cells[j]
        b = metric_cells[j]
        if j == 0:
            merged_cells.append(a)
            continue

        if "T/I" in b:
            # If a is \multicolumn{1}{r}{ZS}, inject inside braces if present
            m = re.match(r"(\\multicolumn\{1\}\{[lrc]\}\{)(.*)(\})$", a)
            if m:
                merged_cells.append(f"{m.group(1)}{m.group(2)} T/I{m.group(3)}")
            else:
                merged_cells.append(f"{a} T/I")
        else:
            merged_cells.append(a)

    # If shot row had extra columns beyond metric row, keep them (rare)
    if len(shot_cells) > n:
        merged_cells.extend([c.strip() for c in shot_cells[n:]])

    # Rebuild merged line and ALWAYS end with '\\'
    lines[shot_i] = " & ".join(merged_cells) + r" \\"

    # Delete metric row
    del lines[metric_i]

    return "\n".join(lines)



In [27]:
# output directory
OUTDIR = Path("/midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables")
OUTDIR.mkdir(parents=True, exist_ok=True)

# save as csv
df.to_csv(OUTDIR / "segmentation_results_wide.csv")

# convert for latex
DECIMALS = 2
latex_numeric = df.copy().round(DECIMALS)

# collapse Tot and Inst into one column per shot
df_pairs = collapse_metrics(latex_numeric)

# format with bolding
latex_str_df = latex_format_bold_pairs(df_pairs, decimals=DECIMALS)

# wrap model names
latex_str_df.index = [_latex_wrap_model_name(str(i)) for i in latex_str_df.index]

latex_str = latex_str_df.to_latex(
    escape=False,
    multicolumn=True,
    multirow=True,
    caption="Segmentation performance (Tot Dice and Inst Dice) for zero-shot, few-shot (ntr5), and many-shot (ntr15). Values are averaged across 10 patches for zero-shot and across 3 CV folds for few/many-shot.",
    label="tab:segmentation_results",
    bold_rows=False,
    longtable=False,
    index=True,
)

# abbreviate shot headers
latex_str = abbreviate_shot_headers(latex_str)

# merge ZS/FS/MS and T/I header rows
latex_str = merge_shot_and_metric_header_rows(latex_str)

# modify tabular to use @{} ... @{} to reduce side padding
latex_str = re.sub(
    r"\\begin\{tabular\}\{[^}]*\}",
    r"\\begin{tabular}{@{}lccccccccc@{}}",
    latex_str,
    count=1
)

latex_str = latex_str.replace(r"\multicolumn{3}{r}{Amyloid Plaque}", r"\multicolumn{3}{c}{Amyloid Plaque}")
latex_str = latex_str.replace(r"\multicolumn{3}{r}{Cell Nucleus}", r"\multicolumn{3}{c}{Cell Nucleus}")
latex_str = latex_str.replace(r"\multicolumn{3}{r}{Vessels}", r"\multicolumn{3}{c}{Vessels}")

# inject 8pt font size inside table
latex_str = inject_table_formatting(latex_str, add_centering=True)

# overwrite file with modified string (since to_latex already wrote it)
out_tex = OUTDIR / "segmentation_results.tex"
out_tex.write_text(latex_str)

print(f"[Saved] {OUTDIR/'segmentation_results_wide.csv'}")
print(f"[Saved] {OUTDIR/'segmentation_results.tex'}")


[Saved] /midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables/segmentation_results_wide.csv
[Saved] /midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables/segmentation_results.tex


In [29]:
import re
import numpy as np
import pandas as pd
from pathlib import Path

# ---------- config ----------
DECIMALS = 2
OUTDIR = Path("/midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables")
OUTDIR.mkdir(parents=True, exist_ok=True)

# column groups (datatypes)
DT_ORDER = ["Amyloid Plaque", "Cell Nucleus", "Vessels"]

# shot order + abbreviations
SHOT_ORDER = [("Zero-shot", "ZS"), ("Few-shot", "FS"), ("Many-shot", "MS")]

# metrics (keep separate; no stacking)
METRIC_ORDER = [("Tot Dice", "Tot"), ("Inst Dice", "Inst")]

# ---------- helper: wrap/brev model names (use your existing mapping) ----------
def _latex_wrap_model_name(name: str) -> str:
    mapping = {
        "Unet Image+CLIP": r"\makecell[l]{Unet I+T}",
        "SwinUNETR Image+CLIP": r"\makecell[l]{Swin I+T}",
        "SwinUNETR Image+CLIP (overtrain)": r"\makecell[l]{Swin I+T\\(over)}",
        "Unet Image-only": r"\makecell[l]{Unet I}",
        "SwinUNETR Image-only": r"\makecell[l]{Swin I}",
        "Unet Random-init": r"\makecell[l]{Unet R}",
        "SwinUNETR Random-init": r"\makecell[l]{Swin R}",
        "microSAM base": r"\makecell[l]{uSAM (b)}",
        "microSAM large": r"\makecell[l]{uSAM (l)}",
        "CellSeg3D": r"CellSeg3D",
        "Cellpose 2D": r"\makecell[l]{Cellpose2D}",
        "Cellpose 3D": r"\makecell[l]{Cellpose3D}",
    }
    return mapping.get(name, name)

# ---------- helper: bold best per column WITHIN each shot block ----------
def bold_best_within_shot_blocks(df_num: pd.DataFrame, shot_sizes: dict, decimals: int = 2) -> pd.DataFrame:
    """
    df_num: index is MultiIndex (Shot, Model), columns are MultiIndex (Datatype, Metric)
    shot_sizes: dict shot_abbrev -> number of rows in that block
    Returns a string-valued dataframe with \textbf{} on maxima within each shot block per column.
    """
    fmt = f"{{:.{decimals}f}}"
    out = df_num.copy().astype(object)

    # format all numeric to strings later; keep a numeric copy for max logic
    for shot_abbrev in shot_sizes.keys():
        block = out.loc[(shot_abbrev, slice(None)), :]
        # compute max per column ignoring NaN
        col_max = block.max(axis=0, skipna=True)

        # convert each cell to string and bold if equals block max (ties ok)
        def format_cell(v, m):
            if pd.isna(v):
                return "--"
            s = fmt.format(float(v))
            if pd.notna(m) and np.isclose(float(v), float(m), rtol=0, atol=1e-12):
                return rf"\textbf{{{s}}}"
            return s

        for col in out.columns:
            m = col_max[col]
            mask = (out.index.get_level_values(0) == shot_abbrev)
            out.loc[mask, col] = out.loc[mask, col].map(lambda v: format_cell(v, m))

    # non-shot rows shouldn't exist, but safe:
    out = out.fillna("--")
    return out.astype(object)

# ---------- reshape your existing df ----------
# Your df currently has columns: (Datatype, Shot, Metric) and index: Model
# We will produce df_long with index: (ShotAbbrev, Model) and columns: (Datatype, MetricShort)

def build_shots_as_rows_table(df: pd.DataFrame) -> pd.DataFrame:
    # ensure expected metric names
    # (If your df columns are "Total Dice"/"Instance Dice", rename here.)
    rename_metric = {
        "Total Dice": "Tot Dice",
        "Instance Dice": "Inst Dice",
        "Tot Dice": "Tot Dice",
        "Inst Dice": "Inst Dice",
    }

    # normalize metric level
    cols = []
    for (dt, sh, met) in df.columns:
        cols.append((dt, sh, rename_metric.get(met, met)))
    df2 = df.copy()
    df2.columns = pd.MultiIndex.from_tuples(cols)

    # build long table blocks shot-by-shot
    blocks = []
    for sh_full, sh_abbrev in SHOT_ORDER:
        # extract columns for this shot: (dt, sh_full, metric)
        subcols = []
        for dt in DT_ORDER:
            for met_full, met_short in METRIC_ORDER:
                subcols.append((dt, sh_full, met_full))

        sub = df2[subcols].copy()

        # collapse to (dt, met_short)
        sub.columns = pd.MultiIndex.from_tuples([(dt, met_short) for (dt, _, met_full), met_short in zip(subcols, [m[1] for dt in DT_ORDER for m in METRIC_ORDER])])

        # attach shot as row level
        sub.index = pd.MultiIndex.from_product([[sh_abbrev], sub.index], names=["Shot", "Model"])
        blocks.append(sub)

    df_long = pd.concat(blocks, axis=0)

    # keep model order within each shot block using your MODEL_SPECS order (already in df index order if you set it)
    return df_long

df_long_num = build_shots_as_rows_table(df).astype(float).round(DECIMALS)

# bold best per column within each shot block
shot_sizes = {sh_abbrev: df.shape[0] for _, sh_abbrev in SHOT_ORDER}  # each block has all models
df_long_str = bold_best_within_shot_blocks(df_long_num, shot_sizes=shot_sizes, decimals=DECIMALS)

# wrap model names for LaTeX
df_long_str = df_long_str.copy()
df_long_str.index = pd.MultiIndex.from_tuples(
    [(shot, _latex_wrap_model_name(str(model))) for shot, model in df_long_str.index],
    names=["Shot", "Model"]
)

# ---------- make LaTeX ----------
latex = df_long_str.to_latex(
    escape=False,
    multicolumn=True,
    multirow=True,     # needed for multiindex rows
    index=True,        # includes Shot and Model index columns
    caption="Segmentation performance (Tot and Inst Dice) across datatypes. Shots: zero-shot (ZS), few-shot (FS, ntr5), many-shot (MS, ntr15). Values averaged across 10 patches for ZS and 3 CV folds for FS/MS.",
    label="tab:segmentation_results_shots_rows",
)

# ---------- postprocess LaTeX to look like your example ----------
# 1) Make datatype headers centered (pandas may default to 'r')
latex = latex.replace(r"\multicolumn{2}{r}{Amyloid Plaque}", r"\multicolumn{2}{c}{Amyloid Plaque}")
latex = latex.replace(r"\multicolumn{2}{r}{Cell Nucleus}", r"\multicolumn{2}{c}{Cell Nucleus}")
latex = latex.replace(r"\multicolumn{2}{r}{Vessels}", r"\multicolumn{2}{c}{Vessels}")

# 2) Replace the "Shot" column values with a single vertical multirow label per block.
#    We expect the first column is Shot, second is Model. We'll convert first row in each block to \multirow{N}{*}{\rotatebox{90}{ZS}}
#    and blank out subsequent shot entries.
lines = latex.splitlines()
out_lines = []
in_body = False
current_shot = None
shot_row_count = 0
block_total = df.shape[0]  # number of models per shot block

def shot_label_tex(sh):
    return rf"\multirow{{{block_total}}}{{*}}{{\rotatebox{{90}}{{\emph{{{sh}}}}}}}"

for line in lines:
    # detect start of data body (after \midrule)
    if line.strip() == r"\midrule":
        in_body = True
        current_shot = None
        shot_row_count = 0
        out_lines.append(line)
        continue

    if in_body and line.strip().startswith(r"\bottomrule"):
        in_body = False
        out_lines.append(line)
        continue

    if in_body and "&" in line and line.strip().endswith(r"\\"):
        # row like: ZS & \makecell... & ... \\
        parts = [p.strip() for p in line.split("&")]
        if len(parts) >= 2:
            shot_cell = parts[0]
            # new block start if shot_cell is ZS/FS/MS
            if shot_cell in ["ZS", "FS", "MS"]:
                current_shot = shot_cell
                shot_row_count = 0

            if current_shot in ["ZS", "FS", "MS"]:
                if shot_row_count == 0:
                    parts[0] = shot_label_tex(current_shot)
                else:
                    parts[0] = ""  # blank for subsequent rows in block
                shot_row_count += 1

            out_lines.append(" & ".join(parts))
            continue

    out_lines.append(line)

latex = "\n".join(out_lines)

# 3) Force a clean column spec: Shot col (c), Model col (l), then 6 numeric cols centered
latex = re.sub(
    r"\\begin\{tabular\}\{[^}]*\}",
    r"\\begin{tabular}{@{}clcccccc@{}}",
    latex,
    count=1
)

# 4) Inject tight formatting (8pt + small col sep)
def inject_table_formatting(latex_str: str) -> str:
    lines = latex_str.splitlines()
    out = []
    injected = False
    for line in lines:
        out.append(line)
        if (not injected) and line.strip() == r"\begin{table}":
            out.append(r"\centering")
            out.append(r"\fontsize{8}{9}\selectfont")
            out.append(r"\setlength{\tabcolsep}{2pt}")
            out.append(r"\renewcommand{\arraystretch}{1.05}")
            injected = True
    return "\n".join(out)

latex = inject_table_formatting(latex)

# save
(out_tex := OUTDIR / "segmentation_results_shots_as_rows.tex").write_text(latex)
print(f"[Saved] {out_tex}")


[Saved] /midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables/segmentation_results_shots_as_rows.tex


In [33]:
import re
import numpy as np
import pandas as pd
import warnings
from pathlib import Path

# =========================
# CONFIG
# =========================
OUTDIR = Path("/midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables")
OUTDIR.mkdir(parents=True, exist_ok=True)

DECIMALS = 2
FONTSIZE_CMD = r"\fontsize{8}{9}\selectfont"
TABCOLSEP_PT = 2
ARRAYSTRETCH = 1.05

SHOT_MAP = {"Zero-shot": "ZS", "Few-shot": "FS", "Many-shot": "MS"}
SHOT_ORDER = ["ZS", "FS", "MS"]

DTYPE_ORDER = ["Amyloid Plaque", "Cell Nucleus", "Vessels"]
METRIC_ORDER = ["Tot", "Inst"]  # short names

# If your level-0 dtype names already match these, this is fine.
DTYPE_REMAP = {
    "Amyloid Plaque": "Amyloid Plaque",
    "Cell Nucleus": "Cell Nucleus",
    "Vessels": "Vessels",
}

# Model-grouping and display (order matters inside each shot section)
GROUPS = [
    ("I+T (U/S)",  ["Unet Image+CLIP", "SwinUNETR Image+CLIP"]),
    ("I+T (over)", ["SwinUNETR Image+CLIP (overtrain)"]),
    ("I (U/S)",    ["Unet Image-only", "SwinUNETR Image-only"]),
    ("R (U/S)",    ["Unet Random-init", "SwinUNETR Random-init"]),
    ("uSAM (b/l)", ["microSAM base", "microSAM large"]),
    ("CellSeg3D",  ["CellSeg3D"]),
    ("Cellpose (2D/3D)", ["Cellpose 2D", "Cellpose 3D"]),
]

def wrap_group_name(name: str) -> str:
    mapping = {
        "I+T (U/S)": r"\makecell[l]{I+T\\(U/S)}",
        "I+T (over)": r"\makecell[l]{I+T\\(over)}",
        "I (U/S)": r"\makecell[l]{I\\(U/S)}",
        "R (U/S)": r"\makecell[l]{R\\(U/S)}",
        "uSAM (b/l)": r"\makecell[l]{uSAM\\(b/l)}",
        "CellSeg3D": "CellSeg3D",
        "Cellpose (2D/3D)": r"\makecell[l]{Cellpose\\(2D/3D)}",
    }
    return mapping.get(name, name)

# =========================
# Helpers
# =========================
def _fmt(v: float, decimals: int) -> str:
    if v is None or (isinstance(v, float) and (np.isnan(v) or np.isinf(v))):
        return "--"
    return f"{float(v):.{decimals}f}"

def join_inline(primary: str, secondary: str) -> str:
    """Return 'primary/secondary' with compact handling of '--'."""
    primary = "--" if primary is None else str(primary)
    secondary = "--" if secondary is None else str(secondary)

    if primary == "--" and secondary == "--":
        return "--"
    if primary == "--":
        return f"--/{secondary}"
    if secondary == "--":
        return f"{primary}/--"
    return f"{primary}/{secondary}"

def inject_table_formatting(
    latex_str: str,
    add_centering: bool = True,
    fontsize_cmd: str = FONTSIZE_CMD,
    tabcolsep_pt: int = TABCOLSEP_PT,
    arraystretch: float = ARRAYSTRETCH,
) -> str:
    lines = latex_str.splitlines()
    out = []
    injected = False
    for line in lines:
        out.append(line)
        s = line.strip()
        if (not injected) and s.startswith(r"\begin{table}"):
            if add_centering:
                out.append(r"\centering")
            out.append(fontsize_cmd)
            out.append(rf"\setlength{{\tabcolsep}}{{{tabcolsep_pt}pt}}")
            out.append(rf"\renewcommand{{\arraystretch}}{{{arraystretch}}}")
            injected = True
    if not injected:
        warnings.warn("Could not inject formatting: no \\begin{table} found.")
    return "\n".join(out)

# =========================
# 1) Convert your wide df -> numeric long df: index=(Shot, Model), cols=(Datatype, Metric)
# =========================
def wide_to_long(df_wide: pd.DataFrame) -> pd.DataFrame:
    if not isinstance(df_wide.columns, pd.MultiIndex) or df_wide.columns.nlevels != 3:
        raise ValueError("Expected `df` columns MultiIndex with 3 levels: (Datatype, Shot, Metric).")

    dtypes = list(df_wide.columns.get_level_values(0).unique())
    shots_full = list(df_wide.columns.get_level_values(1).unique())

    rows = []
    idx = []
    for shot_full in shots_full:
        shot = SHOT_MAP.get(shot_full, shot_full)
        for model in df_wide.index:
            row = {}
            for dt in dtypes:
                dt_disp = DTYPE_REMAP.get(dt, dt)
                row[(dt_disp, "Tot")] = df_wide.loc[model, (dt, shot_full, "Tot Dice")] if (dt, shot_full, "Tot Dice") in df_wide.columns else np.nan
                row[(dt_disp, "Inst")] = df_wide.loc[model, (dt, shot_full, "Inst Dice")] if (dt, shot_full, "Inst Dice") in df_wide.columns else np.nan
            rows.append(row)
            idx.append((shot, model))

    out = pd.DataFrame(rows, index=pd.MultiIndex.from_tuples(idx, names=["Shot", "Model"]))
    out.columns = pd.MultiIndex.from_tuples(out.columns, names=["Datatype", "Metric"])
    out = out.reindex(columns=pd.MultiIndex.from_product([DTYPE_ORDER, METRIC_ORDER]))
    # Ensure shot order ZS/FS/MS
    out = out.reindex(pd.MultiIndex.from_product([SHOT_ORDER, df_wide.index.tolist()], names=["Shot", "Model"]))
    return out

# =========================
# 2) Combine models within each shot into groups
# =========================
def combine_models_by_shot(df_long_num: pd.DataFrame) -> (pd.DataFrame, pd.DataFrame):
    """
    Returns (primary_num, secondary_num) with index=(Shot, GroupName) and same columns.
    primary is first member, secondary is second member (for U/S, b/l, 2D/3D).
    """
    cols = df_long_num.columns
    idx_out = pd.MultiIndex.from_product([SHOT_ORDER, [g[0] for g in GROUPS]], names=["Shot", "Group"])
    primary = pd.DataFrame(index=idx_out, columns=cols, dtype=float)
    secondary = pd.DataFrame(index=idx_out, columns=cols, dtype=float)

    for shot in SHOT_ORDER:
        for group_name, members in GROUPS:
            out_idx = (shot, group_name)
            if len(members) == 1:
                m = members[0]
                if (shot, m) in df_long_num.index:
                    primary.loc[out_idx] = df_long_num.loc[(shot, m)].values
                else:
                    primary.loc[out_idx] = np.nan
                secondary.loc[out_idx] = np.nan
            else:
                m1, m2 = members[0], members[1]
                primary.loc[out_idx] = df_long_num.loc[(shot, m1)].values if (shot, m1) in df_long_num.index else np.nan
                secondary.loc[out_idx] = df_long_num.loc[(shot, m2)].values if (shot, m2) in df_long_num.index else np.nan

    return primary, secondary

# =========================
# 3) Bold best within each shot section for each column (Datatype, Metric)
# =========================
def format_and_bold_within_shot(primary_num: pd.DataFrame, decimals: int = 2) -> pd.DataFrame:
    out = pd.DataFrame(index=primary_num.index, columns=primary_num.columns, dtype=object)

    for shot in SHOT_ORDER:
        shot_rows = [idx for idx in primary_num.index if idx[0] == shot]
        block = primary_num.loc[shot_rows]

        for col in primary_num.columns:
            s = block[col]
            # format
            formatted = s.map(lambda v: _fmt(v, decimals)).astype(object)

            if not s.dropna().empty:
                m = float(s.max(skipna=True))
                bold_mask = s.notna() & np.isclose(s.astype(float), m, rtol=0, atol=1e-12)
                formatted.loc[bold_mask] = formatted.loc[bold_mask].map(lambda x: rf"\textbf{{{x}}}")

            out.loc[shot_rows, col] = formatted.values

    out = out.fillna("--")
    return out

def format_secondary(secondary_num: pd.DataFrame, decimals: int = 2) -> pd.DataFrame:
    out = secondary_num.map(lambda v: _fmt(v, decimals)).astype(object)
    return out.fillna("--")

def join_primary_secondary(primary_str: pd.DataFrame, secondary_str: pd.DataFrame) -> pd.DataFrame:
    out = primary_str.copy().astype(object)
    for idx in out.index:
        joined_row = []
        for col in out.columns:
            joined_row.append(join_inline(primary_str.loc[idx, col], secondary_str.loc[idx, col]))
        out.loc[idx] = joined_row
    return out

# =========================
# 4) Build final (Shot sections + models) table and export
# =========================
df_long = wide_to_long(df.copy().round(DECIMALS))
primary_num, secondary_num = combine_models_by_shot(df_long)

primary_str = format_and_bold_within_shot(primary_num, decimals=DECIMALS)
secondary_str = format_secondary(secondary_num, decimals=DECIMALS)
df_final = join_primary_secondary(primary_str, secondary_str)

# Wrap group names for LaTeX and reorder rows: ZS block then FS then MS, each with GROUPS order
group_order = [g[0] for g in GROUPS]
wrapped_group_order = [wrap_group_name(g) for g in group_order]

df_final.index = pd.MultiIndex.from_tuples(
    [(shot, wrap_group_name(group)) for (shot, group) in df_final.index],
    names=["Shot", "Model"]
)

row_order = [(shot, wrap_group_name(g)) for shot in SHOT_ORDER for g in group_order]
df_final = df_final.reindex(pd.MultiIndex.from_tuples(row_order, names=["Shot", "Model"]))

# Save CSV (string table)
df_final.to_csv(OUTDIR / "segmentation_results_shot_sections.csv")

# Make LaTeX
latex_str = df_final.to_latex(
    escape=False,
    multicolumn=True,
    multirow=True,
    caption=(
        "Segmentation performance (Tot Dice and Inst Dice) for ZS (zero-shot), FS (few-shot, ntr5), "
        "and MS (many-shot, ntr15). Within each shot section, the best value in each column is bolded. "
        "For combined rows (U/S, b/l, 2D/3D), cells show first/second (e.g., UNet/Swin)."
    ),
    label="tab:segmentation_results_shot_sections",
    bold_rows=False,
    longtable=False,
    index=True,
)

# Center datatype headers (each datatype spans 2 columns: Tot + Inst)
latex_str = latex_str.replace(r"\multicolumn{2}{r}{Amyloid Plaque}", r"\multicolumn{2}{c}{Amyloid Plaque}")
latex_str = latex_str.replace(r"\multicolumn{2}{r}{Cell Nucleus}", r"\multicolumn{2}{c}{Cell Nucleus}")
latex_str = latex_str.replace(r"\multicolumn{2}{r}{Vessels}", r"\multicolumn{2}{c}{Vessels}")

# Reduce side padding in tabular spec: wrap {..} with @{} .. @{}
latex_str = re.sub(
    r"\\begin\{tabular\}\{([^}]*)\}",
    r"\\begin{tabular}{@{}\1@{}}",
    latex_str,
    count=1
)

# Inject 8pt + spacing
latex_str = inject_table_formatting(latex_str, add_centering=True)

# Write tex
out_tex = OUTDIR / "segmentation_results_shot_sections.tex"
out_tex.write_text(latex_str)

print(f"[Saved] {OUTDIR/'segmentation_results_shot_sections.csv'}")
print(f"[Saved] {OUTDIR/'segmentation_results_shot_sections.tex'}")

# Show in notebook
display(df_final)


[Saved] /midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables/segmentation_results_shot_sections.csv
[Saved] /midtier/paetzollab/scratch/ads4015/lsm_fm_paper/tables/segmentation_results_shot_sections.tex


Unnamed: 0_level_0,Unnamed: 1_level_0,Amyloid Plaque,Amyloid Plaque,Cell Nucleus,Cell Nucleus,Vessels,Vessels
Unnamed: 0_level_1,Unnamed: 1_level_1,Tot,Inst,Tot,Inst,Tot,Inst
Shot,Model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
ZS,\makecell[l]{I+T\\(U/S)},0.41/0.47,0.02/0.12,0.46/0.72,0.06/0.73,\textbf{0.56}/0.68,0.09/0.48
ZS,\makecell[l]{I+T\\(over)},--,--,--,--,--,--
ZS,\makecell[l]{I\\(U/S)},0.43/0.48,\textbf{0.15}/0.11,\textbf{0.55}/0.73,\textbf{0.27}/0.77,0.53/0.68,\textbf{0.32}/0.47
ZS,\makecell[l]{R\\(U/S)},\textbf{0.47}/0.48,0.00/0.19,0.48/0.73,0.00/0.76,0.54/0.69,0.00/0.46
ZS,\makecell[l]{uSAM\\(b/l)},--,--,--,--,--,--
ZS,CellSeg3D,--,--,--,--,--,--
ZS,\makecell[l]{Cellpose\\(2D/3D)},--,--,--,--,--,--
FS,\makecell[l]{I+T\\(U/S)},\textbf{0.68}/0.49,0.58/0.01,0.78/0.79,0.91/0.95,\textbf{0.86}/0.86,0.76/0.73
FS,\makecell[l]{I+T\\(over)},0.61/--,\textbf{0.61}/--,0.78/--,\textbf{0.94}/--,\textbf{0.86}/--,0.81/--
FS,\makecell[l]{I\\(U/S)},0.50/0.56,0.03/0.24,\textbf{0.80}/0.79,0.87/0.93,0.61/0.83,0.56/0.74
