In [5]:
import pandas as pd
df = pd.read_csv('tree_data_merged.csv')

In [6]:
import json
import os
import sys
import numpy as np
import re

# Add project root to path to import utilities
sys.path.append('/home/utkarsh/TreeInventorization')

from src.utils.mask_serialization import load_panorama_masks, deserialize_ultralytics_mask

# Define mask directory
MASK_DIR = '/home/utkarsh/TreeInventorization/data/masks'

def load_mask_for_row(row):
    """
    Load and deserialize mask for a given CSV row.
    
    Args:
        row: DataFrame row containing pano_id and image_path
    
    Returns:
        Deserialized mask data or None if not found
    """
    try:
        # Extract pano_id from the row
        pano_id = row['pano_id']
        image_path = row['image_path']
        
        # Construct mask JSON file path
        mask_json_path = os.path.join(MASK_DIR, f"{pano_id}_masks.json")
        
        # Check if mask file exists
        if not os.path.exists(mask_json_path):
            # Silently skip - mask file doesn't exist for this pano
            return None
        
        # Load all masks for this panorama
        mask_data = load_panorama_masks(mask_json_path)
        
        # Extract view number from image path (e.g., "view15" from filename)
        image_name = os.path.basename(image_path)
        view_match = re.search(r'view(\d+)', image_name)
        
        if not view_match:
            print(f"⚠️ Could not extract view number from: {image_name}")
            return None
            
        view_num = view_match.group(1)
        view_key = f"view_{view_num}"
        
        # Get masks for this specific view
        if view_key in mask_data.get('views', {}):
            view_masks = mask_data['views'][view_key]
            
            # Extract tree and box numbers from filename
            tree_box_match = re.search(r'tree(\d+)_box(\d+)', image_name)
            if tree_box_match:
                tree_num = tree_box_match.group(1)
                box_num = tree_box_match.group(2)
                tree_index = f"{tree_num}-{box_num}"
                
                # Find matching mask by tree_index
                for mask in view_masks:
                    if mask.get('tree_index') == tree_index:
                        # Deserialize the mask
                        deserialized = deserialize_ultralytics_mask(mask.get('mask_data', {}))
                        
                        return {
                            'pano_id': pano_id,
                            'image_path': image_path,
                            'view': view_key,
                            'tree_index': tree_index,
                            'confidence': mask.get('confidence', 0),
                            'mask': deserialized,
                            'original_mask_data': mask
                        }
        
        return None
        
    except Exception as e:
        print(f"❌ Error loading mask for row: {str(e)}")
        return None

# Process all rows in the dataframe
print(f"Processing {len(df)} rows from CSV...")
mask_results = []

for idx, row in df.iterrows():
    mask_data = load_mask_for_row(row)
    mask_results.append(mask_data)
    
    # Show progress every 50 rows
    if (idx + 1) % 50 == 0:
        print(f"Processed {idx + 1}/{len(df)} rows")

# Count successful loads
successful_loads = sum(1 for m in mask_results if m is not None)
print(f"\n✅ Successfully loaded masks for {successful_loads}/{len(df)} rows")

# Show sample of successfully loaded masks
if successful_loads > 0:
    print(f"\nSample of loaded masks:")
    count = 0
    for result in mask_results:
        if result and count < 5:
            print(f"  - Pano: {result['pano_id']}, View: {result['view']}, Tree: {result['tree_index']}, Conf: {result['confidence']:.3f}")
            if result['mask'] and result['mask'].get('xy'):
                print(f"    Polygon points: {len(result['mask']['xy'][0]) if result['mask']['xy'] else 0}")
            count += 1


Processing 441 rows from CSV...
Processed 50/441 rows
Processed 100/441 rows
Processed 150/441 rows
Processed 200/441 rows
Processed 250/441 rows
Processed 300/441 rows
Processed 350/441 rows
Processed 400/441 rows

✅ Successfully loaded masks for 441/441 rows

Sample of loaded masks:
  - Pano: zXNDa-3To6LUVVwCusNXAg, View: view_6, Tree: 0-0, Conf: 0.390
    Polygon points: 1000
  - Pano: zXNDa-3To6LUVVwCusNXAg, View: view_15, Tree: 0-0, Conf: 0.816
    Polygon points: 715
  - Pano: zn_BsL9aEKPuJ-IlqPplxg, View: view_6, Tree: 0-0, Conf: 0.309
    Polygon points: 831
  - Pano: zn_BsL9aEKPuJ-IlqPplxg, View: view_15, Tree: 0-0, Conf: 0.779
    Polygon points: 708
  - Pano: Eg2I2AlO8wExryXAspOx7g, View: view_4, Tree: 0-0, Conf: 0.341
    Polygon points: 1044


In [7]:
# Additional utility functions for mask processing

def get_mask_polygons(mask_data):
    """Extract polygon coordinates from deserialized mask."""
    if mask_data and mask_data.get('xy'):
        return mask_data['xy'][0]  # Return first polygon
    return None

def get_mask_binary(mask_data):
    """Get binary mask if available from deserialized data."""
    if mask_data and mask_data.get('data') is not None:
        return mask_data['data']
    return None

def get_mask_bbox(polygon):
    """Calculate bounding box from polygon points."""
    if polygon is not None and len(polygon) > 0:
        x_coords = polygon[:, 0]
        y_coords = polygon[:, 1]
        return {
            'x_min': float(np.min(x_coords)),
            'y_min': float(np.min(y_coords)),
            'x_max': float(np.max(x_coords)),
            'y_max': float(np.max(y_coords)),
            'width': float(np.max(x_coords) - np.min(x_coords)),
            'height': float(np.max(y_coords) - np.min(y_coords))
        }
    return None

# Example usage with loaded masks
if mask_results:
    first_result = next((r for r in mask_results if r), None)
    if first_result and first_result.get('mask'):
        mask = first_result['mask']
        
        print(f"\nAnalyzing first mask:")
        print(f"  Image: {os.path.basename(first_result['image_path'])}")
        print(f"  Tree index: {first_result['tree_index']}")
        
        # Get polygon
        polygon = get_mask_polygons(mask)
        if polygon is not None:
            print(f"  Polygon shape: {polygon.shape}")
            
            # Get bounding box
            bbox = get_mask_bbox(polygon)
            if bbox:
                print(f"  Bounding box: x=[{bbox['x_min']:.1f}, {bbox['x_max']:.1f}], y=[{bbox['y_min']:.1f}, {bbox['y_max']:.1f}]")
                print(f"  Box dimensions: {bbox['width']:.1f} x {bbox['height']:.1f}")
        
        # Check for binary mask
        binary_mask = get_mask_binary(mask)
        if binary_mask is not None:
            print(f"  Binary mask shape: {binary_mask.shape}")



Analyzing first mask:
  Image: zXNDa-3To6LUVVwCusNXAg_view6_tree0_box0.jpg
  Tree index: 0-0
  Polygon shape: (1000, 2)
  Bounding box: x=[62.0, 981.0], y=[0.0, 449.0]
  Box dimensions: 919.0 x 449.0
  Binary mask shape: (720, 1024)


In [8]:
# Add mask data to dataframe for easier processing
def add_masks_to_dataframe(df, mask_results):
    """
    Add mask data columns to the original dataframe.
    
    Args:
        df: Original dataframe
        mask_results: List of mask results from load_mask_for_row
    
    Returns:
        DataFrame with additional mask columns
    """
    df_with_masks = df.copy()
    
    # Initialize new columns
    df_with_masks['has_mask'] = False
    df_with_masks['mask_confidence'] = np.nan
    df_with_masks['mask_polygon'] = None
    df_with_masks['mask_bbox'] = None
    df_with_masks['mask_data'] = None
    
    # Add mask data to corresponding rows
    for idx, mask_result in enumerate(mask_results):
        if mask_result is not None:
            df_with_masks.at[idx, 'has_mask'] = True
            df_with_masks.at[idx, 'mask_confidence'] = mask_result.get('confidence', 0)
            
            # Get polygon if available
            if mask_result.get('mask') and mask_result['mask'].get('xy'):
                polygon = mask_result['mask']['xy'][0]
                df_with_masks.at[idx, 'mask_polygon'] = polygon
                
                # Calculate and store bbox
                bbox = get_mask_bbox(polygon)
                df_with_masks.at[idx, 'mask_bbox'] = bbox
            
            # Store full mask data
            df_with_masks.at[idx, 'mask_data'] = mask_result
    
    return df_with_masks

# Create enhanced dataframe with mask data
df_with_masks = add_masks_to_dataframe(df, mask_results)

# Show statistics
print(f"\nDataFrame with masks:")
print(f"  Total rows: {len(df_with_masks)}")
print(f"  Rows with masks: {df_with_masks['has_mask'].sum()}")
print(f"  Rows without masks: {(~df_with_masks['has_mask']).sum()}")

if df_with_masks['has_mask'].any():
    print(f"\n  Average confidence: {df_with_masks['mask_confidence'].mean():.3f}")
    print(f"  Min confidence: {df_with_masks['mask_confidence'].min():.3f}")
    print(f"  Max confidence: {df_with_masks['mask_confidence'].max():.3f}")

# Example: Filter rows with high confidence masks
high_conf_masks = df_with_masks[df_with_masks['mask_confidence'] > 0.7]
print(f"\n  High confidence masks (>0.7): {len(high_conf_masks)} rows")



DataFrame with masks:
  Total rows: 441
  Rows with masks: 441
  Rows without masks: 0

  Average confidence: 0.607
  Min confidence: 0.251
  Max confidence: 0.986

  High confidence masks (>0.7): 172 rows


In [22]:
df = df_with_masks

In [39]:
"""
Mask Usability Classifier — v2 (binary: usable vs not)
------------------------------------------------------
Changes in this version:
  • Keeps square input (default 224) and adds correct ImageNet normalization.
  • Joint image↔mask horizontal flip for Fusion so they stay aligned.
  • Backbone selector for the image branch: resnet18, resnet50, efficientnet_b0.
  • Always returns tensors from Dataset (no None), so default collate works.
  • Saves backbone + normalization in checkpoints and uses them at inference.

Variants supported:
  1) mask_cnn         — Tiny CNN on the mask only (kept for completeness)
  2) resnet_overlay   — Transfer learning on overlayed images (recommended)
  3) fusion           — Image + Mask late-fusion

Quick start:
    results = run_experiment(
        df_with_masks,
        variant="resnet_overlay",            # or "fusion"
        backbone="resnet50",                 # "resnet18" | "resnet50" | "efficientnet_b0"
        out_dir="/mnt/data/mask_quality/res50",
        image_size=224,
        num_epochs=30,
        batch_size=16,
    )

    infer = load_for_inference("/mnt/data/mask_quality/res50/best_overall.pt")
    is_ok, prob = predict_row(df_with_masks.iloc[0], infer)
    print(is_ok, prob)
"""
from __future__ import annotations
import os
import json
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

try:
    from torchvision import transforms
    from torchvision.models import (
        resnet18, ResNet18_Weights,
        resnet50, ResNet50_Weights,
        efficientnet_b0, EfficientNet_B0_Weights,
    )
except Exception as e:
    raise RuntimeError("torchvision is required. pip install torchvision")

# -----------------------------
# Utilities
# -----------------------------

def _safe_makedirs(p: str):
    os.makedirs(p, exist_ok=True)


def _to_bool_label(x: Any) -> int:
    """Map your 'correct' column values to 1/0. Accepts yes/no, 1/0, True/False."""
    if isinstance(x, str):
        s = x.strip().lower()
        if s in {"yes", "true", "1"}:  # usable
            return 1
        if s in {"no", "false", "0"}:  # not usable
            return 0
    if isinstance(x, (int, np.integer)):
        return int(x)
    if isinstance(x, bool):
        return int(x)
    raise ValueError(f"Unrecognized label value: {x}")


def extract_mask_array(mask_obj: Any, orig_shape: Optional[Tuple[int, int]] = None) -> np.ndarray:
    """Return a 2D uint8 mask {0,1} from various formats (np, dict->data, RLE, polygon)."""
    if isinstance(mask_obj, np.ndarray):
        m = mask_obj
        if m.ndim == 3 and m.shape[-1] == 1:
            m = m[..., 0]
        return (m > 0.5).astype(np.uint8)

    if not isinstance(mask_obj, dict):
        raise ValueError("mask_obj must be np.ndarray or dict")

    def _get(key):
        return mask_obj.get(key) if isinstance(mask_obj, dict) else None

    # Try nested dicts with direct 'data'
    for k in ["mask", "mask_data", "data", "rle"]:
        d = _get(k)
        if isinstance(d, dict) and "data" in d and isinstance(d["data"], np.ndarray):
            m = d["data"]
            if m.ndim == 3 and m.shape[-1] == 1:
                m = m[..., 0]
            return (m > 0.5).astype(np.uint8)

    # Try top-level 'data'
    if "data" in mask_obj and isinstance(mask_obj["data"], np.ndarray):
        m = mask_obj["data"]
        if m.ndim == 3 and m.shape[-1] == 1:
            m = m[..., 0]
        return (m > 0.5).astype(np.uint8)

    # Try RLE
    rle = None
    for k in ["mask", "mask_data", "rle"]:
        d = _get(k)
        if isinstance(d, dict) and "rle" in d:
            rle = d["rle"]
            break
    if rle is not None and isinstance(rle, dict) and "counts" in rle and "size" in rle:
        try:
            from pycocotools import mask as maskUtils  # type: ignore
            rr = {
                "size": rle["size"],
                "counts": rle["counts"] if isinstance(rle["counts"], (bytes, bytearray)) else rle["counts"].encode("ascii"),
            }
            m = maskUtils.decode(rr)
            if m.ndim == 3:
                m = m[:, :, 0]
            return (m > 0).astype(np.uint8)
        except Exception:
            pass

    # Try polygons
    polys = None
    for k in ["mask", "mask_data"]:
        d = _get(k)
        if isinstance(d, dict) and "xy" in d:
            polys = d["xy"]
            break
    if polys is not None:
        if isinstance(polys, np.ndarray):
            polys = [polys]
        if orig_shape is None:
            for k in ["mask", "mask_data"]:
                d = _get(k)
                if isinstance(d, dict) and "orig_shape" in d:
                    osz = d["orig_shape"]
                    if isinstance(osz, (list, tuple)) and len(osz) >= 2:
                        orig_shape = (int(osz[0]), int(osz[1]))
                        break
        if orig_shape is None:
            xs, ys = [], []
            for p in polys:
                xs.append(np.max(p[:, 0])); ys.append(np.max(p[:, 1]))
            H = int(max(ys) + 1); W = int(max(xs) + 1)
            orig_shape = (H, W)
        H, W = orig_shape
        m = np.zeros((H, W), dtype=np.uint8)
        try:
            import cv2
            for p in polys:
                cv2.fillPoly(m, [p.astype(np.int32)], 1)
            return m
        except Exception:
            for p in polys:
                p = p.astype(int)
                for y in range(p[:, 1].min(), p[:, 1].max() + 1):
                    xs = p[p[:, 1] == y][:, 0]
                    if len(xs) >= 2:
                        m[y, xs.min(): xs.max()+1] = 1
            return m

    raise ValueError("Could not extract a binary mask from provided mask_obj")


# -----------------------------
# Backbones for the image branch
# -----------------------------

def make_image_backbone(name: str = "resnet18", pretrained: bool = True):
    name = name.lower()
    if name == "resnet18":
        weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        m = resnet18(weights=weights)
        in_feats = m.fc.in_features
        m.fc = nn.Identity()
        norm = None
        if weights is not None:
            t = weights.transforms()
            norm = {"mean": list(t.mean), "std": list(t.std)}
        return m, in_feats, norm
    if name == "resnet50":
        weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        m = resnet50(weights=weights)
        in_feats = m.fc.in_features
        m.fc = nn.Identity()
        norm = None
        if weights is not None:
            t = weights.transforms()
            norm = {"mean": list(t.mean), "std": list(t.std)}
        return m, in_feats, norm
    if name == "efficientnet_b0":
        weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        m = efficientnet_b0(weights=weights)
        in_feats = m.classifier[1].in_features
        m.classifier = nn.Identity()
        norm = None
        if weights is not None:
            t = weights.transforms()
            norm = {"mean": list(t.mean), "std": list(t.std)}
        return m, in_feats, norm
    raise ValueError("unknown backbone; use resnet18 | resnet50 | efficientnet_b0")


# -----------------------------
# Dataset
# -----------------------------
class MaskQualityDataset(Dataset):
    def __init__(self,
                 df: pd.DataFrame,
                 use_overlay: bool = True,
                 use_mask: bool = True,
                 image_key: str = "image_path",
                 mask_key: str = "mask_data",
                 label_key: str = "correct",
                 image_size: int = 224,
                 image_norm: Optional[Dict[str, List[float]]] = None,
                 augment: bool = True,
                 joint_flip_for_fusion: bool = True):
        self.df = df.reset_index(drop=True)
        self.use_overlay = use_overlay
        self.use_mask = use_mask
        self.image_key = image_key
        self.mask_key = mask_key
        self.label_key = label_key
        self.image_size = int(image_size)
        self.image_norm = image_norm
        self.augment = augment
        self.joint_flip_for_fusion = joint_flip_for_fusion

        # Pre-build simple transforms that do NOT include random flip (we'll do it jointly)
        img_tf_list = [transforms.Resize((self.image_size, self.image_size))]
        if augment:
            img_tf_list.append(transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)], p=0.3))
        img_tf_list.append(transforms.ToTensor())
        if image_norm is not None:
            img_tf_list.append(transforms.Normalize(mean=image_norm["mean"], std=image_norm["std"]))
        self.img_tf = transforms.Compose(img_tf_list) if use_overlay else None

        self.mask_tf = None
        if use_mask:
            self.mask_tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((self.image_size, self.image_size)),
            ])

        self.labels = np.array([_to_bool_label(x) for x in self.df[self.label_key].values], dtype=np.int64)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        y = self.labels[idx]

        # Prepare placeholders so collate never sees None
        img_t = torch.zeros(3, self.image_size, self.image_size)
        mask_t = torch.zeros(1, self.image_size, self.image_size)

        # Load image/mask
        if self.use_overlay:
            try:
                with Image.open(row[self.image_key]) as im:
                    im = im.convert("RGB")
                    img_pil = im
            except Exception:
                img_pil = Image.new("RGB", (self.image_size, self.image_size))
        if self.use_mask:
            try:
                m = extract_mask_array(row[self.mask_key])
            except Exception:
                m = np.zeros((self.image_size, self.image_size), dtype=np.uint8)

        # Joint flip if both branches are used
        if self.augment and self.joint_flip_for_fusion and self.use_overlay and self.use_mask:
            if np.random.rand() < 0.5:
                img_pil = transforms.functional.hflip(img_pil)
                m = np.fliplr(m).copy()

        # Apply per-branch transforms
        if self.use_overlay:
            img_t = self.img_tf(img_pil)
        if self.use_mask:
            m = (m > 0).astype(np.uint8)
            mask_t = self.mask_tf(m)

        return img_t, mask_t, y


# -----------------------------
# Models
# -----------------------------
class TinyMaskCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
        )
        self.head = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(128, 1))
    def forward(self, x_mask: torch.Tensor) -> torch.Tensor:
        h = self.net(x_mask)
        return self.head(h).squeeze(1)  # logits


class ImageOverlayClassifier(nn.Module):
    def __init__(self, backbone: str = "resnet18", pretrained: bool = True):
        super().__init__()
        self.backbone, in_feats, self.norm = make_image_backbone(backbone, pretrained)
        self.head = nn.Sequential(nn.Dropout(0.3), nn.Linear(in_feats, 1))
        self.backbone_name = backbone
    def forward(self, x_img: torch.Tensor) -> torch.Tensor:
        f = self.backbone(x_img)
        return self.head(f).squeeze(1)


class FusionNet(nn.Module):
    def __init__(self, backbone: str = "resnet18", pretrained: bool = True):
        super().__init__()
        self.img_backbone, in_feats, self.norm = make_image_backbone(backbone, pretrained)
        self.mask_branch = TinyMaskCNN()
        self.head = nn.Sequential(nn.Dropout(0.3), nn.Linear(in_feats + 128, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 1))
        self.backbone_name = backbone
    def forward(self, x_img: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
        fi = self.img_backbone(x_img)
        fm = self.mask_branch.net(x_mask)
        fm = self.mask_branch.head[0:2](fm)  # Flatten + Dropout
        fm = fm.view(fm.size(0), -1)
        return self.head(torch.cat([fi, fm], dim=1)).squeeze(1)


# -----------------------------
# Training / Eval
# -----------------------------
@dataclass
class TrainConfig:
    variant: str = "resnet_overlay"  # 'mask_cnn' | 'resnet_overlay' | 'fusion'
    backbone: str = "resnet18"       # image branch backbone
    out_dir: str = "/mnt/data/mask_quality"
    num_epochs: int = 30
    batch_size: int = 16
    image_size: int = 224
    lr: float = 1e-4
    weight_decay: float = 1e-4
    early_stop_patience: int = 5
    num_folds: int = 5
    seed: int = 42


def _make_model(variant: str, backbone: str) -> nn.Module:
    if variant == "mask_cnn":
        return TinyMaskCNN()
    elif variant == "resnet_overlay":
        return ImageOverlayClassifier(backbone=backbone, pretrained=True)
    elif variant == "fusion":
        return FusionNet(backbone=backbone, pretrained=True)
    else:
        raise ValueError("Unknown variant")


def _get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _compute_class_weight(y: np.ndarray) -> torch.Tensor:
    pos = (y == 1).sum(); neg = (y == 0).sum()
    if pos == 0:
        return torch.tensor(1.0)
    return torch.tensor(max(1.0, neg / max(1, pos)), dtype=torch.float32)


def _metrics_from_logits(logits: torch.Tensor, y_true: torch.Tensor, threshold: float = 0.5) -> Dict[str, Any]:
    probs = torch.sigmoid(logits).detach().cpu().numpy()
    y = y_true.detach().cpu().numpy()
    yhat = (probs >= threshold).astype(int)
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
    acc = accuracy_score(y, yhat)
    prec, rec, f1, _ = precision_recall_fscore_support(y, yhat, average='binary', zero_division=0)
    try:
        auc = roc_auc_score(y, probs)
    except Exception:
        auc = float('nan')
    cm = confusion_matrix(y, yhat).tolist()
    return {"acc": acc, "prec": prec, "rec": rec, "f1": f1, "auc": auc, "confusion_matrix": cm, "threshold": threshold}


def _optimal_threshold(logits: np.ndarray, y_true: np.ndarray) -> float:
    from sklearn.metrics import roc_curve
    probs = 1 / (1 + np.exp(-logits))
    fpr, tpr, thr = roc_curve(y_true, probs)
    j = tpr - fpr
    i = int(np.argmax(j))
    return float(thr[i])


def _train_one_epoch(model, loader, device, optimizer, criterion):
    model.train(); total_loss = 0.0
    for img_t, mask_t, y in loader:
        y = y.float().to(device)
        if isinstance(model, FusionNet):
            img_t = img_t.to(device); mask_t = mask_t.to(device)
            logits = model(img_t, mask_t)
        elif isinstance(model, ImageOverlayClassifier):
            img_t = img_t.to(device)
            logits = model(img_t)
        else:
            mask_t = mask_t.to(device)
            logits = model(mask_t)
        loss = criterion(logits, y)
        optimizer.zero_grad(set_to_none=True); loss.backward(); optimizer.step()
        total_loss += float(loss.item()) * y.size(0)
    return total_loss / len(loader.dataset)


def _eval_logits(model, loader, device) -> Tuple[np.ndarray, np.ndarray]:
    model.eval(); ys, lg = [], []
    with torch.no_grad():
        for img_t, mask_t, y in loader:
            y = y.to(device)
            if isinstance(model, FusionNet):
                img_t = img_t.to(device); mask_t = mask_t.to(device)
                logits = model(img_t, mask_t)
            elif isinstance(model, ImageOverlayClassifier):
                img_t = img_t.to(device)
                logits = model(img_t)
            else:
                mask_t = mask_t.to(device)
                logits = model(mask_t)
            ys.append(y.cpu().numpy()); lg.append(logits.cpu().numpy())
    return np.concatenate(lg), np.concatenate(ys)


def run_experiment(df: pd.DataFrame,
                   variant: str = "resnet_overlay",
                   backbone: str = "resnet18",
                   out_dir: str = "/mnt/data/mask_quality",
                   num_epochs: int = 30,
                   batch_size: int = 16,
                   image_size: int = 224,
                   seed: int = 42,
                   num_folds: int = 5,
                   early_stop_patience: int = 5,
                   lr: float = 1e-4,
                   weight_decay: float = 1e-4) -> Dict[str, Any]:
    """Train + evaluate with K-fold CV. Returns metrics summary and saves best checkpoint."""
    _safe_makedirs(out_dir)

    # ensure labels are available
    y_all = np.array([_to_bool_label(x) for x in df['correct'].values], dtype=np.int64)

    from sklearn.model_selection import StratifiedKFold
    skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=seed)

    device = _get_device()
    fold_summaries: List[Dict[str, Any]] = []
    best_overall = {"f1": -1.0, "path": None}

    # Build a temp backbone to get normalization stats
    tmp_backbone, _, norm = make_image_backbone(backbone, pretrained=True)
    del tmp_backbone

    for fold, (tr_idx, va_idx) in enumerate(skf.split(np.zeros_like(y_all), y_all), start=1):
        df_tr = df.iloc[tr_idx].reset_index(drop=True)
        df_va = df.iloc[va_idx].reset_index(drop=True)

        use_overlay = (variant != 'mask_cnn')
        use_mask = (variant != 'resnet_overlay')

        ds_tr = MaskQualityDataset(df_tr,
                                   use_overlay=use_overlay,
                                   use_mask=use_mask,
                                   image_size=image_size,
                                   image_norm=norm,
                                   augment=True,
                                   joint_flip_for_fusion=True)
        ds_va = MaskQualityDataset(df_va,
                                   use_overlay=use_overlay,
                                   use_mask=use_mask,
                                   image_size=image_size,
                                   image_norm=norm,
                                   augment=False,
                                   joint_flip_for_fusion=False)
        dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

        model = _make_model(variant, backbone).to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        pos_weight = _compute_class_weight(np.array([_to_bool_label(x) for x in df_tr['correct'].values])).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        best_state = None; best_f1 = -1.0; best_thr = 0.5; no_improve = 0
        for epoch in range(1, num_epochs + 1):
            tr_loss = _train_one_epoch(model, dl_tr, device, optimizer, criterion)
            logits_va, y_va = _eval_logits(model, dl_va, device)
            thr = _optimal_threshold(logits_va, y_va)
            m = _metrics_from_logits(torch.tensor(logits_va), torch.tensor(y_va), threshold=thr)
            print(f"Fold {fold:02d} | Epoch {epoch:02d} | loss {tr_loss:.4f} | val F1 {m['f1']:.3f} | thr {thr:.3f}")
            if m['f1'] > best_f1 + 1e-4:
                best_f1 = m['f1']; best_thr = thr
                best_state = {k: v.cpu() for k, v in model.state_dict().items()}
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= early_stop_patience:
                    break

        assert best_state is not None, "Training did not improve at all. Check data."
        model.load_state_dict(best_state)
        logits_va, y_va = _eval_logits(model, dl_va, device)
        fold_metrics = _metrics_from_logits(torch.tensor(logits_va), torch.tensor(y_va), threshold=best_thr)

        ckpt_path = os.path.join(out_dir, f"{variant}_{backbone}_fold{fold}.pt")
        torch.save({
            "variant": variant,
            "backbone": backbone,
            "state_dict": model.state_dict(),
            "threshold": best_thr,
            "image_size": image_size,
            "norm_mean": (norm["mean"] if norm else None),
            "norm_std": (norm["std"] if norm else None),
        }, ckpt_path)
        fold_metrics.update({"ckpt_path": ckpt_path, "fold": fold})
        fold_summaries.append(fold_metrics)

        if fold_metrics['f1'] > best_overall['f1']:
            best_overall = {"f1": fold_metrics['f1'], "path": ckpt_path}

    def _agg(key):
        vals = [f[key] for f in fold_summaries if isinstance(f.get(key), (int, float)) and not math.isnan(f[key])]
        return float(np.mean(vals)), float(np.std(vals))

    acc_m, acc_s = _agg('acc'); prec_m, prec_s = _agg('prec'); rec_m, rec_s = _agg('rec')
    f1_m, f1_s = _agg('f1'); auc_m, auc_s = _agg('auc')

    model_card = {
        "variant": variant,
        "backbone": backbone,
        "folds": fold_summaries,
        "summary": {
            "acc_mean": acc_m, "acc_std": acc_s,
            "prec_mean": prec_m, "prec_std": prec_s,
            "rec_mean": rec_m, "rec_std": rec_s,
            "f1_mean": f1_m, "f1_std": f1_s,
            "auc_mean": auc_m, "auc_std": auc_s,
        },
        "best_overall": best_overall,
    }

    with open(os.path.join(out_dir, "model_card.json"), "w") as f:
        json.dump(model_card, f, indent=2)

    if best_overall['path'] is not None:
        best_dst = os.path.join(out_dir, "best_overall.pt")
        if os.path.abspath(best_dst) != os.path.abspath(best_overall['path']):
            import shutil; shutil.copy2(best_overall['path'], best_dst)
        model_card["best_overall"]["path"] = best_dst

    return model_card


# -----------------------------
# Inference helpers
# -----------------------------
@dataclass
class InferenceBundle:
    variant: str
    backbone: str
    model: nn.Module
    threshold: float
    image_size: int
    device: torch.device
    norm_mean: Optional[List[float]]
    norm_std: Optional[List[float]]


def load_for_inference(ckpt_path: str) -> InferenceBundle:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    variant = ckpt["variant"]
    backbone = ckpt.get("backbone", "resnet18")
    model = _make_model(variant, backbone)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    device = _get_device(); model.to(device)
    thr = float(ckpt.get("threshold", 0.5))
    image_size = int(ckpt.get("image_size", 224))
    norm_mean = ckpt.get("norm_mean", None)
    norm_std = ckpt.get("norm_std", None)
    return InferenceBundle(variant, backbone, model, thr, image_size, device, norm_mean, norm_std)


def _build_image_infer_tf(image_size: int, norm_mean: Optional[List[float]], norm_std: Optional[List[float]]):
    tfs = [transforms.Resize((image_size, image_size)), transforms.ToTensor()]
    if norm_mean is not None and norm_std is not None:
        tfs.append(transforms.Normalize(mean=norm_mean, std=norm_std))
    return transforms.Compose(tfs)


def predict_row(row: pd.Series, infer: InferenceBundle,
                image_key: str = "image_path",
                mask_key: str = "mask_data") -> Tuple[bool, float]:
    """Return (bool_is_usable, probability_of_usable)."""
    img_t = torch.zeros(1, 3, infer.image_size, infer.image_size, device=infer.device)
    mask_t = torch.zeros(1, 1, infer.image_size, infer.image_size, device=infer.device)

    if infer.variant != 'mask_cnn':
        with Image.open(row[image_key]) as im:
            im = im.convert("RGB")
            tf = _build_image_infer_tf(infer.image_size, infer.norm_mean, infer.norm_std)
            img_t = tf(im).unsqueeze(0).to(infer.device)

    if infer.variant != 'resnet_overlay':
        m = extract_mask_array(row[mask_key])
        m = (m > 0).astype(np.uint8)
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Resize((infer.image_size, infer.image_size))])
        mask_t = tfm(m).unsqueeze(0).to(infer.device)

    with torch.no_grad():
        if infer.variant == 'fusion':
            logits = infer.model(img_t, mask_t)
        elif infer.variant == 'resnet_overlay':
            logits = infer.model(img_t)
        else:
            logits = infer.model(mask_t)
        prob = torch.sigmoid(logits).item()
        return (prob >= infer.threshold), float(prob)


if __name__ == "__main__":
    print("Import this module and call run_experiment(...). See docstring for usage.")


Import this module and call run_experiment(...). See docstring for usage.


In [40]:
# Custom collate to allow None for image or mask depending on variant
import torch

def collate_mask_quality(batch):
    imgs, masks, ys = zip(*batch)
    # y as tensor
    y_batch = torch.tensor(ys, dtype=torch.int64)

    # stack only when present; otherwise keep None
    img_batch = None
    mask_batch = None
    if any(t is not None for t in imgs):
        img_batch = torch.stack([t for t in imgs if t is not None])
    if any(t is not None for t in masks):
        mask_batch = torch.stack([t for t in masks if t is not None])

    return img_batch, mask_batch, y_batch



In [42]:
for v in ["resnet_overlay", "fusion"]:
    for b in ["resnet18", "resnet50", "efficientnet_b0"]:
        print("\n--------------------------------\n")
        print(f"Running {v} with {b}")
        run_experiment(df_with_masks, variant=v, backbone=b, out_dir=f"mask_quality/{v}_{b}")
        print("\n--------------------------------\n")

Fold 01 | Epoch 01 | loss 0.7405 | val F1 0.789 | thr 0.520
Fold 01 | Epoch 02 | loss 0.2454 | val F1 0.800 | thr 0.497
Fold 01 | Epoch 03 | loss 0.0814 | val F1 0.783 | thr 0.504
Fold 01 | Epoch 04 | loss 0.0534 | val F1 0.800 | thr 0.322
Fold 01 | Epoch 05 | loss 0.0451 | val F1 0.806 | thr 0.289
Fold 01 | Epoch 06 | loss 0.0463 | val F1 0.795 | thr 0.154
Fold 01 | Epoch 07 | loss 0.0410 | val F1 0.775 | thr 0.192
Fold 01 | Epoch 08 | loss 0.0297 | val F1 0.757 | thr 0.048
Fold 01 | Epoch 09 | loss 0.0204 | val F1 0.800 | thr 0.271
Fold 01 | Epoch 10 | loss 0.0169 | val F1 0.818 | thr 0.503
Fold 01 | Epoch 11 | loss 0.0096 | val F1 0.824 | thr 0.477
Fold 01 | Epoch 12 | loss 0.0150 | val F1 0.800 | thr 0.391
Fold 01 | Epoch 13 | loss 0.0117 | val F1 0.818 | thr 0.371
Fold 01 | Epoch 14 | loss 0.0083 | val F1 0.818 | thr 0.392
Fold 01 | Epoch 15 | loss 0.0038 | val F1 0.831 | thr 0.446
Fold 01 | Epoch 16 | loss 0.0118 | val F1 0.794 | thr 0.312
Fold 01 | Epoch 17 | loss 0.0037 | val F

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/utkarsh/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:08<00:00, 11.7MB/s]


Fold 01 | Epoch 01 | loss 0.8826 | val F1 0.758 | thr 0.497
Fold 01 | Epoch 02 | loss 0.6270 | val F1 0.776 | thr 0.556
Fold 01 | Epoch 03 | loss 0.2931 | val F1 0.789 | thr 0.259
Fold 01 | Epoch 04 | loss 0.1258 | val F1 0.853 | thr 0.447
Fold 01 | Epoch 05 | loss 0.0642 | val F1 0.781 | thr 0.288
Fold 01 | Epoch 06 | loss 0.0612 | val F1 0.778 | thr 0.067
Fold 01 | Epoch 07 | loss 0.0262 | val F1 0.812 | thr 0.111
Fold 01 | Epoch 08 | loss 0.0135 | val F1 0.831 | thr 0.605
Fold 01 | Epoch 09 | loss 0.0100 | val F1 0.825 | thr 0.478
Fold 02 | Epoch 01 | loss 0.8650 | val F1 0.759 | thr 0.564
Fold 02 | Epoch 02 | loss 0.5784 | val F1 0.806 | thr 0.720
Fold 02 | Epoch 03 | loss 0.3026 | val F1 0.781 | thr 0.313
Fold 02 | Epoch 04 | loss 0.1297 | val F1 0.828 | thr 0.819
Fold 02 | Epoch 05 | loss 0.0568 | val F1 0.800 | thr 0.531
Fold 02 | Epoch 06 | loss 0.0321 | val F1 0.842 | thr 0.729
Fold 02 | Epoch 07 | loss 0.0292 | val F1 0.807 | thr 0.705
Fold 02 | Epoch 08 | loss 0.0193 | val F

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /home/utkarsh/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:01<00:00, 14.1MB/s]


Fold 01 | Epoch 01 | loss 0.8687 | val F1 0.743 | thr 0.523
Fold 01 | Epoch 02 | loss 0.6598 | val F1 0.776 | thr 0.566
Fold 01 | Epoch 03 | loss 0.4875 | val F1 0.784 | thr 0.467
Fold 01 | Epoch 04 | loss 0.3492 | val F1 0.788 | thr 0.541
Fold 01 | Epoch 05 | loss 0.1985 | val F1 0.818 | thr 0.614
Fold 01 | Epoch 06 | loss 0.1431 | val F1 0.806 | thr 0.621
Fold 01 | Epoch 07 | loss 0.0905 | val F1 0.758 | thr 0.682
Fold 01 | Epoch 08 | loss 0.1018 | val F1 0.806 | thr 0.592
Fold 01 | Epoch 09 | loss 0.1088 | val F1 0.829 | thr 0.405
Fold 01 | Epoch 10 | loss 0.0497 | val F1 0.829 | thr 0.350
Fold 01 | Epoch 11 | loss 0.0483 | val F1 0.812 | thr 0.821
Fold 01 | Epoch 12 | loss 0.0301 | val F1 0.795 | thr 0.182
Fold 01 | Epoch 13 | loss 0.0232 | val F1 0.806 | thr 0.540
Fold 01 | Epoch 14 | loss 0.0423 | val F1 0.776 | thr 0.794
Fold 02 | Epoch 01 | loss 0.8554 | val F1 0.767 | thr 0.529
Fold 02 | Epoch 02 | loss 0.6519 | val F1 0.771 | thr 0.476
Fold 02 | Epoch 03 | loss 0.4700 | val F

In [43]:
"""
Model Comparison + Majority-Vote Ensemble (for Mask Usability Classifier v2)
-----------------------------------------------------------------------------
This helper script does two things:

1) Collects all model_cards from /mnt/data/mask_quality/** and prints a table
   comparing variants/backbones by mean F1 (and other metrics). Saves CSV.

2) Loads all trained checkpoints (best_overall.pt) and runs prediction on a
   given DataFrame (e.g., tree_data_merged.csv). It produces 6 predictions per
   row (for the 2 variants × 3 backbones), then computes a majority-vote
   ensemble and reports metrics. Saves predictions CSV.

Notes:
- Expects the v2 module available as `mask_usability_classifier` with
  load_for_inference(...) and predict_row(...).
- If your CSV stores `mask_data` as a string, we attempt to parse it via JSON
  or ast.literal_eval. If mask is still missing, Fusion models will skip that
  row (vote uses only available model predictions for that row).
- Ensemble majority uses a conservative tie-break: ratio > 0.5 → positive; else 0.

Usage example (in a notebook):

    import pandas as pd
    from model_compare_and_ensemble import (
        collect_model_cards, compare_models_table,
        load_all_infers, predict_all_models,
        compute_metrics, evaluate_ensemble
    )

    cards_df = compare_models_table(root="/mnt/data/mask_quality")
    display(cards_df)

    df = pd.read_csv("/mnt/data/tree_data_merged.csv")
    preds_df = predict_all_models(df, root="/mnt/data/mask_quality",
                                  image_key="image_path", mask_key="mask_data")

    # Metrics per individual model and the ensemble
    for col in [c for c in preds_df.columns if c.startswith("pred__")] + ["ensemble_pred"]:
        m = compute_metrics(preds_df["correct"], preds_df[col])
        print(col, m)

    # Quick ensemble report
    print(evaluate_ensemble(preds_df))

Outputs saved to:
- /mnt/data/mask_quality/model_comparison.csv
- /mnt/data/mask_quality/ensemble_predictions.csv
"""
from __future__ import annotations
import os
import json
import math
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

from ast import literal_eval

# -----------------------------
# 1) Collect & compare model cards
# -----------------------------

def _find_all_cards(root: str) -> List[Dict[str, Any]]:
    cards: List[Dict[str, Any]] = []
    for dirpath, dirnames, filenames in os.walk(root):
        if "model_card.json" in filenames:
            p = os.path.join(dirpath, "model_card.json")
            try:
                with open(p, "r") as f:
                    card = json.load(f)
                card["_card_path"] = p
                cards.append(card)
            except Exception:
                pass
    return cards


def collect_model_cards(root: str = "/mnt/data/mask_quality") -> pd.DataFrame:
    cards = _find_all_cards(root)
    rows = []
    for c in cards:
        s = c.get("summary", {})
        rows.append({
            "dir": os.path.dirname(c.get("_card_path", "")),
            "variant": c.get("variant"),
            "backbone": c.get("backbone", "resnet18"),
            "f1_mean": s.get("f1_mean"), "f1_std": s.get("f1_std"),
            "acc_mean": s.get("acc_mean"), "prec_mean": s.get("prec_mean"),
            "rec_mean": s.get("rec_mean"), "auc_mean": s.get("auc_mean"),
            "best_ckpt": c.get("best_overall", {}).get("path"),
        })
    df = pd.DataFrame(rows)
    return df


def compare_models_table(root: str = "/mnt/data/mask_quality") -> pd.DataFrame:
    df = collect_model_cards(root)
    if not df.empty:
        df = df.sort_values(["f1_mean", "acc_mean"], ascending=[False, False])
        out = os.path.join(root, "model_comparison.csv")
        df.to_csv(out, index=False)
        print(f"Saved: {out}")
    else:
        print("No model_card.json files found under", root)
    return df


# -----------------------------
# 2) Load all models & batch predict
# -----------------------------

def _list_best_checkpoints(root: str) -> List[Tuple[str, str]]:
    """Return list of (name, best_ckpt_path). Name is derived from directory name.
    Looks for 'best_overall.pt' in each leaf directory.
    """
    pairs: List[Tuple[str, str]] = []
    for dirpath, dirnames, filenames in os.walk(root):
        if "best_overall.pt" in filenames:
            name = os.path.basename(dirpath)  # e.g., 'resnet_overlay_resnet50' or 'fusion_efficientnet_b0'
            pairs.append((name, os.path.join(dirpath, "best_overall.pt")))
    return sorted(pairs)


def load_all_infers(root: str = "/mnt/data/mask_quality") -> Dict[str, InferenceBundle]:
    pairs = _list_best_checkpoints(root)
    infers: Dict[str, InferenceBundle] = {}
    for name, ckpt in pairs:
        try:
            infers[name] = load_for_inference(ckpt)
        except Exception as e:
            print(f"Failed to load {name} @ {ckpt}: {e}")
    if not infers:
        raise RuntimeError(f"No best_overall.pt found under {root}")
    return infers


# -----------------------------
# 3) Robust mask parsing for CSVs
# -----------------------------

def _parse_mask_cell(cell: Any) -> Any:
    """Try to recover a mask object from CSV cell. Returns a dict/np array or None.
    Accepts: dict already, JSON string, python-literal string, or None.
    """
    if isinstance(cell, (dict, np.ndarray)):
        return cell
    if cell is None or (isinstance(cell, float) and np.isnan(cell)):
        return None
    if isinstance(cell, str):
        s = cell.strip()
        # Try JSON first
        try:
            return json.loads(s)
        except Exception:
            pass
        # Try python literal (lists, tuples, dicts, arrays as lists)
        try:
            return literal_eval(s)
        except Exception:
            pass
    return None


# -----------------------------
# 4) Predict all models and build ensemble
# -----------------------------

def predict_all_models(df: pd.DataFrame,
                       root: str = "/mnt/data/mask_quality",
                       image_key: str = "image_path",
                       mask_key: str = "mask_data",
                       require_mask_for_fusion: bool = False) -> pd.DataFrame:
    """Run every saved model on every row and return a predictions DataFrame.

    Columns added:
      - pred__{model_name}  ∈ {0,1}
      - prob__{model_name}  ∈ [0,1]
      - ensemble_pred, ensemble_prob (vote ratio)

    If a Fusion model needs a mask but none can be parsed for a row:
      - if require_mask_for_fusion=False (default), we skip that model for that row
        (its pred/prob are NaN), and the ensemble uses the remaining votes.
      - if True, we drop rows lacking masks before prediction.
    """
    infers = load_all_infers(root)
    df = df.copy()

    # Ensure mask column is parsed (if present)
    if mask_key in df.columns:
        df[mask_key] = df[mask_key].apply(_parse_mask_cell)

    # Optional: drop rows without mask if we must support Fusion
    if require_mask_for_fusion:
        # Keep only rows with a parseable mask
        has_mask = df[mask_key].apply(lambda x: x is not None)
        df = df[has_mask].reset_index(drop=True)

    # Prepare output columns
    for name in infers.keys():
        df[f"pred__{name}"] = np.nan
        df[f"prob__{name}"] = np.nan

    # Iterate rows
    for i in range(len(df)):
        row = df.iloc[i]
        for name, infer in infers.items():
            try:
                # If fusion and mask missing, skip this model for this row
                if infer.variant == 'fusion' and mask_key in df.columns and row[mask_key] is None:
                    continue
                yb, pr = predict_row(row, infer, image_key=image_key, mask_key=mask_key)
                df.at[i, f"pred__{name}"] = int(yb)
                df.at[i, f"prob__{name}"] = float(pr)
            except Exception as e:
                # Keep NaNs on failure
                pass

    # Ensemble: majority over available predictions per row
    pred_cols = [c for c in df.columns if c.startswith("pred__")]
    pred_mat = df[pred_cols].to_numpy(dtype=float)  # NaNs possible
    votes = np.nansum(pred_mat, axis=1)
    counts = np.sum(~np.isnan(pred_mat), axis=1)
    with np.errstate(invalid='ignore', divide='ignore'):
        ratio = votes / counts
    # Conservative tie-break: strictly greater than 0.5
    ens_pred = (ratio > 0.5).astype(float)
    # Where no models predicted (counts==0), set to NaN
    ens_pred[counts == 0] = np.nan

    df["ensemble_prob"] = ratio
    df["ensemble_pred"] = ens_pred

    # Save predictions
    out_csv = os.path.join(root, "ensemble_predictions.csv")
    df.to_csv(out_csv, index=False)
    print(f"Saved: {out_csv}")

    return df


# -----------------------------
# 5) Metrics
# -----------------------------

def _label_to_int(arr: pd.Series | np.ndarray) -> np.ndarray:
    return np.array([_to_bool_label(x) for x in arr], dtype=int)


def compute_metrics(y_true: pd.Series | np.ndarray, y_pred: pd.Series | np.ndarray) -> Dict[str, Any]:
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
    y = _label_to_int(y_true)
    yp = pd.Series(y_pred).astype(float).to_numpy()
    mask = ~np.isnan(yp)
    y = y[mask]
    yp = yp[mask].astype(int)
    if len(y) == 0:
        return {"n": 0, "acc": np.nan, "prec": np.nan, "rec": np.nan, "f1": np.nan, "cm": [[0,0],[0,0]]}
    acc = accuracy_score(y, yp)
    prec, rec, f1, _ = precision_recall_fscore_support(y, yp, average='binary', zero_division=0)
    cm = confusion_matrix(y, yp).tolist()
    return {"n": int(len(y)), "acc": acc, "prec": prec, "rec": rec, "f1": f1, "cm": cm}


def evaluate_ensemble(df_preds: pd.DataFrame, label_col: str = "correct") -> Dict[str, Any]:
    # Individual model metrics
    per_model = {}
    for c in [c for c in df_preds.columns if c.startswith("pred__")]:
        per_model[c] = compute_metrics(df_preds[label_col], df_preds[c])

    # Ensemble metrics
    ens = compute_metrics(df_preds[label_col], df_preds["ensemble_pred"])

    # Pack
    return {
        "ensemble": ens,
        "per_model": per_model,
        "n_rows": int(len(df_preds)),
        "n_rows_with_any_vote": int(np.sum(~np.isnan(df_preds["ensemble_pred"]))),
        "vote_counts_summary": df_preds[[c for c in df_preds.columns if c.startswith("pred__")]]
            .notna().sum(axis=1).describe().to_dict(),
    }


In [44]:
# 1) Compare all trained models (reads every model_card.json)
cards_df = compare_models_table(root="mask_quality")
display(cards_df)

# 2) Predict with all six models on your CSV
df = pd.read_csv("tree_data_merged.csv")
preds_df = predict_all_models(
    df, 
    root="mask_quality",
    image_key="image_path",
    mask_key="mask_data",          # if masks are strings, this function will try to parse them
    require_mask_for_fusion=False  # keep False to allow partial votes when masks are missing
)

# 3) Metrics per model and for the majority-vote ensemble
for col in [c for c in preds_df.columns if c.startswith("pred__")] + ["ensemble_pred"]:
    m = compute_metrics(preds_df["correct"], preds_df[col])
    print(col, m)

print("Ensemble summary:", evaluate_ensemble(preds_df))


Saved: mask_quality/model_comparison.csv


Unnamed: 0,dir,variant,backbone,f1_mean,f1_std,acc_mean,prec_mean,rec_mean,auc_mean,best_ckpt
2,mask_quality/fusion_resnet50,fusion,resnet50,0.86452,0.039754,0.902503,0.837991,0.90086,0.93663,mask_quality/fusion_resnet50/fusion_resnet50_f...
5,mask_quality/resnet_overlay_resnet50,resnet_overlay,resnet50,0.860515,0.038029,0.897983,0.821142,0.913763,0.920645,mask_quality/resnet_overlay_resnet50/resnet_ov...
4,mask_quality/fusion_resnet18,fusion,resnet18,0.859625,0.052442,0.902477,0.848594,0.873763,0.92805,mask_quality/fusion_resnet18/fusion_resnet18_f...
0,mask_quality/resnet_overlay_resnet18,resnet_overlay,resnet18,0.841338,0.026632,0.886645,0.823654,0.867527,0.926663,mask_quality/resnet_overlay_resnet18/resnet_ov...
3,mask_quality/resnet_overlay_efficientnet_b0,resnet_overlay,efficientnet_b0,0.837285,0.030429,0.884397,0.812307,0.867097,0.913359,mask_quality/resnet_overlay_efficientnet_b0/re...
1,mask_quality/fusion_efficientnet_b0,fusion,efficientnet_b0,0.837122,0.035479,0.884372,0.808824,0.867527,0.914839,mask_quality/fusion_efficientnet_b0/fusion_eff...


  ckpt = torch.load(ckpt_path, map_location="cpu")


Saved: mask_quality/ensemble_predictions.csv
pred__fusion_efficientnet_b0 {'n': 0, 'acc': nan, 'prec': nan, 'rec': nan, 'f1': nan, 'cm': [[0, 0], [0, 0]]}
pred__fusion_resnet18 {'n': 0, 'acc': nan, 'prec': nan, 'rec': nan, 'f1': nan, 'cm': [[0, 0], [0, 0]]}
pred__fusion_resnet50 {'n': 0, 'acc': nan, 'prec': nan, 'rec': nan, 'f1': nan, 'cm': [[0, 0], [0, 0]]}
pred__resnet_overlay_efficientnet_b0 {'n': 441, 'acc': 0.981859410430839, 'prec': np.float64(0.9673202614379085), 'rec': np.float64(0.9801324503311258), 'f1': np.float64(0.9736842105263158), 'cm': [[285, 5], [3, 148]]}
pred__resnet_overlay_resnet18 {'n': 441, 'acc': 0.981859410430839, 'prec': np.float64(0.9798657718120806), 'rec': np.float64(0.9668874172185431), 'f1': np.float64(0.9733333333333334), 'cm': [[287, 3], [5, 146]]}
pred__resnet_overlay_resnet50 {'n': 441, 'acc': 0.9886621315192744, 'prec': np.float64(0.9802631578947368), 'rec': np.float64(0.9867549668874173), 'f1': np.float64(0.9834983498349835), 'cm': [[287, 3], [2, 14