## Compute SPI for diffusion model evaluation

In [2]:
import glob
import hdc.algo
import numpy as np
import pandas as pd
import xarray as xr
import xskillscore as xss

from scipy.stats import spearmanr, pearsonr
from read_data import get_cond_ecmwf_dataset, get_mean_std_data

2024-04-09 18:16:56.842880: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-09 18:16:56.845095: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-09 18:16:56.894292: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-09 18:16:56.895066: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


**Read forecasts and chirps** 

In [3]:
train_ecmwf = xr.open_mfdataset('data/train/*.zarr').ecmwf.load()
test_ecmwf = xr.open_mfdataset('data/test/*.zarr').ecmwf.load()
ecmwf = xr.concat([train_ecmwf, test_ecmwf], dim='time')

train_chirps = xr.open_mfdataset('data/train/*.zarr').chirps.load()
test_chirps = xr.open_mfdataset('data/test/*.zarr').chirps.load()
chirps = xr.concat([train_chirps, test_chirps], dim='time')



**Read benchmarks** 

In [6]:
train_qm = xr.open_zarr('data/benchmarks_results/QM/train_quantile_mapping.zarr').scen.load()
test_qm = xr.open_zarr('data/benchmarks_results/QM/test_quantile_mapping.zarr').scen.load()
qm = xr.concat([train_qm, test_qm], dim='time')

train_bilint = xr.open_zarr('data/benchmarks_results/bilinear/train_bilinear.zarr').bilint.load()
test_bilint = xr.open_zarr('data/benchmarks_results/bilinear/test_bilinear.zarr').bilint.load()
bilint = xr.concat([train_bilint, test_bilint], dim='time')

#train_climax = xr.open_zarr('data/benchmarks_results/ClimaX/train_climax.zarr')
#test_climax = xr.open_zarr('data/benchmarks_results/ClimaX/test_climax.zarr')
#climax = xr.concat([train_climax, test_climax], dim='time')

**Compute SPI** 

Diffusion

In [16]:
train_diff_files = glob.glob('data/diffusion/batch_train_*.npy')
train_diff_samples = np.concatenate([np.load(file) for file in train_diff_files], axis=0)
train_diff = np.mean(train_diff_samples, axis=1)[:, :, :, 0]
train_diff_da = xr.DataArray(
    data=train_diff,
    dims=["time", "latitude", "longitude"],
    coords=train_ecmwf.coords,
)

In [1]:
test_diff_samples = np.load('data/diffusion/batch_25samples_north.npy')
test_diff = np.mean(test_diff_samples, axis=1)[:, :, :, 0]
test_diff_da = xr.DataArray(
    data=test_diff,
    dims=["time", "latitude", "longitude"],
    coords=test_ecmwf.coords,
)

NameError: name 'np' is not defined

In [18]:
### Get mean and var for de-normalizing
mean_chirps, var_chirps, mean_ecmwf, var_ecmwf = get_mean_std_data(split="train")

In [19]:
train_diff_da = train_diff_da * (var_chirps + 1e-4) + mean_chirps
test_diff_da = test_diff_da * (var_chirps + 1e-4) + mean_chirps

In [20]:
da_diff = xr.concat([train_diff_da, test_diff_da], dim='time')
da_diff.attrs['nodata'] = np.nan

In [21]:
da_diff = da_diff.groupby(da_diff.time.dt.strftime('%Y-%m-01')).sum().rename({'strftime':'time'})
da_diff['time'] = pd.to_datetime(da_diff.time.values)

In [22]:
spi_diffusion = da_diff.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [23]:
spi_diffusion = spi_diffusion.where(spi_diffusion.time.dt.year.isin(test_ecmwf.time.dt.year), drop=True)

In [24]:
spi_diffusion.to_zarr("data/diffusion/test_diffusion_spi.zarr", mode='w')

<xarray.backends.zarr.ZarrStore at 0x7f9f5c22f5c0>

Benchmarks

In [73]:
bilint = bilint.groupby(bilint.time.dt.strftime('%Y-%m-01')).sum().rename({'strftime':'time'})
bilint['time'] = pd.to_datetime(bilint.time.values)

In [75]:
bilint.attrs['nodata'] = np.nan
spi_bilint = bilint.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [76]:
spi_bilint = spi_bilint.where(spi_bilint.time.dt.year.isin(test_ecmwf.time.dt.year), drop=True)

In [77]:
spi_bilint.to_zarr("data/benchmarks_results/bilinear/test_bilint_spi.zarr", mode='w')

<xarray.backends.zarr.ZarrStore at 0x7f014a0d35c0>

In [78]:
qm = qm.groupby(qm.time.dt.strftime('%Y-%m-01')).sum().rename({'strftime':'time'})
qm['time'] = pd.to_datetime(qm.time.values)

In [79]:
qm.attrs['nodata'] = np.nan
spi_qm = qm.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [80]:
spi_qm = spi_qm.where(spi_qm.time.dt.year.isin(test_ecmwf.time.dt.year), drop=True)

In [81]:
spi_qm.to_zarr("data/benchmarks_results/QM/test_qm_spi.zarr", mode='w')

<xarray.backends.zarr.ZarrStore at 0x7f01483afe40>

Ground truth

In [82]:
chirps = chirps.groupby(chirps.time.dt.strftime('%Y-%m-01')).sum().rename({'strftime':'time'})
chirps['time'] = pd.to_datetime(chirps.time.values)

In [83]:
spi_gt = chirps.hdc.algo.spi(calibration_start=train_ecmwf.time.isel(time=0).values, calibration_stop=train_ecmwf.time.isel(time=-1).values)

In [84]:
spi_gt = spi_gt.where(spi_gt.time.dt.year.isin(test_ecmwf.time.dt.year), drop=True)

In [85]:
spi_gt.to_zarr("data/test_spi_chirps.zarr", mode='w')

<xarray.backends.zarr.ZarrStore at 0x7f014a06fcc0>