# Simulation Study Runner

This notebook executes the main simulation study. It performs Monte Carlo simulations for all defined scenarios in parallel, leveraging multiple CPU cores and displaying progress with `tqdm` progress bars.

**Key Features:**
- **Parallel Processing:** Uses `joblib` to speed up Monte Carlo runs.
- **Progress Bars:** Provides real-time feedback on the simulation progress.
- **Checkpointing:** Skips computation for runs where results are already found, allowing simulations to be resumed.
- **Comprehensive Output:** Saves raw MCMC samples, posterior summaries, benchmark results, and evaluation metrics for each run.

## 1. Imports and Configuration

In [1]:
import pandas as pd
import numpy as np
import os
import time
import jax 
import json
from joblib import Parallel, delayed
from tqdm import tqdm

# Custom modules
import config
import data_generation
import benchmarks
import model_fitting
import evaluation
import results_io 

print(f"Using {config.NUM_CORES_TO_USE} cores for parallel processing.")
print(f"Global base seed for the study: {config.GLOBAL_BASE_SEED}")

Using 5 cores for parallel processing.
Global base seed for the study: 2025


## 2. Worker Function for a Single Monte Carlo Run

This function encapsulates all operations for one simulation run. It is designed to be called in parallel by `joblib`.

In [2]:
def run_single_mc_replication(args_tuple):
    """Executes one full Monte Carlo replication for a given scenario."""
    scenario_config_dict, mc_run_idx, scenario_base_seed = args_tuple
    scenario_id = scenario_config_dict["id"]
    
    # --- Checkpoint: Skip if results already exist and overwriting is disabled ---
    if not config.OVERWRITE_EXISTING_RESULTS:
        if results_io.load_run_metrics(scenario_id, mc_run_idx, config.OUTPUT_DIR_RUN_METRICS_JSON):
            return results_io.load_run_metrics(scenario_id, mc_run_idx, config.OUTPUT_DIR_RUN_METRICS_JSON)
    
    run_specific_seed = scenario_base_seed + mc_run_idx 
    jax_prng_key = jax.random.PRNGKey(run_specific_seed + 100000)
    current_run_metrics = {"scenario_id": scenario_id, "mc_run": mc_run_idx + 1, "error": "None"}
    
    try:
        # 1. Generate Data
        sim_data = data_generation.simulate_scenario_data(scenario_config_dict, run_seed=run_specific_seed)
        
        # 2. Fit the proposed sCFR model
        posterior_samples_scfr, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key)
        
        # 3. Save raw samples and summaries for the sCFR model
        results_io.save_raw_posterior_samples(scenario_id, mc_run_idx, posterior_samples_scfr, config.OUTPUT_DIR_POSTERIOR_SAMPLES)
        results_io.save_posterior_summary_for_run(scenario_id, mc_run_idx, posterior_samples_scfr, config.OUTPUT_DIR_POSTERIOR_SUMMARIES)

        # 4. Run all benchmark models
        benchmark_r_t_estimates = {
            "cCFR_cumulative": benchmarks.calculate_crude_cfr(sim_data["d_t"], sim_data["c_t"], cumulative=True),
            "aCFR_cumulative": benchmarks.calculate_nishiura_cfr_cumulative(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
        }
        benchmark_cis = benchmarks.calculate_benchmark_cis_with_bayesian(sim_data["d_t"], sim_data["c_t"], sim_data["f_s_true"])
        its_results = benchmarks.calculate_its_with_penalized_mle(
            d_t=sim_data["d_t"], c_t=sim_data["c_t"], f_s=sim_data["f_s_true"],
            Bm=sim_data["Bm_true"],
            intervention_times_abs=sim_data["true_intervention_times_0_abs"],
            intervention_signs=sim_data["beta_signs_true"]
        )

        all_benchmark_results = {**benchmark_r_t_estimates, **benchmark_cis, **its_results}
        results_io.save_benchmark_results(scenario_id, mc_run_idx, all_benchmark_results, config.OUTPUT_DIR_BENCHMARK_RESULTS)
        
        # 5. Collect all scalar evaluation metrics
        calculated_metrics = evaluation.collect_all_metrics(
            sim_data, 
            posterior_samples_scfr,
            benchmark_r_t_estimates, 
            benchmark_cis, 
            its_results
        )
        current_run_metrics.update(calculated_metrics)
        
    except Exception as e:
        import traceback
        error_message = f"ERROR in WORKER: {e}\n{traceback.format_exc()}"
        current_run_metrics["error"] = error_message
    
    # 6. Save the final metrics for this run
    results_io.save_run_metrics(scenario_id, mc_run_idx, current_run_metrics, config.OUTPUT_DIR_RUN_METRICS_JSON)
    return current_run_metrics

## 3. Main Simulation Execution Loop

This function orchestrates the study, iterating through all scenarios and executing the Monte Carlo runs in parallel.

In [3]:
def main_simulation_runner():
    """Main function to run the entire simulation study with progress bars."""
    start_time_total = time.time()

    # Create all necessary output directories
    for dir_path in [config.OUTPUT_DIR_TABLES, config.OUTPUT_DIR_RESULTS_CSV, 
                     config.OUTPUT_DIR_POSTERIOR_SUMMARIES, config.OUTPUT_DIR_BENCHMARK_RESULTS,
                     config.OUTPUT_DIR_POSTERIOR_SAMPLES, config.OUTPUT_DIR_RUN_METRICS_JSON]:
        os.makedirs(dir_path, exist_ok=True)

    all_results_metrics_list = []
    
    # Outer loop for scenarios with a progress bar
    for scenario_idx, scenario_config in enumerate(tqdm(config.SCENARIOS, desc="Overall Scenario Progress")):
        scenario_id = scenario_config["id"]
        scenario_base_seed = config.GLOBAL_BASE_SEED + (scenario_idx * config.NUM_MONTE_CARLO_RUNS * 100)
        
        # Prepare arguments for each Monte Carlo run
        mc_args_list = [(scenario_config, i, scenario_base_seed) for i in range(config.NUM_MONTE_CARLO_RUNS)]

        try:
            # Use joblib for parallel execution with a nested tqdm progress bar
            scenario_metrics = Parallel(n_jobs=config.NUM_CORES_TO_USE)(
                delayed(run_single_mc_replication)(args) for args in tqdm(
                    mc_args_list, desc=f"MC Runs for {scenario_id}", leave=False
                )
            )
            all_results_metrics_list.extend(scenario_metrics)
        
        except Exception as e_pool:
            print(f"\nCRITICAL ERROR during parallel processing for scenario {scenario_id}: {e_pool}")

    end_time_total = time.time()
    print(f"\nAll simulations completed in {end_time_total - start_time_total:.2f} seconds.")

    # Save a final combined CSV of all metrics
    if all_results_metrics_list:
        results_df_all = pd.DataFrame([m for m in all_results_metrics_list if m is not None])
        results_df_all.to_csv(os.path.join(config.OUTPUT_DIR_RESULTS_CSV, "all_scenarios_metrics_combined.csv"), index=False)
        print("Combined metrics for all runs saved.")
    
    print("\nSimulation runs complete. You can now use the analysis notebook to generate plots and tables.")

## 4. Execute Simulation

The following cell will start the simulation process. Ensure all configurations in `config.py` are set as desired.

In [4]:
if __name__ == '__main__':
    main_simulation_runner()

Overall Scenario Progress:   0%|                                                                | 0/12 [00:00<?, ?it/s]
[ARuns for S01:   0%|                                                                           | 0/5 [00:00<?, ?it/s]
[ARuns for S01: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 44.23it/s]
Overall Scenario Progress:   8%|████▋                                                   | 1/12 [00:40<07:30, 40.91s/it]
[ARuns for S02:   0%|                                                                           | 0/5 [00:00<?, ?it/s]
Overall Scenario Progress:  17%|█████████▎                                              | 2/12 [02:35<14:02, 84.21s/it]
[ARuns for S03:   0%|                                                                           | 0/5 [00:00<?, ?it/s]
Overall Scenario Progress:  25%|█████████████▊                                         | 3/12 [05:39<19:28, 129.89s/it]
[ARuns for S04:   0%|                  


All simulations completed in 1437.44 seconds.
Combined metrics for all runs saved.

Simulation runs complete. You can now use the analysis notebook to generate plots and tables.



