# Stochastic Simulation Estimation

- We create a pipeline to automate repeated simulation under known parameters
- We automate model estimation, using Stan or Monte-Carlo MLE
- Collect parameter estimates
- Compute bias, RMSE and coverage across replicates

# Estimation Framework Choice — Stan vs NONMEM + SAEM

In the original study, model estimation was performed in NONMEM using the **Stochastic Approximation Expectation–Maximization (SAEM)** algorithm.  
To enable full reproducibility in an open-source environment, we replaced SAEM with **Bayesian inference implemented in Stan (via CmdStanPy)**.

Stan integrates over both **random effects** and **hidden states** using Hamiltonian Monte Carlo (HMC) sampling, producing posterior distributions for all parameters.  
This approach is mathematically equivalent to maximizing the marginal likelihood (MLE) under SAEM, but provides richer uncertainty quantification and does not depend on proprietary software(such as NONMEM).

Accordingly, all simulation–estimation (SSE) results in this notebook are based on **Stan-derived posterior means and credible intervals** as open-source analogues to SAEM estimates.


In [None]:
import os 
from pathlib import Path
import time 

import datetime
import logging
import pickle

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

#for parallel processing
from joblib import Parallel, delayed
from cmdstanpy import CmdStanModel
import scipy.optimize as opt
from scipy.stats import multivariate_normal

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, str(Path("...").resolve()))
from mHMM.src.emissions import EmissionModel
from mHMM.src.transitions import TransitionModel

#SSE config
BASE_DIR = Path('...')
DATA_DIR = BASE_DIR / 'data' / 'simulated'
RESULTS_DIR = BASE_DIR / 'data' / 'results' / 'SSE'
STAN_FILE = BASE_DIR / 'models' / 'stan' / 'mhmm_model.stan' 

os.makedirs(RESULTS_DIR, exist_ok = True)

timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
logfile = RESULTS_DIR / f"sse_run_{timestamp}.log"
logging.basicConfig(
    filename= str(logfile),
    filemode = 'a',
    format = '%(asctime)s %(levelname)s: %(message)s',
    level = logging.INFO
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger().addHandler(console)

logging.INFO('SSE runner initialized') 

In [None]:
from numpy.random import SeedSequence, default_rng

def rng_from_seed(seed):
    """Return a np generator deterministically from seed (int or None)"""
    ss = SeedSequence(seed if seed is not None else int(time.time()*1e6) & 0xFFFFFFFF) 
    return default_rng  

def spawn_seed(seed):
    """Return a new entropy int from a SeedSequence-based stream for subprocesses"""
    ss = SeedSequence(seed)
    child = ss.spawn(1)[0] 
    return int(child.entropy)  #convert a number to int or return 0 

Simulation Wrapper (deterministic, Our reference scenario)

In [None]:
def simulate_reference_dataset(seed, N_subj=100, T_weeks =60, init_probs=(0.9, 0.1),
                               trans_params=None, em_params=None): 
    """Simulates a full dataset according to the paper's ref scenario. Return
    a pandas df with columns: ID, Week, State, FEV1, PRO"""
    rng = rng_from_seed(seed)
    if trans_params is None:
        trans_params = {'hPRE':0.1,'hPER':0.3, 'gPRE':0.0,'gPER':0.0, 
                        'trt':0, 'slp':0 }
    if em_params is None:
        em_params = dict(
            hFEV1R=3.0, hFEV1E=0.5,
            x2_FEV1R=0.03, x2_FEV1E=0.03,
            hPROR=2.5, hPROE=0.5,
            x2_PROR=0.09, x2_PROE=0.09,
            r2_FEV1=0.015, r2_PRO=0.05,
            qR=-0.33, qE=-0.33,
            PE=0.2, PHL=10.0
        )

    em = EmissionModel(**em_params) 
    tm = TransitionModel(hpRE=trans_params['hPRE'], hpER=trans_params['hPER'],
                         gpRE=trans_params.get('gPRE',0.0), gpER=trans_params.get('gPER',0.0),
                         trt=trans_params.get('trt',0), slp=trans_params.get('slp',0)) 
    
    trans_mat = tm.transition_matrix()
    times = np.arange(T_weeks)

    rows = []
    for sid in range(1,N_subj+1):
        g = em.sample_individual_effects(rng=rng)
        #simulate states
        states = np.zeros(T_weeks, dtype=int)
        states[0] = rng.choice([0,1], p=init_probs)
        for t in range(1, T_weeks):
            states[t] = rng.choice([0,1], p=trans_mat[states[t-1], :])

        #simulate observations
        for t in range(T_weeks):
            mu = np.array([em.individual_fev1(g, states[t]), em.individual_pro(g, t, states[t])])
            cov = em.emission_cov(states[t])
            y = rng.multivariate_normal(mu, cov)
            rows.append({'ID': sid, 'Week': int(t+1), 'State': int(states[t]), 'FEV1': float(y[0]), 'PRO': float(y[1])})

    df = pd.DataFrame(rows) 
    return df 

# Stan Data Builder

In [None]:
def build_stan_data(df, init_probs=(0.9,0.1),trt_slp_val=0.0):

    df = df.sort_values(['ID', 'Week']).reset_index(drop=True)
    subjects = df["ID"].unique().tolist()
    N = len(subjects)
    y1_list, y2_list, time_list = [], [], []
    subj_start, subj_len, trt_slp = [], [], []

    pos = 0 #position in flat arrays
    for sid in subjects:
        sub=df[df["ID"] == sid].sort_values("Week")
        subj_start.append(pos + 1)  #1-based index for Stan
        L = sub.shape[0] 
        subj_len.append(L)
        pos +=L
        y1_list.extend(sub['FEV1'].astype(float).tolist())
        y2_list.extend(sub['PRO'].astype(float).tolist())
        time_list.extend(sub["Week"].astype(float).tolist())  
        trt_slp.append(float(trt_slp_val)) 

    stan_data = {
        'N':N,
        'T_max': max(subj_len),
        'total_obs':len(y1_list),
        'subj_start': subj_start,
        'subj_len': subj_len,
        'y1_flat':y1_list,
        'y2_flat': y2_list,
        'time_flat': time_list,
        'init_prob': list(init_probs),
        'trt_slp': trt_slp

    }
    return stan_data 


Fit Wrapper

In [None]:
#to compile the model once:

assert STAN_FILE.exists(), f"Stan file not found at {STAN_FILE}"
stan_model = CmdStanModel(stan_file=str(STAN_FILE))

def fit_stan_for_df(df, out_dir, test_mode=True, chains=2, seed=12345, adapt_delta=0.85, threads=1):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    stan_data = build_stan_data(df)
    #If test_mode, reduce iter/chain
    iter_warmup = 400 if test_mode else 1000
    iter_sampling = 400 if test_mode else 1000
    chains = 2 if test_mode else chains 

    try: 
        fit = stan_model.sample(
            data=stan_data,
            chains=chains,
            iter_warmup=iter_warmup,
            iter_sampling = iter_sampling,
            adapt_delta=adapt_delta,
            seed=seed,
            show_progress=False,
            threads=threads
        )

        ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        pkl = out_dir / f"stan_fit_{ts}.pkl"
        with open(pkl, "wb") as f:
            pickle.dump(fit, f)

            fit.save_csvfiles(dir=str(out_dir))
            logging.info(f"Stan fit saved: {pkl}")
            return pkl 
        
    except Exception as e:
        logging.exception("Stan fit failed")
        raise 


# Extract posterior summaries (means & 95% Cl) 

In [None]:
def summarize_stan_fit(pickle_path):

    with open(pickle_path, 'rb') as f:
        fit = pickle.load(f)
    
    df = fit.draws_pd()
    #transform derived vars
    df['hpRE'] = 1/(1+np.exp(-df["logit_hpRE"])) if "logit_hpRE" in df.columns else None
    df['hpER'] = 1/(1+np.exp(-df["logit_hpER"])) if "logit_hpER" in df.columns else None
    if "atanh_qR" in df.columns:
        df["qR"] = np.tanh(df["atanh_qR"]) 
    if "atanh_qE" in df.columns:
        df['qE'] = np.tanh(df['qE'])

    
    keys = ["hFEV1R","hFEV1E","hPROR","hPROE","r2_FEV1","r2_PRO","qR","qE","hpRE","hpER","PE","PHL"]
    result = {}
    for k in keys:
        if k in df.columns:
            arr = df[k].dropna().values
            result[k+"_mean"] = float(np.mean(arr))
            result[k+"_sd"] = float(np.std(arr, ddof=1))
            result[k+"_2.5%"] = float(np.percentile(arr, 2.5))
            result[k+"_97.5%"] = float(np.percentile(arr, 97.5))

    return result 


Single Pipeline : Simulate --> Fit --> Summarize 

In [None]:
def run_single_replicate(rep_idx, seed_base, out_root, test_mode=True, use_stan=True, stan_threads=1):
    """
    -Simulaye dataset
    -fit with stan
    -return dict with metadata and summary, save files"""

    start_time=time.time()
    rep_out = Path(out_root) / f"rep_{rep_idx:04d}"
    rep_out.mkdir(parents=True, exist_ok=True)
    seed = int((seed_base + rep_idx) & 0x7FFFFFFF)
    log_prefix = f"rep_{rep_idx}"

    try:
        #simulate
        df = simulate_reference_dataset(seed, N_subj=100, T_weeks=60)
        #save 
        csvp = rep_out / "simulated.csv"
        df.to_csv(csvp, index=False)

        if use_stan:
            pkl_path = fit_stan_for_df(df, out_dir=rep_out, test_mode = test_mode,seed=seed, threads=stan_threads)
            summary = summarize_stan_fit(pkl_path)
        else:
            summary = {"status": "MLE_not_implemented"}

        elapsed = time.time() - start_time 
        meta = {
            "rep": rep_idx,
            "seed": seed,
            "elapsed_s": elapsed,
            "status": "ok" 
        }
        meta.update(summary)

        with open(rep_out / "summary.pkl", "wb") as f:
            pickle.dump(meta, f)  #save
        return meta

    except Exception as e:
        logging.exception(f"Replicate {rep_idx} failed")

        with open(rep_out / "error.txt", "w") as f:
            f.write(str(e))
        return {"rep": rep_idx, "status":"failed", "error": str(e)} 


# Parallel SSE executor (joblib) and aggregator 

In [None]:
def run_sse(n_reps=50, seed_base=1234, out_root=RESULTS_DIR, test_mode=True,
             use_stan=True, n_jobs=2, stan_threads=1):
    
    logging.info(f"Starting SSE: n_reps={n_reps}, test_mode={test_mode}, use_stan={use_stan}")
    start = time.time()
    results = Parallel(n_jobs=n_jobs)(delayed(run_single_replicate)(i+1, seed_base, out_root, test_mode, use_stan, stan_threads)
                                      for i in tqdm(range(n_reps), desc="SSE replicates")) 
    df_res = pd.DataFrame(results)
    csv_out = Path(out_root) / f"sse_summary_all_{timestamp}.csv"
    df_res.to_csv(csv_out, index=False)
    logging.info(f"SSE finished in {(time.time() - start)/60: .2f} min. Results saved to {csv_out}" )
    return df_res 


In [None]:
#TEST CELL
df_results = run_sse(n_reps=4, seed_base=999, out_root=RESULTS_DIR, test_mode=True, use_stan=True, n_jobs=2)

Post-Processing: bias, RMSE, coverage

In [None]:
def compute_sse_metrics(summary_df, reference_params):
    #taking only successful replicates

    ok = summary_df[summary_df["status"] == "ok"].copy()
    metrics = []
    n_ok = ok.shape[0] 
    logging.info(f"Computing metrics on {n_ok} succesful replicates out of {len(summary_df)}")

    for key, true_val in reference_params.items():
        mean_key = f"{key}_mean"
        lower_key = f"{key}_2.5%"
        upper_key = f"{key}_97.5%"

        if mean_key not in ok.columns:
            logging.warning(f"{mean_key} not in summaries; skipping")
            continue 

        ests = ok[mean_key].astype(float)
        bias = (ests - true_val).mean() 
        rmse = np.sqrt(((ests - true_val)**2).mean()) 

        if lower_key in ok.columns and upper_key in ok.columns:
            cov = ((ok[lower_key] <= true_val) & ok[upper_key] >= true_val).mean()
        else:
            cov= np.nan

        metrics[key] = {"Bias": float(bias), "RMSE": float(rmse), "Coverage": float(cov), "N": int(n_ok)}

    return pd.DataFrame(metrics).T 

## Recommended workflow to run SSE

1. **Smoke test**: run `run_sse(n_reps=4, test_mode=True, n_jobs=2)` to verify the pipeline.
2. **Small experiment**: run `n_reps=20` with `test_mode=True` to validate behavior and resource use.
3. **Full SSE**: run `n_reps=100` (or 200) with `test_mode=False` and `n_jobs` set to the number of CPU cores you can allocate.
   - Use `stan_threads>=1` and ensure your machine has enough memory (Stan uses memory per chain).
4. **Parallelization notes**:
   - `n_jobs` controls number of concurrent replicates. Each replicate launches CmdStan and will use `stan_threads` threads.
   - On HPC, prefer job-array or dask for large SSE jobs to avoid too many simultaneous CmdStan compiles/samples.
5. **Caching**: results are saved per-replicate in `data/results/SSE/rep_XXXX`, so re-runs resume cleanly.
