# Downloading Necessary Packages 

In [None]:
# Standard library imports
import gc
import os
import pickle
import sys
from collections import defaultdict
from datetime import datetime, timedelta
from pathlib import Path

# Third-party scientific computing
import numpy as np
import pandas as pd
from scipy import stats

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.dates as mdates
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Geospatial
import geopandas as gpd

# Machine learning
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Other utilities
import itertools
import joblib

# Custom modules - Add paths
functions_path = Path("/home/mokr/Loss_Functions_Paper/ML_Functions/")
sys.path.append(str(functions_path))

# Custom module imports
import ML_functions
from ML_functions import (
    HydroDataset,
    transform_CMAL_parameters_multi,
    run_ensemble_predictions,
    replace_keys, load_and_unnormalize
)

from ML_Plots import *
from ML_Losses import compute_log_likelihood, compute_CDF, crps_loss,  get_member_summaries_torch,
from ML_Processing import process_ensemble_predictions
from ML_Metrics import calculate_nse, calculate_kge, compute_crps, compute_crps_np, calculate_overall_crps, calculate_model_crps, count_zero_variance 

# Loading Data

In [None]:

# Set the directory where your files are saved
data_dir = "/home/mokr/Loss_Functions_Paper/forecast_results/"


# Define the variable names
variable_names = [
    "ensemble_summaries",
    "crps_per_leadtime",
    "stored_forecasts",
    # "KGE_scores",
    # "NSE_scores",
    # "basin_forecasts",
    "variogram_scores",
    "metadata",

]

# Load each file into its corresponding variable
loaded_data = {}
for name in variable_names:
    with open(os.path.join(data_dir, f"{name}_Part0_all_test.pkl"), "rb") as f:
        loaded_data[name] = pickle.load(f)

ensemble_summaries = loaded_data["ensemble_summaries"]
crps_per_leadtime = loaded_data["crps_per_leadtime"]
stored_forecasts = loaded_data["stored_forecasts"]
# KGE_scores = loaded_data["KGE_scores"]
# NSE_scores = loaded_data["NSE_scores"]
# basin_forecasts = loaded_data["basin_forecasts"]
metadata = loaded_data["metadata"]
variogram_scores = loaded_data["variogram_scores"]


In [None]:
key_map = {
    "CRPS": "Conditional",
    "NonBinary": "Probabilistic",
    "Fixed_Seeded": "Seeded (Static)",
    "Non_Fixed_Seeded": "Seeded (Variable)"
}


# Apply to your dictionaries
ensemble_summaries = replace_keys(ensemble_summaries, key_map)
crps_per_leadtime = replace_keys(crps_per_leadtime, key_map)
stored_forecasts = replace_keys(stored_forecasts, key_map)
variogram_scores = replace_keys(variogram_scores, key_map)

models = ['Conditional', 'Seeded (Static)', 'Seeded (Variable)', 'Probabilistic']
colors = ['#2E86AB', '#F18F01', '#06A77D', '#A23B72']
model_colors = dict(zip(models, colors))


In [None]:
# 1. Extract Basin IDs from Metadata
# Since metadata is a structured array with fields, we can access 'basin_idx' directly.
# This avoids a slow Python loop.
basin_ids = metadata['basin_idx']

# Initialize a dictionary to hold the DataFrames for each model
catchment_crps_dfs = {}

for model in models:    
    # crps_per_leadtime[model] shape: (695505, 10)
    df = pd.DataFrame(
        crps_per_leadtime[model], 
        columns=[f'Lead_{t}' for t in range(10)]
    )
    
    # Add the Basin ID column
    df['Basin'] = basin_ids
    
    # 3. Group by Basin and Calculate Mean
    # This results in a DataFrame where Index = Basin ID, Columns = Lead Times
    average_per_catchment = df.groupby('Basin').mean()
    
    catchment_crps_dfs[model] = average_per_catchment

# --- Example Usage ---
# View the first few rows for the 'Conditional' model
print("\nAverage CRPS per Catchment (Conditional Model):")
print(catchment_crps_dfs['Conditional'].head())

catchment_crps_dfs['Conditional']

In [None]:
# Calculate statistics for each model
leadtimes = np.arange(3, 11)
print(f"Lead times: {leadtimes}")
Tick_Size = 16

fig, ax = plt.subplots(figsize=(12, 7))
models = ['Conditional', 'Seeded (Static)', 'Seeded (Variable)', 'Probabilistic']
colors = ['#2E86AB', '#F18F01', '#06A77D', '#A23B72']
linestyles = ['-', '--', '-.', ':']
confidence_level = 0.95  # 95% confidence interval

for idx, model in enumerate(models):

    data = crps_per_leadtime[model]  # shape: (n_samples, 10)
    # Calculate mean across samples (axis=0), ignoring NaNs
    mean_vals = np.nanmean(data, axis=0)[2:]
    
    # Calculate confidence intervals for each leadtime
    ci_lower = np.zeros(len(leadtimes))
    ci_upper = np.zeros(len(leadtimes))
    
    for i in range(len(leadtimes)):
        # Get data for this leadtime, removing NaNs
        leadtime_data = data[:, i]
        leadtime_data = leadtime_data[~np.isnan(leadtime_data)]
        
        if len(leadtime_data) > 1:
            # Calculate standard error
            sem = stats.sem(leadtime_data)
            n = len(leadtime_data)
            
            # Get t-critical value
            t_critical = stats.t.ppf(1 - (1-confidence_level)/2, df=n-1)
            
            # Calculate margin of error
            margin = t_critical * sem
            
            ci_lower[i] = mean_vals[i] - margin
            ci_upper[i] = mean_vals[i] + margin

        else:
            # Not enough data for CI
            ci_lower[i] = mean_vals[i]
            ci_upper[i] = mean_vals[i]
            print(f"    WARNING: Not enough data for CI (n={len(leadtime_data)})")

    
    # Plot confidence interval as shaded region
    ax.fill_between(leadtimes, ci_lower, ci_upper, 
                     color=colors[idx], alpha=0.2)
    
    # Plot mean line
    ax.plot(leadtimes, mean_vals, 
            color=colors[idx], 
            linestyle=linestyles[idx],
            linewidth=2.5, 
            label=model,
            marker='o',
            markersize=6)
    
    print(f"✓ Successfully plotted {model}")

print(f"\n{'='*60}")
print("Setting up plot formatting...")
ax.set_xlabel('Lead Time', fontsize=18, fontweight='bold')
ax.set_ylabel('CRPS', fontsize=18, fontweight='bold')
ax.set_title('CRPS by Lead Time (with 95% Confidence Intervals)', 
             fontsize=22, fontweight='bold', pad=20)
ax.legend(loc='best', fontsize=16, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle='--')
ax.tick_params(axis='both', which='major', labelsize=Tick_Size)
ax.set_xticks(leadtimes)
plt.tight_layout()
print("Displaying plot...")
plt.show()
# print("Done!")

In [None]:

# Configuration
num_leadtimes = 10
num_members = 11
num_bins = num_members + 1  # 12 bins for 11 members

# Create figure with subplots for each leadtime
fig = plt.figure(figsize=(25, 12))
fig.suptitle('Rank Histograms by Lead Time', fontsize=28, fontweight='bold', y=0.98)

# Create 2x5 grid for 10 leadtimes
axes = fig.subplots(2, 5)
axes = axes.flatten()

# Global legend storage
global_handles = []
global_labels = []

# Process each leadtime
for leadtime_idx in range(num_leadtimes):
    ax = axes[leadtime_idx]
    
    # Calculate ranks for each model at this leadtime
    model_rank_distributions = {}
    
    for model in models:
        # Get forecasts and truth for this leadtime
        forecasts = stored_forecasts[model][:, :, leadtime_idx]  # (n_samples, n_members)
        truth = stored_forecasts['Discharge'][:, 0, leadtime_idx]  # (n_samples,)
        
        # Remove any NaN samples
        valid_mask = ~(np.isnan(forecasts).any(axis=1) | np.isnan(truth))
        forecasts = forecasts[valid_mask]
        truth = truth[valid_mask]
        
        # Calculate ranks: count how many ensemble members are below the observation
        # ranks[i] = number of ensemble members < truth[i]
        ranks = np.sum(forecasts < truth[:, np.newaxis], axis=1)
        
        # Create histogram (bins 0, 1, 2, ..., 11)
        hist, _ = np.histogram(ranks, bins=np.arange(num_bins + 1), density=True)
        
        model_rank_distributions[model] = hist
    
    # Plot bars for each model
    num_models = len(models)
    x = np.arange(num_bins)
    width = 0.8 / num_models
    
    for j, model in enumerate(models):
        hist = model_rank_distributions[model]
        
        bar = ax.bar(
            x + j * width,
            hist,
            width=width,
            label= model,
            color=colors[j],
            edgecolor='black',
            linewidth=0.5,
            alpha=0.8
        )
        
        # Store handles/labels only from first subplot
        if leadtime_idx == 0:
            global_handles.append(bar)
            global_labels.append(model)
    
    # Add uniform distribution reference line
    ideal_probability = 1.0 / num_bins
    uniform_line = ax.axhline(y=ideal_probability, color='red', linestyle='--', 
                              linewidth=2, alpha=0.7, zorder=10, label='Uniform')
    
    # Add uniform line to legend (only once)
    if leadtime_idx == 0:
        global_handles.append(uniform_line)
        global_labels.append('Uniform')
    
    # Formatting
    ax.set_title(f"Lead Time {leadtime_idx + 1}", fontsize=18, fontweight='bold', pad=10)
    ax.set_xlabel("Ensemble Members Below Observation", fontsize=14, labelpad=5)
    ax.set_ylabel("Probability", fontsize=14, fontweight='bold')
    ax.set_xticks(x + width * (num_models - 1) / 2)
    ax.set_xticklabels(range(num_bins), fontsize=12)
    ax.tick_params(axis='y', labelsize=12)
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0, 0.35)  # Adjust based on your data
    
    # Disable individual legends
    if ax.get_legend():
        ax.legend().set_visible(False)

# Create single legend at the top
fig.legend(handles=global_handles, labels=global_labels,
           loc='upper center', bbox_to_anchor=(0.5, 0.95),
           ncol=len(global_labels), fontsize=16, frameon=False,
           fancybox=False, shadow=True)

plt.tight_layout()
plt.subplots_adjust(top=0.85, hspace=0.3, wspace=0.3)
plt.show()

# Diagnosing Spread Issues

In [None]:
# Define CRPS intervals
intervals = [(0, 0.1), (0.1, 1), (1, 10), (10, np.inf)]
interval_labels = ['0-0.1', '0.1-1', '1-10', '10+']

print("CRPS Distribution Across Models")
print("="*80)

# Create a dictionary to store results
results = {}

for model in models:
    data = crps_per_leadtime[model]  # shape: (n_samples, 10)
    
    # Flatten to get all CRPS values for this model
    all_crps = data.flatten()
    
    # Remove NaNs
    all_crps = all_crps[~np.isnan(all_crps)]
    
    total_count = len(all_crps)
    
    print(f"\n{model}:")
    print(f"  Total predictions: {total_count}")
    
    # Count how many fall in each interval
    counts = []
    percentages = []
    
    for (lower, upper), label in zip(intervals, interval_labels):
        if upper == np.inf:
            count = np.sum(all_crps >= lower)
        else:
            count = np.sum((all_crps >= lower) & (all_crps < upper))
        
        pct = (count / total_count * 100) if total_count > 0 else 0
        counts.append(count)
        percentages.append(pct)
        
        print(f"  {label:>10}: {count:6d} ({pct:5.1f}%)")
    
    results[model] = {
        'counts': counts,
        'percentages': percentages,
        'total': total_count
    }

# Create a nice summary table
print("\n" + "="*80)
print("\nSUMMARY TABLE (Counts)")
print("-"*80)

# Create DataFrame for counts
count_data = {model: results[model]['counts'] for model in models}
count_df = pd.DataFrame(count_data, index=interval_labels)
print(count_df.to_string())

print("\n" + "-"*80)
print("\nSUMMARY TABLE (Percentages)")
print("-"*80)

# Create DataFrame for percentages
pct_data = {model: [f"{p:.1f}%" for p in results[model]['percentages']] for model in models}
pct_df = pd.DataFrame(pct_data, index=interval_labels)
print(pct_df.to_string())
# Additional analysis: mean CRPS per interval
print("\n" + "="*80)
print("\nMEAN CRPS WITHIN EACH INTERVAL (per model)")
print("-"*80)

for model in models:
    data = crps_per_leadtime[model]
    all_crps = data.flatten()
    all_crps = all_crps[~np.isnan(all_crps)]
    
    # Calculate overall mean CRPS for this model
    overall_mean = np.mean(all_crps)
    total_count = len(all_crps)
    
    print(f"\n{model}:")
    print(f"  Overall mean CRPS: {overall_mean:.2f}")
    print(f"  Total predictions: {total_count}")
    print(f"  {'Interval':>10} | {'Mean':>8} | {'Count':>6} | {'% of total':>10} | {'Contribution':>12}")
    print(f"  {'-'*10}-+-{'-'*8}-+-{'-'*6}-+-{'-'*10}-+-{'-'*12}")
    
    total_contribution = 0
    
    for (lower, upper), label in zip(intervals, interval_labels):
        if upper == np.inf:
            interval_data = all_crps[all_crps >= lower]
        else:
            interval_data = all_crps[(all_crps >= lower) & (all_crps < upper)]
        
        if len(interval_data) > 0:
            mean_crps = np.mean(interval_data)
            count = len(interval_data)
            percentage = (count / total_count) * 100
            
            # Calculate contribution to overall mean
            # Contribution = (count / total_count) * mean_of_interval
            contribution = (count / total_count) * mean_crps
            total_contribution += contribution
            
            print(f"  {label:>10} | {mean_crps:8.2f} | {count:6d} | {percentage:9.1f}% | {contribution:12.2f}")
        else:
            print(f"  {label:>10} | {'no data':>8} | {0:6d} | {0:9.1f}% | {0:12.2f}")
    
    print(f"  {'-'*10}-+-{'-'*8}-+-{'-'*6}-+-{'-'*10}-+-{'-'*12}")
    print(f"  {'TOTAL':>10} | {overall_mean:8.2f} | {total_count:6d} | {100.0:9.1f}% | {total_contribution:12.2f}")
    
    # Verification check
    if abs(total_contribution - overall_mean) > 0.01:
        print(f"  WARNING: Contribution sum ({total_contribution:.2f}) doesn't match overall mean ({overall_mean:.2f})")


# Check if best model is really better at high-error cases
print("\n" + "="*80)
print("\nHYPOTHESIS CHECK: Does your best model perform relatively better at high errors?")
print("-"*80)

for (lower, upper), label in zip(intervals, interval_labels):
    print(f"\n{label} interval:")
    
    interval_means = {}
    interval_contributions = {}
    
    for model in models:
        data = crps_per_leadtime[model]
        all_crps = data.flatten()
        all_crps = all_crps[~np.isnan(all_crps)]
        total_count = len(all_crps)
        
        if upper == np.inf:
            interval_data = all_crps[all_crps >= lower]
        else:
            interval_data = all_crps[(all_crps >= lower) & (all_crps < upper)]
        
        if len(interval_data) > 0:
            interval_means[model] = np.mean(interval_data)
            interval_contributions[model] = (len(interval_data) / total_count) * interval_means[model]
        else:
            interval_means[model] = np.nan
            interval_contributions[model] = 0
    
    # Sort models by performance in this interval
    sorted_models = sorted(interval_means.items(), key=lambda x: x[1] if not np.isnan(x[1]) else np.inf)
    
    for rank, (model, mean_val) in enumerate(sorted_models, 1):
        if not np.isnan(mean_val):
            contrib = interval_contributions[model]
            print(f"  {rank}. {model:25s}: mean = {mean_val:8.2f}, contribution = {contrib:8.2f}")

# Summary: Which interval contributes most to each model's overall CRPS?
print("\n" + "="*80)
print("\nWHICH INTERVAL CONTRIBUTES MOST TO OVERALL CRPS?")
print("-"*80)

for model in models:
    data = crps_per_leadtime[model]
    all_crps = data.flatten()
    all_crps = all_crps[~np.isnan(all_crps)]
    total_count = len(all_crps)
    
    contributions = []
    
    for (lower, upper), label in zip(intervals, interval_labels):
        if upper == np.inf:
            interval_data = all_crps[all_crps >= lower]
        else:
            interval_data = all_crps[(all_crps >= lower) & (all_crps < upper)]
        
        if len(interval_data) > 0:
            mean_crps = np.mean(interval_data)
            contribution = (len(interval_data) / total_count) * mean_crps
            contributions.append((label, contribution))
        else:
            contributions.append((label, 0))
    
    # Sort by contribution
    contributions.sort(key=lambda x: x[1], reverse=True)
    
    print(f"\n{model}:")
    for rank, (label, contrib) in enumerate(contributions, 1):
        pct_of_total = (contrib / np.mean(all_crps)) * 100 if np.mean(all_crps) > 0 else 0
        print(f"  {rank}. {label:>10}: {contrib:8.2f} ({pct_of_total:5.1f}% of total CRPS)")
        


# Making CDF Plots

In [None]:
# Extract catchment IDs from metadata
catchments = np.array(metadata['basin_idx'])  # Keep as string array

# Find boundaries where catchment changes
change_points = np.where(catchments[:-1] != catchments[1:])[0] + 1
boundaries = np.concatenate([[0], change_points, [len(catchments)]])

# Get unique catchments in order of appearance
unique_catchments = [catchments[boundaries[i]] for i in range(len(boundaries) - 1)]

# Models to process


# Create results more efficiently
results = []

for model in models:
    data = crps_per_leadtime[model]  # shape: (200000, 10)
    
    for catchment_idx, catchment in enumerate(unique_catchments):
        start = boundaries[catchment_idx]
        end = boundaries[catchment_idx + 1]
        
        # Extract all leadtimes for this catchment at once
        catchment_data = data[start:end, :]  # shape: (n_samples, 10)
        
        # Calculate mean for all leadtimes at once
        mean_crps = np.nanmean(catchment_data, axis=0)
        counts = np.sum(~np.isnan(catchment_data), axis=0)
        
        # Add results for all leadtimes
        for leadtime in range(1, 11):
            results.append({
                'model': model,
                'catchment': catchment,
                'leadtime': leadtime,
                'mean_crps': mean_crps[leadtime - 1],
                'n_samples': counts[leadtime - 1]
            })

# Create DataFrame
df_catchment = pd.DataFrame(results)

# Pivot to get leadtimes as columns
df_pivot = df_catchment.pivot_table(
    index=['model', 'catchment'],
    columns='leadtime',
    values='mean_crps'
)
df_pivot.columns = [f'leadtime_{i}' for i in df_pivot.columns]
df_pivot = df_pivot.reset_index()

print("Long format DataFrame shape:", df_catchment.shape)


In [None]:
# Models to plot

leadtimes = np.arange(1, 11)
Title_Size = 28
# Create subplots - one for each leadtime
fig, axes = plt.subplots(1, 2, figsize=(24, 10))
axes = axes.flatten()

colors = ['#2E86AB', '#F18F01', '#06A77D', '#A23B72']
linestyles = ['-', '--', '-.', ':']
linewidth = 2.5

for lt_idx, ax in enumerate(axes):
    leadtimes = [3, 10]
    leadtime = leadtimes[lt_idx]
    
    for model_idx, model in enumerate(models):
        # Get catchment-averaged CRPS for this model and leadtime
        model_data = df_pivot[df_pivot['model'] == model]
        col = f'leadtime_{leadtime}'
        data = model_data[col].dropna().values
        
        if len(data) > 0:
            # Sort data for CDF
            sorted_data = np.sort(data)
            # Calculate cumulative probabilities
            cum_prob = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
            
            ax.plot(sorted_data, cum_prob,
                   color=colors[model_idx],
                   linestyle=linestyles[model_idx],
                   linewidth=linewidth,
                   alpha=0.85)
    
    ax.set_xlabel('CRPS (Catchment Average)', fontsize=Title_Size, fontweight='bold')
    ax.set_ylabel('Cumulative Probability', fontsize=Title_Size, fontweight='bold')
    ax.set_title(f'Lead Time {leadtime}', fontsize=Title_Size, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_ylim(0, 1)
    ax.set_xscale('log')
    
    # Make tick labels bigger
    ax.tick_params(axis='both', which='major', labelsize=24)
    # ax.tick_params(axis='both', which='minor', labelsize=16)

# Create custom legend elements
legend_elements = [Line2D([0], [0], color=colors[i], linestyle=linestyles[i], 
                          linewidth=linewidth, label=models[i]) 
                   for i in range(len(models))]

# Add legend at the top of the figure
fig.legend(handles=legend_elements, loc='upper center', ncol=4, 
          frameon=False, fontsize=30, bbox_to_anchor=(0.5, 0.984))

plt.suptitle('Cumulative Distribution Functions of Catchment-Averaged CRPS by Lead Time', 
             fontsize=36, fontweight='bold', y=0.998)
plt.tight_layout(rect=[0, 0, 1, 0.98])  # Leave space at top for legend
plt.show()


# Getting hydrographs over 3 years

In [None]:
scaler_path = '/home/mokr/Loss_Functions_Paper/Scalers/discharge_caravan_scalers.joblib'
scalers = joblib.load(scaler_path)

In [None]:

# Extract basin IDs and dates from metadata
basin_ids = np.array([meta[1] for meta in metadata])
dates = np.array([datetime.strptime(meta[0], '%Y-%m-%d %H:%M:%S') for meta in metadata])

# Get unique basins while preserving order
unique_basins = []
seen = set()
for basin in basin_ids:
    if basin not in seen:
        unique_basins.append(basin)
        seen.add(basin)

# Split forecasts by model and basin
forecasts_by_model_and_basin = {}
for model in models:
    forecasts_by_model_and_basin[model] = {}
    for basin in unique_basins:
        mask = basin_ids == basin
        forecasts_by_model_and_basin[model][basin] = stored_forecasts[model][mask]

# Split truth by basin
truth_by_basin = {}
for basin in unique_basins:
    mask = basin_ids == basin
    truth_by_basin[basin] = stored_forecasts['Discharge'][mask][:, 0, :]


average_crps_by_basin = {}
for model in models:
    average_crps_by_basin[model] = {}
    for basin in unique_basins:
        mask = basin_ids == basin
        average_crps_by_basin[model][basin] = np.mean(crps_per_leadtime[model][mask])


# Split metadata and dates by basin
metadata_by_basin = {}
dates_by_basin = {}
for basin in unique_basins:
    mask = basin_ids == basin
    metadata_by_basin[basin] = metadata[mask]
    dates_by_basin[basin] = dates[mask]

average_crps_by_basin = {}
for model in models:
    average_crps_by_basin[model] = {}
    for basin in unique_basins:
        mask = basin_ids == basin
        average_crps_by_basin[model][basin] = np.mean(crps_per_leadtime[model][mask], axis = 1)
    
print(f"Scaled example: {forecasts_by_model_and_basin[models[0]][unique_basins[0]][0, 0, 0]}")

In [None]:
metadata_by_basin['camelscl_5707002'][0]

In [None]:
# Or fix the basin and randomize sample index
plots = generate_random_plots(
    forecasts_by_model_and_basin, 
    truth_by_basin,
    metadata_by_basin,
    models=models,
    colors=colors,
    n_plots=1,
    basin_idx = 'hysets_0422026250',
    sample_idx = 727,
    average_crps_by_basin=average_crps_by_basin
)
# Plot 1: Basin = hysets_01589290, Sample = 40
# Plot 6: Basin = hysets_0422026250, Sample = 727


In [None]:
crps_per_leadtime['Conditional'].shape


In [None]:
# Create the 2x2 grid
basins = ['hysets_11042631', 'camelscl_5707002']
leadtimes = [2, 9]  # Leadtime 3 (index 2) and Leadtime 10 (index 9)

plot_hydrograph_grid(
    forecasts_by_model_and_basin, 
    truth_by_basin,
    dates_by_basin,
    basins, 
    'Conditional', 
    leadtimes
)

In [None]:
basin = 'hysets_06177500' # Flashy arid hysets_08330600, hysets_11042631

plot_hydrograph(
    forecasts_by_model_and_basin,
    truth_by_basin,
    dates_by_basin,
    basin='hysets_07311900',
    model='Conditional',
    leadtime_idx=2
)

 # camelscl_5707002 Probabilistic

# Recording the overal scores of each model

- NSE. KGE, Autoregression prediction etc. scores for random ensemble members
- CRPS scores for each leadtime
- CRPS scores of how well it predicts Autoregression, Mean Flow, Variance of Flow etc.

In [None]:
model_crps = calculate_model_crps(
    ensemble_summaries, 
    model_names=['Conditional'],
    metrics=['total_flow', 'variance', 'autoregression'])

In [None]:
model_crps

In [None]:
# Calculate CRPS with confidence intervals
model_names = ['Conditional', 'Seeded (Static)', 'Seeded (Variable)', 'Probabilistic']
metrics = ['total_flow', 'variance', 'autoregression', 'gamma', 'num_rise']

model_crps_mean, model_crps_ci = calculate_crps_with_ci(
    ensemble_summaries, model_names, metrics, confidence=0.95
)

# Create comparison DataFrame with confidence intervals
comparison_data = {'CRPS Metric': ['Total Flow', 'Variance', 'Autoregression', 'Gamma', 'Num Rise']}

# Add each model's results with ± CI
for model_name in model_names:
    if model_name in model_crps_mean:
        comparison_data[model_name] = [
            f"{model_crps_mean[model_name][metric]:.3g} ± {model_crps_ci[model_name][metric]:.3g}"
            for metric in metrics
        ]

# Create DataFrame
comparison_df = pd.DataFrame(comparison_data)

comparison_df.T

In [None]:
comparison_df.T

In [None]:
model_names = ['Conditional', 'Seeded (Static)', 'Seeded (Variable)', 'Probabilistic']
metrics = ['total_flow', 'variance', 'autoregression', 'gamma', 'num_rise']

# Calculate all CRPS values at once
model_crps = calculate_model_crps(ensemble_summaries, model_names, metrics)

# Create comparison DataFrame
comparison_data = {'CRPS Metric': ['Total Flow', 'Variance', 'Autoregression',  'Gamma', 'Num Rise']}

# Add each model's results with custom display names
for i, model_name in enumerate(model_names):
    if model_name in model_crps:
        comparison_data[model_name] = [
            float(f"{model_crps[model_name][metric]:.3g}") for metric in metrics
        ]
# Create DataFrame and set precision
comparison_df = pd.DataFrame(comparison_data)
pd.set_option('display.precision', 4)

# Display transposed version
comparison_df.T


In [None]:
model_crps_per_example = calculate_model_crps(
    ensemble_summaries, 
    model_names=['Conditional', 'Seeded (Static)', 'Seeded (Variable)', 'Probabilistic'],  # Add your model names
    metrics=['total_flow', 'autoregression', 'gamma', 'num_rise'],  
    return_mean=False
)


flow_crps_df = create_metric_crps_df('total_flow', model_names, model_names, model_crps_per_example)
autoregression_crps_df = create_metric_crps_df('autoregression', model_names, model_names, model_crps_per_example)
gamma_crps_df = create_metric_crps_df('gamma', model_names, model_names, model_crps_per_example)
num_rise_crps_df = create_metric_crps_df('num_rise', model_names, model_names, model_crps_per_example)




print(flow_crps_df.shape)  # Should be (695505, n_models)
flow_crps_df.head()

# Plotting Spread

In [None]:
bins = 11
dispersion_df = get_dispersion_calculations(ensemble_summaries, ensemble_summaries['Discharge'])

metrics = ['total_flow', 'autoregression', 'gamma', 'num_rise']


In [None]:
dispersion_df = dispersion_df.drop(columns=[col for col in dispersion_df.columns if 'Discharge' in col])

In [None]:
fig = plt.figure(figsize=(16, 15))
fig.suptitle('Rank Histograms of Temporal Characteristics', fontsize=28, fontweight='bold', y=0.93)

# Create subplots
axes = fig.subplots(2, 2)
axes = axes.flatten()
models = ['Conditional', 'Seeded (Static)', 'Seeded (Variable)', 'Probabilistic']
colors = ['#2E86AB', '#F18F01', '#06A77D', '#A23B72']

# Lists to store handles and labels for the global legend
global_handles = []
global_labels = []
Tick_Size = 20
Axis_Size = 22

for i, metric in enumerate(metrics):
    ax = axes[i]
    
    # Extract data for current metric
    metric_columns = [col for col in dispersion_df.columns if metric in col]
    
    model_names = []
    metric_values = []
    
    for col in metric_columns:
        # Extract model name
        model_name = col.replace(f'_{metric}', '')
        
        display_name = model_name_mapping.get(model_name, model_name) if 'model_name_mapping' in locals() else model_name
            
        model_names.append(display_name)
        metric_values.append(dispersion_df[col].values / 100) 
    
    # Rank bins
    rank_bins = dispersion_df.index.tolist()
    
    # Bar plot setup
    num_models = len(model_names)
    x = np.arange(len(rank_bins)) 
    width = 0.8 / num_models 
    
    for j, values in enumerate(metric_values):
        # --- FIX 2: Capture the bar object for the legend ---
        bar = ax.bar(
            x + j * width, 
            values, 
            width=width, 
            label=model_names[j], # Label is assigned here
            color=colors[j % len(colors)],
            edgecolor='black',
            linewidth=0.5
        )
        
        # Store handles/labels ONLY from the first subplot to avoid duplicates
        if i == 0:
            global_handles.append(bar)
            global_labels.append(model_names[j])
    
    # --- FIX 3: Add Red Line to Legend as 'Uniform' ---
    ideal_probability = 1.0 / len(rank_bins)
    uniform_line = ax.axhline(y=ideal_probability, color='red', linestyle='--', linewidth=2, 
               alpha=0.7, zorder=10, label='Uniform')
    
    # Add the Uniform line to the global legend lists (only once)
    if i == 0:
        global_handles.append(uniform_line)
        global_labels.append('Uniform')

    # Improved formatting
    metric_title = metric.replace('_', ' ').title()
    ax.set_title(f"{metric_title}", fontsize=22, fontweight='bold', pad=10)
    ax.set_xlabel("Ensemble Members Below Observation", fontsize=Axis_Size, labelpad=1)
    ax.set_ylabel("Probability", fontsize=Axis_Size, fontweight='bold')
    ax.set_xticks(x + width * (num_models - 1) / 2)
    ax.set_xticklabels(rank_bins, fontsize=Tick_Size)
    ax.tick_params(axis='y', labelsize=Tick_Size)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Disable individual subplot legends
    if ax.get_legend():
        ax.legend().set_visible(False)

for ax in axes:
    ax.set_ylim(0, 0.6)

# Create single legend at the top using the collected HANDLES
# Note: ncol includes len(models) + 1 for the Uniform line
fig.legend(handles=global_handles, labels=global_labels, 
           loc='upper center', bbox_to_anchor=(0.5, 0.9), 
           ncol=len(global_labels), fontsize=20, frameon=False, 
           fancybox=False, shadow=True)

plt.tight_layout()
plt.subplots_adjust(top=0.82, hspace=0.25)
plt.show()

In [None]:
model_crps = calculate_model_crps(
    ensemble_summaries, 
    model_names=['Conditional'],
    metrics=['total_flow', 'variance', 'autoregression'])

model_crps_per_example = calculate_model_crps(
    ensemble_summaries, 
    model_names=['Conditional', 'Seeded (Static)', 'Seeded (Variable)', 'Probabilistic'],  # Add your model names
    metrics=['total_flow', 'autoregression', 'gamma', 'num_rise'],  
    return_mean=False
)


In [None]:
metadata['basin_idx']

In [None]:


flow_crps_df = create_metric_crps_df('total_flow', model_names, model_names, model_crps_per_example)
autoregression_crps_df = create_metric_crps_df('autoregression', model_names, model_names, model_crps_per_example)
gamma_crps_df = create_metric_crps_df('gamma', model_names, model_names, model_crps_per_example)
num_rise_crps_df = create_metric_crps_df('num_rise', model_names, model_names, model_crps_per_example)




print(flow_crps_df.shape)  # Should be (695505, n_models)
flow_crps_df.head()

# Making CRPS Map

In [None]:
# 1. Add the basins array as a column
basins = metadata['basin_idx']
flow_crps_df['Basin'] = basins
autoregression_crps_df['Basin'] = basins
gamma_crps_df['Basin'] = basins
num_rise_crps_df['Basin'] = basins

# 2. Group by Basin and calculate the mean over the arrays in each cell
# np.stack turns the list of arrays into a matrix, axis=0 averages across the rows (forecasts)
# 1. Group by Basin and average the arrays (as established before)
flow_crps_df['Basin'] = basins
grouped = flow_crps_df.groupby('Basin', as_index=False).agg(lambda x: np.mean(np.stack(x), axis=0))

# 2. Melt (Model to column) and Explode (expand lists into rows)
long_flow_df = grouped.melt(id_vars='Basin', var_name='Model', value_name='CRPS').explode('CRPS')

# 3. Add a Leadtime column (automatically counts 1, 2, 3... for each group)
long_flow_df['Leadtime'] = long_flow_df.groupby(['Basin', 'Model']).cumcount() + 1

# Ensure the CRPS column is numeric (explode sometimes leaves it as object)
long_flow_df['CRPS'] = long_flow_df['CRPS'].astype(float)

print(long_flow_df.head())

In [None]:
long_flow_df = process_crps_df(flow_crps_df, basins, 'total_flow')
long_autoregression_df = process_crps_df(autoregression_crps_df, basins, 'autoregression')
long_gamma_df = process_crps_df(gamma_crps_df, basins, 'gamma')
long_num_rise_df = process_crps_df(num_rise_crps_df, basins, 'num_rise')
    

In [None]:
Conditional_Flow_DF = long_flow_df[long_flow_df['Model'] == 'Conditional']

In [None]:
basin_tables = create_basin_crps_tables(crps_per_leadtime, basins)


In [None]:
CRPS_df = add_coordinates_to_basin_table(basin_tables['Conditional'], csv_base_path="/perm/mokr/Caravans/Caravan/attributes/")
Probabilistic_df = add_coordinates_to_basin_table(basin_tables['Probabilistic'], csv_base_path="/perm/mokr/Caravans/Caravan/attributes/")

Flow_df = add_coordinates_to_basin_table(long_flow_df, csv_base_path="/perm/mokr/Caravans/Caravan/attributes/")
Autoregression_df = add_coordinates_to_basin_table(long_autoregression_df, csv_base_path="/perm/mokr/Caravans/Caravan/attributes/")
Gamma_df = add_coordinates_to_basin_table(long_gamma_df, csv_base_path="/perm/mokr/Caravans/Caravan/attributes/")
NumRise_df = add_coordinates_to_basin_table(long_num_rise_df, csv_base_path="/perm/mokr/Caravans/Caravan/attributes/")

In [None]:
# Create all dictionaries
Flow_dict = create_model_dict(Flow_df)
Autoregression_dict = create_model_dict(Autoregression_df)
Gamma_dict = create_model_dict(Gamma_df)
NumRise_dict = create_model_dict(NumRise_df)

In [None]:
shapefile_path="/home/mokr/GLOFAS_ML_Flood_Modelling/data/ne_110m_admin_0_countries/raw/ne_110m_admin_0_countries.shp"
world = gpd.read_file(shapefile_path)

In [None]:
# Merge on coordinates
leadtime = 3
merged = CRPS_df[['Basin', 'Latitude','Longitude',f'Leadtime_{leadtime}']].merge(
    Probabilistic_df[['Basin','Latitude','Longitude',f'Leadtime_{leadtime}']],
    on=['Basin', 'Latitude','Longitude'],
    suffixes=("_CRPS", "_Prob")
)

# Difference: CRPS - Probabilistic
merged['difference'] = 1 - (merged[f'Leadtime_{leadtime}_CRPS']/merged[f'Leadtime_{leadtime}_Prob']) 
merged['abs_difference'] = merged[f'Leadtime_{leadtime}_CRPS'] - merged[f'Leadtime_{leadtime}_Prob']

# Convert to GeoDataFrame
gdf_diff = gpd.GeoDataFrame(
    merged,
    geometry=gpd.points_from_xy(merged.Longitude, merged.Latitude),
    crs="EPSG:4326"
)

In [None]:
Flow_dict['Conditional']
Flow_dict['Probabilistic']
Flow_dict['Conditional'][Flow_dict['Conditional']['Location'] == 'camelscl']

In [None]:
merged.sort_values('difference')

In [None]:


# Sort the differences
sorted_diff = np.sort(merged['difference'])

# Calculate the cumulative probabilities
cumulative_prob = np.arange(1, len(sorted_diff) + 1) / len(sorted_diff)

# Create the CDF plot
plt.figure(figsize=(10, 6))
plt.plot(sorted_diff, cumulative_prob, linewidth=2)
plt.xlabel('Difference (1 - CRPS/Probabilistic)', fontsize=12)
plt.ylabel('Cumulative Probability', fontsize=12)
plt.title('CDF of Performance Difference (CRPS vs Probabilistic)', fontsize=14)
plt.grid(True, alpha=0.3)

# Add reference lines
plt.axvline(x=0, color='red', linestyle='--', alpha=0.5, label='No difference')
plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.3, label='Median')

# Add median value annotation
median_diff = np.median(merged['difference'])
plt.axvline(x=median_diff, color='green', linestyle='--', alpha=0.5, label=f'Median: {median_diff:.4f}')

plt.legend()
plt.tight_layout()
plt.show()

# Print summary statistics
print(f"Summary Statistics for Difference:")
print(f"Mean: {merged['difference'].mean():.4f}")
print(f"Median: {median_diff:.4f}")
print(f"Std Dev: {merged['difference'].std():.4f}")
print(f"Min: {merged['difference'].min():.4f}")
print(f"Max: {merged['difference'].max():.4f}")
print(f"% of locations where CRPS < Probabilistic: {(merged['difference'] > 0).mean()*100:.2f}%")

In [None]:

# --- Choose discrete difference bounds (adjust as needed) ---
bounds = [-4, -0.25, -0.05, 0.05, 0.25, 1]
base_cmap = plt.get_cmap('RdBu')   # diverging colormap
n_colours = len(bounds) - 1  # Number of color segments
colors = base_cmap([0.1, 0.25, 0.5, 0.75, 0.9])
cmap = mcolors.ListedColormap(colors)
# norm = mcolors.TwoSlopeNorm(0, -1, 1)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# --- 2 Subplots ---
fig, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [2.35, 1]}, figsize=(20, 10))
fig.suptitle('Conditional-LSTM CRPSS at 3-Day Leadtime', fontsize=32, fontweight='bold', y=0.91)

Title_Size = 28
Tick_Size = 22

# ========== US MAP ==========
world.plot(ax=ax1, color='lightgray', edgecolor='black')

us_gdf = gdf_diff[(gdf_diff.geometry.x >= -130) & (gdf_diff.geometry.x <= -50) &
                  (gdf_diff.geometry.y >= 25) & (gdf_diff.geometry.y <= 60)]

us_gdf.plot(
    column='difference',
    ax=ax1,
    cmap=cmap,
    norm=norm,
    legend=False,
    markersize=20
)

ax1.set_title('United States & Canada', fontsize=Title_Size)
ax1.set_xlim([-130, -50])
ax1.set_ylim([25, 60])

# ========== BRAZIL MAP ==========
world.plot(ax=ax2, color='lightgray', edgecolor='black')

brazil_gdf = gdf_diff[(gdf_diff.geometry.x >= -75) & (gdf_diff.geometry.x <= -35) &
                      (gdf_diff.geometry.y >= -35) & (gdf_diff.geometry.y <= 10)]

brazil_gdf.plot(
    column='difference',
    ax=ax2,
    cmap=cmap,
    norm=norm,
    legend=False,
    markersize=20
)

ax2.set_title('South America', fontsize=Title_Size)
ax2.set_xlim([-85, -33])
ax2.set_ylim([-60, 15])

for ax in [ax1, ax2]:
    ax.tick_params(axis='both', which='major', labelsize=Tick_Size)
    ax.set_xlabel('Longitude', fontsize=28, fontweight='bold')
    ax.set_ylabel('Latitude', fontsize=28, fontweight='bold')

# --- Single Colorbar ---
divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="5%", pad=0.1)

sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, cax=cax)
cbar.ax.tick_params(labelsize=Tick_Size)
cbar.set_label("CRPSS", fontsize=28, fontweight='bold')

plt.tight_layout()
plt.show()



In [None]:
# Usage:
metric_dicts = {
    'Total_Flow': Flow_dict,
    'Autocorrelation': Autoregression_dict,
    'Gamma': Gamma_dict,
    'NumRise': NumRise_dict
}

CRPSS_Stats = calculate_crpss('Conditional', 'Probabilistic', metric_dicts)

CRPSS_Stats['Latitude'] = Flow_dict['Conditional']['Latitude'].reset_index(drop=True)
CRPSS_Stats['Longitude'] = Flow_dict['Conditional']['Longitude'].reset_index(drop=True)
# Convert to GeoDataFrame
gdf_diff = gpd.GeoDataFrame(
    CRPSS_Stats,
    geometry=gpd.points_from_xy(CRPSS_Stats.Longitude, CRPSS_Stats.Latitude),
    crs="EPSG:4326"
)

In [None]:
# --- Choose discrete difference bounds (adjust as needed) ---
bounds = [-2, -0.25, -0.05, 0.05, 0.25, 1]
base_cmap = plt.get_cmap('RdBu')   # diverging colormap
n_colours = len(bounds) - 1  # Number of color segments
colors = base_cmap([0.1, 0.25, 0.5, 0.75, 0.9])
cmap = mcolors.ListedColormap(colors)
# norm = mcolors.TwoSlopeNorm(0, -1, 1)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# --- 2 Subplots ---
fig, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [2.35, 1]}, figsize=(20, 10))
fig.suptitle('CRPSS of NumRise', fontsize=24, fontweight='bold', y=0.91)

Title_Size = 18

# ========== US MAP ==========
world.plot(ax=ax1, color='lightgray', edgecolor='black')

us_gdf = gdf_diff[(gdf_diff.geometry.x >= -130) & (gdf_diff.geometry.x <= -50) &
                  (gdf_diff.geometry.y >= 25) & (gdf_diff.geometry.y <= 60)]

us_gdf.plot(
    column='NumRise',
    ax=ax1,
    cmap=cmap,
    norm=norm,
    legend=False,
    markersize=20
)

ax1.set_title('US: Conditional-LSTM CRPSS', fontsize=Title_Size)
ax1.set_xlim([-130, -50])
ax1.set_ylim([25, 60])

# ========== BRAZIL MAP ==========
world.plot(ax=ax2, color='lightgray', edgecolor='black')

brazil_gdf = gdf_diff[(gdf_diff.geometry.x >= -75) & (gdf_diff.geometry.x <= -35) &
                      (gdf_diff.geometry.y >= -35) & (gdf_diff.geometry.y <= 10)]

brazil_gdf.plot(
    column='Total_Flow',
    ax=ax2,
    cmap=cmap,
    norm=norm,
    legend=False,
    markersize=20
)

ax2.set_title('South America: Conditional-LSTM CRPSS', fontsize=Title_Size)
ax2.set_xlim([-85, -33])
ax2.set_ylim([-60, 15])

for ax in [ax1, ax2]:
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.set_xlabel('Longitude', fontsize=16, fontweight='bold')
    ax.set_ylabel('Latitude', fontsize=16, fontweight='bold')

# --- Single Colorbar ---
divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="5%", pad=0.1)

sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, cax=cax)
cbar.ax.tick_params(labelsize=14)
cbar.set_label("CRPSS", fontsize=16, fontweight='bold')

plt.tight_layout()
plt.show()
