# Simulation Study Runner

This notebook executes the main simulation study. It performs Monte Carlo simulations for 12 defined scenarios in parallel, leveraging multiple CPU cores. 

**Key Features:**
- **Parallel Processing:** Uses `multiprocessing` to speed up Monte Carlo runs.
- **Checkpointing:** Skips computation for (scenario, MC run) combinations if results (metrics JSON files or raw posterior samples) are already found. This allows resuming interrupted simulations.
- **Output Saving:** Saves detailed results for each run:
    - Raw MCMC posterior samples (`.npz` files) to `OUTPUT_DIR_POSTERIOR_SAMPLES`.
    - Summaries of posteriors (mean, median, quantiles as `.npz` files) to `OUTPUT_DIR_POSTERIOR_SUMMARIES`.
    - Calculated evaluation metrics (`.json` files) to `OUTPUT_DIR_RUN_METRICS_JSON`.
- Aggregated metrics per scenario are also saved as CSV files to `OUTPUT_DIR_RESULTS_CSV`.

**Prerequisites:**
1. Ensure `sampler.py` is modified to accept `rng_key` in its `sample` function (as detailed above).
2. All helper Python modules (`config.py`, `data_generation.py`, `benchmarks.py`, `model_fitting.py`, `evaluation.py`, `results_io.py`) must be in the same directory or accessible in `PYTHONPATH`.

## 1. Imports and Configuration

In [1]:
import pandas as pd
import numpy as np
import os
import time
import jax 
import json
from joblib import Parallel, delayed # For parallel processing
from tqdm import tqdm # For progress bars

# Custom modules
import config
import data_generation
import benchmarks
import model_fitting
import evaluation
import results_io 

print(f"Using {config.NUM_CORES_TO_USE} cores for parallel processing.")
print(f"Global base seed for the study: {config.GLOBAL_BASE_SEED}")

Using 5 cores for parallel processing.
Global base seed for the study: 2025


## 2. Worker Function for Single Monte Carlo Replication

This function encapsulates all operations for one Monte Carlo run. It includes checkpointing. Plot data is not returned by this worker in this version; `analyze_results_notebook.ipynb` will handle plot data generation from saved summaries.

In [2]:
# def run_single_mc_replication(args_tuple):
#     """
#     Executes one full Monte Carlo replication for a given scenario.

#     This function is designed to be called by joblib.Parallel. It handles data generation, 
#     model fitting, metric calculation, and saving results for a single run, including 
#     checkpointing to skip re-computation if results already exist.

#     Args:
#         args_tuple (tuple): A tuple containing the arguments for the worker.
#             - scenario_config_dict (dict): Configuration for the specific scenario.
#             - mc_run_idx (int): The index of the current Monte Carlo run (e.g., 0 to N-1).
#             - scenario_base_seed (int): The base seed for this scenario's set of runs.

#     Returns:
#         tuple: (metrics_dictionary, None). plot_data is None as plotting is handled later.
#     """

#     scenario_config_dict, mc_run_idx, scenario_base_seed = args_tuple
#     scenario_id = scenario_config_dict["id"]
    
#     metrics_filepath = os.path.join(
#         config.OUTPUT_DIR_RUN_METRICS_JSON,
#         f"metrics_scen_{scenario_id}_run_{mc_run_idx+1}.json"
#     )

#     if not config.OVERWRITE_EXISTING_RESULTS and os.path.exists(metrics_filepath):
#         loaded_metrics = results_io.load_run_metrics(scenario_id, mc_run_idx, config.OUTPUT_DIR_RUN_METRICS_JSON)
#         if loaded_metrics and loaded_metrics.get("error") in [None, "None"]:
#             return loaded_metrics, None
    
#     run_specific_seed_dgp = scenario_base_seed + mc_run_idx 
#     jax_prng_key = jax.random.PRNGKey(run_specific_seed_dgp + 100000)
#     current_run_metrics = {"scenario_id": scenario_id, "mc_run": mc_run_idx + 1, "error": "None"}
    
#     try:
#         sim_data = data_generation.simulate_scenario_data(scenario_config_dict, run_seed=run_specific_seed_dgp)
        
#         # --- Model Fitting with separate checkpointing for raw samples ---
#         posterior_samples_scfr = results_io.load_raw_posterior_samples(scenario_id, mc_run_idx, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR")
#         if posterior_samples_scfr is None:
#             posterior_samples_scfr, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key, use_orthogonalization=False)
#             results_io.save_raw_posterior_samples(scenario_id, mc_run_idx, posterior_samples_scfr, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR")
        
#         posterior_samples_scfr_o = None
#         if config.COMPARE_SCFR_AND_SCFR_O:
#             posterior_samples_scfr_o = results_io.load_raw_posterior_samples(scenario_id, mc_run_idx, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR_O")
#             if posterior_samples_scfr_o is None:
#                 jax_prng_key, subkey = jax.random.split(jax_prng_key)
#                 posterior_samples_scfr_o, _ = model_fitting.fit_proposed_model(sim_data, subkey, use_orthogonalization=True)
#                 results_io.save_raw_posterior_samples(scenario_id, mc_run_idx, posterior_samples_scfr_o, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR_O")
        
#         # --- Calculate benchmarks and all metrics ---
#         benchmark_r_t_estimates = {
#             "cCFR_cumulative": benchmarks.calculate_crude_cfr(sim_data["d_t"], sim_data["c_t"], cumulative=True),
#             "aCFR_cumulative": benchmarks.calculate_nishiura_cfr_cumulative(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
#         }
#         benchmark_cis = benchmarks.calculate_benchmark_cis_with_bayesian(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
#         its_results = benchmarks.calculate_its_with_parametric_bootstrap(
#             d_t=sim_data["d_t"],
#             c_t=sim_data["c_t"],
#             f_s=sim_data["f_s_true"],
#             Bm=sim_data["Bm_true"],
#             intervention_times_abs=sim_data["true_intervention_times_0_abs"],
#             intervention_signs=sim_data["beta_signs_true"]
#         )
        
#         calculated_metrics = evaluation.collect_all_metrics(
#             sim_data, posterior_samples_scfr, posterior_samples_scfr_o, 
#             benchmark_r_t_estimates, benchmark_cis, its_results
#         )
#         current_run_metrics.update(calculated_metrics)
#         current_run_metrics["status"] = "computed"
#         results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)
        
#     # scenario_config_dict, mc_run_idx, scenario_base_seed = args_tuple
#     # scenario_id = scenario_config_dict["id"]

#     # metrics_filepath = os.path.join(
#     #     config.OUTPUT_DIR_RUN_METRICS_JSON,
#     #     f"metrics_scen_{scenario_id}_run_{mc_run_idx+1}.json"
#     # )
#     # raw_posterior_samples_filepath = os.path.join(
#     #     config.OUTPUT_DIR_POSTERIOR_SAMPLES,
#     #     f"scen_{scenario_id}_run_{mc_run_idx+1}_posterior_samples.npz"
#     # )

#     # # --- Checkpoint: If final metrics file exists for a successful run, skip ---
#     # if not config.OVERWRITE_EXISTING_RESULTS and os.path.exists(metrics_filepath):
#     #     loaded_metrics = results_io.load_run_metrics(scenario_id, mc_run_idx, config.OUTPUT_DIR_RUN_METRICS_JSON)
#     #     if loaded_metrics and loaded_metrics.get("error") in [None, "None"]:
#     #         return loaded_metrics, None
    
#     # run_specific_seed_dgp = scenario_base_seed + mc_run_idx 
#     # jax_prng_key = jax.random.PRNGKey(run_specific_seed_dgp + 100000)
#     # current_run_metrics = {"scenario_id": scenario_id, "mc_run": mc_run_idx + 1, "error": "None"}
    
#     # try:
#     #     # Checkpoint: Load existing raw samples if MCMC was already run
#     #     sim_data = data_generation.simulate_scenario_data(scenario_config_dict, run_seed=run_specific_seed_dgp)
        
#     #     # --- Model Fitting with separate checkpointing for raw samples ---
#     #     posterior_samples_scfr = results_io.load_raw_posterior_samples(scenario_id, mc_run_idx, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR")
#     #     if posterior_samples_scfr is None:
#     #         posterior_samples_scfr, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key, use_orthogonalization=False)
#     #         results_io.save_raw_posterior_samples(scenario_id, mc_run_idx, posterior_samples_scfr, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR")
        
#     #     posterior_samples_scfr_o = None
#     #     if config.COMPARE_SCFR_AND_SCFR_O:
#     #         posterior_samples_scfr_o = results_io.load_raw_posterior_samples(scenario_id, mc_run_idx, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR_O")
#     #         if posterior_samples_scfr_o is None:
#     #             jax_prng_key, subkey = jax.random.split(jax_prng_key)
#     #             posterior_samples_scfr_o, _ = model_fitting.fit_proposed_model(sim_data, subkey, use_orthogonalization=True)
#     #             results_io.save_raw_posterior_samples(scenario_id, mc_run_idx, posterior_samples_scfr_o, config.OUTPUT_DIR_POSTERIOR_SAMPLES, model_name="sCFR_O")
        
#     #     # posterior_samples = None
#     #     # if os.path.exists(raw_posterior_samples_filepath):
#     #     #     posterior_samples = results_io.load_raw_posterior_samples(scenario_id, mc_run_idx, config.OUTPUT_DIR_POSTERIOR_SAMPLES)
#     #     #     current_run_metrics["status"] = "loaded_raw_samples"
        
#     #     # # Regenerate data deterministically for this run to get true values for metrics
#     #     # sim_data = data_generation.simulate_scenario_data(scenario_config_dict, run_seed=run_specific_seed_dgp)

#     #     # posterior_samples_scfr = None
#     #     # posterior_samples_scfr_o = None
        
#     #     # # Fit standard sCFR model
#     #     # print(f"    -> Fitting standard sCFR model...")
#     #     # posterior_samples_scfr, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key, use_orthogonalization=False)
#     #     # results_io.save_posterior_summary_for_run(scenario_id, mc_run_idx, posterior_samples_scfr, config.OUTPUT_DIR_POSTERIOR_SUMMARIES, model_name="sCFR")

#     #     # if config.COMPARE_SCFR_AND_SCFR_O:
#     #     #     # Fit orthogonalized sCFR-O model
#     #     #     print(f"    -> Fitting orthogonalized sCFR-O model...")
#     #     #     jax_prng_key, subkey = jax.random.split(jax_prng_key)
#     #     #     posterior_samples_scfr_o, _ = model_fitting.fit_proposed_model(sim_data, subkey, use_orthogonalization=True)
#     #     #     results_io.save_posterior_summary_for_run(scenario_id, mc_run_idx, posterior_samples_scfr_o, config.OUTPUT_DIR_POSTERIOR_SUMMARIES, model_name="sCFR_O")

#     #     # # If samples were not loaded, run the MCMC model fitting
#     #     # if posterior_samples is None: 
#     #     #     posterior_samples, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key_mcmc)
#     #     #     current_run_metrics["status"] = "computed_mcmc"
#     #     #     results_io.save_raw_posterior_samples(scenario_id, mc_run_idx, posterior_samples, config.OUTPUT_DIR_POSTERIOR_SAMPLES)
        
#     #     # # Save posterior summary (mean, quantiles) from raw samples
#     #     # results_io.save_posterior_summary_for_run(
#     #     #     scenario_id, mc_run_idx, posterior_samples, config.OUTPUT_DIR_POSTERIOR_SUMMARIES
#     #     # )
            
#     #     # Calculate benchmarks and metrics
#     #     benchmark_r_t_estimates = {
#     #         "cCFR_cumulative": benchmarks.calculate_crude_cfr(sim_data["d_t"], sim_data["c_t"], cumulative=True),
#     #         "aCFR_cumulative": benchmarks.calculate_nishiura_cfr_cumulative(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
#     #     }
        
#     #     # Call the new Bayesian CI function (fast, no loop)
#     #     benchmark_cis = benchmarks.calculate_benchmark_cis_with_bayesian(
#     #         sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"]
#     #     )

#     #     # NEW: Calculate ITS benchmark results
#     #     its_results = benchmarks.calculate_its_with_parametric_bootstrap(
#     #         d_t=sim_data["d_t"],
#     #         c_t=sim_data["c_t"],
#     #         f_s=sim_data["f_s_true"],
#     #         Bm=sim_data["Bm_true"],
#     #         intervention_times_abs=sim_data["true_intervention_times_0_abs"],
#     #         intervention_signs=sim_data["beta_signs_true"]
#     #     )

#     #     # --- Calculate Evaluation Metrics (pass the new benchmark_cis) ---
#     #     # calculated_metrics = evaluation.collect_all_metrics(
#     #     #     sim_data, posterior_samples, benchmark_r_t_estimates, benchmark_cis
#     #     # )
#     #     # calculated_metrics = evaluation.collect_all_metrics(
#     #     #     sim_data, posterior_samples, benchmark_r_t_estimates, benchmark_cis, its_results
#     #     # )
#     #     calculated_metrics = evaluation.collect_all_metrics(
#     #         sim_data, posterior_samples_scfr, posterior_samples_scfr_o, 
#     #         benchmark_r_t_estimates, benchmark_cis, its_results
#     #     )
#     #     current_run_metrics.update(calculated_metrics)
#     #     current_run_metrics["status"] = "computed"
#     #     results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)
        

#     except Exception as e:
#         import traceback
#         error_message = f"ERROR in WORKER {os.getpid()} for Scen {scenario_id}, Run {mc_run_idx + 1}: {type(e).__name__} - {e}\n{traceback.format_exc()}"
#         print(f"\n---!!! {error_message} !!!---\n") # Make error more visible
#         current_run_metrics["error"] = error_message
#         current_run_metrics["status"] = "error_in_run"
#         try:
#             results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)
#         except Exception as save_e:
#             print(f"    WORKER {os.getpid()}: CRITICAL ERROR - Could not save error metrics: {save_e}")
            
#     return current_run_metrics, None

In [3]:
def run_single_mc_replication(args_tuple):
    """Executes one full Monte Carlo replication for a given scenario."""
    scenario_config_dict, mc_run_idx, scenario_base_seed = args_tuple
    scenario_id = scenario_config_dict["id"]
    
    metrics_filepath = os.path.join(
        config.OUTPUT_DIR_RUN_METRICS_JSON,
        f"metrics_scen_{scenario_id}_run_{mc_run_idx+1}.json"
    )

    if not config.OVERWRITE_EXISTING_RESULTS and os.path.exists(metrics_filepath):
        loaded_metrics = results_io.load_run_metrics(scenario_id, mc_run_idx, config.OUTPUT_DIR_RUN_METRICS_JSON)
        if loaded_metrics and loaded_metrics.get("error") in [None, "None"]:
            return loaded_metrics, None
    
    run_specific_seed_dgp = scenario_base_seed + mc_run_idx 
    jax_prng_key = jax.random.PRNGKey(run_specific_seed_dgp + 100000)
    current_run_metrics = {"scenario_id": scenario_id, "mc_run": mc_run_idx + 1, "error": "None"}
    
    try:
        sim_data = data_generation.simulate_scenario_data(scenario_config_dict, run_seed=run_specific_seed_dgp)
        
        # Fit the standard sCFR model
        posterior_samples_scfr, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key)
        results_io.save_posterior_summary_for_run(scenario_id, mc_run_idx, posterior_samples_scfr, config.OUTPUT_DIR_POSTERIOR_SUMMARIES)

        # Benchmark Calculations
        benchmark_r_t_estimates = {
            "cCFR_cumulative": benchmarks.calculate_crude_cfr(sim_data["d_t"], sim_data["c_t"], cumulative=True),
            "aCFR_cumulative": benchmarks.calculate_nishiura_cfr_cumulative(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
        }
        benchmark_cis = benchmarks.calculate_benchmark_cis_with_bayesian(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
        
        its_results = benchmarks.calculate_its_with_penalized_mle(
            d_t=sim_data["d_t"], c_t=sim_data["c_t"], f_s=sim_data["f_s_true"],
            Bm=sim_data["Bm_true"],
            intervention_times_abs=sim_data["true_intervention_times_0_abs"],
            intervention_signs=sim_data["beta_signs_true"]
        )

        all_benchmark_results = {**benchmark_r_t_estimates, **benchmark_cis, **its_results}
        results_io.save_benchmark_results(scenario_id, mc_run_idx, all_benchmark_results, config.OUTPUT_DIR_BENCHMARK_RESULTS)
        
        # Collect all scalar metrics
        calculated_metrics = evaluation.collect_all_metrics(
            sim_data, posterior_samples_scfr,
            benchmark_r_t_estimates, benchmark_cis, its_results
        )
        current_run_metrics.update(calculated_metrics)
        results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)
        
    except Exception as e:
        import traceback
        error_message = f"ERROR in WORKER: {e}\n{traceback.format_exc()}"
        current_run_metrics["error"] = error_message
        results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)
            
    return current_run_metrics, None

## 3. Main Simulation Execution Function

In [4]:
def main_simulation_runner():
    """
    Main function to orchestrate the entire simulation study.
    It iterates through all defined scenarios and uses a joblib Parallel backend 
    to execute all Monte Carlo runs with real-time progress bars.
    """
    start_time_total = time.time()

    # Create all output directories
    # for dir_path in [config.OUTPUT_DIR_PLOTS, config.OUTPUT_DIR_TABLES, 
    #                  config.OUTPUT_DIR_RESULTS_CSV, config.OUTPUT_DIR_POSTERIOR_SUMMARIES,
    #                  config.OUTPUT_DIR_POSTERIOR_SAMPLES, config.OUTPUT_DIR_RUN_METRICS_JSON]:
    for dir_path in [config.OUTPUT_DIR_TABLES, config.OUTPUT_DIR_RESULTS_CSV, config.OUTPUT_DIR_POSTERIOR_SUMMARIES,
                     config.OUTPUT_DIR_BENCHMARK_RESULTS,config.OUTPUT_DIR_POSTERIOR_SAMPLES, config.OUTPUT_DIR_RUN_METRICS_JSON]:
        os.makedirs(dir_path, exist_ok=True)

    all_results_metrics_list = []
    with tqdm(total=len(config.SCENARIOS), desc="Overall Scenario Progress") as outer_pbar:
        for scenario_idx, scenario_config_dict in enumerate(config.SCENARIOS):
            scenario_id = scenario_config_dict["id"]
            outer_pbar.set_description(f"Processing Scen: {scenario_id}")
            
            scenario_base_seed = config.GLOBAL_BASE_SEED + (scenario_idx * config.NUM_MONTE_CARLO_RUNS * 1000)
            mc_args_list = [(scenario_config_dict, i, scenario_base_seed) for i in range(config.NUM_MONTE_CARLO_RUNS)]

            try:
                # Use joblib Parallel for execution with a nested tqdm progress bar
                scenario_run_outputs = Parallel(n_jobs=config.NUM_CORES_TO_USE)(
                    delayed(run_single_mc_replication)(args) for args in tqdm(
                        mc_args_list, desc=f"  MC Runs for {scenario_id}", leave=False, position=1
                    )
                )
                
                current_scenario_metrics_list = [result[0] for result in scenario_run_outputs]
                
                if current_scenario_metrics_list:
                    df_scenario_metrics = pd.DataFrame(current_scenario_metrics_list)
                    csv_path = os.path.join(config.OUTPUT_DIR_RESULTS_CSV, f"metrics_scen_{scenario_id}.csv")
                    df_scenario_metrics.to_csv(csv_path, index=False)
                    all_results_metrics_list.extend(current_scenario_metrics_list)
            
            except Exception as e_pool:
                print(f"\nCritical error during parallel processing for scenario {scenario_id}: {e_pool}")
            
            outer_pbar.update(1)

    end_time_total = time.time()
    print(f"\nAll Monte Carlo simulations completed in {end_time_total - start_time_total:.2f} seconds.")

    # Save combined results
    if all_results_metrics_list:
        results_df_all = pd.DataFrame(all_results_metrics_list)
        results_df_all.to_csv(os.path.join(config.OUTPUT_DIR_RESULTS_CSV, "all_scenarios_metrics_combined.csv"), index=False)
        print(f"Combined metrics saved.")
    
    print("\nSimulation runs complete. Use analyze_results.py to generate plots and final tables.")

## 4. Execute Simulation

The following cell will start the simulation process. Ensure all configurations in `config.py` are set as desired and `sampler.py` has been modified for dynamic `rng_key`.

In [5]:
# if __name__ == '__main__': # This idiom is helpful if you convert notebook to .py
#     # Important for multiprocessing on Windows, and good practice elsewhere
# multiprocessing.freeze_support() 
main_simulation_runner()

Processing Scen: S01:   0%|                                                                     | 0/12 [00:00<?, ?it/s]
[AC Runs for S01:   0%|                                                                        | 0/10 [00:00<?, ?it/s]
Processing Scen: S01:   0%|                                                                     | 0/12 [00:03<?, ?it/s]

KeyboardInterrupt

