In [1]:
import numpy as np
import xarray as xr
from dask.diagnostics import ProgressBar
from isodisreg import idr

# Downloading data

WeatherBench2 Data Guide: https://weatherbench2.readthedocs.io/en/latest/data-guide.html

In [2]:
# "Observations"
obs_paths = {
    'era5':          'gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr',
    'ifs_analysis':  'gs://weatherbench2/datasets/hres_t0/2016-2022-6h-64x32_equiangular_conservative.zarr',
}

# Forecasts
forecast_paths = {
    'hres':                    'gs://weatherbench2/datasets/hres/2016-2022-0012-64x32_equiangular_conservative.zarr',
    'pangu':                   'gs://weatherbench2/datasets/pangu/2018-2022_0012_64x32_equiangular_conservative.zarr',
    'graphcast':               'gs://weatherbench2/datasets/graphcast/2020/date_range_2019-11-16_2021-02-01_12_hours-64x32_equiangular_conservative.zarr',
    'pangu_operational':       'gs://weatherbench2/datasets/pangu_hres_init/2020_0012_64x32_equiangular_conservative.zarr',
    'graphcast_operational':   'gs://weatherbench2/datasets/graphcast_hres_init/2020/date_range_2019-11-16_2021-02-01_12_hours-64x32_equiangular_conservative.zarr',
}

# local data path (needs to be chosen)
save_path = 'data/'

In [3]:
variables = [
    '2m_temperature',
    'mean_sea_level_pressure',
    '10m_wind_speed',
    ]

lead_times = [
    np.timedelta64(1, 'D'),
    np.timedelta64(3, 'D'),
    np.timedelta64(5, 'D'),
    np.timedelta64(7, 'D'),
    np.timedelta64(10, 'D')
    ]

time_range = slice('2020-01-01','2020-12-31')

In [None]:
for name, path in obs_paths.items():
    ds = xr.open_zarr(
        store=path, 
        storage_options={'token': 'anon'},
        decode_timedelta=True
        ).sel(time=slice('2020-01-01','2021-01-10'))[variables].drop_encoding() # drop encoding to avoid overlap errors
    
    ds = ds.sel(time=ds.time.dt.hour.isin([0, 12])) # only 00 and 12 UTC times are relevant since we only have full day lead times

    # rechunk
    ds = ds.chunk({'time': 1, 'latitude': 64, 'longitude': 32})

    with ProgressBar(): # not necessary but shows progress
        ds.to_zarr(save_path + f'{name}_64x32.zarr', mode='w', consolidated=True) # make sure save_path ends in /

In [None]:
for name, path in forecast_paths.items():
    ds = xr.open_zarr(
        store=path, 
        storage_options={'token': 'anon'},
        decode_timedelta=True
        ).sel(time=time_range, prediction_timedelta=lead_times)[variables].drop_encoding()

    # rechunk
    ds = ds.chunk({'time': 1, 'latitude': 64, 'longitude': 32})

    with ProgressBar():
        ds.to_zarr(save_path + f'{name}_64x32.zarr', mode='w', consolidated=True)

overall ~487 MB of data

# Computing PC at an example grid point

In [21]:
forecasts = xr.open_zarr(
        store='data/graphcast_64x32.zarr',
        decode_timedelta=True
        )

In [18]:
observations = xr.open_zarr(
        store='data/era5_64x32.zarr',
        decode_timedelta=True
        )

In [23]:
ka_lat = 49.00937
ka_lon = 8.40444
var = '2m_temperature'
lead_time = np.timedelta64(3, 'D')

In [25]:
preds_point = forecasts.sel(prediction_timedelta=lead_time, latitude=ka_lat, longitude=ka_lon, method='nearest').sel(time=time_range)[var].load()

In [27]:
valid_time = preds_point.time + lead_time

In [28]:
obs_point = observations.sel(latitude=preds_point.latitude, longitude=preds_point.longitude, time=valid_time)[var].load()

In [29]:
fitted_idr = idr(obs_point, preds_point.to_dataframe()[[var]])

In [30]:
easyuq_preds_point = fitted_idr.predict(preds_point.to_dataframe()[[var]], digits=8)

In [31]:
crps = easyuq_preds_point.crps(obs_point)

In [32]:
pc = np.mean(crps)

In [33]:
pc

np.float64(0.2569399556732324)