In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import properscoring as ps
import matplotlib.pyplot as plt

# --- 1. Create Sample z500 Ensemble Data ---
# This section simulates a realistic z500 forecast dataset.
# You would replace this with your actual data loading.

# Define the dimensions of the data
init_times = pd.to_datetime(['2025-07-01T00:00', '2025-07-02T00:00'])
forecast_hours = np.arange(0, 121, 24)  # 0 to 5 days
ensemble_members = np.arange(20)
latitudes = np.linspace(50, 30, 10)
longitudes = np.linspace(-110, -90, 20)
dims = ('initialization_time', 'forecast_hour', 'latitude', 'longitude')

# Create a smooth "truth" field for z500 (in meters)
base_height = 5600  # A typical z500 value
lon_grid, lat_grid = np.meshgrid(longitudes, latitudes)
truth_field = base_height + 150 * np.sin(np.deg2rad(lat_grid) * 4) * np.cos(np.deg2rad(lon_grid) * 2)

ds_truth = xr.DataArray(
    data=np.stack([truth_field] * len(forecast_hours) * len(init_times)).reshape(
        len(init_times), len(forecast_hours), len(latitudes), len(longitudes)
    ),
    dims=dims,
    coords={
        'initialization_time': init_times,
        'forecast_hour': forecast_hours,
        'latitude': latitudes,
        'longitude': longitudes
    }
).to_dataset(name='z500')

# Create an ensemble forecast by adding growing random noise to the truth
# This simulates increasing forecast error and spread over time.
forecast_error_scaling = np.linspace(0, 100, len(forecast_hours)) # Error grows with forecast time
random_noise = np.random.randn(
    len(init_times),
    len(ensemble_members),
    len(forecast_hours),
    len(latitudes),
    len(longitudes)
) * forecast_error_scaling[np.newaxis, np.newaxis, :, np.newaxis, np.newaxis]

forecast_data = truth_field + random_noise

ds_ensemble = xr.DataArray(
    data=forecast_data,
    dims=('initialization_time', 'number', 'forecast_hour', 'latitude', 'longitude'),
    coords={
        'initialization_time': init_times,
        'number': ensemble_members,
        'forecast_hour': forecast_hours,
        'latitude': latitudes,
        'longitude': longitudes
    }
).to_dataset(name='z500')


# --- 2. Calculate CRPS ---
# This is the core verification logic.

# Extract the data arrays for convenience
forecasts_da = ds_ensemble['z500']
truth_da = ds_truth['z500']

# Get the forecast values, ensuring the 'number' (ensemble) dimension is last
fcst_values = forecasts_da.transpose(..., 'number').values
# Get the observation values (truth)
obs_values = truth_da.values

# Calculate CRPS for every point; this returns a numpy array
crps_values = ps.crps_ensemble(obs_values, fcst_values)

# Convert the numpy result back to an xarray.DataArray to use named dimensions
crps_da = xr.DataArray(
    crps_values,
    dims=truth_da.dims,
    coords=truth_da.coords
)

# Average the CRPS over space and initialization time to get a score for each forecast hour
crps_mean = crps_da.mean(dim=["initialization_time", "latitude", "longitude"])


# --- 3. Plot the Results ---
fig, ax = plt.subplots(figsize=(10, 6))

# Use 'forecast_hour' for the x-axis
crps_mean.plot(ax=ax, marker='o', label='CRPS')

ax.set_title('Ensemble Forecast Skill (z500)')
ax.set_xlabel('Forecast Lead Time (hours)')
ax.set_ylabel('Mean CRPS (meters)')
ax.grid(True, linestyle='--', alpha=0.7)
ax.legend()
plt.tight_layout()
plt.show()