In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numba import jit
from typing import Tuple, Dict, List
import time
import os

# Constants for state representation
HEALTHY = 0
INFECTED = 1
RECOVERED = 2
DEAD = 3
VACCINATED = 4

class EnhancedEpidemicSimulation:
    def __init__(
        self,
        grid_size: Tuple[int, int] = (100, 100),
        population: int = 1000,
        infection_rate: float = 0.3,
        step_size: int = 1,
        recovery_period: int = 50,
        death_rate: float = 0.02,
        time_steps: int = 500,
        initial_infected: int = 10,
        age_groups: Dict[str, float] = None,
        vaccination_rate: float = 0.0,
        vaccine_efficacy: float = 0.9,
        reinfection_rate: float = 0.3,
        superspreader_rate: float = 0.01,
        superspreader_multiplier: float = 3.0,
        social_distancing: float = 0.0,
        seasonal_effect: bool = True,
        mutation_rate: float = 0.001
    ):
        """Initialize enhanced epidemic simulation with realistic parameters."""
        self.grid_size = grid_size
        self.population = min(population, grid_size[0] * grid_size[1])
        self.infection_rate = infection_rate
        self.step_size = step_size
        self.recovery_period = recovery_period
        self.death_rate = death_rate
        self.time_steps = time_steps
        self.initial_infected = initial_infected
        
        # Initialize new parameters
        self.vaccination_rate = vaccination_rate
        self.vaccine_efficacy = vaccine_efficacy
        self.reinfection_rate = reinfection_rate
        self.superspreader_rate = superspreader_rate
        self.superspreader_multiplier = superspreader_multiplier
        self.social_distancing = social_distancing
        self.seasonal_effect = seasonal_effect
        self.mutation_rate = mutation_rate
        
        # Set age groups distribution and their risk factors
        self.age_groups = age_groups or {
            '0-17': 0.2,
            '18-49': 0.5,
            '50-69': 0.2,
            '70+': 0.1
        }
        self.age_risk_factors = {
            '0-17': 0.2,
            '18-49': 1.0,
            '50-69': 2.0,
            '70+': 5.0
        }
        
        # Initialize arrays
        self.positions = None
        self.states = None
        self.infection_times = None
        self.age_categories = None
        self.is_superspreader = None
        self.immunity_level = None
        self.history = []
        
        self._initialize_simulation()

    def _initialize_simulation(self):
        """Set up initial simulation state with enhanced features."""
        # Generate random unique positions
        positions = set()
        while len(positions) < self.population:
            pos = (np.random.randint(0, self.grid_size[0]), 
                  np.random.randint(0, self.grid_size[1]))
            positions.add(pos)
        
        self.positions = np.array(list(positions))
        self.states = np.zeros(self.population, dtype=np.int32)
        
        # Assign age categories based on demographic proportions
        age_categories = []
        for age_group, proportion in self.age_groups.items():
            count = int(self.population * proportion)
            age_categories.extend([age_group] * count)
        while len(age_categories) < self.population:
            age_categories.append('18-49')
        self.age_categories = np.array(age_categories)
        
        # Initialize other attributes
        self.is_superspreader = np.random.random(self.population) < self.superspreader_rate
        self.immunity_level = np.zeros(self.population)
        
        # Initialize vaccinated individuals
        vaccinated_count = int(self.population * self.vaccination_rate)
        vaccinated_indices = np.random.choice(
            self.population, 
            size=vaccinated_count, 
            replace=False
        )
        self.states[vaccinated_indices] = VACCINATED
        self.immunity_level[vaccinated_indices] = self.vaccine_efficacy
        
        # Initialize infected individuals
        available_indices = np.where(self.states == HEALTHY)[0]
        initial_infected_idx = np.random.choice(
            available_indices,
            size=self.initial_infected,
            replace=False
        )
        self.states[initial_infected_idx] = INFECTED
        
        # Initialize infection times
        self.infection_times = np.full(self.population, -1, dtype=np.int32)
        self.infection_times[initial_infected_idx] = 0

    def _get_seasonal_modifier(self, time_step):
        """Calculate seasonal effect on transmission."""
        if not self.seasonal_effect:
            return 1.0
        return 1.0 + 0.3 * np.cos(2 * np.pi * time_step / 365)

    def _get_age_risk_factor(self, idx):
        """Get risk factor based on age category."""
        return self.age_risk_factors[self.age_categories[idx]]

    def _calculate_infection_probability(self, infected_idx, susceptible_idx, distance, time_step):
        """Calculate probability of infection considering all factors."""
        base_prob = self.infection_rate
        
        seasonal_mod = self._get_seasonal_modifier(time_step)
        age_risk_mod = self._get_age_risk_factor(susceptible_idx)
        immunity_mod = 1.0 - self.immunity_level[susceptible_idx]
        superspreader_mod = self.superspreader_multiplier if self.is_superspreader[infected_idx] else 1.0
        social_distance_mod = 1.0 - self.social_distancing
        distance_mod = 1.0 / (1.0 + distance ** 2)
        
        final_prob = (base_prob * seasonal_mod * age_risk_mod * immunity_mod * 
                     superspreader_mod * social_distance_mod * distance_mod)
        
        return min(final_prob, 1.0)

    @staticmethod
    @jit(nopython=True)
    def _update_positions(positions, step_size, grid_size):
        """Update positions using random walk with boundary conditions."""
        steps = np.random.randint(-step_size, step_size + 1, size=(len(positions), 2))
        new_positions = positions + steps
        new_positions[:, 0] = np.clip(new_positions[:, 0], 0, grid_size[0] - 1)
        new_positions[:, 1] = np.clip(new_positions[:, 1], 0, grid_size[1] - 1)
        return new_positions

    def _process_infections(self, time_step):
        """Process new infections with enhanced mechanics."""
        infected_mask = self.states == INFECTED
        infected_positions = self.positions[infected_mask]
        
        if len(infected_positions) == 0:
            return
        
        for i in range(len(self.positions)):
            if self.states[i] not in [HEALTHY, RECOVERED, VACCINATED]:
                continue
            
            for j, inf_pos in enumerate(infected_positions):
                distance = np.sqrt(np.sum((self.positions[i] - inf_pos) ** 2))
                
                if distance > 5:  # Maximum interaction distance
                    continue
                
                inf_prob = self._calculate_infection_probability(
                    np.where(infected_mask)[0][j],
                    i,
                    distance,
                    time_step
                )
                
                if np.random.random() < inf_prob:
                    self.states[i] = INFECTED
                    self.infection_times[i] = time_step
                    break

    def _process_recovery(self, time_step):
        """Process recovery and death with age-based risk."""
        for i in range(len(self.states)):
            if self.states[i] == INFECTED:
                if time_step - self.infection_times[i] >= self.recovery_period:
                    base_death_prob = self.death_rate * self._get_age_risk_factor(i)
                    modified_death_prob = base_death_prob * (1.0 - self.immunity_level[i])
                    
                    if np.random.random() < modified_death_prob:
                        self.states[i] = DEAD
                    else:
                        self.states[i] = RECOVERED
                        self.immunity_level[i] = min(1.0, self.immunity_level[i] + 0.5)
                    
                    self.infection_times[i] = -1

    def run(self, verbose=True):
        """Run the enhanced simulation."""
        if verbose:
            print("Starting enhanced epidemic simulation...")
        
        start_time = time.time()
        
        for t in range(self.time_steps):
            self.positions = self._update_positions(
                self.positions, self.step_size, self.grid_size
            )
            self._process_infections(t)
            self._process_recovery(t)
            
            state_counts = np.bincount(self.states, minlength=5)
            self.history.append(state_counts)
            
            if verbose and (t + 1) % 50 == 0:
                print(f"Time step {t + 1}/{self.time_steps} completed")
            
            if np.sum(self.states == INFECTED) == 0:
                if verbose:
                    print("\nNo infected individuals remain. Ending simulation early.")
                break

        if verbose:
            print(f"\nSimulation completed in {time.time() - start_time:.2f} seconds")
        
        return np.array(self.history)

def calculate_statistics(histories, values):
    """Calculate key statistical metrics with robust error handling."""
    stats = {}
    
    def estimate_r0(infected_curve, window=10):
        try:
            growth_rates = []
            for i in range(len(infected_curve) - window):
                if infected_curve[i] > 0:
                    rate = infected_curve[i + window] / infected_curve[i]
                    if not np.isinf(rate) and not np.isnan(rate):
                        growth_rates.append(rate)
            return np.mean(growth_rates) if growth_rates else np.nan
        except Exception:
            return np.nan
    
    metrics = ['r0', 'doubling_time', 'peak_intensity', 
               'total_infected_percent', 'outbreak_duration', 'mortality_rate']
    for metric in metrics:
        stats[metric] = []
    
    for history in histories:
        try:
            infected_curve = history[:, INFECTED]
            total_cases = np.sum(history[:, RECOVERED] + history[:, DEAD])
            population = np.sum(history[0])
            
            stats['r0'].append(estimate_r0(infected_curve))
            
            early_infected = infected_curve[:min(len(infected_curve), 50)]
            positive_growth = np.where(early_infected > 0)[0]
            stats['doubling_time'].append(
                np.mean(np.diff(positive_growth)) if len(positive_growth) > 1 else np.nan
            )
            
            peak = np.max(infected_curve) if len(infected_curve) > 0 else 0
            stats['peak_intensity'].append(peak)
            
            stats['total_infected_percent'].append(
                (total_cases / population * 100) if population > 0 else 0
            )
            
            threshold = peak * 0.01
            duration = len(infected_curve)
            for i in range(len(infected_curve)-1, 0, -1):
                if infected_curve[i] > threshold:
                    duration = i
                    break
            stats['outbreak_duration'].append(duration)
            
            stats['mortality_rate'].append(
                (history[-1, DEAD] / total_cases * 100) if total_cases > 0 else 0
            )
            
        except Exception as e:
            print(f"Error calculating statistics: {str(e)}")
            for metric in metrics:
                stats[metric].append(np.nan)
    
    # Calculate correlations with improved handling of edge cases
    stats['correlations'] = {}
    for metric in metrics:
        try:
            metric_values = np.array(stats[metric])
            # Remove any nan or inf values
            valid_mask = ~(np.isnan(metric_values) | np.isinf(metric_values))
            valid_metric_values = metric_values[valid_mask]
            valid_param_values = values[valid_mask]
            
            if len(valid_metric_values) > 1 and len(valid_param_values) > 1:
                # Check if there's any variation in the data
                if np.std(valid_metric_values) > 0 and np.std(valid_param_values) > 0:
                    metric_mean = np.mean(valid_metric_values)
                    param_mean = np.mean(valid_param_values)
                    numerator = np.sum((valid_metric_values - metric_mean) * 
                                     (valid_param_values - param_mean))
                    denominator = np.sqrt(np.sum((valid_metric_values - metric_mean)**2) * 
                                        np.sum((valid_param_values - param_mean)**2))
                    
                    if denominator != 0:
                        correlation = numerator / denominator
                        stats['correlations'][metric] = correlation
                    else:
                        stats['correlations'][metric] = 0.0
                else:
                    stats['correlations'][metric] = 0.0
            else:
                stats['correlations'][metric] = 0.0
        except Exception as e:
            print(f"Error calculating correlation for {metric}: {str(e)}")
            stats['correlations'][metric] = 0.0
    
    return stats

def run_parameter_analysis(base_params: Dict, param_variations: Dict):
    """Run multiple simulations with different parameter values."""
    results = {}
    
    for param_name, values in param_variations.items():
        param_results = []
        print(f"\nAnalyzing variations in {param_name}:")
        
        for value in values:
            current_params = base_params.copy()
            current_params[param_name] = value
            sim = EnhancedEpidemicSimulation(**current_params)
            history = sim.run(verbose=False)
            param_results.append(history)
        
        results[param_name] = (values, param_results)
    
    return results
    
def plot_analysis(results: Dict, save_dir="plots"):
    """Create comprehensive plots with statistical annotations."""
    os.makedirs(save_dir, exist_ok=True)
    plt.rcParams['font.family'] = 'DejaVu Sans'
    plt.style.use('seaborn-v0_8-darkgrid')
    
    # Styling constants
    TITLE_SIZE = 24
    AXIS_LABEL_SIZE = 20
    TICK_LABEL_SIZE = 16
    LEGEND_SIZE = 14
    
    COLORS = {
        'healthy': '#2ecc71',
        'infected': '#e74c3c',
        'recovered': '#3498db',
        'dead': '#7f8c8d',
        'vaccinated': '#9b59b6'
    }
    
    param_labels = {
        'infection_rate': 'Infection Rate (β)',
        'vaccination_rate': 'Vaccination Coverage',
        'social_distancing': 'Social Distancing Compliance',
        'superspreader_rate': 'Superspreader Proportion'
    }
    
    for param_name, (values, histories) in results.items():
        stats = calculate_statistics(histories, values)
        fig = plt.figure(figsize=(20, 15))
        gs = plt.GridSpec(3, 2, figure=fig, height_ratios=[3, 3, 1])
        
        # 1. Infection curves plot
        ax1 = fig.add_subplot(gs[0, 0])
        plot_values = values * 100 if param_name in ['vaccination_rate', 'social_distancing', 'superspreader_rate'] else values
        
        for value, history in zip(plot_values, histories):
            infected_curve = history[:, INFECTED]
            window_size = 5
            smoothed_curve = np.convolve(infected_curve, 
                                       np.ones(window_size)/window_size, 
                                       mode='valid')
            ax1.plot(smoothed_curve, alpha=0.8, 
                    label=f'{param_labels[param_name]} = {value:.2f}')
        
        ax1.set_xlabel('Time (days)', fontsize=AXIS_LABEL_SIZE)
        ax1.set_ylabel('Number of Infected', fontsize=AXIS_LABEL_SIZE)
        ax1.set_title('Active Infections Over Time', fontsize=TITLE_SIZE)
        ax1.tick_params(labelsize=TICK_LABEL_SIZE)
        ax1.grid(True, alpha=0.3)
        ax1.legend(fontsize=LEGEND_SIZE, bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # 2. Peak analysis plot
        ax2 = fig.add_subplot(gs[0, 1])
        peak_infections = [np.max(h[:, INFECTED]) for h in histories]
        total_cases = [np.sum(h[:, RECOVERED] + h[:, DEAD]) for h in histories]
        
        ax2.plot(plot_values, peak_infections, 'o-', color=COLORS['infected'],
                label='Peak Infections', linewidth=2)
        ax2.plot(plot_values, total_cases, 's-', color=COLORS['recovered'],
                label='Total Cases', linewidth=2)
        
        # Add R0 annotations with error handling
        for value, r0, peak in zip(plot_values, stats['r0'], peak_infections):
            if not np.isnan(r0):
                ax2.annotate(f'R0={r0:.2f}',
                           xy=(value, peak),
                           xytext=(10, 10),
                           textcoords='offset points',
                           fontsize=8,
                           bbox=dict(boxstyle='round,pad=0.5',
                                   fc='yellow',
                                   alpha=0.3),
                           arrowprops=dict(arrowstyle='->'))
        
        ax2.set_xlabel(param_labels[param_name], fontsize=AXIS_LABEL_SIZE)
        ax2.set_ylabel('Number of Cases', fontsize=AXIS_LABEL_SIZE)
        ax2.set_title('Peak & Total Cases Analysis', fontsize=TITLE_SIZE)
        ax2.tick_params(labelsize=TICK_LABEL_SIZE)
        ax2.grid(True, alpha=0.3)
        ax2.legend(fontsize=LEGEND_SIZE)
        
        # 3. Final outcome distribution
        ax3 = fig.add_subplot(gs[1, 0])
        final_states = np.array([h[-1] for h in histories])
        totals = final_states.sum(axis=1, keepdims=True)
        totals[totals == 0] = 1  # Avoid division by zero
        proportions = final_states / totals * 100
        
        bottom = np.zeros(len(values))
        for state, color in [('Recovered', COLORS['recovered']),
                           ('Dead', COLORS['dead']),
                           ('Vaccinated', COLORS['vaccinated'])]:
            state_idx = ['Recovered', 'Dead', 'Vaccinated'].index(state) + 2
            ax3.bar(plot_values, proportions[:, state_idx], bottom=bottom,
                   label=state, color=color, alpha=0.8)
            bottom += proportions[:, state_idx]
        
        ax3.set_xlabel(param_labels[param_name], fontsize=AXIS_LABEL_SIZE)
        ax3.set_ylabel('Percentage of Population', fontsize=AXIS_LABEL_SIZE)
        ax3.set_title('Final Outcome Distribution', fontsize=TITLE_SIZE)
        ax3.tick_params(labelsize=TICK_LABEL_SIZE)
        ax3.grid(True, alpha=0.3)
        ax3.legend(fontsize=LEGEND_SIZE)
        
        # 4. Statistical summary
        ax4 = fig.add_subplot(gs[1, 1])
        ax4.axis('off')
        
        # Create statistical summary text
        stats_text = [
            f"Statistical Summary for {param_labels[param_name]}:",
            "",
            "Parameter Correlations:"
        ]
        
        for metric, corr in stats['correlations'].items():
            if not np.isnan(corr):
                stats_text.append(f"{metric.replace('_', ' ').title()}: {corr:.3f}")
        
        stats_text.extend(["", "Average Metrics:"])
        for metric in ['r0', 'peak_intensity', 'total_infected_percent', 'mortality_rate']:
            values = np.array(stats[metric])
            valid_values = values[~np.isnan(values)]
            if len(valid_values) > 0:
                mean = np.mean(valid_values)
                std = np.std(valid_values)
                stats_text.extend([
                    f"{metric.replace('_', ' ').title()}:",
                    f"Mean: {mean:.2f} ± {std:.2f}"
                ])
        
        ax4.text(0.05, 0.95, '\n'.join(stats_text),
                transform=ax4.transAxes,
                fontsize=12,
                verticalalignment='top',
                fontfamily='monospace')
        
        # Save plot
        plt.tight_layout()
        save_path = os.path.join(save_dir, f'{param_name}_analysis.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.5)
        print(f"Saved {param_name} analysis to {save_path}")
        plt.close()

if __name__ == "__main__":
    # Example usage
    base_params = {
        'grid_size': (100, 100),
        'population': 1000,
        'infection_rate': 0.3,
        'step_size': 1,
        'recovery_period': 14,
        'death_rate': 0.02,
        'time_steps': 500,
        'initial_infected': 10,
        'vaccination_rate': 0.3,
        'vaccine_efficacy': 0.9,
        'reinfection_rate': 0.3,
        'superspreader_rate': 0.01,
        'superspreader_multiplier': 3.0,
        'social_distancing': 0.2,
        'seasonal_effect': True,
        'mutation_rate': 0.001
    }
    
    param_variations = {
        'infection_rate': np.linspace(0.01, 0.81, 10),
        'vaccination_rate': np.linspace(0.0, 0.9, 10),
        'social_distancing': np.linspace(0.0, 0.6, 10),
        'superspreader_rate': np.linspace(0.0, 0.01, 10)
    }
    
    print("Starting parameter analysis...")
    results = run_parameter_analysis(base_params, param_variations)
    plot_analysis(results)
