In [None]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp
import numpy as np
from models.variable_susceptibility import (
    simulate_variable_susceptibility_hom,
    simulate_variable_susceptibility_pol,
    N_COMPARTMENTS
)
from typing import Dict, Tuple, Optional
import matplotlib.pyplot as plt

In [None]:
def run_homophilic_sir_test(
    test_params: Dict,
    show_plots: bool = True
) -> Tuple[bool, Dict[str, str], Dict[str, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]], Optional[plt.Figure]]:
    """
    Run tests comparing homophilic and non-homophilic SIR models
    
    Args:
        test_params: Dictionary containing:
            'beta_params': Tuple[float, float] - Parameters (a,b) for beta distribution
            'transmission_rates': Tuple[float, float] - (min_rate, max_rate)
            'recovery_rate': float - Recovery rate gamma
            'homophilic_tendency': float - Homophily parameter h
            'dt': float - Time step size
            'n_steps': int - Number of simulation steps
            'initial_infected': float - Initial proportion infected
        show_plots: Whether to generate visualization
        
    Returns:
        Tuple containing:
        - Boolean indicating if all tests passed
        - Dictionary of test messages
        - Dictionary of simulation results
        - Matplotlib figure if show_plots=True, else None
    
    Example params:
    {
        'beta_params': (2.0, 2.0),            # Symmetric population distribution
        'transmission_rates': (0.1, 0.5),      # Range of transmission rates
        'recovery_rate': 0.1,                  # Recovery rate
        'homophilic_tendency': 2.0,           # Positive h for homophilic mixing
        'dt': 0.25,                           # Time step
        'n_steps': 400,                       # Number of steps
        'initial_infected': 0.01              # Initial infected proportion
    }
    """
    # Extract parameters with defaults
    params = {
        'transmission_rates': test_params.get('transmission_rates', (0.1, 0.5)),
        'recovery_rate': test_params.get('recovery_rate', 0.1),
        'homophilic_tendency': test_params.get('homophilic_tendency', 2.0),
        'dt': test_params.get('dt', 0.25)
    }
    
    beta_params = test_params.get('beta_params', (2.0, 2.0))
    n_steps = test_params.get('n_steps', 400)
    initial_infected = test_params.get('initial_infected', 0.01)
    
    # Run both homophilic and non-homophilic simulations
    results_hom = simulate_variable_susceptibility_hom(
        beta_params, params, n_steps, initial_infected
    )
    
    # Run polarized case (h=0) for comparison
    params_pol = params.copy()
    params_pol['homophilic_tendency'] = 0.0
    results_pol = simulate_variable_susceptibility_pol(
        beta_params, params_pol, n_steps, initial_infected
    )
    
    # Test results
    test_messages = {}
    all_passed = True
    
    # 1. Test population conservation for both cases
    for case, results in [("homophilic", results_hom), ("polarized", results_pol)]:
        S_t, I_t, R_t = results
        total_pop_t = S_t + I_t + R_t
        pop_conservation = jnp.allclose(
            total_pop_t.sum(axis=1),
            jnp.ones(n_steps + 1),
            rtol=1e-5
        )
        test_messages[f'population_conservation_{case}'] = (
            f"PASSED - Population is conserved ({case})" if pop_conservation 
            else f"FAILED - Population not conserved ({case})"
        )
        all_passed &= pop_conservation
    
    # 2. Test non-negative populations
    for case, results in [("homophilic", results_hom), ("polarized", results_pol)]:
        S_t, I_t, R_t = results
        non_negative = (
            jnp.all(S_t >= -1e-10) and 
            jnp.all(I_t >= -1e-10) and 
            jnp.all(R_t >= -1e-10)
        )
        test_messages[f'non_negative_{case}'] = (
            f"PASSED - All populations non-negative ({case})" if non_negative 
            else f"FAILED - Negative populations detected ({case})"
        )
        all_passed &= non_negative
    
    # 3. Compare epidemic dynamics
    I_t_h = results_hom[1]
    I_t_p = results_pol[1]
    
    peak_time_h = float(jnp.argmax(I_t_h.sum(axis=1)) * params['dt'])
    peak_time_p = float(jnp.argmax(I_t_p.sum(axis=1)) * params['dt'])
    
    peak_size_h = float(jnp.max(I_t_h.sum(axis=1)))
    peak_size_p = float(jnp.max(I_t_p.sum(axis=1)))
    
    test_messages['epidemic_peaks'] = (
        f"Homophilic - Peak: {peak_size_h:.1%} at {peak_time_h:.1f} days\n"
        f"Polarized - Peak: {peak_size_p:.1%} at {peak_time_p:.1f} days"
    )
    
    # 4. Final epidemic sizes
    final_size_h = float(results_hom[2][-1].sum())
    final_size_p = float(results_pol[2][-1].sum())
    
    test_messages['final_sizes'] = (
        f"Final size - Homophilic: {final_size_h:.1%}, "
        f"Polarized: {final_size_p:.1%}"
    )
    
    # Store results
    simulation_results = {
        'homophilic': results_hom,
        'polarized': results_pol,
        'params_used': {
            'model_params': params,
            'beta_params': beta_params,
            'n_steps': n_steps,
            'initial_infected': initial_infected
        }
    }
    
    # Create visualization if requested
    fig = None
    if show_plots:
        fig = plot_homophilic_epidemic_curves(
            results_hom,
            results_pol,
            params['dt']
        )
    
    return all_passed, test_messages, simulation_results, fig

def plot_homophilic_epidemic_curves(
    results_hom: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
    results_pol: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
    dt: float = 0.25,
    title: str = "Comparison of Homophilic vs Non-homophilic SIR"
) -> plt.Figure:
    """Plot comparison of homophilic and non-homophilic epidemic curves"""
    S_t_h, I_t_h, R_t_h = results_hom
    S_t_p, I_t_p, R_t_p = results_pol
    time = np.arange(len(S_t_h)) * dt
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 10))
    
    # Plot total populations
    ax = axes[0]
    ax.plot(time, I_t_h.sum(axis=1), 'r-', label='Infected (h≠0)', alpha=0.7)
    ax.plot(time, I_t_p.sum(axis=1), 'r--', label='Infected (h=0)', alpha=0.7)
    ax.plot(time, R_t_h.sum(axis=1), 'g-', label='Recovered (h≠0)', alpha=0.7)
    ax.plot(time, R_t_p.sum(axis=1), 'g--', label='Recovered (h=0)', alpha=0.7)
    
    ax.set_title(title)
    ax.set_xlabel('Time (days)')
    ax.set_ylabel('Total Population Proportion')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # Plot heatmap of infected proportions
    ax = axes[1]
    im = ax.imshow(
        I_t_h.T,
        aspect='auto',
        extent=[0, time[-1], 0, 1],
        origin='lower',
        cmap='YlOrRd'
    )
    
    ax.set_title('Infected Proportions Across Behavior Groups (h≠0)')
    ax.set_xlabel('Time (days)')
    ax.set_ylabel('Behavior Group (normalized)')
    plt.colorbar(im, ax=ax, label='Proportion Infected')
    
    plt.tight_layout()
    return fig

In [None]:
test_params = {
    'beta_params': (1, 1),
    'transmission_rates': (0.0, 0.6),
    'recovery_rate': 0.1,
    'homophilic_tendency': -10,
    'dt': 0.25,
    'n_steps': 1000,
    'initial_infected': 1e-4
}

success, messages, _, _ = run_homophilic_sir_test(test_params, show_plots=True)
print("Overall test success:", success)
print("\nDetailed test results:")
for test_name, message in messages.items():
    print(f"{test_name}: {message}")
plt.show()

In [None]:
test_params = {
    'beta_params': (1, 1),
    'transmission_rates': (0.0, 0.6),
    'recovery_rate': 0.1,
    'homophilic_tendency': 0.0001,
    'dt': 0.25,
    'n_steps': 1000,
    'initial_infected': 1e-4
}

success, messages, _, _ = run_homophilic_sir_test(test_params, show_plots=True)
print("Overall test success:", success)
print("\nDetailed test results:")
for test_name, message in messages.items():
    print(f"{test_name}: {message}")
plt.show()

In [None]:
test_params = {
    'beta_params': (1, 1),
    'transmission_rates': (0.0, 0.6),
    'recovery_rate': 0.1,
    'homophilic_tendency': 5,
    'dt': 0.25,
    'n_steps': 1000,
    'initial_infected': 1e-4
}

success, messages, _, _ = run_homophilic_sir_test(test_params, show_plots=True)
print("Overall test success:", success)
print("\nDetailed test results:")
for test_name, message in messages.items():
    print(f"{test_name}: {message}")
plt.show()

In [None]:
test_params = {
    'beta_params': (1, 1),
    'transmission_rates': (0.0, 0.6),
    'recovery_rate': 0.1,
    'homophilic_tendency': 20,
    'dt': 0.25,
    'n_steps': 1000,
    'initial_infected': 1e-4
}

success, messages, _, _ = run_homophilic_sir_test(test_params, show_plots=True)
print("Overall test success:", success)
print("\nDetailed test results:")
for test_name, message in messages.items():
    print(f"{test_name}: {message}")
plt.show()

In [None]:
test_params = {
    'beta_params': (1, 1),
    'transmission_rates': (0.0, 0.6),
    'recovery_rate': 0.1,
    'homophilic_tendency': -5,
    'dt': 0.25,
    'n_steps': 1000,
    'initial_infected': 1e-4
}

success, messages, _, _ = run_homophilic_sir_test(test_params, show_plots=True)
print("Overall test success:", success)
print("\nDetailed test results:")
for test_name, message in messages.items():
    print(f"{test_name}: {message}")
plt.show()

In [None]:
test_params = {
    'beta_params': (1, 1),
    'transmission_rates': (0.0, 0.6),
    'recovery_rate': 0.1,
    'homophilic_tendency': -20,
    'dt': 0.25,
    'n_steps': 1000,
    'initial_infected': 1e-4
}

success, messages, _, _ = run_homophilic_sir_test(test_params, show_plots=True)
print("Overall test success:", success)
print("\nDetailed test results:")
for test_name, message in messages.items():
    print(f"{test_name}: {message}")
plt.show()