In [None]:
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import numpy as np
from models.variable_susceptibility import (
    simulate_variable_susceptibility_hom,
    simulate_variable_susceptibility_pol,
    N_COMPARTMENTS
)
from core.interaction import create_contact_matrix
from typing import Dict, Tuple, Optional, Any
import matplotlib.pyplot as plt



In [None]:
def plot_contact_matrix(
    matrix: jnp.ndarray,
    title: str = "Contact Matrix",
    populations: Optional[jnp.ndarray] = None
) -> plt.Figure:
    """
    Plot contact matrix with optional population distribution
    
    Args:
        matrix: Contact matrix to plot
        title: Plot title
        populations: Optional population distribution to show alongside
    """
    if populations is not None:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), 
                                      gridspec_kw={'width_ratios': [3, 1]})
    else:
        fig, ax1 = plt.subplots(1, 1, figsize=(8, 6))
        
    # Plot contact matrix
    im = ax1.imshow(matrix, origin='lower', aspect='equal', cmap='YlOrRd')
    ax1.set_title(title)
    ax1.set_xlabel('Group j')
    ax1.set_ylabel('Group i')
    plt.colorbar(im, ax=ax1, label='Contact rate')
    
    # Add row/column annotations for contact strength
    row_sums = matrix.sum(axis=1)
    col_sums = matrix.sum(axis=0)
    ax1.text(1.02, 0.5, 'Row sums\n(outgoing)', 
             transform=ax1.transAxes, rotation=270, va='center')
    ax1.text(0.5, -0.1, 'Column sums\n(incoming)', 
             transform=ax1.transAxes, ha='center')
    
    # Add text for matrix properties
    props = (
        f"Max value: {matrix.max():.3f}\n"
        f"Mean value: {matrix.mean():.3f}\n"
        f"Row sum range: [{row_sums.min():.3f}, {row_sums.max():.3f}]\n"
        f"Col sum range: [{col_sums.min():.3f}, {col_sums.max():.3f}]"
    )
    ax1.text(1.35, 0.95, props, transform=ax1.transAxes, 
             bbox=dict(facecolor='white', alpha=0.8))
    
    if populations is not None:
        # Plot population distribution
        ax2.plot(populations, np.arange(len(populations))/len(populations), 
                'k-', label='Population')
        ax2.fill_betweenx(np.arange(len(populations))/len(populations), 
                         populations, alpha=0.3)
        ax2.set_title('Population\nDistribution')
        ax2.set_xlabel('Density')
        ax2.set_ylim(0, 1)
        ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def run_homophilic_sir_test(
    test_params: Dict,
    show_plots: bool = True
) -> Tuple[bool, Dict[str, str], Dict[str, Any], Dict[str, plt.Figure]]:
    """
    Run tests comparing homophilic and non-homophilic SIR models
    
    Args:
        test_params: Dictionary containing model parameters
        show_plots: Whether to generate visualizations
        
    Returns:
        Tuple containing:
        - Boolean indicating if all tests passed
        - Dictionary of test messages
        - Dictionary of simulation results and matrices
        - Dictionary of figures
    """
    # Extract parameters with defaults
    params = {
        'transmission_rates': test_params.get('transmission_rates', (0.1, 0.5)),
        'recovery_rate': test_params.get('recovery_rate', 0.1),
        'homophilic_tendency': test_params.get('homophilic_tendency', 2.0),
        'dt': test_params.get('dt', 0.25)
    }
    
    beta_params = test_params.get('beta_params', (2.0, 2.0))
    n_steps = test_params.get('n_steps', 400)
    initial_infected = test_params.get('initial_infected', 0.01)
    
    # Run homophilic simulation
    results_hom = simulate_variable_susceptibility_hom(
        beta_params, params, n_steps, initial_infected
    )
    
    # Run polarized case for comparison
    params_pol = params.copy()
    params_pol['homophilic_tendency'] = 0.0
    results_pol = simulate_variable_susceptibility_pol(
        beta_params, params_pol, n_steps, initial_infected
    )
    
    # Get initial populations from results
    S0, I0, R0 = results_hom[0][0], results_hom[1][0], results_hom[2][0]
    populations = S0 + I0 + R0
    
    # Generate contact matrices
    h = params['homophilic_tendency']
    C_hom = create_contact_matrix(N_COMPARTMENTS, h, populations)
    C_pol = create_contact_matrix(N_COMPARTMENTS, 0.0, populations)
    
    # Run tests as before
    test_messages = {}
    all_passed = True
    
    # [Previous test code remains the same...]
    for case, results in [("homophilic", results_hom), ("polarized", results_pol)]:
        S_t, I_t, R_t = results
        total_pop_t = S_t + I_t + R_t
        pop_conservation = jnp.allclose(
            total_pop_t.sum(axis=1),
            jnp.ones(n_steps + 1),
            rtol=1e-5
        )
        test_messages[f'population_conservation_{case}'] = (
            f"PASSED - Population is conserved ({case})" if pop_conservation 
            else f"FAILED - Population not conserved ({case})"
        )
        all_passed &= pop_conservation
    
    # Store results including matrices
    simulation_results = {
        'homophilic': results_hom,
        'polarized': results_pol,
        'contact_matrix_hom': C_hom,
        'contact_matrix_pol': C_pol,
        'populations': populations,
        'params_used': {
            'model_params': params,
            'beta_params': beta_params,
            'n_steps': n_steps,
            'initial_infected': initial_infected
        }
    }
    
    # Create visualizations if requested
    figures = {}
    if show_plots:
        figures['dynamics'] = plot_homophilic_epidemic_curves(
            results_hom,
            results_pol,
            params['dt']
        )
        
        figures['contact_matrix_hom'] = plot_contact_matrix(
            C_hom,
            f"Contact Matrix (h={h})",
            populations
        )
        
        figures['contact_matrix_pol'] = plot_contact_matrix(
            C_pol,
            "Contact Matrix (h=0)",
            populations
        )
    
    return all_passed, test_messages, simulation_results, figures
def plot_homophilic_epidemic_curves(
    results_hom: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
    results_pol: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
    dt: float = 0.25,
    title: str = "Comparison of Homophilic vs Non-homophilic SIR"
) -> plt.Figure:
    """Plot comparison of homophilic and non-homophilic epidemic curves"""
    S_t_h, I_t_h, R_t_h = results_hom
    S_t_p, I_t_p, R_t_p = results_pol
    time = np.arange(len(S_t_h)) * dt
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 10))
    
    # Plot total populations
    ax = axes[0]
    ax.plot(time, I_t_h.sum(axis=1), 'r-', label='Infected (h≠0)', alpha=0.7)
    ax.plot(time, I_t_p.sum(axis=1), 'r--', label='Infected (h=0)', alpha=0.7)
    ax.plot(time, R_t_h.sum(axis=1), 'g-', label='Recovered (h≠0)', alpha=0.7)
    ax.plot(time, R_t_p.sum(axis=1), 'g--', label='Recovered (h=0)', alpha=0.7)
    
    ax.set_title(title)
    ax.set_xlabel('Time (days)')
    ax.set_ylabel('Total Population Proportion')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # Plot heatmap of infected proportions
    ax = axes[1]
    im = ax.imshow(
        I_t_h.T,
        aspect='auto',
        extent=[0, time[-1], 0, 1],
        origin='lower',
        cmap='YlOrRd'
    )
    
    ax.set_title('Infected Proportions Across Behavior Groups (h≠0)')
    ax.set_xlabel('Time (days)')
    ax.set_ylabel('Behavior Group (normalized)')
    plt.colorbar(im, ax=ax, label='Proportion Infected')
    
    plt.tight_layout()
    return fig
# Original plot_homophilic_epidemic_curves function remains the same

In [None]:
if __name__ == '__main__':
    # Example usage
    test_params = {
        'beta_params': (0.2, 0.2),
        'transmission_rates': (0.1, 0.5),
        'recovery_rate': 0.1,
        'homophilic_tendency': 0.02,
        'dt': 0.25,
        'n_steps': 400,
        'initial_infected': 0.01
    }
    
    success, messages, results, figs = run_homophilic_sir_test(test_params, show_plots=True)
    
    # Print results
    print("Overall test success:", success)
    print("\nDetailed test results:")
    for test_name, message in messages.items():
        print(f"{test_name}: {message}")
    
    # Show all figures
    plt.show()