In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import yaml

# Config matplotlib
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

# Custom utils
from utils.utils_data import *
from utils.utils_plot import *

In [None]:
PRECIP_XTRM = 99 # Percentile (threshold) for the extremes

In [None]:
import yaml
conf = yaml.safe_load(open("config.yaml"))

In [None]:
# Load axes
lons_y = np.load('tmp/data/lons_y.npy')
lats_y = np.load('tmp/data/lats_y.npy')

# Load precip
y_prec = np.load(f'tmp/data/y_prec.npy')
y_xtrm = np.load(f'tmp/data/y_xtrm_0{PRECIP_XTRM}.npy')

# load percentiles
qq = np.load(f'tmp/data/qq_0{PRECIP_XTRM}.npy')

In [None]:
times = np.arange(np.datetime64('2016-01-01'), np.datetime64('2022-01-01'))
times = pd.to_datetime(times)

In [None]:
# Still need to add RF
models_report = ['RF4', 'Pan-orig','UNET2','RaNet']

In [None]:
def create_xarray_from_pred_RF(preds, time, lats_y, lons_y):
    """Function to create the xarray 3D of predictions from the outputs from the xr.apply_ufunc
       Args: preds are the prediction for each grid cell that contains the output values"""
    # create the xarray of predictions
    mx= xr.DataArray(np.zeros((time, len(lats_y),len(lons_y))), dims=["time","lat", "lon"],
                  coords=dict(lat = lats_y, 
                  lon = lons_y))
    # put the outputs for each latitude and longitue, 
    for ilat in range(len(lats_y)):
        for ilon in range(len(lons_y)):
            if preds[ilat,ilon] is None:
                mx[:,ilat,ilon] = np.nan
            else:
                mx[:,ilat,ilon] = preds[ilat, ilon]

    return(mx)

In [None]:
# Plotting
n_models = len(models_report)

n_rows = n_models+1
fig, axs = plt.subplots(n_rows, 3, figsize=(10, n_rows*3))

# Get the index of the max # of extremes
i_max_obs = np.argmax(np.sum(y_xtrm, axis=(1,2)))

vmax = np.max(y_prec[i_max_obs])

plot_map(axs[0, 0], lons_y, lats_y, y_prec[i_max_obs], title="Prec. value - truth", vmin=0, vmax=vmax, show_colorbar=False, cmap=mpl.cm.YlGnBu)
plot_map(axs[0, 1], lons_y, lats_y, y_xtrm[i_max_obs], title="Prec. extreme - truth", vmin=0, vmax=1, show_colorbar=False)
plot_map(axs[0, 2], lons_y, lats_y, y_xtrm[i_max_obs], title="Prec. extreme - truth", vmin=0, vmax=1, show_colorbar=False)

for idx, m_id in enumerate(models_report):
    if m_id == 'RF4':
        y_pred_test = np.load(f'tmp/data/predictions/y_pred_test_{m_id}_{PRECIP_XTRM}.npy', allow_pickle=True)
        xarr = create_xarray_from_pred_RF(y_pred_test, len(times), lats_y, lons_y)
        y_pred_test = xarr.to_numpy()
    else:
        y_pred_test = np.load(f'tmp/data/predictions/y_pred_test_{m_id}_{PRECIP_XTRM}.npy')
    y_pred_bool = y_pred_test > qq
    
    # Multiply to transform to numeric values
    y_pred_bool = y_pred_bool * 1
    
    # Extremes
    if m_id == 'RF4':
        y_pred_test_xtrm = np.load(f'tmp/data/predictions/y_pred_test_xtrm_{m_id}_{PRECIP_XTRM}.npy', allow_pickle=True)
        xarr = create_xarray_from_pred_RF(y_pred_test_xtrm, len(times), lats_y, lons_y)
        y_pred_test_xtrm = xarr.to_numpy()
    else:
        y_pred_test_xtrm = np.load(f'tmp/data/predictions/y_pred_test_xtrm_{m_id}_{PRECIP_XTRM}.npy')
    
    plot_map(axs[idx+1, 0], lons_y, lats_y, y_pred_test[i_max_obs], title=f"Prec. value - {m_id}", vmin=0, vmax=vmax, show_colorbar=False, cmap=mpl.cm.YlGnBu)
    plot_map(axs[idx+1, 1], lons_y, lats_y, y_pred_bool[i_max_obs], title=f"Prec. extreme - {m_id}", vmin=0, vmax=1, show_colorbar=False)
    plot_map(axs[idx+1, 2], lons_y, lats_y, y_pred_test_xtrm[i_max_obs], title=f"Prec. extreme - {m_id}", vmin=0, vmax=1, show_colorbar=False)
    
plt.tight_layout()
plt.savefig(f'figures/plot_model_comparison_{PRECIP_XTRM}_report.pdf')