In [None]:
import xarray as xr
import numpy as np
import xesmf as xe
import glob

# Step 1: Prepare file list if you have several files for each ensembles
file_pattern = "/directory/data???.nc"
date="2026-02-12"

##################################################################################


file_list = sorted(glob.glob(file_pattern))

# Step 2: Initialize empty lists for each variable
t2m_list, msl_list, tp_list = [], [], []

# Step 3: Loop through each file, extract variables, and add time dimension
for i, file in enumerate(file_list):
    ds = xr.open_dataset(file, chunks={"ensemble_member": 10})

    time_val = np.datetime64(date) + np.timedelta64(i, 'D')

    for var, var_list in zip(["t2m", "msl", "tp"], [t2m_list, msl_list, tp_list]):
        if var in ds:
            var_data = ds[var].expand_dims(time=[time_val])
            var_list.append(var_data)
        else:
            print(f"Warning: {var} not found in {file}")

# Step 4: Concatenate each variable
t2m_all = xr.concat(t2m_list, dim='time')
msl_all = xr.concat(msl_list, dim='time')
tp_all = xr.concat(tp_list, dim='time')

In [None]:
t2m_all

### interpolate to EC grid

In [None]:
####interpolate to EC grid


ds_out = xr.Dataset(
    {
        "lat": (["lat"], np.arange(90.0, -90.1, -1.5)),
        "lon": (["lon"], np.arange(0.0, 360.0, 1.5)),
    }
)
regridder = xe.Regridder(t2m_all, ds_out, "bilinear", reuse_weights=False)

t2m_interp = regridder(t2m_all)
msl_interp = regridder(msl_all)
tp_interp = regridder(tp_all)

### make week 3 and 4 average for t2m and msl but weekly sum for tp

In [None]:
# Define week slices (0-based indexing)
week_slices = [(19, 25), (26, 32)]  # Week 3 and Week 4: days 19–25 and 26–32 for Thursday start

# Dictionary to hold all weekly averaged/summed data
weekly_data = {"t2m": [], "msl": [], "tp": []}

# Loop over weeks
for i, (start, end) in enumerate(week_slices):
    for varname, var_interp in zip(["t2m", "msl", "tp"], [t2m_interp, msl_interp, tp_interp]):
        time_slice = var_interp.isel(time=slice(start, end+1))
        
        if varname == "tp":
            week_stat = time_slice.sum(dim='time')  # Sum for total precipitation
        else:
            week_stat = time_slice.mean(dim='time')  # Mean for temperature and pressure
        
        week_stat = week_stat.expand_dims(week=[i+1])  # Week 3 → 3, Week 4 → 4
        weekly_data[varname].append(week_stat)

# Combine into a Dataset
weekly_dataset = xr.Dataset({
    var: xr.concat(weekly_data[var], dim='week') for var in weekly_data
})

# Optional: add metadata
weekly_dataset['t2m'].attrs['units'] = 'K'
weekly_dataset['msl'].attrs['units'] = 'Pa'
weekly_dataset['tp'].attrs['units'] = 'm'  # assuming 'tp' is in meters


# Save to NetCDF (optional)
#weekly_dataset.to_netcdf("spire_ai_20250626_weekly_avg_sum_week3_4.nc")



## Estimate probabilities

#### Retrive climatology quantile from AI Quest. 

The dates belongs to the assosiate excell file. For example if  the forecast issued on 20250626, then the week 3 is 20250714 and week 4 is 20250721. Replace "password" by your AI Quest password.

In [None]:
from AI_WQ_package import retrieve_evaluation_data
retrieve_evaluation_data.retrieve_20yr_quintile_clim('20260302','tas','password')
retrieve_evaluation_data.retrieve_20yr_quintile_clim('20260302','mslp','password')
retrieve_evaluation_data.retrieve_20yr_quintile_clim('20260302','pr','password')
retrieve_evaluation_data.retrieve_20yr_quintile_clim('20260309','tas','password')
retrieve_evaluation_data.retrieve_20yr_quintile_clim('20260309','mslp','password')
retrieve_evaluation_data.retrieve_20yr_quintile_clim('20260309','pr','password')

In [None]:
tas_quintile_bounds_wk3 = xr.open_dataset("tas_20yrCLIM_WEEKLYMEAN_quintiles_20260302.nc")['tas'].squeeze('time')
tas_quintile_bounds_wk4 = xr.open_dataset("tas_20yrCLIM_WEEKLYMEAN_quintiles_20260309.nc")['tas'].squeeze('time')

mslp_quintile_bounds_wk3 = xr.open_dataset("mslp_20yrCLIM_WEEKLYMEAN_quintiles_20260302.nc")['mslp'].squeeze('time')
mslp_quintile_bounds_wk4 = xr.open_dataset("mslp_20yrCLIM_WEEKLYMEAN_quintiles_20260309.nc")['mslp'].squeeze('time')

TP_quintile_bounds_wk3 = xr.open_dataset("pr_20yrCLIM_WEEKLYSUM_quintiles_20260302.nc")['pr'].squeeze('time')
TP_quintile_bounds_wk4 = xr.open_dataset("pr_20yrCLIM_WEEKLYSUM_quintiles_20260309.nc")['pr'].squeeze('time')


In [None]:
t2m_forecasts_wk3=weekly_dataset.sel(week=1).t2m.rename({'lat':'latitude','lon':'longitude'})
t2m_forecasts_wk4=weekly_dataset.sel(week=2).t2m.rename({'lat':'latitude','lon':'longitude'})
msl_forecasts_wk3=weekly_dataset.sel(week=1).msl.rename({'lat':'latitude','lon':'longitude'})
msl_forecasts_wk4=weekly_dataset.sel(week=2).msl.rename({'lat':'latitude','lon':'longitude'})
tp_forecasts_wk3=weekly_dataset.sel(week=1).tp.rename({'lat':'latitude','lon':'longitude'})
tp_forecasts_wk4=weekly_dataset.sel(week=2).tp.rename({'lat':'latitude','lon':'longitude'})


If needed multiply 1000 to change the unit of precipitation otherwise ignore this part

In [None]:
tp_forecasts_wk3=tp_forecasts_wk3*1000#.isel(ensemble_member=100).plot()
tp_forecasts_wk4=tp_forecasts_wk4*1000

#### Function to estimate probailites using "counting members" method 


In [None]:
def categorize_into_quintiles(forecast_week, quintile_bounds):
    """
    forecast_week: (ensemble_member, lat, lon)
    quintile_bounds: (quantile=4, lat, lon)
    Returns:
        probabilities: (quintile=5, lat, lon)
    """
    quantile_edges = xr.concat([
        xr.full_like(quintile_bounds.isel(quantile=0), -np.inf),
        quintile_bounds.isel(quantile=0),
        quintile_bounds.isel(quantile=1),
        quintile_bounds.isel(quantile=2),
        quintile_bounds.isel(quantile=3),
        xr.full_like(quintile_bounds.isel(quantile=0), np.inf)
    ], dim='edge')

    counts = []
    for q in range(5):
        in_bin = ((forecast_week > quantile_edges.isel(edge=q)) &
                  (forecast_week <= quantile_edges.isel(edge=q + 1)))
        prob = in_bin.sum(dim='ensemble_member') / forecast_week.sizes['ensemble_member']
        counts.append(prob)

    probs = xr.concat(counts, dim=xr.DataArray(np.arange(1, 6), dims="quintile"))
    return probs

#### calculating probabilites

In [None]:
probs_tas_wk3 = categorize_into_quintiles(t2m_forecasts_wk3, tas_quintile_bounds_wk3)
probs_tas_wk4 = categorize_into_quintiles(t2m_forecasts_wk4, tas_quintile_bounds_wk4)

probs_mslp_wk3 = categorize_into_quintiles(msl_forecasts_wk3, mslp_quintile_bounds_wk3)
probs_mslp_wk4 = categorize_into_quintiles(msl_forecasts_wk3, mslp_quintile_bounds_wk4)

probs_tp_wk3 = categorize_into_quintiles(tp_forecasts_wk3, TP_quintile_bounds_wk3)
probs_tp_wk4 = categorize_into_quintiles(tp_forecasts_wk3, TP_quintile_bounds_wk4)

In [None]:
probs_tas_wk3

## Forecast submission

In [None]:
from AI_WQ_package import forecast_submission

In [None]:
forecast_tas_wk3 = forecast_submission.AI_WQ_create_empty_dataarray('tas', '20260212', '1', 'team name', 'model name', 'password')
forecast_tas_wk4 = forecast_submission.AI_WQ_create_empty_dataarray('tas', '20260212', '2', 'team name', 'model name', 'password')

forecast_mslp_wk3 = forecast_submission.AI_WQ_create_empty_dataarray('mslp', '20260212', '1', 'team name', 'model name', 'password')
forecast_mslp_wk4 = forecast_submission.AI_WQ_create_empty_dataarray('mslp', '20260212', '2', 'team name', 'model name', 'password')

forecast_tp_wk3 = forecast_submission.AI_WQ_create_empty_dataarray('pr', '20260212', '1', 'team name', 'model name', 'password')
forecast_tp_wk4 = forecast_submission.AI_WQ_create_empty_dataarray('pr', '20260212', '2', 'team name', 'model name', 'password')

In [None]:
forecast_tas_wk3.values=probs_tas_wk3
forecast_tas_wk4.values=probs_tas_wk4

In [None]:
probs_tas_wk3

In [None]:
#forecast_tas_wk3.to_netcdf("tas_20260129_p1_SAIS2S.nc")
#forecast_tas_wk4.to_netcdf("tas_20260129_p2_SAIS2S.nc")

In [None]:
forecast_tas_wk3_submit = forecast_submission.AI_WQ_forecast_submission(forecast_tas_wk3,'tas', '20260212', '1', 'team name', 'model name', 'password')
forecast_tas_wk4_submit = forecast_submission.AI_WQ_forecast_submission(forecast_tas_wk4,'tas', '20260212', '2', 'team name', 'model name', 'password')


In [None]:
forecast_mslp_wk3.values=probs_mslp_wk3
forecast_mslp_wk4.values=probs_mslp_wk4

In [None]:
#forecast_mslp_wk3.to_netcdf("mslp_20260129_p1_SAIS2S.nc")
#forecast_mslp_wk4.to_netcdf("mslp_20260129_p2_SAIS2S.nc")

In [None]:
forecast_mslp_wk3_submit = forecast_submission.AI_WQ_forecast_submission(forecast_mslp_wk3,'mslp', '20260212', '1', 'team name', 'model name', 'password')
forecast_mslp_wk4_submit = forecast_submission.AI_WQ_forecast_submission(forecast_mslp_wk4,'mslp', '20260212', '2', 'team name', 'model name', 'password')


In [None]:
forecast_tp_wk3.values=probs_tp_wk3
forecast_tp_wk4.values=probs_tp_wk4

In [None]:
#forecast_tp_wk3.to_netcdf("pr_20260129_p1_SAIS2S.nc")
#forecast_tp_wk4.to_netcdf("pr_20260129_p2_SAIS2S.nc")

In [None]:
forecast_tp_wk3_submit = forecast_submission.AI_WQ_forecast_submission(forecast_tp_wk3,'pr', '20260212', '1', 'team name', 'model name', 'password')
forecast_tp_wk4_submit = forecast_submission.AI_WQ_forecast_submission(forecast_tp_wk4,'pr', '20260212', '2', 'team name', 'model name', 'password')


Plotting forecast

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import matplotlib.colors as mcolors

def plot_ecmwf_quintile_discrete(q5_da, title="Prob. of quintile 80<=x<100 for T2M forecast_issue_date:20260212.forecast period:20260302 to 20260309 "):
                                  #period_label="Mon 30 Jun 2025 – Mon 07 Jul 2025"):

    # Convert to %
    data = 100 * q5_da

    # ECMWF bin levels
    levels = [0, 5, 10, 15, 20, 25, 33, 40, 50, 60, 70, 80, 90, 100]

    # ECMWF color approximation
    ecmwf_colors = [
        "#8000A8", "#665DC1", "#0000FF", "#0070DD", "#00A5C9",
        "#00E0DD", "#6CF3C6", "#4CE07F", "#A5E400", "#F0F300",
        "#FFCC00", "#FF9900", "#FF4D00", "#E60000"
    ]
    cmap = mcolors.ListedColormap(ecmwf_colors)
    norm = mcolors.BoundaryNorm(boundaries=levels, ncolors=len(ecmwf_colors))

    # Plot
    fig = plt.figure(figsize=(13, 6))
    ax = plt.axes(projection=ccrs.PlateCarree())

    mesh = ax.pcolormesh(
        data.longitude, data.latitude, data,
        cmap=cmap, norm=norm, shading='auto', transform=ccrs.PlateCarree()
    )

    ax.coastlines(resolution='110m', linewidth=0.7)
    ax.add_feature(cfeature.BORDERS, linewidth=0.4)
    ax.set_global()

    plt.title(title, fontsize=14)
    #plt.suptitle(f"Forecast period: {period_label}", fontsize=11, y=0.91)

    cbar = plt.colorbar(mesh, orientation='horizontal', pad=0.05, aspect=50, ticks=levels)
    cbar.set_label('%')

    plt.tight_layout()
    plt.show()


In [None]:
q5 = probs_tas_wk3.sel(quintile=5)  # assuming 1.0 is 80–100% bin

plot_ecmwf_quintile_discrete(
    q5,
    #title="Surface temperature: Probability distribution",
    #period_label="Mon 30 Jun 2025 – Mon 07 Jul 2025"
)