In [4]:
%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp
from models.variable_susceptibility import simulate_variable_susceptibility_hom
from utils.parallel import (
    parameter_sweep,
    create_h_sweep,
    create_susceptibility_sweep,
    create_beta_sweep
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def test_h_sweep(h_min: float = -5.0, h_max: float = 5.0, n_h: int = 50):
    """Test parameter sweep varying only homophilic tendency"""
    
    # Setup parameters
    base_params = {
        'transmission_rates': (0.1, 0.5),
        'recovery_rate': 0.1,
        'dt': 0.25,
        'beta_params': (2.0, 2.0)
    }
    
    param_ranges = create_h_sweep(h_min, h_max, n_h)
    n_steps = 100
    
    # Run sweep
    S_t, I_t, R_t = parameter_sweep(
        simulation_fn=simulate_variable_susceptibility_hom,
        param_ranges=param_ranges,  # Changed from params to param_ranges
        base_params=base_params,
        n_steps=n_steps,
        sweep_type='h'
    )
    
    return S_t, I_t, R_t  # Return results for plotting/analysis

def test_susceptibility_symmetric_sweep(
    susc_min: float = 0.2,
    susc_max: float = 0.8,
    n_susc: int = 10,
    ab_min: float = 0.5,
    ab_max: float = 5.0,
    n_ab: int = 10
):
    """Test parameter sweep varying susceptibility range and symmetric beta params"""
    
    # Added beta_params to base_params
    base_params = {
        'recovery_rate': 0.1,
        'dt': 0.25,
        'homophilic_tendency': 0,
        'beta_params': (1.0, 1.0),  # Added this line
        'transmission_rates': (0.0, 0.6)  # Added this line
    }
    
    param_ranges = create_susceptibility_sweep(
        susc_min, susc_max, n_susc,
        ab_min, ab_max, n_ab
    )
    
    n_steps = 100
    
    # Run sweep
    S_t, I_t, R_t = parameter_sweep(
        simulation_fn=simulate_variable_susceptibility_hom,
        param_ranges=param_ranges,
        base_params=base_params,
        n_steps=n_steps,
        sweep_type='susc'
    )
    
    return S_t, I_t, R_t

def test_asymmetric_beta_sweep( 
        range_a: dict ={"m" : 0.5, "M" : 5.0, "n" : 10}, 
        range_b: dict ={"m" : 0.5, "M" : 5.0, "n" : 10}, 
        h: float = 0, dt: float = 0.25, recovery_rate: float = 0.1, transmission_rates: tuple = (0, 0.6), T = 1000):
    """Test parameter sweep varying asymmetric beta parameters"""
    
    base_params = {
        'transmission_rates': transmission_rates,
        'recovery_rate': recovery_rate,
        'dt': 0.25,
        'homophilic_tendency': h,
        'beta_params': (1.0, 1.0)  # Added default beta_params
    }
    
    param_ranges = create_beta_sweep(
        range_a["m"], range_a["M"], range_a["n"],
        range_b["m"], range_b["M"], range_b["n"]
    )
    
    n_steps = np.round(T / dt).astype(int)
    
    # Run sweep
    S_t, I_t, R_t = parameter_sweep(
        simulation_fn=simulate_variable_susceptibility_hom,
        param_ranges=param_ranges,
        base_params=base_params,
        n_steps=n_steps,
        sweep_type='beta'
    )
    
    return S_t, I_t, R_t

In [29]:
S, I, R = test_asymmetric_beta_sweep()

KeyboardInterrupt: 

In [23]:
def run_all_tests():
    # Test 1: Homophily sweep
    print("Testing homophily sweep...")
    S_h, I_h, R_h = test_h_sweep(h_min=-5.0, h_max=5.0, n_h=10)
    print(f"Homophily sweep shapes: S:{S_h.shape}, I:{I_h.shape}, R:{R_h.shape}")
    
    # Test 2: Susceptibility symmetric sweep
    print("\nTesting susceptibility symmetric sweep...")
    S_s, I_s, R_s = test_susceptibility_symmetric_sweep(
        susc_min=0.1,
        susc_max=0.9,
        n_susc=5,
        ab_min=1.0,
        ab_max=4.0,
        n_ab=5
    )
    print(f"Susceptibility sweep shapes: S:{S_s.shape}, I:{I_s.shape}, R:{R_s.shape}")
    
    # Test 3: Asymmetric beta sweep
    print("\nTesting asymmetric beta sweep...")
    S_b, I_b, R_b = test_asymmetric_beta_sweep(
        a_min=1.0,
        a_max=3.0,
        n_a=5,
        b_min=2.0,
        b_max=4.0,
        n_b=5
    )
    print(f"Beta sweep shapes: S:{S_b.shape}, I:{I_b.shape}, R:{R_b.shape}")
    
    # Basic validation checks
    def validate_results(S, I, R, name):
        print(f"\nValidating {name}...")
        # Check population conservation
        total = S + I + R
        is_conserved = np.allclose(total.sum(axis=-1), 1.0)
        print(f"Population conserved: {is_conserved}")
        
        # Check non-negativity
        non_negative = (S >= 0).all() and (I >= 0).all() and (R >= 0).all()
        print(f"All values non-negative: {non_negative}")
        
        # Check final epidemic size
        final_size = R[:, -1, :].sum(axis=-1)
        print(f"Range of final epidemic sizes: [{final_size.min():.3f}, {final_size.max():.3f}]")
    
    validate_results(S_h, I_h, R_h, "homophily sweep")
    validate_results(S_s, I_s, R_s, "susceptibility sweep")
    validate_results(S_b, I_b, R_b, "beta sweep")

In [24]:
run_all_tests()

Testing homophily sweep...
Homophily sweep shapes: S:(10, 101, 100), I:(10, 101, 100), R:(10, 101, 100)

Testing susceptibility symmetric sweep...
Susceptibility sweep shapes: S:(25, 101, 100), I:(25, 101, 100), R:(25, 101, 100)

Testing asymmetric beta sweep...
Beta sweep shapes: S:(25, 101, 100), I:(25, 101, 100), R:(25, 101, 100)

Validating homophily sweep...
Population conserved: True
All values non-negative: True
Range of final epidemic sizes: [0.006, 0.016]

Validating susceptibility sweep...
Population conserved: True
All values non-negative: True
Range of final epidemic sizes: [0.000, 0.184]

Validating beta sweep...
Population conserved: True
All values non-negative: True
Range of final epidemic sizes: [0.002, 0.151]


In [16]:
S_t = np.array(S_t)
type(S_t)
S_t.shape

(10, 101, 100)

In [None]:
test_susceptibility_symmetric_sweep(
    susc_min=0.1,
    susc_max=0.9,
    n_susc=20,
    ab_min=1.0,
    ab_max=4.0,
    n_ab=15
)

In [27]:
S_t, I_t, R_t = test_asymmetric_beta_sweep(
    a_min=1.0,
    a_max=3.0,
    n_a=100,
    b_min=2.0,
    b_max=4.0,
    n_b=100
)