In [None]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp
import numpy as np
from core.sir_base import simulate_trajectory
from typing import Dict, Tuple, Optional
import matplotlib.pyplot as plt


In [None]:

def plot_epidemic_curves(
    results: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
    dt: float = 0.1,
    title: str = "SIR Epidemic Curves"
) -> plt.Figure:
    """
    Plot SIR curves from simulation results
    
    Args:
        results: Tuple of (S_t, I_t, R_t) arrays
        dt: Time step size for x-axis scaling
        title: Plot title
        
    Returns:
        matplotlib figure object
    """
    S_t, I_t, R_t = results
    time = np.arange(len(S_t)) * dt
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), height_ratios=[2, 1])
    
    # Main epidemic curves
    ax1.plot(time, S_t, 'b-', label='Susceptible', alpha=0.7)
    ax1.plot(time, I_t, 'r-', label='Infected', alpha=0.7)
    ax1.plot(time, R_t, 'g-', label='Recovered', alpha=0.7)
    
    ax1.set_title(title)
    ax1.set_xlabel('Time (days)')
    ax1.set_ylabel('Proportion of Population')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Infected curve with peak annotation
    peak_idx = int(np.argmax(I_t))
    peak_time = peak_idx * dt
    peak_value = float(I_t[peak_idx])
    
    ax2.plot(time, I_t, 'r-', label='Infected', alpha=0.7)
    ax2.axvline(peak_time, color='gray', linestyle='--', alpha=0.5)
    ax2.plot(peak_time, peak_value, 'ko')
    ax2.annotate(f'Peak: {peak_value:.1%}\nDay {peak_time:.1f}',
                xy=(peak_time, peak_value),
                xytext=(10, 10), textcoords='offset points')
    
    ax2.set_xlabel('Time (days)')
    ax2.set_ylabel('Infected')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def run_basic_sir_test(show_plots: bool = True) -> Tuple[bool, Dict[str, str], Dict[str, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]], Optional[plt.Figure]]:
    """
    Run basic tests for SIR model and optionally show plots
    
    Args:
        show_plots: Whether to generate visualization
        
    Returns:
        Tuple containing:
        - Boolean indicating if all tests passed
        - Dictionary of test messages
        - Dictionary of simulation results
        - Matplotlib figure if show_plots=True, else None
    """
    # Initialize test parameters
    params = {
        'transmission_rate': 0.3,  # R0 ≈ 3 when combined with recovery rate
        'recovery_rate': 0.1,      # ~10 day recovery period
        'dt': 0.1                  # Time step of 0.1 days
    }
    
    # Single population initial conditions
    total_population = 1.0
    initial_infected = 0.01
    initial_state = (
        jnp.array([total_population - initial_infected]),  # S
        jnp.array([initial_infected]),                     # I
        jnp.array([0.0])                                  # R
    )
    
    n_steps = 1000  # Simulate 100 days with dt=0.1
    
    # Run simulation
    results = simulate_trajectory(initial_state, params, n_steps)
    S_t, I_t, R_t = results
    
    # Test results
    test_messages = {}
    all_passed = True
    
    # 1. Test population conservation
    total_population_t = S_t + I_t + R_t
    expected_population = jnp.sum(initial_state[0] + initial_state[1] + initial_state[2])
    pop_conservation = jnp.allclose(total_population_t, expected_population, rtol=1e-5)
    test_messages['population_conservation'] = (
        "PASSED - Population is conserved" if pop_conservation 
        else "FAILED - Population not conserved"
    )
    all_passed &= pop_conservation
    
    # 2. Test non-negative populations
    non_negative = (
        jnp.all(S_t >= 0) and 
        jnp.all(I_t >= 0) and 
        jnp.all(R_t >= 0)
    )
    test_messages['non_negative'] = (
        "PASSED - All populations non-negative" if non_negative 
        else "FAILED - Negative populations detected"
    )
    all_passed &= non_negative
    
    # 3. Test epidemic peak
    peak_time = int(jnp.argmax(I_t))
    has_peak = 0 < peak_time < n_steps
    test_messages['epidemic_peak'] = (
        f"PASSED - Epidemic peaks at t={peak_time * params['dt']:.1f} days" if has_peak 
        else "FAILED - No clear epidemic peak"
    )
    all_passed &= has_peak
    
    # 4. Test final size
    final_recovered = float(R_t[-1] / jnp.sum(initial_state[0]))
    final_size_ok = final_recovered > 0.6
    test_messages['final_size'] = (
        f"PASSED - Final recovered fraction: {final_recovered:.1%}" if final_size_ok 
        else f"FAILED - Final recovered fraction too small: {final_recovered:.1%}"
    )
    all_passed &= final_size_ok
    
    # 5. Test final infected
    final_infected = float(I_t[-1] / jnp.sum(initial_state[0]))
    final_infected_ok = final_infected < 0.01
    test_messages['final_infected'] = (
        f"PASSED - Final infected fraction: {final_infected:.1%}" if final_infected_ok 
        else f"FAILED - Final infected fraction too large: {final_infected:.1%}"
    )
    all_passed &= final_infected_ok
    
    # Store results for inspection
    simulation_results = {
        'base_case': (S_t, I_t, R_t),
    }
    
    # Create visualization if requested
    fig = None
    if show_plots:
        fig = plot_epidemic_curves(results, params['dt'])
    
    return all_passed, test_messages, simulation_results, fig

In [None]:
if __name__ == '__main__':
    # If running as script, just print results
    success, messages, _, _ = run_basic_sir_test()
    print("Overall test success:", success)
    print("\nDetailed test results:")
    for test_name, message in messages.items():
        print(f"{test_name}: {message}")
else:
    # If imported, just run function
    pass

In [None]:
if __name__ == '__main__':
    # If running as script
    success, messages, _, fig = run_basic_sir_test(show_plots=True)
    print("Overall test success:", success)
    print("\nDetailed test results:")
    for test_name, message in messages.items():
        print(f"{test_name}: {message}")
    plt.show()
else:
    # If imported, just define functions
    pass