In [11]:
import numpy as np
import pandas as pd
from SimulationConfig import SimulationConfig
from simulation import Simulation
import itertools
from typing import Dict, List, Tuple

In [None]:

class SensitivityAnalyzer:
    def __init__(self):
        self.base_config = SimulationConfig()
        self.results = []
        
    def define_parameter_ranges(self) -> Dict[str, Dict[str, float]]:
        """Define the parameter variations to test"""
        return {
            'prevalence': {
                'low': 26,      # per 100,000
                'base': 53,
                'high': 95
            },
            'treatment_access': {
                'low': 0.25,    # 25%
                'base': 0.43,   # 43%
                'high': 0.60    # 60%
            },
            'chronic_fraction': {
                'low': 0.10,    # 10%
                'base': 0.20,   # 20%
                'high': 0.25    # 25%
            }
        }
    
    def create_config_variant(self, prevalence: int, treatment_access: float, 
                            chronic_fraction: float) -> SimulationConfig:
        """Create a configuration with modified parameters"""
        config = SimulationConfig()
        config.annual_prevalence_per_100k = prevalence
        config.prop_treated = treatment_access
        config.prop_untreated = 1 - treatment_access
        config.prop_chronic = chronic_fraction
        config.prop_episodic = 1 - chronic_fraction
        # Use smaller simulation size for speed
        config.percent_of_patients_to_simulate = 0.05
        return config
    
    def calculate_dles(self, simulation: Simulation) -> Dict[str, float]:
        """Calculate DLES and other key metrics from simulation results"""
        results = simulation.get_results()
        
        # Calculate total person-years at ≥9/10 intensity (index 90+)
        total_extreme_pain = sum(
            sum(group_data[90:]) for group_data in results['global_person_years'].values()
        )
        
        # Convert to days (DLES)
        dles = total_extreme_pain * 365
        
        # Also calculate ≥7/10 intensity (YLSS equivalent)
        total_severe_pain = sum(
            sum(group_data[70:]) for group_data in results['global_person_years'].values()
        )
        ylss = total_severe_pain * 365
        
        # Total person-years in any pain
        total_pain = sum(
            sum(group_data) for group_data in results['global_person_years'].values()
        )
        
        return {
            'dles': dles,
            'ylss': ylss,
            'total_person_years': total_pain,
            'total_ch_sufferers': simulation.get_total_ch_sufferers()
        }
    
    def run_sensitivity_analysis(self) -> pd.DataFrame:
        """Run all 27 parameter combinations"""
        param_ranges = self.define_parameter_ranges()
        
        # Get all combinations
        combinations = list(itertools.product(
            param_ranges['prevalence'].items(),
            param_ranges['treatment_access'].items(),
            param_ranges['chronic_fraction'].items()
        ))
        
        print(f"Running {len(combinations)} scenarios...")
        
        base_case_dles = None
        
        for i, ((prev_label, prev_val), (treat_label, treat_val), (chronic_label, chronic_val)) in enumerate(combinations):
            print(f"Scenario {i+1}/{len(combinations)}: {prev_label} prevalence, {treat_label} treatment, {chronic_label} chronic")
            
            # Create configuration for this scenario
            config = self.create_config_variant(prev_val, treat_val, chronic_val)
            
            # Run simulation
            simulation = Simulation(config)
            simulation.run()
            
            # Calculate metrics
            metrics = self.calculate_dles(simulation)
            
            # Store results
            result = {
                'scenario': i + 1,
                'prevalence_label': prev_label,
                'prevalence_value': prev_val,
                'treatment_label': treat_label,
                'treatment_value': treat_val,
                'chronic_label': chronic_label,
                'chronic_value': chronic_val,
                'dles': metrics['dles'],
                'ylss': metrics['ylss'],
                'total_person_years': metrics['total_person_years'],
                'total_ch_sufferers': metrics['total_ch_sufferers'],
                'is_base_case': (prev_label == 'base' and treat_label == 'base' and chronic_label == 'base')
            }
            
            if result['is_base_case']:
                base_case_dles = metrics['dles']
            
            self.results.append(result)
        
        # Convert to DataFrame and calculate percentage changes
        df = pd.DataFrame(self.results)
        
        if base_case_dles:
            df['dles_pct_change'] = ((df['dles'] - base_case_dles) / base_case_dles) * 100
        
        return df
    
    def create_summary_table(self, df: pd.DataFrame) -> pd.DataFrame:
        """Create the compact matrix format table"""
        # Pivot the data to create the matrix format
        pivot_data = []
        
        for treat_label in ['low', 'base', 'high']:
            for chronic_label in ['10%', '20% (Base)', '25%']:
                row_data = {'Treatment Access': f"{treat_label.title()}", 'Chronic %': chronic_label}
                
                for prev_label in ['low', 'base', 'high']:
                    # Find the matching scenario
                    scenario = df[
                        (df['prevalence_label'] == prev_label) & 
                        (df['treatment_label'] == treat_label) & 
                        (df['chronic_label'] == chronic_label.split()[0].rstrip('%').lower())
                    ]
                    
                    if not scenario.empty:
                        dles_full = scenario['dles'].iloc[0]
                        row_data[f"{prev_label.title()} Prevalence"] = f"{int(dles_full):,}"
                
                pivot_data.append(row_data)
        
        return pd.DataFrame(pivot_data)
    
    def create_summary_stats(self, df: pd.DataFrame) -> Dict:
        """Create summary statistics"""
        base_case = df[df['is_base_case']].iloc[0]
        
        return {
            'total_scenarios': len(df),
            'min_dles': df['dles'].min(),
            'max_dles': df['dles'].max(),
            'base_dles': base_case['dles'],
            'min_pct_change': df['dles_pct_change'].min(),
            'max_pct_change': df['dles_pct_change'].max(),
            'q25_dles': df['dles'].quantile(0.25),
            'q75_dles': df['dles'].quantile(0.75),
            'coefficient_variation': (df['dles'].std() / df['dles'].mean()) * 100
        }
    
    def format_parameter_labels(self, df: pd.DataFrame) -> pd.DataFrame:
        # Get parameter ranges for formatting
        param_ranges = self.define_parameter_ranges()
        
        # Create mapping functions
        def format_prevalence(label, value):
            if label == 'low':
                return f"{int(value)} (low)"
            elif label == 'base':
                return f"{int(value)} (base)"
            elif label == 'high':
                return f"{int(value)} (high)"
            return f"{int(value)}"
        
        def format_treatment(label, value):
            pct = int(value * 100)
            if label == 'low':
                return f"{pct}% (low)"
            elif label == 'base':
                return f"{pct}% (base)"
            elif label == 'high':
                return f"{pct}% (high)"
            return f"{pct}%"
        
        def format_chronic(label, value):
            pct = int(value * 100)
            if label == 'low':
                return f"{pct}% (low)"
            elif label == 'base':
                return f"{pct}% (base)"
            elif label == 'high':
                return f"{pct}% (high)"
            return f"{pct}%"
        
        # Apply formatting
        df_formatted = df.copy()
        df_formatted['prevalence_formatted'] = df_formatted.apply(
            lambda row: format_prevalence(row['prevalence_label'], row['prevalence_value']), axis=1
        )
        df_formatted['treatment_formatted'] = df_formatted.apply(
            lambda row: format_treatment(row['treatment_label'], row['treatment_value']), axis=1
        )
        df_formatted['chronic_formatted'] = df_formatted.apply(
            lambda row: format_chronic(row['chronic_label'], row['chronic_value']), axis=1
        )
        
        return df_formatted

    def create_detailed_table(self, df: pd.DataFrame) -> pd.DataFrame:
        """Create detailed results table with proper formatting and sorting"""
        
        # Format the labels
        df_formatted = self.format_parameter_labels(df)
        
        # Define sorting order
        prevalence_order = ['low', 'base', 'high']
        treatment_order = ['low', 'base', 'high']
        chronic_order = ['low', 'base', 'high']
        
        # Create categorical columns for proper sorting
        df_formatted['prevalence_cat'] = pd.Categorical(
            df_formatted['prevalence_label'], categories=prevalence_order, ordered=True
        )
        df_formatted['treatment_cat'] = pd.Categorical(
            df_formatted['treatment_label'], categories=treatment_order, ordered=True
        )
        df_formatted['chronic_cat'] = pd.Categorical(
            df_formatted['chronic_label'], categories=chronic_order, ordered=True
        )
        
        # Sort by the categorical columns
        df_sorted = df_formatted.sort_values(['prevalence_cat', 'treatment_cat', 'chronic_cat'])
        
        # Create the final table with nice column names and formatting
        detailed_table = pd.DataFrame({
            'Prevalence': df_sorted['prevalence_formatted'],
            'Treatment access': df_sorted['treatment_formatted'],
            'Chronic %': df_sorted['chronic_formatted'],
            'DLES': df_sorted['dles'].apply(lambda x: f"{int(x):,}"),
            'DLES (% change)': df_sorted['dles_pct_change'].apply(lambda x: f"{int(x):+d}%")
        })
        
        # Reset index
        detailed_table.reset_index(drop=True, inplace=True)
        
        return detailed_table

def main():
    """Run the sensitivity analysis"""
    analyzer = SensitivityAnalyzer()
    
    # Run analysis
    results_df = analyzer.run_sensitivity_analysis()
    
    # Create summary table
    summary_table = analyzer.create_summary_table(results_df)
    
    # Calculate summary statistics
    summary_stats = analyzer.create_summary_stats(results_df)
    
    # Print results
    print("\n" + "="*80)
    print("SENSITIVITY ANALYSIS RESULTS")
    print("="*80)
    
    print(f"\nScenarios tested: {summary_stats['total_scenarios']}")
    print(f"Base case DLES: {summary_stats['base_dles']:,.0f} days")
    print(f"Range: {summary_stats['min_dles']:,.0f} - {summary_stats['max_dles']:,.0f} days")
    print(f"Percentage change from base: {summary_stats['min_pct_change']:.1f}% to +{summary_stats['max_pct_change']:.1f}%")
    print(f"25th-75th percentile: {summary_stats['q25_dles']:,.0f} - {summary_stats['q75_dles']:,.0f} days")
    print(f"Coefficient of variation: {summary_stats['coefficient_variation']:.1f}%")
    
    print("\n" + "="*80)
    print("COMPACT SENSITIVITY TABLE")
    print("="*80)
    print("Global Days Lived with Extreme Suffering (DLES) in thousands")
    print()
    print(summary_table.to_string(index=False))
    
    print("\n" + "="*80)
    print("DETAILED RESULTS")
    print("="*80)
    
    # Create and display detailed table with proper formatting
    detailed_table = analyzer.create_detailed_table(results_df)
    print(detailed_table.to_string(index=False))
    # In your main() function, after creating detailed_table:
    detailed_table.to_csv('detailed_sensitivity_results.csv', index=False)
    print(f"\nDetailed table saved to: detailed_sensitivity_results.csv")
    return results_df, summary_table, summary_stats, detailed_table

In [33]:
results_df, summary_table, summary_stats, detailed_table = main()

Running 27 scenarios...
Scenario 1/27: low prevalence, low treatment, low chronic
Scenario 2/27: low prevalence, low treatment, base chronic
Scenario 3/27: low prevalence, low treatment, high chronic
Scenario 4/27: low prevalence, base treatment, low chronic
Scenario 5/27: low prevalence, base treatment, base chronic
Scenario 6/27: low prevalence, base treatment, high chronic
Scenario 7/27: low prevalence, high treatment, low chronic
Scenario 8/27: low prevalence, high treatment, base chronic
Scenario 9/27: low prevalence, high treatment, high chronic
Scenario 10/27: base prevalence, low treatment, low chronic
Scenario 11/27: base prevalence, low treatment, base chronic
Scenario 12/27: base prevalence, low treatment, high chronic
Scenario 13/27: base prevalence, base treatment, low chronic
Scenario 14/27: base prevalence, base treatment, base chronic
Scenario 15/27: base prevalence, base treatment, high chronic
Scenario 16/27: base prevalence, high treatment, low chronic
Scenario 17/27


Detailed table saved to: detailed_sensitivity_results.csv
