In [None]:
import submitit
import inspect
import os
import pandas as pd

from frust.stepper import Stepper
import frust.vis as vis
from pathlib import Path
from tooltoad.chemutils import xyz2mol

from rdkit import Chem

# Live

In [3]:
# scripts/submit_multistage_ts_rpos.py
import os
import re
import importlib
import inspect
import pandas as pd
import submitit

# ─── CONFIG ─────────────────────────────────────────────────────────────
MODULE_PATH   = "frust.pipelines.run_ts_per_rpos"  # your run_* module
CSV_PATH      = "../datasets/dft_finalists.csv"
TS_XYZ        = "../structures/ts4_TMP.xyz"

ROOT_SAVE_DIR = "dft_finalists_ts4"
WORK_DIR      = None
LOG_DIR       = f"logs/{ROOT_SAVE_DIR}"

PARTITION     = "kemi1"
USE_SLURM     = True
DEBUG         = False
PRODUCTION    = True
N_CONFS       = None if PRODUCTION else 1

RES = {
    "run_init":   dict(cpus=16, mem=20, time=7200),
    "run_hess":   dict(cpus=8, mem=64, time=7200),
    "run_OptTS":  dict(cpus=16, mem=20, time=7200),
    "run_freq":   dict(cpus=8, mem=64, time=7200),
    "run_solv":   dict(cpus=16,  mem=20, time=3600),
    "run_cleanup": dict(cpus=2, mem=2,  time=60),
}


# ─── HELPERS ────────────────────────────────────────────────────────────
def _sanitize(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", s)

def _tag_from_ts_struct(ts_struct: dict) -> str:
    return _sanitize(list(ts_struct.keys())[0])

def _stage_res(fn_name: str):
    r = RES.get(fn_name, dict(cpus=4, mem=20, time=720))
    return r["cpus"], r["mem"], r["time"]

def _clear_sticky(executor):
    executor.update_parameters(slurm_additional_parameters={})

def _submit(executor, fn, kwargs, fn_name, tag, dep_job):
    _clear_sticky(executor)
    cpus, mem, tmo = _stage_res(fn_name)
    extra = {"dependency": f"afterok:{dep_job.job_id}",
             "exclude": "node236,node237,node238,node239"} if (USE_SLURM and dep_job) else {}
    executor.update_parameters(
        slurm_job_name=f"{tag}_{fn_name}" if USE_SLURM else None,
        slurm_partition=PARTITION if USE_SLURM else None,
        cpus_per_task=cpus,
        mem_gb=mem,
        timeout_min=tmo,
        slurm_additional_parameters=extra,
    )
    return executor.submit(fn, **kwargs)

def _next_parquet(current: str, fn_name: str) -> str:
    """Return the parquet filename produced by fn_name, given current input."""
    stem = current.rsplit(".", 1)[0]
    if fn_name == "run_hess":
        return f"{stem}.hess.parquet"
    if fn_name == "run_OptTS":
        return f"{stem}.optts.parquet"
    if fn_name == "run_freq":
        return f"{stem}.freq.parquet"
    if fn_name == "run_solv":
        return f"{stem}.solv.parquet"
    return current  # run_init / run_cleanup don’t change the chain

# ─── LOAD PIPE & DISCOVER STAGES (order-of-definition) ──────────────────
mod = importlib.import_module(MODULE_PATH)
runs = [obj for name, obj in mod.__dict__.items()
        if name.startswith("run_") and callable(obj)]

# ─── INPUT PREP ─────────────────────────────────────────────────────────
from frust.utils.mols import create_ts_per_rpos

df = pd.read_csv(CSV_PATH)
smi_list = list(dict.fromkeys(df["smiles"]))
job_inputs = create_ts_per_rpos(smi_list, TS_XYZ)

os.makedirs(ROOT_SAVE_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# ─── EXECUTOR ───────────────────────────────────────────────────────────
executor = submitit.AutoExecutor(LOG_DIR) if USE_SLURM else submitit.LocalExecutor(LOG_DIR)
executor.update_parameters(
    slurm_partition=PARTITION if USE_SLURM else None,
    cpus_per_task=4, mem_gb=20, timeout_min=720
)

# ─── SUBMIT CHAINS ──────────────────────────────────────────────────────
all_jobs = []
for ts_struct in job_inputs:
    tag = _tag_from_ts_struct(ts_struct)
    save_dir = os.path.join(ROOT_SAVE_DIR, tag)
    os.makedirs(save_dir, exist_ok=True)

    last_job = None
    current_parquet = "init.parquet"  # input to run_hess; produced by run_init

    for fn in runs:
        fn_name = fn.__name__
        sig = inspect.signature(fn)

        # Build kwargs respecting each function's signature
        kwargs = {"save_dir": save_dir, "work_dir": WORK_DIR, "debug": DEBUG}

        if fn_name == "run_init":
            kwargs.update({
                "ts_struct": ts_struct,
                "n_confs": N_CONFS,
                "n_cores": _stage_res(fn_name)[0],
                "mem_gb":  _stage_res(fn_name)[1],
            })
        elif fn_name == "run_cleanup":
            kwargs = {"save_dir": save_dir}
        else:
            # Input to this stage is the parquet produced by the previous stage
            kwargs.update({
                "parquet_path": current_parquet,
                "n_cores": _stage_res(fn_name)[0],
                "mem_gb":  _stage_res(fn_name)[1],
            })

        # Trim kwargs to the function signature
        kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}

        # Submit with dependency on previous stage (if any)
        job = _submit(executor, fn, kwargs, fn_name, tag, last_job)
        all_jobs.append(job)
        last_job = job

        # After submitting, update current_parquet to the file this stage will produce
        if fn_name not in ("run_init", "run_cleanup"):
            current_parquet = _next_parquet(current_parquet, fn_name)

    _clear_sticky(executor)

print("Submitted Slurm job IDs:", [j.job_id for j in all_jobs])

Submitted Slurm job IDs: ['55096764', '55096765', '55096766', '55096767', '55096768', '55096769', '55096770', '55096771', '55096772', '55096773', '55096774', '55096775', '55096776', '55096777', '55096778', '55096779', '55096780', '55096781', '55096782', '55096783', '55096784', '55096785', '55096786', '55096787', '55096788', '55096789', '55096790', '55096791', '55096792', '55096793', '55096794', '55096795', '55096796', '55096797', '55096798', '55096799', '55096800', '55096801', '55096802', '55096803', '55096804', '55096805']
