In [None]:
from __future__ import annotations

import json
import math
import sys
from pathlib import Path

import numpy as np
from qspectro2d.config.create_sim_obj import load_simulation, load_simulation_config
from qspectro2d.core.simulation.time_axes import compute_times_local, compute_t_coh

# Ensure package imports work when running in notebook
ROOT = Path.cwd().resolve()
if ROOT.name == "notebooks":
    ROOT = ROOT.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

SWEEPS_ROOT = ROOT / "data" / "jobs" / "sweeps"

summary_candidates = sorted(
    SWEEPS_ROOT.glob("*/summary.json"),
    key=lambda path: path.stat().st_mtime,
    reverse=True,
 )
SUMMARY_PATH = summary_candidates[0] if summary_candidates else None
if SUMMARY_PATH is None:
    print(f"No sweep summary.json found under {SWEEPS_ROOT.resolve()}.")
else:
    print(f"Using sweep summary: {SUMMARY_PATH}")

def time_str_to_seconds(time_str: str) -> int:
    h, m, s = (int(part) for part in time_str.split(":"))
    return h * 3600 + m * 60 + s

def estimate_slurm_resources(
    n_times: int,  # number of time steps in the local grid
    n_inhom: int,
    n_t_coh: int,  # number of coherence times -> how many combinations
    n_batches: int,
    *,
    workers: int = 1,
    N_dim: int,
    solver: str = "lindblad",
    mem_safety: float = 100.0,
    base_mb: float = 500.0,
    time_safety: float = 5.0,
    base_time: float = 0.0,
    rwa_sl: bool = True,
    summary_path: Path | None = None,
 ) -> tuple[str, str]:
    """
    Estimate SLURM memory and runtime for QuTiP evolutions.

    Scaling model (per batch):
        time ~ base_t * solver_factor * rwa_factor
               * (n_times / 1000) * (N_dim**2)
               * (combos_per_batch / workers)
               * time_safety
    """
    # ---------------------- MEMORY ----------------------
    bytes_per_solver = n_times * (N_dim) * 16  # only store the expectation values
    total_bytes = mem_safety * workers * bytes_per_solver
    mem_mb = base_mb + total_bytes / (1024**2)
    requested_mem = f"{int(math.ceil(mem_mb))}M"

    # ---------------------- TIME ------------------------
    # Number of total independent simulations
    combos_total = n_inhom * n_t_coh
    batches = max(1, n_batches)
    combos_per_batch = max(1, int(math.ceil(combos_total / batches)))

    # Empirical baseline: base_t s per combo for lindblad+RWA, 1 atom, n_times=1000, N=2
    base_t = 2.3
    solver_factor = {
        "paper_eqs": 1.0,
        "lindblad": 1.0,
        "redfield": 1.06,
    }
    # Conservative no-RWA factor (max observed across solvers)
    rwa_factor = 3.0 if not rwa_sl else 1.0

    # Optional calibration using sweep summary
    if summary_path and summary_path.exists():
        try:
            summary_data = json.loads(summary_path.read_text(encoding="utf-8"))
            by_label = {entry["label"]: entry for entry in summary_data}
            baseline_entry = by_label.get("baseline")
            if baseline_entry:
                baseline_cfg = load_simulation_config(baseline_entry["config_path"])
                baseline_times = np.asarray(compute_times_local(baseline_cfg), dtype=float)
                baseline_runtime = float(baseline_entry["runtime_s"])

                # Calibrate solver factors from sweep ratios
                solver_factor = {
                    "redfield": 1.0,
                    "lindblad": solver_factor.get("lindblad", 1.0),
                    "paper_eqs": solver_factor.get("paper_eqs", 1.0),
                }
                lindblad_entry = by_label.get("config_solver=lindblad")
                if lindblad_entry:
                    solver_factor["lindblad"] = (
                        float(lindblad_entry["runtime_s"]) / baseline_runtime
                    )
                paper_entry = by_label.get("config_solver=paper_eqs")
                if paper_entry:
                    solver_factor["paper_eqs"] = (
                        float(paper_entry["runtime_s"]) / baseline_runtime
                    )

                # Calibrate RWA factor if present
                no_rwa_entry = by_label.get("laser_rwa_sl=False")
                if no_rwa_entry:
                    rwa_factor = float(no_rwa_entry["runtime_s"]) / baseline_runtime

                # Calibrate base_t from baseline runtime (assume combos=1, safety=1)
                n_times_base = max(1, len(baseline_times))
                base_t = (
                    baseline_runtime
                    / ((n_times_base / 1000) * (N_dim**2) * solver_factor["lindblad"])
                )
                print(f"Calibrated base_t={base_t:.4g} s (summary)")
        except Exception:
            pass

    if solver not in solver_factor:
        raise ValueError(f"Unsupported solver '{solver}'.")
    if rwa_sl:
        rwa_factor = 1.0

    # scaling ~ n_times * N^2  (sparse regime)
    time_per_combo = (
        base_t
        * solver_factor[solver]
        * rwa_factor
        * (n_times / 1000)
        * (N_dim**2)
    )

    # total time for one batch (divide by workers)
    total_seconds = time_per_combo * combos_per_batch * time_safety / max(1, workers)

    # Ensure minimum time of 1 minute to avoid SLURM rejection
    total_seconds = max(total_seconds, base_time)

    # convert to HH:MM:SS, clip to max 24h if needed
    h = int(total_seconds // 3600)
    m = int((total_seconds % 3600) // 60)
    s = int(total_seconds % 60)
    # Cap at 3 days (72 hours) to fit GPGPU partition limit
    if h >= 72:
        h, m, s = 72, 0, 0
    requested_time = f"{h:02d}:{m:02d}:{s:02d}"

    return requested_mem, requested_time


Using sweep summary: C:\Users\leopo\.vscode\thesis_python\data\jobs\sweeps\_monomer_1d_20260126_182009\summary.json


In [2]:
# Typical parameter scenarios without using a sweep summary
scenarios = [
    {
        "label": "small_1d",
        "n_times": 800,
        "n_inhom": 1,
        "n_t_coh": 10,
        "workers": 1,
        "N_dim": 2,
        "solver": "redfield",
        "rwa_sl": True,
    },
    {
        "label": "medium_1d",
        "n_times": 1500,
        "n_inhom": 5,
        "n_t_coh": 20,
        "workers": 2,
        "N_dim": 2,
        "solver": "redfield",
        "rwa_sl": True,
    },
    {
        "label": "large_1d",
        "n_times": 2500,
        "n_inhom": 10,
        "n_t_coh": 40,
        "workers": 4,
        "N_dim": 2,
        "solver": "redfield",
        "rwa_sl": True,
    },
    {
        "label": "medium_2d",
        "n_times": 1500,
        "n_inhom": 5,
        "n_t_coh": 50,
        "workers": 4,
        "N_dim": 4,
        "solver": "redfield",
        "rwa_sl": True,
    },
]

print("Estimated runtimes (using sweep summary if present):")
for sc in scenarios:
    _, est_time = estimate_slurm_resources(
        n_times=sc["n_times"],
        n_inhom=sc["n_inhom"],
        n_t_coh=sc["n_t_coh"],
        n_batches=1,
        workers=sc["workers"],
        N_dim=sc["N_dim"],
        solver=sc["solver"],
        rwa_sl=sc["rwa_sl"],
        summary_path=SUMMARY_PATH,
    )
    print(
        f"  {sc['label']}: n_times={sc['n_times']}, "
        f"n_inhom={sc['n_inhom']}, n_t_coh={sc['n_t_coh']}, "
        f"N_dim={sc['N_dim']}, workers={sc['workers']} -> {est_time}"
    )

Estimated runtimes (using sweep summary if present):
Calibrated base_t=21.57 s (summary)
  small_1d: n_times=800, n_inhom=1, n_t_coh=10, N_dim=2, workers=1 -> 00:57:31
Calibrated base_t=21.57 s (summary)
  medium_1d: n_times=1500, n_inhom=5, n_t_coh=20, N_dim=2, workers=2 -> 08:59:14
Calibrated base_t=21.57 s (summary)
  large_1d: n_times=2500, n_inhom=10, n_t_coh=40, N_dim=2, workers=4 -> 29:57:27
Calibrated base_t=5.392 s (summary)
  medium_2d: n_times=1500, n_inhom=5, n_t_coh=50, N_dim=4, workers=4 -> 11:14:02


In [9]:
# Estimate runtime for current _monomer.yaml settings (no sweep summary)
config_path = ROOT / "scripts" / "simulation_configs" / "_monomer.yaml"

sim = load_simulation(str(config_path), run_validation=False)

n_times = len(compute_times_local(sim.simulation_config))
n_t_coh = len(compute_t_coh(sim.simulation_config))
n_inhom = sim.simulation_config.n_inhomogen
workers = max(1, sim.simulation_config.max_workers)
N_dim = sim.system.dimension
solver = sim.simulation_config.ode_solver
rwa_sl = sim.simulation_config.rwa_sl

_, est_time = estimate_slurm_resources(
    n_times=n_times,
    n_inhom=n_inhom,
    n_t_coh=n_t_coh,
    n_batches=1,
    workers=workers,
    N_dim=N_dim,
    solver=solver,
    rwa_sl=rwa_sl,
 )

print(
    "Current _monomer.yaml estimate (using sweep summary if present):",
    f"n_times={n_times}, n_t_coh={n_t_coh}, n_inhom={n_inhom}, ",
    f"solver={solver}, rwa_sl={rwa_sl} -> {est_time}",
)

Current _monomer.yaml estimate (using sweep summary if present): n_times=13826, n_t_coh=1, n_inhom=1,  solver=redfield, rwa_sl=True -> 00:00:54


In [4]:
# Compare estimated vs actual runtimes using sweep summary (if available)
if SUMMARY_PATH is None:
    print("No sweep summary found; cannot compare estimates.")
else:
    summary = json.loads(SUMMARY_PATH.read_text(encoding="utf-8"))
    results = []
    for entry in summary:
        sim = load_simulation(entry["config_path"], run_validation=False)
        n_times = len(compute_times_local(sim.simulation_config))
        n_t_coh = len(compute_t_coh(sim.simulation_config))
        n_inhom = sim.simulation_config.n_inhomogen
        workers = max(1, sim.simulation_config.max_workers)
        N_dim = sim.system.dimension
        solver = sim.simulation_config.ode_solver
        rwa_sl = sim.simulation_config.rwa_sl

        _, est_time = estimate_slurm_resources(
            n_times=n_times,
            n_inhom=n_inhom,
            n_t_coh=n_t_coh,
            n_batches=1,
            workers=workers,
            N_dim=N_dim,
            solver=solver,
            rwa_sl=rwa_sl,
            summary_path=SUMMARY_PATH,
        )
        est_seconds = time_str_to_seconds(est_time)
        actual = float(entry["runtime_s"])
        ratio = est_seconds / actual if actual > 0 else float("inf")
        results.append({
            "label": entry["label"],
            "actual_s": actual,
            "est_s": est_seconds,
            "ratio": ratio,
        })

    failures = [row for row in results if row["ratio"] < 2.0]
    print(f"Total cases: {len(results)}")
    print(f"Failures (<2x): {len(failures)}")
    print("\nTop 5 lowest ratios:")
    for row in sorted(results, key=lambda r: r["ratio"])[:5]:
        print(
            f"  {row['label']}: actual={row['actual_s']:.3f}s, "
            f"est={row['est_s']}s, ratio={row['ratio']:.2f}"
        )
    print("\nAll failures:")
    for row in failures:
        print(
            f"  {row['label']}: actual={row['actual_s']:.3f}s, "
            f"est={row['est_s']}s, ratio={row['ratio']:.2f}"
        )

Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Calibrated base_t=21.57 s (summary)
Total cases: 11
Failures (<2x): 10

Top 5 lowest ratios:
  config_solver=redfield: actual=14.924s, est=4s, ratio=0.27
  config_t_det_max=5.0: actual=11.171s, est=3s, ratio=0.27
  config_t_wait=20.0: actual=11.076s, est=3s, ratio=0.27
  config_t_coh=10.0: actual=11.763s, est=4s, ratio=0.34
  bath_bath_type=drudelorentz: actual=11.737s, est=4s, ratio=0.34

All failures:
  baseline: actual=11.475s, est=4s, ratio=0.35
  config_solver=redfield: actual=14.924s, est=4s, ratio=0.27
  config_solver=paper_eqs: actual=10.373s, est=4s, ratio=0.39
  bath_bath_type=drudelorentz: actual=11.737s, est=4s, ratio=0.