In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
import shutil
from glob import glob
import pysumma as ps
import pysumma.evaluation as pse
import pysumma.plotting as psp
from tqdm import tqdm_notebook as tqdm
import xarray as xr
import pandas as pd
from pathlib import Path
import seaborn as sns
import warnings
from sklearn import linear_model
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)

sns.set_context('talk')
mpl.style.use('seaborn-bright')
mpl.rcParams['figure.figsize'] = (10, 8)
from matplotlib.colors import LinearSegmentedColormap
colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', 
          '#42d4f4', '#f032e6', '#fabebe', '#469990', '#e6beff']

colors = ['#e6194b', '#3cb44b', '#ffe119', 
          '#4363d8', '#f58231', '#911eb4', 
          '#46f0f0', '#f032e6', '#bcf60c', 
          '#fabebe', '#008080', '#e6beff', ]
          #'#9a6324', '#fffac8', '#800000', 
          #'#aaffc3', '#808000', '#ffd8b1', 
          #'#000075', '#808080', '#ffffff', '#000000']
cm = LinearSegmentedColormap.from_list('veg', colors, N=len(colors))

In [None]:
sites = os.listdir('../sites/')



seed = 50334
np.random.seed(seed)
np.random.shuffle(sites)
len(sites)

nfold = 5
equal_divisor_len = nfold * (len(sites) // nfold)
kfold_test_sites = list(np.array(sites[0:equal_divisor_len]).reshape(nfold, -1))
kfold_test_sites[-1] = np.hstack([kfold_test_sites[-1], sites[equal_divisor_len:]])
kfold_train_sites = [list(set(sites) - set(test_sites)) for test_sites in kfold_test_sites]

all_fluxfiles = {}
all_statefiles = {}
bad_sites = []
for i, fold in enumerate(kfold_train_sites):
    for s in fold:
        if s in bad_sites:
            continue
        all_fluxfiles[s] = f'../sites/{s}/output/flux_nn_output_{s}_timestep.nc'
        all_statefiles[s] = f'../sites/{s}/output/state2_nn_output_{s}_timestep.nc'

sim_sites = [s for s in sites if s not in bad_sites]

In [None]:
def merge_site_data(site):
    try:
        with xr.open_dataset(all_statefiles[site]) as d:
            state_data = d.load().drop_vars(['scalarSoilWatBalError']).isel(hru=0, gru=0, drop=True)
            state_data['time'] = state_data['time'].dt.round('30min')
        with xr.open_dataset(all_fluxfiles[site]) as d:
            flux_data = d.load().isel(hru=0, gru=0, drop=True)
            flux_data['time'] = flux_data['time'].dt.round('30min')
        with xr.open_dataset(f'../sites/{site}/output/template_output_{site}_timestep.nc') as d:
            sim_data = d.load().isel(hru=0, gru=0, drop=True)
            sim_data['time'] = sim_data['time'].dt.round('30min')
        #with xr.open_dataset(f'../sites/{site}/output/naive_output_{site}_timestep.nc') as d:
        #    naive_data = d.load().isel(hru=0, gru=0, drop=True)
        #    naive_data['time'] = naive_data['time'].dt.round('30min')
        with xr.open_dataset(f'../sites/{site}/forcings/{site}.nc') as d:
            obs_data = d.load().isel(hru=0, drop=True)
            obs_data['time'] = obs_data['time'].dt.round('30min')
        with xr.open_dataset(f'../sites/{site}/params/parameter_trial.nc') as d:
            parm_data = d.load()
        with xr.open_dataset(f'../sites/{site}/params/local_attributes.nc') as d:
            attr_data = d.load()
           
        vars_to_drop = ['scalarSoilWatBalError', 'scalarCanopyWat', 'mLayerVolFracIce', 'mLayerVolFracLiq', 'mLayerTranspire']
        for v in vars_to_drop:
            if v in state_data:
                state_data = state_data.drop(v)
            if v in flux_data:
                flux_data = flux_data.drop(v)
            if v in sim_data:
                sim_data = sim_data.drop(v)
            #if v in naive_data:
            #    naive_data = naive_data.drop(v)
        
            
        state_times = pse.trim_time(state_data, obs_data)
        flux_times = pse.trim_time(flux_data, obs_data)
        sa_times = pse.trim_time(sim_data, obs_data)
        #naive_times = pse.trim_time(naive_data, obs_data)
        start = max(state_times.start, flux_times.start, sa_times.start)#, naive_times.start)
        stop  = min(state_times.stop, flux_times.stop, sa_times.stop)#, naive_times.stop)
        times = slice(start, stop)
        state_data = state_data.sel(time=times)
        flux_data = flux_data.sel(time=times)
        sim_data = sim_data.sel(time=times)
        obs_data = obs_data.sel(time=times)
        
        ds = (xr.merge([xr.concat([state_data, flux_data, sim_data], dim='type')
                          .assign_coords({'type': ['SUMMA-NN2W', 'SUMMA-NN1W', 'SUMMA-SA']}),
                         obs_data, parm_data, attr_data]))
        #ds = (xr.merge([xr.concat([state_data, flux_data, sim_data, naive_data], dim='type')
        #                  .assign_coords({'type': ['SUMMA-NN2W', 'SUMMA-NN1W', 'SUMMA-SA', 'SUMMA-NC']}),
        #                 obs_data, parm_data, attr_data]))
        ds = ds.where(ds['gap_filled']==0, other=np.nan)
        return ds
    except:
        print(site)
        #raise
        return None

In [None]:
def compute_metric(simvar, obsvar, metric=pse.nash_sutcliffe_efficiency):
    return metric(simvar, obsvar)

In [None]:
site_data = {site: merge_site_data(site) for site in tqdm(sim_sites)}

site_data = {site: site_data[site] for site in sim_sites if site_data[site] is not None}# and site not in more_bad_sites}

In [None]:
#site_attr = merge_site_attr_parm(list(site_data.keys()))

bad_sites=[]
complete_sites = [s for s in list(site_data.keys()) if s not in bad_sites]

In [None]:
standalone_day_nse_qle = {}
flux_day_nse_qle = {}
state_day_nse_qle = {}

standalone_day_kge_qle = {}
flux_day_kge_qle = {}
state_day_kge_qle = {}

standalone_night_nse_qle = {}
flux_night_nse_qle = {}
state_night_nse_qle = {}

standalone_night_kge_qle = {}
flux_night_kge_qle = {}
state_night_kge_qle = {}

for site in tqdm(complete_sites):
    daytime_filter = np.logical_and(site_data[site]['time'].dt.hour.values >= 8, site_data[site]['time'].dt.hour.values < 20)
    nighttime_filter = np.logical_or(site_data[site]['time'].dt.hour.values < 8, site_data[site]['time'].dt.hour.values >= 20)
    
    flux_daytime = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W').values[daytime_filter]
    state_daytime = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W').values[daytime_filter]
    standalone_daytime = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA').values[daytime_filter]
    observed_daytime = site_data[site]['Qle_cor'].values[daytime_filter]
    
    flux_nighttime = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W').values[nighttime_filter]
    state_nighttime = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W').values[nighttime_filter]
    standalone_nighttime = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA').values[nighttime_filter]
    observed_nighttime = site_data[site]['Qle_cor'].values[nighttime_filter]
    
    standalone_day_nse_qle[site] = compute_metric(standalone_daytime, observed_daytime)
    flux_day_nse_qle[site] = compute_metric(flux_daytime, observed_daytime)
    state_day_nse_qle[site] = compute_metric(state_daytime, observed_daytime)
    
    standalone_day_kge_qle[site] = compute_metric(standalone_daytime, observed_daytime, metric=pse.kling_gupta_efficiency)
    flux_day_kge_qle[site] = compute_metric(flux_daytime, observed_daytime, metric=pse.kling_gupta_efficiency)
    state_day_kge_qle[site] = compute_metric(state_daytime, observed_daytime, metric=pse.kling_gupta_efficiency)
    
    standalone_night_nse_qle[site] = compute_metric(standalone_nighttime, observed_nighttime)
    flux_night_nse_qle[site] = compute_metric(flux_nighttime, observed_nighttime)
    state_night_nse_qle[site] = compute_metric(state_nighttime, observed_nighttime)
    
    standalone_night_kge_qle[site] = compute_metric(standalone_nighttime, observed_nighttime, metric=pse.kling_gupta_efficiency)
    flux_night_kge_qle[site] = compute_metric(flux_nighttime, observed_nighttime, metric=pse.kling_gupta_efficiency)
    state_night_kge_qle[site] = compute_metric(state_nighttime, observed_nighttime, metric=pse.kling_gupta_efficiency)

flux_nse_qle = {site: compute_metric(-site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W').values, 
                                 site_data[site]['Qle_cor']) 
                  for site in tqdm(complete_sites)}
state_nse_qle = {site: compute_metric(-site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W').values, 
                                  site_data[site]['Qle_cor']) 
                  for site in tqdm(complete_sites)}
standalone_nse_qle    = {site: compute_metric(-site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA').values, 
                                       site_data[site]['Qle_cor']) 
                  for site in tqdm(complete_sites)}

standalone_kge_qle = {site: compute_metric(-site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA').values, 
                                       site_data[site]['Qle_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(complete_sites)}

flux_kge_qle    = {site: compute_metric(-site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W').values, 
                                       site_data[site]['Qle_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(complete_sites)}

state_kge_qle    = {site: compute_metric(-site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W').values, 
                                       site_data[site]['Qle_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(complete_sites)}


y_vals = np.arange(0, len(complete_sites)) / len(complete_sites)

In [None]:
standalone_day_nse_qh = {}
flux_day_nse_qh = {}
state_day_nse_qh = {}

standalone_day_kge_qh = {}
flux_day_kge_qh = {}
state_day_kge_qh = {}

standalone_night_nse_qh = {}
flux_night_nse_qh = {}
state_night_nse_qh = {}

standalone_night_kge_qh = {}
flux_night_kge_qh = {}
state_night_kge_qh = {}

for site in tqdm(complete_sites):
    daytime_filter = np.logical_and(site_data[site]['time'].dt.hour.values >= 8, site_data[site]['time'].dt.hour.values < 20)
    nighttime_filter = np.logical_or(site_data[site]['time'].dt.hour.values < 8, site_data[site]['time'].dt.hour.values >= 20)
    
    flux_daytime = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN1W').values[daytime_filter]
    state_daytime = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN2W').values[daytime_filter]
    standalone_daytime = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-SA').values[daytime_filter]
    observed_daytime = site_data[site]['Qh_cor'].values[daytime_filter]
    
    flux_nighttime = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN1W').values[nighttime_filter]
    state_nighttime = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN2W').values[nighttime_filter]
    standalone_nighttime = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-SA').values[nighttime_filter]
    observed_nighttime = site_data[site]['Qh_cor'].values[nighttime_filter]
    
    standalone_day_nse_qh[site] = compute_metric(standalone_daytime, observed_daytime)
    flux_day_nse_qh[site] = compute_metric(flux_daytime, observed_daytime)
    state_day_nse_qh[site] = compute_metric(state_daytime, observed_daytime)
    
    standalone_day_kge_qh[site] = compute_metric(standalone_daytime, observed_daytime, metric=pse.kling_gupta_efficiency)
    flux_day_kge_qh[site] = compute_metric(flux_daytime, observed_daytime, metric=pse.kling_gupta_efficiency)
    state_day_kge_qh[site] = compute_metric(state_daytime, observed_daytime, metric=pse.kling_gupta_efficiency)
    
    standalone_night_nse_qh[site] = compute_metric(standalone_nighttime, observed_nighttime)
    flux_night_nse_qh[site] = compute_metric(flux_nighttime, observed_nighttime)
    state_night_nse_qh[site] = compute_metric(state_nighttime, observed_nighttime)
    
    standalone_night_kge_qh[site] = compute_metric(standalone_nighttime, observed_nighttime, metric=pse.kling_gupta_efficiency)
    flux_night_kge_qh[site] = compute_metric(flux_nighttime, observed_nighttime, metric=pse.kling_gupta_efficiency)
    state_night_kge_qh[site] = compute_metric(state_nighttime, observed_nighttime, metric=pse.kling_gupta_efficiency)

flux_nse_qh = {site: compute_metric(-site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN1W').values, 
                                 site_data[site]['Qh_cor']) 
                  for site in tqdm(complete_sites)}
state_nse_qh = {site: compute_metric(-site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN2W').values, 
                                  site_data[site]['Qh_cor']) 
                  for site in tqdm(complete_sites)}
standalone_nse_qh    = {site: compute_metric(-site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-SA').values, 
                                       site_data[site]['Qh_cor']) 
                  for site in tqdm(complete_sites)}

standalone_kge_qh = {site: compute_metric(-site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-SA').values, 
                                       site_data[site]['Qh_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(complete_sites)}

flux_kge_qh    = {site: compute_metric(-site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN1W').values, 
                                       site_data[site]['Qh_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(complete_sites)}

state_kge_qh    = {site: compute_metric(-site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN2W').values, 
                                       site_data[site]['Qh_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(complete_sites)}


y_vals = np.arange(0, len(complete_sites)) / len(complete_sites) + (1/ len(complete_sites))

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10), sharey='row', sharex='row')
axes = axes.flatten()

#axes[0].plot(sorted(list(uncalib_nse.values())),y_vals, label='NC', color='orange', linewidth=2)
axes[0].plot(sorted(list(standalone_nse_qle.values())), y_vals, label='SA', color='#4363d8', linewidth=2)
axes[0].plot(sorted(list(flux_nse_qle.values())), y_vals, label='NN1W',     color='#800000', linewidth=2)
axes[0].plot(sorted(list(state_nse_qle.values())), y_vals,label='NN2W',     color='#f58231', linewidth=2)
axes[0].axhline(0, color='black')
axes[0].set_ylim([0, 1])
axes[0].set_xlim([-0.5, 1])
axes[0].set_xlabel('NSE')
axes[0].legend()
axes[0].set_title('Latent heat')


axes[1].plot(sorted(list(standalone_nse_qh.values())), y_vals, label='SA', color='#4363d8', linewidth=2)
axes[1].plot(sorted(list(flux_nse_qh.values())), y_vals, label='NN1W', color='#800000', linewidth=2)
axes[1].plot(sorted(list(state_nse_qh.values())), y_vals,label='NN2W', color='#f58231', linewidth=2)
axes[1].axhline(0, color='black')
axes[1].set_ylim([0, 1])
axes[1].set_xlim([-0.5, 1])
axes[1].set_xlabel('NSE')
axes[1].set_title('Sensible heat')


axes[2].plot(sorted(list(standalone_kge_qle.values())), y_vals, label='SA', color='#4363d8', linewidth=2)
axes[2].plot(sorted(list(flux_kge_qle.values())), y_vals, label='NN1W',     color='#800000', linewidth=2)
axes[2].plot(sorted(list(state_kge_qle.values())),y_vals, label='NN2W',     color='#f58231', linewidth=2)
axes[2].axhline(0, color='black')
axes[2].set_ylim([0, 1])
axes[2].set_xlim([-0.5, 1])
axes[2].set_xlabel('KGE')

axes[3].plot(sorted(list(standalone_kge_qh.values())), y_vals, label='SA', color='#4363d8', linewidth=2)
axes[3].plot(sorted(list(flux_kge_qh.values())), y_vals, label='NN1W', color='#800000', linewidth=2)
axes[3].plot(sorted(list(state_kge_qh.values())),y_vals, label='NN2W', color='#f58231', linewidth=2)
axes[3].axhline(0, color='black')
axes[3].set_ylim([0, 1])
axes[3].set_xlim([-0.5, 1])
axes[3].set_xlabel('KGE')


plt.tight_layout()
fig.text(-.01, 0.27, 'Nonexceedance Probability', rotation=90)



In [None]:
fig, axes = plt.subplots(1, 2, figsize=(17, 9), sharey='row', sharex='row')
axes = axes.flatten()

axes[0].plot(sorted(list(standalone_kge_qle.values())), y_vals, label='SA', color='#4363d8', linewidth=3)
axes[0].plot(sorted(list(flux_kge_qle.values())), y_vals, label='NN1W', color='#800000', linewidth=3)
axes[0].plot(sorted(list(state_kge_qle.values())),y_vals, label='NN2W', color='#f58231', linewidth=3)
axes[0].axhline(0, color='black')
axes[0].set_ylim([0, 1])
axes[0].set_xlim([-0.5, 1])
axes[0].set_xlabel('KGE', fontsize=24)
#axes[0].set_title('Latent heat')
axes[0].legend(fontsize=24)

axes[1].plot(sorted(list(standalone_kge_qh.values())), y_vals, label='SA', color='#4363d8', linewidth=3)
axes[1].plot(sorted(list(flux_kge_qh.values())), y_vals, label='NN1W', color='#800000', linewidth=3)
axes[1].plot(sorted(list(state_kge_qh.values())),y_vals, label='NN2W', color='#f58231', linewidth=3)
axes[1].axhline(0, color='black')
axes[1].set_ylim([0, 1])
axes[1].set_xlim([-0.5, 1])
axes[1].set_xlabel('KGE', fontsize=24)
#axes[1].set_title('Sensible heat')


plt.tight_layout()
fig.text(-.01, 0.3, 'Nonexceedance Probability', rotation=90, fontsize=24)



In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 10), sharey='row', sharex='col')
axes = axes.flatten()
s=10
plot_range = [-.41, 1]
axes[0].plot(plot_range, plot_range, color='black')
axes[0].scatter(list(standalone_nse_qle.values()), list(flux_nse_qle.values()), label='NN1W',  s=s,color='#800000', linewidth=2)
axes[0].scatter(list(standalone_nse_qle.values()),list(state_nse_qle.values()), label='NN2W',  s=s,color='#f58231', linewidth=2)
axes[0].set_title('Latent Heat')
axes[0].set_ylim(plot_range)
axes[0].set_xlim(plot_range)
axes[0].text(0.74, -0.33, 'NSE')

axes[1].plot(plot_range, plot_range, color='black')
axes[1].scatter(list(standalone_nse_qh.values()), list(flux_nse_qh.values()), label='NN1W',  s=s,color='#800000', linewidth=2)
axes[1].scatter(list(standalone_nse_qh.values()),list(state_nse_qh.values()), label='NN2W',  s=s,color='#f58231', linewidth=2)
axes[1].set_title('Sensible Heat')
axes[1].legend()
axes[1].set_ylim(plot_range)
axes[1].set_xlim(plot_range)
axes[1].text(0.74, 0., 'NSE')

axes[2].plot(plot_range, plot_range, color='black')
axes[2].scatter(list(standalone_kge_qle.values()),  list(flux_kge_qle.values()), label='NN1W',  s=s,color='#800000', linewidth=2)
axes[2].scatter(list(standalone_kge_qle.values()), list(state_kge_qle.values()), label='NN2W',  s=s, color='#f58231', linewidth=2)
axes[2].set_ylim(plot_range)
axes[2].set_xlim(plot_range)
axes[2].text(0.74, -0.33, 'KGE')


axes[3].plot(plot_range, plot_range, color='black')
axes[3].scatter(list(standalone_kge_qh.values()), list(flux_kge_qh.values()), label='NN1W',  s=s,color='#800000', linewidth=2)
axes[3].scatter(list(standalone_kge_qh.values()),list(state_kge_qh.values()), label='NN2W',  s=s,color='#f58231', linewidth=2)
axes[3].set_ylim(plot_range)
axes[3].set_xlim(plot_range)
axes[3].text(0.74, -0.33, 'KGE')


fig.text(0.4, 0.05, 'SUMMA-SA Performance')
fig.text(0.01, 0.33, 'SUMMA-NN Performance', rotation=90)

In [None]:
sort_idx = np.argsort(list(standalone_nse_qle.values()))
for i, idx in enumerate(sort_idx):
    dkge_nn1w = list(flux_nse_qle.values())[idx] - list(standalone_nse_qle.values())[idx]
    dkge_nn2w = list(state_nse_qle.values())[idx] - list(standalone_nse_qle.values())[idx]
    print(i, complete_sites[idx], dkge_nn1w, dkge_nn2w)
 

In [None]:
fig, ax= plt.subplots(2, 2, sharex=True, sharey=True, figsize=(16, 8))
ax = ax.flatten()

ax[0].axhline(0, color='grey')
x_axis = []
max_improve = 1.0 - np.array(sorted(list(standalone_nse_qle.values())))
sort_idx = np.argsort(list(standalone_nse_qle.values()))
for i, idx in enumerate(sort_idx):
    dnse_nn1w = list(flux_nse_qle.values())[idx] - list(standalone_nse_qle.values())[idx]
    dnse_nn2w = list(state_nse_qle.values())[idx] - list(standalone_nse_qle.values())[idx]
    x_axis.append(100 * (i)/60)
    if i > 0:
        ax[0].scatter(100 * (i)/60, dnse_nn1w, color='#800000', marker='D')
        ax[0].scatter(100 * (i)/60, dnse_nn2w, color='#f58231', marker='X')
    else: 
        ax[0].scatter(100 * (i)/60, dnse_nn1w, color='#800000', label='NN1W', marker='D')
        ax[0].scatter(100 * (i)/60, dnse_nn2w, color='#f58231', label='NN2W', marker='X')
ax[0].plot(x_axis, max_improve, color='black', label='Maximum improvement')


ax[1].axhline(0, color='grey')
max_improve = 1.0 - np.array(sorted(list(standalone_nse_qh.values())))
sort_idx = np.argsort(list(standalone_nse_qh.values()))
ax[1].plot(x_axis, max_improve, color='black', label='Maximum improvement')
for i, idx in enumerate(sort_idx):
    dnse_nn1w = list(flux_nse_qh.values())[idx] - list(standalone_nse_qh.values())[idx]
    dnse_nn2w = list(state_nse_qh.values())[idx] - list(standalone_nse_qh.values())[idx]
    if i > 0:
        ax[1].scatter(100 * (i)/60, dnse_nn1w, color='#800000', marker='D')
        ax[1].scatter(100 * (i)/60, dnse_nn2w, color='#f58231', marker='X')
    else: 
        ax[1].scatter(100 * (i)/60, dnse_nn1w, color='#800000', label='NN1W', marker='D')
        ax[1].scatter(100 * (i)/60, dnse_nn2w, color='#f58231', label='NN2W', marker='X')



ax[2].axhline(0, color='grey')
max_improve = 1.0 - np.array(sorted(list(standalone_kge_qle.values())))
sort_idx = np.argsort(list(standalone_kge_qle.values()))
for i, idx in enumerate(sort_idx):
    dkge_nn1w = list(flux_kge_qle.values())[idx] - list(standalone_kge_qle.values())[idx]
    dkge_nn2w = list(state_kge_qle.values())[idx] - list(standalone_kge_qle.values())[idx]
    if i > 0:
        ax[2].scatter(100 * (i)/60, dkge_nn1w, color='#800000', marker='D')
        ax[2].scatter(100 * (i)/60, dkge_nn2w, color='#f58231', marker='X')
    else: 
        ax[2].scatter(100 * (i)/60, dkge_nn1w, color='#800000', label='NN1W', marker='D')
        ax[2].scatter(100 * (i)/60, dkge_nn2w, color='#f58231', label='NN2W', marker='X')
ax[2].plot(x_axis, max_improve, color='black', label='Maximum improvement')


ax[3].axhline(0, color='grey')
max_improve = 1.0 - np.array(sorted(list(standalone_kge_qh.values())))
sort_idx = np.argsort(list(standalone_kge_qh.values()))
for i, idx in enumerate(sort_idx):
    dkge_nn1w = list(flux_kge_qh.values())[idx] - list(standalone_kge_qh.values())[idx]
    dkge_nn2w = list(state_kge_qh.values())[idx] - list(standalone_kge_qh.values())[idx]
    if dkge_nn1w < -1:
        dkge_nn1w = -1
        marker_dkge_nn1w=7
    else:
        marker_dkge_nn1w='D'
    if dnse_nn1w < -1:
        dnse_nn1w = -1
        marker_dkge_nn2w=7
    else:
        marker_dkge_nn2w='X'
    
    if i > 0:
        ax[3].scatter(100 * (i)/60, dkge_nn1w, color='#800000', marker=marker_dkge_nn1w)
        ax[3].scatter(100 * (i)/60, dkge_nn2w, color='#f58231', marker=marker_dkge_nn2w)
    else: 
        ax[3].scatter(100 * (i)/60, dkge_nn1w, color='#800000', label='NN1W', marker=marker_dkge_nn1w)
        ax[3].scatter(100 * (i)/60, dkge_nn2w, color='#f58231', label='NN2W', marker=marker_dkge_nn2w)
ax[3].plot(x_axis, max_improve, color='black', label='Maximum improvement')
ax[3].set_ylim([-1, 1])

ax[0].text(0,  -0.93, 'Latent Heat, NSE')
ax[1].text(60, -0.93, 'Sensible Heat, NSE')
ax[2].text(0,  -0.93, 'Latent Heat, KGE')
ax[3].text(60, -0.93, 'Sensible Heat, KGE')

handles, labels = ax[1].get_legend_handles_labels()
ax[1].legend(handles[::-1], labels[::-1], loc='lower left', frameon=False)
plt.tight_layout()
plt.gcf().text(x=0.35, y=-0.02, s='SA Performance Rank (0=worst, 100=best)')
plt.gcf().text(x=-0.0, y=0.22, s='Performance difference from SA', rotation=90)
#ax[0].set_ylabel('Change in NSE from SA')

In [None]:
sort_idx = np.argsort(list(standalone_kge_qh.values()))
for i, idx in enumerate(sort_idx):
    dkge_nn1w = list(flux_kge_qh.values())[idx] - list(standalone_kge_qh.values())[idx]
    dkge_nn2w = list(state_kge_qh.values())[idx] - list(standalone_kge_qh.values())[idx]
    print(i, complete_sites[idx], dkge_nn1w, dkge_nn2w)
 

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16,6), sharex=True, sharey=False)
axes = axes.flatten()
plot_sites = ['AT-Neu', 'DK-Eng', 'CH-Cha']
for i, site in tqdm(enumerate(plot_sites)):
    if i == len(axes): break
   
    flux_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qh_cor']
    
    flux_qle.groupby(flux_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    state_qle.groupby(state_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    standalone_qle.groupby(standalone_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    observed_qle.groupby(observed_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    #flux_qle.groupby(flux_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    #state_qle.groupby(state_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    #standalone_qle.groupby(standalone_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    #observed_qle.groupby(observed_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    #axes[i].set_xlim([0, 12])
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Latent heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
#axes[0].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16,6), sharex=True, sharey=False)
axes = axes.flatten()
plot_sites = ['AT-Neu', 'DK-Eng', 'CH-Cha']
for i, site in tqdm(enumerate(plot_sites)):
    if i == len(axes): break
   
    flux_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qh_cor']
    
    observed_qle.plot(ax=axes[i], color='black', label='Observed')
    flux_qle.plot(ax=axes[i], color='crimson', label='NN1W')
    state_qle.plot(ax=axes[i], color='forestgreen', label='NN2W')
    standalone_qle.plot(ax=axes[i], color='royalblue', label='SA')
    
    #flux_qle.groupby(flux_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    #state_qle.groupby(state_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    #standalone_qle.groupby(standalone_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    #observed_qle.groupby(observed_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    #axes[i].set_xlim([0, 12])
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Latent heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
#axes[0].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16,6), sharex=True, sharey=False)
axes = axes.flatten()
plot_sites = ['IT-Noe', 'AU-ASM', 'US-Var']
for i, site in tqdm(enumerate(plot_sites)):
    if i == len(axes): break
   
    flux_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qle_cor']
    
    observed_qle.plot(ax=axes[i], color='black', label='Observed')
    flux_qle.plot(ax=axes[i], color='crimson', label='NN1W')
    state_qle.plot(ax=axes[i], color='forestgreen', label='NN2W')
    standalone_qle.plot(ax=axes[i], color='royalblue', label='SA')
    
    #flux_qle.groupby(flux_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    #state_qle.groupby(state_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    #standalone_qle.groupby(standalone_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    #observed_qle.groupby(observed_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    #axes[i].set_xlim([0, 12])
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Latent heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
#axes[0].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16,6), sharex=True, sharey=False)
axes = axes.flatten()
plot_sites = ['AT-Neu', 'DK-Eng', 'CH-Cha']
plot_sites = ['IT-Noe', 'AU-ASM', 'US-Var']

for i, site in tqdm(enumerate(plot_sites)):
    if i == len(axes): break
   
    flux_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qle_cor']
    
    flux_qle.groupby(flux_qle.time.dt.day).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    state_qle.groupby(state_qle.time.dt.day).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    standalone_qle.groupby(standalone_qle.time.dt.day).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    observed_qle.groupby(observed_qle.time.dt.day).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    #flux_qle.groupby(flux_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    #state_qle.groupby(state_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    #standalone_qle.groupby(standalone_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    #observed_qle.groupby(observed_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    #axes[i].set_xlim([0, 12])
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Latent heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
#axes[0].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
thresh = 0.25
state_decrease_idx = np.argwhere(np.array(list(standalone_nse_qle.values())) - np.array(list(state_nse_qle.values())) > thresh).flatten()
flux_decrease_idx = np.argwhere(np.array(list(standalone_nse_qle.values())) - np.array(list(flux_nse_qle.values())) > thresh).flatten()

In [None]:
state_decrease_idx

In [None]:
flux_decrease_idx

In [None]:
mean_sa_where_decrease = np.median(np.unique(np.array(list(standalone_nse_qle.values()))[np.hstack([state_decrease_idx, flux_decrease_idx])]))

In [None]:
mean_sa_where_decrease

In [None]:
np.argmin(np.abs(np.array(sorted(standalone_nse_qle.values())) - mean_sa_where_decrease))

In [None]:
y_vals[33]

In [None]:
dnse_qle_flux = np.around(np.median(list(flux_nse_qle.values())) - np.median(list(standalone_nse_qle.values())), 3)
dnse_qh_flux = np.around(np.median(list(flux_nse_qh.values())) - np.median(list(standalone_nse_qh.values())), 3)

dnse_qle_state = np.around(np.median(list(state_nse_qle.values())) - np.median(list(standalone_nse_qle.values())), 3)
dnse_qh_state = np.around(np.median(list(state_nse_qh.values())) - np.median(list(standalone_nse_qh.values())), 3)

dkge_qle_flux = np.around(np.median(list(flux_kge_qle.values())) - np.median(list(standalone_kge_qle.values())), 3)
dkge_qh_flux = np.around(np.median(list(flux_kge_qh.values())) - np.median(list(standalone_kge_qh.values())), 3)

dkge_qle_state = np.around(np.median(list(state_kge_qle.values())) - np.median(list(standalone_kge_qle.values())), 3)
dkge_qh_state = np.around(np.median(list(state_kge_qh.values())) - np.median(list(standalone_kge_qh.values())), 3)

In [None]:
(f"In the case of the NN1W we showed a median increase in NSE of {dnse_qle_flux} for latent heat and {dnse_qh_flux} for sensible heat, while the NN2W showed a median increase in NSE of {dnse_qle_state} for latent heat and {dnse_qh_state} for sensible heat. "
+ f"Likewise, for KGE these were {dkge_qle_flux} (Q_le) and {dkge_qh_flux} (Q_h) for NN1W and {dkge_qle_state} (Q_le) and {dkge_qh_state} (Q_h) for NN2W.")

In [None]:
def L_vap(T):
    # J / g
    C = T - 273.16
    return 2500.8 - (2.36 * C) + (0.0016 * C**2) - (0.00006 * C**3)

def Qle_to_ET(Q, T):
    g_per_kg = 1e3
    return (1 / (g_per_kg * L_vap(T))) * Q
    

In [None]:
np.array(list(standalone_nse_qle.keys()))[state_decrease_idx]

In [None]:
worst_state_sites = np.array(list(standalone_nse_qle.keys()))[np.argsort(list(state_nse_qle.values()))][0:10]

In [None]:
worst_state_sites

check_sites = site_data.keys()
check_sites = worst_state_sites
fig, axes = plt.subplots(len(check_sites), 4, figsize=(18, 4 * len(check_sites)))
offset = 0
for i, s in enumerate(check_sites):
    if i != len(check_sites)-1:
        add_legend = False
    else:
        add_legend = True
    (site_data[s]['Qle_cor']            ).resample({'time': 'W'}).mean().plot.line(x='time', ax=axes[i, 0], color='black')
    (-site_data[s]['scalarLatHeatTotal']).isel(type=slice(0,3)).resample({'time': 'W'}).mean().plot.line(x='time', ax=axes[i, 0], add_legend=False, alpha=0.7)
    (site_data[s]['scalarTotalSoilWat'] ).isel(type=slice(0,3)).resample({'time': 'W'}).mean().plot.line(x='time', ax=axes[i, 1], add_legend=False)
    (site_data[s]['pptrate'] ).resample({'time': 'W'}).sum().plot.line(x='time', ax=axes[i, 2], add_legend=False)
    (site_data[s]['scalarTotalRunoff'] ).resample({'time': 'W'}).mean().plot.line(x='time', ax=axes[i, 3], add_legend=add_legend)
    axes[i, 0].set_title(s)
    axes[i, 1].set_title(s)
    
plt.tight_layout()

In [None]:
def merge_site_attr_parm(sites):
    ds_list = []
    s_list = []
    for site in sites:
        try:
            with xr.open_dataset(f'../sites/{site}/params/parameter_trial.nc') as d:
                parm_data = d.load()
            with xr.open_dataset(f'../sites/{site}/params/local_attributes.nc') as d:
                attr_data = d.load()
            assert 'heightCanopyTop' in parm_data
            ds_list.append(xr.merge([parm_data, attr_data]))
            s_list.append(site)
        except:
            print(site)
    return xr.concat(ds_list, dim='site').assign_coords({'site': s_list}).isel(hru=0, gru=0, drop=True)

In [None]:
vegtypes = np.array([site_data[s]['vegTypeIndex'].mean().values[()] for s in complete_sites]).astype(int)

In [None]:
colors = ['#e6194b', '#3cb44b', '#ffe119', 
          '#4363d8', '#f58231', '#911eb4', 
          '#46f0f0', '#f032e6', '#bcf60c', 
          '#fabebe', '#008080', '#e6beff', ]
veg_colors = np.array(colors)[vegtypes-1]
cm = LinearSegmentedColormap.from_list('veg', colors, N=len(colors))


In [None]:

vegtypes = [site_data[s]['vegTypeIndex'].mean().values[()] for s in complete_sites]

In [None]:
temps = np.arange(-40, 40, 10) + 273.16
lvaps = L_vap(temps)

In [None]:
from metsim.physics import calc_pet

In [None]:
flux_etp = {}
state_etp = {}
standalone_etp = {}
observed_etp = {}


for site in tqdm(complete_sites):
   
    agg_period = '30D'
    flux_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qle_cor']
    observed_p = (site_data[site]['pptrate']).sum().values
    
    observed_T = site_data[site]['airtemp']
    flux_et = Qle_to_ET(flux_qle, observed_T).sum().values
    state_et = Qle_to_ET(state_qle, observed_T).sum().values
    standalone_et = Qle_to_ET(standalone_qle, observed_T).sum().values
    observed_et = Qle_to_ET(observed_qle, observed_T).sum().values
    
    flux_etp[site] = np.nanmean(flux_et) / np.nanmean(observed_p)
    state_etp[site] = np.nanmean(state_et) / np.nanmean(observed_p)
    standalone_etp[site] = np.nanmean(standalone_et) / np.nanmean(observed_p)
    observed_etp[site] = np.nanmean(observed_et) / np.nanmean(observed_p)

In [None]:
def calc_phi(X, Y):
    regr = linear_model.LinearRegression()
    regr.fit(X, Y)
    b0, b1, b2 = regr.coef_
    omega = 2 * np.pi / 48
    phi = np.arctan(omega * b2 / b1)
    return phi


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(22, 5), sharex=True, sharey=True)

sc3 = axes[0].scatter(list(observed_etp.values()), list(standalone_etp.values()), c=list(standalone_kge_qle.values()), marker='o', cmap='turbo', vmin=0, vmax=1, alpha=0.8)
plt.colorbar(sc3, ax=axes[0])
axes[0].plot([0,1.2], [0,1.2], color='black')


sc1= axes[1].scatter(list(observed_etp.values()), list(flux_etp.values()),    c=list(flux_kge_qle.values()),    marker='o', cmap='turbo', vmin=0, vmax=1, alpha=0.8)
plt.colorbar(sc1, ax=axes[1])
axes[1].plot([0,1.2], [0,1.2], color='black')

sc2= axes[2].scatter(list(observed_etp.values()), list(state_etp.values()),    c=list(state_kge_qle.values()),    marker='o', cmap='turbo', vmin=0, vmax=1, alpha=0.8)
plt.colorbar(sc2, ax=axes[2], label='KGE')
axes[2].plot([0,1.2], [0,1.2], color='black')

axes[0].set_xlabel('Observed ET/P')
axes[1].set_xlabel('Observed ET/P')
axes[2].set_xlabel('Observed ET/P')
axes[0].set_ylabel('Simulated ET/P')


axes[0].set_title('SA')
axes[1].set_title('NN1W')
axes[2].set_title('NN2W')

In [None]:
phi_qle_obs  = []
phi_qle_sa   = []
phi_qle_nn1w = []
phi_qle_nn2w = []

phi_qh_obs  = []
phi_qh_sa   = []
phi_qh_nn1w = []
phi_qh_nn2w = []

for site in tqdm(complete_sites):
    ds = site_data[site]
    ds['halfhour'] = (ds['time'].dt.hour * 2) + (ds['time'].dt.minute / 30)
    swrad = ds['SWRadAtm'].groupby(ds['halfhour']).median()
    dswrad_dt = ds['SWRadAtm'].diff(dim='time').ffill(dim='time')
    dswrad_dt = dswrad_dt.groupby(ds['halfhour'].isel(time=slice(1,None))).median()
    qle = -ds['scalarLatHeatTotal'].groupby(ds['halfhour']).median(dim='time')
    qle_obs = ds['Qle_cor'].groupby(ds['halfhour']).median()
    qh = -ds['scalarSenHeatTotal'].groupby(ds['halfhour']).median(dim='time')
    qh_obs = ds['Qh_cor'].groupby(ds['halfhour']).median()
    
    Y_qle_sa = qle.sel(type='SUMMA-SA').values
    Y_qle_nn1w = qle.sel(type='SUMMA-NN1W').values
    Y_qle_nn2w = qle.sel(type='SUMMA-NN2W').values
    Y_qle_obs = qle_obs.values
    
    
    Y_qh_sa = qh.sel(type='SUMMA-SA').values
    Y_qh_nn1w = qh.sel(type='SUMMA-NN1W').values
    Y_qh_nn2w = qh.sel(type='SUMMA-NN2W').values
    Y_qh_obs = qh_obs.values
    X = np.vstack([np.ones_like(swrad.values), swrad.values, dswrad_dt.values]).T
    
    phi_qle_obs.append( calc_phi(X,  Y_qle_obs))
    phi_qle_sa.append( calc_phi(X,   Y_qle_sa))
    phi_qle_nn1w.append( calc_phi(X, Y_qle_nn1w))
    phi_qle_nn2w.append( calc_phi(X, Y_qle_nn2w))
    
    phi_qh_obs.append( calc_phi(X,  Y_qh_obs))
    phi_qh_sa.append( calc_phi(X,   Y_qh_sa))
    phi_qh_nn1w.append( calc_phi(X, Y_qh_nn1w))
    phi_qh_nn2w.append( calc_phi(X, Y_qh_nn2w))

In [None]:
dphi_qle_sa =   (30 * 24 / (2 * np.pi)) * (np.array(phi_qle_obs) - np.array(phi_qle_sa))
dphi_qle_nn1w = (30 * 24 / (2 * np.pi)) * (np.array(phi_qle_obs) - np.array(phi_qle_nn1w))
dphi_qle_nn2w = (30 * 24 / (2 * np.pi)) * (np.array(phi_qle_obs) - np.array(phi_qle_nn2w))
dphi_qh_sa =    (30 * 24 / (2 * np.pi)) * (np.array(phi_qh_obs) - np.array(phi_qh_sa))
dphi_qh_nn1w =  (30 * 24 / (2 * np.pi)) * (np.array(phi_qh_obs) - np.array(phi_qh_nn1w))
dphi_qh_nn2w =  (30 * 24 / (2 * np.pi)) * (np.array(phi_qh_obs) - np.array(phi_qh_nn2w))

In [None]:
sa_color='#4363d8'
nn1w_color='#800000'
nn2w_color='#f58231'

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16, 8), sharey=True)
ax = ax.flatten()
showfliers = False
ax[0].axhline(0, color='black')
ax[1].axhline(0, color='black')
pa_sa = ax[0].boxplot(dphi_qle_sa, positions=[1], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
pa_nn1w = ax[0].boxplot(dphi_qle_nn1w, positions=[2], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
pa_nn2w = ax[0].boxplot(dphi_qle_nn2w, positions=[3], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
for patch in pa_sa['boxes']:
    patch.set_facecolor(sa_color)
for patch in pa_nn1w['boxes']:
    patch.set_facecolor(nn1w_color)
for patch in pa_nn2w['boxes']:
    patch.set_facecolor(nn2w_color)

    
pa_sa = ax[1].boxplot(dphi_qh_sa, positions=[1], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
pa_nn1w = ax[1].boxplot(dphi_qh_nn1w, positions=[2], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
pa_nn2w = ax[1].boxplot(dphi_qh_nn2w, positions=[3], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
for patch in pa_sa['boxes']:
    patch.set_facecolor(sa_color)
for patch in pa_nn1w['boxes']:
    patch.set_facecolor(nn1w_color)
for patch in pa_nn2w['boxes']:
    patch.set_facecolor(nn2w_color)


ax[0].set_ylabel('Difference in phase lag \nfrom observed (minutes)')
ax[0].set_title('Latent heat')
ax[1].set_title('Sensible heat')

ax[0].set_xticklabels(['SA', 'NN1W', 'NN2W'])
ax[1].set_xticklabels(['SA', 'NN1W', 'NN2W'])

In [None]:
site = 'AT-Neu'

ds = site_data[site]
swrad = ds['SWRadAtm'].groupby(ds['time'].dt.hour).mean()
qle = -ds['scalarSenHeatTotal'].groupby(ds['time'].dt.hour).mean(dim='time')
qle_obs = ds['Qh_cor'].groupby(ds['time'].dt.hour).mean()


plt.plot(swrad, qle_obs, marker='o', label='Obs', color='black')
#plt.plot(swrad, qle.sel(type='SUMMA-SA'),   label='SA',   color='#4363d8', marker='o' )
#plt.plot(swrad, qle.sel(type='SUMMA-NN1W'), label='NN1W', color='#800000', marker='o' )
#plt.plot(swrad, qle.sel(type='SUMMA-NN2W'), label='NN2W', color='#f58231', marker='o' )
plt.xlabel(r'Shortwave $(W/m^2)$')
plt.ylabel(r'Latent Heat $(W/m^2)$')
#plt.title(f'Site: {site}')
#plt.legend()

In [None]:
site

In [None]:
ds = site_data[site]
swrad = ds['SWRadAtm'].groupby(ds['time'].dt.hour).mean()
qle = -ds['scalarSenHeatTotal'].groupby(ds['time'].dt.hour).mean(dim='time')
qle_obs = ds['Qh_cor'].groupby(ds['time'].dt.hour).mean()

plt.plot(swrad, qle_obs, marker='o', label='Obs')
plt.plot(swrad, qle.sel(type='SUMMA-SA'),   label='SA',   marker='o' )
plt.plot(swrad, qle.sel(type='SUMMA-NN1W'), label='NN1W', marker='o' )
plt.plot(swrad, qle.sel(type='SUMMA-NN2W'), label='NN2W', marker='o' )
plt.xlabel('')
plt.legend()

In [None]:
def calc_pet(tmin, tmax, tmean, rad, rad_mult=1):
    return 0.0023 * (tmean + 17.8) * (tmax - tmin) ** 0.5 * 0.408 * (rad * rad_mult)

In [None]:
petp = {}
pet = {}
dailyp = {}
for site in tqdm(complete_sites):
    ds = site_data[site] 
    daily_p = 1800.0 * ds['pptrate'].resample(time='D').sum()
    dailyp[site] = np.nanmean(daily_p)
    
    # Calculate temperatures in C
    seconds_per_half_hour = 1800.0
    to_mega = 1e6
    kelvin_to_celcius = 273.16
    tmean = ds['airtemp'].resample(time='30D').mean() - kelvin_to_celcius
    tmin  = ds['airtemp'].resample(time='30D').min()  - kelvin_to_celcius
    tmax  = ds['airtemp'].resample(time='30D').max()  - kelvin_to_celcius
    
    # Aggregate and convert to MJ / day * m^2
    netrad  = ds['SWRadAtm'].resample(time='D').sum().resample(time='30D').mean().values
    netrad = (netrad * seconds_per_half_hour) / (to_mega )#* 0.72)
   
    # Use Hargreave's eq to estimate PET, then aggregate to yearly
    daily_pet = calc_pet(tmin, tmax, tmean, netrad)
    
    yearly_pet = daily_pet
    yearly_p = daily_p
    pet[site] = np.nanmean(yearly_pet)
    petp[site] = np.nanmean(yearly_pet) / np.nanmean(yearly_p)

In [None]:
phi = np.arange(0, 7.01, 0.01)
budyko = np.sqrt(phi * np.tanh(1/phi) * (1 - np.exp(-phi)))
fig, axes = plt.subplots(1, 3, figsize=(22,5), sharex=True, sharey=True)
axes = axes.flatten()

axes[0].plot(phi, budyko, color='black')
sc1 = axes[0].scatter(list(petp.values()), list(flux_etp.values()), c=list(flux_kge_qle.values()),
                      vmin=0, vmax=1, cmap='turbo')
plt.colorbar(sc1, label='KGE', ax=axes[0])
axes[0].set_xlabel('PET / P')
axes[0].set_ylabel('ET / P')
axes[0].set_title('SUMMA-NN1W')
axes[0].set_ylim([0, 2])

axes[1].plot(phi, budyko, color='black')
axes[1].scatter(list(petp.values()), list(state_etp.values()), c=list(state_kge_qle.values()),
                vmin=0, vmax=1, cmap='turbo')
axes[1].set_xlabel('PET / P')
axes[1].set_ylabel('ET / P')
axes[1].set_title('SUMMA-NN2W')
axes[1].set_ylim([0, 2])
plt.colorbar(sc1, label='KGE', ax=axes[1])
    
axes[2].plot(phi, budyko, color='black')
axes[2].scatter(list(petp.values()), list(standalone_etp.values()), c=list(standalone_kge_qle.values()),
                vmin=0, vmax=1, cmap='turbo')
axes[2].set_xlabel('PET / P')
axes[2].set_ylabel('ET / P')
axes[2].set_ylim([0, 2])
axes[2].set_title('SUMMA-SA')
plt.colorbar(sc1, label='KGE', ax=axes[2])

In [None]:
aggregate_freqs = ['3H', '6H', 'D', '7D', '14D', '30D']
aggregate_freqs = ['3H', 'D', '7D', '30D']

aggregate_sa_nse_qle = [standalone_nse_qle]
aggregate_sa_kge_qle = [standalone_kge_qle]
aggregate_sa_nse_qh = [standalone_nse_qh]
aggregate_sa_kge_qh = [standalone_kge_qh]

aggregate_flux_nse_qle = [flux_nse_qle]
aggregate_flux_kge_qle = [flux_kge_qle]
aggregate_flux_nse_qh = [flux_nse_qh]
aggregate_flux_kge_qh = [flux_kge_qh]

aggregate_state_nse_qle = [state_nse_qle]
aggregate_state_kge_qle = [state_kge_qle]
aggregate_state_nse_qh = [state_nse_qh]
aggregate_state_kge_qh = [state_kge_qh]

for freq in aggregate_freqs:
    print(freq)
    flux_nse_qle = {}
    state_nse_qle = {}
    sa_nse_qle = {}
    
    flux_nse_qh = {}
    state_nse_qh = {}
    sa_nse_qh = {}
    
    flux_kge_qle = {}
    state_kge_qle = {}
    sa_kge_qle = {}
    
    flux_kge_qh = {}
    state_kge_qh = {}
    sa_kge_qh = {}
       
    for s in tqdm(complete_sites):
        obs_qle   = site_data[s]['Qle_cor'].resample({'time': freq}).mean()
        flux_qle  = -site_data[s]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W').resample({'time': freq}).mean()
        state_qle = -site_data[s]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W').resample({'time': freq}).mean()
        sa_qle    = -site_data[s]['scalarLatHeatTotal'].sel(type='SUMMA-SA').resample({'time': freq}).mean()
        
        obs_qh   = site_data[s]['Qh_cor'].resample({'time': freq}).mean()
        flux_qh  = -site_data[s]['scalarSenHeatTotal'].sel(type='SUMMA-NN1W').resample({'time': freq}).mean()
        state_qh = -site_data[s]['scalarSenHeatTotal'].sel(type='SUMMA-NN2W').resample({'time': freq}).mean()
        sa_qh    = -site_data[s]['scalarSenHeatTotal'].sel(type='SUMMA-SA').resample({'time': freq}).mean()
        
        flux_nse_qle[s] = compute_metric(flux_qle, obs_qle)
        state_nse_qle[s] = compute_metric(state_qle, obs_qle)
        sa_nse_qle[s] = compute_metric(sa_qle, obs_qle)
        
        flux_nse_qh[s] = compute_metric(flux_qh, obs_qh)
        state_nse_qh[s] = compute_metric(state_qh, obs_qh)
        sa_nse_qh[s] = compute_metric(sa_qh, obs_qh)
        
        flux_kge_qle[s] = compute_metric(flux_qle, obs_qle, metric=pse.kling_gupta_efficiency)
        state_kge_qle[s] = compute_metric(state_qle, obs_qle, metric=pse.kling_gupta_efficiency)
        sa_kge_qle[s] = compute_metric(sa_qle, obs_qle, metric=pse.kling_gupta_efficiency)
        
        flux_kge_qh[s] = compute_metric(flux_qh, obs_qh, metric=pse.kling_gupta_efficiency)
        state_kge_qh[s] = compute_metric(state_qh, obs_qh, metric=pse.kling_gupta_efficiency)
        sa_kge_qh[s] = compute_metric(sa_qh, obs_qh, metric=pse.kling_gupta_efficiency)
       
       
    aggregate_flux_nse_qle.append(flux_nse_qle)
    aggregate_state_nse_qle.append(state_nse_qle)
    aggregate_sa_nse_qle.append(sa_nse_qle)
    
    aggregate_flux_nse_qh.append(flux_nse_qh)
    aggregate_state_nse_qh.append(state_nse_qh)
    aggregate_sa_nse_qh.append(sa_nse_qh)
    
    aggregate_flux_kge_qle.append(flux_kge_qle)
    aggregate_state_kge_qle.append(state_kge_qle)
    aggregate_sa_kge_qle.append(sa_kge_qle)
    
    aggregate_flux_kge_qh.append(flux_kge_qh)
    aggregate_state_kge_qh.append(state_kge_qh)
    aggregate_sa_kge_qh.append(sa_kge_qh)
    

In [None]:
aggregate_flux_kge_qle[0] = flux_kge_qle
aggregate_state_kge_qle[0] = state_kge_qle
aggregate_flux_kge_qh[0] = flux_kge_qh
aggregate_state_kge_qh[0] = state_kge_qh

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(16, 12), sharex=True)

aggregate_names = ['Half-hourly', 'Three-hourly',  'Daily', 'Weekly',  'Monthly']
showfliers = True


sa_color='#4363d8'
nn1w_color='#800000'
nn2w_color='#f58231'
for i, name in enumerate(aggregate_names):

    sa_kge =    aggregate_sa_kge_qle[i].values()
    fl_kge =  aggregate_flux_kge_qle[i].values()
    st_kge = aggregate_state_kge_qle[i].values()

    pa_sa = ax[0].boxplot(sa_kge, positions=[1 + (4 * i)], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
    pa_fl = ax[0].boxplot(fl_kge, positions=[2 + (4 * i)], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
    pa_st = ax[0].boxplot(st_kge, positions=[3 + (4 * i)], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
    for patch in pa_sa['boxes']:
        patch.set_facecolor(sa_color)
    for patch in pa_fl['boxes']:
        patch.set_facecolor(nn1w_color)
    for patch in pa_st['boxes']:
        patch.set_facecolor(nn2w_color)

for i, name in enumerate(aggregate_names):

    sa_kge =    aggregate_sa_kge_qh[i].values()
    fl_kge =  aggregate_flux_kge_qh[i].values()
    st_kge = aggregate_state_kge_qh[i].values()

    pa_sa = ax[1].boxplot(sa_kge, positions=[1 + (4 * i)], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
    pa_fl = ax[1].boxplot(fl_kge, positions=[2 + (4 * i)], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
    pa_st = ax[1].boxplot(st_kge, positions=[3 + (4 * i)], showfliers=showfliers, patch_artist=True, 
                       notch=True, widths=0.7, medianprops={'color': 'black'})
    for patch in pa_sa['boxes']:
        patch.set_facecolor(sa_color)
    for patch in pa_fl['boxes']:
        patch.set_facecolor(nn1w_color)
    for patch in pa_st['boxes']:
        patch.set_facecolor(nn2w_color)

        
ax[0].legend([pa["boxes"][0] for pa in [pa_sa, pa_fl, pa_st]], ['SA', 'NN1W', 'NN2W'], loc='lower left')
ax[1].set_xticks([2,6,10, 14, 18])
ax[1].set_xticklabels(aggregate_names, fontsize=24)
#ax[1].set_ylabel('KGE (Sensible heat)')
#ax[0].set_ylabel('KGE (Latent heat)')
#ax[1].set_xlabel('Aggregation period', fontsize=28)
ax[1].set_ylim([-0.41, 1])
ax[0].set_ylim([-0.41, 1])

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16,6), sharex=True, sharey=False)
axes = axes.flatten()
plot_sites = ['AT-Neu', 'DK-Eng', 'CH-Cha']
for i, site in tqdm(enumerate(plot_sites)):
    if i == len(axes): break
   
    flux_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qle_cor']
    
    flux_qle.groupby(flux_qle.time.dt.dayofyear).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    state_qle.groupby(state_qle.time.dt.dayofyear).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    standalone_qle.groupby(standalone_qle.time.dt.dayofyear).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    observed_qle.groupby(observed_qle.time.dt.dayofyear).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    #flux_qle.groupby(flux_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    #state_qle.groupby(state_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    #standalone_qle.groupby(standalone_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    #observed_qle.groupby(observed_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    #axes[i].set_xlim([0, 12])
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Latent heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
#axes[0].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
standalone_mbe = {}
state_mbe = {}
flux_mbe = {}

standalone_r = {}
state_r = {}
flux_r = {}

standalone_nme = {}
state_nme = {}
flux_nme = {}

for site in tqdm(complete_sites):
    state = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W').values
    flux = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W').values
    standalone = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA').values
    observed = site_data[site]['Qle_cor'].values
    nanfilter = ~np.logical_or(np.logical_or(np.isnan(state), np.isnan(flux)), np.logical_or(np.isnan(standalone), np.isnan(observed)))
    state = state[nanfilter]
    flux = flux[nanfilter]
    standalone = standalone[nanfilter]
    observed = observed[nanfilter]
    
    standalone_mbe[site] = np.sum(standalone - observed) / len(standalone)
    state_mbe[site] = np.sum(state - observed) / len(state)
    flux_mbe[site] = np.sum(flux - observed) / len(flux)
    
    standalone_r[site] = np.corrcoef(standalone, observed)[0, 1]
    state_r[site] = np.corrcoef(state, observed)[0, 1]
    flux_r[site] = np.corrcoef(flux, observed)[0, 1]
    
    standalone_nme[site] = np.sum(np.abs(standalone - observed)) / np.sum(np.abs(np.mean(observed) - observed))
    state_nme[site] = np.sum(np.abs(state - observed)) / np.sum(np.abs(np.mean(observed) - observed))
    flux_nme[site] = np.sum(np.abs(flux - observed)) / np.sum(np.abs(np.mean(observed) - observed))

fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey='row', sharex=False)
axes = axes.flatten()

axes[0].hist(sorted(list(standalone_mbe.values())), color='royalblue', alpha=0.6)
axes[0].hist(sorted(list(flux_mbe.values())), color='crimson', alpha=0.6)
axes[0].hist(sorted(list(state_mbe.values())), color='forestgreen', alpha=0.4)
axes[0].set_xlabel(r'Mean Bias Error ($W/m^2$)')

axes[1].hist(sorted(list(standalone_r.values())), color='royalblue', alpha=0.6)
axes[1].hist(sorted(list(flux_r.values())), color='crimson', alpha=0.6)
axes[1].hist(sorted(list(state_r.values())), color='forestgreen', alpha=0.4)
axes[1].set_xlabel(r'Correlation coefficient')

axes[2].hist(sorted(list(standalone_nme.values())), color='royalblue', label='SUMMA-SA', alpha=0.6)
axes[2].hist(sorted(list(flux_nme.values())), color='crimson', label='SUMMA-NN1W', alpha=0.6)
axes[2].hist(sorted(list(state_nme.values())), color='forestgreen', label='SUMMA-NN2W', alpha=0.4)
axes[2].set_xlabel(r'Normalized mean error ($W/m^2$)')
axes[2].legend()

In [None]:
fig, axes = plt.subplots(8, 10, figsize=(18,16), sharex=True, sharey='row')
axes = axes.flatten()

for i, site in tqdm(enumerate(complete_sites)):
    if i == len(axes): break
   
    flux_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qle_cor']
    flux_qle.groupby(flux_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    state_qle.groupby(state_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    standalone_qle.groupby(standalone_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    observed_qle.groupby(observed_qle.time.dt.hour).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Latent heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
axes[0].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
fig, axes = plt.subplots(8, 10, figsize=(18,16), sharex=True, sharey='row')
axes = axes.flatten()

for i, site in tqdm(enumerate(complete_sites)):
    if i == len(axes): break
   
    flux_qle = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN1W')
    state_qle = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-NN2W')
    standalone_qle = -site_data[site]['scalarSenHeatTotal'].sel(type='SUMMA-SA')
    observed_qle = site_data[site]['Qh_cor']
    flux_qle.groupby(flux_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='NN1W')
    state_qle.groupby(state_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='forestgreen', label='NN2W')
    standalone_qle.groupby(standalone_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='SA')
    observed_qle.groupby(observed_qle.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', label='Observed')
    
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Sensible heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
axes[-1].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
fig, axes = plt.subplots(8, 10, figsize=(18,16), sharex=True, sharey='row')
axes = axes.flatten()

for i, site in tqdm(enumerate(complete_sites)):
    if i == len(axes): break
   
    coupled_qh = -site_data[site]['scalarSenHeatTotal'].sel(type='coupled', hru=1, drop=True)
    standalone_qh = -site_data[site]['scalarSenHeatTotal'].sel(type='standalone', hru=1, drop=True)
    observed_qh = site_data[site]['Qh_cor'].sel(hru=1, drop=True)
    coupled_qh = coupled_qh.groupby(coupled_qh.time.dt.month).mean(dim='time')
    standalone_qh = standalone_qh.groupby(standalone_qh.time.dt.month).mean(dim='time')
    observed_qh = observed_qh.groupby(observed_qh.time.dt.month).mean(dim='time')
    
    coupled_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='coupled', hru=1, drop=True)
    standalone_qle = -site_data[site]['scalarLatHeatTotal'].sel(type='standalone', hru=1, drop=True)
    observed_qle = site_data[site]['Qle_cor'].sel(hru=1, drop=True)
    coupled_qle = coupled_qle.groupby(coupled_qle.time.dt.month).mean(dim='time')
    standalone_qle = standalone_qle.groupby(standalone_qle.time.dt.month).mean(dim='time')
    observed_qle = observed_qle.groupby(observed_qle.time.dt.month).mean(dim='time')
    
    (observed_qh / observed_qle).plot(ax=axes[i], color='black', label='Observed')
    (standalone_qh / standalone_qle).plot(ax=axes[i], color='royalblue', label='Standalone')
    (coupled_qh / coupled_qle).plot(ax=axes[i], color='crimson', label='Coupled')
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    axes[i].set_ylim([-5, 10])
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Bowen ratio', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
axes[-1].legend(bbox_to_anchor=(1.0, 1.0))

In [None]:
site_dnse = {}
dnse_list = []
for s in complete_sites:
    nse_sa = standalone_nse[s]
    nse_2w = state_nse[s]
    dnse = nse_sa - nse_2w
    site_dnse[s] = dnse
    dnse_list.append(dnse)

In [None]:
plt.plot(sorted(dnse_list), marker='o')
plt.axhline(0, color='black')
plt.axhline(0, color='black')

In [None]:
big_diffs = np.argsort(dnse_list)[-5:]
big_diffs

In [None]:
np.array(dnse_list)[big_diffs]
np.array(complete_sites)[big_diffs]