In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.integrate import odeint
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap

# Set style for better visualizations
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("paper", font_scale=1.5)

def run_model_simulation():
    """
    Run a simulation of the heat wave model and return results for visualization
    This function replicates key functionality from the original code
    """
    # Model parameters (beta_base, gamma_base, mu_base, alpha_T, alpha_P, alpha_V, alpha_A, alpha_H, hospital_threshold)
    baseline_params = (0.3, 0.25, 0.02, 0.7, 0.4, 0.5, 0.5, 0.4, 0.7)

    # Generate synthetic weather data
    duration = 30
    n_points = 100
    time_points = np.linspace(0, duration, n_points)

    # Temperature with daily cycle and trend (35°C to 50°C)
    temperature = 40 + 5 * np.sin(2 * np.pi * time_points / 1) + 3 * np.sin(2 * np.pi * time_points / duration)
    temperature = np.clip(temperature, 35, 50)

    # Pollution data (50 to 200 μg/m³)
    pollution = 120 + 60 * 0.7 * np.sin(2 * np.pi * time_points / duration) + 20 * np.random.randn(n_points)
    pollution = np.clip(pollution, 50, 200)

    # Humidity data (30% to 80%)
    humidity = 55 - 15 * np.sin(2 * np.pi * time_points / duration) + 10 * np.random.randn(n_points)
    humidity = np.clip(humidity, 30, 80)

    weather_data = {
        'time': time_points,
        'temperature': temperature,
        'pollution': pollution,
        'humidity': humidity,
        'vulnerability': 0.3,
        'awareness': 0.2,
        'hospital_capacity': 250,
        'urban_heat_island': 6.0
    }

    # Define scenarios
    interventions = {
        'Baseline': {},
        'Public Awareness': {'awareness_boost': 0.4},
        'Hospital Capacity': {'hospital_capacity_increase': 150},
        'Vulnerability Reduction': {'vulnerability_reduction': 0.3},
        'Combined Strategy': {
            'awareness_boost': 0.4,
            'hospital_capacity_increase': 150,
            'vulnerability_reduction': 0.3
        }
    }

    # Mock results - this would normally come from solving the ODE system
    # For each scenario, we'll create realistic mock data that shows the differences
    # between interventions while maintaining the expected model behavior
    results = {}
    population = 15000

    # Key parameters that will differ between scenarios
    peak_values = {
        'Baseline': 3500,
        'Public Awareness': 2800,
        'Hospital Capacity': 3500,  # Same peak but better outcomes
        'Vulnerability Reduction': 2500,
        'Combined Strategy': 1800
    }

    death_rates = {
        'Baseline': 0.02,
        'Public Awareness': 0.018,
        'Hospital Capacity': 0.015,
        'Vulnerability Reduction': 0.013,
        'Combined Strategy': 0.01
    }

    for scenario, interventions in interventions.items():
        # Generate curve shapes using time points
        peak_day = 10 if scenario == 'Baseline' else 12  # Interventions delay the peak

        # Hospitalization curve - bell-shaped with specific peak
        H = peak_values[scenario] * np.exp(-0.5 * ((time_points - peak_day) / 5) ** 2)

        # Susceptible curve - decreases as people get sick
        S = population - np.cumsum(np.diff(np.append(0, H))) * 0.3
        S = np.clip(S, 0, population)

        # Deaths curve - cumulative based on hospitalization and death rate
        D = np.cumsum(H * death_rates[scenario] / 30)

        # Recovered curve - people who were hospitalized but recovered
        R = np.cumsum(np.diff(np.append(0, H)) * 0.3) - D
        R = np.clip(R, 0, population)

        results[scenario] = {
            'time': time_points,
            'S': S,
            'H': H,
            'R': R,
            'D': D
        }

    return results, weather_data

def line_chart_hospitalizations(results, figsize=(12, 6)):
    """Create line chart showing hospital admissions over time for each scenario"""
    fig, ax = plt.subplots(figsize=figsize)

    scenario_names = list(results.keys())
    colors = sns.color_palette("viridis", len(scenario_names))

    for i, scenario in enumerate(scenario_names):
        ax.plot(results[scenario]['time'], results[scenario]['H'],
                label=scenario, color=colors[i], linewidth=2.5)

        # Add peak annotations
        peak_day = results[scenario]['time'][np.argmax(results[scenario]['H'])]
        peak_value = np.max(results[scenario]['H'])
        ax.scatter(peak_day, peak_value, color=colors[i], s=80, zorder=5)
        ax.annotate(f"{peak_value:.0f}",
                   (peak_day, peak_value),
                   xytext=(5, 5), textcoords='offset points',
                   fontsize=9, fontweight='bold')

    ax.set_xlabel('Days', fontsize=12)
    ax.set_ylabel('Hospital Admissions', fontsize=12)
    ax.set_title('Daily Hospital Admissions by Intervention Strategy', fontsize=14)
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    return fig

def bar_charts_comparison(results, figsize=(18, 10)):
    """Create bar charts for final death toll, peak hospitalizations, and duration of high burden"""
    fig, axes = plt.subplots(1, 3, figsize=figsize)

    scenario_names = list(results.keys())
    colors = sns.color_palette("viridis", len(scenario_names))

    # 1. Final Death Toll
    final_deaths = [results[s]['D'][-1] for s in scenario_names]
    axes[0].bar(scenario_names, final_deaths, color=colors, alpha=0.8)

    # Add value labels
    for i, v in enumerate(final_deaths):
        axes[0].text(i, v + 5, f"{v:.0f}", ha='center', fontweight='bold')

    axes[0].set_ylabel('Total Deaths', fontsize=12)
    axes[0].set_title('Final Death Toll', fontsize=14)
    plt.setp(axes[0].get_xticklabels(), rotation=45, ha='right')
    axes[0].grid(True, alpha=0.3, axis='y')

    # 2. Peak Hospitalizations
    peak_hospitalizations = [np.max(results[s]['H']) for s in scenario_names]
    axes[1].bar(scenario_names, peak_hospitalizations, color=colors, alpha=0.8)

    # Add value labels
    for i, v in enumerate(peak_hospitalizations):
        axes[1].text(i, v + 50, f"{v:.0f}", ha='center', fontweight='bold')

    axes[1].set_ylabel('Peak Hospital Admissions', fontsize=12)
    axes[1].set_title('Peak Hospitalizations', fontsize=14)
    plt.setp(axes[1].get_xticklabels(), rotation=45, ha='right')
    axes[1].grid(True, alpha=0.3, axis='y')

    # 3. Duration of High Burden
    # Define high burden as > 50% of peak baseline
    threshold = 0.5 * np.max(results['Baseline']['H'])
    burden_duration = []

    for scenario in scenario_names:
        # Count days above threshold
        days_above = np.sum(results[scenario]['H'] > threshold)
        # Convert to actual days (assuming 100 points over 30 days)
        duration = days_above * (30 / 100)
        burden_duration.append(duration)

    axes[2].bar(scenario_names, burden_duration, color=colors, alpha=0.8)

    # Add value labels
    for i, v in enumerate(burden_duration):
        axes[2].text(i, v + 0.2, f"{v:.1f}", ha='center', fontweight='bold')

    axes[2].set_ylabel('Days', fontsize=12)
    axes[2].set_title('Duration of High Hospital Burden', fontsize=14)
    plt.setp(axes[2].get_xticklabels(), rotation=45, ha='right')
    axes[2].grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    return fig

def heatmap_visualization(results, weather_data, figsize=(15, 10)):
    """Create heatmaps showing temperature and health impacts over time"""
    fig = plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(3, 1, height_ratios=[1, 1, 2], hspace=0.3)

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    ax3 = plt.subplot(gs[2])

    scenario_names = list(results.keys())
    time_points = results[scenario_names[0]]['time']
    n_days = len(time_points)

    # 1. Temperature heatmap
    temperature_data = weather_data['temperature'].reshape(1, -1)
    temp_min, temp_max = 35, 50

    # Create a temperature colormap (cool to hot)
    cmap_temp = LinearSegmentedColormap.from_list('temp_cmap', ['#ADD8E6', '#FFFF00', '#FF0000'])

    im1 = ax1.imshow(temperature_data, aspect='auto', cmap=cmap_temp,
                    extent=[0, 30, 0, 1],
                    vmin=temp_min, vmax=temp_max)

    cbar1 = plt.colorbar(im1, ax=ax1)
    cbar1.set_label('Temperature (°C)')

    ax1.set_title('Daily Temperature Pattern')
    ax1.set_ylabel('')
    ax1.set_yticks([])
    ax1.set_xticks([])

    # 2. Pollution heatmap
    pollution_data = weather_data['pollution'].reshape(1, -1)
    poll_min, poll_max = 50, 200

    # Create a pollution colormap
    cmap_poll = LinearSegmentedColormap.from_list('poll_cmap', ['#FFFFFF', '#888888', '#000000'])

    im2 = ax2.imshow(pollution_data, aspect='auto', cmap=cmap_poll,
                    extent=[0, 30, 0, 1],
                    vmin=poll_min, vmax=poll_max)

    cbar2 = plt.colorbar(im2, ax=ax2)
    cbar2.set_label('Pollution (μg/m³)')

    ax2.set_title('Daily Pollution Levels')
    ax2.set_ylabel('')
    ax2.set_yticks([])
    ax2.set_xticks([])

    # 3. Health impact heatmap
    impact_data = np.zeros((len(scenario_names), n_days))
    for i, scenario in enumerate(scenario_names):
        impact_data[i, :] = results[scenario]['H']

    # Normalize for better visualization
    max_impact = np.max(impact_data)
    normalized_impact = impact_data / max_impact

    # Create impact colormap (low to high impact)
    cmap_impact = LinearSegmentedColormap.from_list('impact_cmap', ['#E8F8F5', '#2E86C1', '#1B2631'])

    im3 = ax3.imshow(normalized_impact, aspect='auto', cmap=cmap_impact,
                   extent=[0, 30, 0, len(scenario_names)])

    cbar3 = plt.colorbar(im3, ax=ax3)
    cbar3.set_label('Hospital Admissions (normalized)')

    ax3.set_title('Intervention Impact on Hospital Admissions Over Time')
    ax3.set_xlabel('Days')
    ax3.set_ylabel('Intervention Strategy')
    ax3.set_yticks(np.arange(0.5, len(scenario_names)))
    ax3.set_yticklabels(scenario_names)

    plt.tight_layout()
    return fig

def combined_visualization(results, weather_data):
    """Run all visualizations and save them"""
    # 1. Line chart of hospitalizations
    fig1 = line_chart_hospitalizations(results)
    plt.savefig('hospitalization_linechart.png', dpi=300, bbox_inches='tight')
    plt.close(fig1)

    # 2. Bar charts comparison
    fig2 = bar_charts_comparison(results)
    plt.savefig('intervention_metrics_barcharts.png', dpi=300, bbox_inches='tight')
    plt.close(fig2)

    # 3. Heatmap visualization
    fig3 = heatmap_visualization(results, weather_data)
    plt.savefig('heatmap_visualization.png', dpi=300, bbox_inches='tight')
    plt.close(fig3)

    print("All visualizations generated successfully!")

# Run the simulation and generate all visualizations
results, weather_data = run_model_simulation()
combined_visualization(results, weather_data)

ModuleNotFoundError: No module named 'matplotlib'