In [1]:
import numpy as np
import pandas as pd
import jax
from tqdm.auto import tqdm
import copy
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Import your custom modules
import config
import data_generation
import model_fitting
import evaluation

# --- Configuration for this Test ---
NUM_REPLICATIONS = 10 # Number of MC runs for each experimental setting.

In [2]:
def plot_scenario_test_results(plot_data, output_dir):
    """
    Generates and saves a single plot for a specific scenario and parameter combination.
    """
    scenario_id = plot_data['scenario_id']
    param_tag = plot_data['param_tag']
    T_analyze = len(plot_data['true_r_t'])
    time_points = np.arange(T_analyze)

    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot true curves
    ax.plot(time_points, plot_data['true_r_t'], color='black', linestyle='--', label="True Factual $r_t$")
    ax.plot(time_points, plot_data['true_rcf_t'], color='gray', linestyle=':', label="True Counterfactual $r_t$")
    
    # Plot estimated curves
    ax.plot(time_points, plot_data['est_r_mean'], color='blue', label="sCFR Factual (Mean)")
    ax.fill_between(time_points, plot_data['est_r_lower'], plot_data['est_r_upper'], color='blue', alpha=0.2, label="sCFR Factual 95% CrI")
    ax.plot(time_points, plot_data['est_rcf_mean'], color='cyan', linestyle='-.', label="sCFR Counterfactual (Mean)")
    ax.fill_between(time_points, plot_data['est_rcf_lower'], plot_data['est_rcf_upper'], color='cyan', alpha=0.2, label="sCFR Counterfactual 95% CrI")

    ax.set_title(f"Visual Inspection for {scenario_id} with {param_tag}")
    ax.set_xlabel("Time (days)")
    ax.set_ylabel("Case Fatality Rate")
    ax.legend(loc='best')
    ax.grid(True, linestyle=':', alpha=0.6)
    
    # Create a dedicated subdirectory for these plots
    visual_test_dir = os.path.join(output_dir, "visual_parameter_tests")
    os.makedirs(visual_test_dir, exist_ok=True)
    plot_filename = os.path.join(visual_test_dir, f"visual_test_{scenario_id}_{param_tag}.pdf")
    plt.savefig(plot_filename)
    plt.close(fig)

In [None]:
print("--- Starting Full Analysis of beta_abs and lambda Magnitude ---")
SCENARIOS_by_id = {s['id']: s for s in config.SCENARIOS}
results = []
base_seed = config.GLOBAL_BASE_SEED

import itertools

# Define test grids
beta_k1_tests = [0.5, 1, 2]
lambda_k1_tests = [0.5, 1, 2]
beta_k2_tests = [(1,1)]
lambda_k2_tests = list(itertools.product([0.5, 1, 2], [0.5, 1, 2]))

# --- Test K=1 Scenarios ---
k1_scenarios = {sid: conf for sid, conf in SCENARIOS_by_id.items() if conf["num_interventions_K_true"] == 1}
for scenario_id, base_config in tqdm(k1_scenarios.items(), desc="Testing K=1 Scenarios"):
    for beta_val in beta_k1_tests:
        for lambda_val in lambda_k1_tests:
            run_metrics = {"Bias_beta": [], "Coverage_beta": [], "Bias_lambda": [], "Coverage_lambda": []}
            p_samples_stack, p_cf_samples_stack = [], []
            
            for i in range(NUM_REPLICATIONS):
                current_config = copy.deepcopy(base_config)
                current_config["true_beta_abs_0"] = np.array([beta_val])
                current_config["true_lambda_0"] = np.array([lambda_val])
                
                sim_data = data_generation.simulate_scenario_data(current_config, run_seed=(base_seed + i))
                jax_prng_key = jax.random.PRNGKey(base_seed + i)
                posterior_samples, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key)
                
                if posterior_samples:
                    est_beta_m, est_beta_l, est_beta_u = evaluation.get_posterior_estimates(posterior_samples, "beta_abs")
                    est_lambda_m, est_lambda_l, est_lambda_u = evaluation.get_posterior_estimates(posterior_samples, "lambda")
                    
                    run_metrics["Bias_beta"].append(evaluation.calculate_param_bias(beta_val, est_beta_m[0]))
                    run_metrics["Coverage_beta"].append(evaluation.calculate_param_cri_coverage(beta_val, est_beta_l[0], est_beta_u[0]))
                    run_metrics["Bias_lambda"].append(evaluation.calculate_param_bias(lambda_val, est_lambda_m[0]))
                    run_metrics["Coverage_lambda"].append(evaluation.calculate_param_cri_coverage(lambda_val, est_lambda_l[0], est_lambda_u[0]))
                    
                    p_samples_stack.append(posterior_samples.get("p"))
                    p_cf_samples_stack.append(posterior_samples.get("p_cf"))

            results.append({
                "Scenario": scenario_id, "Baseline Shape": base_config["cfr_type_name"], "Num Interventions": 1,
                "True β_abs_1": beta_val, "True λ_1": lambda_val,
                "Bias (β1)": np.mean(run_metrics["Bias_beta"]), "Coverage (β1)": np.mean(run_metrics["Coverage_beta"]),
                "Bias (λ1)": np.mean(run_metrics["Bias_lambda"]), "Coverage (λ1)": np.mean(run_metrics["Coverage_lambda"]),
            })

            if p_samples_stack:
                T_analyze = config.T_ANALYSIS_LENGTH
                plot_data = {
                    'scenario_id': scenario_id,
                    'param_tag': f"b1_{beta_val}_l1_{lambda_val}",
                    'true_r_t': sim_data['true_r_0_t'][:T_analyze],
                    'true_rcf_t': sim_data['true_rcf_0_t'][:T_analyze],
                    'est_r_mean': np.mean(np.vstack(p_samples_stack), axis=0)[:T_analyze],
                    'est_r_lower': np.percentile(np.vstack(p_samples_stack), 2.5, axis=0)[:T_analyze],
                    'est_r_upper': np.percentile(np.vstack(p_samples_stack), 97.5, axis=0)[:T_analyze],
                    'est_rcf_mean': np.mean(np.vstack(p_cf_samples_stack), axis=0)[:T_analyze],
                    'est_rcf_lower': np.percentile(np.vstack(p_cf_samples_stack), 2.5, axis=0)[:T_analyze],
                    'est_rcf_upper': np.percentile(np.vstack(p_cf_samples_stack), 97.5, axis=0)[:T_analyze],
                }
                plot_scenario_test_results(plot_data, config.OUTPUT_DIR_PLOTS)

--- Starting Full Analysis of beta_abs and lambda Magnitude ---


Testing K=1 Scenarios:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
# --- Test K=2 Scenarios ---
k2_scenarios = {sid: conf for sid, conf in SCENARIOS_by_id.items() if conf["num_interventions_K_true"] == 2}
for scenario_id, base_config in tqdm(k2_scenarios.items(), desc="Testing K=2 Scenarios"):
    for beta_vals in beta_k2_tests:
        for lambda_vals in lambda_k2_tests:
            run_metrics = {"Bias_β1": [], "Cov_β1": [], "Bias_λ1": [], "Cov_λ1": [], 
                           "Bias_β2": [], "Cov_β2": [], "Bias_λ2": [], "Cov_λ2": []}
            p_samples_stack, p_cf_samples_stack = [], []

            for i in range(NUM_REPLICATIONS):
                current_config = copy.deepcopy(base_config)
                current_config["true_beta_abs_0"] = np.array(beta_vals)
                current_config["true_lambda_0"] = np.array(lambda_vals)
                
                sim_data = data_generation.simulate_scenario_data(current_config, run_seed=(base_seed + i))
                jax_prng_key = jax.random.PRNGKey(base_seed + i)
                posterior_samples, _ = model_fitting.fit_proposed_model(sim_data, jax_prng_key)
                
                if posterior_samples and posterior_samples["beta_abs"].shape[1] == 2:
                    est_beta_m, est_beta_l, est_beta_u = evaluation.get_posterior_estimates(posterior_samples, "beta_abs")
                    est_lambda_m, est_lambda_l, est_lambda_u = evaluation.get_posterior_estimates(posterior_samples, "lambda")
                    
                    run_metrics["Bias_β1"].append(evaluation.calculate_param_bias(beta_vals[0], est_beta_m[0]))
                    run_metrics["Cov_β1"].append(evaluation.calculate_param_cri_coverage(beta_vals[0], est_beta_l[0], est_beta_u[0]))
                    run_metrics["Bias_λ1"].append(evaluation.calculate_param_bias(lambda_vals[0], est_lambda_m[0]))
                    run_metrics["Cov_λ1"].append(evaluation.calculate_param_cri_coverage(lambda_vals[0], est_lambda_l[0], est_lambda_u[0]))
                    run_metrics["Bias_β2"].append(evaluation.calculate_param_bias(beta_vals[1], est_beta_m[1]))
                    run_metrics["Cov_β2"].append(evaluation.calculate_param_cri_coverage(beta_vals[1], est_beta_l[1], est_beta_u[1]))
                    run_metrics["Bias_λ2"].append(evaluation.calculate_param_bias(lambda_vals[1], est_lambda_m[1]))
                    run_metrics["Cov_λ2"].append(evaluation.calculate_param_cri_coverage(lambda_vals[1], est_lambda_l[1], est_lambda_u[1]))

                    p_samples_stack.append(posterior_samples.get("p"))
                    p_cf_samples_stack.append(posterior_samples.get("p_cf"))

            results.append({
                "Scenario": scenario_id, "Baseline Shape": base_config["cfr_type_name"], "Num Interventions": 2,
                "True β_abs_1": beta_vals[0], "True λ_1": lambda_vals[0],
                "Bias (β1)": np.mean(run_metrics["Bias_β1"]), "Coverage (β1)": np.mean(run_metrics["Cov_β1"]),
                "Bias (λ1)": np.mean(run_metrics["Bias_λ1"]), "Coverage (λ1)": np.mean(run_metrics["Cov_λ1"]),
                "True β_abs_2": beta_vals[1], "True λ_2": lambda_vals[1],
                "Bias (β2)": np.mean(run_metrics["Bias_β2"]), "Coverage (β2)": np.mean(run_metrics["Cov_β2"]),
                "Bias (λ2)": np.mean(run_metrics["Bias_λ2"]), "Coverage (λ2)": np.mean(run_metrics["Cov_λ2"])
            })

            if p_samples_stack:
                T_analyze = config.T_ANALYSIS_LENGTH
                plot_data = {
                    'scenario_id': scenario_id,
                    'param_tag': f"b_{beta_vals[0]}-{beta_vals[1]}_l_{lambda_vals[0]}-{lambda_vals[1]}",
                    'true_r_t': sim_data['true_r_0_t'][:T_analyze],
                    'true_rcf_t': sim_data['true_rcf_0_t'][:T_analyze],
                    'est_r_mean': np.mean(np.vstack(p_samples_stack), axis=0)[:T_analyze],
                    'est_r_lower': np.percentile(np.vstack(p_samples_stack), 2.5, axis=0)[:T_analyze],
                    'est_r_upper': np.percentile(np.vstack(p_samples_stack), 97.5, axis=0)[:T_analyze],
                    'est_rcf_mean': np.mean(np.vstack(p_cf_samples_stack), axis=0)[:T_analyze],
                    'est_rcf_lower': np.percentile(np.vstack(p_cf_samples_stack), 2.5, axis=0)[:T_analyze],
                    'est_rcf_upper': np.percentile(np.vstack(p_cf_samples_stack), 97.5, axis=0)[:T_analyze],
                }
                plot_scenario_test_results(plot_data, config.OUTPUT_DIR_PLOTS)

In [None]:
# --- Aggregate and Report Results ---
if not results:
    print("\nNo results generated.")

results_df = pd.DataFrame(results)

# --- Numerical Summary for K=1 ---
summary_k1 = results_df[results_df['Num Interventions'] == 1].pivot_table(
    values=['Bias (β1)', 'Coverage (β1)', 'Bias (λ1)', 'Coverage (λ1)'],
    index=['Baseline Shape', 'True β_abs_1', 'True λ_1']
)
print("\n\n--- Numerical Summary for K=1 Scenarios ---")
print(summary_k1.to_string(formatters={
    'Bias (β1)': '{:.3f}'.format, 'Coverage (β1)': '{:.0%}'.format,
    'Bias (λ1)': '{:.3f}'.format, 'Coverage (λ1)': '{:.0%}'.format
}))

# --- Numerical Summary for K=2 ---
summary_k2 = results_df[results_df['Num Interventions'] == 2].pivot_table(
    values=['Bias (β1)', 'Coverage (β1)', 'Bias (λ1)', 'Coverage (λ1)', 
            'Bias (β2)', 'Coverage (β2)', 'Bias (λ2)', 'Coverage (λ2)'],
    index=['Baseline Shape', 'True β_abs_1', 'True β_abs_2', 'True λ_1', 'True λ_2']
).rename(columns={'Bias (β1)': 'Bias(β1)', 'Coverage (β1)': 'Cov(β1)',
                  'Bias (λ1)': 'Bias(λ1)', 'Coverage (λ1)': 'Cov(λ1)',
                  'Bias (β2)': 'Bias(β2)', 'Coverage (β2)': 'Cov(β2)',
                  'Bias (λ2)': 'Bias(λ2)', 'Coverage (λ2)': 'Cov(λ2)'})

print("\n\n--- Numerical Summary for K=2 Scenarios ---")
print(summary_k2.to_string(formatters={
    'Bias(β1)': '{:.3f}'.format, 'Cov(β1)': '{:.0%}'.format, 'Bias(λ1)': '{:.3f}'.format, 'Cov(λ1)': '{:.0%}'.format,
    'Bias(β2)': '{:.3f}'.format, 'Cov(β2)': '{:.0%}'.format, 'Bias(λ2)': '{:.3f}'.format, 'Cov(λ2)': '{:.0%}'.format
}))