# Fusion evaluation: HSI √ó ResNet-18 (10 runs)

This notebook evaluates **late fusion** between:

- **Vision expert:** 10 independently trained ResNet-18 models (same architecture, different seeds)
- **Context expert:** XGBoost Habitat Suitability Index (HSI)
- **Fusion rule:** Product of Experts (PoE):  
$$
p(y \mid x, c) \propto p_{\text{vis}}(y \mid x)\, p_{\text{ctx}}(y \mid c)
$$

What you get:
- Accuracy (Top-1 / Top-3 / Top-5) for **ResNet**, **HSI**, **Fused**
- Calibration metrics for **ResNet** and **Fused**: ECE (top-label), multi-class Brier, log-loss
- McNemar test (paired) per run: ResNet vs Fused
- Mean confusion matrices across all 10 runs
- Clean, repo-relative paths (works even if you run from `notebooks/`)

**Expected repo layout (yours):**
- Models: `models/vision/*.pth`
- Temperatures: `models/vision/temperatures/temperature_<run_tag>.npy`
- HSI model + mapping: `models/context/`
- Outputs: `outputs/fusion_poe/{figures,metrics,preds}`
- Data: `data/amsterdam/images_no_vespula/test2` and metadata parquet under `data/amsterdam/val/`


In [None]:
import sys, platform
import numpy as np
import pandas as pd
import torch

print("Python:", sys.version.split()[0])
print("Platform:", platform.platform())
print("Torch:", torch.__version__)
print("MPS available:", getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available())
print("CUDA available:", torch.cuda.is_available())

# Optional: reproducibility baseline
np.random.seed(42)
torch.manual_seed(42)


## 1) Configuration (repo-relative paths)

This cell finds the repo root via `pyproject.toml` and defines all paths relative to it.

In [None]:
from __future__ import annotations

from pathlib import Path
import numpy as np
import pandas as pd
import torch  # ‚úÖ needed for DEVICE

# Single source of truth for paths
from digital_naturalist.paths import load_paths

# -------------------------------------------------------------------
# Load paths (repo-relative, portable)
# -------------------------------------------------------------------
P = load_paths("configs/paths.yaml")

# -------------------------------------------------------------------
# Outputs (fusion)
# -------------------------------------------------------------------
OUT_DIR = P["OUT_FUSION_POE"]
FIG_DIR = OUT_DIR / "figures"
MET_DIR = OUT_DIR / "metrics"
PRD_DIR = OUT_DIR / "preds"
REPO_ROOT = P["REPO_ROOT"]

# Shared logs directory (repo-wide)
LOG_DIR = P["OUTPUTS_ROOT"] / "logs"

# -------------------------------------------------------------------
# Models
# -------------------------------------------------------------------
VISION_MODEL_DIR = P["VISION_MODEL_DIR"]
VISION_TEMPS_DIR = P["VISION_TEMPS_DIR"]
CONTEXT_MODEL_DIR = P["CONTEXT_MODEL_DIR"]

# Context model artifacts (filenames live inside models/context)
HSI_MODEL_PATH = CONTEXT_MODEL_DIR / "xgboost_hsi_model_FINAL_no_vespula.json"
FEATURES_PATH = CONTEXT_MODEL_DIR / "feature_names_FINAL_no_vespula.csv"
SPECIES_MAP_PATH = CONTEXT_MODEL_DIR / "species_mapping_FINAL_no_vespula.csv"

# Backwards-compatible aliases (old cells may use these names)
FEATURE_NAMES_PATH = FEATURES_PATH
SPECIES_MAPPING_PATH = SPECIES_MAP_PATH  # ‚úÖ some cells use this name

# -------------------------------------------------------------------
# Choose which image split to fuse
# Default: temporal holdout "test2" if present, else fallback to "test"
# -------------------------------------------------------------------
TEST_SPLIT_NAME = "test2"  # ‚úÖ used in logs/labels in many notebooks

IMAGE_SPLIT_DIR = P.get("IMAGE_TEST2_DIR", None)
if IMAGE_SPLIT_DIR is None:
    IMAGE_SPLIT_DIR = P["IMAGE_TEST_DIR"]
    TEST_SPLIT_NAME = "test"

# Backwards-compatible alias used in older pairing code
IMAGE_ROOT = IMAGE_SPLIT_DIR  # ‚úÖ your old code uses IMAGE_ROOT

# -------------------------------------------------------------------
# Metadata pool for synthetic pairing (GBIF val parquet)
# (Used to sample context rows per species to match images 1:1)
# -------------------------------------------------------------------
GBIF_VAL_DIR = P["GBIF_VAL_DIR"]
GBIF_VAL_PARQUET = GBIF_VAL_DIR / "observations_filtered_50m_accuracy.parquet"

# Backwards-compatible alias 
METADATA_PATH = GBIF_VAL_PARQUET  

# -------------------------------------------------------------------
# Run / fusion settings (often assumed later)
# -------------------------------------------------------------------
NUM_CLASSES = 8
BATCH_SIZE = 32

GRID_DECIMALS = 2  # MUST match HSI training
SCALE = 10 ** GRID_DECIMALS

SEED_PAIRING = 42  # synthetic image‚Üîmetadata pairing seed

DEVICE = torch.device(
    "mps" if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)

# -------------------------------------------------------------------
# Ensure output dirs exist
# -------------------------------------------------------------------
for d in (FIG_DIR, MET_DIR, PRD_DIR, LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

# -------------------------------------------------------------------
# Sanity checks (fail early with clear error)
# -------------------------------------------------------------------
assert IMAGE_SPLIT_DIR.exists(), f"Missing image split dir: {IMAGE_SPLIT_DIR}"
assert VISION_MODEL_DIR.exists(), f"Missing vision model dir: {VISION_MODEL_DIR}"
assert VISION_TEMPS_DIR.exists(), f"Missing vision temps dir: {VISION_TEMPS_DIR}"
assert CONTEXT_MODEL_DIR.exists(), f"Missing context model dir: {CONTEXT_MODEL_DIR}"

# These three may be absent if you're preparing a ‚Äúno-models‚Äù repo,
# but for running fusion locally they must exist:
assert HSI_MODEL_PATH.exists(), f"Missing HSI model: {HSI_MODEL_PATH}"
assert FEATURES_PATH.exists(), f"Missing feature names CSV: {FEATURES_PATH}"
assert SPECIES_MAP_PATH.exists(), f"Missing species mapping CSV: {SPECIES_MAP_PATH}"
assert GBIF_VAL_PARQUET.exists(), f"Missing GBIF val parquet: {GBIF_VAL_PARQUET}"

# -------------------------------------------------------------------
# Printout
# -------------------------------------------------------------------
print("\n=== Fusion PoE ‚Äî resolved paths ===")
print("Config:              configs/paths.yaml")
print("Device:              ", DEVICE)
print("Test split name:     ", TEST_SPLIT_NAME)
print("Image split dir:     ", IMAGE_SPLIT_DIR)
print("GBIF val parquet:    ", GBIF_VAL_PARQUET)
print("Vision models dir:   ", VISION_MODEL_DIR)
print("Vision temps dir:    ", VISION_TEMPS_DIR)
print("Context model dir:   ", CONTEXT_MODEL_DIR)
print("HSI model:           ", HSI_MODEL_PATH)
print("Feature names:       ", FEATURES_PATH)
print("Species mapping:     ", SPECIES_MAP_PATH)
print("Outputs/figures:     ", FIG_DIR)
print("Outputs/metrics:     ", MET_DIR)
print("Outputs/preds:       ", PRD_DIR)
print("Logs:                ", LOG_DIR)


## 2) Load wild test images and metadata pool

We have:
- **Images:** labeled wild camera images (Aarhus rooftop camera) stored in an `ImageFolder` structure.
- **Metadata pool:** Amsterdam GBIF observations with weather + habitat features.

**Important:** The pairing of each test image to a metadata row is *synthetic* (matched only by species) to test whether ecological priors can rescue visually ambiguous cases under cross-sensor domain shift.

In [None]:
import numpy as np
import pandas as pd
from torchvision import datasets, transforms

# -------------------------
# Synthetic pairing set-up
# -------------------------
SEED_PAIRING = 42  # controls deterministic image‚Üîcontext pairing
rng = np.random.default_rng(SEED_PAIRING)

# Species list (must match BOTH: ImageFolder class folders AND GBIF parquet 'species' names)
WILD_SPECIES = [
    "Apis mellifera", "Eristalis tenax",
    "Bombus terrestris", "Coccinella septempunctata",
    "Bombus lapidarius", "Episyrphus balteatus",
    "Aglais urticae", "Eupeodes corollae",
]

# Load metadata pool (context rows)
metadata = pd.read_parquet(GBIF_VAL_PARQUET)
metadata = metadata[metadata["species"].isin(WILD_SPECIES)].copy()
print("Metadata rows:", len(metadata), "| parquet:", GBIF_VAL_PARQUET)

# Load images ONLY from the chosen split folder (e.g., .../images_no_vespula/test2/<class>/...)
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_dataset = datasets.ImageFolder(str(IMAGE_SPLIT_DIR), transform=test_transform)
print("Images in split:", len(test_dataset), "| split:", IMAGE_SPLIT_DIR)
print("ImageFolder classes:", [c.replace("_", " ") for c in test_dataset.classes])

# Build a dataframe of image paths + species (from ImageFolder labels)
image_df = pd.DataFrame(
    [{
        "image_path": p,
        "species": test_dataset.classes[y].replace("_", " "),
        "resnet_label_idx": int(y),  # index in ImageFolder class order
    } for p, y in test_dataset.samples]
)

# Synthetic 1:1 pairing by species (sample metadata WITH replacement if needed)
matched_rows = []
for sp in WILD_SPECIES:
    imgs = image_df[image_df["species"] == sp]
    metas = metadata[metadata["species"] == sp]
    if len(imgs) == 0 or len(metas) == 0:
        print(f"‚ö†Ô∏è  Skipping species (missing imgs or metas): {sp} | imgs={len(imgs)} metas={len(metas)}")
        continue

    replace = len(metas) < len(imgs)
    chosen = rng.choice(metas.index.to_numpy(), size=len(imgs), replace=replace)

    for img_path, meta_idx in zip(imgs["image_path"].tolist(), chosen.tolist()):
        matched_rows.append({
            "image_path": img_path,
            "species": sp,
            "metadata_index": int(meta_idx),
        })

matched_df = pd.DataFrame(matched_rows).reset_index(drop=True)
print("Matched pairs:", len(matched_df))

# Materialize metadata rows in the same order as matched_df
metadata_for_hsi = metadata.loc[matched_df["metadata_index"].to_numpy()].reset_index(drop=True)
assert len(metadata_for_hsi) == len(matched_df)


## 3) Load HSI model and compute context probabilities

The `engineer_hsi_features` function **must match** the feature engineering used in your HSI training script. This is the same logic you used before (kept intact).

In [None]:
import numpy as np
import pandas as pd
import xgboost as xgb

# Load context model
hsi_model = xgb.XGBClassifier()
hsi_model.load_model(str(HSI_MODEL_PATH))

feature_names = pd.read_csv(FEATURE_NAMES_PATH)
# accept either column name 'feature' or first column
if "feature" in feature_names.columns:
    feature_list = feature_names["feature"].tolist()
else:
    feature_list = feature_names.iloc[:, 0].tolist()

species_mapping = pd.read_csv(SPECIES_MAPPING_PATH)
species_to_idx = dict(zip(species_mapping["species"], species_mapping["idx"]))
idx_to_species = dict(zip(species_mapping["idx"], species_mapping["species"]))

print("HSI features:", len(feature_list))
print("Species mapping:", len(species_to_idx))

# -------- Feature engineering (must match training) --------
def engineer_hsi_features(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    # Spatial bins (lat_bin / lon_bin)
    lat = pd.to_numeric(df["final_latitude"], errors="coerce")
    lon = pd.to_numeric(df["final_longitude"], errors="coerce")
    df["lat_bin"] = np.round(lat * SCALE).astype("Int64")
    df["lon_bin"] = np.round(lon * SCALE).astype("Int64")

    # Temporal features
    df["hour_sin"] = np.sin(2 * np.pi * df["hour_local"] / 24)
    df["hour_cos"] = np.cos(2 * np.pi * df["hour_local"] / 24)
    df["week_of_year"] = pd.to_datetime(df["obs_dt_utc"], unit="ms").dt.isocalendar().week
    df["week_sin"] = np.sin(2 * np.pi * df["week_of_year"] / 52)
    df["week_cos"] = np.cos(2 * np.pi * df["week_of_year"] / 52)

    day_length = 12 + 6 * np.sin(2 * np.pi * (df["doy"] - 80) / 365)
    sunrise_hour = 12 - day_length / 2
    sunset_hour = 12 + day_length / 2
    df["hours_since_sunrise"] = df["hour_local"] - sunrise_hour
    df["hours_until_sunset"] = sunset_hour - df["hour_local"]
    df["is_golden_hour"] = ((df["hours_since_sunrise"] < 2) | (df["hours_until_sunset"] < 2)).astype(int)

    df["is_spring"] = df["obs_month"].isin([3, 4, 5]).astype(int)
    df["is_summer"] = df["obs_month"].isin([6, 7, 8]).astype(int)
    df["is_fall"] = df["obs_month"].isin([9, 10]).astype(int)

    # Weather features
    df["is_optimal_temp"] = ((df["temp_c"] >= 15) & (df["temp_c"] <= 28)).astype(int)
    df["temp_squared"] = df["temp_c"] ** 2
    df["is_humid"] = (df["rhum"] > 70).astype(int)
    df["is_dry"] = (df["rhum"] < 40).astype(int)
    df["is_calm"] = (df["wspd_ms"] < 3).astype(int)
    df["is_windy"] = (df["wspd_ms"] > 7).astype(int)
    df["has_rain"] = (df["prcp_mm"] > 0.5).astype(int)
    df["is_sunny"] = (df["cloud_cover"] < 30).astype(int)
    df["is_overcast"] = (df["cloud_cover"] > 70).astype(int)
    df["swrad_per_hour"] = df["swrad"] / np.maximum(day_length, 1)

    # Habitat composition
    for radius in [10, 50, 100, 250]:
        df[f"vegetation_total_{radius}"] = (
            df[f"wc{radius}_tree"] + df[f"wc{radius}_shrub"] + df[f"wc{radius}_grass"]
        )
        df[f"natural_total_{radius}"] = (
            df[f"wc{radius}_tree"] + df[f"wc{radius}_shrub"] +
            df[f"wc{radius}_grass"] + df[f"wc{radius}_herb_wetland"]
        )
        df[f"impervious_{radius}"] = df[f"wc{radius}_builtup"] + df[f"wc{radius}_bare"]

    # Habitat diversity
    for radius in [10, 50, 100, 250]:
        habitat_cols = [
            f"wc{radius}_tree", f"wc{radius}_shrub", f"wc{radius}_grass",
            f"wc{radius}_cropland", f"wc{radius}_builtup", f"wc{radius}_water"
        ]
        habitat_matrix = df[habitat_cols].values + 1e-6
        habitat_matrix = habitat_matrix / habitat_matrix.sum(axis=1, keepdims=True)
        shannon = -np.sum(habitat_matrix * np.log(habitat_matrix), axis=1)

        df[f"habitat_diversity_{radius}"] = shannon
        df[f"habitat_richness_{radius}"] = (df[habitat_cols] > 0.05).sum(axis=1)
        df[f"habitat_dominance_{radius}"] = df[habitat_cols].max(axis=1)

    # Cross-scale gradients
    df["vegetation_gradient_10_50"] = df["vegetation_total_10"] - df["vegetation_total_50"]
    df["vegetation_gradient_50_250"] = df["vegetation_total_50"] - df["vegetation_total_250"]
    df["urban_gradient_10_50"] = df["wc10_builtup"] - df["wc50_builtup"]
    df["urban_gradient_50_250"] = df["wc50_builtup"] - df["wc250_builtup"]
    df["water_gradient_10_100"] = df["wc10_water"] - df["wc100_water"]
    df["tree_gradient_10_100"] = df["wc10_tree"] - df["wc100_tree"]

    # Weather √ó habitat interactions
    df["temp_x_vegetation_50"] = df["temp_c"] * df["vegetation_total_50"]
    df["temp_x_builtup_50"] = df["temp_c"] * df["wc50_builtup"]
    df["temp_x_water_50"] = df["temp_c"] * df["wc50_water"]
    df["humidity_x_vegetation_50"] = df["rhum"] * df["vegetation_total_50"]
    df["humidity_x_wetland_50"] = df["rhum"] * df["wc50_herb_wetland"]
    df["wind_x_vegetation_100"] = df["wspd_ms"] * df["vegetation_total_100"]
    df["wind_x_tree_shelter"] = df["wspd_ms"] * df["wc100_tree"]
    df["solar_x_vegetation"] = df["swrad"] * df["vegetation_total_50"]
    df["solar_x_builtup"] = df["swrad"] * df["wc50_builtup"]
    df["vpd_x_vegetation"] = df["vpd_kpa"] * df["vegetation_total_50"]
    df["vpd_x_water_proximity"] = df["vpd_kpa"] * (1 - df["wc50_water"])

    # Urban context
    df["urban_heat_index"] = df["wc50_builtup"] * 2 + df["wc250_builtup"] - 0.5 * df["vegetation_total_50"]
    df["floral_resources"] = (
        df["wc10_grass"] * 0.5 + df["wc50_grass"] * 1.0 +
        df["wc50_shrub"] * 1.5 + df["wc50_cropland"] * 0.8
    )
    df["cavity_nesting_habitat"] = df["wc50_tree"] + df["wc50_builtup"] * 0.2
    df["ground_nesting_habitat"] = df["wc10_grass"] + df["wc10_bare"] * 0.5
    df["habitat_edges_50"] = df["habitat_richness_50"] * df["habitat_diversity_50"]

    # Temporal √ó habitat interactions
    df["spring_x_vegetation"] = df["is_spring"] * df["vegetation_total_50"]
    df["summer_x_water"] = df["is_summer"] * df["wc50_water"]
    df["morning_x_flowers"] = (df["hour_local"] < 12).astype(int) * df["floral_resources"]
    df["afternoon_x_flowers"] = (df["hour_local"] >= 12).astype(int) * df["floral_resources"]

    # Spatial precision
    df["log_gps_accuracy"] = np.log1p(df["final_accuracy_m"])
    df["is_precise"] = (df["final_accuracy_m"] <= 10).astype(int)

    return df

metadata_engineered = engineer_hsi_features(metadata_for_hsi)

# Ensure we only use the trained feature set
missing_feats = [f for f in feature_list if f not in metadata_engineered.columns]
if missing_feats:
    raise KeyError(f"Missing engineered features (first 10): {missing_feats[:10]}")

X_hsi = metadata_engineered[feature_list].fillna(0).to_numpy()
hsi_probs = hsi_model.predict_proba(X_hsi)
print("HSI probs:", hsi_probs.shape)

# y_true in HSI index-space
y_true = np.array([species_to_idx[s] for s in matched_df["species"].tolist()], dtype=int)
assert hsi_probs.shape[0] == len(y_true)
assert hsi_probs.shape[1] == NUM_CLASSES


## 4) Helpers (metrics, plotting, McNemar, ResNet prediction)

ResNet outputs are re-ordered to the **HSI index order** before fusion, to guarantee the elementwise product is aligned.

In [None]:
from __future__ import annotations
import math
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from sklearn.metrics import confusion_matrix, log_loss

# Optional: McNemar p-values
try:
    from scipy.stats import chi2
    SCIPY_AVAILABLE = True
except Exception:
    SCIPY_AVAILABLE = False

def topk_accuracy(y_true: np.ndarray, probs: np.ndarray, k: int) -> float:
    k = int(min(k, probs.shape[1]))
    topk = np.argsort(probs, axis=1)[:, -k:]
    return float(np.mean([yt in row for yt, row in zip(y_true, topk)]))

def multiclass_brier_score(y_true: np.ndarray, probs: np.ndarray) -> float:
    y_true = np.asarray(y_true, dtype=int)
    probs = np.asarray(probs, dtype=float)
    n, k = probs.shape
    onehot = np.zeros((n, k), dtype=float)
    onehot[np.arange(n), y_true] = 1.0
    return float(np.mean((probs - onehot) ** 2))

def ece_toplabel(y_true: np.ndarray, probs: np.ndarray, n_bins: int = 15) -> float:
    y_true = np.asarray(y_true, dtype=int)
    probs = np.asarray(probs, dtype=float)
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == y_true).astype(float)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for b in range(n_bins):
        lo, hi = bins[b], bins[b + 1]
        mask = (conf >= lo) & (conf < hi) if b < n_bins - 1 else (conf >= lo) & (conf <= hi)
        if not np.any(mask):
            continue
        ece += mask.mean() * abs(acc[mask].mean() - conf[mask].mean())
    return float(ece)

def mcnemar_test(y_true: np.ndarray, y_a: np.ndarray, y_b: np.ndarray):
    """Paired McNemar test between predictions y_a (ResNet) and y_b (Fused)."""
    a_correct = (y_a == y_true)
    b_correct = (y_b == y_true)

    n11 = int(np.sum(a_correct & b_correct))
    n10 = int(np.sum(a_correct & ~b_correct))
    n01 = int(np.sum(~a_correct & b_correct))
    n00 = int(np.sum(~a_correct & ~b_correct))

    if (n01 + n10) > 0:
        chi2_cc = (abs(n01 - n10) - 1) ** 2 / (n01 + n10)
    else:
        chi2_cc = 0.0

    p_val = float(chi2.sf(chi2_cc, df=1)) if (SCIPY_AVAILABLE and (n01 + n10) > 0) else float("nan")
    return n11, n10, n01, n00, float(chi2_cc), p_val

# ---------- plotting (seaborn heatmap) ----------
import matplotlib.pyplot as plt
import seaborn as sns

def _pretty_class_name(name: str) -> str:
    return name.replace("_", " ")

def save_confusion_matrix_png(
    cm: np.ndarray,
    class_names: list[str],
    out_path: Path,
    normalize: bool = False,
    title: str | None = None,
    *,
    cmap: str = "Blues",
    annot_fontsize: int = 12,
    row_norm_colours: bool = True,
) -> None:
    """Save confusion matrix plot.

    - If normalize=False: annotate with integer counts.
      Colours are row-normalised by default (like your thesis figure) so rows sum to 1 visually.
    - If normalize=True: annotate with row-normalised percentages and colours match those percentages.

    Parameters
    ----------
    row_norm_colours : bool
        If True and normalize=False, colours use row-normalised values while annotations show counts.
    """
    cm = np.asarray(cm, dtype=float)

    # Row-normalised matrix for colouring
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = cm / np.maximum(row_sums, 1e-12)

    if normalize:
        data = cm_norm
        annot = cm_norm
        fmt = ".1%"
        cbar_label = "Row-normalised proportion"
    else:
        data = cm_norm if row_norm_colours else cm
        annot = cm.astype(int)
        fmt = "d"
        cbar_label = "Row-normalised proportion" if row_norm_colours else "Count"

    tick_labels = [_pretty_class_name(c) for c in class_names]

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        data,
        ax=ax,
        cmap=cmap,
        vmin=0.0,
        vmax=1.0 if (normalize or row_norm_colours) else None,
        square=True,
        annot=annot,
        fmt=fmt,
        annot_kws={"size": annot_fontsize},
        linewidths=0.5,
        linecolor="white",
        cbar_kws={"label": cbar_label},
        xticklabels=tick_labels,
        yticklabels=tick_labels,
    )

    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title(title or ("Confusion matrix (row-normalised)" if normalize else "Confusion matrix (counts annotated; colours row-normalised)"))
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
# ---------- dataset for matched image paths ----------
class ImagePathDataset(Dataset):
    def __init__(self, image_paths: list[str], transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        from PIL import Image
        p = self.image_paths[idx]
        img = Image.open(p).convert("RGB")
        x = self.transform(img)
        return x, idx  # idx to keep ordering

def load_resnet18(num_classes: int, weights_path: Path, device: torch.device) -> nn.Module:
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device)
    model.eval()
    return model

@torch.no_grad()
def resnet_predict_proba_in_hsi_order(
    weights_path: Path,
    image_paths: list[str],
    transform,
    device: torch.device,
    resnet_class_species: list[str],
    species_to_idx: dict[str, int],
    temperature: float | None = None,
    batch_size: int = 32,
):
    """Return probabilities in HSI index order (columns aligned to species_mapping idx)."""
    model = load_resnet18(NUM_CLASSES, weights_path, device)

    ds = ImagePathDataset(image_paths, transform)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)

    probs_resnet_order = np.zeros((len(image_paths), NUM_CLASSES), dtype=np.float32)

    for xb, idxs in dl:
        xb = xb.to(device)
        logits = model(xb)
        if temperature is not None:
            logits = logits / float(temperature)
        probs = F.softmax(logits, dim=1).detach().cpu().numpy()
        probs_resnet_order[idxs.numpy()] = probs

    # Map ResNet column order -> HSI idx order
    # resnet_class_species: list in resnet/ImageFolder class order
    perm = [species_to_idx[s] for s in resnet_class_species]  # length NUM_CLASSES
    probs_hsi_order = np.zeros_like(probs_resnet_order)
    for resnet_col, hsi_col in enumerate(perm):
        probs_hsi_order[:, hsi_col] = probs_resnet_order[:, resnet_col]

    return probs_hsi_order


## 5) Run fusion across all 10 ResNet models

This will:
- load each `models/vision/*.pth`
- load matching temperature file if present (`models/vision/temperatures/temperature_<run_tag>.npy`)
- compute ResNet probs, fuse with HSI probs, compute metrics + McNemar
- save tables/plots to `outputs/fusion_poe/...`


### Rescued vs Hurt analysis

To make the fusion behaviour interpretable, we decompose the net Top-1 accuracy change into:

- **Rescued:** ResNet wrong but fused correct  
- **Hurt:** ResNet correct but fused wrong  

For each run we report rescued/hurt counts and rates overall, and also aggregated per species (true label).


In [None]:
from __future__ import annotations
from pathlib import Path
import numpy as np
import pandas as pd

# Identify models (only ResNet run checkpoints)
model_files = sorted([p for p in VISION_MODEL_DIR.glob("*.pth") if p.name.startswith("resnet18_run_")])
print(f"Found {len(model_files)} vision checkpoints in {VISION_MODEL_DIR}")

if len(model_files) == 0:
    raise FileNotFoundError(f"No resnet checkpoints found in {VISION_MODEL_DIR}")

# resnet class order: from ImageFolder class order (same as training if folder names match)
resnet_class_species = [c.replace("_", " ") for c in test_dataset.classes]

# Matched image paths in the order used for pairing
image_paths = matched_df["image_path"].tolist()

all_results = []
mcnemar_rows = []
species_improvements = {sp: [] for sp in WILD_SPECIES}

# Rescued / Hurt analysis (per-run + per-species)
rescue_rows = []
rescue_species_rows = []
species_rescued = {sp: [] for sp in WILD_SPECIES}
species_hurt = {sp: [] for sp in WILD_SPECIES}

all_resnet_preds = []
all_fused_preds = []
run_tags = []  # in the same order as model_files

# Optional: save full probabilities (can be large)
SAVE_FULL_PROBS = False

for i, model_path in enumerate(model_files, 1):
    run_tag = model_path.stem  # e.g. resnet18_run_10_seed_51
    run_tags.append(run_tag)
    print("\n" + "="*80)
    print(f"MODEL {i}/{len(model_files)}: {run_tag}")
    print("="*80)

    # Temperature 
    t_path = VISION_TEMPS_DIR / f"temperature_{run_tag}.npy"
    temperature = float(np.load(t_path).ravel()[0]) if t_path.exists() else None
    if temperature is not None:
        print(f"  üå°Ô∏è  Using temperature scaling T*={temperature:.4f}")
    else:
        print("  ‚ö†Ô∏è  No temperature file found -> uncalibrated logits")

    # ResNet probabilities (aligned to HSI idx order)
    resnet_probs = resnet_predict_proba_in_hsi_order(
        weights_path=model_path,
        image_paths=image_paths,
        transform=test_transform,
        device=DEVICE,
        resnet_class_species=resnet_class_species,
        species_to_idx=species_to_idx,
        temperature=temperature,
        batch_size=BATCH_SIZE,
    )

    # Fusion (PoE)
    fused_probs = resnet_probs * hsi_probs
    fused_probs = fused_probs / np.maximum(fused_probs.sum(axis=1, keepdims=True), 1e-12)

    # Predictions
    y_pred_resnet = resnet_probs.argmax(axis=1)
    y_pred_fused = fused_probs.argmax(axis=1)

    all_resnet_preds.append(y_pred_resnet)
    all_fused_preds.append(y_pred_fused)

    # Rescued / Hurt (per-sample)
    res_correct = (y_pred_resnet == y_true)
    fus_correct = (y_pred_fused == y_true)
    rescued_mask = (~res_correct) & fus_correct
    hurt_mask = res_correct & (~fus_correct)

    n_rescued = int(rescued_mask.sum())
    n_hurt = int(hurt_mask.sum())
    n_total = int(len(y_true))
    rescue_rows.append({
        "run_tag": run_tag,
        "rescued_n": n_rescued,
        "hurt_n": n_hurt,
        "net_rescue_n": n_rescued - n_hurt,
        "rescued_rate": n_rescued / n_total if n_total else np.nan,
        "hurt_rate": n_hurt / n_total if n_total else np.nan,
    })

    # Per-species rescued / hurt (Top-1 only)
    species_arr = matched_df["species"].to_numpy()
    for sp in WILD_SPECIES:
        idxs = np.where(species_arr == sp)[0]
        if len(idxs) == 0:
            continue
        r_sp = int(rescued_mask[idxs].sum())
        h_sp = int(hurt_mask[idxs].sum())
        species_rescued[sp].append(r_sp)
        species_hurt[sp].append(h_sp)
        rescue_species_rows.append({
            "run_tag": run_tag,
            "species": sp,
            "n_samples": int(len(idxs)),
            "rescued_n": r_sp,
            "hurt_n": h_sp,
            "net_rescue_n": r_sp - h_sp,
            "rescued_rate": r_sp / len(idxs),
            "hurt_rate": h_sp / len(idxs),
        })


    # Metrics
    res_top1 = topk_accuracy(y_true, resnet_probs, 1)
    res_top3 = topk_accuracy(y_true, resnet_probs, 3)
    res_top5 = topk_accuracy(y_true, resnet_probs, 5)

    hsi_top1 = topk_accuracy(y_true, hsi_probs, 1)

    fus_top1 = topk_accuracy(y_true, fused_probs, 1)
    fus_top3 = topk_accuracy(y_true, fused_probs, 3)
    fus_top5 = topk_accuracy(y_true, fused_probs, 5)

    imp_top1 = (fus_top1 - res_top1) * 100.0

    # Calibration metrics (ResNet vs Fused)
    res_brier = multiclass_brier_score(y_true, resnet_probs)
    fus_brier = multiclass_brier_score(y_true, fused_probs)
    res_ece = ece_toplabel(y_true, resnet_probs, n_bins=15)
    fus_ece = ece_toplabel(y_true, fused_probs, n_bins=15)
    res_ll = float(log_loss(y_true, resnet_probs))
    fus_ll = float(log_loss(y_true, fused_probs))

    # McNemar
    n11, n10, n01, n00, chi2_cc, p_val = mcnemar_test(y_true, y_pred_resnet, y_pred_fused)

    print(f"  ResNet Top-1: {res_top1:.1%} | Fused Top-1: {fus_top1:.1%} | Œî={imp_top1:+.2f} pp")
    print(f"  ECE: {res_ece:.4f} -> {fus_ece:.4f} | Brier: {res_brier:.4f} -> {fus_brier:.4f} | LogLoss: {res_ll:.4f} -> {fus_ll:.4f}")
    if not np.isnan(p_val):
        print(f"  McNemar: chi2_cc={chi2_cc:.3f}, p={p_val:.3e} (n01={n01}, n10={n10})")
    else:
        print(f"  McNemar: chi2_cc={chi2_cc:.3f} (p-value unavailable) (n01={n01}, n10={n10})")

    # Per-species improvements (Top-1)
    for sp in WILD_SPECIES:
        idxs = np.where(matched_df["species"].to_numpy() == sp)[0]
        if len(idxs) == 0:
            continue
        r = topk_accuracy(y_true[idxs], resnet_probs[idxs], 1)
        f = topk_accuracy(y_true[idxs], fused_probs[idxs], 1)
        species_improvements[sp].append((f - r) * 100.0)

    # Save per-run outputs
    run_metrics_dir = MET_DIR / run_tag
    run_fig_dir = FIG_DIR / run_tag
    run_pred_dir = PRD_DIR / run_tag
    run_metrics_dir.mkdir(parents=True, exist_ok=True)
    run_fig_dir.mkdir(parents=True, exist_ok=True)
    run_pred_dir.mkdir(parents=True, exist_ok=True)

    cm_res = confusion_matrix(y_true, y_pred_resnet, labels=list(range(NUM_CLASSES)))
    cm_fus = confusion_matrix(y_true, y_pred_fused, labels=list(range(NUM_CLASSES)))

    np.save(run_metrics_dir / f"confusion_matrix_resnet_{run_tag}.npy", cm_res)
    np.save(run_metrics_dir / f"confusion_matrix_fused_{run_tag}.npy", cm_fus)

    class_names = [idx_to_species[i] for i in range(NUM_CLASSES)]
    save_confusion_matrix_png(cm_res, class_names, run_fig_dir / f"confusion_matrix_resnet_{run_tag}_counts.png", normalize=False,
                              title=f"ResNet confusion (counts) ‚Äì {run_tag}")
    save_confusion_matrix_png(cm_res, class_names, run_fig_dir / f"confusion_matrix_resnet_{run_tag}_norm.png", normalize=True,
                              title=f"ResNet confusion (normalized) ‚Äì {run_tag}")
    save_confusion_matrix_png(cm_fus, class_names, run_fig_dir / f"confusion_matrix_fused_{run_tag}_counts.png", normalize=False,
                              title=f"Fused confusion (counts) ‚Äì {run_tag}")
    save_confusion_matrix_png(cm_fus, class_names, run_fig_dir / f"confusion_matrix_fused_{run_tag}_norm.png", normalize=True,
                              title=f"Fused confusion (normalized) ‚Äì {run_tag}")

    # Prediction table (essential columns)
    pred_df = pd.DataFrame({
        "image_path": [str(Path(p).resolve().relative_to(REPO_ROOT)) if str(p).startswith(str(REPO_ROOT)) else str(p) for p in image_paths],
        "species": matched_df["species"],
        "y_true": y_true,
        "y_pred_resnet": y_pred_resnet,
        "y_pred_fused": y_pred_fused,
        "resnet_conf": resnet_probs.max(axis=1),
        "fused_conf": fused_probs.max(axis=1),
    })
    pred_df.to_csv(run_pred_dir / f"predictions_{run_tag}.csv", index=False)

    if SAVE_FULL_PROBS:
        np.save(run_pred_dir / f"resnet_probs_{run_tag}.npy", resnet_probs.astype(np.float32))
        np.save(run_pred_dir / f"fused_probs_{run_tag}.npy", fused_probs.astype(np.float32))

    all_results.append({
        "run_tag": run_tag,
        "temperature_T": temperature if temperature is not None else np.nan,
        "resnet_top1": res_top1,
        "resnet_top3": res_top3,
        "resnet_top5": res_top5,
        "hsi_top1": hsi_top1,
        "fused_top1": fus_top1,
        "fused_top3": fus_top3,
        "fused_top5": fus_top5,
        "improvement_top1_pp": imp_top1,
        "rescued_n": n_rescued,
        "hurt_n": n_hurt,
        "net_rescue_n": (n_rescued - n_hurt),
        "rescued_rate": (n_rescued / len(y_true)) if len(y_true) else np.nan,
        "hurt_rate": (n_hurt / len(y_true)) if len(y_true) else np.nan,
        "resnet_ece": res_ece,
        "fused_ece": fus_ece,
        "resnet_brier": res_brier,
        "fused_brier": fus_brier,
        "resnet_logloss": res_ll,
        "fused_logloss": fus_ll,
    })

    mcnemar_rows.append({
        "run_tag": run_tag,
        "n11_both_correct": n11,
        "n10_resnet_only": n10,
        "n01_fused_only": n01,
        "n00_both_wrong": n00,
        "chi2_cc": chi2_cc,
        "p_value": p_val,
    })

# Aggregate tables
results_df = pd.DataFrame(all_results).sort_values("run_tag").reset_index(drop=True)
mcnemar_df = pd.DataFrame(mcnemar_rows).sort_values("run_tag").reset_index(drop=True)

results_df.to_csv(MET_DIR / "fusion_results_all_models_with_calibration.csv", index=False)
mcnemar_df.to_csv(MET_DIR / "fusion_mcnemar_all_models.csv", index=False)

# Rescued / Hurt summary tables
rescue_df = pd.DataFrame(rescue_rows).sort_values("run_tag").reset_index(drop=True)
rescue_species_df = pd.DataFrame(rescue_species_rows).sort_values(["species", "run_tag"]).reset_index(drop=True)

rescue_df.to_csv(MET_DIR / "fusion_rescued_hurt_all_models.csv", index=False)
rescue_species_df.to_csv(MET_DIR / "fusion_rescued_hurt_per_species.csv", index=False)

# Quick printed summary (mean ¬± std across runs)
def _mean_std(x):
    x = np.asarray(x, dtype=float)
    return float(np.nanmean(x)), float(np.nanstd(x))

m_r, s_r = _mean_std(rescue_df["rescued_rate"].values)
m_h, s_h = _mean_std(rescue_df["hurt_rate"].values)
m_n, s_n = _mean_std(rescue_df["net_rescue_n"].values)

print("\nRescued/Hurt (mean ¬± std across runs):")
print(f"  rescued_rate: {m_r:.4f} ¬± {s_r:.4f}")
print(f"  hurt_rate   : {m_h:.4f} ¬± {s_h:.4f}")
print(f"  net_rescue_n: {m_n:.2f} ¬± {s_n:.2f}")

print("\nPer-species rescued/hurt (mean ¬± std counts across runs):")
for sp in WILD_SPECIES:
    r = np.asarray(species_rescued[sp], dtype=float)
    h = np.asarray(species_hurt[sp], dtype=float)
    if len(r) == 0:
        continue
    print(f"  {sp:<30} rescued {r.mean():6.2f} ¬± {r.std():5.2f} | hurt {h.mean():6.2f} ¬± {h.std():5.2f}")

print("\nSaved:")
print("-", MET_DIR / "fusion_results_all_models_with_calibration.csv")
print("-", MET_DIR / "fusion_mcnemar_all_models.csv")

# McNemar summary
valid_p = mcnemar_df["p_value"].dropna()
if len(valid_p) > 0:
    num_sig = int((valid_p < 0.05).sum())
    mean_chi2 = float(mcnemar_df["chi2_cc"].mean())
    mean_delta = float((mcnemar_df["n01_fused_only"] - mcnemar_df["n10_resnet_only"]).mean())
    print(f"\nMcNemar significant (p<0.05): {num_sig}/{len(valid_p)}")
    print(f"Mean McNemar chi2_cc: {mean_chi2:.3f}")
    print(f"Mean (fused-only correct - resnet-only correct): {mean_delta:.2f}")
else:
    print("\nMcNemar p-values not available (SciPy not installed or no discordant pairs).")


# Mean confusion matrices
all_resnet_preds = np.stack(all_resnet_preds, axis=0)  # (n_models, n_samples)
all_fused_preds = np.stack(all_fused_preds, axis=0)

cm_res_sum = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=float)
cm_fus_sum = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=float)
for i in range(all_resnet_preds.shape[0]):
    cm_res_sum += confusion_matrix(y_true, all_resnet_preds[i], labels=list(range(NUM_CLASSES)))
    cm_fus_sum += confusion_matrix(y_true, all_fused_preds[i], labels=list(range(NUM_CLASSES)))

cm_res_mean = cm_res_sum / all_resnet_preds.shape[0]
cm_fus_mean = cm_fus_sum / all_fused_preds.shape[0]

np.save(MET_DIR / "confusion_matrix_resnet_MEAN_10runs.npy", cm_res_mean)
np.save(MET_DIR / "confusion_matrix_fused_MEAN_10runs.npy", cm_fus_mean)

class_names = [idx_to_species[i] for i in range(NUM_CLASSES)]
save_confusion_matrix_png(cm_res_mean, class_names, FIG_DIR / "confusion_matrix_resnet_MEAN_10runs_norm.png", normalize=True,
                          title="Mean confusion (normalized) ‚Äì ResNet (10 runs)")
save_confusion_matrix_png(cm_fus_mean, class_names, FIG_DIR / "confusion_matrix_fused_MEAN_10runs_norm.png", normalize=True,
                          title="Mean confusion (normalized) ‚Äì Fused (10 runs)")

print("-", FIG_DIR / "confusion_matrix_resnet_MEAN_10runs_norm.png")
print("-", FIG_DIR / "confusion_matrix_fused_MEAN_10runs_norm.png")

# Short printed summary
def _mean_std(x): 
    x = np.asarray(x, dtype=float)
    return float(np.nanmean(x)), float(np.nanstd(x))

print("\n" + "="*80)
print("SUMMARY (mean ¬± std across runs)")
print("="*80)

for col in ["resnet_top1", "fused_top1", "improvement_top1_pp", "resnet_ece", "fused_ece", "resnet_brier", "fused_brier", "resnet_logloss", "fused_logloss"]:
    m, s = _mean_std(results_df[col].values)
    print(f"{col:>22}: {m:.4f} ¬± {s:.4f}")

print("\nPer-species Top-1 improvement (mean ¬± std pp):")
for sp in WILD_SPECIES:
    arr = np.asarray(species_improvements[sp], dtype=float)
    if len(arr) == 0:
        continue
    print(f"  {sp:<30} {arr.mean():+6.2f} ¬± {arr.std():.2f}")


## 6) Optional quick figures

Two simple figures that are commonly useful in the thesis:
- distribution of Top-1 improvements across runs
- per-species mean improvement (mean ¬± std across runs)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# 1) Improvement distribution
fig, ax = plt.subplots(figsize=(7,4))
ax.hist(results_df["improvement_top1_pp"], bins=10, edgecolor="black", alpha=0.7)
ax.axvline(results_df["improvement_top1_pp"].mean(), linestyle="--", linewidth=2,
           label=f"Mean: {results_df['improvement_top1_pp'].mean():.2f} pp")
ax.set_xlabel("Top-1 improvement (percentage points)")
ax.set_ylabel("Frequency")
ax.set_title("Fusion improvement across 10 ResNet runs")
ax.legend()
fig.tight_layout()
fig.savefig(FIG_DIR / "fusion_improvement_hist.png", dpi=250)
plt.show()

# 2) Per-species improvement bars
species = WILD_SPECIES
means = [np.mean(species_improvements[s]) if len(species_improvements[s]) else 0.0 for s in species]
stds  = [np.std(species_improvements[s]) if len(species_improvements[s]) else 0.0 for s in species]

fig, ax = plt.subplots(figsize=(9,4))
x = np.arange(len(species))
ax.bar(x, means, yerr=stds, capsize=4, alpha=0.7, edgecolor="black")
ax.axhline(0, linewidth=1)
ax.set_xticks(x, labels=[s.split()[-1] for s in species], rotation=45, ha="right")
ax.set_ylabel("Top-1 improvement (pp)")
ax.set_title("Per-species fusion improvement (mean ¬± std across runs)")
fig.tight_layout()
fig.savefig(FIG_DIR / "fusion_improvement_per_species.png", dpi=250)
plt.show()

print("Saved:")
print("-", FIG_DIR / "fusion_improvement_hist.png")
print("-", FIG_DIR / "fusion_improvement_per_species.png")


In [None]:
import numpy as np
import pandas as pd
from IPython.display import display
# --- prerequisites ---
species_arr = matched_df["species"].to_numpy()
y_true_arr = np.asarray(y_true, dtype=int)

all_resnet_preds_arr = np.asarray(all_resnet_preds)  # (n_runs, n_samples)
all_fused_preds_arr  = np.asarray(all_fused_preds)   # (n_runs, n_samples)

assert all_resnet_preds_arr.shape == all_fused_preds_arr.shape
assert all_resnet_preds_arr.shape[1] == len(y_true_arr) == len(species_arr)

species_list = list(pd.unique(species_arr))

# -----------------------------
# A) Per-species Top-1 per run
# -----------------------------
rows = []
for run_tag, y_r, y_f in zip(run_tags, all_resnet_preds_arr, all_fused_preds_arr):
    for sp in species_list:
        idx = np.where(species_arr == sp)[0]
        if len(idx) == 0:
            continue

        r_acc = float(np.mean(y_r[idx] == y_true_arr[idx]))
        f_acc = float(np.mean(y_f[idx] == y_true_arr[idx]))
        delta_pp = 100.0 * (f_acc - r_acc)

        rows.append({
            "run_tag": run_tag,
            "species": sp,
            "n": int(len(idx)),
            "cnn_acc": r_acc,          # in [0,1]
            "fused_acc": f_acc,        # in [0,1]
            "delta_pp": delta_pp,      # percentage points
        })

per_run_species_df = pd.DataFrame(rows)

# -----------------------------------
# B) Rescued/Hurt per species per run
# -----------------------------------
rh_rows = []
for run_tag, y_r, y_f in zip(run_tags, all_resnet_preds_arr, all_fused_preds_arr):
    res_correct = (y_r == y_true_arr)
    fus_correct = (y_f == y_true_arr)

    rescued = (~res_correct) & fus_correct
    hurt    = res_correct & (~fus_correct)

    for sp in species_list:
        idx = np.where(species_arr == sp)[0]
        if len(idx) == 0:
            continue
        rh_rows.append({
            "run_tag": run_tag,
            "species": sp,
            "n": int(len(idx)),
            "rescued_n": int(rescued[idx].sum()),
            "hurt_n": int(hurt[idx].sum()),
        })

rescued_hurt_df = pd.DataFrame(rh_rows)

# Save the raw per-run tables (handy for debugging / appendix)
per_run_species_df.to_csv(MET_DIR / "per_species_top1_by_run.csv", index=False)
rescued_hurt_df.to_csv(MET_DIR / "rescued_hurt_by_run.csv", index=False)

print("Saved raw tables:")
print("-", MET_DIR / "per_species_top1_by_run.csv")
print("-", MET_DIR / "rescued_hurt_by_run.csv")


print("\nPreview: per-run per-species Top-1 table (first 20 rows)")
display(per_run_species_df.head(20))

print("\nPreview: per-run rescued/hurt table (first 20 rows)")
display(rescued_hurt_df.head(20))



In [None]:
import numpy as np
import pandas as pd
from IPython.display import display

def fmt_mean_std(x, is_percent=True):
    m = float(np.mean(x))
    s = float(np.std(x, ddof=1)) if len(x) > 1 else 0.0
    if is_percent:
        return f"{100*m:.1f} ¬± {100*s:.1f}"
    return f"{m:.3f} ¬± {s:.3f}"

# -----------------------------
# Table: Per-species Top-1
# -----------------------------
g = per_run_species_df.groupby("species", sort=False)

table_top1 = g.agg(
    N=("n", "first"),
    cnn_mean=("cnn_acc", "mean"),
    cnn_std=("cnn_acc", "std"),
    fused_mean=("fused_acc", "mean"),
    fused_std=("fused_acc", "std"),
    delta_mean=("delta_pp", "mean"),
    delta_std=("delta_pp", "std"),
    delta_min=("delta_pp", "min"),
    delta_max=("delta_pp", "max"),
).reset_index()

# Relative improvement in %
table_top1["rel_improv_pct"] = 100.0 * (table_top1["delta_mean"] / (100.0 * table_top1["cnn_mean"]))

# Pretty string columns (match your screenshot style)
table_top1["CNN (%)"]   = [f"{100*m:.1f} ¬± {100*s:.1f}" for m, s in zip(table_top1["cnn_mean"],  table_top1["cnn_std"].fillna(0))]
table_top1["Fused (%)"] = [f"{100*m:.1f} ¬± {100*s:.1f}" for m, s in zip(table_top1["fused_mean"], table_top1["fused_std"].fillna(0))]
table_top1["Œî (pp)"]    = [f"{m:+.2f} ¬± {s:.2f}"        for m, s in zip(table_top1["delta_mean"], table_top1["delta_std"].fillna(0))]
table_top1["Rel. improv. (%)"] = [f"{x:+.1f}" for x in table_top1["rel_improv_pct"]]
table_top1["Range (pp)"] = [f"{mn:+.2f} to {mx:+.2f}" for mn, mx in zip(table_top1["delta_min"], table_top1["delta_max"])]

# Sort like your table: by improvement descending
table_top1 = table_top1.sort_values("delta_mean", ascending=False).reset_index(drop=True)

# Overall mean row (across runs, overall accuracy)
# Compute per-run overall accuracies and average them (clean, matches your setup)
overall_rows = []
for run_tag in per_run_species_df["run_tag"].unique():
    run_block = per_run_species_df[per_run_species_df["run_tag"] == run_tag]
    # weighted by N for overall mean accuracy
    w = run_block["n"].to_numpy()
    cnn = (run_block["cnn_acc"].to_numpy() * w).sum() / w.sum()
    fus = (run_block["fused_acc"].to_numpy() * w).sum() / w.sum()
    overall_rows.append({"run_tag": run_tag, "cnn": cnn, "fused": fus, "delta_pp": 100*(fus-cnn)})

overall_df = pd.DataFrame(overall_rows)
overall_row = {
    "species": "Mean (Overall)",
    "N": int(per_run_species_df["n"].dropna().unique().sum()) if per_run_species_df["n"].nunique() == 1 else int(len(y_true)),
    "CNN (%)":   fmt_mean_std(overall_df["cnn"].to_numpy(), is_percent=True),
    "Fused (%)": fmt_mean_std(overall_df["fused"].to_numpy(), is_percent=True),
    "Œî (pp)":    f"{overall_df['delta_pp'].mean():+.2f} ¬± {overall_df['delta_pp'].std(ddof=1):.2f}",
    "Rel. improv. (%)": f"{(100*(overall_df['fused'].mean()/overall_df['cnn'].mean() - 1)):+.1f}",
    "Range (pp)": f"{overall_df['delta_pp'].min():+.2f} to {overall_df['delta_pp'].max():+.2f}",
}

table_top1_out = table_top1[["species","N","CNN (%)","Fused (%)","Œî (pp)","Rel. improv. (%)","Range (pp)"]].copy()
table_top1_out = pd.concat([table_top1_out, pd.DataFrame([overall_row])], ignore_index=True)

table_top1_out.to_csv(MET_DIR / "table_per_species_top1.csv", index=False)


# -----------------------------
# Table: Rescued / Hurt
# -----------------------------
g2 = rescued_hurt_df.groupby("species", sort=False)
table_rh = g2.agg(
    N=("n", "first"),
    rescued_mean=("rescued_n", "mean"),
    hurt_mean=("hurt_n", "mean"),
).reset_index()

table_rh["rescued_pct"] = 100.0 * (table_rh["rescued_mean"] / table_rh["N"])
table_rh["hurt_pct"]    = 100.0 * (table_rh["hurt_mean"] / table_rh["N"])
table_rh["net_pp"]      = table_rh["rescued_pct"] - table_rh["hurt_pct"]

# Format like screenshot: mean counts with (percentage)
table_rh["Rescued Count (%)"] = [f"{c:.1f} ({p:.1f}%)" for c, p in zip(table_rh["rescued_mean"], table_rh["rescued_pct"])]
table_rh["Hurt Count (%)"]    = [f"{c:.1f} ({p:.1f}%)" for c, p in zip(table_rh["hurt_mean"], table_rh["hurt_pct"])]
table_rh["Net Impact"]        = [f"{x:+.2f} pp" for x in table_rh["net_pp"]]

# Overall mean row (mean rescued/hurt totals across runs)
totals_by_run = rescued_hurt_df.groupby("run_tag").agg(
    rescued=("rescued_n","sum"),
    hurt=("hurt_n","sum"),
    N=("n","sum"),
).reset_index()

N_total = int(totals_by_run["N"].iloc[0])
rescued_mean = float(totals_by_run["rescued"].mean())
hurt_mean = float(totals_by_run["hurt"].mean())
rescued_pct = 100.0 * rescued_mean / N_total
hurt_pct = 100.0 * hurt_mean / N_total

overall_rh = {
    "species": "Overall Mean",
    "N": N_total,
    "Rescued Count (%)": f"{rescued_mean:.1f} ({rescued_pct:.1f}%)",
    "Hurt Count (%)": f"{hurt_mean:.1f} ({hurt_pct:.1f}%)",
    "Net Impact": f"{(rescued_pct - hurt_pct):+.2f} pp",
}

table_rh_out = table_rh[["species","N","Rescued Count (%)","Hurt Count (%)","Net Impact"]].copy()
table_rh_out = pd.concat([table_rh_out, pd.DataFrame([overall_rh])], ignore_index=True)

table_rh_out.to_csv(MET_DIR / "table_rescued_hurt.csv", index=False)

print("Saved aggregated tables:")
print("-", MET_DIR / "table_per_species_top1.csv")
print("-", MET_DIR / "table_rescued_hurt.csv")

print("\n‚úÖ Table: Per-species Top-1 (final)")
display(table_top1_out.style.hide(axis="index"))

print("\n‚úÖ Table: Rescued / Hurt (final)")
display(table_rh_out.style.hide(axis="index"))