# Step-size Methods vs Pulses & Runtime

Compares const, linesearch (Armijo), and adam under identical baseline, basis, and stopping.
Objective fixed: terminal. Artifacts saved under `artifacts/<run_name>/<method>/...`.
Figures added later.

In [None]:
# Imports
from pathlib import Path
import sys
import os

notebook_dir = Path(__file__).resolve().parent if '__file__' in globals() else Path.cwd()
repo_root = notebook_dir.parent if notebook_dir.name == 'notebooks' else notebook_dir
if str(repo_root) not in map(str, sys.path):
    sys.path.insert(0, str(repo_root))
os.chdir(repo_root)

from src import TimeGridSpec, PulseShapeSpec, BasisSpec


In [None]:
# Experiment configuration inputs
max_time_min = 5.0  # per optimizer runtime budget in minutes

omega_shape = {'kind': 'polynomial', 'area_pi': 1.0}
delta_shape = {'kind': 'linear_chirp', 'area_pi': 0.0, 'amplitude_scale': 40.0}
K_omega = 6
K_delta = 0

max_iters = 200
grad_tol = 1e-4
rtol = 1e-5

const_learning_rate = 0.05
alpha0 = 0.1
ls_beta = 0.5
ls_sigma = 1e-4
ls_max_backtracks = 12
adam_learning_rate = 0.05
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8

power_weight = 0.0
neg_weight = 0.0
neg_kappa = 10.0

artifact_root = Path('./artifacts')
run_name = 'stepsize-comparison'
objective = 'terminal'


## Runner utilities

In [None]:

# Runner utilities
import csv
from pathlib import Path
from typing import Any, Dict

import numpy as np
from tqdm.auto import tqdm

from src import override_from_dict, run_experiment
from src.notebook_runners import (
    BaselineArrays,
    build_base_config,
    method_options,
    prepare_baseline,
)
from src.crab_notebook_utils import population_excited
from src.physics import propagate_piecewise_const

RUN_METHODS = ("const", "linesearch", "adam")


def extract_history_series(history: Dict[str, np.ndarray]) -> Dict[str, np.ndarray | None]:
    def _pull(key: str) -> np.ndarray | None:
        series = history.get(key)
        if series is None or len(series) == 0:
            return None
        return np.asarray(series, dtype=np.float64)

    return {
        "cost_total": _pull("total"),
        "cost_terminal": _pull("terminal"),
        "cost_power": _pull("power_penalty"),
        "cost_neg": _pull("neg_penalty"),
        "grad_norm": _pull("grad_norm"),
        "step_norm": _pull("step_norm"),
    }


def sum_oracle_calls(history: Dict[str, np.ndarray]) -> int:
    calls = history.get("calls_per_iter")
    if calls is None or len(calls) == 0:
        return 0
    return int(np.asarray(calls, dtype=np.int64).sum())


def compute_pulse_metrics(omega: np.ndarray, t_us: np.ndarray) -> Dict[str, float]:
    max_abs = float(np.max(np.abs(omega)))
    area = float(np.trapz(np.abs(omega), t_us) / np.pi)
    neg_frac = float(np.mean(omega < 0.0))
    return {
        "max_abs_omega": max_abs,
        "area_omega_over_pi": area,
        "negativity_fraction": neg_frac,
    }


def save_trajectories(
    run_dir: Path,
    t_us: np.ndarray,
    psi_path: np.ndarray,
    rho_path: np.ndarray,
    pop_excited: np.ndarray,
) -> None:
    traj_path = Path(run_dir) / "trajectories.npz"
    np.savez_compressed(
        traj_path,
        t_us=t_us,
        psi_path=psi_path,
        rho_path=rho_path,
        pop_excited=pop_excited,
    )


def build_method_payload(result, ctx: BaselineArrays) -> Dict[str, Any]:
    history_map = extract_history_series(result.history)
    oracle_calls = sum_oracle_calls(result.history)
    metrics = result.final_metrics
    runtime_s = float(metrics.get("runtime_s", np.nan))
    pulses = result.pulses
    omega_final = np.asarray(pulses["omega"], dtype=np.float64)
    delta_final_raw = pulses.get("delta")
    delta_final = (
        np.zeros_like(omega_final)
        if delta_final_raw is None
        else np.asarray(delta_final_raw, dtype=np.float64)
    )
    omega_init = np.asarray(
        pulses.get("omega_base", ctx.arrays["Omega0"]),
        dtype=np.float64,
    )
    delta_init_raw = pulses.get("delta_base", ctx.arrays.get("Delta0"))
    delta_init = (
        np.zeros_like(omega_final)
        if delta_init_raw is None
        else np.asarray(delta_init_raw, dtype=np.float64)
    )
    t_us = np.asarray(pulses.get("t_us", ctx.t_us), dtype=np.float64)
    traj = propagate_piecewise_const(
        omega_final,
        delta_final,
        float(ctx.dt_us),
        psi0=ctx.psi0,
    )
    rho_path = np.asarray(traj["rho_path"])
    psi_path = np.asarray(traj["psi_path"])
    pop_excited = population_excited(rho_path)
    save_trajectories(result.artifacts_dir, t_us, psi_path, rho_path, pop_excited)
    history_total = history_map["cost_total"]
    iterations = history_total.size if history_total is not None else 0
    grad_final = (
        float(history_map["grad_norm"][-1])
        if history_map["grad_norm"] is not None and history_map["grad_norm"].size
        else None
    )
    step_final = (
        float(history_map["step_norm"][-1])
        if history_map["step_norm"] is not None and history_map["step_norm"].size
        else None
    )
    pulse_metrics = compute_pulse_metrics(omega_final, t_us)
    raw_status = str(result.optimizer_state.get("status", "completed")).lower()
    status = "error" if "fail" in raw_status else "ok"
    return {
        "history": history_map,
        "oracle_calls": oracle_calls,
        "runtime_s": runtime_s,
        "final_pulses": {"omega": omega_final, "delta": delta_final},
        "initial_pulses": {"omega": omega_init, "delta": delta_init},
        "time_grid_us": t_us,
        "trajectories": {
            "psi_path": psi_path,
            "rho_path": rho_path,
            "pop_excited": pop_excited,
        },
        "metrics": {
            "total_final": float(metrics.get("total", np.nan)),
            "terminal_final": float(metrics.get("terminal", np.nan)),
            "power_final": float(metrics.get("power_penalty", 0.0)),
            "neg_final": float(metrics.get("neg_penalty", 0.0)),
            "iterations": iterations,
            "grad_norm_final": grad_final,
            "step_norm_final": step_final,
            **pulse_metrics,
        },
        "status": status,
        "status_detail": raw_status,
        "artifacts_dir": Path(result.artifacts_dir),
    }


def error_payload(message: str) -> Dict[str, Any]:
    nan = float("nan")
    empty_history = {
        "cost_total": None,
        "cost_terminal": None,
        "cost_power": None,
        "cost_neg": None,
        "grad_norm": None,
        "step_norm": None,
    }
    return {
        "history": empty_history,
        "oracle_calls": 0,
        "runtime_s": nan,
        "final_pulses": {"omega": None, "delta": None},
        "initial_pulses": {"omega": None, "delta": None},
        "time_grid_us": None,
        "trajectories": {
            "psi_path": None,
            "rho_path": None,
            "pop_excited": None,
        },
        "metrics": {
            "total_final": nan,
            "terminal_final": nan,
            "power_final": nan,
            "neg_final": nan,
            "iterations": 0,
            "grad_norm_final": None,
            "step_norm_final": None,
            "max_abs_omega": nan,
            "area_omega_over_pi": nan,
            "negativity_fraction": nan,
        },
        "status": "error",
        "status_detail": message,
        "artifacts_dir": None,
    }


## Run optimizers

In [None]:

# Run all (const, linesearch, adam) - with progress
time_grid_cfg = globals().get("time_grid_params")
runner_ctx = prepare_baseline(
    time_grid=time_grid_cfg,
    omega_shape=omega_shape,
    delta_shape=delta_shape,
    K_omega=K_omega,
    K_delta=K_delta,
    rho0=globals().get("rho0"),
    target=globals().get("target"),
    initial_omega=globals().get("initial_omega"),
    initial_delta=globals().get("initial_delta"),
)

penalties = {
    "power_weight": float(power_weight),
    "neg_weight": float(neg_weight),
    "neg_kappa": float(neg_kappa),
}
base_config, base_opts = build_base_config(
    runner_ctx.config,
    run_name=run_name,
    artifact_root=artifact_root,
    penalties=penalties,
    objective=objective,
    base_optimizer_options={
        "max_iters": int(max_iters),
        "grad_tol": float(grad_tol),
        "rtol": float(rtol),
        "max_time_s": float(max_time_min) * 60.0,
        "optimize_delta": bool(K_delta > 0),
    },
)

method_overrides = {
    "const": {"learning_rate": float(const_learning_rate)},
    "linesearch": {
        "alpha0": float(alpha0),
        "ls_beta": float(ls_beta),
        "ls_sigma": float(ls_sigma),
        "ls_max_backtracks": int(ls_max_backtracks),
    },
    "adam": {
        "learning_rate": float(adam_learning_rate),
        "beta1": float(beta1),
        "beta2": float(beta2),
        "epsilon": float(epsilon),
    },
}

results: Dict[str, Any] = {}

for method in RUN_METHODS:
    bar = tqdm(
        total=int(base_opts.get("max_iters", 0)),
        desc=f"{method:>10}",
        leave=False,
    )
    method_payload: Dict[str, Any] | None = None

    def progress_cb(stats, state, _bar=bar):
        if _bar.total < stats.iteration:
            _bar.total = stats.iteration
        _bar.n = stats.iteration
        _bar.set_postfix(cost=f"{stats.total:.3e}")
        _bar.refresh()

    try:
        overrides = method_overrides.get(method, {})
        opts = method_options(method, base_opts, overrides)
        config = override_from_dict(base_config, {"optimizer_options": opts})
        result = run_experiment(
            config,
            method=method,
            run_name=f"{run_name}-{method}",
            exist_ok=True,
            progress_callback=progress_cb,
        )
        method_payload = build_method_payload(result, runner_ctx)
    except Exception as exc:
        print(f"[{method}] error: {exc}")
        method_payload = error_payload(str(exc))
    finally:
        if bar.total and bar.n < bar.total:
            bar.n = bar.total
        bar.close()
        results[method] = method_payload


## Results summary

In [None]:
# Results summary (no plots)
if not results:
    raise RuntimeError("Run the optimizer cell first.")

header = (
    f"{'method':>10}  {'total':>12}  {'terminal':>12}  {'power':>10}  "
    f"{'neg':>10}  {'iters':>8}  {'runtime_s':>10}  {'oracle':>8}  "
    f"{'max|Omega|':>10}  {'area/pi':>10}  {'neg_frac':>10}"
)
print(header)
rows = []
for method in RUN_METHODS:
    data = results.get(method)
    if data is None:
        continue
    metrics = data["metrics"]
    row = {
        "method": method,
        "total_final": float(metrics["total_final"]),
        "terminal_final": float(metrics["terminal_final"]),
        "power_final": float(metrics["power_final"]),
        "neg_final": float(metrics["neg_final"]),
        "iterations": int(metrics["iterations"]),
        "runtime_s": float(data["runtime_s"]),
        "oracle_calls": int(data["oracle_calls"]),
        "max_abs_omega": float(metrics["max_abs_omega"]),
        "area_omega_over_pi": float(metrics["area_omega_over_pi"]),
        "negativity_fraction": float(metrics["negativity_fraction"]),
    }
    rows.append(row)
    line = (
        f"{method:>10}  {row['total_final']:12.5e}  {row['terminal_final']:12.5e}  "
        f"{row['power_final']:10.3e}  {row['neg_final']:10.3e}  {row['iterations']:8d}  "
        f"{row['runtime_s']:10.3f}  {row['oracle_calls']:8d}  {row['max_abs_omega']:10.3f}  "
        f"{row['area_omega_over_pi']:10.3f}  {row['negativity_fraction']:10.3f}"
    )
    print(line)

summary_fields = [
    "method",
    "total_final",
    "terminal_final",
    "power_final",
    "neg_final",
    "iterations",
    "runtime_s",
    "oracle_calls",
    "max_abs_omega",
    "area_omega_over_pi",
    "negativity_fraction",
]
summary_dir = (Path(artifact_root) / run_name).resolve()
summary_dir.mkdir(parents=True, exist_ok=True)
csv_path = summary_dir / "summary.csv"
with csv_path.open("w", newline="", encoding="utf-8") as fh:
    writer = csv.DictWriter(fh, fieldnames=summary_fields)
    writer.writeheader()
    for row in rows:
        writer.writerow(row)

print(f"Summary written to {csv_path}")
