In [None]:
%pylab inline
import pandas as pd
import pysumma as ps
import xarray as xr
from matplotlib import cm
import seaborn as sns
from pathlib import Path
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
import numpy as np
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
%run ../lib/summa_snow_layering.py

sns.set_context('poster')
mpl.style.use('seaborn-bright')
mpl.rcParams['figure.figsize'] = (18, 12)

In [None]:
dana = xr.open_dataset('../processed/dana_perturbations.nc').load()
coldeport = xr.open_dataset('../processed/coldeport_perturbations.nc').load()
reynolds = xr.open_dataset('../processed/reynolds_perturbations.nc').load()

dana['scalarSnowTemp'] -= 273.16
reynolds['scalarSnowTemp'] -= 273.16
coldeport['scalarSnowTemp'] -= 273.16

dana['scalarSurfaceTemp'] -= 273.16
reynolds['scalarSurfaceTemp'] -= 273.16
coldeport['scalarSurfaceTemp'] -= 273.16

In [None]:
dana_years = np.arange(dana.time.dt.year.values[0], dana.time.dt.year.values[-1]-1)
cdp_years = np.arange(coldeport.time.dt.year.values[0], coldeport.time.dt.year.values[-1])
reynolds_years = np.arange(reynolds.time.dt.year.values[0], reynolds.time.dt.year.values[-1])

In [None]:
colors_2l = ['wheat', 'orange', 'peru']
colors_3l = ['skyblue', 'dodgerblue', 'royalblue']
colors_4l = ['violet', 'deeppink', 'crimson']
colors_all = [ *colors_2l, *colors_3l, *colors_4l, 'lime']

In [None]:
year = lambda x, y: slice('{}/10/01'.format(x), '{}/09/30'.format(y))

In [None]:
year_dict = {'dana': dana_years,
             'cdp': cdp_years,
             'reynolds': reynolds_years}

ds_dict = {'dana': dana,
           'cdp': coldeport,
           'reynolds': reynolds}

In [None]:
sites = ['dana', 'cdp', 'reynolds']
temps = ['-2.0K', '+0.0K', '+2.0K', '+4.0K']
layers = ['2L_thin' , '2L_mid', '2L_thick', '3L_thin', '3L_mid', '3L_thick', '4L_thin', '4L_mid', '4L_thick', 'CLM']

In [None]:
site_years = [f'{s}_{y}' for s in sites for y in year_dict[s][:-1]]
site_model_years = [f'{s}_{m}_{y}' for s in sites for m in layers for y in year_dict[s][:-1]]

In [None]:
nses                    = {l: {dt: [] for dt in temps} for l in layers}
kges                    = {l: {dt: [] for dt in temps} for l in layers}
sdds                    = {l: {dt: [] for dt in temps} for l in layers}
psds                    = {l: {dt: [] for dt in temps} for l in layers}
kges_melt               = {l: {dt: [] for dt in temps} for l in layers }
surftemp_kge            = {l: {dt: [] for dt in temps} for l in layers }
surftemp_premelt_kge    = {l: {dt: [] for dt in temps} for l in layers }
coldcontent_kge         = {l: {dt: [] for dt in temps} for l in layers }
average_snopack_temp    = {l: {dt: [] for dt in temps} for l in layers }
snowtemp_min            = {dt: [] for dt in temps}
total_snowfalls         = {dt: [] for dt in temps}
total_precips           = {dt: [] for dt in temps}
accum_temp              = {dt: [] for dt in temps}
melt_temp               = {dt: [] for dt in temps}
season_temp             = {dt: [] for dt in temps}
site_year_strings = []

var = 'scalarSWE'
for site in ['dana', 'cdp', 'reynolds']:
    for dt in temps:
        for layer in layers:
            for y in year_dict[site][1:-1]:
                sim = ds_dict[site][var].sel(time=year(y, y+1), dt=dt, model=layer)
                jrdn = ds_dict[site][var].sel(time=year(y, y+1), dt=dt, model='JRDN')
                sim_sst = ds_dict[site]['scalarSurfaceTemp'].sel(time=year(y, y+1), dt=dt, model=layer)
                jrdn_sst  = ds_dict[site]['scalarSurfaceTemp'].sel(time=year(y, y+1), dt=dt, model='JRDN')
                sim_cc = ds_dict[site]['scalarColdContent'].sel(time=year(y, y+1), dt=dt, model=layer)
                jrdn_cc = ds_dict[site]['scalarColdContent'].sel(time=year(y, y+1), dt=dt, model='JRDN')
                nses[layer][dt].append(nse(sim.values, jrdn.values))
                kges[layer][dt].append(kge(sim.values, jrdn.values))
                sdds[layer][dt].append(sdd_diff(sim, jrdn))
                psds[layer][dt].append(ps_diff(sim, jrdn))
                
                start = int(sim.argmax().values[()])
                firstsnow = np.where(sim.isel(time=slice(0, start)) != 0)[0][0]
                stop = np.where(sim.isel(time=slice(start, None)) == 0)[0]
                shift = (24 * 3)
                if len(stop):
                    stop = stop[0] - shift
                    stop += start
                else:
                    stop = -shift #None
                melt_season = slice(start - shift, stop)
                swe_melt_season = slice(start, stop + shift)
                premelt_season = slice(firstsnow, start + 2 * shift)
                
                sim_cc = sim_cc.isel(time=premelt_season).values
                jrdn_cc = jrdn_cc.isel(time=premelt_season).values
                mask_cc = np.logical_and(~np.isnan(sim_cc), ~np.isnan(jrdn_cc))
                kges_melt[layer][dt].append(kge(sim.isel(time=swe_melt_season).values,
                                                jrdn.isel(time=swe_melt_season).values))
                surftemp_kge[layer][dt].append(kge(sim_sst.isel(time=melt_season).values, 
                                                   jrdn_sst.isel(time=melt_season).values))
                coldcontent_kge[layer][dt].append(kge(sim_cc[mask_cc], jrdn_cc[mask_cc]))
                surftemp_premelt_kge[layer][dt].append(kge(sim_sst.isel(time=premelt_season)[mask_cc], 
                                                                 jrdn_sst.isel(time=premelt_season)[mask_cc]).values)
                average_snopack_temp[layer][dt].append(ds_dict[site]['scalarSnowTemp'].sel(time=year(y, y+1), dt=dt, model=layer).mean().values[()])
                if layer == '2L_thin':
                    if dt == '+0.0K':
                        site_year_strings.append(f'{site}_{y}')
                    total_snowfalls[dt].append(3600*ds_dict[site]['scalarSnowfall'].sel(time=year(y, y+1), dt=dt, model='JRDN').sum().values[()])
                    accum_temp[dt].append(ds_dict[site]['airtemp'].sel(time=year(y, y+1), dt=dt).isel(time=premelt_season).median().values[()])
                    melt_temp[dt].append(ds_dict[site]['airtemp'].sel(time=year(y, y+1), dt=dt).isel(time=melt_season).median().values[()])
                    season_temp[dt].append(ds_dict[site]['airtemp'].sel(time=year(y, y+1), dt=dt).isel(time=slice(firstsnow, stop)).median().values[()])
                    total_precips[dt].append(3600*ds_dict[site]['pptrate'].sel(time=year(y, y+1), dt=dt).sum().values[()])
                    snowtemp_min[dt].append(ds_dict[site]['scalarSnowTemp'].sel(time=year(y, y+1), dt=dt, model='JRDN').median().values[()])
                    
                 
l_layers = ['2L_thin' , '2L_mid', '2L_thick', '3L_thin', '3L_mid', '3L_thick', '4L_thin', '4L_mid', '4L_thick', 'CLM-like']
legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=c, label=m, markersize=15) for m, c in zip(l_layers, colors_all)]

In [None]:
mpl.rcParams['figure.figsize'] = (20, 14)
fig, axes = plt.subplots(2, 2, sharex=True)
axes = axes.flatten()
ax = axes[0]
colors = ['powderblue', 'azure', '#ffe119', '#f58231']
color_dict = {'-2.0K': colors[0],
              '+0.0K': colors[1], 
              '+2.0K': colors[2], 
              '+4.0K': colors[3],}

x_pos = 0
for model, model_dict in nses.items():
    for dt, vals in model_dict.items():
        pa = ax.boxplot(vals, positions=[x_pos], widths=[0.8], 
                        notch=True, showfliers=False, patch_artist=True, 
                        medianprops={'color': 'black'})
        x_pos += 1
        for patch in pa['boxes']:
            patch.set_facecolor(color_dict[dt])
    x_pos += 1
ax.axhline(1, color='black')
ax.set_xlim(-1., x_pos -1)
ax.set_ylim(0.8, 1.05)
ax.set_xticks(np.arange(1, x_pos, 5))
ax.set_xticklabels(kges.keys(), rotation=45)
ax.set_ylabel(r'NSE')

ax = axes[1]
x_pos = 0
for model, model_dict in kges.items():
    for dt, vals in model_dict.items():
        pa = ax.boxplot(vals, positions=[x_pos], widths=[0.8], 
                        notch=True, showfliers=False, patch_artist=True, 
                        medianprops={'color': 'black'})
        x_pos += 1
        for patch in pa['boxes']:
            patch.set_facecolor(color_dict[dt])
    x_pos += 1
ax.axhline(1, color='black')
ax.set_xlim(-1., x_pos -1)
ax.set_ylim(0.6, 1.05)
ax.set_xticks(np.arange(1, x_pos, 5))
ax.set_xticklabels(kges.keys(), rotation=45)
ax.set_ylabel(r'KGE')

ax = axes[2]
x_pos = 0
for model, model_dict in psds.items():
    for dt, vals in model_dict.items():
        pa = ax.boxplot(vals, positions=[x_pos], widths=[0.8], 
                        notch=True, showfliers=False, patch_artist=True, 
                        medianprops={'color': 'black'})
        x_pos += 1
        for patch in pa['boxes']:
            patch.set_facecolor(color_dict[dt])
    x_pos += 1
ax.set_xlim(-1., x_pos -1)
ax.axhline(0, color='black')
ax.set_xticks(np.arange(1, x_pos, 5))
ax.set_xticklabels(kges.keys(), rotation=45)
ax.set_ylabel(r'$\Delta$ Peak SWE (mm)')

ax = axes[3]
x_pos = 0
legend_pa = []
for model, model_dict in sdds.items():
    for dt, vals in model_dict.items():
        vals = [v for v in vals if not np.isnan(v)]
        pa = ax.boxplot(vals, positions=[x_pos], widths=[0.8], 
                        notch=True, showfliers=False, patch_artist=True, 
                        medianprops={'color': 'black'})
        x_pos += 1
        for patch in pa['boxes']:
            patch.set_facecolor(color_dict[dt])
        legend_pa.append(pa)
    x_pos += 1
ax.set_xlim(-1., x_pos -1)
ax.set_xticks(np.arange(1, x_pos, 5))
ax.set_xticklabels(kges.keys(), rotation=45)
ax.legend([pa["boxes"][0] for pa in legend_pa], ['-2.0K', '+0.0K', '+2.0K', '+4.0K'], loc='lower_right')
ax.axhline(0, color='black')
ax.set_ylabel(r'$\Delta$SDD (days)')
plt.tight_layout()