# LG-CoTrain: Alternative Early Stopping — Full Experiment Run

This notebook runs the **complete experiment sweep** (all budgets × all seed sets × all events)
under each of the 6 stopping strategies, then produces the same comparison charts as
notebook 04 but with statistically reliable mean ± std estimates.

| Strategy | Key Idea |
|---|---|
| `baseline` | Original: stop when ensemble macro-F1 plateaus for `patience` epochs |
| `no_early_stopping` | Run all `finetune_max_epochs`; restore best-ever checkpoint (upper bound) |
| `per_class_patience` | Stop only when **every** class F1 has individually plateaued |
| `weighted_macro_f1` | Weight rare classes more in the stopping metric |
| `balanced_dev` | Resample dev set to equal class sizes for the stopping signal |
| `scaled_threshold` | Require a larger improvement delta for highly imbalanced events |

**Total experiments**: 6 strategies × 10 events × 4 budgets × 3 seeds = **720 runs**

Results are stored in `results/{PSEUDO_LABEL_SOURCE}/stop/{strategy}/` (e.g. `results/gpt-4o/stop/baseline/`).
See **notebook 04** for a fast preview using budget=50, seed=1 only.

## Resume Support

Every `(event, budget, seed_set)` combination is written to its own `metrics.json`
as soon as it completes. Re-running any cell automatically skips completed experiments.

In [1]:
import json
import statistics
import sys
import time
from pathlib import Path


def _find_repo_root(marker: str = "lg_cotrain") -> Path:
    for candidate in [Path().resolve()] + list(Path().resolve().parents):
        if (candidate / marker).is_dir():
            return candidate
    raise RuntimeError(
        f"Cannot find repo root: no ancestor directory contains '{marker}/'. "
        "Run the notebook from inside the repository."
    )


repo_root = _find_repo_root()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

import matplotlib.pyplot as plt
import numpy as np

from lg_cotrain.run_all import BUDGETS, SEED_SETS, format_summary_table, run_all_experiments


class ProgressTracker:
    """Track global progress across all strategies × events × budgets × seeds."""

    def __init__(self, total: int, already_done: int, start_time: float):
        self.total = total
        self.done = already_done
        self.start_time = start_time

    def update(self, event, budget, seed_set, status):
        self.done += 1
        elapsed = time.time() - self.start_time
        pct = 100.0 * self.done / self.total if self.total else 0
        elapsed_h = elapsed / 3600
        remaining = self.total - self.done
        eta_h = (elapsed / self.done) * remaining / 3600 if self.done > 0 else 0
        print(
            f"  [PROGRESS] {self.done}/{self.total} ({pct:.1f}%)"
            f"  |  Elapsed: {elapsed_h:.2f}h  |  ETA: {eta_h:.2f}h  |  {status}"
        )


print(f"Repo root: {repo_root}")
print(f"Budgets  : {BUDGETS}")
print(f"Seed sets: {SEED_SETS}")

Repo root: D:\Workspace\Co-Training
Budgets  : [5, 10, 25, 50]
Seed sets: [1, 2, 3]


In [None]:
# ---- Configuration ----

PSEUDO_LABEL_SOURCE = "gpt-4o"

STRATEGIES = [
    "baseline",
    "no_early_stopping",
    "per_class_patience",
    "weighted_macro_f1",
    "balanced_dev",
    "scaled_threshold",
]

DATA_ROOT = str(repo_root / "data")

# Discover all events
TARGET_EVENTS = sorted(
    p.name for p in (Path(DATA_ROOT) / "original").iterdir() if p.is_dir()
)

# Each strategy gets its own results sub-folder under model/stop/
STRATEGY_RESULTS_ROOTS = {
    s: str(repo_root / "results" / PSEUDO_LABEL_SOURCE / "stop" / s)
    for s in STRATEGIES
}

expts_per_strategy = len(TARGET_EVENTS) * len(BUDGETS) * len(SEED_SETS)
total_runs = len(STRATEGIES) * expts_per_strategy

print(f"Strategies : {STRATEGIES}")
print(f"Events     : {len(TARGET_EVENTS)} events")
print(f"Per strategy: {expts_per_strategy} experiments  ({len(TARGET_EVENTS)} events × {len(BUDGETS)} budgets × {len(SEED_SETS)} seeds)")
print(f"Grand total : {total_runs} experiments")
print()
for s, r in STRATEGY_RESULTS_ROOTS.items():
    print(f"  {s:<25} → {r}")

## Running Experiments

The loop below processes one strategy at a time, running all events × budgets × seeds
for each. Completed experiments are skipped automatically.

> **Runtime estimate**: ~3-5 hours per strategy on GPU, ~18-30 hours total.
> Run overnight or split across sessions — resume support ensures no work is lost.

In [3]:
# Count already-completed experiments across all strategies (for accurate ETA from the start)
already_done = sum(
    1
    for strategy in STRATEGIES
    for event in TARGET_EVENTS
    for budget in BUDGETS
    for seed_set in SEED_SETS
    if (
        Path(STRATEGY_RESULTS_ROOTS[strategy])
        / event / f"{budget}_set{seed_set}" / "metrics.json"
    ).exists()
)
total_experiments = len(STRATEGIES) * expts_per_strategy

print(f"Total experiments : {total_experiments}")
print(f"Already completed : {already_done}")
print(f"Remaining         : {total_experiments - already_done}")
print()

overall_start = time.time()
tracker = ProgressTracker(total_experiments, already_done, overall_start)
all_strategy_results = {}  # strategy -> event -> list[result_dict]

for strategy in STRATEGIES:
    results_root = STRATEGY_RESULTS_ROOTS[strategy]
    strat_start = time.time()
    print(f"\n{'=' * 70}")
    print(f"Strategy: {strategy}  →  {results_root}")
    print(f"{'=' * 70}")
    all_strategy_results[strategy] = {}

    for event in TARGET_EVENTS:
        print(f"\n  Event: {event}")
        results = run_all_experiments(
            event,
            pseudo_label_source=PSEUDO_LABEL_SOURCE,
            stopping_strategy=strategy,
            data_root=DATA_ROOT,
            results_root=results_root,
            _on_experiment_done=tracker.update,
        )
        all_strategy_results[strategy][event] = results
        print()
        print(format_summary_table(results, event))

    strat_elapsed = time.time() - strat_start
    print(f"\n  Strategy '{strategy}' done in {strat_elapsed / 3600:.2f}h ({strat_elapsed / 60:.1f}min)")

total_elapsed = time.time() - overall_start
print(f"\n{'=' * 70}")
print(f"All strategies complete in {total_elapsed / 3600:.2f}h ({total_elapsed / 60:.1f}min)")

Total experiments : 720
Already completed : 0
Remaining         : 720


Strategy: baseline  →  D:\Workspace\Co-Training\results\gpt-4o-stop-baseline

  Event: california_wildfires_2018


  from .autonotebook import tqdm as notebook_tqdm


[1/12] budget=5, seed=1 -- starting...


2026-02-20 09:31:59,510 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=5, seed_set=1
2026-02-20 09:31:59,557 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 09:31:59,570 - lg_cotrain - INFO - D_l1: 30, D_l2: 20, D_LG: 5113
2026-02-20 09:31:59,571 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1117.91it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1144.22it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 09:32:18,527 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.0981, m

[1/12] budget=5, seed=1 -- done (macro_f1=0.6291)
  [PROGRESS] 1/720 (0.1%)  |  Elapsed: 0.26h  |  ETA: 184.97h  |  done
[2/12] budget=5, seed=2 -- starting...


2026-02-20 09:47:06,918 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=5, seed_set=2
2026-02-20 09:47:06,969 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 09:47:06,980 - lg_cotrain - INFO - D_l1: 30, D_l2: 20, D_LG: 5113
2026-02-20 09:47:06,982 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1019.40it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1042.36it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 09:47:27,775 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.0968, m

[2/12] budget=5, seed=2 -- done (macro_f1=0.5977)
  [PROGRESS] 2/720 (0.3%)  |  Elapsed: 0.50h  |  ETA: 179.07h  |  done
[3/12] budget=5, seed=3 -- starting...


2026-02-20 10:01:36,466 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=5, seed_set=3
2026-02-20 10:01:36,513 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 10:01:36,513 - lg_cotrain - INFO - D_l1: 30, D_l2: 20, D_LG: 5113
2026-02-20 10:01:36,523 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1119.15it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1155.75it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 10:01:55,369 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.1056, m

[3/12] budget=5, seed=3 -- done (macro_f1=0.5642)
  [PROGRESS] 3/720 (0.4%)  |  Elapsed: 0.74h  |  ETA: 176.36h  |  done
[4/12] budget=10, seed=1 -- starting...


2026-02-20 10:15:57,277 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=10, seed_set=1
2026-02-20 10:15:57,312 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 10:15:57,321 - lg_cotrain - INFO - D_l1: 50, D_l2: 50, D_LG: 5063
2026-02-20 10:15:57,323 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1196.34it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1180.10it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 10:16:16,219 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.1004, 

[4/12] budget=10, seed=1 -- done (macro_f1=0.6017)
  [PROGRESS] 4/720 (0.6%)  |  Elapsed: 0.97h  |  ETA: 174.08h  |  done
[5/12] budget=10, seed=2 -- starting...


2026-02-20 10:30:02,106 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=10, seed_set=2
2026-02-20 10:30:02,158 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 10:30:02,169 - lg_cotrain - INFO - D_l1: 50, D_l2: 50, D_LG: 5063
2026-02-20 10:30:02,172 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1188.14it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1148.49it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 10:30:21,136 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.0967, 

[5/12] budget=10, seed=2 -- done (macro_f1=0.5955)
  [PROGRESS] 5/720 (0.7%)  |  Elapsed: 1.21h  |  ETA: 172.84h  |  done
[6/12] budget=10, seed=3 -- starting...


2026-02-20 10:44:11,945 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=10, seed_set=3
2026-02-20 10:44:11,987 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 10:44:12,001 - lg_cotrain - INFO - D_l1: 50, D_l2: 50, D_LG: 5063
2026-02-20 10:44:12,004 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1169.40it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1070.68it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 10:44:30,934 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.1066, 

[6/12] budget=10, seed=3 -- done (macro_f1=0.6334)
  [PROGRESS] 6/720 (0.8%)  |  Elapsed: 1.46h  |  ETA: 173.47h  |  done
[7/12] budget=25, seed=1 -- starting...


2026-02-20 10:59:08,595 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=25, seed_set=1
2026-02-20 10:59:08,633 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 10:59:08,645 - lg_cotrain - INFO - D_l1: 130, D_l2: 120, D_LG: 4913
2026-02-20 10:59:08,647 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1118.13it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1090.39it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 10:59:28,142 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.1076

[7/12] budget=25, seed=1 -- done (macro_f1=0.6243)
  [PROGRESS] 7/720 (1.0%)  |  Elapsed: 1.70h  |  ETA: 173.25h  |  done
[8/12] budget=25, seed=2 -- starting...


2026-02-20 11:13:44,047 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=25, seed_set=2
2026-02-20 11:13:44,096 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 11:13:44,105 - lg_cotrain - INFO - D_l1: 130, D_l2: 120, D_LG: 4913
2026-02-20 11:13:44,106 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1086.54it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1054.73it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 11:14:04,325 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.1022

[8/12] budget=25, seed=2 -- done (macro_f1=0.6577)
  [PROGRESS] 8/720 (1.1%)  |  Elapsed: 1.95h  |  ETA: 173.80h  |  done
[9/12] budget=25, seed=3 -- starting...


2026-02-20 11:28:50,897 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=25, seed_set=3
2026-02-20 11:28:50,947 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 11:28:50,958 - lg_cotrain - INFO - D_l1: 130, D_l2: 120, D_LG: 4913
2026-02-20 11:28:50,960 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1023.16it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1010.95it/s, Materializing param=bert.pooler.dense.weight]
2026-02-20 11:29:11,457 - lg_cotrain - INFO - Phase 1 epoch 1/7: mean_prob1=0.1079

KeyboardInterrupt: 

In [None]:
# Load any results that already existed (re-run safe)
for strategy in STRATEGIES:
    results_root = Path(STRATEGY_RESULTS_ROOTS[strategy])
    if strategy not in all_strategy_results:
        all_strategy_results[strategy] = {}

    for event in TARGET_EVENTS:
        if event in all_strategy_results[strategy]:
            continue
        results = []
        for budget in BUDGETS:
            for seed_set in SEED_SETS:
                path = results_root / event / f"{budget}_set{seed_set}" / "metrics.json"
                if path.exists():
                    with open(path) as f:
                        results.append(json.load(f))
        if results:
            all_strategy_results[strategy][event] = results

# Build lookup: lookup[strategy][event][(budget, seed_set)] -> result
lookup = {}
for strategy in STRATEGIES:
    lookup[strategy] = {}
    for event in TARGET_EVENTS:
        lookup[strategy][event] = {}
        for r in all_strategy_results.get(strategy, {}).get(event, []):
            if r is not None:
                lookup[strategy][event][(r["budget"], r["seed_set"])] = r

# Coverage report
print("Experiments loaded per strategy:")
for strategy in STRATEGIES:
    n = sum(len(lookup[strategy][e]) for e in TARGET_EVENTS)
    expected = expts_per_strategy
    pct = 100 * n / expected if expected else 0
    print(f"  {strategy:<25}: {n:>3}/{expected}  ({pct:.0f}%)")

In [None]:
# Summary tables: for each event, show mean macro-F1 per strategy × budget
# plus a delta-from-baseline table.

col_w = 10

for event in TARGET_EVENTS:
    print(f"\n{'=' * 70}")
    print(f"Event: {event}")
    print(f"{'=' * 70}")

    # --- Absolute macro-F1 ---
    print(f"{'Strategy':<26}" + "".join(f" B={b:<{col_w - 2}}" for b in BUDGETS) + " | Mean")
    print("-" * (26 + col_w * len(BUDGETS) + 7))

    baseline_means = {}
    for strategy in STRATEGIES:
        row = f"{strategy:<26}"
        budget_means = []
        for budget in BUDGETS:
            f1s = [
                lookup[strategy][event].get((budget, s), {}).get("test_macro_f1")
                for s in SEED_SETS
            ]
            f1s = [f for f in f1s if f is not None]
            if f1s:
                m = statistics.mean(f1s)
                sd = statistics.stdev(f1s) if len(f1s) >= 2 else 0.0
                budget_means.append(m)
                row += f" {m:.4f}±{sd:.4f}"
            else:
                budget_means.append(None)
                row += f" {'N/A':<{col_w}}"
        valid = [v for v in budget_means if v is not None]
        row += f" | {statistics.mean(valid):.4f}" if valid else " | N/A"
        print(row)
        if strategy == "baseline":
            baseline_means = dict(zip(BUDGETS, budget_means))

    # --- Delta vs baseline ---
    print()
    print(f"Delta vs baseline  (+) = better:")
    print(f"{'Strategy':<26}" + "".join(f" B={b:<{col_w - 2}}" for b in BUDGETS) + " | Mean Δ")
    print("-" * (26 + col_w * len(BUDGETS) + 9))

    for strategy in STRATEGIES:
        if strategy == "baseline":
            continue
        row = f"{strategy:<26}"
        deltas = []
        for budget in BUDGETS:
            f1s = [
                lookup[strategy][event].get((budget, s), {}).get("test_macro_f1")
                for s in SEED_SETS
            ]
            f1s = [f for f in f1s if f is not None]
            base = baseline_means.get(budget)
            if f1s and base is not None:
                d = statistics.mean(f1s) - base
                deltas.append(d)
                sign = "+" if d >= 0 else ""
                row += f" {sign}{d:.4f}   "
            else:
                row += f" {'N/A':<{col_w}}"
        row += f" | {'+' if sum(deltas)/len(deltas)>=0 else ''}{sum(deltas)/len(deltas):.4f}" if deltas else " | N/A"
        print(row)

In [None]:
# Grouped bar chart: macro-F1 by budget, grouped bars per strategy
# One subplot per event (mean ± std across 3 seeds)

n_events     = len(TARGET_EVENTS)
n_strategies = len(STRATEGIES)
bar_width    = 0.8 / n_strategies
colors       = plt.cm.tab10(np.linspace(0, 1, n_strategies))

# Layout: up to 5 events per row
ncols = min(5, n_events)
nrows = (n_events + ncols - 1) // ncols
fig, axes = plt.subplots(nrows, ncols, figsize=(4.5 * ncols, 4.5 * nrows), sharey=False)
axes_flat = np.array(axes).flatten() if n_events > 1 else [axes]

for ax, event in zip(axes_flat, TARGET_EVENTS):
    x = np.arange(len(BUDGETS))
    for i, (strategy, color) in enumerate(zip(STRATEGIES, colors)):
        means, errs = [], []
        for budget in BUDGETS:
            f1s = [
                lookup[strategy][event].get((budget, s), {}).get("test_macro_f1")
                for s in SEED_SETS
            ]
            f1s = [f for f in f1s if f is not None]
            means.append(statistics.mean(f1s) if f1s else 0)
            errs.append(statistics.stdev(f1s) if len(f1s) >= 2 else 0)
        offset = (i - n_strategies / 2 + 0.5) * bar_width
        ax.bar(
            x + offset, means, bar_width * 0.9,
            yerr=errs, capsize=2,
            label=strategy, color=color, alpha=0.85,
        )
    ax.set_title(event.replace("_", " ").title(), fontsize=9)
    ax.set_xlabel("Budget")
    ax.set_ylabel("Macro-F1")
    ax.set_xticks(x)
    ax.set_xticklabels([str(b) for b in BUDGETS])
    ax.set_ylim(0, 1)
    ax.grid(axis="y", alpha=0.3)

# Hide unused subplots
for ax in axes_flat[n_events:]:
    ax.set_visible(False)

axes_flat[0].legend(fontsize=7, loc="upper left", framealpha=0.7)
fig.suptitle(
    f"Stopping Strategy Comparison — All Budgets & Seeds\n(pseudo-labels: {PSEUDO_LABEL_SOURCE})",
    fontsize=12,
)
plt.tight_layout()
plt.show()

In [None]:
# Cross-event summary: for each strategy, show mean macro-F1 across all events and budgets.
# Useful for picking a single "best" strategy to recommend.

print("Grand summary — mean macro-F1 across all events and seeds\n")
print(f"{'Strategy':<26}" + "".join(f" B={b:<8}" for b in BUDGETS) + " | Overall")
print("-" * (26 + 10 * len(BUDGETS) + 10))

for strategy in STRATEGIES:
    row = f"{strategy:<26}"
    all_f1s = []
    for budget in BUDGETS:
        f1s = [
            lookup[strategy][e].get((budget, s), {}).get("test_macro_f1")
            for e in TARGET_EVENTS
            for s in SEED_SETS
        ]
        f1s = [f for f in f1s if f is not None]
        if f1s:
            m = statistics.mean(f1s)
            sd = statistics.stdev(f1s) if len(f1s) >= 2 else 0
            all_f1s.extend(f1s)
            row += f" {m:.4f}±{sd:.4f}"
        else:
            row += f" {'N/A':<10}"
    overall = statistics.mean(all_f1s) if all_f1s else None
    row += f" | {overall:.4f}" if overall is not None else " | N/A"
    print(row)

# Per-class F1 heatmap for each event at budget=5 (hardest case)
from lg_cotrain.data_loading import CLASS_LABELS

HEATMAP_BUDGET = 5
print(f"\nPer-class F1 heatmaps at budget={HEATMAP_BUDGET} (hardest imbalance scenario)")

for event in TARGET_EVENTS:
    strategies_with_data = []
    class_f1_matrix = []

    for strategy in STRATEGIES:
        per_class_all_seeds = [
            lookup[strategy][event][(HEATMAP_BUDGET, s)]["test_per_class_f1"]
            for s in SEED_SETS
            if (HEATMAP_BUDGET, s) in lookup[strategy][event]
            and "test_per_class_f1" in lookup[strategy][event][(HEATMAP_BUDGET, s)]
        ]
        if per_class_all_seeds:
            mean_per_class = [
                statistics.mean(seed[i] for seed in per_class_all_seeds)
                for i in range(len(per_class_all_seeds[0]))
            ]
            strategies_with_data.append(strategy)
            class_f1_matrix.append(mean_per_class)

    if not strategies_with_data:
        print(f"  No per-class data for {event} at budget={HEATMAP_BUDGET}, skipping.")
        continue

    data = np.array(class_f1_matrix)
    fig, ax = plt.subplots(
        figsize=(max(9, len(CLASS_LABELS) * 0.75), len(strategies_with_data) * 0.65 + 1.8)
    )
    im = ax.imshow(data, cmap="RdYlGn", aspect="auto", vmin=0, vmax=1)

    ax.set_xticks(range(len(CLASS_LABELS)))
    ax.set_xticklabels(CLASS_LABELS, rotation=45, ha="right", fontsize=8)
    ax.set_yticks(range(len(strategies_with_data)))
    ax.set_yticklabels(strategies_with_data, fontsize=9)
    ax.set_title(
        f"{event}  |  Budget={HEATMAP_BUDGET}  |  Per-class F1 (mean across seeds)",
        fontsize=10,
    )

    for i in range(len(strategies_with_data)):
        for j in range(len(CLASS_LABELS)):
            val = data[i, j]
            color = "black" if 0.25 < val < 0.75 else "white"
            ax.text(j, i, f"{val:.2f}", ha="center", va="center", fontsize=7, color=color)

    fig.colorbar(im, ax=ax, label="F1 Score")
    plt.tight_layout()
    plt.show()

In [None]:
# Rebuild multi-tab dashboard with all strategy result sets.

from lg_cotrain.dashboard import discover_result_sets, generate_html_multi

TOP_RESULTS_ROOT = str(repo_root / "results")

result_sets = discover_result_sets(TOP_RESULTS_ROOT)
print(f"Discovered {len(result_sets)} model(s):")
for model, types in result_sets.items():
    for exp_type, experiments in types.items():
        for name, path in experiments:
            print(f"  {model}/{exp_type}/{name:<20} -> {path}")

html = generate_html_multi(result_sets, data_root=DATA_ROOT)
dashboard_path = Path(TOP_RESULTS_ROOT) / "dashboard.html"
dashboard_path.write_text(html)
print(f"\nDashboard written to: {dashboard_path}")
print("Open in a browser to compare all strategies across all budgets and events.")