In [None]:
%pylab inline
import os
import sys
import glob
import json
import itertools
import importlib
import pandas as pd
import xarray as xr
from numpy import pi
import geopandas as gp
import summa_plot as sp
import matplotlib as mpl
from pprint import pprint 
from functools import reduce
from jupyterthemes import jtplot
from summa_plot.spatial import add_map_features
import hyeenna as hy
jtplot.style(jtplot.infer_theme(), fscale=2.2)
jtplot.style('grade3', fscale=2.2)
jtplot.figsize(x=13, y=13)
hy = importlib.reload(hy)

In [None]:
plt.plot([0])
jtplot.figsize(x=20, y=16)
plt.clf()

In [None]:
def fix(ds):
    ds *= -1
    ds['precipitation'] *= -1
    return ds

def subtract_yearly_min(da):
    return (da.groupby(da.time.dt.year)
            .apply(lambda x: x - x.min(dim='time', skipna=True))
            .drop('year'))


In [None]:
summa_will =    fix(xr.open_dataset('./data/summa_will.nc'))
summa_snake =   fix(xr.open_dataset('./data/summa_snake.nc'))
summa_rockies = fix(xr.open_dataset('./data/summa_rockies.nc'))
summa_olys =    fix(xr.open_dataset('./data/summa_olys.nc'))

In [None]:
summa_will['swe'] = subtract_yearly_min(summa_will['swe'])
summa_will['soil_moisture'] = subtract_yearly_min(summa_will['soil_moisture'])

summa_snake['swe'] = subtract_yearly_min(summa_snake['swe'])
summa_snake['soil_moisture'] = subtract_yearly_min(summa_snake['soil_moisture'])

summa_rockies['swe'] = subtract_yearly_min(summa_rockies['swe'])
summa_rockies['soil_moisture'] = subtract_yearly_min(summa_rockies['soil_moisture'])

summa_olys['swe'] = subtract_yearly_min(summa_olys['swe'])
summa_olys['soil_moisture'] = subtract_yearly_min(summa_olys['soil_moisture'])

In [None]:
vic_will = xr.open_dataset('./data/vic_will.nc')
vic_will = vic_will.where(vic_will['willamette'] == 1, drop=True)
vic_snake = xr.open_dataset('./data/vic_snake.nc')
vic_snake = vic_snake.where(vic_snake['snake'] == 1, drop=True)
vic_rockies = xr.open_dataset('./data/vic_rockies.nc')
vic_rockies = vic_rockies.where(vic_rockies['rockies'] == 1, drop=True)
vic_olys = xr.open_dataset('./data/vic_olys.nc')
vic_olys = vic_olys.where(vic_olys['olys'] == 1, drop=True)

prms_will = xr.open_dataset('./data/prms_will.nc')
prms_will = prms_will.where(prms_will['willamette'] == 1, drop=True)
prms_snake = xr.open_dataset('./data/prms_snake.nc')
prms_snake = prms_snake.where(prms_snake['snake'] == 1, drop=True)
prms_rockies = xr.open_dataset('./data/prms_rockies.nc')
prms_rockies = prms_rockies.where(prms_rockies['rockies'] == 1, drop=True)
prms_olys = xr.open_dataset('./data/prms_olys.nc')
prms_olys = prms_olys.where(prms_olys['olys'] == 1, drop=True)

In [None]:
analysis_dict = {
    'summa_will': summa_will,
    'summa_snake': summa_snake,
    'summa_rockies': summa_rockies,
    'summa_olys': summa_olys,
    'vic_will': vic_will,
    'vic_snake': vic_snake,
    'vic_rockies': vic_rockies,
    'vic_olys': vic_olys,
    'prms_will': prms_will,
    'prms_snake': prms_snake,
    'prms_rockies': prms_rockies,
    'prms_olys': prms_olys
}
for k, v in analysis_dict.items():
    print(k)
    #v['weekofyear'] = v.time.dt.weekofyear
    #analysis_dict[k] = v#.where(v.time.dt.season=='SON', drop=True)#resample(time='W').mean()
    sub = v.resample(time='M').sum(dim='time')
    sub['swe'].values           = v['swe'].resample(time='M').mean(dim='time')
    sub['soil_moisture'].values = v['soil_moisture'].resample(time='M').mean(dim='time')
    
    #sub = v.resample(time='W').sum(dim='time')
    #sub['swe'].values           = v['swe'].resample(time='W').mean(dim='time')
    #sub['soil_moisture'].values = v['soil_moisture'].resample(time='W').mean(dim='time')
    
    sub['weekofyear'] = sub.time.dt.weekofyear
    sub['month'] = sub.time.dt.month
    analysis_dict[k] = sub


In [None]:
def raw_data(data_array):
    return data_array.values.flatten()

def get_data(ds, dvars=True):
    evap     = raw_data(ds['evaporation'])[2:]
    runoff   = raw_data(ds['runoff'])[2:]
    precip   = raw_data(ds['precipitation'])[2:]
    if dvars:
        soil_liq = np.diff(raw_data(ds['soil_moisture']))[1:]
        swe      = np.diff(raw_data(ds['swe']))[1:]
        names = ['R', 'ET', 'P', 'ΔSM', 'ΔSWE']
    else:
        soil_liq = raw_data(ds['soil_moisture'])[2:]
        swe      = raw_data(ds['swe'])[2:]
        names = ['R', 'ET', 'P', 'SM', 'SWE']
    varlist = [runoff, evap, precip, soil_liq, swe]
    return names, varlist

In [None]:
def data_dict(ds, dvars=True):
    n, v = get_data(ds, dvars)
    return {k: v for k, v in zip(n, v)}

def filter_data(filter_func, *args):
    args = args[0]
    mpdt = list(map(filter_func, args))
    good_inds = reduce(np.intersect1d, mpdt)
    args = [x[good_inds] for x in args]
    run = args[0]
    good_inds = np.argwhere(run>0)
    return [x[good_inds] for x in args]

def notnan(x):
    return np.argwhere(np.logical_and(~np.isnan(x), x>-1000))

In [None]:
def ts_info_transfer(ds, key='weekofyear'):
    def _call(y, xm, ym, how='mean'):
        return hy.estimator_stats(
                   hy.conditional_mutual_info,
                   data={'X': y,  'Y': xm, 'Z': ym},
                   params={}, 
                   nruns=20, sample_size=3000
               )[how]
    precip, et, swe, sm = [], [], [], []
    
    for woy in np.unique(ds[key]):
        print(woy)
        woy2 = woy - 1
        if woy2 == 0: woy2 = np.unique(ds[key].values)[-1]
        ds1 = ds.where(ds[key] == woy, drop=True)
        ds2 = ds.where(ds[key] == woy2, drop=True)
        rnow = data_dict(ds1)['R']
        dd2 = data_dict(ds2)
        names, data1 = dd2.keys(), dd2.values()
        rnow, *data1 = filter_data(notnan, [rnow, *data1])
        
        names = ['R', 'ET', 'P', 'ΔSM', 'ΔSWE']
        et.append(_call(    rnow, data1[1], data1[0]))
        precip.append(_call(rnow, data1[2], data1[0]))
        sm.append(_call(    rnow, data1[3], data1[0]))
        swe.append(_call(   rnow, data1[4], data1[0]))
    
    precip = np.array(precip)
    et = np.array(et)
    swe = np.array(swe)
    sm = np.array(sm)
    
    precip[precip < 0] = 0
    et[et < 0] = 0
    swe[swe < 0] = 0
    sm[sm < 0] = 0
    
    weekly_df = pd.DataFrame(index=np.unique(ds[key]))
    weekly_df['P']    = np.roll(precip, 3)#14)
    weekly_df['ET']   = np.roll(et, 3)#14)
    weekly_df['ΔSM']  = np.roll(sm, 3)#14)
    weekly_df['ΔSWE'] = np.roll(swe, 3)#14)
    return weekly_df

In [None]:
loc = 'rockies'
summa_ts_df = ts_info_transfer(analysis_dict['summa_{}'.format(loc)], 'month')

In [None]:
vic_ts_df = ts_info_transfer(analysis_dict['vic_{}'.format(loc)], 'month')

In [None]:
prms_ts_df = ts_info_transfer(analysis_dict['prms_{}'.format(loc)], 'month')

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True)
ax = ax.flatten()
ymax = 1.1*np.max([summa_ts_df.sum(axis=1).max(),
               vic_ts_df.sum(axis=1).max(), 
               prms_ts_df.sum(axis=1).max()])
summa_ts_df.plot.area(color=['#c44e52', '#3472c6', '#8172b2', '#ff914d'], ax=ax[0], legend=False)
vic_ts_df.plot.area(color=['#c44e52', '#3472c6', '#8172b2', '#ff914d'], ax=ax[1], legend=False)
prms_ts_df.plot.area(color=['#c44e52', '#3472c6', '#8172b2', '#ff914d'], ax=ax[2], legend=False)
ax[0].set_xticks([i for i in [4.2*i for i in range(1, 13)]])
ax[0].set_xticks([i for i in [i for i in range(1, 13)]])
months = ['Oct', 'Nov', 'Dec', 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep']
#months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
ax[2].set_xticklabels(months, rotation=30)
#ax[0].set_xlim([1, 53])
ax[0].set_xlim([1, 12])

ax[0].set_ylim([0, ymax])
ax[1].set_ylim([0, ymax])
ax[2].set_ylim([0, ymax])

ax[0].set_title('SUMMA')
ax[1].set_title('VIC')
ax[2].set_title('PRMS')
for a in ax:
    a.grid(False)
ax[1].set_ylabel('Information transferred to runoff (nats)')

In [None]:
def weekly_info_transfer(ds):
    def _call(y, xm, ym):
        return hy.estimator_stats(
                   hy.conditional_mutual_info,
                   data={'X': y,  'Y': xm, 'Z': ym},
                   params={}, #{'tau': 1, 'omega': 1, 'l': 1, 'k': 1},
                   nruns=30, sample_size=3000
               )['mean']
    
    precip, et, swe, sm = [], [], [], []
    for woy in np.unique(ds['weekofyear']):
        woy2 = woy - 1
        if woy2 == 0: woy2 = 52
        ds1 = ds.where(ds['weekofyear'] == woy, drop=True)
        ds2 = ds.where(ds['weekofyear'] == woy2, drop=True)
        rnow = get_data(ds1)[-1][1]
        names, data1 = get_data(ds2)
        names = ['R', 'ET', 'P', 'ΔSM', 'ΔSWE']
        et.append(_call(    rnow, data1[0], data1[1]))
        precip.append(_call(rnow, data1[2], data1[1]))
        sm.append(_call(    rnow, data1[3], data1[1]))
        swe.append(_call(   rnow, data1[4], data1[1]))
    
    precip = np.array(precip)
    et = np.array(et)
    swe = np.array(swe)
    sm = np.array(sm)
    
    precip[precip < 0] = 0
    et[et < 0] = 0
    swe[swe < 0] = 0
    sm[sm < 0] = 0
    
    weekly_df = pd.DataFrame(index=np.unique(ds['weekofyear']))
    smooth = lambda x: np.mean([np.roll(x, -1), x, np.roll(x, 1)],  axis=0)
    weekly_df['P']    = np.roll(smooth(precip), 15)
    weekly_df['R']   = np.roll(smooth(et), 15)
    weekly_df['ΔSM']  = np.roll(smooth(sm), 15)
    weekly_df['ΔSWE'] = np.roll(smooth(swe), 15)
    return weekly_df

In [None]:
vic_weekly_df = weekly_info_transfer(analysis_dict['vic_will'])
summa_weekly_df = weekly_info_transfer(analysis_dict['summa_will'])
prms_weekly_df = weekly_info_transfer(analysis_dict['prms_will'])

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True)
ax = ax.flatten()
summa_weekly_df.plot.area(color=['#c44e52', '#83a83b', '#8172b2', '#ff914d'], ax=ax[0], legend=False)
vic_weekly_df.plot.area(color=['#c44e52',   '#83a83b', '#8172b2', '#ff914d'], ax=ax[1])
prms_weekly_df.plot.area(color=['#c44e52',  '#83a83b', '#8172b2', '#ff914d'], ax=ax[2], legend=False)
#ax[0].set_xticks([i for i in [4.1*i for i in range(1, 13)]])
months = ['Oct', 'Nov', 'Dec', 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep']
ax[0].set_xticklabels(months, rotation=45)
ax[0].set_xlim([1, 53])
ax[0].set_xlim([1, 12])
ax[0].set_title('SUMMA')
ax[1].set_title('VIC')
ax[2].set_title('PRMS')
for a in ax:
    a.grid(False)
ax[1].set_ylabel('Information transferred to ET (nats)')