In [None]:
%pylab inline
import os
import sys
import glob
import json
import itertools
import importlib
import pandas as pd
import xarray as xr
from numpy import pi
from jupyterthemes import jtplot
from matplotlib import patches
import hyeenna as hy
jtplot.style(jtplot.infer_theme(), fscale=1.4)
jtplot.figsize(x=16, y=6)
hy = importlib.reload(hy)

#wb_vars = ['precipitation', 'swe', 'evaporation', 'runoff', 'soil_moist']
wb_vars = ['runoff', 'evap', 'precip', 'soil_moist', 'swe']
wb_symbols = ['R', 'ET', 'P', 'ΔSM', 'ΔSWE']
wb_map = {k: v for k, v in zip(wb_vars, wb_symbols)}
METRIC = 'chebyshev'
EPS = 1e-12
K = 10

In [None]:
plt.plot([0])
jtplot.figsize(x=16, y=8)

In [None]:
def raw_data(data_array):
    return data_array.values.flatten()

def good_inds(x):
    return np.where(np.logical_and(x > -3000, x < 3000, np.logical_not(np.isnan(x))))[0]

def get_data(ds, dvars=True):
    evap     = raw_data(ds['evaporation'])[1:]
    runoff   = raw_data(ds['runoff'])[1:]
    precip   = raw_data(ds['precipitation'])[1:]
    if dvars:
        soil_liq = raw_data(ds['soil_moisture'].diff(dim='time'))
        swe      = raw_data(ds['swe'].diff(dim='time'))
        names = ['R', 'ET', 'P', 'ΔSM', 'ΔSWE']
    else:
        soil_liq = raw_data(ds['soil_moisture'])[1:]
        swe      = raw_data(ds['swe'])[1:]
        names = ['R', 'ET', 'P', 'SM', 'SWE']
    varlist = [runoff, evap, precip, soil_liq, swe]
    varlist = [np.around((v - v.mean()) / v.std(), 5) for v in varlist]
    return names, varlist

In [None]:
summa_will = xr.open_dataset('./data/summa_will_info.nc')
summa_snake = xr.open_dataset('./data/summa_snake_info.nc')
summa_rockies = xr.open_dataset('./data/summa_rockies_info.nc')
summa_olys = xr.open_dataset('./data/summa_olys_info.nc')
summa_cascade = xr.open_dataset('./data/summa_cascade_info.nc')

vic_will = xr.open_dataset('./data/vic_will_info.nc')
vic_snake = xr.open_dataset('./data/vic_snake_info.nc')
vic_rockies = xr.open_dataset('./data/vic_rockies_info.nc')
vic_olys = xr.open_dataset('./data/vic_olys_info.nc')
vic_cascade = xr.open_dataset('./data/vic_cascade_info.nc')

prms_will = xr.open_dataset('./data/prms_will_info.nc')
prms_snake = xr.open_dataset('./data/prms_snake_info.nc')
prms_rockies = xr.open_dataset('./data/prms_rockies_info.nc')
prms_olys = xr.open_dataset('./data/prms_olys_info.nc')
prms_cascade = xr.open_dataset('./data/prms_cascade_info.nc')

In [None]:
analysis_dict = {
    'summa_will': summa_will,
    'summa_snake': summa_snake,
    'summa_rockies': summa_rockies,
    'summa_olys': summa_olys,
    'summa_cascade': summa_cascade,
    'vic_will': vic_will,
    'vic_snake': vic_snake,
    'vic_rockies': vic_rockies,
    'vic_olys': vic_olys,
    'vic_cascade': vic_cascade,
    'prms_will': prms_will,
    'prms_snake': prms_snake,
    'prms_rockies': prms_rockies,
    'prms_olys': prms_olys,
    'prms_cascade': prms_cascade,
}

In [None]:
def plot_lags(var_a, var_b, taus=np.arange(1, 10, 1), loc='olys', ss=3000, nrun=5, dvars=True, ax=False):
    if not ax:
        fig, ax = plt.subplots()
    a = 0.3
    x_idx = np.arange(0, len(taus))
    if dvars:
        wb_vars = ['R', 'ET', 'P', 'ΔSM', 'ΔSWE']
    else:
        wb_vars = ['R', 'ET', 'P', 'SM', 'SWE']
    wb_vars.remove(var_a)
    wb_vars.remove(var_b)
    def _unpack(st, mi):
        lower, median, upper, thresh = [], [], [], []
        for s in st:
            if mi > 0:
                s['lower'], s['upper'] = np.percentile(s['results'], [5, 95])
                thresh.append(np.max([s['shuffled_thresh']/mi, 0.0]))
                median.append(np.max([s['median']/mi, 0.0]))
                lower.append(np.max([s['lower']/mi, 0.0]))
                upper.append(np.max([s['upper']/mi, 0.0]))
            else:
                thresh.append(0)
                median.append(0)
                lower.append(0)
                upper.append(0)
        return lower, median, upper, thresh
    
    ds = analysis_dict['summa_{}'.format(loc)]
    names, varlist = get_data(ds, dvars)
    var_dict = {k: v for k, v in zip(names, varlist)}
    z = np.vstack([var_dict[v] for v in wb_vars]).T
    st = [hy.shuffle_test(hy.transfer_entropy,
                             data={'X': var_dict[var_a], 'Y': var_dict[var_b]},
                             #data={'X': var_dict[var_a], 'Y': var_dict[var_b], 'Z': z},
                             params={'tau': t, 'omega': 1, 'nu': 1, 'k': 1, 'l': 1, 'm': 1},
                             nruns=nrun, sample_size=ss) for t in taus]
    mi = hy.mutual_info(var_dict[var_a], var_dict[var_b])
    lower, median, upper, thresh = _unpack(st, mi)

    ax.plot(taus, median, lw=2.5, label='SUMMA', color='darkgreen')
    ax.plot(taus, thresh, lw=2.5, linestyle=':', color='darkgreen')
    ax.fill_between(taus, lower, upper, 
                    facecolor='darkgreen', alpha=a, label='')

    ds = analysis_dict['prms_{}'.format(loc)]
    names, varlist = get_data(ds, dvars)
    var_dict = {k: v for k, v in zip(names, varlist)}
    z = np.vstack([var_dict[v] for v in wb_vars]).T
    st = [hy.shuffle_test(hy.transfer_entropy,
                             data={'X': var_dict[var_a], 'Y': var_dict[var_b]},
                             #data={'X': var_dict[var_a], 'Y': var_dict[var_b], 'Z': z},
                             params={'tau': t, 'omega': 1, 'nu': 1, 'k': 1, 'l': 1, 'm': 1},
                             nruns=nrun, sample_size=ss) for t in taus]
    mi = hy.mutual_info(var_dict[var_a], var_dict[var_b])
    lower, median, upper, thresh = _unpack(st, mi)
    ax.plot(taus, median, lw=2.5, label='PRMS', color='darkslateblue')
    ax.plot(taus, thresh, lw=2.5, linestyle=':', color='darkslateblue')
    ax.fill_between(taus, lower, upper, 
                    facecolor='slateblue', alpha=a, label='')
    
    ds = analysis_dict['vic_{}'.format(loc)]
    names, varlist = get_data(ds, dvars)
    var_dict = {k: v for k, v in zip(names, varlist)}
    z = np.vstack([var_dict[v] for v in wb_vars]).T
    st = [hy.shuffle_test(hy.transfer_entropy,
                             data={'X': var_dict[var_a], 'Y': var_dict[var_b]},
                             #data={'X': var_dict[var_a], 'Y': var_dict[var_b], 'Z': z},
                             params={'tau': t, 'omega': 1, 'nu': 1, 'k': 1, 'l': 1, 'm': 1},
                             nruns=nrun, sample_size=ss) for t in taus]
    mi = hy.mutual_info(var_dict[var_a], var_dict[var_b])
    lower, median, upper, thresh = _unpack(st, mi)
    ax.plot(taus, median, lw=2.5, label='VIC', color='darkorange')
    ax.plot(taus, thresh, lw=2.5, linestyle=':', color='darkorange')
    ax.fill_between(taus, lower, upper, 
                    facecolor='darkorange', alpha=a, label='')
    
    return plt.gcf(), plt.gca()

In [None]:
jtplot.figsize(x=16, y=12)
fig, axes = plt.subplots(nrows=5, ncols=5, sharex=True, sharey=False)
axes = np.array(axes).flatten()

loc = 'snake'
taus = np.arange(1, 15, 2)
ss = 3000
nrun = 20
dvars = True
if dvars:
    wb_vars = ['R', 'ET', 'P', 'ΔSM', 'ΔSWE']
else:
    wb_vars = ['R', 'ET', 'P', 'SM', 'SWE']
        
for ax, v in zip(axes, itertools.product(wb_vars, wb_vars)):
    print(v)
    if v[0] == v[1] or v[1] == 'P':
        continue
    plot_lags(v[0], v[1], taus=taus, nrun=nrun, loc=loc, ax=ax, ss=ss, dvars=dvars)
    ax.set_title(r'${} \rightarrow {}$'.format(v[0], v[1]))
    ax.set_xticks(taus)
plt.tight_layout()
jtplot.figsize(x=16, y=6)