# Simulation Study Runner

This notebook executes the main simulation study. It performs Monte Carlo simulations for 12 defined scenarios in parallel, leveraging multiple CPU cores. 

**Key Features:**
- **Parallel Processing:** Uses `multiprocessing` to speed up Monte Carlo runs.
- **Checkpointing:** Skips computation for (scenario, MC run) combinations if results (metrics JSON files or raw posterior samples) are already found. This allows resuming interrupted simulations.
- **Output Saving:** Saves detailed results for each run:
    - Raw MCMC posterior samples (`.npz` files) to `OUTPUT_DIR_POSTERIOR_SAMPLES`.
    - Summaries of posteriors (mean, median, quantiles as `.npz` files) to `OUTPUT_DIR_POSTERIOR_SUMMARIES`.
    - Calculated evaluation metrics (`.json` files) to `OUTPUT_DIR_RUN_METRICS_JSON`.
- Aggregated metrics per scenario are also saved as CSV files to `OUTPUT_DIR_RESULTS_CSV`.

**Prerequisites:**
1. Ensure `sampler.py` is modified to accept `rng_key` in its `sample` function (as detailed above).
2. All helper Python modules (`config.py`, `data_generation.py`, `benchmarks.py`, `model_fitting.py`, `evaluation.py`, `results_io.py`) must be in the same directory or accessible in `PYTHONPATH`.

## 1. Imports and Configuration

In [1]:
import pandas as pd
import numpy as np
import os
import time
import jax 
import json
from joblib import Parallel, delayed # For parallel processing
from tqdm import tqdm # For progress bars

# Custom modules
import config
import data_generation
import benchmarks
import model_fitting
import evaluation
import results_io 

print(f"Using {config.NUM_CORES_TO_USE} cores for parallel processing.")
print(f"Global base seed for the study: {config.GLOBAL_BASE_SEED}")

Using 10 cores for parallel processing.
Global base seed for the study: 2025


  from .autonotebook import tqdm as notebook_tqdm


## 2. Worker Function for Single Monte Carlo Replication

This function encapsulates all operations for one Monte Carlo run. It includes checkpointing. Plot data is not returned by this worker in this version; `analyze_results_notebook.ipynb` will handle plot data generation from saved summaries.

In [2]:
def run_single_mc_replication(args_tuple):
    """
    Executes one full Monte Carlo replication for a given scenario.

    This function is designed to be called by joblib.Parallel. It handles data generation, 
    model fitting, metric calculation, and saving results for a single run, including 
    checkpointing to skip re-computation if results already exist.

    Args:
        args_tuple (tuple): A tuple containing the arguments for the worker.
            - scenario_config_dict (dict): Configuration for the specific scenario.
            - mc_run_idx (int): The index of the current Monte Carlo run (e.g., 0 to N-1).
            - scenario_base_seed (int): The base seed for this scenario's set of runs.

    Returns:
        tuple: (metrics_dictionary, None). plot_data is None as plotting is handled later.
    """
    scenario_config_dict, mc_run_idx, scenario_base_seed = args_tuple
    scenario_id = scenario_config_dict["id"]
    
    run_specific_seed_dgp = scenario_base_seed + mc_run_idx 
    jax_prng_key_mcmc = jax.random.PRNGKey(scenario_base_seed + mc_run_idx + 100000) 

    metrics_filepath = os.path.join(
        config.OUTPUT_DIR_RUN_METRICS_JSON,
        f"metrics_scen_{scenario_id}_run_{mc_run_idx+1}.json"
    )
    raw_posterior_samples_filepath = os.path.join(
        config.OUTPUT_DIR_POSTERIOR_SAMPLES,
        f"scen_{scenario_id}_run_{mc_run_idx+1}_posterior_samples.npz"
    )

    # --- Checkpoint: If final metrics file exists for a successful run, skip ---
    if os.path.exists(metrics_filepath):
        loaded_metrics = results_io.load_run_metrics(scenario_id, mc_run_idx, config.OUTPUT_DIR_RUN_METRICS_JSON)
        if loaded_metrics and loaded_metrics.get("error") in [None, "None"]:
            loaded_metrics["status"] = "loaded_metrics_from_file"
            return loaded_metrics, None 
    
    current_run_metrics = {"scenario_id": scenario_id, "mc_run": mc_run_idx + 1, "error": "None", "status": "init"}
    
    try:
        # Checkpoint: Load existing raw samples if MCMC was already run
        posterior_samples = None
        if os.path.exists(raw_posterior_samples_filepath):
            posterior_samples = results_io.load_raw_posterior_samples(scenario_id, mc_run_idx, config.OUTPUT_DIR_POSTERIOR_SAMPLES)
            current_run_metrics["status"] = "loaded_raw_samples"
        
        # Regenerate data deterministically for this run to get true values for metrics
        sim_data = data_generation.simulate_scenario_data(scenario_config_dict, run_seed=run_specific_seed_dgp)

        # If samples were not loaded, run the MCMC model fitting
        if posterior_samples is None: 
            posterior_samples, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key_mcmc)
            current_run_metrics["status"] = "computed_mcmc"
            results_io.save_raw_posterior_samples(scenario_id, mc_run_idx, posterior_samples, config.OUTPUT_DIR_POSTERIOR_SAMPLES)
        
        # Save posterior summary (mean, quantiles) from raw samples
        results_io.save_posterior_summary_for_run(
            scenario_id, mc_run_idx, posterior_samples, config.OUTPUT_DIR_POSTERIOR_SUMMARIES
        )
            
        # Calculate benchmarks and metrics
        benchmark_r_t_estimates = {
            "cCFR_cumulative": benchmarks.calculate_crude_cfr(sim_data["d_t"], sim_data["c_t"], cumulative=True),
            "aCFR_cumulative": benchmarks.calculate_nishiura_cfr_cumulative(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
        }
        
        # Call the new Bayesian CI function (fast, no loop)
        benchmark_cis = benchmarks.calculate_benchmark_cis_with_bayesian(
            sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"]
        )

        # NEW: Calculate ITS benchmark results
        its_results = benchmarks.calculate_its_with_poisson_mle(
            d_t=sim_data["d_t"], c_t=sim_data["c_t"], f_s=sim_data["f_s_true"],
            Bm=sim_data["Bm_true"],
            intervention_times_abs=sim_data["true_intervention_times_0_abs"],
            intervention_signs=sim_data["beta_signs_true"]
        )

        # --- Calculate Evaluation Metrics (pass the new benchmark_cis) ---
        # calculated_metrics = evaluation.collect_all_metrics(
        #     sim_data, posterior_samples, benchmark_r_t_estimates, benchmark_cis
        # )
        calculated_metrics = evaluation.collect_all_metrics(
            sim_data, posterior_samples, benchmark_r_t_estimates, benchmark_cis, its_results
        )
        current_run_metrics.update(calculated_metrics)
        
        # Save the calculated metrics for this specific run
        results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)

    except Exception as e:
        import traceback
        error_message = f"ERROR in WORKER {os.getpid()} for Scen {scenario_id}, Run {mc_run_idx + 1}: {type(e).__name__} - {e}\n{traceback.format_exc()}"
        print(f"\n---!!! {error_message} !!!---\n") # Make error more visible
        current_run_metrics["error"] = error_message
        current_run_metrics["status"] = "error_in_run"
        try:
            results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)
        except Exception as save_e:
            print(f"    WORKER {os.getpid()}: CRITICAL ERROR - Could not save error metrics: {save_e}")
            
    return current_run_metrics, None

## 3. Main Simulation Execution Function

In [5]:
def main_simulation_runner():
    """
    Main function to orchestrate the entire simulation study.
    It iterates through all defined scenarios and uses a joblib Parallel backend 
    to execute all Monte Carlo runs with real-time progress bars.
    """
    start_time_total = time.time()

    # Create all output directories
    for dir_path in [config.OUTPUT_DIR_PLOTS, config.OUTPUT_DIR_TABLES, 
                     config.OUTPUT_DIR_RESULTS_CSV, config.OUTPUT_DIR_POSTERIOR_SUMMARIES,
                     config.OUTPUT_DIR_POSTERIOR_SAMPLES, config.OUTPUT_DIR_RUN_METRICS_JSON]:
        os.makedirs(dir_path, exist_ok=True)

    all_results_metrics_list = []
    with tqdm(total=len(config.SCENARIOS), desc="Overall Scenario Progress") as outer_pbar:
        for scenario_idx, scenario_config_dict in enumerate(config.SCENARIOS):
            scenario_id = scenario_config_dict["id"]
            outer_pbar.set_description(f"Processing Scen: {scenario_id}")
            
            scenario_base_seed = config.GLOBAL_BASE_SEED + (scenario_idx * config.NUM_MONTE_CARLO_RUNS * 1000)
            mc_args_list = [(scenario_config_dict, i, scenario_base_seed) for i in range(config.NUM_MONTE_CARLO_RUNS)]

            try:
                # Use joblib Parallel for execution with a nested tqdm progress bar
                scenario_run_outputs = Parallel(n_jobs=config.NUM_CORES_TO_USE)(
                    delayed(run_single_mc_replication)(args) for args in tqdm(
                        mc_args_list, desc=f"  MC Runs for {scenario_id}", leave=False, position=1
                    )
                )
                
                current_scenario_metrics_list = [result[0] for result in scenario_run_outputs]
                
                if current_scenario_metrics_list:
                    df_scenario_metrics = pd.DataFrame(current_scenario_metrics_list)
                    csv_path = os.path.join(config.OUTPUT_DIR_RESULTS_CSV, f"metrics_scen_{scenario_id}.csv")
                    df_scenario_metrics.to_csv(csv_path, index=False)
                    all_results_metrics_list.extend(current_scenario_metrics_list)
            
            except Exception as e_pool:
                print(f"\nCritical error during parallel processing for scenario {scenario_id}: {e_pool}")
            
            outer_pbar.update(1)

    end_time_total = time.time()
    print(f"\nAll Monte Carlo simulations completed in {end_time_total - start_time_total:.2f} seconds.")

    # Save combined results
    if all_results_metrics_list:
        results_df_all = pd.DataFrame(all_results_metrics_list)
        results_df_all.to_csv(os.path.join(config.OUTPUT_DIR_RESULTS_CSV, "all_scenarios_metrics_combined.csv"), index=False)
        print(f"Combined metrics saved.")
    
    print("\nSimulation runs complete. Use analyze_results.py to generate plots and final tables.")



## 4. Execute Simulation

The following cell will start the simulation process. Ensure all configurations in `config.py` are set as desired and `sampler.py` has been modified for dynamic `rng_key`.

In [6]:
# if __name__ == '__main__': # This idiom is helpful if you convert notebook to .py
#     # Important for multiprocessing on Windows, and good practice elsewhere
# multiprocessing.freeze_support() 
main_simulation_runner()

Processing Scen: S01:   0%|                                           | 0/12 [00:00<?, ?it/s]
  MC Runs for S01:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S02:   8%|██▉                                | 1/12 [00:16<03:03, 16.73s/it][A
  MC Runs for S02:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S03:  17%|█████▊                             | 2/12 [00:29<02:21, 14.19s/it][A
  MC Runs for S03:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S04:  25%|████████▊                          | 3/12 [00:43<02:10, 14.48s/it][A
  MC Runs for S04:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S05:  33%|███████████▋                       | 4/12 [00:53<01:41, 12.63s/it][A
  MC Runs for S05:   0%|                                              | 0/10 [00:00<?, ?it/s][A
                                 


---!!! ERROR in WORKER 663415 for Scen S01, Run 5: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multivariate_normal(popt, pcov, size=n_mc_ci)
                                                        ^^^^
UnboundLocalError: cannot access local variable 'pcov' where it is not associated with a value
 !!!---


---!!! ERROR in WORKER 663415 for Scen S02, Run 5: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multiva

Processing Scen: S06:  42%|██████████████▌                    | 5/12 [01:10<01:39, 14.17s/it]
  MC Runs for S06:   0%|                                              | 0/10 [00:00<?, ?it/s][A
                                                                                             [A


---!!! ERROR in WORKER 663412 for Scen S01, Run 1: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multivariate_normal(popt, pcov, size=n_mc_ci)
                                                        ^^^^
UnboundLocalError: cannot access local variable 'pcov' where it is not associated with a value
 !!!---


---!!! ERROR in WORKER 663412 for Scen S02, Run 8: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multiva

Processing Scen: S07:  50%|█████████████████▌                 | 6/12 [01:30<01:36, 16.08s/it]
  MC Runs for S07:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S08:  58%|████████████████████▍              | 7/12 [01:45<01:19, 15.81s/it][A
  MC Runs for S08:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S09:  67%|███████████████████████▎           | 8/12 [01:56<00:56, 14.22s/it][A
  MC Runs for S09:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S10:  75%|██████████████████████████▎        | 9/12 [02:12<00:43, 14.66s/it][A
  MC Runs for S10:   0%|                                              | 0/10 [00:00<?, ?it/s][A
Processing Scen: S11:  83%|████████████████████████████▎     | 10/12 [02:19<00:25, 12.52s/it][A
  MC Runs for S11:   0%|                                              | 0/10 [00:00<?, ?it/s][A
                                 


---!!! ERROR in WORKER 668060 for Scen S07, Run 10: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multivariate_normal(popt, pcov, size=n_mc_ci)
                                                        ^^^^
UnboundLocalError: cannot access local variable 'pcov' where it is not associated with a value
 !!!---


---!!! ERROR in WORKER 668060 for Scen S08, Run 9: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multiv

Processing Scen: S12:  92%|███████████████████████████████▏  | 11/12 [02:28<00:11, 11.30s/it]
  MC Runs for S12:   0%|                                              | 0/10 [00:00<?, ?it/s][A
                                                                                             [A


---!!! ERROR in WORKER 668061 for Scen S07, Run 9: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multivariate_normal(popt, pcov, size=n_mc_ci)
                                                        ^^^^
UnboundLocalError: cannot access local variable 'pcov' where it is not associated with a value
 !!!---


---!!! ERROR in WORKER 668061 for Scen S08, Run 10: UnboundLocalError - cannot access local variable 'pcov' where it is not associated with a value
Traceback (most recent call last):
  File "/tmp/ipykernel_663324/330236072.py", line 75, in run_single_mc_replication
  File "/mnt/data/users/hengtao/simulation/benchmarks.py", line 187, in calculate_its_with_nls
    param_samples = np.random.multiv

Processing Scen: S12: 100%|██████████████████████████████████| 12/12 [02:46<00:00, 13.91s/it]


All Monte Carlo simulations completed in 166.98 seconds.
Combined metrics saved.

Simulation runs complete. Use analyze_results.py to generate plots and final tables.



