In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from pathlib import Path


In [None]:

# Assuming these are already populated from your TLO model data loading code:
# - appointment_delayed_scenarios (dict: scenario_name -> list of values)
# - appointment_cancelled_scenarios (dict: scenario_name -> list of values)
# - scenario_names (list of scenario names including 'Baseline' and GCM models)

# Configuration
ssps = ["ssp245"]
service = "ANC"

# Path for precipitation data (CMIP6 downscaled)
precip_base_path = "/Users/rem76/Desktop/Climate_change_health/Data/Precipitation_data/Downscaled_CMIP6_data_CIL"

# Path for TLO model results (appointment disruptions)
results_folder = Path('/Users/rem76/PycharmProjects/TLOmodel/outputs/rm916@ic.ac.uk/climate_scenario_runs_lhs_param_scan-2026-01-25T135152Z')
output_folder = Path('/Users/rem76/PycharmProjects/TLOmodel/outputs/rm916@ic.ac.uk/climate_scenario_runs_lhs_param_scan-2026-01-25T135152Z')

for ssp_scenario in ssps:
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Get all non-baseline scenario names (i.e., the GCM model draws)
    draw_scenario_names = [s for s in scenario_names if s != 'Baseline']
    
    # Collect total disruption (delayed + cancelled) for each draw
    all_draws_total = []
    for scenario_name in draw_scenario_names:
        delayed = np.array(appointment_delayed_scenarios[scenario_name], dtype=float) * 100
        cancelled = np.array(appointment_cancelled_scenarios[scenario_name], dtype=float) * 100
        total = delayed + cancelled
        all_draws_total.append(total)
    
    all_draws_array = np.array(all_draws_total)  # Shape: (n_draws, n_years)
    
    # Calculate mean and percentiles across draws
    mean_total = np.mean(all_draws_array, axis=0)
    percentile_25 = np.percentile(all_draws_array, 25, axis=0)
    percentile_75 = np.percentile(all_draws_array, 75, axis=0)
    percentile_10 = np.percentile(all_draws_array, 10, axis=0)
    percentile_90 = np.percentile(all_draws_array, 90, axis=0)
    
    # Identify draws in central 50% around the mean
    # Calculate mean absolute distance from ensemble mean for each draw
    distances = np.mean(np.abs(all_draws_array - mean_total), axis=1)
    
    # Get indices of draws sorted by distance from mean
    sorted_indices = np.argsort(distances)
    n_central = max(1, int(len(sorted_indices) * 0.5))  # Central 50%
    central_50_indices = sorted_indices[:n_central]
    central_50_scenario_names = [draw_scenario_names[i] for i in central_50_indices]
    
    print(f"Scenarios in central 50%: {central_50_scenario_names}")
    
    # Save central 50% draws info
    central_draws_info = pd.DataFrame({
        'scenario_name': central_50_scenario_names,
        'mean_distance_from_ensemble_mean': distances[central_50_indices]
    })
    central_draws_info.to_csv(output_folder / f'central_50_percent_draws_{ssp_scenario}.csv', index=False)
    
    # Also save the actual data for central 50% scenarios
    central_draws_data = {}
    for scenario_name in central_50_scenario_names:
        central_draws_data[scenario_name] = {
            'delayed': appointment_delayed_scenarios[scenario_name],
            'cancelled': appointment_cancelled_scenarios[scenario_name]
        }
    
    # Create time axis
    n_years = len(mean_total)
    start_date = pd.date_range(start='2025-01', periods=n_years, freq='YE')
    
    # Load mean weather prediction for comparison
    weather_data_prediction_monthly = pd.read_csv(
        f"{precip_base_path}/{ssp_scenario}/mean_monthly_prediction_weather_by_facility_{service}.csv",
        dtype={'column_name': 'float64'}
    )
    mask = (weather_data_prediction_monthly.index > 11) & \
           (weather_data_prediction_monthly.index < (17 * 12))
    weather_data_prediction_monthly = weather_data_prediction_monthly.loc[mask].reset_index(drop=True)
    weather_data_avg = weather_data_prediction_monthly.iloc[:, 1:].mean(axis=1)
    yearly_precip_mean = weather_data_avg.groupby(weather_data_avg.index // 12).sum()
    
    # Plot precipitation on primary y-axis (BOLD)
    color_precip = '#1C6E8C'
    ax.plot(start_date, yearly_precip_mean, label='Mean Precipitation', 
            color=color_precip, linewidth=3, linestyle='-', zorder=5)
    ax.set_xlabel("Year")
    ax.set_ylabel("Cumulative Precipitation (mm)", color=color_precip)
    ax.tick_params(axis='y', labelcolor=color_precip)
    ax.grid(False)
    
    # Secondary y-axis for appointment disruptions with CI
    ax2 = ax.twinx()
    
    # Shaded CI regions (outer 80% CI)
    ax2.fill_between(start_date, percentile_10, percentile_90, 
                     alpha=0.15, color='#5A716A', label='80% CI')
    
    # Inner 50% CI
    ax2.fill_between(start_date, percentile_25, percentile_75, 
                     alpha=0.3, color='#5A716A', label='50% CI')
    
    # Bold mean line for total disruption
    ax2.plot(start_date, mean_total, label="Mean Total Disrupted", 
             linewidth=3, color='#5A716A', zorder=5)
    
    ax2.set_ylabel("Appointment Disruption (%)", rotation=-90, labelpad=25)
    ax2.tick_params(axis='y')
    
    # Combine legends
    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left', frameon=False)
    
    ax.text(-0.0, 1.05, '(C)', transform=ax.transAxes,
            fontsize=14, va='top', ha='right')
    
    plt.tight_layout()
    plt.savefig(output_folder / f'precipitation_disruption_ci_{ssp_scenario}.png', dpi=300, bbox_inches='tight')
    plt.show()

# Print summary
print("\n=== Central 50% Draws Summary ===")
print(central_draws_info)
print(f"\nThese {len(central_50_scenario_names)} scenarios can be used for further analysis.")