# Training staiblity and setup

## Batch Norm vs Group Norm

In [None]:
%load_ext autoreload
%autoreload 2
import utils.read_runs as rr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

def plot_validation_metric_comparison(
    run_paths,
    run_labels,
    column="val_loss",
    as_metric="identity",
    best_mode="best_epoch",
    title=None,
    save_path=None,
    colors=None
):
    """
    Plot validation metrics for multiple runs with median and IQR.
    
    Args:
        run_paths: List of paths to optuna runs
        run_labels: List of labels for each run
        column: Column to plot (default "val_loss")
        as_metric: '1-minus' to convert loss -> Dice, 'identity' for raw values
        best_mode: 'final' or 'best_epoch' for selecting best trial
        title: Plot title
        save_path: Where to save the plot
        colors: List of colors for each run
    """
    if colors is None:
        colors = [f"C{i}" for i in range(len(run_paths))]
    
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
    
    for i, (run_path, label, color) in enumerate(zip(run_paths, run_labels, colors)):
        try:
            df = rr.get_optuna_df(run_path)
            series = df[column].dropna()
            
            # Collect trials as arrays (variable lengths)
            trials = [np.asarray(v, dtype=float) for v in series]
            if as_metric == "1-minus":
                trials = [1.0 - v for v in trials]  # convert Dice loss -> Dice
                y_label = "Validation Soft Dice"
                prefer = "max"
            else:
                y_label = f"Validation {column.replace('_', ' ').title()}"
                prefer = "min"
            
            max_len = max(len(t) for t in trials)
            arr = np.full((len(trials), max_len), np.nan)
            for j, t in enumerate(trials):
                arr[j, :len(t)] = t
            
            epochs = np.arange(1, max_len + 1)
            median = np.nanmedian(arr, axis=0)
            q1 = np.nanpercentile(arr, 25, axis=0)
            q3 = np.nanpercentile(arr, 75, axis=0)
            
            # Plot median line and IQR fill
            ax.plot(epochs, median, color=color, label=f"{label} (median)", linewidth=2)
            ax.fill_between(epochs, q1, q3, color=color, alpha=0.2, label=f"{label} (IQR)")
            
            # Optionally add best trial curve
            if best_mode == "final":
                finals = np.array([row[~np.isnan(row)][-1] if np.any(~np.isnan(row)) else np.nan for row in arr])
                best_idx = int(np.nanargmax(finals)) if prefer == "max" else int(np.nanargmin(finals))
            else:  # best_epoch
                agg = np.nanmax(arr, axis=1) if prefer == "max" else np.nanmin(arr, axis=1)
                best_idx = int(np.nanargmax(agg)) if prefer == "max" else int(np.nanargmin(agg))
            
            best_curve = arr[best_idx, :]
            ax.plot(epochs, best_curve, color=color, ls="--", lw=1, alpha=0.7, 
                   label=f"{label} (best trial)")
            
        except Exception as e:
            print(f"Warning: Could not load data for {label}: {e}")
    
    ax.set_xlabel("Epoch")
    ax.set_ylabel(y_label)
    ax.set_title(title or "Validation Metric Comparison")
    ax.grid(True, alpha=0.25)
    ax.legend(loc='upper right')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
    
    plt.show()

# Create output directory
fig_dir = Path("figs/normalization_comparison")
fig_dir.mkdir(parents=True, exist_ok=True)

# Compare BatchNorm vs GroupNorm for TH prediction
th_runs = [
    "",  
    ""   
]
th_labels = ["TH — BatchNorm", "TH — GroupNorm + Anneal"]
th_colors = ["C0", "C1"]

plot_validation_metric_comparison(
    run_paths=th_runs,
    run_labels=th_labels,
    colors=th_colors,
    title="TH Prediction: BatchNorm vs GroupNorm Comparison",
    save_path=fig_dir / "th_batchnorm_vs_groupnorm.pdf"
)

# Compare BatchNorm vs GroupNorm for NF prediction
nf_runs = [
    "",  # NF BatchNorm
    ""   # NF GroupNorm + anneal
]
nf_labels = ["NF — BatchNorm", "NF — GroupNorm + Anneal"]
nf_colors = ["C2", "C3"]

plot_validation_metric_comparison(
    run_paths=nf_runs,
    run_labels=nf_labels,
    colors=nf_colors,
    title="NF Prediction: BatchNorm vs GroupNorm Comparison",
    save_path=fig_dir / "nf_batchnorm_vs_groupnorm.pdf"
)

## anneal and lr behavior

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import json
import os

def plot_trial_analysis_separate(
    trial_path, 
    figsize=(10, 6),
    save_dir=None,                 # e.g. ".../trial_1/plots"
    save_formats=("png",),         # e.g. ("png","pdf")
    dpi=300,
    show=True,
    close=False
):
    """
    Create 4 separate plots for a single trial and optionally save them.

    Returns:
        figures, weight_lists, saved_paths
        - figures: list[Figure]
        - weight_lists: dict of lists
        - saved_paths: dict[str, list[str]] mapping short name -> list of saved files
    """
    # Load metadata
    metadata_path = os.path.join(trial_path, "metadata.json")
    if not os.path.exists(metadata_path):
        raise FileNotFoundError(f"metadata.json not found in {trial_path}")
    with open(metadata_path, "r") as f:
        metadata = json.load(f)

    trial_name = os.path.basename(os.path.normpath(trial_path))
    figures = []
    saved_paths = {"lr": [], "weights": [], "comp": [], "comp_weighted": []}

    # Ensure output dir
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    # 1) Learning rate
    fig1 = plt.figure(figsize=figsize); ax1 = fig1.add_subplot(111)
    plot_learning_rate(ax1, metadata, trial_name)
    fig1.tight_layout(); figures.append(fig1)
    if save_dir:
        for fmt in save_formats:
            p = os.path.join(save_dir, f"{trial_name}_learning_rate.{fmt}")
            fig1.savefig(p, dpi=dpi, bbox_inches="tight"); saved_paths["lr"].append(p)
    if show: plt.show()
    if close: plt.close(fig1)

    # 2) Loss weights
    fig2 = plt.figure(figsize=figsize); ax2 = fig2.add_subplot(111)
    weight_lists = plot_loss_weights(ax2, metadata, trial_name)
    fig2.tight_layout(); figures.append(fig2)
    if save_dir:
        for fmt in save_formats:
            p = os.path.join(save_dir, f"{trial_name}_loss_weights.{fmt}")
            fig2.savefig(p, dpi=dpi, bbox_inches="tight"); saved_paths["weights"].append(p)
    if show: plt.show()
    if close: plt.close(fig2)

    # 3) Component losses (unweighted)
    fig3 = plt.figure(figsize=figsize); ax3 = fig3.add_subplot(111)
    plot_component_losses(ax3, metadata, trial_name)
    fig3.tight_layout(); figures.append(fig3)
    if save_dir:
        for fmt in save_formats:
            p = os.path.join(save_dir, f"{trial_name}_component_losses.{fmt}")
            fig3.savefig(p, dpi=dpi, bbox_inches="tight"); saved_paths["comp"].append(p)
    if show: plt.show()
    if close: plt.close(fig3)

    # 4) Weighted component losses
    fig4 = plt.figure(figsize=figsize); ax4 = fig4.add_subplot(111)
    plot_weighted_component_losses(ax4, metadata, weight_lists, trial_name)
    fig4.tight_layout(); figures.append(fig4)
    if save_dir:
        for fmt in save_formats:
            p = os.path.join(save_dir, f"{trial_name}_weighted_component_losses.{fmt}")
            fig4.savefig(p, dpi=dpi, bbox_inches="tight"); saved_paths["comp_weighted"].append(p)
    if show: plt.show()
    if close: plt.close(fig4)

    return figures, weight_lists, saved_paths

def plot_learning_rate(ax, metadata, trial_name):
    lrs = metadata.get("lrs", [])
    epochs = list(range(len(lrs)))
    ax.plot(epochs, lrs, 'b-', linewidth=2, label='Learning Rate')
    ax.set_xlabel('Epoch'); ax.set_ylabel('Learning Rate')
    ax.set_title(f'Learning Rate Schedule - {trial_name}')
    if len(lrs) and all(v > 0 for v in lrs):
        ax.set_yscale('log')
    ax.grid(True, alpha=0.3)

    # Mark annealing events
    for event in metadata.get("loss_weight_history", []):
        ep = event.get("epoch", None)
        if ep is not None and 0 <= ep < len(lrs):
            ax.axvline(x=ep, color='red', linestyle='--', alpha=0.7)
            ax.text(ep, lrs[ep], f'Anneal\n@{ep}', rotation=90, ha='right', va='bottom', fontsize=8)
    ax.legend()

def plot_loss_weights(ax, metadata, trial_name):
    initial_w_dice = metadata.get("w_dice", 0.6)
    initial_w_cldice = metadata.get("w_cldice", 0.2)
    initial_w_bce = metadata.get("w_bce", 0.2)

    epochs_run = metadata.get("epochs_run", 0)
    epochs = list(range(epochs_run))

    w_dice_list = [initial_w_dice] * epochs_run
    w_cldice_list = [initial_w_cldice] * epochs_run
    w_bce_list  = [initial_w_bce]  * epochs_run

    for event in metadata.get("loss_weight_history", []):
        ep = event.get("epoch", None)
        if ep is None: 
            continue
        new_w_dice = event.get("w_dice", w_dice_list[ep])
        new_w_cldice = event.get("w_cldice", w_cldice_list[ep])
        new_w_bce = event.get("w_bce", w_bce_list[ep])
        for i in range(ep, epochs_run):
            w_dice_list[i] = new_w_dice
            w_cldice_list[i] = new_w_cldice
            w_bce_list[i] = new_w_bce

    ax.plot(epochs, w_dice_list, 'r-', linewidth=2, label='Dice Weight')
    ax.plot(epochs, w_cldice_list, 'g-', linewidth=2, label='clDice Weight')
    ax.plot(epochs, w_bce_list,  'b-', linewidth=2, label='BCE Weight')
    ax.set_xlabel('Epoch'); ax.set_ylabel('Weight')
    ax.set_title(f'Loss Component Weights - {trial_name}')
    ax.grid(True, alpha=0.3); ax.legend()

    for event in metadata.get("loss_weight_history", []):
        ep = event.get("epoch", None)
        if ep is not None:
            ax.axvline(x=ep, color='black', linestyle='--', alpha=0.7)

    return {'w_dice': w_dice_list, 'w_cldice': w_cldice_list, 'w_bce': w_bce_list}

def plot_component_losses(ax, metadata, trial_name):
    comp_loss = metadata.get("comp_loss", [])
    if not comp_loss:
        ax.text(0.5, 0.5, 'No component loss data', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(f'Component Losses (Unweighted) - {trial_name}')
        return
    epochs = list(range(len(comp_loss)))
    dice_losses   = [e["dice"]   for e in comp_loss]
    cldice_losses = [e["cldice"] for e in comp_loss]
    bce_losses    = [e["bce"]    for e in comp_loss]
    ax.plot(epochs, dice_losses,   'r-', linewidth=2, label='Dice Loss')
    ax.plot(epochs, cldice_losses, 'g-', linewidth=2, label='clDice Loss')
    ax.plot(epochs, bce_losses,    'b-', linewidth=2, label='BCE Loss')
    ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')
    ax.set_title(f'Component Losses (Unweighted) - {trial_name}')
    ax.grid(True, alpha=0.3); ax.legend()

def plot_weighted_component_losses(ax, metadata, weight_lists, trial_name):
    comp_loss = metadata.get("comp_loss", [])
    if not comp_loss or not weight_lists:
        ax.text(0.5, 0.5, 'No data available', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(f'Weighted Component Losses - {trial_name}')
        return
    n = len(comp_loss)
    epochs = list(range(n))
    wd, wc, wb = weight_lists['w_dice'], weight_lists['w_cldice'], weight_lists['w_bce']
    m = min(n, len(wd), len(wc), len(wb))

    weighted_dice   = [comp_loss[i]["dice"]   * wd[i] for i in range(m)]
    weighted_cldice = [comp_loss[i]["cldice"] * wc[i] for i in range(m)]
    weighted_bce    = [comp_loss[i]["bce"]    * wb[i] for i in range(m)]

    ax.plot(epochs[:m], weighted_dice,   'r-', linewidth=2, label='Weighted Dice Loss')
    ax.plot(epochs[:m], weighted_cldice, 'g-', linewidth=2, label='Weighted clDice Loss')
    ax.plot(epochs[:m], weighted_bce,    'b-', linewidth=2, label='Weighted BCE Loss')
    ax.set_xlabel('Epoch'); ax.set_ylabel('Weighted Loss')
    ax.set_title(f'Weighted Component Losses - {trial_name}')
    ax.grid(True, alpha=0.3); ax.legend()


In [None]:
from pathlib import Path

anneal_fig_dir = Path("figs/anneal_vs_dice_comparison")
anneal_fig_dir.mkdir(parents=True, exist_ok=True)

# ⬇️ Same trial path as your original example — NOT changed
figs, weights, saved = plot_trial_analysis_separate(
    trial_path="",
    save_dir=anneal_fig_dir,       # only the save location changes
    save_formats=("pdf",),         # or ("png","pdf")
    show=True,
    close=True
)
print(saved)


## With and without peripherin difference

In [None]:
def compare_th_experiments_with_without_peripherin(figsize=(12, 8), save_path=None):
    """
    Compare NF prediction experiments with and without peripherin channel.
    
    Args:
        figsize (tuple): Figure size
        save_path (str): Optional path to save the plot
    """
    # Define the two experiments
    with_peripherin_path = ""  # 37 channels (with peripherin)
    without_peripherin_path = ""  # 36 channels (without peripherin)
    
    # Get validation curves for both experiments
    with_peripherin_mean = get_mean_validation_curve(with_peripherin_path)
    without_peripherin_mean = get_mean_validation_curve(without_peripherin_path)
    
    # Create the comparison plot
    fig, ax = plt.subplots(figsize=figsize, dpi=150)
    
    # Plot both curves
    epochs_with = list(range(1, len(with_peripherin_mean) + 1))
    epochs_without = list(range(1, len(without_peripherin_mean) + 1))
    
    ax.plot(epochs_with, with_peripherin_mean, 'b-', linewidth=2.5, 
            label='NF with Peripherin (37 channels)', alpha=0.8)
    ax.plot(epochs_without, without_peripherin_mean, 'r-', linewidth=2.5, 
            label='NF without Peripherin (36 channels)', alpha=0.8)
    
    # Formatting
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Mean Validation Loss', fontsize=12)
    ax.set_title('NF Prediction: With vs Without Peripherin Channel', fontsize=14, pad=20)
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=11, loc='upper right')
    
    # Add some statistics
    final_with = with_peripherin_mean[-1] if len(with_peripherin_mean) > 0 else float('nan')
    final_without = without_peripherin_mean[-1] if len(without_peripherin_mean) > 0 else float('nan')
    
    # Add text box with final values
    textstr = f'Final validation loss:\nWith Peripherin: {final_with:.4f}\nWithout Peripherin: {final_without:.4f}'
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
    ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=props)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
    
    plt.show()
    
    return {
        'with_peripherin_mean': with_peripherin_mean,
        'without_peripherin_mean': without_peripherin_mean,
        'final_with': final_with,
        'final_without': final_without
    }

def get_mean_validation_curve(run_path):
    """
    Get the mean validation curve across all trials in a run.
    
    Args:
        run_path (str): Path to the optuna run directory
        
    Returns:
        list: Mean validation loss at each epoch across all trials
    """
    try:
        # Load the run data
        df = rr.get_optuna_df(run_path)
        
        if df.empty:
            print(f"Warning: No data found in {run_path}")
            return []
        
        # Extract validation loss series
        val_loss_series = df['val_loss'].dropna()
        
        if val_loss_series.empty:
            print(f"Warning: No validation loss data found in {run_path}")
            return []
        
        # Convert to list of arrays (each trial can have different length)
        trials = []
        for val_loss_list in val_loss_series:
            if isinstance(val_loss_list, list) and len(val_loss_list) > 0:
                trials.append(np.array(val_loss_list, dtype=float))
        
        if not trials:
            print(f"Warning: No valid trials found in {run_path}")
            return []
        
        # Find the maximum length across all trials
        max_epochs = max(len(trial) for trial in trials)
        
        # Create array to hold all trials (padded with NaN)
        all_trials = np.full((len(trials), max_epochs), np.nan)
        for i, trial in enumerate(trials):
            all_trials[i, :len(trial)] = trial
        
        # Calculate mean across trials at each epoch (ignoring NaN values)
        mean_curve = np.nanmean(all_trials, axis=0)
        
        # Remove trailing NaN values
        valid_indices = ~np.isnan(mean_curve)
        if np.any(valid_indices):
            last_valid = np.where(valid_indices)[0][-1]
            mean_curve = mean_curve[:last_valid + 1]
        
        print(f"Processed {len(trials)} trials from {run_path}")
        print(f"Mean curve length: {len(mean_curve)} epochs")
        
        return mean_curve.tolist()
        
    except Exception as e:
        print(f"Error processing {run_path}: {e}")
        return []

# Run the comparison
results = compare_th_experiments_with_without_peripherin(
    save_path="figs/th_with_without_peripherin_comparison.pdf"
)

# Print some summary statistics
print("\nSummary:")
print(f"With Peripherin - Final loss: {results['final_with']:.4f}")
print(f"Without Peripherin - Final loss: {results['final_without']:.4f}")
if not np.isnan(results['final_with']) and not np.isnan(results['final_without']):
    difference = results['final_with'] - results['final_without']
    print(f"Difference (With - Without): {difference:.4f}")
    if difference > 0:
        print("Without peripherin performs better (lower loss)")
    else:
        print("With peripherin performs better (lower loss)")

# Peripherin plots

## Anneal vs no anneal

In [None]:

import json

def plot_validation_metric_with_continuation(
    run_path,
    continuation_paths=None,  
    dice_only_run_path=None,  
    column="val_loss",
    as_metric="identity",
    best_mode="best_hard_dice",  
    title=None,
    save_path=None,
):
    """
    Plot validation metric with optional continuation phases and dice-only comparison.
    continuation_paths: dict mapping continuation labels to their directories
    dice_only_run_path: path to standalone dice-only run (e.g., peripherin_scratch_2025-08-19_08-20-01)
    """
    # Get original run data
    df = rr.get_optuna_df(run_path)
    series = df[column].dropna()

    # Collect trials as arrays (variable lengths)
    trials = [np.asarray(v, dtype=float) for v in series]
    if as_metric == "1-minus":
        trials = [1.0 - v for v in trials]
        y_label = "Validation Soft Dice"
        prefer = "max"
    else:
        y_label = f"Validation {column.replace('_', ' ').title()}"
        prefer = "min"  # for loss, we want minimum

    max_len = max(len(t) for t in trials)
    
    # Load continuation data if provided
    continuation_data = {}
    if continuation_paths:
        for cont_label, cont_path in continuation_paths.items():
            try:
                cont_meta_path = Path(cont_path) / "metadata_dice_continuation.json"
                if cont_meta_path.exists():
                    cont_meta = json.loads(cont_meta_path.read_text())
                    cont_losses = cont_meta.get("training_curves", {}).get("val_losses", [])
                    if cont_losses:
                        if as_metric == "1-minus":
                            cont_losses = [1.0 - v for v in cont_losses]
                        continuation_data[cont_label] = np.array(cont_losses)
                        print(f"Loaded {len(cont_losses)} continuation epochs for {cont_label}")
            except Exception as e:
                print(f"Warning: Could not load continuation data for {cont_label}: {e}")

    dice_only_data = None
    if dice_only_run_path:
        try:
            dice_df = rr.get_optuna_df(dice_only_run_path)
            dice_series = dice_df[column].dropna()
            dice_trials = [np.asarray(v, dtype=float) for v in dice_series]
            if as_metric == "1-minus":
                dice_trials = [1.0 - v for v in dice_trials]
            
            # Create array for dice-only trials
            dice_max_len = max(len(t) for t in dice_trials) if dice_trials else 0
            dice_arr = np.full((len(dice_trials), dice_max_len), np.nan)
            for i, t in enumerate(dice_trials):
                dice_arr[i, :len(t)] = t
            
            # Calculate median curve instead of selecting best trial
            dice_only_data = np.nanmedian(dice_arr, axis=0)
            print(f"Loaded dice-only median curve: {len(dice_only_data)} epochs from {len(dice_trials)} trials")
                    
        except Exception as e:
            print(f"Warning: Could not load dice-only run data: {e}")

    # Find the best trial using hard dice metric
    arr = np.full((len(trials), max_len), np.nan)
    for i, t in enumerate(trials):
        arr[i, :len(t)] = t

    if best_mode == "best_hard_dice":
        # Use the same selection as get_best_trial() - based on val_dice_at_best_thr
        periph_df = rr.get_peripherin_df(run_path)
        if periph_df is not None and "val_dice_at_best_thr" in periph_df.columns:
            best_trial_name = periph_df["val_dice_at_best_thr"].astype(float).idxmax()
            # Extract trial number from name like "trial_15"
            if best_trial_name.startswith("trial_"):
                best_idx = int(best_trial_name.split("_")[1])
            else:
                # Fallback to index position
                best_idx = periph_df.index.get_loc(best_trial_name)
            print(f"Selected trial {best_trial_name} (index {best_idx}) based on val_dice_at_best_thr")
        else:
            # Fallback to best epoch method
            agg = np.nanmax(arr, axis=1) if prefer == "max" else np.nanmin(arr, axis=1)
            best_idx = int(np.nanargmax(agg)) if prefer == "max" else int(np.nanargmin(agg))
            print(f"Fallback: Selected trial {best_idx} based on best epoch")
    elif best_mode == "final":
        finals = np.array([row[~np.isnan(row)][-1] if np.any(~np.isnan(row)) else np.nan for row in arr])
        best_idx = int(np.nanargmax(finals)) if prefer == "max" else int(np.nanargmin(finals))
    else:  # best_epoch
        agg = np.nanmax(arr, axis=1) if prefer == "max" else np.nanmin(arr, axis=1)
        best_idx = int(np.nanargmax(agg)) if prefer == "max" else int(np.nanargmin(agg))
    
    best_curve = arr[best_idx, :]
    
    # Find the actual end point of the best trial (where it stopped)
    best_trial_valid_epochs = np.where(~np.isnan(best_curve))[0]
    if len(best_trial_valid_epochs) > 0:
        best_trial_end_epoch = best_trial_valid_epochs[-1] + 1  # +1 for 1-based indexing
        best_trial_final_loss = best_curve[best_trial_valid_epochs[-1]]
    else:
        best_trial_end_epoch = 1
        best_trial_final_loss = np.nan
    
    print(f"Best trial ended at epoch {best_trial_end_epoch}")
    
    # Calculate statistics for original runs
    epochs = np.arange(1, max_len + 1)
    median = np.nanmedian(arr, axis=0)
    q1 = np.nanpercentile(arr, 25, axis=0)
    q3 = np.nanpercentile(arr, 75, axis=0)

    # Plot
    fig, ax = plt.subplots(figsize=(10, 4), dpi=150)
    
    # Original run curves
    ax.plot(epochs, median, color="C0", label="Median (original)", linewidth=2)
    ax.fill_between(epochs, q1, q3, color="C0", alpha=0.2, label="IQR (original)")
    ax.plot(epochs, best_curve, color="C1", ls="--", lw=1.5, label=f"Best trial (#{best_idx})")

    # Add dice-only standalone curve if provided
    if dice_only_data is not None:
        dice_epochs = np.arange(1, len(dice_only_data) + 1)
        ax.plot(dice_epochs, dice_only_data, color="C4", linewidth=2, 
               label="Dice-only (median)", linestyle="-.")

    # Add continuation data starting from the actual end of the best trial
    colors = ["C2", "C3"]  # Reserve C4 for dice-only standalone
    
    for i, (cont_label, cont_curve) in enumerate(continuation_data.items()):
        if len(cont_curve) > 0:
            # Start continuation from where the best trial actually ended
            cont_epochs = np.arange(best_trial_end_epoch + 1, best_trial_end_epoch + 1 + len(cont_curve))
            color = colors[i % len(colors)]
            
            # Connect the curves at transition point
            if not np.isnan(best_trial_final_loss):
                # Draw connecting line
                ax.plot([best_trial_end_epoch, best_trial_end_epoch + 1], 
                       [best_trial_final_loss, cont_curve[0]], 
                       color=color, linestyle=":", alpha=0.7)
            
            # Plot continuation (single curve, no ribbon since it's just one trial)
            ax.plot(cont_epochs, cont_curve, color=color, linewidth=2,
                   label=f"{cont_label} continuation")
            
            # Mark transition point
            if i == 0:  # Only add the label once
                ax.axvline(best_trial_end_epoch + 0.5, color="red", linestyle="--", 
                          alpha=0.7, linewidth=1, label="Dice-only start")

    # After plotting everything but before plt.tight_layout()
    # Legend inside upper right


    # --- Add text box with final values ---
    final_texts = []

    # Final value of best trial
    if not np.isnan(best_trial_final_loss):
        final_texts.append(f"Pre-continuation lowest: {min(best_curve[best_trial_valid_epochs]):.4f}")

    # Dice-only run final value
    if dice_only_data is not None and len(dice_only_data) > 0:
        final_texts.append(f"Dice-only lowest: {min(dice_only_data):.4f}")

    # Continuation(s) final values
    for cont_label, cont_curve in continuation_data.items():
        if len(cont_curve) > 0:
            final_texts.append(f"Continuation lowest: {min(cont_curve):.4f}")

    if final_texts:
        textstr = "Best validation:\n" + "\n".join(final_texts)
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
        ax.text(
            0.3, 0.98, textstr,   
            transform=ax.transAxes,
            fontsize=9,
            verticalalignment='top',
            horizontalalignment='left',
            bbox=props
        )

    ax.set_xlabel("Epoch")
    ax.set_ylabel(y_label)
    ax.set_title(title or f"{Path(run_path).name} with continuations")
    ax.grid(True, alpha=0.25)
    ax.legend(loc="upper right", frameon=True)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')

    plt.show()
    return fig, ax

# Create output directory
anneal_fig_dir = Path("figs/anneal_vs_dice_comparison")
anneal_fig_dir.mkdir(parents=True, exist_ok=True)

# 1. Scratch: Annealed vs Dice-only
plot_validation_metric_with_continuation(
    run_path="",
    continuation_paths={
        "dice_only": ""
    },
    dice_only_run_path="",
    column="val_loss",
    best_mode="best_hard_dice",
    title="PRPH: Scratch — Annealed vs Dice-only",
    save_path=anneal_fig_dir / "scratch_annealed_vs_dice_only.pdf"
)

# 2. TH Transfer: Annealed vs Dice-only  
plot_validation_metric_with_continuation(
    run_path="",
    continuation_paths={
        "dice_only": ""
    },
    dice_only_run_path="",
    column="val_loss",
    best_mode="best_hard_dice",
    title="PRPH: TH Transfer — Annealed vs Dice-only",
    save_path=anneal_fig_dir / "th_transfer_annealed_vs_dice_only.pdf"
)

# 3. NF Transfer: Annealed vs Dice-only
plot_validation_metric_with_continuation(
    run_path="",
    continuation_paths={
        "dice_only": ""
    },
    dice_only_run_path="",
    column="val_loss",
    best_mode="best_hard_dice",
    title="PRPH: NF Transfer — Annealed vs Dice-only",
    save_path=anneal_fig_dir / "nf_transfer_annealed_vs_dice_only.pdf"
)

## Anneal vs soft_dice difference plots

In [None]:
# === PRPH annealed(continued) vs dice-only: delta histogram + zoomed scatter with highlights ===
# Outputs:
#   figs/anneal_vs_dice/delta_hist_{ctx}_{metric}.pdf
#   figs/anneal_vs_dice/scatter_{ctx}_{metric}_zoom.pdf
#   figs/anneal_vs_dice/highlighted_{ctx}_{metric}.csv
#
# Contexts use "*_continued" as the annealed variant, per your preference.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# -------------------- CONFIG --------------------
CSV_PATH = "per_sample_metrics.csv"
METRIC   = "soft_dice"    # or "cldice"
METRIC_LABEL = "clDice" if METRIC == "cldice" else METRIC.replace("_", " ").title()

# Pair the runs you want to compare (annealed=continued vs pure dice).
CONTEXTS = {
    "scratch": {"annealed": "scratch_continued", "dice": "scratch_dice_only"},
    "nf":      {"annealed": "nf_continued",      "dice": "nf_dice_only"},
    "th":      {"annealed": "th_continued",      "dice": "th_dice_only"},
}

# Highlighting controls for scatter
TOP_K_ABS_DELTA = 12        # show top-|Δ| cases
DELTA_THRESHOLD = 0.05      # also highlight any |Δ| >= this threshold
SHOW_LABELS     = True      # annotate highlighted points with sample_idx

# Bootstrap for CI
N_BOOT          = 10_000
SEED            = 42

# Output directory
out_dir = Path("figs/anneal_vs_dice")
out_dir.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------

# ---------- helpers ----------
def _pair_xy(df, annealed_label, dice_label, metric):
    g = df.groupby(["sample_idx", "model_label"], as_index=False).agg({metric: "mean"})
    wide = g[g["model_label"].isin([annealed_label, dice_label])].pivot(
        index="sample_idx", columns="model_label", values=metric
    ).dropna(subset=[annealed_label, dice_label])
    x = wide[annealed_label].to_numpy()   # Annealed (continued)
    y = wide[dice_label].to_numpy()       # Dice-only
    return x, y, wide.index.values

def _bootstrap_mean_ci(x, n_boot=N_BOOT, ci=95, seed=SEED):
    rng = np.random.default_rng(seed)
    boots = rng.choice(x, size=(n_boot, x.size), replace=True).mean(axis=1)
    lo = np.percentile(boots, (100-ci)/2)
    hi = np.percentile(boots, 100-(100-ci)/2)
    return x.mean(), (lo, hi)

def _legend_and_textbox(ax, text_lines, kind="scatter"):
    """
    kind = "hist" or "scatter"
    - Legend is always loc='upper right'
    - Text placement:
        * hist: top-right, below legend
        * scatter: top-left
    """
    loc = "upper left" if kind == "scatter" else "upper right"
    ax.legend(loc=loc, frameon=True)

    if not text_lines:
        return

    props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
    text = "\n".join(text_lines)

    if kind == "hist":
        # Place textbox below legend area; tweak y-coord if needed for your legend height
        ax.text(
            0.98, 0.80, text,
            transform=ax.transAxes, fontsize=9,
            va='top', ha='right', bbox=props
        )
    else:
        # scatter: top-left
        ax.text(
            0.02, 0.80, text,
            transform=ax.transAxes, fontsize=9,
            va='top', ha='left', bbox=props
        )

# ---------- plotters ----------
def plot_delta_hist(diffs, out_path, title=""):
    """
    diffs: Dice-only − Annealed (continued)
    """
    mean_delta, (lo, hi) = _bootstrap_mean_ci(diffs)
    fig, ax = plt.subplots(figsize=(6.2, 4.2), dpi=130)
    ax.hist(diffs, bins=15, alpha=0.8, edgecolor='black')
    ax.axvline(0, linestyle="--", linewidth=1.0, color='red', alpha=0.9, label='Δ = 0')
    ax.axvline(mean_delta, linewidth=1.2, color='blue', label=f'Mean Δ = {mean_delta:.3f}')
    ax.set_xlabel(f"Δ {METRIC_LABEL} (Dice-only − Annealed)")
    ax.set_ylabel("Number of Slides")
    ax.set_title(title)  # keep blank or minimal to avoid clutter
    _legend_and_textbox(ax, [
        "Summary:",
        f"Mean Δ: {mean_delta:.3f}",
        f"95% CI: [{lo:.3f}, {hi:.3f}]",
        f"n = {diffs.size}",
    ], kind="hist")
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200, bbox_inches='tight')
    plt.show()
    print("Saved:", out_path)

def plot_scatter_zoom(x, y, sample_ids, out_path, title="", pad=0.02,
                      top_k=TOP_K_ABS_DELTA, thr=DELTA_THRESHOLD, show_labels=SHOW_LABELS):
    # Axis range + identity
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    lo = min(xmin, ymin); hi = max(xmax, ymax)
    span = max(hi - lo, 1e-6)
    lo -= pad*span; hi += pad*span

    deltas = y - x  # Dice-only − Annealed
    idx_sorted = np.argsort(-np.abs(deltas))  # descending by |Δ|
    top_idx = set(idx_sorted[:top_k])
    thr_idx = set(np.where(np.abs(deltas) >= thr)[0])
    highlight_idx = np.array(sorted(top_idx.union(thr_idx)), dtype=int)

    fig, ax = plt.subplots(figsize=(6.2, 6.2), dpi=130)

    # Base scatter (all)
    ax.scatter(x, y, alpha=0.25, edgecolor="none", label="All slides")

    # Highlight improvements (Δ>0) and degradations (Δ<0)
    if highlight_idx.size > 0:
        imp = highlight_idx[deltas[highlight_idx] > 0]
        deg = highlight_idx[deltas[highlight_idx] < 0]

        if imp.size:
            ax.scatter(x[imp], y[imp], alpha=0.9, edgecolor="black", linewidths=0.6,
                       label=f"Highlighted (Δ>0, n={imp.size})")
        if deg.size:
            ax.scatter(x[deg], y[deg], alpha=0.9, edgecolor="black", linewidths=0.6,
                       label=f"Highlighted (Δ<0, n={deg.size})", marker="s")

        if show_labels:
            for i in highlight_idx:
                ax.annotate(str(sample_ids[i]),
                            (x[i], y[i]),
                            xytext=(3, 3),
                            textcoords="offset points",
                            fontsize=8, alpha=0.9)

    # Identity line
    ax.plot([lo, hi], [lo, hi], linestyle="--", linewidth=1.0, color='gray', alpha=0.9, label='Identity')
    ax.set_xlim(lo, hi); ax.set_ylim(lo, hi); ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel(f"Annealed (continued) ({METRIC_LABEL})")
    ax.set_ylabel(f"Dice-only ({METRIC_LABEL})")
    ax.set_title(title)  # keep blank or minimal

    # Text box with summary (top-left)
    mean_delta, (ci_lo, ci_hi) = _bootstrap_mean_ci(deltas)
    _legend_and_textbox(ax, [
        "Δ summary:",
        f"Mean Δ: {mean_delta:.3f}",
        f"95% CI: [{ci_lo:.3f}, {ci_hi:.3f}]",
        f"|Δ| ≥ {thr:.3f}: {len(thr_idx)}",
        f"Top-|Δ| shown: {min(top_k, deltas.size)}",
    ], kind="scatter")

    ax.grid(True, alpha=0.25)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200, bbox_inches='tight')
    plt.show()
    print("Saved:", out_path)

    # Return highlight dataframe for saving
    if highlight_idx.size > 0:
        abs_rank = pd.Series(np.abs(deltas)).rank(ascending=False, method="first").astype(int)
        return pd.DataFrame({
            "sample_idx": sample_ids[highlight_idx],
            "annealed_continued": x[highlight_idx],
            "dice_only": y[highlight_idx],
            "delta_dice_minus_annealed": deltas[highlight_idx],
            "abs_delta_rank": abs_rank.iloc[highlight_idx].values
        }).sort_values(["abs_delta_rank", "delta_dice_minus_annealed"], ascending=[True, False])
    else:
        return pd.DataFrame(columns=[
            "sample_idx","annealed_continued","dice_only","delta_dice_minus_annealed","abs_delta_rank"
        ])

# ---------- run ----------
df = pd.read_csv(CSV_PATH)

for ctx, labels in CONTEXTS.items():
    annealed_label = labels["annealed"]
    dice_label     = labels["dice"]

    # Pair per-sample metrics
    x, y, sids = _pair_xy(df, annealed_label, dice_label, METRIC)
    deltas = y - x

    # Delta histogram (minimal title)
    plot_delta_hist(
        deltas,
        out_dir / f"delta_hist_{ctx}_{METRIC}.pdf",
        title=""  # keep empty to avoid clutter; LaTeX caption will explain
    )

    # Scatter with highlights (minimal title)
    hi_df = plot_scatter_zoom(
        x, y, sids,
        out_dir / f"scatter_{ctx}_{METRIC}_zoom.pdf",
        title=""  # keep empty
    )

    # Save highlighted table
    hi_csv = out_dir / f"highlighted_{ctx}_{METRIC}.csv"
    hi_df.to_csv(hi_csv, index=False)
    print("Saved:", hi_csv)


## Transfer vs no transfer

In [None]:
# === Pretrained vs Scratch: delta histogram + zoomed scatter with highlights ===
# Uses existing helpers: _bootstrap_mean_ci, _legend_and_textbox, and global configs (CSV_PATH, METRIC, METRIC_LABEL,
# TOP_K_ABS_DELTA, DELTA_THRESHOLD, SHOW_LABELS). Titles kept minimal; legend left for scatter (per previous change).

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# ---------- generic pair + plotters for pretrained vs scratch ----------
def _pair_xy_generic(df, x_label, y_label, metric):
    g = df.groupby(["sample_idx", "model_label"], as_index=False).agg({metric: "mean"})
    wide = g[g["model_label"].isin([x_label, y_label])].pivot(
        index="sample_idx", columns="model_label", values=metric
    ).dropna(subset=[x_label, y_label])
    x = wide[x_label].to_numpy()   # Scratch
    y = wide[y_label].to_numpy()   # Pretrained
    return x, y, wide.index.values

def plot_delta_hist_pretrained(diffs, out_path, title=""):
    mean_delta, (lo, hi) = _bootstrap_mean_ci(diffs)
    fig, ax = plt.subplots(figsize=(6.2, 4.2), dpi=130)
    ax.hist(diffs, bins=15, alpha=0.8, edgecolor='black')
    ax.axvline(0, linestyle="--", linewidth=1.0, color='red', alpha=0.9, label='Δ = 0')
    ax.axvline(mean_delta, linewidth=1.2, color='blue', label=f'Mean Δ = {mean_delta:.3f}')
    ax.set_xlabel(f"Δ {METRIC_LABEL} (Pretrained − Scratch)")
    ax.set_ylabel("Number of Slides")
    ax.set_title(title)
    _legend_and_textbox(ax, [
        "Summary:",
        f"Mean Δ: {mean_delta:.3f}",
        f"95% CI: [{lo:.3f}, {hi:.3f}]",
        f"n = {diffs.size}",
    ], kind="hist")
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200, bbox_inches='tight')
    plt.show()
    print("Saved:", out_path)

def plot_scatter_zoom_pretrained(x, y, sample_ids, out_path, title="", pad=0.02,
                                 top_k=TOP_K_ABS_DELTA, thr=DELTA_THRESHOLD, show_labels=SHOW_LABELS):
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    lo = min(xmin, ymin); hi = max(xmax, ymax)
    span = max(hi - lo, 1e-6)
    lo -= pad*span; hi += pad*span

    deltas = y - x  # Pretrained − Scratch
    idx_sorted = np.argsort(-np.abs(deltas))
    top_idx = set(idx_sorted[:top_k])
    thr_idx = set(np.where(np.abs(deltas) >= thr)[0])
    highlight_idx = np.array(sorted(top_idx.union(thr_idx)), dtype=int)

    fig, ax = plt.subplots(figsize=(6.2, 6.2), dpi=130)
    ax.scatter(x, y, alpha=0.25, edgecolor="none", label="All slides")

    if highlight_idx.size > 0:
        imp = highlight_idx[deltas[highlight_idx] > 0]
        deg = highlight_idx[deltas[highlight_idx] < 0]
        if imp.size:
            ax.scatter(x[imp], y[imp], alpha=0.9, edgecolor="black", linewidths=0.6,
                       label=f"Highlighted (Δ>0, n={imp.size})")
        if deg.size:
            ax.scatter(x[deg], y[deg], alpha=0.9, edgecolor="black", linewidths=0.6,
                       label=f"Highlighted (Δ<0, n={deg.size})", marker="s")
        if show_labels:
            for i in highlight_idx:
                ax.annotate(str(sample_ids[i]), (x[i], y[i]),
                            xytext=(3, 3), textcoords="offset points",
                            fontsize=8, alpha=0.9)

    ax.plot([lo, hi], [lo, hi], linestyle="--", linewidth=1.0, color='gray', alpha=0.9, label='Identity')
    ax.set_xlim(lo, hi); ax.set_ylim(lo, hi); ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel(f"Scratch ({METRIC_LABEL})")
    ax.set_ylabel(f"Pretrained ({METRIC_LABEL})")
    ax.set_title(title)

    mean_delta, (ci_lo, ci_hi) = _bootstrap_mean_ci(deltas)
    _legend_and_textbox(ax, [
        "Δ summary:",
        f"Mean Δ: {mean_delta:.3f}",
        f"95% CI: [{ci_lo:.3f}, {ci_hi:.3f}]",
        f"|Δ| ≥ {thr:.3f}: {len(thr_idx)}",
        f"Top-|Δ| shown: {min(top_k, deltas.size)}",
    ], kind="scatter")

    ax.grid(True, alpha=0.25)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200, bbox_inches='tight')
    plt.show()
    print("Saved:", out_path)

    if highlight_idx.size > 0:
        abs_rank = pd.Series(np.abs(deltas)).rank(ascending=False, method="first").astype(int)
        return pd.DataFrame({
            "sample_idx": sample_ids[highlight_idx],
            "scratch": x[highlight_idx],
            "pretrained": y[highlight_idx],
            "delta_pretrained_minus_scratch": deltas[highlight_idx],
            "abs_delta_rank": abs_rank.iloc[highlight_idx].values
        }).sort_values(["abs_delta_rank", "delta_pretrained_minus_scratch"], ascending=[True, False])
    else:
        return pd.DataFrame(columns=[
            "sample_idx","scratch","pretrained","delta_pretrained_minus_scratch","abs_delta_rank"
        ])

# ---------- run comparisons ----------
out_dir = Path("figs/pretrained_vs_scratch"); out_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_csv(CSV_PATH)

COMPARISONS = [
    # continued (annealed) variants
    {"ctx": "nf", "variant": "continued", "scratch": "scratch_continued", "pretrained": "nf_continued"},
    {"ctx": "th", "variant": "continued", "scratch": "scratch_continued", "pretrained": "th_continued"},
    # dice-only variants
    {"ctx": "nf", "variant": "dice", "scratch": "scratch_dice_only", "pretrained": "nf_dice_only"},
    {"ctx": "th", "variant": "dice", "scratch": "scratch_dice_only", "pretrained": "th_dice_only"},
]

for c in COMPARISONS:
    x, y, sids = _pair_xy_generic(df, c["scratch"], c["pretrained"], METRIC)
    if x.size == 0:
        print(f"Skipping {c['ctx']} ({c['variant']}): no overlapping samples.")
        continue

    deltas = y - x

    # Δ histogram (minimal title)
    plot_delta_hist_pretrained(
        deltas,
        out_dir / f"delta_hist_{c['ctx']}_{c['variant']}_{METRIC}.pdf",
        title=""
    )

    # Scatter with highlights (minimal title)
    hi_df = plot_scatter_zoom_pretrained(
        x, y, sids,
        out_dir / f"scatter_{c['ctx']}_{c['variant']}_{METRIC}_zoom.pdf",
        title=""
    )

    # Save highlighted table
    hi_csv = out_dir / f"highlighted_{c['ctx']}_{c['variant']}_{METRIC}.csv"
    hi_df.to_csv(hi_csv, index=False)
    print("Saved:", hi_csv)


## Overview table

In [None]:

# Overview table of top models, including both *_continued and *_annealed (no files saved)

import pandas as pd
import numpy as np

# --- Config ---
CSV_PATH = "per_sample_metrics.csv"   # adjust if needed
RNG_SEED = 42
N_BOOT   = 10_000

# --- Load ---
df = pd.read_csv(CSV_PATH)

# Metrics available
METRICS = [m for m in ["soft_dice", "cldice"] if m in df.columns]
if not METRICS:
    raise ValueError("No supported metrics found. Expected 'soft_dice' and/or 'cldice' columns in CSV.")

# Include both *_continued and *_annealed, plus dice-only variants
preferred_models = [
    "scratch_continued", "scratch_annealed",
    "nf_continued",      "nf_annealed",
    "th_continued",      "th_annealed",
    "scratch_dice_only", "nf_dice_only", "th_dice_only",
]

present = set(df["model_label"].unique())

# Keep only those that exist, in the preferred order; optionally include baseline at the end if present
models = [m for m in preferred_models if m in present]
if "baseline" in present and "baseline" not in models:
    models.append("baseline")

# If none matched, use all labels
if not models:
    models = sorted(present)

# --- Helpers ---
def summarize_metric(df_in: pd.DataFrame, metric: str, model_label: str, n_boot: int = N_BOOT, seed: int = RNG_SEED):
    # Aggregate per-sample first (avoid duplicates)
    g = (df_in[df_in["model_label"] == model_label]
         .groupby("sample_idx", as_index=False)[metric].mean())
    x = g[metric].to_numpy()
    n = x.size
    if n == 0:
        return {
            "Metric": metric, "Model": model_label,
            "Mean [95% CI]": "—", "Median [IQR]": "—", "n": 0
        }
    # Bootstrap CI for the mean
    rng = np.random.default_rng(seed)
    boots = rng.choice(x, size=(n_boot, n), replace=True).mean(axis=1)
    ci_low = np.percentile(boots, 2.5)
    ci_high = np.percentile(boots, 97.5)

    mean = x.mean()
    median = np.median(x)
    q1 = np.percentile(x, 25)
    q3 = np.percentile(x, 75)

    return {
        "Metric": metric,
        "Model": model_label,
        "Mean [95% CI]": f"{mean:.3f} [{ci_low:.3f}, {ci_high:.3f}]",
        "Median [IQR]": f"{median:.3f} [{q1:.3f}, {q3:.3f}]",
        "n": int(n)
    }

# --- Build + show table ---
rows = []
for metric in METRICS:
    for m in models:
        rows.append(summarize_metric(df, metric, m))

summary_df = pd.DataFrame(rows)

# Order rows: by metric then preferred model order
summary_df["Model"] = pd.Categorical(summary_df["Model"], categories=models, ordered=True)
summary_df = summary_df.sort_values(["Metric", "Model"]).reset_index(drop=True)

# Display in Jupyter
summary_df


## Stringyness

In [None]:
# === Stringyness perspective: Paired Δ-clDice at matched Soft-Dice ===
# For each context, compare clDice(annealed) − clDice(dice-only)
# Plots:
#   1) Δ-clDice histogram with bootstrap 95% CI (legend top-right, textbox below legend)
#   2) Soft-Dice scatter (x = dice-only, y = annealed) colored by Δ-clDice (legend upper-left, textbox below legend)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path

# -------------------- CONFIG --------------------
CSV_PATH = "per_sample_metrics.csv"
# Use "*_continued" as your annealed variant:
CONTEXTS = {
    "scratch": {"annealed": "scratch_continued", "dice": "scratch_dice_only"},
    "nf":      {"annealed": "nf_continued",      "dice": "nf_dice_only"},
    "th":      {"annealed": "th_continued",      "dice": "th_dice_only"},
}
N_BOOT = 10_000
SEED   = 42
SAVE   = True
OUTDIR = Path("figs/stringyness"); OUTDIR.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------

# ---------- helpers (re-use if already defined) ----------
def _bootstrap_mean_ci(x, n_boot=N_BOOT, ci=95, seed=SEED):
    x = np.asarray(x, dtype=float)
    rng = np.random.default_rng(seed)
    boots = rng.choice(x, size=(n_boot, x.size), replace=True).mean(axis=1)
    lo = np.percentile(boots, (100-ci)/2)
    hi = np.percentile(boots, 100-(100-ci)/2)
    return float(x.mean()), (float(lo), float(hi))

def _pair_metrics(df, annealed_label, dice_label):
    """Return per-sample paired soft_dice (annealed/dice) and cldice (annealed/dice)."""
    need_cols = {"sample_idx", "model_label", "soft_dice", "cldice"}
    missing = need_cols - set(df.columns)
    if missing:
        raise ValueError(f"CSV missing columns: {missing}")

    sub = df[df["model_label"].isin([annealed_label, dice_label])]
    g = sub.groupby(["sample_idx", "model_label"], as_index=False)[["soft_dice","cldice"]].mean()

    # Pivot to MultiIndex columns: (metric, model_label)
    wide = g.pivot(index="sample_idx", columns="model_label", values=["soft_dice","cldice"])

    # ✅ Correctly flatten: "<metric>_<model_label>"
    wide.columns = [f"{metric}_{label}" for metric, label in wide.columns.to_flat_index()]

    # Keep only samples present for both models
    cols = [
        f"soft_dice_{annealed_label}", f"soft_dice_{dice_label}",
        f"cldice_{annealed_label}",    f"cldice_{dice_label}",
    ]
    missing_cols = [c for c in cols if c not in wide.columns]
    if missing_cols:
        raise KeyError(f"Expected after pivot, not found: {missing_cols}. Got: {list(wide.columns)}")

    wide = wide.dropna(subset=cols).reset_index()
    return wide


def _textbox(ax, lines, x, y, align=("left","top")):
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.85)
    ha = "left" if align[0]=="left" else "right"
    va = "top"  if align[1]=="top"  else "bottom"
    ax.text(x, y, "\n".join(lines), transform=ax.transAxes, fontsize=9, ha=ha, va=va, bbox=props)

# ---------- plotting ----------
def plot_delta_hist(deltas, ctx, save=SAVE):
    mean_delta, (lo, hi) = _bootstrap_mean_ci(deltas)
    fig, ax = plt.subplots(figsize=(6.2, 4.2), dpi=130)
    ax.hist(deltas, bins=15, alpha=0.85, edgecolor="black")
    ax.axvline(0, linestyle="--", linewidth=1.0, color="red", alpha=0.9, label="Δ = 0")
    ax.axvline(mean_delta, linewidth=1.2, color="blue", label=f"Mean Δ = {mean_delta:.3f}")
    ax.set_xlabel("Δ clDice (annealed − dice-only)")
    ax.set_ylabel("Number of slides")
    ax.set_title("")  # minimal title (caption will explain)
    ax.legend(loc="upper right", frameon=True)

    # Textbox below the legend (top-right area)
    _textbox(ax, [
        f"Mean Δ: {mean_delta:.3f}",
        f"95% CI: [{lo:.3f}, {hi:.3f}]",
        f"n = {deltas.size}",
    ], x=0.98, y=0.70, align=("right","top"))

    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    if save:
        fp = OUTDIR / f"delta_cldice_hist_{ctx}.pdf"
        fig.savefig(fp, dpi=200, bbox_inches="tight")
        print("Saved:", fp)
    plt.show()

def plot_softdice_scatter_colored(pair_df, annealed_label, dice_label, ctx, save=SAVE):
    x = pair_df[f"soft_dice_{dice_label}"].to_numpy()
    y = pair_df[f"soft_dice_{annealed_label}"].to_numpy()
    deltas = pair_df[f"cldice_{annealed_label}"].to_numpy() - pair_df[f"cldice_{dice_label}"].to_numpy()

    # Range padding and identity
    pad = 0.02
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    lo = min(xmin, ymin); hi = max(xmax, ymax)
    span = max(hi - lo, 1e-6)
    lo -= pad*span; hi += pad*span

    # Diverging color centered at 0 for ΔclDice
    norm = mcolors.TwoSlopeNorm(vmin=np.min(deltas), vcenter=0.0, vmax=np.max(deltas))

    fig, ax = plt.subplots(figsize=(6.2, 6.2), dpi=130)
    sc = ax.scatter(x, y, c=deltas, cmap="coolwarm", norm=norm, alpha=0.9, edgecolor="none")
    ax.plot([lo, hi], [lo, hi], linestyle="--", linewidth=1.0, color="gray", alpha=0.9, label="Identity")

    ax.set_xlim(lo, hi); ax.set_ylim(lo, hi); ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("Soft-Dice (dice-only)")
    ax.set_ylabel("Soft-Dice (annealed)")
    ax.set_title("")

    # Legend upper-left, textbox below legend (top-left)
    ax.legend(loc="upper left", frameon=True)
    mean_delta, (ci_lo, ci_hi) = _bootstrap_mean_ci(deltas)
    _textbox(ax, [
        f"ΔclDice mean: {mean_delta:.3f}",
        f"95% CI: [{ci_lo:.3f}, {ci_hi:.3f}]",
        f"n = {deltas.size}",
    ], x=0.02, y=0.70, align=("left","top"))

    # Colorbar for ΔclDice
    cbar = plt.colorbar(sc, ax=ax, shrink=0.9, pad=0.02)
    cbar.set_label("Δ clDice (annealed − dice-only)")

    ax.grid(True, alpha=0.25)
    fig.tight_layout()
    if save:
        fp = OUTDIR / f"softdice_scatter_colored_{ctx}.pdf"
        fig.savefig(fp, dpi=200, bbox_inches="tight")
        print("Saved:", fp)
    plt.show()

# ---------- run ----------
df = pd.read_csv(CSV_PATH)

for ctx, lbls in CONTEXTS.items():
    annealed_label = lbls["annealed"]
    dice_label     = lbls["dice"]

    paired = _pair_metrics(df, annealed_label, dice_label)
    if paired.empty:
        print(f"[{ctx}] No overlapping samples for {annealed_label} vs {dice_label}. Skipping.")
        continue

    # 1) Δ-clDice histogram
    deltas = paired[f"cldice_{annealed_label}"].to_numpy() - paired[f"cldice_{dice_label}"].to_numpy()
    plot_delta_hist(deltas, ctx)

    # 2) Soft-Dice scatter colored by Δ-clDice
    plot_softdice_scatter_colored(paired, annealed_label, dice_label, ctx)
