# Imports, loading, preprocessing

In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import os
import warnings
import pandas as pd
from glob import glob
import xarray as xr
import seaborn as sns
from IPython.display import SVG
from functools import partial
import geopandas as gpd
from shapely.geometry import Polygon, Point, MultiPolygon

from sklearn import linear_model
import pysumma.plotting as psp
import pysumma.utils as psu
import pysumma.evaluation as pse

from tqdm.notebook import tqdm
import networkx as nx

from sknetwork.visualization import svg_graph, svg_digraph, svg_bigraph
from sknetwork.embedding import *
from sknetwork.clustering import Louvain

from scipy import sparse
from IPython.display import SVG

from tensorflow import keras
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.utils import plot_model, model_to_dot
from tensorflow.keras import layers
from tensorflow.keras.callbacks import Callback, EarlyStopping

from sklearn import preprocessing
from sklearn.linear_model import LinearRegression

import shap
import innvestigate

os.environ["CUDA_VISIBLE_DEVICES"]="1"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["OMP_NUM_THREADS"] = "1"
os.environ['TF_NUM_INTEROP_THREADS'] = '1'
os.environ['TF_NUM_INTRAOP_THREADS'] = '1'

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

mpl.style.use('seaborn-talk')
sns.set_context('talk')
mpl.rcParams['figure.figsize'] = (8, 6)

K.set_floatx('float32')

In [None]:
model = keras.models.load_model('../new_models/all_var_dense_dropout.h5')
sim_sites = os.listdir('../sites')

In [None]:
site_ds    = {s: xr.open_dataset(f'../data_for_paper_2/{s}_data.nc').load() for s in tqdm(sim_sites)}
site_r_qle = {s: xr.open_dataset(f'../data_for_paper_2/{s}_lrp_qle.nc').load() for s in tqdm(sim_sites)}
site_r_qh  = {s: xr.open_dataset(f'../data_for_paper_2/{s}_lrp_qh.nc').load() for s in tqdm(sim_sites)}

In [None]:
site_attrs = gpd.GeoDataFrame.from_file('../data_for_paper_2/site_attrs.shp')

In [None]:
site_attrs.index = site_attrs['Site']

In [None]:
feature_strings = ['airtemp',       # 0
                   'relhum',        # 1
                   'swradatm',      # 2
                   'surf_sm',       # 3
                   'lai',           # 4
                   'vegtype',       # 5
              ]           

var_to_color = {
    'swradatm': '#4363d8',
    'airtemp': '#e6194B',
    'relhum': '#f58231',
    'surf_sm': '#911eb4',
    'transpirable': '#9A6324',
    'vegtype': '#469990',
    'soiltype': 'goldenrod',
}

# Functions

In [None]:
def calc_relhum(T, p, sh):
    T0 = 273.16
    return sh * (0.263*p)  / np.exp((17.67*(T-T0))/(T-29.65))

def L_vap(T):
    # J / g
    C = T - 273.16
    return 2500.8 - (2.36 * C) + (0.0016 * C**2) - (0.00006 * C**3)

def Qle_to_ET(Q, T):
    g_per_kg = 1e3
    return (1 / (g_per_kg * L_vap(T))) * Q

def calc_pet(tmin, tmax, tmean, rad, rad_mult=1):
    return 0.0023 * (tmean + 17.8) * (tmax - tmin) ** 0.5 * 0.408 * (rad * rad_mult)


In [None]:
def format_data_for_nn(ds, use_mask=True):
    airtemp   = (((ds['airtemp'].values / 27.315) - 10) / 2) + 0.5
    swradatm  = (ds['SWRadAtm'].values / 1000) 
    mask      = ds['gap_filled'].values
    relhumid  = calc_relhum(ds['airtemp'].values, ds['airpres'].values, ds['spechum']) / 100
    relhumid[relhumid<0] = 0
    lai = ds['scalarLAI'].values / 12
    try:
        canwidth = (ds['heightCanopyTop']).values / 20
    except:
        canwidth = 1/(lai + 0.001)
    vegtype = ds['vegTypeIndex'].values[()] * np.ones_like(mask) / 12
    surf_sm = ds['surf_sm'].values

    train_input = np.vstack([airtemp,       # 0
                             relhumid,      # 1
                             swradatm,      # 2
                             surf_sm,       # 3
                             lai * canwidth, # 4
                             vegtype,       # 5
                            ]).T 
    train_output = np.vstack([ds['Qle_cor'].values / 500,
                              ds['Qh_cor'].values / 500,]).T
    
    if use_mask:
        train_input = train_input[mask == 0]
        train_output = train_output[mask == 0]    
    return train_input.astype(np.float32), train_output.astype(np.float32)

In [None]:
def lrp_to_ds(r, timedim, feature_strings=feature_strings):
    feature_das = []
    for i, feature in enumerate(feature_strings):
            r_feat = r[:, i]
            feature_das.append(xr.DataArray(r_feat, coords={'time': timedim}, dims=['time'], name=feature))
    r_ds = xr.merge(feature_das) 
    in_sum = np.sum(r)
    out_sum = np.sum(r_ds.sum().to_array())
    return r_ds

def analyze_model(new_input, time_dim, analyzer):
    
    r_qle = analyzer.analyze(new_input, neuron_selection=0)
    r_qh = analyzer.analyze(new_input, neuron_selection=1)
    
    feature_strings = ['airtemp', 'relhum', 'swradatm', 'surf_sm', 'lai', 'vegtype']
    r_qle_ds = lrp_to_ds(r_qle, time_dim, feature_strings)
    r_qh_ds = lrp_to_ds(r_qh, time_dim, feature_strings)
    return r_qle_ds, r_qh_ds
        

In [None]:
def xargmax(ds, var='SWRadAtm', dim=None):
    return ds.isel(**{dim: ds[var].argmax(dim)})


In [None]:
%%time
site = 'US-Whs'
site = 'CH-Fru'
vars_of_interest = ['airtemp', 'relhum', 'SWRadAtm', 'LWRadAtm', 'surf_sm', 'scalarLAI', 'pptrate',#'heightCanopyTop', 
                    'scalarLatHeatTotal', 'scalarSenHeatTotal', 'Qle_cor', 'Qh_cor', 'time', 'day', 'theta_sat', 'fieldCapacity', 'critSoilWilting']
ds = site_ds[site][vars_of_interest]
max_ds = ds.groupby(ds['day']).apply(xargmax, dim='time')
max_ds['day'] = max_ds['time']
max_ds = max_ds.drop('time').rename({'day': 'time'})
max_ds['surf_sm'] = max_ds['surf_sm'] * (max_ds['theta_sat'] - max_ds['fieldCapacity']) + max_ds['fieldCapacity']
max_r_qle = site_r_qle[site].sel(time=max_ds['time'])
max_r_qh = site_r_qh[site].sel(time=max_ds['time'])

In [None]:
horizontal=True
if horizontal:
    fig, axes = plt.subplots(2, 5, figsize=(18, 8), sharex=True)
else:
    fig, axes = plt.subplots(5, 2, figsize=(8, 15), sharex=True)
    axes = axes.T
(max_ds['SWRadAtm']).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[0, 0], color='dimgrey', label='Input to NNLRP')
max_r_qle['swradatm'].groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[1, 0], label=r'Relevance to latent Heat')
max_r_qh['swradatm'].groupby(max_r_qh['time'].dt.dayofyear).mean().plot(ax=axes[1, 0], label=r'Relevance to sensible heat')
axes[0, 0].set_ylabel(r'Shortwave ($W/m^2$)')
axes[1, 0].set_ylabel(r'$R_{shortwave}$')
axes[0, 0].set_xlabel('')
axes[1, 0].set_xlabel('')

(max_ds['airtemp']-273.16).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[0, 1], color='dimgrey')
max_r_qle['airtemp'].groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[1,1], label=r'$R_{T\rightarrow Q_{le}}$')
max_r_qh['airtemp'].groupby(max_r_qh['time'].dt.dayofyear).mean().plot(ax=axes[1,1], label=r'$R_{T\rightarrow Q_{h}}$')
axes[0, 1].set_ylabel(r'Temperature ($^{\circ}C$)')
axes[1, 1].set_ylabel(r'$R_{temperature}$')
axes[0, 1].set_xlabel('')
axes[1, 1].set_xlabel('')

(max_ds['relhum']).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[0, 2], color='dimgrey')
max_r_qle['relhum'].groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[1,2], label=r'$R_{T\rightarrow Q_{le}}$')
max_r_qh['relhum'].groupby(max_r_qh['time'].dt.dayofyear).mean().plot(ax=axes[1,2], label=r'$R_{T\rightarrow Q_{h}}$')
axes[0, 2].set_ylabel(r'Rel. Humidity (%)')
axes[1, 2].set_ylabel(r'$R_{humidity}$')
axes[0, 2].set_xlabel('')
axes[1, 2].set_xlabel('')

(max_ds['surf_sm']).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[0, 3], color='dimgrey')
max_r_qle['surf_sm'].groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[1,3], label=r'$R_{T\rightarrow Q_{le}}$')
max_r_qh['surf_sm'].groupby(max_r_qh['time'].dt.dayofyear).mean().plot(ax=axes[1,3], label=r'$R_{T\rightarrow Q_{h}}$')
axes[0, 3].set_ylabel('Soil saturation (frac)')
axes[1, 3].set_ylabel(r'$R_{moisture}$')
axes[0, 3].set_xlabel('')
axes[1, 3].set_xlabel('')

(max_ds['Qle_cor']).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[0, 4], color='black', label='Observed')
(max_ds['scalarLatHeatTotal']).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[0, 4], color='crimson', label='Simulated')
axes[0, 4].set_ylabel(r'Latent heat ($W/m^2$)')


(max_ds['Qh_cor']).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[1, 4], color='black', label='Observed')
(max_ds['scalarSenHeatTotal']).groupby(max_r_qle['time'].dt.dayofyear).mean().plot(ax=axes[1, 4], color='crimson', label='Simulated')
axes[1, 4].set_ylabel(r'Sensible heat ($W/m^2$)')
axes[0, 4].set_xlabel('')
axes[1, 4].set_xlabel('')
#axes[1, 4].set_ylim([0, 600])
#axes[0, 4].set_ylim([0, 600])

doy_idx = [32, 92, 152, 213, 274, 335]
doy_lab = ['Feb', 'Apr', 'Jun', 'Aug', 'Oct', 'Dec']
axes[0, 4].set_xticks(doy_idx)
axes[0, 4].set_xticklabels(doy_lab, rotation=90)
axes[1, 4].set_xticks(doy_idx)
axes[1, 4].set_xticklabels(doy_lab, rotation=90)
axes[1, 4].set_xlim([1, 365])
if horizontal:
    axes[1, 3].set_xticklabels(doy_lab, rotation=90)
    axes[1, 3].set_xlim([1, 365])
    axes[1, 2].set_xticklabels(doy_lab, rotation=90)
    axes[1, 2].set_xlim([1, 365])
    axes[1, 1].set_xticklabels(doy_lab, rotation=90)
    axes[1, 1].set_xlim([1, 365])
    axes[1, 0].set_xticklabels(doy_lab, rotation=90)
    axes[1, 0].set_xlim([1, 365])

axes[1, 0].axhline(0, color='black', linestyle='--')
axes[1, 1].axhline(0, color='black', linestyle='--')
axes[1, 2].axhline(0, color='black', linestyle='--')
axes[1, 3].axhline(0, color='black', linestyle='--')
if horizontal:
    plt.tight_layout()
    axes[1, 0].legend(bbox_to_anchor=(2, -0.3))
    axes[0, 0].legend(bbox_to_anchor=(1.3, 1.2))
    axes[1, 4].legend(bbox_to_anchor=(1, -0.3))
else:
    plt.tight_layout()
    axes[1, 0].legend(bbox_to_anchor=(.975, 1.3),ncol=2)
    axes[1, 4].legend(bbox_to_anchor=(1, -0.3))

In [None]:
pvals = []
r_sm_rh_qle = []
r_sm_rh_qh = []
for site in tqdm(sim_sites):
    p = site_attrs.loc[site]['PET/P']
    r_sm_qle = site_r_qle[site]['surf_sm']#.resample({'time': 'D'}).mean()
    r_sm_qh = site_r_qh[site]['surf_sm']#.resample({'time': 'D'}).mean()
    r_rh_qle = site_r_qle[site]['relhum']#.resample({'time': 'D'}).mean()
    r_rh_qh = site_r_qh[site]['relhum']#.resample({'time': 'D'}).mean()
    
    r_sm_rh_qle.append(np.corrcoef(r_sm_qle, r_rh_qle)[0,1])
    r_sm_rh_qh.append(np.corrcoef(r_sm_qh, r_rh_qh)[0,1])
   
    pvals.append(p)
    

In [None]:
pvals = []
sm_rvals = []
sw_rvals = []
t_rvals = []
rh_rvals = []
for site in tqdm(sim_sites):
    p = site_attrs.loc[site]['PET/P']
    r_qle = site_r_qle[site]['surf_sm']#.resample({'time': 'W'}).mean()
    r_qh = site_r_qh[site]['surf_sm']#.resample({'time': 'W'}).mean()
    sm_rvals.append(np.corrcoef(r_qle, r_qh)[0,1])
    r_qle = site_r_qle[site]['swradatm']#.resample({'time': 'W'}).mean()
    r_qh = site_r_qh[site]['swradatm']#.resample({'time': 'W'}).mean()
    sw_rvals.append(np.corrcoef(r_qle, r_qh)[0,1])
                    
    r_qle = site_r_qle[site]['airtemp']#.resample({'time': 'W'}).mean()
    r_qh = site_r_qh[site]['airtemp']#.resample({'time': 'W'}).mean()
    t_rvals.append(np.corrcoef(r_qle, r_qh)[0,1])
    r_qle = site_r_qle[site]['relhum']#.resample({'time': 'W'}).mean()
    r_qh = site_r_qh[site]['relhum']#.resample({'time': 'W'}).mean()
    rh_rvals.append(np.corrcoef(r_qle, r_qh)[0,1])
    pvals.append(p)
    

In [None]:
pvals = []
t_rrat = []
sw_rrat = []
for site in sim_sites:
    p = site_attrs.loc[site]['PET/P']
                    
    r_qle = site_r_qle[site]['airtemp'].resample({'time': 'D'}).mean()
    r_qh = site_r_qh[site]['airtemp'].resample({'time': 'D'}).mean()
    t_rrat.append(np.mean(r_qle)/np.mean(r_qh))
    r_qle = site_r_qle[site]['swradatm'].resample({'time': 'D'}).mean()
    r_qh = site_r_qh[site]['swradatm'].resample({'time': 'D'}).mean()
    sw_rrat.append(np.mean(r_qle)/np.mean(r_qh))
    pvals.append(p)
    

In [None]:
fig, axes = plt.subplots(2, 1, dpi=150, sharex=True, figsize=(8, 12))
axes[0].plot(pvals, sm_rvals, label='Surface moisture', linewidth=0, marker='o', color='tab:blue')
#axes[0].set_xlabel('PET/P')
axes[0].set_ylabel(r'Correlation($R_{SM\rightarrow Q_{le}}, R_{SM\rightarrow Q_{h}}$)')
axes[0].axvline(1, linestyle='--', color='grey', zorder=-10)

#axes[1].scatter(pvals, sw_rrat, marker='o', color='tab:blue')
#axes[1].set_ylim([0, 1.8])
#axes[1].set_xlabel('PET/P')
#axes[1].set_ylabel(r'$\frac{\bar{R}_{SW\rightarrow Q_{le}}}{\bar{R}_{SW\rightarrow Q_{h}}}$', fontsize=28)
#axes[1].axvline(1, linestyle='--', color='grey', zorder=-10)

axes[1].scatter(pvals, r_sm_rh_qle)
axes[1].set_xlabel('PET/P')
axes[1].set_ylabel(r'Correlation($R_{SM\rightarrow Q_{le}}, R_{RH\rightarrow Q_{le}}$)')
axes[1].axvline(1, linestyle='--', color='grey', zorder=-10)
plt.tight_layout()

In [None]:
mean_r_qle_df = {}
mean_r_qh_df = {}
normalize = True
petp = site_attrs['PET/P'].sort_values().to_dict()
for site in petp.keys():
    avg_qle = site_ds[site]['scalarLatHeatTotal'].median(dim='time').values[()]
    avg_qh = site_ds[site]['scalarSenHeatTotal'].median(dim='time').values[()]
    mean_r_qle_df[site] = site_r_qle[site].to_dataframe().mean() 
    mean_r_qh_df[site] = site_r_qh[site].to_dataframe().mean() 
    if normalize:
        mean_r_qle_df[site] /= np.sum(np.abs(mean_r_qle_df[site]))
        mean_r_qh_df[site] /= np.sum(np.abs(mean_r_qh_df[site]))

In [None]:
petp_idx = np.argmin(np.abs(np.array(list(petp.values())) - 1.0))

In [None]:
rename_dict = {'airtemp': 'Temperature',
               'relhum': 'Humidity',
               'swradatm': 'Shortwave',
               'surf_sm': 'Soil Moisture',
               'lai': 'LAI',
               'vegtype': 'Vegetation Type'
              }
mean_r_qle_df = {s: d.rename(rename_dict) for s, d in mean_r_qle_df.items()}

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(18,10), sharex=True, sharey=True)

pd.DataFrame(mean_r_qle_df).T.plot.bar(ax=axes[0], stacked=True, legend=False, width=0.7)
pd.DataFrame(mean_r_qh_df).T.plot.bar(ax=axes[1], stacked=True, legend=False, width=0.7)
axes[0].legend(ncol=7)
axes[0].axhline(0, color='black', zorder=-10)
axes[0].set_ylabel(r'Fraction of relevance to $Q_{le}$')
axes[1].axhline(0, color='black', zorder=-10)
axes[1].set_ylabel(r'Fraction of relevance to $Q_{h}$')
axes[0].axvline(petp_idx+0.505, linestyle='--', color='dimgrey')
axes[1].axvline(petp_idx+0.505, linestyle='--', color='dimgrey')
axes[1].set_xlabel(r'Increasing PET/P $\longrightarrow$')
#axes[1].annotate('local max', xy=(30, -1.6),  xycoords='data',
#            xytext=(10, -1.2), textcoords='data',
#            arrowprops=dict(facecolor='black', shrink=0.05),
#            horizontalalignment='right', verticalalignment='top',
#            )
plt.tight_layout()
plt.tight_layout()

In [None]:
analyzer = innvestigate.analyzer.LRPEpsilon(model, epsilon=1e-3, neuron_selection_mode='index')
qle_analyzer = partial(analyzer.analyze, neuron_selection=0)
qh_analyzer = partial(analyzer.analyze, neuron_selection=1)

In [None]:
def sensitivity_to_sm(new_input, surf_mults, qle_lrp, qh_lrp):
    qle_mults = []
    qh_mults = []
    r_qle = []
    r_qh = []
    for sm in surf_mults:
        modified_input = new_input.copy()
        # Surface moisture
        modified_input[:, 3] = sm
           
        qle_mults.append(500 * model.predict(modified_input)[0, 0].flatten())
        qh_mults.append(500 * model.predict(modified_input)[0, 1].flatten())
            
        r_qle.append(qle_lrp(modified_input)[0, 3])
        r_qh.append(qh_lrp(modified_input)[0, 3])
        
    qle_mults = np.array(qle_mults).flatten()
    qh_mults = np.array(qh_mults).flatten()
    r_qle = np.array(r_qle).flatten()
    r_qh = np.array(r_qh).flatten()
    qle = xr.DataArray(qle_mults, 
                   coords={'surface_moisture': surf_mults},
                   dims=['surface_moisture'],
                   name='Qle')
    qh = xr.DataArray(qh_mults, 
                      coords={'surface_moisture': surf_mults}, 
                      dims=['surface_moisture'],
                       name='Qh')
    r_qle_sm = xr.DataArray(r_qle, 
                   coords={'surface_moisture': surf_mults},
                   dims=['surface_moisture'],
                   name='r_Qle_sm')
    r_qh_sm = xr.DataArray(r_qh, 
                      coords={'surface_moisture': surf_mults},
                      dims=['surface_moisture'],
                       name='r_Qh_sm')
    return xr.merge([qle, qh, r_qle_sm, r_qh_sm])

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 10), sharey='row')
axes = axes.T.flatten()

site = 'CH-Fru'
var = 'airtemp'
time_idx = np.argmin(np.abs(site_r_qle[site][var].values - site_r_qle[site][var].max().values[()]))
new_input, new_output = format_data_for_nn(site_ds[site].isel(time=slice(time_idx-1, time_idx+1)), use_mask=False)

sm_vals = np.arange(0, 1, 0.01)
sm_ds = sensitivity_to_sm(new_input, sm_vals, qle_analyzer, qh_analyzer)
sm_ds.assign_coords({'surface_moisture': sm_vals})

axes[0].plot(sm_vals, sm_ds['Qle'], label='Latent heat')
axes[0].plot(sm_vals, sm_ds['Qh'] , label='Sensible heat')
axes[1].plot(sm_vals, sm_ds['r_Qle_sm'], label='Latent heat')
axes[1].plot(sm_vals, sm_ds['r_Qh_sm'], label='Sensible heat')

axes[0].set_title(site)
site = 'US-Whs'
time_idx = np.argmin(np.abs(site_r_qle[site][var].values - site_r_qle[site][var].max().values[()]))
sm_vals = np.arange(0, 1, 0.01)

new_input, new_output = format_data_for_nn(site_ds[site].isel(time=slice(time_idx, time_idx+1)), use_mask=False)
sm_ds = sensitivity_to_sm(new_input, sm_vals, qle_analyzer, qh_analyzer)
sm_ds.assign_coords({'surface_moisture': sm_vals})

axes[2].plot(sm_vals, sm_ds['Qle']     , label='Latent heat')
axes[2].plot(sm_vals, sm_ds['Qh']      , label='Sensible heat')
axes[3].plot(sm_vals, sm_ds['r_Qle_sm'], label='Latent heat')
axes[3].plot(sm_vals, sm_ds['r_Qh_sm'] , label='Sensible heat')
axes[1].axhline(0, color='grey', linestyle='--')
axes[3].axhline(0, color='grey', linestyle='--')
axes[2].set_title(site)
axes[0].legend()
axes[0].set_ylabel(r'Heat flux $(W/m^2)$')
axes[1].set_ylabel('Relevance')
axes[1].set_xlabel('Degree of saturation')
axes[3].set_xlabel('Degree of saturation')
plt.tight_layout()

In [None]:
linearizability_rel_qle = []
linearizability_rel_qle_kge = []
linearizability_rel_qh = []
linearizability_rel_qh_kge = []
linearizability_phys_qle = []
linearizability_phys_qle_kge = []
linearizability_phys_qh = []
linearizability_phys_qh_kge = []

for site in tqdm(sim_sites):
    r_ds_qle = site_r_qle[site]
    r_ds_qh = site_r_qh[site]
    phys_ds = site_ds[site]
    X = r_ds_qle.to_array().values.T
    qle = phys_ds[['scalarLatHeatTotal']].to_array().values.T
    qh = phys_ds[['scalarSenHeatTotal']].to_array().values.T
    
    rel_regr_qle = linear_model.LinearRegression()
    rel_regr_qle.fit(X, qle)
    qle_hat = rel_regr_qle.predict(X)
    linearizability_rel_qle.append( rel_regr_qle.score(X, qle))
    linearizability_rel_qle_kge.append(pse.kling_gupta_efficiency(qle_hat, qle))
    
    X = r_ds_qh.to_array().values.T
    rel_regr_qh = linear_model.LinearRegression()
    rel_regr_qh.fit(X, qh)
    qh_hat = rel_regr_qh.predict(X)
    linearizability_rel_qh.append( rel_regr_qh.score(X, qh))
    linearizability_rel_qh_kge.append(pse.kling_gupta_efficiency(qh_hat, qh))

    X = format_data_for_nn(phys_ds, use_mask=False)[0]
    
    phys_regr_qle = linear_model.LinearRegression()
    phys_regr_qle.fit(X, qle)
    qle_hat = phys_regr_qle.predict(X)
    linearizability_phys_qle.append( phys_regr_qle.score(X, qle))
    linearizability_phys_qle_kge.append(pse.kling_gupta_efficiency(qle_hat, qle))
    
    phys_regr_qh = linear_model.LinearRegression()
    phys_regr_qh.fit(X, qh)
    qh_hat = phys_regr_qh.predict(X)
    linearizability_phys_qh.append( phys_regr_qh.score(X, qh))
    linearizability_phys_qh_kge.append(pse.kling_gupta_efficiency(qh_hat, qh))

In [None]:
data = np.vstack([linearizability_rel_qle_kge, linearizability_rel_qh_kge]).T
vplot = sns.violinplot(data=data, color='silver')

plt.gca().get_children()[1].set_color('white')
plt.gca().get_children()[3].set_color('white')
plt.gca().get_children()[4].set_color('black')
plt.gca().get_children()[5].set_color('black')
plt.gca().get_children()[6].set_color('black')
plt.gca().get_children()[7].set_color('black')
plt.gca().get_children()[0].set_edgecolor(None)
plt.gca().get_children()[2].set_edgecolor(None)
plt.xticks(ticks=[0,1], labels=['Latent heat', 'Sensible heat'])
plt.ylabel(r'KGE($Q_{NNLRP}, Q_{LM}$)')

In [None]:
linear_qle_kge = []
linear_qh_kge = []
full_qle_kge = []
full_qh_kge = []


for site in tqdm(sim_sites):
    r_ds_qle = site_r_qle[site]
    r_ds_qh = site_r_qh[site]
    phys_ds = site_ds[site]
    X = r_ds_qle.to_array().values.T
    qle = phys_ds[['scalarLatHeatTotal']].to_array().values.T
    qh = phys_ds[['scalarSenHeatTotal']].to_array().values.T
    qle_obs = phys_ds[['Qle_cor']].to_array().values.T
    qh_obs = phys_ds[['Qh_cor']].to_array().values.T
    
    rel_regr_qle = linear_model.LinearRegression()
    rel_regr_qle.fit(X, qle)
    qle_hat = rel_regr_qle.predict(X)
    
    X = r_ds_qh.to_array().values.T
    rel_regr_qh = linear_model.LinearRegression()
    rel_regr_qh.fit(X, qh)
    qh_hat = rel_regr_qh.predict(X)

    linear_qle_kge.append(pse.kling_gupta_efficiency(qle_hat, qle_obs))
    full_qle_kge.append(pse.kling_gupta_efficiency(qle, qle_obs))
    
    phys_regr_qh = linear_model.LinearRegression()
    phys_regr_qh.fit(X, qh)
    qh_hat = phys_regr_qh.predict(X)
    linearizability_phys_qh.append( phys_regr_qh.score(X, qh))
    linearizability_phys_qh_kge.append(pse.kling_gupta_efficiency(qh_hat, qh))
    
    linear_qh_kge.append(pse.kling_gupta_efficiency(qh_hat, qh_obs))
    full_qh_kge.append(pse.kling_gupta_efficiency(qh, qh_obs))

In [None]:
minsize = np.min([len(ds['time']) for ds in site_ds.values()])
argmin  = np.argmin([len(ds['time']) for ds in site_ds.values()])

In [None]:
similarity_qle = np.nan * np.zeros((len(sim_sites), len(sim_sites)))
similarity_qh  = np.nan * np.zeros((len(sim_sites), len(sim_sites)))
similarity_mean = np.nan * np.zeros((len(sim_sites), len(sim_sites)))
similarity_qle_kge = np.nan * np.zeros((len(sim_sites), len(sim_sites)))
similarity_qh_kge = np.nan * np.zeros((len(sim_sites), len(sim_sites)))

coef_qle = []
coef_qh = []

regr_vars = ['airtemp', 'relhum', 'swradatm', 'surf_sm', 'lai', 'vegtype']

for i, site_a in tqdm(enumerate(sim_sites)):
    r_ds_qle = site_r_qle[site_a]
    r_ds_qh = site_r_qh[site_a]
    phys_ds = site_ds[site_a]
    
    X_qle = r_ds_qle[regr_vars].to_array().values.T
    X_qh = r_ds_qh[regr_vars].to_array().values.T
    
    #X_qle = format_data_for_nn(phys_ds, use_mask=False)[0]
    #X_qh = format_data_for_nn(phys_ds, use_mask=False)[0]
    
    qle = phys_ds[['scalarLatHeatTotal']].to_array().values.T
    qh = phys_ds[['scalarSenHeatTotal']].to_array().values.T
   
    # Trim to shortest length with random selection
    sel_idxs = np.random.choice(np.arange(X_qle.shape[0]), size=minsize)
    X_qle = X_qle[sel_idxs, :]
    X_qh = X_qh[sel_idxs, :]
    qle = qle[sel_idxs, :]
    qh = qh[sel_idxs, :]

    regr_qle = linear_model.LinearRegression()
    regr_qle.fit(X_qle, qle)
    coef_qle.append(np.hstack([regr_qle.intercept_, regr_qle.coef_.flatten()]))
    
    regr_qh = linear_model.LinearRegression()
    regr_qh.fit(X_qh, qh)
    coef_qh.append(np.hstack([regr_qh.intercept_, regr_qh.coef_.flatten()]))
 
    for j, site_b in enumerate(sim_sites):
        r_ds_qle = site_r_qle[site_b]
        r_ds_qh = site_r_qh[site_b]
        phys_ds = site_ds[site_b]
        
        X_qle = r_ds_qle[regr_vars].to_array().values.T
        X_qh = r_ds_qh[regr_vars].to_array().values.T
        #X_qle = format_data_for_nn(phys_ds, use_mask=False)[0]
        #X_qh = format_data_for_nn(phys_ds, use_mask=False)[0]
        
        qle = phys_ds[['scalarLatHeatTotal']].to_array().values.T
        qh = phys_ds[['scalarSenHeatTotal']].to_array().values.T
        qle_hat = regr_qle.predict(X_qle)
        qh_hat = regr_qh.predict(X_qh)
   
        similarity_qle_kge[i, j] = pse.kling_gupta_efficiency(qle_hat, qle)
        similarity_qh_kge[i, j] = pse.kling_gupta_efficiency(qh_hat, qh)
        similarity_qle[i, j] = np.max([0, regr_qle.score(X_qle, qle)])
        similarity_qh[i, j] = np.max([0, regr_qh.score(X_qh, qh)])
        similarity_mean[i, j] = np.max([0, 0.5 * (regr_qh.score(X_qh, qh) + regr_qle.score(X_qle, qle))])

In [None]:
similarity_ds = xr.Dataset()
similarity_ds['Qle'] = xr.DataArray(data=similarity_qle_kge, dims=['source', 'target'], coords={'source': sim_sites, 'target': sim_sites})
similarity_ds['Qh'] = xr.DataArray(data=similarity_qh_kge, dims=['source', 'target'], coords={'source': sim_sites, 'target': sim_sites})
similarity_ds['mean'] = 0.5 * (similarity_ds['Qle'] + similarity_ds['Qh'])

In [None]:
regrs = ['intercept'] + regr_vars
coef_ds = xr.Dataset()
coef_ds['qle'] = xr.DataArray(data=coef_qle, dims=['site', 'coef'], coords={'site': sim_sites, 'coef': regrs})
coef_ds['qh'] = xr.DataArray(data=coef_qh, dims=['site', 'coef'], coords={'site': sim_sites, 'coef': regrs})

In [None]:
from matplotlib.colors import LinearSegmentedColormap
colors = [
          'orange', 
          'goldenrod', 
          'forestgreen',
          'yellowgreen', 
          'lightgreen', 
          'khaki', 
          'mediumseagreen', 
          'chocolate', 
          'dodgerblue', 
          'firebrick', 
          'red',
         ]
veg_idxs = site_attrs['Veg Code'].unique().astype(int)-1
cmp = LinearSegmentedColormap.from_list('veg', colors, N=len(colors))

In [None]:
color_mapping = {v: c for v, c in zip(sorted(site_attrs['Veg Type'].unique()), colors)}

In [None]:
fig, ax = plt.subplots(1,1,figsize=(12,12))
for s in similarity_ds.source.values:
    c = color_mapping[site_attrs.loc[s]['Veg Type']]
    plt.scatter((similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))),
                (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))), 
                 color=c, s=250, marker='.', alpha=1)

plt.axhline(-.17, color='dimgrey', linestyle='--')
plt.axvline(-.17, color='dimgrey', linestyle='--')

plt.gca().set_xlim([-1, 1])
plt.gca().set_ylim([-1, 1])
legend_elements = [Line2D([0], [0], marker='o', color='w', label=v, markerfacecolor=c, markersize=12, )
                   for v,c in color_mapping.items()]
# Create the figure
ax.legend(handles=legend_elements, fontsize=14, loc='lower left', frameon=False)
ax.set_xlabel('Site performance as predictor ($KGE_m$)')
ax.set_ylabel('Site performance as predictand ($KGE_m$)')

In [None]:
fig, axes = plt.subplots(1,3,figsize=(21, 7), dpi=250, sharey=True)
axes = axes.flatten()
for s in similarity_ds.source.values:
    c = color_mapping[site_attrs.loc[s]['Veg Type']]
    xmed = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).median()
    x_lo_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).quantile(q=0.25)
    x_hi_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).quantile(q=0.75)
    x2_lo_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).quantile(q=0.35)
    x2_hi_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).quantile(q=0.65)
    
    ymed = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).median()
    y_lo_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).quantile(q=0.25)
    y_hi_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).quantile(q=0.75)
    y2_lo_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).quantile(q=0.35)
    y2_hi_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).quantile(q=0.65)
    axes[0].plot([x_lo_iqr, x_hi_iqr], [ymed, ymed], color=c, linewidth=3, alpha=0.75, zorder=-1)
    axes[0].plot([xmed, xmed], [y_lo_iqr, y_hi_iqr], color=c, linewidth=3, alpha=0.75, zorder=-1)
    #plt.plot([x2_lo_iqr, x2_hi_iqr], [ymed, ymed], color=c, linewidth=8, alpha=0.5, zorder=-1)
    #plt.plot([xmed, xmed], [y2_lo_iqr, y2_hi_iqr], color=c, linewidth=8, alpha=0.5, zorder=-1)
    axes[0].scatter(xmed, ymed, color=c, s=400, marker='.', edgecolor='black', alpha=1, linewidth=1)

axes[0].axhline(-.17, color='dimgrey', linestyle='--')
axes[0].axvline(-.17, color='dimgrey', linestyle='--')
axes[0].set_xlim([-1, 1])
axes[0].set_ylim([-1, 1])

enf = ['RU-Fyo', 'CA-Qfo', 'AU-ASM', 'US-Prr', 'IT-Ren', 'IT-Lav', 'US-GLE', 'IT-SRo', 'DE-Tha', 'FI-Let', 'CA-TP3', 'US-NR1', 'FI-Hyy', 'US-Blo', 'DE-Obe', 'FR-LBr', 'CA-TP1', 'FI-Sod', 'AU-Wac']
for i, s in enumerate(enf):
    c = color_mapping[site_attrs.loc[s]['Veg Type']]
    xmed = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=enf).median()
    x_lo_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=enf).quantile(q=0.25)
    x_hi_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=enf).quantile(q=0.75)
    x2_lo_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=enf).quantile(q=0.35)
    x2_hi_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=enf).quantile(q=0.65)
    
    ymed = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=enf).median()
    y_lo_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=enf).quantile(q=0.25)
    y_hi_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=enf).quantile(q=0.75)
    y2_lo_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=enf).quantile(q=0.35)
    y2_hi_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=enf).quantile(q=0.65)
    axes[2].plot([x_lo_iqr, x_hi_iqr], [ymed, ymed], color=c, linewidth=3, alpha=0.75, zorder=-1)
    axes[2].plot([xmed, xmed], [y_lo_iqr, y_hi_iqr], color=c, linewidth=3, alpha=0.75, zorder=-1)
    #plt.plot([x2_lo_iqr, x2_hi_iqr], [ymed, ymed], color=c, linewidth=8, alpha=0.5, zorder=-1)
    #plt.plot([xmed, xmed], [y2_lo_iqr, y2_hi_iqr], color=c, linewidth=8, alpha=0.5, zorder=-1)
    axes[2].scatter(xmed, ymed, color=c, s=400, marker='.', edgecolor='black', alpha=1, linewidth=1)
    
axes[1].axhline(-.17, color='dimgrey', linestyle='--')
axes[1].axvline(-.17, color='dimgrey', linestyle='--')
axes[1].set_xlim([-1, 1])
axes[1].set_ylim([-1, 1])   

nonenf = list(set(sim_sites) - set(enf))
for i, s in enumerate(nonenf):
    c = color_mapping[site_attrs.loc[s]['Veg Type']]
    xmed = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=nonenf).median()
    x_lo_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=nonenf).quantile(q=0.25)
    x_hi_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=nonenf).quantile(q=0.75)
    x2_lo_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=nonenf).quantile(q=0.35)
    x2_hi_iqr = (similarity_ds['mean'].sel(source=s) / (2-similarity_ds['mean'].sel(source=s))).sel(target=nonenf).quantile(q=0.65)
    
    ymed = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=nonenf).median()
    y_lo_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=nonenf).quantile(q=0.25)
    y_hi_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=nonenf).quantile(q=0.75)
    y2_lo_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=nonenf).quantile(q=0.35)
    y2_hi_iqr = (similarity_ds['mean'].sel(target=s) / (2-similarity_ds['mean'].sel(target=s))).sel(source=nonenf).quantile(q=0.65)
    axes[1].plot([x_lo_iqr, x_hi_iqr], [ymed, ymed], color=c, linewidth=3, alpha=0.75, zorder=-1)
    axes[1].plot([xmed, xmed], [y_lo_iqr, y_hi_iqr], color=c, linewidth=3, alpha=0.75, zorder=-1)
    #plt.plot([x2_lo_iqr, x2_hi_iqr], [ymed, ymed], color=c, linewidth=8, alpha=0.5, zorder=-1)
    #plt.plot([xmed, xmed], [y2_lo_iqr, y2_hi_iqr], color=c, linewidth=8, alpha=0.5, zorder=-1)
    axes[1].scatter(xmed, ymed, color=c, s=400, marker='.', edgecolor='black', alpha=1, linewidth=1)

legend_elements = [Line2D([0], [0], marker='o', color='w', label=v, markerfacecolor=c, markersize=12, )
                   for v,c in color_mapping.items()]
axes[2].legend(handles=legend_elements, fontsize=14, loc='lower left', frameon=True, framealpha=1)
   
    
axes[2].axhline(-.17, color='dimgrey', linestyle='--')
axes[2].axvline(-.17, color='dimgrey', linestyle='--')
axes[2].set_xlim([-1, 1])
axes[2].set_ylim([-1, 1])   

axes[0].set_xlabel('Site performance as predictor ($KGE_m$)')
axes[0].set_ylabel('Site performance as predictand ($KGE_m$)')
axes[1].set_xlabel('Site performance as predictor ($KGE_m$)')
axes[2].set_xlabel('Site performance as predictor ($KGE_m$)')

axes[0].set_title('All sites')
axes[1].set_title('Purple cluster (non-evergreen)')
axes[2].set_title('Green cluster (evergreen)')
plt.tight_layout()

In [None]:
import sklearn.cluster

cluster = sklearn.cluster.AgglomerativeClustering(distance_threshold=0, n_clusters=None)
cluster.fit(np.hstack([coef_qh, coef_qle]))
km_labels_qh = cluster.labels_
km_labels_qle = cluster.labels_

In [None]:
def plot_dendrogram(cluster_model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(cluster_model.children_.shape[0])
    n_samples = len(cluster_model.labels_)
    for i, merge in enumerate(cluster_model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([cluster_model.children_, cluster_model.distances_,
                                      counts]).astype(float)
    #linkage_matrix[:, 2] = np.log(linkage_matrix[:, 2])

    # Plot the corresponding dendrogram
    d = dendrogram(linkage_matrix, **kwargs)
    return linkage_matrix, d