# Subrun Comparison

Compare parent runs to their subruns — the **deterministic swap** use case:
- Parent runs use Config A (e.g., original reranker or embedder)
- Subruns reuse the same memories, test cases, and queries, but run experiments with Config B
- Because everything except the experiment config is identical, any F1 difference is purely due to the config change

**Typical workflow:**
1. Run a full pipeline batch → parent runs
2. Run a subrun batch pointing at those parent run IDs → subruns
3. Open this notebook, paste the parent run IDs, compare

In [None]:
# Cell 0 — Setup & Imports
import os
from itertools import combinations
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from memory_retrieval.experiments.comparison import (
    compare_configs,
    fingerprint_diff,
    group_runs_by_fingerprint,
    load_run_summaries,
    load_subrun_summaries,
)
from memory_retrieval.experiments.metrics_adapter import extract_metric_from_nested
from memory_retrieval.infra.figures import create_figure_session, save_figure
from memory_retrieval.infra.runs import PHASE1, PHASE2

# Find project root
PROJECT_ROOT = Path.cwd()
while not (PROJECT_ROOT / "pyproject.toml").exists():
    if PROJECT_ROOT.parent == PROJECT_ROOT:
        raise RuntimeError("Could not find project root")
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)


def _extract_f1_from_summary(
    macro: dict, metric_path: str, strategy: str | None = None
) -> float | None:
    """Extract the primary optimal F1 from a macro_averaged dict."""
    return extract_metric_from_nested(macro, metric_path, strategy, metric_key="f1")


def _extract_f1_from_per_test_case(
    per_tc: dict, metric_path: str, strategy: str | None = None
) -> float | None:
    """Extract F1 from a per_test_case entry."""
    return extract_metric_from_nested(per_tc, metric_path, strategy, metric_key="f1")


def _fingerprint_hash(summary: dict) -> str:
    fingerprint = summary.get("config_fingerprint", {})
    return fingerprint.get("fingerprint_hash", "unknown")


def _summary_parent_anchor(summary: dict) -> str | None:
    """Return the parent-run anchor used to match groups fairly."""
    if summary.get("source_kind") == "subrun":
        return summary.get("parent_run_id")
    return summary.get("run_id")


def _label_for_index(index: int) -> str:
    alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    if index < len(alphabet):
        return alphabet[index]
    return f"G{index + 1}"


print("Setup complete.")


In [None]:
# Cell 1 — Configuration

# Which phase to analyze
PHASE = PHASE2

# Parent run IDs to compare.
# Copy these from run_batch.py output or data/<phase>/batches/<batch_id>/batch_manifest.json.
# Set to None to load all parent runs that have subruns.
PARENT_RUN_IDS = None  # e.g., ["run_20260218_143022", "run_20260218_145301"]

# Rerank strategy — must match what was used in the subrun's ExperimentConfig.
STRATEGY = "default"

# Metric path: "post_rerank" for reranking experiments, "metrics" for standard vector-only.
METRIC_PATH = "pre_rerank"

# Optional detailed pair (e.g., "A_vs_B"). None = auto-select baseline-vs-first-variant.
DETAIL_PAIR = None

# Derived label for titles/logs — avoids showing a strategy name when using pre_rerank
METRIC_LABEL = STRATEGY if METRIC_PATH == "post_rerank" else "pre-rerank (distance)"

print(f"Phase: {PHASE}")
print(f"Strategy: {STRATEGY}")
print(f"Metric path: {METRIC_PATH}")
print(f"Metric label: {METRIC_LABEL}")
print(f"Detail pair: {DETAIL_PAIR}")


In [None]:
# Cell 2 — Load Parent/Subruns, Group by Fingerprint, and Show Coverage

parent_summaries = [
    {**summary, "source_kind": "parent"}
    for summary in load_run_summaries(PHASE, run_ids=PARENT_RUN_IDS)
]
subrun_summaries = [
    {**summary, "source_kind": "subrun"}
    for summary in load_subrun_summaries(PHASE, parent_run_ids=PARENT_RUN_IDS)
]

loaded_parent_run_ids = sorted(
    [summary.get("run_id", "unknown") for summary in parent_summaries if summary.get("run_id")]
)
all_summaries = parent_summaries + subrun_summaries
fingerprint_groups = group_runs_by_fingerprint(all_summaries)

parent_hashes = sorted({_fingerprint_hash(summary) for summary in parent_summaries})
ordered_hashes = parent_hashes + sorted(
    [hash_key for hash_key in fingerprint_groups if hash_key not in set(parent_hashes)]
)
group_labels = {
    hash_key: _label_for_index(index) for index, hash_key in enumerate(ordered_hashes)
}

parent_ids_by_hash: dict[str, set[str]] = {}
for hash_key in ordered_hashes:
    parent_ids = {
        parent_id
        for parent_id in (_summary_parent_anchor(summary) for summary in fingerprint_groups[hash_key])
        if parent_id is not None
    }
    parent_ids_by_hash[hash_key] = parent_ids

if PARENT_RUN_IDS is None:
    parents_part = f"parents-all-n{len(loaded_parent_run_ids)}"
elif loaded_parent_run_ids:
    parents_part = (
        f"parents-{loaded_parent_run_ids[0]}-{loaded_parent_run_ids[-1]}-n{len(loaded_parent_run_ids)}"
    )
else:
    parents_part = "parents-none-none-n0"

if ordered_hashes:
    groups_part = f"groups-{ordered_hashes[0][:8]}-{ordered_hashes[-1][:8]}-n{len(ordered_hashes)}"
else:
    groups_part = "groups-none-none-n0"

comparison_key = (
    f"{parents_part}__subruns-n{len(subrun_summaries)}__{groups_part}"
    f"__strategy-{STRATEGY}__metric-{METRIC_PATH}"
)
FIGURE_SESSION = create_figure_session(
    root_dir=PROJECT_ROOT / "data" / "comparisons" / "exports",
    notebook_slug=f"subrun_comparison/{PHASE}",
    context_key=comparison_key,
    context={
        "phase": PHASE,
        "strategy": STRATEGY,
        "metric_path": METRIC_PATH,
        "detail_pair": DETAIL_PAIR,
        "parent_run_ids": PARENT_RUN_IDS,
        "loaded_parent_run_ids": loaded_parent_run_ids,
        "subrun_ids": [summary.get("run_id", "unknown") for summary in subrun_summaries],
        "group_hashes": ordered_hashes,
        "group_labels": group_labels,
    },
)
print(f"Figure export session: {FIGURE_SESSION.session_dir}")

if not subrun_summaries:
    print("No subruns found. Run a subrun batch first:")
    print()
    print(
        "  from memory_retrieval.experiments.batch_runner import BatchSubrunConfig, run_subrun_batch"
    )
    print("  outcome = run_subrun_batch(BatchSubrunConfig(")
    print("      phase=PHASE2,")
    print("      parent_run_ids=[...],")
    print(
        "      experiment_config=ExperimentConfig(search_backend=VectorBackend(), reranker=Reranker(...)),"
    )
    print("  ))")
else:
    print(f"Loaded {len(parent_summaries)} parent runs")
    print(f"Loaded {len(subrun_summaries)} subruns")
    print(f"Found {len(ordered_hashes)} fingerprint groups")
    print()

    header = (
        f"{'Group':>6} {'Hash':>10} {'Parents':>8} {'Subruns':>8} "
        f"{'Anchors':>8} {'Avg F1':>8} {'Kind':>20}"
    )
    print(header)
    print("-" * len(header))

    for hash_key in ordered_hashes:
        label = group_labels[hash_key]
        group_summaries = fingerprint_groups[hash_key]
        parent_count = sum(1 for summary in group_summaries if summary.get("source_kind") == "parent")
        subrun_count = sum(1 for summary in group_summaries if summary.get("source_kind") == "subrun")

        f1_values = []
        for summary in group_summaries:
            f1_value = _extract_f1_from_summary(
                summary.get("macro_averaged", {}), METRIC_PATH, STRATEGY
            )
            if f1_value is not None:
                f1_values.append(f1_value)
        avg_f1 = float(np.mean(f1_values)) if f1_values else float("nan")

        kind_parts = []
        if parent_count:
            kind_parts.append("parent")
        if subrun_count:
            kind_parts.append("subrun")
        kind_text = "+".join(kind_parts) if kind_parts else "unknown"

        avg_text = f"{avg_f1:.4f}" if not np.isnan(avg_f1) else "n/a"
        print(
            f"{label:>6} {hash_key:>10} {parent_count:>8} {subrun_count:>8} "
            f"{len(parent_ids_by_hash[hash_key]):>8} {avg_text:>8} {kind_text:>20}"
        )

    print("\nParent-anchor coverage (which groups each parent run has):")
    print(f"{'Parent Run ID':<30} {'Groups':>12}")
    print("-" * 45)

    groups_by_parent: dict[str, set[str]] = {}
    for summary in all_summaries:
        parent_id = _summary_parent_anchor(summary)
        if parent_id is None:
            continue
        hash_key = _fingerprint_hash(summary)
        groups_by_parent.setdefault(parent_id, set()).add(group_labels.get(hash_key, "?"))

    for parent_run_id in sorted(groups_by_parent):
        labels = ",".join(sorted(groups_by_parent[parent_run_id]))
        print(f"{parent_run_id:<30} {labels:>12}")


In [None]:
# Cell 3 — Config Diff by Fingerprint Group

if not ordered_hashes:
    print("Cannot compute diffs: no fingerprint groups found.")
elif len(ordered_hashes) == 1:
    only_hash = ordered_hashes[0]
    print(f"Only one fingerprint group found: {group_labels[only_hash]} ({only_hash})")
    print("Need at least two groups to compare config diffs.")
else:
    baseline_hash = ordered_hashes[0]
    baseline_label = group_labels[baseline_hash]
    baseline_fingerprint = fingerprint_groups[baseline_hash][0].get("config_fingerprint", {})

    print(f"Baseline group: {baseline_label} ({baseline_hash})")

    if not baseline_fingerprint:
        print("Baseline group has no config fingerprint.")
    else:
        for hash_key in ordered_hashes[1:]:
            label = group_labels[hash_key]
            candidate_fingerprint = fingerprint_groups[hash_key][0].get("config_fingerprint", {})

            print("\n" + "=" * 72)
            print(f"Group {baseline_label} ({baseline_hash}) -> Group {label} ({hash_key})")
            print("=" * 72)

            if not candidate_fingerprint:
                print("Target group has no config fingerprint.")
                continue

            diff = fingerprint_diff(baseline_fingerprint, candidate_fingerprint)
            if not diff:
                print("No config differences detected.")
                continue

            max_field_len = max(len(field) for field in diff)
            for field, values in diff.items():
                print(f"  {field:<{max_field_len}}  {str(values['a']):<40} -> {values['b']}")


In [None]:
# Cell 4 — Statistical Comparison Matrix (A vs B, A vs C, B vs C, ...)

pairwise_comparisons = []
selected_pair_result = None
comparison = None

if len(ordered_hashes) < 2:
    print("Need at least two fingerprint groups for comparison.")
else:
    for hash_a, hash_b in combinations(ordered_hashes, 2):
        label_a = group_labels[hash_a]
        label_b = group_labels[hash_b]

        matched_parent_ids = parent_ids_by_hash[hash_a] & parent_ids_by_hash[hash_b]
        if not matched_parent_ids:
            continue

        group_a = [
            summary
            for summary in fingerprint_groups[hash_a]
            if _summary_parent_anchor(summary) in matched_parent_ids
        ]
        group_b = [
            summary
            for summary in fingerprint_groups[hash_b]
            if _summary_parent_anchor(summary) in matched_parent_ids
        ]

        result = compare_configs(
            group_a,
            group_b,
            strategy=STRATEGY,
            metric_path=METRIC_PATH,
        )

        pairwise_comparisons.append(
            {
                "pair_id": f"{label_a}_vs_{label_b}",
                "hash_a": hash_a,
                "hash_b": hash_b,
                "label_a": label_a,
                "label_b": label_b,
                "matched_parent_ids": sorted(matched_parent_ids),
                "group_a": group_a,
                "group_b": group_b,
                "comparison": result,
            }
        )

    if not pairwise_comparisons:
        print("No comparable group pairs found after parent-anchor matching.")
    else:
        print("=" * 92)
        print("PAIRWISE COMPARISONS (parent-anchor matched)")
        print("=" * 92)
        print(
            f"{'Pair':<10} {'Hash A':<10} {'Hash B':<10} {'Anchors':>8} {'Runs A':>8} {'Runs B':>8} "
            f"{'Mean Diff':>10} {'p-value':>10} {'Significance':>14}"
        )
        print("-" * 92)

        for entry in pairwise_comparisons:
            result = entry["comparison"]
            if "error" in result:
                print(
                    f"{entry['pair_id']:<10} {entry['hash_a']:<10} {entry['hash_b']:<10} "
                    f"{len(entry['matched_parent_ids']):>8} {len(entry['group_a']):>8} {len(entry['group_b']):>8} "
                    f"{'n/a':>10} {'n/a':>10} {'error':>14}"
                )
                continue

            mean_diff = result["mean_diff"]
            p_value = result["wilcoxon"]["p_value"]
            significance = result["significance"]
            print(
                f"{entry['pair_id']:<10} {entry['hash_a']:<10} {entry['hash_b']:<10} "
                f"{len(entry['matched_parent_ids']):>8} {result['num_runs_a']:>8} {result['num_runs_b']:>8} "
                f"{mean_diff:>+10.4f} {p_value:>10.4f} {significance:>14}"
            )

        valid_pairs = [entry for entry in pairwise_comparisons if "error" not in entry["comparison"]]

        if not valid_pairs:
            print("\nAll pairwise comparisons returned errors.")
        else:
            if DETAIL_PAIR:
                normalized = DETAIL_PAIR.strip().upper().replace(" ", "")
                for entry in valid_pairs:
                    pair_id = entry["pair_id"].upper().replace(" ", "")
                    pair_compact = pair_id.replace("_VS_", "VS")
                    if normalized in {pair_id, pair_compact}:
                        selected_pair_result = entry
                        break

            if selected_pair_result is None:
                baseline_hash = ordered_hashes[0]
                baseline_candidates = [
                    entry for entry in valid_pairs if entry["hash_a"] == baseline_hash
                ]
                selected_pair_result = baseline_candidates[0] if baseline_candidates else valid_pairs[0]

            comparison = selected_pair_result["comparison"]
            print("\nSelected detailed pair:")
            print(
                f"  {selected_pair_result['pair_id']} "
                f"({selected_pair_result['hash_a']} vs {selected_pair_result['hash_b']})"
            )
            print(
                f"  Mean F1: {selected_pair_result['label_a']}={comparison['mean_f1_a']:.4f}, "
                f"{selected_pair_result['label_b']}={comparison['mean_f1_b']:.4f}, "
                f"delta={comparison['mean_diff']:+.4f}"
            )


In [None]:
# Cell 5 — Per-Test-Case Paired Plot (Selected Pair)

if selected_pair_result and comparison and "paired_differences" in comparison:
    paired = comparison["paired_differences"]
    label_a = selected_pair_result["label_a"]
    label_b = selected_pair_result["label_b"]

    fig, axes = plt.subplots(1, 2, figsize=(16, max(5, len(paired) * 0.55 + 2)))

    test_case_ids = [entry["test_case_id"] for entry in paired]
    y_positions = list(range(len(test_case_ids)))

    # Left: paired dot plot
    ax = axes[0]
    for y_pos, entry in enumerate(paired):
        diff_color = (
            "#2ecc71"
            if entry["diff"] > 0.001
            else "#e74c3c"
            if entry["diff"] < -0.001
            else "#95a5a6"
        )
        ax.plot(
            [entry["f1_a"], entry["f1_b"]],
            [y_pos, y_pos],
            color=diff_color,
            linewidth=2.5,
            alpha=0.6,
        )
        ax.scatter(
            entry["f1_a"], y_pos, color="#3498db", s=70, zorder=5, edgecolors="white", linewidth=0.5
        )
        ax.scatter(
            entry["f1_b"], y_pos, color="#e67e22", s=70, zorder=5, edgecolors="white", linewidth=0.5
        )

    ax.set_yticks(y_positions)
    ax.set_yticklabels(test_case_ids, fontsize=8)
    ax.set_xlabel("F1 Score")
    ax.set_title(f"Per-Test-Case F1: Group {label_a} (blue) vs Group {label_b} (orange)")
    ax.scatter([], [], color="#3498db", s=70, label=f"Group {label_a}")
    ax.scatter([], [], color="#e67e22", s=70, label=f"Group {label_b}")
    ax.plot([], [], color="#2ecc71", linewidth=2.5, label="Improved")
    ax.plot([], [], color="#e74c3c", linewidth=2.5, label="Degraded")
    ax.legend(fontsize=8, loc="lower right")
    ax.set_xlim(-0.05, 1.05)
    ax.grid(True, alpha=0.3, axis="x")

    # Right: delta bar chart
    ax = axes[1]
    diffs = [entry["diff"] for entry in paired]
    bar_colors = ["#2ecc71" if diff > 0 else "#e74c3c" for diff in diffs]
    ax.barh(y_positions, diffs, color=bar_colors, alpha=0.75, edgecolor="white", linewidth=0.3)
    ax.axvline(0, color="black", linewidth=0.8)
    ax.axvline(
        comparison["mean_diff"],
        color="#2c3e50",
        linewidth=1.5,
        linestyle="--",
        label=f"Mean delta: {comparison['mean_diff']:+.3f}",
    )
    ax.set_yticks(y_positions)
    ax.set_yticklabels(test_case_ids, fontsize=8)
    ax.set_xlabel(f"F1 Delta (Group {label_b} - Group {label_a})")
    ax.set_title("Per-Test-Case Delta")
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, axis="x")

    plt.suptitle(
        (
            f"Subrun Comparison - {PHASE} ({METRIC_LABEL}) | "
            f"{selected_pair_result['pair_id']} | "
            f"{comparison['significance'].replace('_', ' ')}"
        ),
        fontsize=12,
        fontweight="bold",
    )
    plt.tight_layout()
    pair_slug = selected_pair_result["pair_id"].lower()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        f"per_test_case_paired_plot_{pair_slug}",
        title=f"Subrun Comparison - {PHASE} ({METRIC_LABEL}) - {selected_pair_result['pair_id']}",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")
else:
    print("Run Cell 4 first to generate at least one valid pairwise comparison.")


In [None]:
# Cell 6 — Threshold Sweep Overlay (All Fingerprint Groups)
# Supports both metric paths:
# - post_rerank: rerank score threshold sweep from rerank_strategies[STRATEGY].threshold_sweep
# - pre_rerank: distance threshold sweep from baseline.distance_threshold_sweep

if not ordered_hashes:
    print("Cannot plot: no fingerprint groups loaded.")
else:
    fig, ax = plt.subplots(figsize=(12, 6))

    palette = sns.color_palette("tab10", n_colors=max(3, len(ordered_hashes)))
    line_styles = ["-", "--", "-.", ":"]
    is_pre_rerank_mode = METRIC_PATH == "pre_rerank"
    threshold_label = (
        "Distance Threshold (cosine distance)"
        if is_pre_rerank_mode
        else "Rerank Score Threshold"
    )
    annotation_prefix = "d" if is_pre_rerank_mode else "t"
    plot_mode_label = "Distance" if is_pre_rerank_mode else "Rerank Score"

    plotted_groups = 0

    for group_index, hash_key in enumerate(ordered_hashes):
        group_summaries = fingerprint_groups[hash_key]
        label = group_labels.get(hash_key, hash_key[:8])
        color = palette[group_index % len(palette)]

        all_sweep_dicts: list[dict[float, float]] = []
        for summary in group_summaries:
            if is_pre_rerank_mode:
                sweep_data = summary.get("baseline", {}).get("distance_threshold_sweep", {})
            else:
                sweep_data = (
                    summary.get("rerank_strategies", {})
                    .get(STRATEGY, {})
                    .get("threshold_sweep", {})
                )

            full_sweep = sweep_data.get("full_sweep", [])
            if full_sweep:
                all_sweep_dicts.append({entry["threshold"]: entry["f1"] for entry in full_sweep})

        if not all_sweep_dicts:
            continue

        all_thresholds = sorted(set().union(*[set(sweep.keys()) for sweep in all_sweep_dicts]))

        mean_f1_values = []
        min_f1_values = []
        max_f1_values = []
        for threshold in all_thresholds:
            f1_at_threshold = [sweep[threshold] for sweep in all_sweep_dicts if threshold in sweep]
            if f1_at_threshold:
                mean_f1_values.append(float(np.mean(f1_at_threshold)))
                min_f1_values.append(min(f1_at_threshold))
                max_f1_values.append(max(f1_at_threshold))

        ax.plot(
            all_thresholds,
            mean_f1_values,
            label=f"Group {label} ({hash_key[:8]}, n={len(all_sweep_dicts)})",
            color=color,
            linewidth=2,
            linestyle=line_styles[group_index % len(line_styles)],
        )
        plotted_groups += 1

        if len(all_sweep_dicts) > 1:
            ax.fill_between(all_thresholds, min_f1_values, max_f1_values, color=color, alpha=0.15)

        optimal_index = int(np.argmax(mean_f1_values))
        optimal_threshold = all_thresholds[optimal_index]
        optimal_f1 = mean_f1_values[optimal_index]
        ax.axvline(
            optimal_threshold,
            color=color,
            linestyle=":",
            alpha=0.55,
            linewidth=1.2,
        )
        ax.annotate(
            f"{annotation_prefix}={optimal_threshold:.3f}\nF1={optimal_f1:.3f}",
            xy=(optimal_threshold, optimal_f1),
            xytext=(8, 8),
            textcoords="offset points",
            fontsize=7,
            color=color,
            ha="left",
        )

    if plotted_groups == 0:
        if is_pre_rerank_mode:
            print(
                "No distance threshold sweep data found for METRIC_PATH='pre_rerank'. "
                "Expected baseline.distance_threshold_sweep.full_sweep in run summaries."
            )
        else:
            print(
                f"No rerank threshold sweep data found for METRIC_PATH='post_rerank' "
                f"and STRATEGY='{STRATEGY}'."
            )
        plt.close(fig)
    else:
        ax.set_xlabel(threshold_label)
        ax.set_ylabel("Macro F1")
        ax.set_title(f"{plot_mode_label} Threshold Sweep by Fingerprint Group - {PHASE} ({METRIC_LABEL})")
        ax.legend(fontsize=9)
        ax.set_ylim(0, 1.05)
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        saved_paths = save_figure(
            fig,
            FIGURE_SESSION,
            "threshold_sweep_overlay_all_groups",
            title=f"{plot_mode_label} Threshold Sweep by Fingerprint Group - {PHASE} ({METRIC_LABEL})",
        )
        plt.show()
        print(f"Saved: {saved_paths['png']}")


In [None]:
# Cell 7 — Per-Test-Case Heatmaps (All Groups + Selected Pair Delta)

if not ordered_hashes:
    print("Cannot plot: no fingerprint groups loaded.")
else:
    all_test_case_ids = set()
    for summary in all_summaries:
        all_test_case_ids.update(summary.get("per_test_case", {}).keys())
    all_test_case_ids = sorted(all_test_case_ids)

    heatmap_data = np.full((len(all_test_case_ids), len(ordered_hashes)), np.nan)
    column_labels = []

    for col_index, hash_key in enumerate(ordered_hashes):
        label = group_labels.get(hash_key, hash_key[:8])
        column_labels.append(f"{label}\n{hash_key[:8]}")
        group_summaries = fingerprint_groups[hash_key]

        for row_index, test_case_id in enumerate(all_test_case_ids):
            f1_values = []
            for summary in group_summaries:
                per_tc = summary.get("per_test_case", {}).get(test_case_id, {})
                f1_value = _extract_f1_from_per_test_case(per_tc, METRIC_PATH, STRATEGY)
                if f1_value is not None:
                    f1_values.append(f1_value)
            if f1_values:
                heatmap_data[row_index, col_index] = float(np.mean(f1_values))

    fig, ax = plt.subplots(
        figsize=(max(8, len(ordered_hashes) * 2.2), max(6, len(all_test_case_ids) * 0.55))
    )
    sns.heatmap(
        heatmap_data,
        annot=True,
        fmt=".2f",
        xticklabels=column_labels,
        yticklabels=all_test_case_ids,
        cmap="RdYlGn",
        vmin=0,
        vmax=1,
        linewidths=0.5,
        ax=ax,
    )
    ax.set_title(f"Per-Test-Case F1 by Fingerprint Group - {PHASE} ({METRIC_LABEL})")
    ax.set_ylabel("Test Case")
    ax.set_xlabel("Fingerprint Group")
    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "per_test_case_f1_heatmap_all_groups",
        title=f"Per-Test-Case F1 by Fingerprint Group - {PHASE} ({METRIC_LABEL})",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")

    if selected_pair_result and comparison and "paired_differences" in comparison:
        label_a = selected_pair_result["label_a"]
        label_b = selected_pair_result["label_b"]
        pair_id = selected_pair_result["pair_id"]

        delta_by_case = {
            entry["test_case_id"]: entry["diff"] for entry in comparison.get("paired_differences", [])
        }
        delta_rows = [delta_by_case.get(test_case_id, np.nan) for test_case_id in all_test_case_ids]
        delta_data = np.array(delta_rows).reshape(-1, 1)

        fig, ax = plt.subplots(figsize=(3.3, max(6, len(all_test_case_ids) * 0.55)))
        delta_max = max(float(np.nanmax(np.abs(delta_data))), 0.01)
        sns.heatmap(
            delta_data,
            annot=True,
            fmt="+.2f",
            xticklabels=[f"Delta\n({label_b}-{label_a})"],
            yticklabels=all_test_case_ids,
            cmap="RdYlGn",
            vmin=-delta_max,
            vmax=delta_max,
            center=0,
            linewidths=0.5,
            ax=ax,
        )
        ax.set_title(f"Selected Pair Delta - {pair_id}")
        plt.tight_layout()
        saved_paths = save_figure(
            fig,
            FIGURE_SESSION,
            f"per_test_case_delta_heatmap_{pair_id.lower()}",
            title=f"Selected Pair Delta - {pair_id}",
        )
        plt.show()
        print(f"Saved: {saved_paths['png']}")

        print("\nSelected pair test-case breakdown:")
        print(f"  {'Test Case':<35} {'F1 A':>8} {'F1 B':>8} {'Delta':>8}")
        print(f"  {'-' * 35} {'-' * 8} {'-' * 8} {'-' * 8}")
        for entry in sorted(comparison["paired_differences"], key=lambda item: item["diff"], reverse=True):
            print(
                f"  {entry['test_case_id']:<35} {entry['f1_a']:>8.3f} {entry['f1_b']:>8.3f} {entry['diff']:>+8.3f}"
            )
