### Analyse de the performance of the network when applying the LRPcomposite

In [None]:
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= '0.20'

from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_score, recall_score, roc_auc_score, roc_curve
from sklearn.utils import class_weight

# TensorFlow ≥2.0 is required
import tensorflow_addons as tfa
import tensorflow as tf
assert tf.__version__ >= '2.0'

from tensorflow import keras
from tensorflow.keras import layers, regularizers

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# Common imports
import os
import glob
import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import dask
import datetime
import math
import pickle
import pathlib
import hashlib
import seaborn
dask.config.set({'array.slicing.split_large_chunks': False})

# To make this notebook's output stable across runs
np.random.seed(42)

# Config matplotlib
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import StrMethodFormatter

mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Dotenv
from dotenv import dotenv_values

# Custom utils
from utils.utils_data import *
from utils.utils_ml import *
from utils.utils_resnet import *
from utils.utils_plot import *
from utils.DNN_models import *
from utils.rank_relevances import *

In [None]:
import yaml
conf = yaml.safe_load(open("config.yaml"))
PRECIP_XTRM = 0.95 # Percentile (threshold) for the extremes
PRECIP_DATA = 'ERA5-low' 

In [None]:
# load data
# test samples
dg_train_X = np.array(xr.open_dataarray('tmp/data/dg_train_X.nc'))
dg_train_Y_xtrm = xr.open_dataarray('tmp/data/dg_train_Y_xtrm0.95th.nc')

In [None]:
dg_test_X = np.array(xr.open_dataarray('tmp/data/dg_test_X.nc'))
dg_test_Y = np.array(xr.open_dataarray('tmp/data/dg_test_Y.nc'))
dg_test_Y_xtrm = xr.open_dataarray('tmp/data/dg_test_Y_xtrm0.95th.nc')

In [None]:
# coordinates
lons_x = np.load('tmp/data/lons_y.npy')
lats_y = np.load('tmp/data/lats_y.npy')
# test-times
#times = np.arange(np.datetime64('2016-01-01'), np.datetime64('2022-01-01'))
#times = pd.to_datetime(times)
# train-times
times = np.arange(np.datetime64('1979-01-01'), np.datetime64('2006-01-01'))
times = pd.to_datetime(times)

In [None]:
y_bool= dg_train_Y_xtrm

In [None]:
models = ['UNET1','UNET2','UNET3','UNET4']

In [None]:
# load weights
lrp_all = []
list_df = []
    
for m_id in models:
    print('LRP',m_id)
    lrp = np.load(f'tmp/LRP/lrpcomp_train_DNN_{m_id}.npy')
    lrp_all.append(lrp)
    localrel_avg, localrel_max  = getmap_localrel(lrp, conf['i_shape'], conf['varnames'], y_bool, lats_y, lons_x, times, icrop=3)
    rel_varmeans = localrel_avg.mean(axis=(0,1))
    df_sortedvars = [conf['varnames'][i] for i in np.argsort(rel_varmeans)[::-1]]
    df_sortval = [rel_varmeans[i] for i in np.argsort(rel_varmeans)[::-1]]
    data={'Model': m_id,'Variable' : df_sortedvars, 'Values' : df_sortval}
    df = pd.DataFrame(data)
    list_df.append(df)
    


In [None]:
df_all = pd.concat([list_df[1], list_df[3], list_df[2], list_df[0]])

In [None]:
dfpivot = df_all.pivot(columns='Model')

In [None]:
# plot the outputs

In [None]:
df_all.head()

In [None]:
fig, ax = plt.subplots(figsize=(14,6))

b= sns.barplot(ax=ax, x = "Variable", y = "Values", hue = "Model", data = df_all)
plt.tick_params(labelsize=7)
#plt.show()
ax.set_ylabel(None)
ax.set_xlabel(None)
ax.tick_params(axis='x', rotation=90)
b.tick_params(labelsize=15)

plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}')) 
# save figure 
fname = 'Ranking_predictors_test_UNETs'
plt.savefig('figures/' + fname + '.pdf')

In [None]:
# select UNET2 as the best network (higher scores)
rel = lrp_all[1]

In [None]:
pixelrel_avg, pixelrel_max = getmap_rel(rel, conf['i_shape'], y_bool, lons_x, lats_y, False)

In [None]:
def plot_xr_rel(rel, lats_y,lons_x, vnames, fname, cmap='Reds', vmin=None, vmax=None, vcenter=None, plot=True):
    
    
    mx= xr.DataArray(rel, dims=["lat", "lon", "variable"],
                  coords=dict(lat = lats_y, 
            lon = lons_x, variable= vnames ))
    
      
    if vcenter is None:
    
        g = mx.plot.pcolormesh("lon", "lat", col="variable", col_wrap=6, robust=True, cmap=cmap,
        yincrease = False, extend='both',vmin=vmin, vmax=vmax,
        figsize=(12, 8),  cbar_kwargs={"orientation": "vertical", "shrink": 0.9, "aspect": 50})
    
    else:
        
        if mx.min() == 0:
            norm = mcolors.TwoSlopeNorm(vcenter=0, vmax=mx.max())
        
        else:
        
            norm = mcolors.TwoSlopeNorm(vmin=mx.min(), vcenter=0, vmax=mx.max())
        
  
        g = mx.plot.pcolormesh("lon", "lat", col="variable", col_wrap=6, robust=True, cmap=cmap,
        yincrease = False, extend='both', norm = norm,
        figsize=(14, 14),  cbar_kwargs={"orientation": "vertical", "shrink": 0.9, "aspect": 50})
        
    #figsize=(14, 12)
    for ax, title in zip(g.axes.flat, vnames):

        world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
        world.boundary.plot(ax=ax, lw=1, color='k')
        ax.set_xlim(min(lons_x), max(lons_x))
        ax.set_title(title)
        ax.set_ylim(min(lats_y), max(lats_y))
        
    # To control the space
    plt.subplots_adjust(right=0.8, wspace=0.1, hspace=0.3)
    if plot:
        #plt.tight_layout()
        plt.savefig('figures/' + fname + '.pdf')
    else:
        
        plt.draw()

In [None]:
plot_xr_rel(pixelrel_avg, lats_y, lons_x, conf['varnames'], 'relevances_train_UNET2_pixelwise')

In [None]:
#pixelrel_avg_u4, pixelrel_max_4 = getmap_rel(lrp_all[3], conf['i_shape'], y_bool, lons_x, lats_y, False)

In [None]:
#plot_xr_rel(pixelrel_avg_u4, lats_y, lons_x, conf['varnames'], 'relevances_train_UNET4_pixelwise')