# Imports, loading, preprocessing

In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import os
import warnings
import pandas as pd
from glob import glob
import xarray as xr
import seaborn as sns
from IPython.display import SVG
from functools import partial
import geopandas as gpd
from shapely.geometry import Polygon, Point, MultiPolygon

from sklearn import linear_model
import pysumma.plotting as psp
import pysumma.utils as psu
import pysumma.evaluation as pse

from tqdm.notebook import tqdm

from tensorflow import keras
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.utils import plot_model, model_to_dot
from tensorflow.keras import layers
from tensorflow.keras.callbacks import Callback, EarlyStopping

from sklearn import preprocessing
from sklearn.linear_model import LinearRegression

import shap
import innvestigate

os.environ["CUDA_VISIBLE_DEVICES"]="1"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["OMP_NUM_THREADS"] = "1"
os.environ['TF_NUM_INTEROP_THREADS'] = '1'
os.environ['TF_NUM_INTRAOP_THREADS'] = '1'

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

mpl.style.use('seaborn-talk')
sns.set_context('talk')
mpl.rcParams['figure.figsize'] = (8, 6)

K.set_floatx('float32')

In [None]:
model = keras.models.load_model('../new_models/all_var_dense_dropout.h5')
sim_sites = os.listdir('../sites')

In [None]:
def mse_eb(y_true, y_pred):
    # Normal MSE loss
    mse = K.mean(K.square(y_true[:, 0:2]-y_pred[:, 0:2]), axis=-1)
    # Loss that penalizes differences between sum(predictions) and sum(true) (energy balance constraint)
    sum_constraint = K.mean(K.square(K.sum(y_pred[:, 0:2], axis=-1) + y_true[:, 2] )) / 10
    return mse + sum_constraint

In [None]:
feature_strings = ['airtemp',       # 0
                   'relhum',        # 1
                   'swradatm',      # 2
                   'surf_sm',       # 3
                   'lai',           # 4
                   'vegtype',       # 5
              ]           

var_to_color = {
    'swradatm': '#4363d8',
    'airtemp': '#e6194B',
    'relhum': '#f58231',
    'surf_sm': '#911eb4',
    'lai': '#9A6324',
    'vegtype': '#469990',
}

In [None]:
site_dict = {s: xr.open_dataset(f'../sites/{s}/forcings/{s}.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}
site_attr = {s: xr.open_dataset(f'../sites/{s}/params/local_attributes.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}

site_outp = {s: xr.open_dataset(f'../prepped_output_for_casper/lrp_nn_output_{s}_timestep.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}

site_parm = {s: xr.open_dataset(f'../sites/{s}/params/parameter_trial.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}

site_sa = {s: xr.open_dataset(f'../sites/{s}/output/template_output_{s}_timestep.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}


# Functions

In [None]:
def trim_time(sim, obs, roundto='min'):
    sim['time'] = sim['time'].dt.round(roundto)
    obs['time'] = obs['time'].dt.round(roundto)
    sim_start = sim['time'].values[1]
    sim_stop = sim['time'].values[-2]
    obs_start = obs['time'].values[1]
    obs_stop = obs['time'].values[-2]
    start = max(sim_start, obs_start)
    stop = min(sim_stop, obs_stop)
    return slice(start, stop)

def calc_relhum(T, p, sh):
    T0 = 273.16
    return sh * (0.263*p)  / np.exp((17.67*(T-T0))/(T-29.65))

def L_vap(T):
    # J / g
    C = T - 273.16
    return 2500.8 - (2.36 * C) + (0.0016 * C**2) - (0.00006 * C**3)

def Qle_to_ET(Q, T):
    g_per_kg = 1e3
    return (1 / (g_per_kg * L_vap(T))) * Q

def calc_pet(tmin, tmax, tmean, rad, rad_mult=1):
    return 0.0023 * (tmean + 17.8) * (tmax - tmin) ** 0.5 * 0.408 * (rad * rad_mult)

def etl_single_site(ds, use_mask=True):
   
    airtemp   = (((ds['airtemp'].values / 27.315) - 10) / 2) + 0.5
    swradatm  = (ds['SWRadAtm'].values / 1000) 
    mask      = ds['gap_filled'].values
    relhumid  = calc_relhum(ds['airtemp'].values, ds['airpres'].values, ds['spechum']) / 100
    relhumid[relhumid<0] = 0
    lai = ds['scalarLAI'].values / 12
    try:
        canwidth = (ds['heightCanopyTop']).values / 20
    except:
        canwidth = 1/(lai + 0.001)
    
    vegtype = ds['vegTypeIndex'].values[()] * np.ones_like(mask) / 12

    thetasat = ds['theta_sat'].values[()] * np.ones_like(mask)
    fieldcapacity = ds['fieldCapacity'].values[()] * np.ones_like(mask)
    soilwilting = ds['critSoilWilting'].values[()] * np.ones_like(mask)
    sm_min = soilwilting
      
    # Surface moisture
    nlayers_top = 0
    nlayers_bot = 4
    surf_idx = -len(ds.midSoil)
    surf_sm = ds['mLayerVolFracWat'].copy(deep=True)
    vmask = surf_sm != -9999
    
    depth = ds['mLayerHeight'].copy(deep=True)
    dmask = depth != -9999
    depth.values = psp.utils.justify(depth.where(dmask).values)
    depth = depth.isel(midToto=slice(surf_idx+nlayers_top, surf_idx+nlayers_bot))
    depth = depth / depth.sum(dim='midToto')
    
    surf_sm.values = psp.utils.justify(surf_sm.where(vmask).values)
    surf_sm = surf_sm.isel(midToto=slice(surf_idx+nlayers_top, surf_idx+nlayers_bot))
    surf_sm *= depth
    surf_sm = surf_sm.sum(dim='midToto')
    surf_sm = (surf_sm - sm_min[0]) / (thetasat[0] - sm_min[0])

    train_input = np.vstack([airtemp,       # 0
                             relhumid,      # 1
                             swradatm,      # 2
                             surf_sm,       # 3
                             lai * canwidth, # 4
                             vegtype,       # 5
                            ]).T 
    
    train_output = np.vstack([ds['Qle_cor'].values / 500,
                              ds['Qh_cor'].values / 500,]).T
    
    if use_mask:
        train_input = train_input[mask == 0]
        train_output = train_output[mask == 0]    
    return train_input.astype(np.float32), train_output.astype(np.float32)

In [None]:
def modify_variables(ds):
    
    new_input, new_output = etl_single_site(ds, use_mask=False)
    t = ds['time'].values
    
    ds['transpirable'] = ds['scalarLatHeatTotal'].copy(deep=True)
    ds['transpirable'].values = new_input[:, 5]
    ds['surf_sm'] = ds['scalarLatHeatTotal'].copy(deep=True)
    ds['surf_sm'].values = new_input[:, 3]
    
    ds['scalarLatHeatTotal'].values = (500 * model.predict(new_input)[:, 0].flatten())
    ds['scalarSenHeatTotal'].values = (500 * model.predict(new_input)[:, 1].flatten())
    
    ds['halfhourofday'] = ds['scalarLatHeatTotal'].copy(deep=True)
    ds['halfhourofday'].values = (2 * ds['time'].dt.hour + ds['time'].dt.minute // 30)
    
    ds['day'] = (ds['halfhourofday'] == 0).cumsum()
    
    return ds
 

In [None]:
def lrp_to_ds(r, timedim, feature_strings=feature_strings):
    feature_das = []
    for i, feature in enumerate(feature_strings):
            r_feat = r[:, i]
            feature_das.append(xr.DataArray(r_feat, coords={'time': timedim}, dims=['time'], name=feature))
    r_ds = xr.merge(feature_das) 
    in_sum = np.sum(r)
    out_sum = np.sum(r_ds.sum().to_array())
    return r_ds

In [None]:
def analyze_model(new_input, time_dim, analyzer):
    
    r_qle = analyzer.analyze(new_input, neuron_selection=0)
    r_qh = analyzer.analyze(new_input, neuron_selection=1)
    
    feature_strings = ['airtemp', 'relhum', 'swradatm', 'surf_sm', 'lai', 'vegtype']
    r_qle_ds = lrp_to_ds(r_qle, time_dim, feature_strings)
    r_qh_ds = lrp_to_ds(r_qh, time_dim, feature_strings)
    return r_qle_ds, r_qh_ds
        

In [None]:
def compute_metric(simvar, obsvar, metric=pse.nash_sutcliffe_efficiency):
    sv = simvar.copy(deep=True)
    ov = obsvar.copy(deep=True)
    t = pse.trim_time(sv, ov)
    return metric(sv.sel(time=t), ov.sel(time=t))

In [None]:
for s in sim_sites:
    # trim times to match
    in_ds = site_dict[s]
    outp_ds = site_outp[s]
    ts = trim_time(in_ds, outp_ds)
    in_ds = in_ds.sel(time=ts)
    in_ds['relhum'] = calc_relhum(in_ds['airtemp'], in_ds['airpres'], in_ds['spechum'])
    outp_ds = outp_ds.sel(time=ts)
    site_dict[s] = in_ds
    site_outp[s] = outp_ds
    site_nn1w[s] = site_nn1w[s].sel(time=ts)
    site_sa[s] = site_sa[s].sel(time=ts)
    

site_ds = {s: modify_variables(xr.merge([site_dict[s], site_attr[s], site_outp[s], site_parm[s]])) for s in tqdm(sim_sites)}

# Slicing, filtering, etc

In [None]:
veg_igbp = pd.read_csv('./VEGPARM_IGBP_MODIS_NOAH.TBL', index_col=-1, skipinitialspace=True)
veg_igbp.index = veg_igbp.index.map(lambda x: x.strip().replace("'", ""))

soil_rosetta = pd.read_csv('./SOILPARM_ROSETTA.TBL', index_col=-1, skipinitialspace=True)
soil_rosetta.index = soil_rosetta.index.map(lambda x: x.strip().replace("'", ""))

data = np.ones((len(sites), 5))
for i, site in enumerate(sites):
    local_attrs = xr.open_dataset(f'../sites/{site}/params/local_attributes.nc')
    data[i, 0] = local_attrs['latitude'].values[0]
    data[i, 1] = local_attrs['longitude'].values[0]
    data[i, 2] = local_attrs['elevation'].values[0]
    data[i, 3] = local_attrs['soilTypeIndex'].values[0]
    data[i, 4] = local_attrs['vegTypeIndex'].values[0]
site_attrs = gpd.GeoDataFrame(data, index=sites, columns=['Latitude', 'Longitude', 'Elevation', 'Soil Code', 'Veg Code'])
site_attrs['geometry'] = [Point(xy) for xy in zip(site_attrs['Longitude'], site_attrs['Latitude'])]
veg_idxs = site_attrs['Veg Code'].unique().astype(int)
site_attrs['Veg Type'] = site_attrs['Veg Code'].apply(lambda x: veg_igbp.where(veg_igbp['VEGTYPINDEX'] == x).dropna().index[0])
site_attrs['Soil Type'] = site_attrs['Soil Code'].apply(lambda x: soil_rosetta.where(soil_rosetta['SOILTYPINDEX']== x).dropna().index[0])

veg_types = np.unique(site_attrs['Veg Type'].values)

colors = ['#e6194b', '#3cb44b', '#ffe119', 
          '#4363d8', '#f58231', '#911eb4', 
          '#46f0f0', '#f032e6', '#bcf60c', 
          '#fabebe', '#008080', '#e6beff', ]

vegcolors = ['goldenrod',  'firebrick',  'grey', 
          'chocolate',  'orange',  'dodgerblue', 
          'khaki',  'red', 'lightgreen', 
          'yellowgreen',  'mediumseagreen',  'forestgreen', ]

In [None]:
vegtypes = [site_attrs.loc[site]['Veg Code'] for site in sim_sites]
soiltypes = [site_attrs.loc[site]['Soil Code'] for site in sim_sites]

In [None]:
sim_etp = {}
obs_etp = {}

petp = {}
pet = {}
dailyp = {}

for site in sim_sites:
    agg_period = '1000D'
    sim_qle = -site_outp[site]['scalarLatHeatTotal']
    obs_qle = site_dict[site]['Qle_cor']
    obs_p = (site_dict[site]['pptrate']).sum().values
    
    obs_T = site_dict[site]['airtemp']
    sim_et = Qle_to_ET(sim_qle, obs_T).sum().values
    obs_et = Qle_to_ET(obs_qle, obs_T).sum().values
    
    sim_etp[site] = np.nanmean(sim_et) / np.nanmean(obs_p)
    obs_etp[site] = np.nanmean(obs_et) / np.nanmean(obs_p)

    ds = site_dict[site] 
    daily_p = 1800.0 * ds['pptrate'].resample(time='D').sum()
    dailyp[site] = np.nanmean(daily_p)
    
    # Calculate temperatures in C
    seconds_per_half_hour = 1800.0
    to_mega = 1e6
    kelvin_to_celcius = 273.16
    tmean = ds['airtemp'].resample(time='30D').mean() - kelvin_to_celcius
    tmin  = ds['airtemp'].resample(time='30D').min()  - kelvin_to_celcius
    tmax  = ds['airtemp'].resample(time='30D').max()  - kelvin_to_celcius
    
    # Aggregate and convert to MJ / day * m^2
    netrad  = ds['SWRadAtm'].resample(time='D').sum().resample(time='30D').mean().values
    netrad = (netrad * seconds_per_half_hour) / (to_mega )#* 0.72)
   
    # Use Hargreave's eq to estimate PET, then aggregate to yearly
    daily_pet = calc_pet(tmin, tmax, tmean, netrad)
    
    yearly_pet = daily_pet
    yearly_p = daily_p
    pet[site] = np.nanmean(yearly_pet)
    petp[site] = np.nanmean(yearly_pet) / np.nanmean(yearly_p)

    
obs_etp = dict(sorted(obs_etp.items(), key=lambda item: item[1]))
sim_etp = dict(sorted(sim_etp.items(), key=lambda item: item[1]))
    
petp = dict(sorted(petp.items(), key=lambda item: item[1]))
pet = dict(sorted(pet.items(), key=lambda item: item[1]))

In [None]:
et_pet = {}
for s in tqdm(sim_sites):
    e = obs_etp[s]
    p = petp[s]
    et_pet[s] = e/p
    
et_pet = dict(sorted(et_pet.items(), key=lambda item: item[1]))

In [None]:

nn1w_nse_qle = {site: compute_metric(-site_nn1w[site]['scalarLatHeatTotal'], 
                                 site_ds[site]['Qle_cor']) 
                  for site in tqdm(sim_sites)}
lrp_nse_qle = {site: compute_metric(site_ds[site]['scalarLatHeatTotal'], 
                                  site_ds[site]['Qle_cor']) 
                  for site in tqdm(sim_sites)}
standalone_nse_qle    = {site: compute_metric(-site_sa[site]['scalarLatHeatTotal'], 
                                       site_ds[site]['Qle_cor']) 
                  for site in tqdm(sim_sites)}

standalone_kge_qle = {site: compute_metric(-site_sa[site]['scalarLatHeatTotal'], 
                                       site_ds[site]['Qle_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(sim_sites)}

nn1w_kge_qle    = {site: compute_metric(-site_nn1w[site]['scalarLatHeatTotal'], 
                                       site_ds[site]['Qle_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(sim_sites)}

lrp_kge_qle    = {site: compute_metric(site_ds[site]['scalarLatHeatTotal'], 
                                       site_ds[site]['Qle_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(sim_sites)}



nn1w_nse_qh = {site: compute_metric(-site_nn1w[site]['scalarSenHeatTotal'], 
                                 site_ds[site]['Qh_cor']) 
                  for site in tqdm(sim_sites)}
lrp_nse_qh = {site: compute_metric(site_ds[site]['scalarSenHeatTotal'], 
                                  site_ds[site]['Qh_cor']) 
                  for site in tqdm(sim_sites)}
standalone_nse_qh    = {site: compute_metric(-site_sa[site]['scalarSenHeatTotal'], 
                                       site_ds[site]['Qh_cor']) 
                  for site in tqdm(sim_sites)}

standalone_kge_qh = {site: compute_metric(-site_sa[site]['scalarSenHeatTotal'], 
                                       site_ds[site]['Qh_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(sim_sites)}

nn1w_kge_qh    = {site: compute_metric(-site_nn1w[site]['scalarSenHeatTotal'], 
                                       site_ds[site]['Qh_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(sim_sites)}

lrp_kge_qh    = {site: compute_metric(site_ds[site]['scalarSenHeatTotal'], 
                                       site_ds[site]['Qh_cor'], 
                                       metric=pse.kling_gupta_efficiency) 
                  for site in tqdm(sim_sites)}




y_vals = np.arange(0, len(sim_sites)) / len(sim_sites)

In [None]:
site_attrs['ET/PET'] = pd.Series(et_pet)
site_attrs['PET'] = pd.Series(pet)
site_attrs['PET/P'] = pd.Series(petp)
site_attrs['ET/P_obs'] = pd.Series(obs_etp)
site_attrs['ET/P_sim'] = pd.Series(sim_etp)

site_attrs['KGE_Qle_SA'] = pd.Series(standalone_kge_qle)
site_attrs['KGE_Qh_SA'] = pd.Series(standalone_kge_qh)
site_attrs['NSE_Qle_SA'] = pd.Series(standalone_nse_qle)
site_attrs['NSE_Qh_SA'] = pd.Series(standalone_nse_qh)

site_attrs['KGE_Qle_NN1W'] = pd.Series(nn1w_kge_qle)
site_attrs['KGE_Qh_NN1W'] = pd.Series(nn1w_kge_qh)
site_attrs['NSE_Qle_NN1W'] = pd.Series(nn1w_nse_qle)
site_attrs['NSE_Qh_NN1W'] = pd.Series(nn1w_nse_qh)

site_attrs['KGE_Qle_NNLRP'] = pd.Series(lrp_kge_qle)
site_attrs['KGE_Qh_NNLRP'] = pd.Series(lrp_kge_qh)
site_attrs['NSE_Qle_NNLRP'] = pd.Series(lrp_nse_qle)
site_attrs['NSE_Qh_NNLRP'] = pd.Series(lrp_nse_qh)

site_attrs['Site'] = site_attrs.index
site_attrs.to_file("../data_for_paper_2/site_attrs.shp")

In [None]:
plt.subplots(figsize=(10, 7))

sa_color='#4363d8'
nn1w_color='#f58231'
lrp_color='#ffe119'

pa_sa  = plt.boxplot(pd.Series(standalone_kge_qle), notch=True, patch_artist=True, 
                     positions=[0], widths=[0.7], medianprops={'color': 'black'})
pa_1w  = plt.boxplot(pd.Series(nn1w_kge_qle),       notch=True, patch_artist=True, 
                     positions=[1], widths=[0.7], medianprops={'color': 'black'})
pa_lrp = plt.boxplot(pd.Series(lrp_kge_qle),        notch=True, patch_artist=True, 
                     positions=[2], widths=[0.7], medianprops={'color': 'black'})
for patch in pa_sa['boxes']:
    patch.set_facecolor(sa_color)
for patch in pa_1w['boxes']:
    patch.set_facecolor(nn1w_color)
for patch in pa_lrp['boxes']:
    patch.set_facecolor(lrp_color)


pa_sa  = plt.boxplot(pd.Series(standalone_kge_qh),  notch=True, patch_artist=True, 
                     positions=[3.75], widths=[0.7], medianprops={'color': 'black'})
pa_1w  = plt.boxplot(pd.Series(nn1w_kge_qh),        notch=True, patch_artist=True, 
                     positions=[4.75], widths=[0.7], medianprops={'color': 'black'})
pa_lrp = plt.boxplot(pd.Series(lrp_kge_qh),         notch=True, patch_artist=True, 
                     positions=[5.75], widths=[0.7], medianprops={'color': 'black'})
for patch in pa_sa['boxes']:
    patch.set_facecolor(sa_color)
for patch in pa_1w['boxes']:
    patch.set_facecolor(nn1w_color)
for patch in pa_lrp['boxes']:
    patch.set_facecolor(lrp_color)

plt.gca().set_ylim([-0.5, 1])
plt.xticks([0, 1, 2, 3.75, 4.75, 5.75], labels=2*['SA', 'NN2W', 'NNLRP'])
plt.text(0.3, -0.8, 'Latent Heat', fontsize=22)
plt.text(4., -0.8, 'Sensible Heat', fontsize=22)
plt.ylabel('KGE')

In [None]:
for s, ds in tqdm(site_ds.items()):
    ds.to_netcdf(f'../data_for_paper_2/{s}_data.nc')

In [None]:
site_r_qle = {}
site_r_qh = {}

for site in tqdm(sim_sites):
    new_input, new_output = etl_single_site(site_ds[site], use_mask=False)
    #new_input = post_process_site(new_input)
    t = site_ds[site]['time'].values
    
    analyzer = innvestigate.analyzer.LRPEpsilon(model, epsilon=1e-3, neuron_selection_mode='index')
    r_qle_ds, r_qh_ds = analyze_model(new_input, t, analyzer)
    site_r_qle[site] = r_qle_ds
    site_r_qh[site] = r_qh_ds


In [None]:
for s, ds in tqdm(site_r_qle.items()):
    ds.to_netcdf(f'../data_for_paper_2/{s}_lrp_qle.nc')

In [None]:
for s, ds in tqdm(site_r_qh.items()):
    ds.to_netcdf(f'../data_for_paper_2/{s}_lrp_qh.nc')