# LG-CoTrain: Global Optuna Hyperparameter Tuning

This notebook finds the **optimal hyperparameters** for the LG-CoTrain pipeline
using a single global [Optuna](https://optuna.org/) study. Each trial runs the
full 3-phase pipeline across all 10 disaster events (budget=50, seed=1) and
optimizes the **mean dev macro-F1**.

### Why global tuning?

- **No test-set leakage**: The objective uses `dev_macro_f1`, not test F1.
- **Generalizable**: One set of hyperparameters that works across all events,
  rather than overfitting to a single event.
- **Efficient**: Optuna's TPE sampler + MedianPruner skips unpromising trials early.

### Search space

| Parameter | Range | Scale |
|-----------|-------|-------|
| `lr` | 1e-5 to 1e-3 | Log-uniform |
| `batch_size` | [8, 16, 32, 64] | Categorical |
| `cotrain_epochs` | 5 to 20 | Uniform integer |
| `finetune_patience` | 4 to 10 | Uniform integer |

### Paper deviation note

The original paper (Cornelia et al. 2025) uses **fixed** hyperparameters:
lr=2e-5, batch_size=32, cotrain_epochs=10, patience=5. This notebook explores
whether tuning these improves performance.

### Usage

1. Run cells 1-3 to configure and launch the Optuna study
2. After the study completes, results are saved as a JSON file
3. Cells 4-7 visualize and analyze the results
4. Cell 8 shows the CLI command to apply the best hyperparameters

In [2]:
import importlib
import sys
import time
from pathlib import Path


def _find_repo_root(marker: str = "lg_cotrain") -> Path:
    for candidate in [Path().resolve()] + list(Path().resolve().parents):
        if (candidate / marker).is_dir():
            return candidate
    raise RuntimeError(
        f"Cannot find repo root: no ancestor directory contains '{marker}/'. "
        "Run the notebook from inside the repository."
    )


repo_root = _find_repo_root()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

import optuna

import lg_cotrain.optuna_tuner
importlib.reload(lg_cotrain.optuna_tuner)
from lg_cotrain.optuna_tuner import ALL_EVENTS, create_objective, run_study

print(f"Repo root: {repo_root}")
print(f"Optuna version: {optuna.__version__}")
print(f"Events ({len(ALL_EVENTS)}): {ALL_EVENTS}")

Repo root: D:\Workspace\Co-Training
Optuna version: 4.7.0
Events (10): ['california_wildfires_2018', 'canada_wildfires_2016', 'cyclone_idai_2019', 'hurricane_dorian_2019', 'hurricane_florence_2018', 'hurricane_harvey_2017', 'hurricane_irma_2017', 'hurricane_maria_2017', 'kaikoura_earthquake_2016', 'kerala_floods_2018']


## 1. Configuration

Set the number of trials, budget, and seed for the tuning study.

- **`N_TRIALS`**: More trials = better hyperparameters, but each trial runs
  the full pipeline across all 10 events. Start with 10-20 for a quick scan,
  use 50+ for thorough tuning.
- **`TUNING_BUDGET`**: Budget level used during tuning (50 = most labeled data).
- **`TUNING_SEED`**: Seed set used during tuning.

In [3]:
# ---- Tuning Configuration ----

N_TRIALS      = 20        # Number of Optuna trials
TUNING_BUDGET = 50        # Budget level for tuning
TUNING_SEED   = 1         # Seed set for tuning
STUDY_NAME    = "lg_cotrain_global"  # Optuna study name

DATA_ROOT    = str(repo_root / "data")
RESULTS_ROOT = str(repo_root / "results" / "optuna")

# Optionally restrict to a subset of events for faster iteration
# Set to None to use all 10 events
EVENTS = None  # or e.g. ["hurricane_harvey_2017", "kerala_floods_2018"]

events_to_use = EVENTS or ALL_EVENTS
runs_per_trial = len(events_to_use)
total_runs = N_TRIALS * runs_per_trial

print(f"Study name   : {STUDY_NAME}")
print(f"Trials       : {N_TRIALS}")
print(f"Events/trial : {runs_per_trial}")
print(f"Total runs   : {total_runs} (upper bound, pruning may reduce this)")
print(f"Budget       : {TUNING_BUDGET}")
print(f"Seed set     : {TUNING_SEED}")
print(f"Results root : {RESULTS_ROOT}")
print()
print("Search space:")
print("  lr               : 1e-5 to 1e-3  (log-uniform)")
print("  batch_size       : [8, 16, 32, 64]")
print("  cotrain_epochs   : 5 to 20")
print("  finetune_patience: 4 to 10")

Study name   : lg_cotrain_global
Trials       : 20
Events/trial : 10
Total runs   : 200 (upper bound, pruning may reduce this)
Budget       : 50
Seed set     : 1
Results root : D:\Workspace\Co-Training\results\optuna

Search space:
  lr               : 1e-5 to 1e-3  (log-uniform)
  batch_size       : [8, 16, 32, 64]
  cotrain_epochs   : 5 to 20
  finetune_patience: 4 to 10


## 2. Run the Optuna Study

This cell launches the study. Each trial:

1. Optuna's TPE sampler picks a set of hyperparameters
2. The full 3-phase pipeline runs for each event
3. After each event, the running mean dev F1 is reported to the pruner
4. If the trial looks unpromising (below median), it's pruned early
5. The final objective = mean dev macro-F1 across all events

**Progress tracking**: After each event completes, you'll see the current trial,
event, dev F1, elapsed time, and estimated time remaining.

**This will take a long time** (hours to days depending on GPU and N_TRIALS).

In [None]:
import time


class ProgressTracker:
    """Track per-event progress across all Optuna trials."""

    def __init__(self, n_trials: int, n_events: int, start_time: float):
        self.n_trials = n_trials
        self.n_events = n_events
        self.total_runs = n_trials * n_events  # upper bound (pruning reduces this)
        self.runs_done = 0
        self.trials_done = 0
        self.start_time = start_time
        self._current_trial = -1

    def on_event_done(self, trial_number, event, event_idx, n_events, dev_f1, mean_f1):
        """Called after each event within a trial."""
        self.runs_done += 1

        # Detect trial transitions
        if trial_number != self._current_trial:
            if self._current_trial >= 0:
                self.trials_done += 1
            self._current_trial = trial_number

        elapsed = time.time() - self.start_time
        elapsed_h = elapsed / 3600
        # ETA based on average time per run so far
        if self.runs_done > 0:
            avg_per_run = elapsed / self.runs_done
            # Estimate remaining: remaining trials × events per trial
            remaining_this_trial = n_events - (event_idx + 1)
            remaining_future = (self.n_trials - self.trials_done - 1) * n_events
            remaining_runs = remaining_this_trial + remaining_future
            eta_h = avg_per_run * remaining_runs / 3600
        else:
            eta_h = 0

        print(
            f"  Trial {trial_number + 1}/{self.n_trials} | "
            f"Event {event_idx + 1}/{n_events} ({event}) | "
            f"dev_F1={dev_f1:.4f} (mean={mean_f1:.4f}) | "
            f"Elapsed: {elapsed_h:.2f}h | ETA: {eta_h:.2f}h"
        )


start_time = time.time()
tracker = ProgressTracker(N_TRIALS, len(events_to_use), start_time)

study = run_study(
    n_trials=N_TRIALS,
    events=events_to_use,
    budget=TUNING_BUDGET,
    seed_set=TUNING_SEED,
    data_root=DATA_ROOT,
    results_root=RESULTS_ROOT,
    study_name=STUDY_NAME,
    _on_event_done=tracker.on_event_done,
)

elapsed = time.time() - start_time
print(f"\nStudy completed in {elapsed / 3600:.2f}h ({elapsed / 60:.1f}min)")
print(f"Runs executed: {tracker.runs_done} (of {tracker.total_runs} max)")
print(f"Pruned trials saved ~{tracker.total_runs - tracker.runs_done} runs")

[32m[I 2026-02-20 18:37:54,594][0m A new study created in memory with name: lg_cotrain_global[0m
2026-02-20 18:37:59,381 - lg_cotrain - INFO - Starting LG-CoTrain: event=california_wildfires_2018, budget=50, seed_set=1
2026-02-20 18:37:59,410 - lg_cotrain - INFO - Detected 10 classes for event california_wildfires_2018: ['caution_and_advice', 'displaced_people_and_evacuations', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people', 'not_humanitarian', 'other_relevant_information', 'requests_or_urgent_needs', 'rescue_volunteering_or_donation_effort', 'sympathy_and_support']
2026-02-20 18:37:59,539 - lg_cotrain - INFO - D_l1: 250, D_l2: 250, D_LG: 4663
2026-02-20 18:37:59,545 - lg_cotrain - INFO - === Phase 1: Weight Generation ===
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1051.21it/s, Materializing param=bert.pooler.dense.weight]
Loading weights: 100%|███████████████| 199/199 [00:00<00:00, 1075.41it/s, Materializing param=bert.poo

## 3. Save Results to JSON

Export the study results as a human-readable JSON file for easy access.

In [None]:
import json
from pathlib import Path

# Build results dict
results = {
    "study_name": STUDY_NAME,
    "n_trials": len(study.trials),
    "best_trial": {
        "number": study.best_trial.number,
        "mean_dev_macro_f1": round(study.best_value, 6),
        "params": study.best_params,
    },
    "paper_defaults": {
        "lr": 2e-5,
        "batch_size": 32,
        "cotrain_epochs": 10,
        "finetune_patience": 5,
    },
    "search_space": {
        "lr": "1e-5 to 1e-3 (log-uniform)",
        "batch_size": [8, 16, 32, 64],
        "cotrain_epochs": "5 to 20",
        "finetune_patience": "4 to 10",
    },
    "trials": [],
}

for t in study.trials:
    trial_info = {
        "number": t.number,
        "state": t.state.name,
        "params": t.params,
    }
    if t.value is not None:
        trial_info["mean_dev_macro_f1"] = round(t.value, 6)
    if t.datetime_start and t.datetime_complete:
        trial_info["duration_seconds"] = round(
            (t.datetime_complete - t.datetime_start).total_seconds(), 1
        )
    results["trials"].append(trial_info)

# Save to results directory
output_dir = Path(RESULTS_ROOT)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "optuna_results.json"

with open(output_path, "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {output_path}")
print(f"\nBest trial #{results['best_trial']['number']}:")
print(f"  Mean dev macro-F1: {results['best_trial']['mean_dev_macro_f1']}")
for k, v in results['best_trial']['params'].items():
    print(f"  {k}: {v}")

## 4. Optimization History

Plot how the objective value evolved across trials. Completed trials are shown
as blue dots; pruned trials as red X marks. The dashed line tracks the
running best value.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

completed = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
pruned    = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]

fig, ax = plt.subplots(figsize=(10, 5))

# Completed trials
if completed:
    ax.scatter(
        [t.number for t in completed],
        [t.value for t in completed],
        color="tab:blue", alpha=0.7, label=f"Completed ({len(completed)})",
        zorder=3,
    )

# Pruned trials (show at last reported value)
if pruned:
    pruned_vals = []
    for t in pruned:
        if t.intermediate_values:
            last_step = max(t.intermediate_values.keys())
            pruned_vals.append(t.intermediate_values[last_step])
        else:
            pruned_vals.append(0)
    ax.scatter(
        [t.number for t in pruned],
        pruned_vals,
        color="tab:red", marker="x", alpha=0.5, label=f"Pruned ({len(pruned)})",
        zorder=3,
    )

# Running best line
if completed:
    sorted_completed = sorted(completed, key=lambda t: t.number)
    running_best = []
    best_so_far = -1
    for t in sorted_completed:
        best_so_far = max(best_so_far, t.value)
        running_best.append(best_so_far)
    ax.plot(
        [t.number for t in sorted_completed],
        running_best,
        "--", color="tab:green", alpha=0.8, label="Running best",
    )

ax.set_xlabel("Trial number")
ax.set_ylabel("Mean dev macro-F1")
ax.set_title("Optimization History")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Total trials : {len(study.trials)}")
print(f"Completed    : {len(completed)}")
print(f"Pruned       : {len(pruned)}")

## 5. Parameter Distributions

Visualize how Optuna explored the search space. For each parameter, we show:
- A histogram of sampled values (completed trials only)
- The best trial's value highlighted in red

In [None]:
params = ["lr", "batch_size", "cotrain_epochs", "finetune_patience"]

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for ax, param in zip(axes.flat, params):
    values = [t.params[param] for t in completed]
    best_val = best.params[param]

    if param == "lr":
        # Log scale for learning rate
        log_values = [np.log10(v) for v in values]
        ax.hist(log_values, bins=15, alpha=0.7, color="tab:blue", edgecolor="white")
        ax.axvline(np.log10(best_val), color="tab:red", linestyle="--",
                   label=f"Best: {best_val:.2e}")
        ax.set_xlabel(f"{param} (log10)")
    elif param == "batch_size":
        # Categorical: bar chart
        from collections import Counter
        counts = Counter(values)
        categories = [8, 16, 32, 64]
        bar_counts = [counts.get(c, 0) for c in categories]
        bar_colors = ["tab:red" if c == best_val else "tab:blue" for c in categories]
        ax.bar([str(c) for c in categories], bar_counts, color=bar_colors, alpha=0.7,
               edgecolor="white")
        ax.set_xlabel(param)
    else:
        ax.hist(values, bins=range(min(values), max(values) + 2), alpha=0.7,
                color="tab:blue", edgecolor="white", align="left")
        ax.axvline(best_val, color="tab:red", linestyle="--",
                   label=f"Best: {best_val}")
        ax.set_xlabel(param)

    ax.set_ylabel("Count")
    ax.set_title(param)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, axis="y")

plt.suptitle("Parameter Distributions (Completed Trials)", fontsize=13)
plt.tight_layout()
plt.show()

## 6. Parameter vs Objective Scatter Plots

See how each parameter correlates with the objective value.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for ax, param in zip(axes.flat, params):
    values = [t.params[param] for t in completed]
    objectives = [t.value for t in completed]

    if param == "lr":
        ax.scatter(values, objectives, alpha=0.6, color="tab:blue")
        ax.set_xscale("log")
        ax.axvline(best.params[param], color="tab:red", linestyle="--", alpha=0.5)
    elif param == "batch_size":
        # Add jitter for visibility
        jittered = [v + np.random.uniform(-1.5, 1.5) for v in values]
        ax.scatter(jittered, objectives, alpha=0.6, color="tab:blue")
        ax.set_xticks([8, 16, 32, 64])
    else:
        jittered = [v + np.random.uniform(-0.3, 0.3) for v in values]
        ax.scatter(jittered, objectives, alpha=0.6, color="tab:blue")
        ax.axvline(best.params[param], color="tab:red", linestyle="--", alpha=0.5)

    ax.set_xlabel(param)
    ax.set_ylabel("Mean dev macro-F1")
    ax.set_title(f"{param} vs Objective")
    ax.grid(True, alpha=0.3)

plt.suptitle("Parameter vs Objective (Completed Trials)", fontsize=13)
plt.tight_layout()
plt.show()

## 7. Trial Summary Table

Show all completed trials sorted by objective value (best first).

In [None]:
sorted_trials = sorted(completed, key=lambda t: t.value, reverse=True)

print(f"{'#':>4}  {'Mean Dev F1':>12}  {'lr':>10}  {'batch':>6}  {'co_ep':>6}  {'pat':>4}  {'Duration':>10}")
print("-" * 62)

for t in sorted_trials[:20]:  # Show top 20
    duration = (t.datetime_complete - t.datetime_start).total_seconds() if t.datetime_complete else 0
    duration_str = f"{duration / 60:.1f}min" if duration < 3600 else f"{duration / 3600:.1f}h"
    print(
        f"{t.number:>4}  {t.value:>12.4f}  {t.params['lr']:>10.2e}  "
        f"{t.params['batch_size']:>6}  {t.params['cotrain_epochs']:>6}  "
        f"{t.params['finetune_patience']:>4}  {duration_str:>10}"
    )

if len(sorted_trials) > 20:
    print(f"  ... and {len(sorted_trials) - 20} more trials")

## 8. Apply Best Hyperparameters

Use the best hyperparameters found by Optuna to run the full experiment grid
(all budgets, all seed sets, all events).

**Copy the CLI command below** and run it in a terminal, or use the
`run_experiment.py` API directly.

In [None]:
bp = study.best_params

print("Apply the best hyperparameters via CLI:\n")
print("# Single event:")
print(
    f"python -m lg_cotrain.run_experiment \\"
    f"\n    --event kaikoura_earthquake_2016 \\"
    f"\n    --lr {bp['lr']:.6f} \\"
    f"\n    --batch-size {bp['batch_size']} \\"
    f"\n    --cotrain-epochs {bp['cotrain_epochs']} \\"
    f"\n    --finetune-patience {bp['finetune_patience']}"
)

print("\n# All events (full sweep):")
events_str = " ".join(ALL_EVENTS)
print(
    f"python -m lg_cotrain.run_experiment \\"
    f"\n    --events {events_str} \\"
    f"\n    --lr {bp['lr']:.6f} \\"
    f"\n    --batch-size {bp['batch_size']} \\"
    f"\n    --cotrain-epochs {bp['cotrain_epochs']} \\"
    f"\n    --finetune-patience {bp['finetune_patience']} \\"
    f"\n    --output-folder results/gpt-4o/test/optuna-tuned"
)

## Summary

This notebook ran a **global Optuna hyperparameter study** to find optimal
`lr`, `batch_size`, `cotrain_epochs`, and `finetune_patience` for the
LG-CoTrain pipeline.

### Methodology
- **Objective**: Mean dev macro-F1 across all 10 disaster events (budget=50, seed=1)
- **Sampler**: TPE (Tree-structured Parzen Estimator)
- **Pruner**: MedianPruner (prune after 3+ events if below median)
- **No test-set leakage**: Only dev set metrics used for optimization
- **Output**: Results saved as JSON to `results/optuna/optuna_results.json`

### CLI equivalent
```bash
python -m lg_cotrain.optuna_tuner --n-trials 20
```