In [None]:
"""
Example of how to use the BEACON framework to run parameter sweeps.

This example demonstrates how to:
1. Import a model (SIRM, SIRT, or SIRV)
2. Define parameter ranges for sweeping
3. Run the sweep and analyze results
"""

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Import the model and batch sweep functionality
from src.models import SIRM, SIRT, SIRV
from src.utils.batch_sweep import sweep_two_parameters
from src.utils.distributions import pol_to_alpha

# Example 1: Sweep over polarization and mask effectiveness for the SIRM model
def run_mask_sweep():
    print("Running SIRM model sweep (mask-wearing intervention)...")
    
    # Define parameter ranges for sweeping
    # Polarization (pol) range from 0 to 1
    pol_range = {"m": 0, "M": 1, "n": 10}
    
    # Maximum mask effectiveness (mu_max) range from 0 to 1
    mask_range = {"m": 0, "M": 1, "n": 10}
    
    # Customize model parameters (optional)
    custom_params = {
        'recovery_rate': 0.1,      # Recovery rate
        'beta_M': 0.6,             # Maximum susceptibility
        'dT': 0.25,                # Time step
        'homophilic_tendency': 0,  # No homophily in this example
    }
    
    # Run the sweep over polarization and mask effectiveness
    states, r0s, hs, param_grid = sweep_two_parameters(
        model_module=SIRM,
        param1_name="mu_max",       # First parameter to sweep: maximum mask-wearing
        param1_range=mask_range,
        param2_name="beta_params",  # Second parameter to sweep: polarization (through beta parameters)
        param2_range=pol_range,
        custom_base_params=custom_params,
        n_steps=1000,               # Number of simulation steps
        initial_infected_prop=1e-4, # Initial proportion of infected individuals
        population_size=100,        # Number of population compartments
        batch_size=100              # Process 100 parameter combinations at a time
    )
    
    # Convert param_grid to mask effectiveness and polarization values for plotting
    mu_vals = param_grid[:, 0].reshape(10, 10)  # Reshape to 10x10 grid
    alpha_vals = param_grid[:, 1].reshape(10, 10)
    
    # Total infections = initial population - final susceptible - final vaccinated
    S, I, R = states
    total_infections = np.array(1 - np.sum(S, axis=1))
    total_infections = total_infections.reshape(10, 10)
    
    # Plot the results
    plt.figure(figsize=(10, 8))
    
    # Create a heatmap of total infections
    plt.pcolormesh(mu_vals, alpha_vals, total_infections, cmap='viridis')
    plt.colorbar(label='Fraction of Population Infected')
    
    plt.xlabel('Maximum Mask Effectiveness (μ_max)')
    plt.ylabel('Alpha Parameter (related to polarization)')
    plt.title('SIRM Model: Impact of Mask Effectiveness and Population Polarization')
    
    plt.tight_layout()
    plt.savefig('sirm_mask_polarization_sweep.png')
    plt.show()
    
    return states, r0s, hs, param_grid

# Example 2: Sweep over polarization and testing rate for the SIRT model
def run_test_sweep():
    print("Running SIRT model sweep (testing intervention)...")
    
    # Define parameter ranges for sweeping
    # Polarization range from 0 to 1 
    pol_range = {"m": 0, "M": 1, "n": 10}
    
    # Maximum testing rate range from 0 to 0.1
    test_range = {"m": 0, "M": 0.1, "n": 10}
    
    # Run the sweep using default parameters but sweeping over testing max rate and polarization
    states, r0s, hs, param_grid = sweep_two_parameters(
        model_module=SIRT,
        param1_name="testing_rates",  # Format will be (0, value)
        param1_range=test_range,
        param2_name="beta_params",    # Controls polarization
        param2_range=pol_range,
        n_steps=1000,
        population_size=100
    )
    
    # Process the results (similar to Example 1)
    # ...
    
    return states, r0s, hs, param_grid

# Example 3: Sweep over homophily and polarization for the SIRV model
def run_vaccination_sweep():
    print("Running SIRV model sweep (vaccination intervention)...")
    
    # Define parameter ranges for sweeping
    # Homophily range from -5 to 5
    h_range = {"m": -5, "M": 5, "n": 10}
    
    # Polarization range from 0 to 1
    pol_range = {"m": 0, "M": 1, "n": 10}
    
    # Customize model parameters: set fixed vaccination rate
    custom_params = {
        'vaccination_rates': (0, 0.05),  # Use fixed maximum vaccination rate
    }
    
    # Run the sweep over homophily and polarization
    states, r0s, hs, param_grid = sweep_two_parameters(
        model_module=SIRV,
        param1_name="homophilic_tendency",  # First parameter: homophily
        param1_range=h_range,
        param2_name="beta_params",         # Second parameter: polarization
        param2_range=pol_range,
        custom_base_params=custom_params,
        n_steps=1000,
        population_size=100,
        use_contact_matrix=True  # Must use contact matrix when homophily is non-zero
    )
    
    # Process the results
    # ...
    
    return states, r0s, hs, param_grid

if __name__ == "__main__":
    # Choose which example to run (or run all of them)
    # Uncomment the examples you want to run
    
    # Example 1: SIRM model (mask-wearing intervention)
    states, r0s, hs, param_grid = run_mask_sweep()
    
    # Example 2: SIRT model (testing intervention)
    # states, r0s, hs, param_grid = run_test_sweep()
    
    # Example 3: SIRV model (vaccination intervention)
    # states, r0s, hs, param_grid = run_vaccination_sweep()