### This notebook computes all the scores for the VAE models presented in the paper.
We reccomend running this notebook in a GPU enabled environment. Using a reasonably fast GPU (Nvidia RTX 4090), the notebook should take around 2 hours to finish.

In [None]:
import sys
sys.path.append('..')
import os
from tqdm import tqdm
import torch
import numpy as np
import pandas as pd
from multiprocessing.pool import Pool
from pysteps.verification.detcatscores import det_cat_fct_init as contingency_init
from pysteps.verification.detcatscores import det_cat_fct_accum as contingency_accum
from pysteps.verification.detcatscores import det_cat_fct_compute as contingency_compute
from pysteps.verification.detcontscores import det_cont_fct_init as continuous_init
from pysteps.verification.detcontscores import det_cont_fct_accum as continuous_accum
from pysteps.verification.detcontscores import det_cont_fct_compute as continuous_compute
from pysteps.utils.spectral import rapsd
from pysteps.visualization.spectral import plot_spectrum1d
from pysteps.verification.salscores import sal
import matplotlib.pyplot as plt
import pickle
from gptcast.data import MiaradDataModule
from gptcast.models import VAEGANVQ
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
def normalized_reflectivity_to_rainrate(arr: np.ndarray,
                                        minmax: tuple = (-20, 60),
                                        a: float = 200.0,
                                        b: float = 1.6):
    """
    Input is 0 - 1 normalized reflectivity value
        ( reflectivity (dbZ) / max reflectivity (52.5) )
    Output is mm/h rain rate
    """
    min, max = minmax
    rescaled = arr * (max - min)
    Z = 10.0 ** (rescaled / 10.0)  # wradlib.trafo.idecibel
    rr = (Z / a) ** (1.0 / b)  # wradlib.zr.z_to_r
    rr[rr < 0.04] = 0.
    return rr

In [None]:
md = MiaradDataModule.load_from_zenodo(
    clip_and_normalize= [0,60,-1,1],
    crop=None,
    batch_size=8,
    num_workers=8,
    pin_memory=False,
)
md.setup(stage="test")

In [4]:
output_dir = '../data/verification_vae/'
os.makedirs(output_dir, exist_ok=True)

In [5]:
test_thresholds = [.1,  .5,  1., 5., 10., 30., 50.]
wavelength_ticks = [256, 128, 64, 32, 16, 8, 4, 2]

In [None]:
model_list = ["vae_mae", "vae_mwae"]
for model in model_list:
    ae = VAEGANVQ.load_from_zenodo(model, device=device).eval()
    cat_tables = [contingency_init(k, axis=None) for k in test_thresholds]
    cont_scores = continuous_init()
    tdl = md.test_dataloader()
    input_spectras = []
    recons_spectras = []
    sal_scores = []
    with torch.set_grad_enabled(False):
        for batch in tqdm(tdl, desc=model):
            input_image = ae.get_input(batch, 'image').to(device=device)[..., :288, :368]
            dec, _ = ae(input_image, return_pred_indices=False)
            
            # rescale to 0-1
            inpt = (input_image.cpu().numpy().squeeze().clip(-1,1)+1)/2
            recons = (dec.cpu().numpy().squeeze().clip(-1,1)+1)/2

            # compute SAL (structure, amplitude, location) score
            # SAL scores are computed on the reflectivity
            with Pool(8) as p:
                sal_score = p.starmap(sal, [(recons_el*60, inpt_el*60) for inpt_el, recons_el in zip(inpt, recons)])
            sal_scores.extend(sal_score)
            # for inpt_el, recons_el in zip(inpt, recons):
            #     sal_scores.append(sal(recons_el*60, inpt_el*60))
            
            # transform to rainrate
            inpt = normalized_reflectivity_to_rainrate(inpt, minmax=(0,60))
            recons = normalized_reflectivity_to_rainrate(recons, minmax=(0,60))

            assert inpt.shape == recons.shape
            assert not np.isnan(recons).any()
            assert not np.isnan(inpt).any()
            assert not np.isinf(recons).any()
            assert not np.isinf(inpt).any()

            # cycle through the batch and accumulate scores
            for inpt_el, recons_el in zip(inpt, recons):
                spectra, freq = rapsd(inpt_el, fft_method=np.fft, return_freq=True, d=1, normalize=False)
                input_spectras.append(spectra)
                spectra, freq = rapsd(recons_el, fft_method=np.fft, return_freq=True, d=1, normalize=False)
                recons_spectras.append(spectra)
            
            for c in cat_tables:
                contingency_accum(c, recons, inpt)
            continuous_accum(cont_scores, recons, inpt)

    cat_scores = {str(table['thr']): contingency_compute(table) for table in cat_tables}
    cont_scores = continuous_compute(cont_scores)
    inpt_mean_spectra = np.array(input_spectras).mean(axis=0)
    recons_mean_spectra = np.array(recons_spectras).mean(axis=0)

    #save scores to disk as CSV
    pd.DataFrame(cat_scores).T.to_csv(os.path.join(output_dir, f'{model}.cat.csv'))
    pd.DataFrame(cont_scores, index=[0]).to_csv(os.path.join(output_dir, f'{model}.cont.csv'))

    # plot of the mean spectra of the input and reconstructed images
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    plot_spectrum1d(freq, inpt_mean_spectra, x_units='km', y_units='mm/h', color='k', lw=2, ax=ax, wavelength_ticks=wavelength_ticks, label='Input')
    plot_spectrum1d(freq, recons_mean_spectra, x_units='km', y_units='mm/h', color='red', ax=ax, wavelength_ticks=wavelength_ticks, label='Reconstructed')
    plt.savefig(os.path.join(output_dir, f'{model}.spectra.png'))
    plt.close(fig)

    # save spectra to disk as pickle
    with open(os.path.join(output_dir, f'{model}.spectra.pkl'), 'wb') as f:
        pickle.dump({'input': input_spectras, 'recons': recons_spectras, 'freq': freq, 'sal': sal_scores}, f)
