# Water balance plotting



In [None]:
%pylab inline
import shapely
import calendar
import pandas as pd
import xarray as xr
import geopandas as gp
import matplotlib as mpl
import cartopy.crs as ccrs
from rasterio import features
from affine import Affine
import summa_plot as sp
from jupyterthemes import jtplot
from matplotlib import patches

jtplot.style(jtplot.infer_theme(), fscale=1.4)
jtplot.style('grade3', fscale=1.4)
jtplot.figsize(x=19, y=4)
mpl.rcParams['figure.figsize'] = (19, 4)

totAREA = 810000
totHRU = 11723
totGRID = 23929.
YEAR = 2008
year_slice = slice(f'10-01-{YEAR-1}', f'09-30-{YEAR}')
wb_vars = ['precipitation', 'swe', 'evaporation', 'runoff', 'soil_moisture', 'canopy_moisture']
loc_names = ['willamette', 'snake', 'olys', 'cascade', 'rockies']
wb_longnames = ['Evapotranspiration (ET)', 'Runoff', 'Precipitation', 'Soil & canopy moisture', 'Snow water equivalent (SWE)']
loc_longnames = ['Willamette', 'Snake', 'Olympics', 'North Cascades', 'Canadian Rockies']
name_map = {k: v for k, v in zip(loc_names, loc_longnames)}
delta_vars = ['swe', 'soil_liquid']
def scale_to_area(ds, gdf):
    for var in wb_vars:
        ds[var] = ds[var] * gdf['rel_area'].values
    return ds


In [None]:
plt.subplots()
jtplot.style('grade3', fscale=1.5)
jtplot.figsize(x=18, y=10)
mpl.rcParams['figure.figsize'] = (18, 10)

# Load data

In [None]:
SHAPEFILE = './data/shapefile.shp'
WILLAMETTE = './data/subshapes/willamette2.shp'
SNAKE = './data/subshapes/snake.shp'
ROCKIES = './data/subshapes/can_rockies.shp'
OLYMPIC = './data/subshapes/olympics.shp'

summa_will = xr.open_dataset('./data/summa_will.nc')
summa_snake = xr.open_dataset('./data/summa_snake.nc')
summa_rockies = xr.open_dataset('./data/summa_rockies.nc')
summa_olys = xr.open_dataset('./data/summa_olys.nc')

vic_will = xr.open_dataset('./data/vic_will.nc')
vic_snake = xr.open_dataset('./data/vic_snake.nc')
vic_rockies = xr.open_dataset('./data/vic_rockies.nc')
vic_olys = xr.open_dataset('./data/vic_olys.nc')

prms_will = xr.open_dataset('./data/prms_will.nc')
prms_snake = xr.open_dataset('./data/prms_snake.nc')
prms_rockies = xr.open_dataset('./data/prms_rockies.nc')
prms_olys = xr.open_dataset('./data/prms_olys.nc')

## Load in regions
gdf_will = gp.GeoDataFrame.from_file(WILLAMETTE)
gdf_will = gdf_will.to_crs({'init': 'epsg:4326'})
gdf_will = gdf_will[gdf_will['hru'].isin(summa_will.hru.values)]
summa_will = summa_will.sel(hru=gdf_will['hru'].values)

gdf_snake = gp.GeoDataFrame.from_file(SNAKE)
gdf_snake = gdf_snake.to_crs({'init': 'epsg:4326'})
gdf_snake = gdf_snake[gdf_snake['hru'].isin(summa_snake.hru.values)]

gdf_rockies = gp.GeoDataFrame.from_file(ROCKIES)
gdf_rockies = gdf_rockies.to_crs({'init': 'epsg:4326'})
gdf_rockies = gdf_rockies[gdf_rockies['hru'].isin(summa_rockies.hru.values)]

gdf_olys = gp.GeoDataFrame.from_file(OLYMPIC)
gdf_olys = gdf_olys.to_crs({'init': 'epsg:4326'})
gdf_olys = gdf_olys[gdf_olys['hru'].isin(summa_olys.hru.values)]

## Preprocess SUMMA
summa_will_seas = scale_to_area(summa_will, gdf_will)
summa_snake_seas = scale_to_area(summa_snake, gdf_snake)
summa_rockies_seas = scale_to_area(summa_rockies, gdf_rockies)
summa_olys_seas = scale_to_area(summa_olys, gdf_olys)

In [None]:

def calc_monthly_flux(da: xr.DataArray, year: int, prms: bool=False) -> xr.DataArray:
    """Calculates monthly change in a data array for a given year"""
    feb_end = 29 if calendar.isleap(year) else 28
    start = [f'9-30-{year-1}', f'10-31-{year-1}', f'11-30-{year-1}',
             f'12-31-{year-1}', f'01-31-{year}', f'02-{feb_end}-{year}', 
             f'03-31-{year}', f'04-30-{year}', f'05-31-{year}', 
             f'06-30-{year}', f'07-31-{year}', f'08-31-{year}']
    if prms:
        start[0] = f'10-01-{year-1}'
    end =  [f'10-31-{year-1}', f'11-30-{year-1}', f'12-31-{year-1}', 
            f'01-31-{year}', f'02-{feb_end}-{year}', f'03-31-{year}', 
            f'04-30-{year}', f'05-31-{year}', f'06-30-{year}', 
            f'07-31-{year}', f'08-31-{year}', f'09-30-{year}']
    fluxes = np.array([da.sel(time=e).values - da.sel(time=s).values for s, e in zip(start, end)])
    return fluxes

def calc_monthly_sum(da: xr.DataArray, year: int, prms: bool=False) -> xr.DataArray:
    """Calculates monthly change in a data array for a given year"""
    feb_end = 29 if calendar.isleap(year) else 28
    start =  [f'10-01-{year-1}', f'11-01-{year-1}', f'12-01-{year-1}', 
              f'01-01-{year}', f'02-01-{year}', f'03-01-{year}', 
              f'04-01-{year}', f'05-01-{year}', f'06-01-{year}', 
              f'07-01-{year}', f'08-01-{year}', f'09-01-{year}']
    end =  [f'10-31-{year-1}', f'11-30-{year-1}', f'12-31-{year-1}', 
            f'01-31-{year}', f'02-{feb_end}-{year}', f'03-31-{year}', 
            f'04-30-{year}', f'05-31-{year}', f'06-30-{year}', 
            f'07-31-{year}', f'08-31-{year}', f'09-30-{year}']
    fluxes = np.array([da.sel(time=slice(s,e)).sum(dim='time') for s, e in zip(start, end)])
    return fluxes

def monthly_water_balance(ds: xr.Dataset, year: int, prms=False,
                          agg_dims: list=None, weights: pd.Series=None) -> pd.DataFrame:
    wb_vars = ['evaporation', 'runoff', 'precipitation', 'soil_moisture', 'swe']
    time_group = ds.sel(time=slice(f'10-01-{year-1}', f'9-30-{year}')).time.dt.month
    wb_monthly = ds.sel(time=slice(f'10-01-{year-1}', f'9-30-{year}')).groupby(time_group).sum(dim=['time'])
    wb_monthly['swe'].values = calc_monthly_flux(ds['swe'], year, prms)
    wb_monthly['soil_moisture'].values = calc_monthly_flux(ds['soil_moisture'], year, prms)
    wb_monthly['soil_moisture'].values += calc_monthly_flux(ds['canopy_moisture'], year, prms)
    wb_monthly['evaporation'].values = calc_monthly_sum(ds['evaporation'], year, prms)
    wb_monthly['runoff'].values = calc_monthly_sum(ds['runoff'], year, prms)
    wb_monthly['precipitation'].values = calc_monthly_sum(ds['precipitation'], year, prms)
    if agg_dims is not None:
        wb_monthly = wb_monthly[wb_vars].sum(dim=agg_dims)
    else:
        wb_monthly = wb_monthly[wb_vars]
    wb_df = wb_monthly.to_dataframe()
    wb_df.index -= 1
    return wb_df

In [None]:
def plot_water_balance(summa, vic, prms, gdf, keyname, ax=None, legend=True):
    selHRU = len(gdf)
    selGRID = prms[keyname].sum(skipna=True).values
    scaleHRU = gdf['rel_area'].sum()
    scaleGRID = selGRID
    summa_seas = summa
    vic_seas = -1 * vic
    vic_seas['precipitation'] *= -1
    prms_seas = -1 * prms
    prms_seas['precipitation'] *= -1
    
    s_df = monthly_water_balance(summa_seas, YEAR, agg_dims=['hru']) / scaleHRU
    v_df = monthly_water_balance(vic_seas, YEAR, agg_dims=['lat','lon']) / scaleGRID
    p_df = monthly_water_balance(prms_seas, YEAR, agg_dims=['lat', 'lon'], prms=True) / scaleGRID 
    s_df.columns = wb_longnames
    
    if not ax:
        fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True)
    if legend:
        patch = [patches.Rectangle([1,1], 60, 30, edgecolor='black'),
                 patches.Rectangle([1,1], 60, 30, edgecolor='black', hatch='///'),
                 patches.Rectangle([1,1], 60, 30, edgecolor='black', hatch='..')]
        legend = plt.legend(patch, ['SUMMA', 'VIC', 'PRMS'], bbox_to_anchor=(0.25, -0.1), title='')
        ax.add_artist(legend)
    
    s_bar = s_df.plot(kind='bar', ax=ax,  stacked=True, position=1.5, 
                      legend=legend, width=0.3)
    v_df.plot.bar(ax=ax,  stacked=True, position=0.5, width=0.3, legend=False,
                  fill=True, hatch='//')
    p_df.plot.bar(ax=ax,  stacked=True, position=-.5, width=0.3, legend=False,
                  fill=True, hatch='..')
    months = ['Oct', 'Nov', 'Dec', 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep']
    plt.sca(ax)
    plt.title(name_map[keyname])
    plt.xticks(np.arange(12), months, rotation=45)
    if legend:
        s_bar.legend(bbox_to_anchor=(.7, -0.1), title='')
    
    ax.axhline(0, color='black', linewidth=2, label='')
    ax.set_xlim([-0.5, 11.5])
    #ax.set_ylabel('Flux (mm/month)')
    ax.set_xlabel('')
    ax.set_axisbelow(True)
    return ax

In [None]:
def calc_seasonal_flux(da: xr.DataArray, year: int, prms: bool=False) -> xr.DataArray:
    """Calculates seasonal change in a data array for a given year"""
    feb_end = 29 if calendar.isleap(year) else 28
    start = [f'12-01-{year-1}', f'03-01-{year}', f'06-01-{year}', f'08-31-{year}']
    end =  [f'02-{feb_end}-{year}', f'05-31-{year}', f'08-31-{year}', f'11-30-{year}']
    fluxes = np.array([da.sel(time=e).values - da.sel(time=s).values for s, e in zip(start, end)])
    return fluxes

def calc_seasonal_sum(da: xr.DataArray, year: int, prms: bool=False) -> xr.DataArray:
    """Calculates seasonal change in a data array for a given year"""
    feb_end = 29 if calendar.isleap(year) else 28
    start = [f'12-01-{year-1}', f'03-01-{year}', f'06-01-{year}', f'09-01-{year}']
    end =  [f'02-{feb_end}-{year}', f'05-31-{year}', f'08-31-{year}', f'11-30-{year}']
    sums = np.array([da.sel(time=slice(s,e)).sum(dim='time') for s, e in zip(start, end)])
    return sums

def seasonal_water_balance(ds: xr.Dataset, year: int, prms=False,
                          agg_dims: list=None, weights: pd.Series=None) -> pd.DataFrame:
    wb_vars = ['evaporation', 'runoff', 'precipitation', 'soil_moisture', 'swe']
    time_group = ds.sel(time=slice(f'11-30-{year-1}', f'12-31-{year}')).time.dt.season
    wb_seasonal = ds.sel(time=slice(f'11-30-{year-1}', f'12-31-{year}')).groupby(time_group).sum(dim=['time'])
    wb_seasonal['swe'].values = calc_seasonal_flux(ds['swe'], year, prms)
    wb_seasonal['soil_moisture'].values = calc_seasonal_flux(ds['soil_moisture'], year, prms)
    wb_seasonal['soil_moisture'].values += calc_seasonal_flux(ds['canopy_moisture'], year, prms)
    wb_seasonal['evaporation'].values = calc_seasonal_sum(ds['evaporation'], year, prms)
    wb_seasonal['runoff'].values = calc_seasonal_sum(ds['runoff'], year, prms)
    wb_seasonal['precipitation'].values = calc_seasonal_sum(ds['precipitation'], year, prms)
    if agg_dims is not None:
        wb_seasonal = wb_seasonal[wb_vars].sum(dim=agg_dims)
    else:
        wb_seasonal = wb_seasonal[wb_vars]
    wb_df = wb_seasonal.to_dataframe()
    #wb_df.index -= 1
    return wb_df

In [None]:
def plot_water_balance(summa, vic, prms, gdf, keyname, ax=None, legend=True):
    selHRU = len(gdf)
    selGRID = prms[keyname].sum(skipna=True).values
    scaleHRU = gdf['rel_area'].sum()
    scaleGRID = selGRID
    summa_seas = summa
    vic_seas = -1 * vic
    vic_seas['precipitation'] *= -1
    prms_seas = -1 * prms
    prms_seas['precipitation'] *= -1
    s_df = []
    v_df = []
    p_df = []
    for year in np.arange(1951, 2009, 2):
        print(year)
        s_df.append(seasonal_water_balance(summa_seas, year, agg_dims=['hru']) / scaleHRU)
        v_df.append(seasonal_water_balance(vic_seas, year, agg_dims=['lat','lon']) / scaleGRID)
        p_df.append(seasonal_water_balance(prms_seas, year, agg_dims=['lat', 'lon'], prms=True) / scaleGRID )
    s_df = pd.concat(s_df).groupby('season').mean()
    v_df = pd.concat(v_df).groupby('season').mean()
    p_df = pd.concat(p_df).groupby('season').mean()
    s_df.columns = wb_longnames
    months = ['Winter', 'Spring', 'Summer', 'Fall']
    
    if not ax:
        fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True)
    if legend:
        patch = [patches.Rectangle([1,1], 60, 30, edgecolor='black'),
                 patches.Rectangle([1,1], 60, 30, edgecolor='black', hatch='///'),
                 patches.Rectangle([1,1], 60, 30, edgecolor='black', hatch='..')]
        legend = plt.legend(patch, ['SUMMA', 'VIC', 'PRMS'], bbox_to_anchor=(0.25, -0.1), title='')
        ax.add_artist(legend)
    
    s_bar = s_df.plot(kind='bar', ax=ax,  stacked=True, position=1.5, 
                      legend=legend, width=0.3)
    v_df.plot.bar(ax=ax,  stacked=True, position=0.5, width=0.3, legend=False,
                  fill=True, hatch='//')
    p_df.plot.bar(ax=ax,  stacked=True, position=-.5, width=0.3, legend=False,
                  fill=True, hatch='..')
    plt.sca(ax)
    plt.title(name_map[keyname])
    plt.xticks(np.arange(4), months, rotation=45)
    if legend:
        s_bar.legend(bbox_to_anchor=(.7, -0.1), title='')
    
    ax.axhline(0, color='black', linewidth=2, label='')
    ax.set_xlim([-0.5, 3.5])
    #ax.set_ylabel('Flux (mm/month)')
    ax.set_xlabel('')
    ax.set_axisbelow(True)
    return ax

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, sharex=True)
axes = axes.flatten()
axes[0].set_ylabel("Flux (mm/month)")
#axes[1][0].set_ylabel("Flux (mm/month)")
print('snake...')
plot_water_balance(summa_snake_seas, vic_snake, prms_snake, gdf_snake, 
                   'snake', ax=axes[0], legend=False)
print('will...')
plot_water_balance(summa_will_seas, vic_will, prms_will, gdf_will, 
                   'willamette', ax=axes[1], legend=False)
print('rockies...')
plot_water_balance(summa_rockies_seas, vic_rockies, prms_rockies, gdf_rockies, 
                   'rockies', ax=axes[3], legend=False)
print('olys...')
plot_water_balance(summa_olys_seas, vic_olys, prms_olys, gdf_olys, 
                   'olys', ax=axes[4], legend=False)

patch = [patches.Rectangle([1,1], 60, 30, edgecolor='black'),
         patches.Rectangle([1,1], 60, 30, edgecolor='black', hatch='///'),
         patches.Rectangle([1,1], 60, 30, edgecolor='black', hatch='..')]
colors = ['#3472c6', '#83a83b', '#c44e52', '#8172b2', '#ff914d']
patch2 = [patches.Rectangle([1, 1], 60, 30, color=c) for c in colors]
varnames = ['Evaporation', 'Runoff', 'Precipitation', 'ΔSoil Liquid', 'ΔSWE']
legend = axes[-1].legend(patch+patch2, ['SUMMA', 'VIC', 'PRMS']+varnames, title='')
axes[-1].add_artist(legend)
    
axes[0].text(0.05, 1.03, 'a)', transform=axes[0].transAxes)
axes[1].text(0.05, 1.03, 'b)', transform=axes[1].transAxes)
axes[3].text(0.05, 1.03, 'c)', transform=axes[3].transAxes)
axes[4].text(0.05, 1.03, 'd)', transform=axes[4].transAxes)

axes[-1].set_axis_off()
axes[2].set_axis_off()
axes[-1].grid(False)
axes[2].grid(False)
plt.tight_layout()