In [None]:
%load_ext autoreload
%autoreload 2
from typing import Tuple
import unittest
import sys
import jax.numpy as jnp

In [None]:
def test_imports() -> Tuple[bool, str]:
    """Test if all required modules can be imported
    
    Returns:
        Tuple of (success, message)
    """
    try:
        # Test core imports
        from core.sir_base import sir_step, simulate_trajectory
        from core.interaction import create_contact_matrix
        from core.population import (
            my_beta_symmetric,
            my_beta_asymmetric,
            initialize_states,
            generate_behavior_values
        )
        
        # Test model imports
        from models.variable_susceptibility import (
            simulate_variable_susceptibility_pol,
            simulate_variable_susceptibility_hom,
            susceptibility_step_pol,
            susceptibility_step_hom
        )
        
        # Test minimal functionality
        state = (jnp.array([1.0]), jnp.array([0.1]), jnp.array([0.0]))
        params = {
            'transmission_rate': 0.3,
            'recovery_rate': 0.1,
            'dt': 0.1
        }
        sir_step(state, params)
        
        n_groups = 10
        h = 0.5
        pops = jnp.ones(n_groups)
        create_contact_matrix(n_groups, h, pops)
        
        return True, "All imports and basic functionality tests passed"
        
    except Exception as e:
        return False, f"Error during testing: {str(e)}"

class TestImports(unittest.TestCase):
    def test_all_imports(self):
        success, message = test_imports()
        self.assertTrue(success, message)



In [None]:
success, message = test_imports()
print(message)