## Compute SPI for diffusion model evaluation

In [64]:
import glob
import hdc.algo
import numpy as np
import pandas as pd
import xarray as xr
import xskillscore as xss

from scipy.stats import spearmanr, pearsonr
from read_data import get_cond_ecmwf_dataset, get_mean_std_data

**Read forecasts and chirps** 

In [76]:
train_ecmwf = xr.open_mfdataset('data/train/*.zarr').ecmwf.load()
test_ecmwf = xr.open_mfdataset('data/test/*.zarr').ecmwf.load()
ecmwf = xr.concat([train_ecmwf, test_ecmwf], dim='time')

train_chirps = xr.open_mfdataset('data/train/*.zarr').chirps.load()
test_chirps = xr.open_mfdataset('data/test/*.zarr').chirps.load()
chirps = xr.concat([train_chirps, test_chirps], dim='time')



**Read benchmarks** 

In [66]:
train_qm = xr.open_zarr('data/benchmarks_results/QM/train_quantile_mapping.zarr').scen.load()
test_qm = xr.open_zarr('data/benchmarks_results/QM/test_quantile_mapping.zarr').scen.load()
qm = xr.concat([train_qm, test_qm], dim='time')

train_bilint = xr.open_zarr('data/benchmarks_results/bilinear/train_bilinear.zarr').tp.load()
test_bilint = xr.open_zarr('data/benchmarks_results/bilinear/test_bilinear.zarr').bilint.load()
bilint = xr.concat([train_bilint, test_bilint], dim='time')

#train_climax = xr.open_zarr('data/benchmarks_results/ClimaX/train_climax.zarr')
#test_climax = xr.open_zarr('data/benchmarks_results/ClimaX/test_climax.zarr')
#climax = xr.concat([train_climax, test_climax], dim='time')

**Compute SPI** 

Diffusion

In [67]:
train_diff_files = glob.glob('data/diffusion/batch_train_*.npy')
train_diff_samples = np.concatenate([np.load(file) for file in train_diff_files], axis=0)
train_diff = np.mean(train_diff_samples, axis=1)[:, :, :, 0]
train_diff_da = xr.DataArray(
    data=train_diff,
    dims=["time", "latitude", "longitude"],
    coords=train_ecmwf.coords,
)

In [68]:
test_diff_samples = np.load('data/diffusion/batch_5samples.npy')
test_diff = np.mean(test_diff_samples, axis=1)[:, :, :, 0]
test_diff_da = xr.DataArray(
    data=test_diff,
    dims=["time", "latitude", "longitude"],
    coords=test_ecmwf.coords,
)

In [69]:
### Get mean and var for de-normalizing
mean_chirps, var_chirps, mean_ecmwf, var_ecmwf = get_mean_std_data(split="train")

In [70]:
train_diff_da = train_diff_da * (var_chirps + 1e-4) + mean_chirps
test_diff_da = test_diff_da * (var_chirps + 1e-4) + mean_chirps

In [49]:
da_diff = xr.concat([train_diff_da, test_diff_da], dim='time')
da_diff.attrs['nodata'] = np.nan

In [50]:
train_ecmwf.time.isel(time=-1).values

numpy.datetime64('2015-12-21T00:00:00.000000000')

In [51]:
spi_diffusion = da_diff.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [55]:
spi_diffusion = spi_diffusion.sel(time=test_ecmwf.time)

In [85]:
spi_diffusion

In [87]:
spi_diffusion.to_zarr("data/diffusion/test_diffusion_spi.zarr")

<xarray.backends.zarr.ZarrStore at 0x7faa0b4398c0>

Benchmarks

In [60]:
bilint.attrs['nodata'] = np.nan
spi_bilint = bilint.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [88]:
spi_bilint = spi_bilint.sel(time=test_ecmwf.time)

In [89]:
spi_bilint.to_zarr("data/benchmarks_results/bilinear/test_bilint_spi.zarr")

<xarray.backends.zarr.ZarrStore at 0x7faa0b4c16c0>

In [62]:
qm.attrs['nodata'] = np.nan
spi_qm = qm.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [80]:
spi_qm = spi_qm.sel(time=test_ecmwf.time)

In [81]:
spi_qm.to_zarr("data/benchmarks_results/QM/test_qm_spi.zarr")

<xarray.backends.zarr.ZarrStore at 0x7faa0b451640>

Ground truth

In [78]:
spi_gt = chirps.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [82]:
spi_gt = spi_gt.sel(time=test_ecmwf.time)

In [84]:
spi_gt.to_zarr("data/test_spi_chirps.zarr")

<xarray.backends.zarr.ZarrStore at 0x7faa0b749c40>