In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import geopandas as gp
from os.path import join, basename
from datetime import date, datetime
import os
from scipy.stats import rankdata
from scipy.interpolate import interp1d

In [None]:
import sys
import os
sys.path.append(os.path.abspath('../3-postprocess/'))
import xstats as xs 

from peaks import get_peaks


In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from plot_tools import *
import cartopy.crs as ccrs
import seaborn as sns
from string import ascii_uppercase as letters

In [None]:
rc = {'savefig.bbox': 'tight',  'savefig.format': 'png', 'savefig.dpi':300}
context = 'paper'
# sns.set(context=context, style='whitegrid', font_scale=0.75 if context == 'talk' else 1., rc=rc)
sns.set(context=context, style='whitegrid', font_scale=1.2 if context == 'paper' else 1., rc=rc)
crs_sub = ccrs.PlateCarree()

bmap_kwargs = dict()
plot_kwargs=dict(edgecolor=(0.5, 0.5, 0.5, 0.8), linewidth=0.5, legend=False, zorder=2)
box_kwargs=dict(whis=[5,95], boxprops=dict(linewidth=1.), medianprops=dict(linewidth=1.5), 
                showfliers=False, flierprops=dict(markersize=2))

locs = [1930, 1618, 809]
riv_names = {
    809: 'Mattapone', 
    1695: 'Weser',
    1930: 'Dal',
    1618: 'Volta',  
    2884: 'Ataran',
}
drivers = ['Hskewsurge','Q']

In [None]:
root = r'/scratch/compound_hotspots'
ddir = join(root, 'data', '4-postprocessed')
fdir = join(root, 'reports', 'figures')
wdw = 1
model = 'mean'
rp = 10

### Combine data

In [None]:
attrs_fn = join(ddir, 'rivmth_mean_attrs.csv')
attrs = pd.read_csv(attrs_fn, index_col='index').rename(columns={'rivmth_lat':'lat', 'rivmth_lon':'lon'})

attrs['uparea_log10'] = np.log10(np.maximum(attrs['uparea'].values, 0.001))
attrs['Hseasrange'] = attrs['Hseas_amax']-attrs['Hseas_amin']
attrs['mean_drain_length'] = attrs['mean_drain_length']/1e3 #[km]
attrs['mean_drain_slope'] = attrs['mean_drain_slope']*1e3 #[m/km]
attrs['uparea_100'] = attrs['uparea']/1e2
attrs['Hseasrange_cm'] = attrs['Hseasrange']*1e2
attrs['Hseas_amax_cm'] = attrs['Hseas_amax']*1e2
attrs['Hsurge_amax_cm'] = attrs['Hsurge_amax']*1e2

In [None]:
# lazely open time series data
fn_rivmth_ts = join(ddir, 'rivmth_reanalysis.zarr')
ds = xr.open_zarr(fn_rivmth_ts)
# open time extreme value analysis data
fn_wse_ev = join(ddir, f'rivmth_wse_ev.nc')
ds_rp = xr.open_dataset(fn_wse_ev).sel(index=attrs.index)#.sel(ensemble=[model])
# open annual max with drivers data
fn_drivers = join(ddir, f'rivmth_drivers_wdw{wdw}.nc')
ds_peaks = xr.open_dataset(fn_drivers).sel(index=attrs.index)#.sel(ensemble=[model])

In [None]:
fn_spear = join(ddir, f'rivmth_drivers_wdw{wdw}_spearmanrank.nc')
ds_spear = xr.open_dataset(fn_spear)

fn_drivers = join(ddir, f'rivmth_drivers_wdw{wdw}_ensemble-{model}.nc')
ds_drivers = xr.open_dataset(fn_drivers)

fn_impact = join(ddir, f'rivmth_pop_affected_ensemble-{model}.nc')
ds_impact = xr.open_dataset(fn_impact)

fn_wse_ev = join(ddir, f'rivmth_wse_ev_ensemble-{model}.nc')
ds_diff_h_stats = xr.open_dataset(fn_wse_ev)

ds_stats = xr.merge([ds_diff_h_stats, 
                     ds_impact, 
                     ds_drivers]).sel(T=rp)

In [None]:
gdf = pandas2geopandas(pd.concat([
    attrs,
    ds_stats.reset_coords(drop=True).to_dataframe()
], axis=1))

In [None]:
gdf[gdf['driver_insign']].sort_values(by='Q_mean', ascending=False).head()

## location plots

In [None]:
from scipy.stats import rankdata
from scipy.stats import ttest_ind
from scipy.interpolate import interp1d

def weibull(peaks, nyears=None):
    peaks = peaks[np.isfinite(peaks)]
    peaks_rank = rankdata(peaks, 'ordinal')
    P = peaks_rank/(peaks.size+1)
    freq = 1. if nyears is None else peaks.size / nyears
    rp = 1/(1-P)/freq
    return rp

def _interp_ev(peaks, vals, nyears=None):
    peaks = peaks[np.isfinite(peaks)]
    peaks.sort()
    peaks_rank = np.arange(peaks.size)+1
    P = peaks_rank/(peaks.size+1)
    freq = 1. if nyears is None else peaks.size / nyears
    rp = 1/(1-P)/freq
    kwargs = dict(
        kind='linear', bounds_error=False, assume_sorted=True,
        fill_value=(rp.min(), rp.max())
    )
    rp_out = interp1d(peaks, rp, **kwargs)(vals)
    return rp_out

In [None]:
scen_cmap = {
    'surge': np.asarray(plt.cm.tab10.colors[-1]), 
    'seas': np.asarray(plt.cm.tab10.colors[-2]), 
    'tide': np.asarray(plt.cm.tab10.colors[-3])
}
scen_mmap  = {'surge': 'o', 'seas': '^', 'tide': 'd'}
scen_nmap  = {'seas': 'seasonal'}
alpha = 0.025

xlim = [0.9,40]
rps2 = [1, 2, 4, 8, 16, 32]

In [None]:
rm = {'Hskewsurge_day':'Hskewsurge', 'Htot_day_max': 'Htot', 'Htide_day_max': 'Htide', 'Hsurge_day_max': 'Hsurge'}
ds1 = ds.sel(scen='surge').drop(['Hsurge', 'Htide']).rename(rm)
rps = np.array([0.2, 1, 2, 4, 8, 16, 32])
# cmap = ListedColormap(google_turbo_data[50:])
# cmap.set_over('black')
cmap = plt.cm.viridis_r
norm = BoundaryNorm(rps[rps>=1], cmap.N)

mmap = {
    'Hskewsurge': '^',
    'Q': 's',
}
cmmap = {
    'Hskewsurge': plt.cm.tab10.colors[3],
    'Q': plt.cm.tab10.colors[2],
    'WSE': plt.cm.tab10.colors[0],
}
labs = {
    'Hskewsurge': '$H_{SS}$',
    'Q': 'Q',
}
model2='jrc'


In [None]:
%matplotlib inline
rm = {'Hskewsurge_day':'Hskewsurge', 'Htot_day_max': 'Htot', 'Htide_day_max': 'Htide', 'Hsurge_day_max': 'Hsurge'}
ds1 = ds.sel(scen='surge').drop(['Hsurge', 'Htide']).rename(rm)

rps = np.array([0.25, 0.5, 1, 2, 4, 8, 16, 32])
cmap = plt.cm.viridis_r
norm = BoundaryNorm(rps[rps>=1], cmap.N)
kwargs = dict(cmap=cmap, norm=norm, linewidth=0.5, edgecolor='k', s=35)
for loc in [2294]:

    plt.close('all')
    fig = plt.figure(figsize=(20, 12))
    grid = plt.GridSpec(6,5, hspace=0.4, wspace=0.4)

    name = riv_names.get(loc,'unknown')
    gdf_loc = gdf.loc[loc,:]    
    
    # get data
    ts_loc = ds1.sel(index=loc, ensemble=model2).reset_coords(drop=True)
    t = ts_loc.time.values
    tlim = t[ts_loc.time.dt.year==2000][0], t[-1]
    df_loc = ds_peaks.sel(index=loc, ensemble=model2).reset_coords(drop=True).to_dataframe()
    tam = pd.to_datetime([f'{yr:04d}{doy:02.0f}' for yr,doy in zip(df_loc.index, df_loc.dayofyear)], format = "%Y%j")
    Hthresh = np.percentile(ts_loc[drivers[0]], 75)
    Hpeaks = get_peaks(ts_loc[drivers[0]], min_dist=30).reindex(time=t)
    Hpeaks = Hpeaks.where(Hpeaks>Hthresh)
    Qthresh = np.percentile(ts_loc[drivers[1]], 75)
    Qpeaks = get_peaks(ts_loc[drivers[1]], min_dist=45).reindex(time=t) #.groupby('time.year').max('time').values
    Qpeaks = Qpeaks.where(Qpeaks>Qthresh)
#     Tthresh = np.percentile(ts_loc['Htot'], 75)
#     Tpeaks = get_peaks(ts_loc['Htot'], min_dist=45).reindex(time=t) #.groupby('time.year').max('time').values
#     Tpeaks = Tpeaks.where(Tpeaks>Tthresh)
    
    #plot ts
    ax0 = fig.add_subplot(grid[:2, :-1])
    ax0.plot(t, ts_loc['WSE'].values, color=cmmap['WSE'])
    ax0.plot(tam, ts_loc['WSE'].to_series().loc[tam], '.k')
    ymin, ymax = ax0.get_ylim()
    ax0.vlines(x=tam, ymin=ymin, ymax=ymax, zorder=-1, color='k', linewidth=1, linestyle='-')
    ax0.set_xticklabels([])
    ax0.set_xlim(tlim)
    ax0.set_ylim([ymin, ymax])
    ax0.set_ylabel('WSE')
    ax0.set_title(f'{loc}. {name} River')

    
    ax0 = fig.add_subplot(grid[2:4, :-1])
    ax0.plot(t, ts_loc['Q'].values, color=cmmap['Q'])
    ax0.plot(t, Qpeaks.values, '.k')
    ymin, ymax = ax0.get_ylim()
    ax0.vlines(x=tam, ymin=ymin, ymax=ymax, zorder=-1, color='k', linewidth=1, linestyle='--')
    ax0.set_xlim(tlim)
    ax0.set_ylim([ymin, ymax])
    ax0.set_xticklabels([])
    ax0.set_ylabel('Q')

    ax0 = fig.add_subplot(grid[4:, :-1])
    ax0.plot(t, ts_loc['Hskewsurge'].values, color=cmmap['Hskewsurge'])
    ax0.plot(t, ts_loc['Hseas_day_mean'].values, 'k')
    ax0.plot(t, Hpeaks.values, '.k')
    ax0.set_ylabel('Hsurge')
    ymin, ymax = ax0.get_ylim()
    ax0.vlines(x=tam, ymin=ymin, ymax=ymax, zorder=-1, color='k', linewidth=1, linestyle='--')
    ax0.set_xlim(tlim)
    ax0.set_ylim([ymin, ymax])
    
    
    # ev plot
    ax = fig.add_subplot(grid[:2, -1]) 
    ds_ev_loc = ds_rp.sel(index=loc)
    z0 = attrs.loc[loc, 'z0']
    T = ds_ev_loc.T.values
    for scen in ['surge', 'seas', 'tide'][::-1]:
        _ds = ds_ev_loc.sel(scen=scen, ensemble=model2)
        peaks_am = _ds['annual_maxima'].values-z0
        Tpeaks_am = weibull(peaks_am)
        ev_ci = _ds['extreme_values_ci'].sel(alpha=[alpha, 1-alpha]).values - z0
        ev = _ds['extreme_values'].values - z0
        ax.scatter(Tpeaks_am, peaks_am, c=scen_cmap[scen], marker=scen_mmap[scen], label=scen_nmap.get(scen, scen), zorder=2)
        ax.plot(T, ev, color=scen_cmap[scen], zorder=1)
        ax.plot(T, ev_ci[0,:], color=scen_cmap[scen], linestyle='-.', alpha=0.4, zorder=0)
        ax.plot(T, ev_ci[1,:], color=scen_cmap[scen], linestyle='-.', alpha=0.4, zorder=0)
    ax.set_xscale('log')
    ax.legend(loc='lower left', bbox_to_anchor=(1.05, 0.5), title='sea boundary')
    ax.set_xticks(rps2)
    ax.set_xticklabels(rps2)
    ax.set_xlabel('return period [years]')
    ax.set_ylabel('water surface elevation [m+EGM96]')
    
    # plot ranks
    ax = fig.add_subplot(grid[2:4, -1])
    c = weibull(df_loc['h'].values)
    xx = np.maximum(rps[0],_interp_ev(Qpeaks.dropna('time').values, df_loc[f'{drivers[1]}'].values, 35))
    yy = np.maximum(rps[0],_interp_ev(Hpeaks.dropna('time').values, df_loc[f'{drivers[0]}'].values, 35))
    im = ax.scatter(x=xx, y=yy, c=c, **kwargs)

    for iyr, yr in enumerate(tam.year[c>10]):
        if yr < 2000: continue
        ax.annotate(str(yr), (xx[tam.year==yr], yy[tam.year==yr]))
                    
    # plot rank correlation
    ax1 = fig.add_subplot(grid[4:, -1])
    for v in drivers:
        ax1.plot(rankdata(df_loc[v]), rankdata(df_loc['h']), color=cmmap[v], label=labs[v], marker=mmap[v], linewidth=0)
    ax1.plot([0, 41],[0, 41],'--k')
    
    # axis
    ax.set_xlim([0.85, 38])
    ax.set_ylim([0.85, 38])
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_aspect('equal')
    ax.grid(False)
    ax1.set_xlim([0.5, 41])
    ax1.set_ylim([0.5, 41])
    ax1.set_aspect('equal')
    ax1.grid(False)
    
    # labels
    ax.set_yticks(rps)
    ax.set_yticklabels(rps)
    ax.set_ylabel(labs[drivers[0]]+' return period [years]')
    ax1.set_ylabel('riverine water level rank [-]')
    ax.set_xticks(rps)
    ax.set_xticklabels(rps)
    ax.set_xlabel(labs[drivers[1]]+' return period [years]')
    ax1.set_xlabel('driver rank [-]')
    
    # text 
    Hr = float(ds_spear.sel(index=loc, ensemble=model2)[f'{drivers[0]}_r'].values)
    Qr = float(ds_spear.sel(index=loc, ensemble=model2)[f'{drivers[1]}_r'].values)
    Hp = float(ds_spear.sel(index=loc, ensemble=model2)[f'{drivers[0]}_p'].values)
    Qp = float(ds_spear.sel(index=loc, ensemble=model2)[f'{drivers[1]}_p'].values)
    ax1.text(1, 35.5, labs[drivers[0]]+f': {Hr:.2f} ({Hp:.2f})\n{labs[drivers[1]]}: {Qr:.2f} ({Qp:.2f})', fontsize='small')
       
    # make colorbar
    pad, shrink, fraction = 0.02, 1.0, 0.04
    cax = fig.add_axes([1, 1, 0.1, 0.1]) # new ax
    cbar = fig.colorbar(im, extend='max', cax=cax)
    cbar.ax.set_ylabel("riverine water level\n return period [years]", rotation='vertical')
    posn = ax.get_position()
    cax.set_position([posn.x1+pad, posn.y0+posn.height*(1-shrink)/2., posn.width*fraction, posn.height*shrink])

    ax1.legend(loc='lower left', bbox_to_anchor=(1.05, 0.5), title='driver')

    posn = ax1.get_position()
    axg = fig.add_axes([posn.x1+0.01, posn.y0-0.1, posn.width*0.8, posn.height], projection=crs_sub) # new ax   
    basemap(axg, bbox=(-180, -60, 180, 90), gridlines=False, outline=False, features=['land'])
    gdf.loc[[loc],:].plot(ax=axg, marker='o', color='red', markersize=25, legend=True)
    x, y = gdf.loc[loc,:].geometry.coords[0]
    axg.text(x+10, y, loc, transform=crs_sub)

#     fn_fig = join(fdir, 'locs', f'loc{loc:04d}_{model2}_wdw{wdw}.png')
    fn_fig = join(fdir, f'loc{loc:04d}_{model2}_wdw{wdw}.png')
    print(basename(fn_fig))
    plt.savefig(fn_fig)
    break