In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import geopandas as gp
from os.path import join, basename
from datetime import date, datetime

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from plot_tools import *
import seaborn as sns
from string import ascii_uppercase as letters

In [None]:
root = r'/scratch/compound_hotspots'
ddir = join(root, 'data', '4-postprocessed')
fdir = join(root, 'reports', 'figures')

In [None]:
import sys
import os
sys.path.append(os.path.abspath('../3-postprocess/'))
from peaks import peaks_over_threshold, get_peaks
import xstats as xs 

In [None]:
drivers = ['WSE', 'Htide', 'Hskewsurge', 'Q']
Npeaks=35
wdw=1
min_dist=14

fn_peaks = join(ddir, f'rivmth_peaks_d{min_dist}.zarr')
peaks = xr.open_zarr(fn_peaks)

peaks_ = peaks.sel(scen='surge').drop('scen').fillna(-np.inf)
fn_out = join(ddir, f'rivmth_AMpeaks_wdw{wdw}.nc')

peaks_wdwmax = peaks_.rolling(time=wdw*2+1, min_periods=1, center=True).construct('window').max('window').astype(np.float32)
dss = []
for driver in drivers:
    da_peaks = peaks_[driver].where(peaks_[f'{driver}_peaks'], -np.inf)
    da_group_yr = da_peaks.groupby('time.year')
    da_peaks_am = da_group_yr == da_group_yr.max('time')
    itime = xs.xtopn_idx(da_peaks.where(da_peaks_am, -np.inf), n=Npeaks).load()
    ds_topn = xr.merge([
        peaks_[[driver, f'{driver}_rp', f'{driver}_peaks']], 
        peaks_wdwmax.drop([v for v in list(peaks_.data_vars.keys()) if v.startswith(driver)])
    ]).isel(time=itime).transpose('rank', 'index')
    dss.append(ds_topn)
ds_out = xr.concat(dss, dim='driver').reset_coords('time')
ds_out['driver'] = xr.Variable('driver', drivers)
ds_out['rank'] = xr.Variable('rank', np.arange(Npeaks).astype(int)+1)
ds_out.to_netcdf(fn_out)
        

In [None]:
rc = {'savefig.bbox': 'tight',  'savefig.format': 'png', 'savefig.dpi':300}
context = 'paper'
# sns.set(context=context, style='whitegrid', font_scale=0.75 if context == 'talk' else 1., rc=rc)
sns.set(context=context, style='whitegrid', font_scale=1.2 if context == 'paper' else 1., rc=rc)

cmap_div = sns.diverging_palette(220, 10, s=75, l=40, sep=1, as_cmap=True)

plot_kwargs=dict(edgecolor=(0.5, 0.5, 0.5, 0.8), linewidth=0.5, legend=False, zorder=2)
box_kwargs=dict(whis=[5,95], showfliers=False, boxprops=dict(linewidth=1.5), medianprops=dict(linewidth=1.5))

In [None]:
attrs_fn = join(ddir, 'rivmth_mean_attrs.csv')
attrs = pd.read_csv(attrs_fn, index_col='index').rename(columns={'rivmth_lat':'lat', 'rivmth_lon':'lon'})
attrs['Q_amax'] = np.log10(attrs['Q_amax'])
attrs['Qmsl_amax'] = np.log10(attrs['Qmsl_amax'])
attrs['uparea'] = np.log10(attrs['uparea'])
attrs['Hseasrange'] = attrs['Hseas_amax']-attrs['Hseas_amin']

## bivariate plots

In [None]:
import pandas as pd
from scipy import stats
from scipy.interpolate import interp1d
from lmoments3 import distr
import sys
import os
sys.path.append(os.path.abspath('../3-postprocess/'))
import peaks as xpeaks
import xstats as xs 

In [None]:
def idx_topn(da, n):
    def _idx_topn(x, n=50):
        return np.argsort(x)[::-1][:n]
    
    return xr.apply_ufunc(
        _idx_topn, 
        da, 
        kwargs=dict(n=n),
        input_core_dims=[['time']], 
        output_core_dims=[['rank']], 
        vectorize=True, 
        dask='allowed', 
        output_dtypes=[int],
        output_sizes={'rank':n}
    )

In [None]:
model='cnrs'
    rm = {'Hskewsurge_day':'Hskewsurge', 'Htot_day_max': 'Htot', 'Htide_day_max': 'Htide'}
fn_rivmth_ts = join(ddir, 'rivmth_reanalysis.zarr')
ds_rivmth = xr.open_zarr(fn_rivmth_ts)
ds_rivmth = ds_rivmth[['WSE', 'Q', 'Hskewsurge_day', 'Htot_day_max', 'Htide_day_max']].sel(scen='surge', ensemble=model).rename(rm)

In [None]:
min_dist=14
nyears=35
ds_peaks  = xpeaks.get_peaks(ds_rivmth, min_dist=min_dist, dim='time').reset_coords(drop=True).reindex_like(ds_rivmth)
ds_peaks_rp = xs.xinterp_ev(ds_peaks, ds_rivmth, nyears=nyears)
ds_peaks = xr.ufuncs.isfinite(ds_peaks).rename({v:f'{v}_peaks' for v in list(ds_peaks.data_vars.keys())})
ds_peaks_rp = ds_peaks_rp.rename({v:f'{v}_rp' for v in list(ds_peaks_rp.data_vars.keys())})
ds_out = xr.merge([ds_rivmth, ds_peaks, ds_peaks_rp]).transpose('time', 'index')
ds_out.to_netcdf(join(ddir, f'rivmth_peaks_d{min_dist}.nc'))

In [None]:
wdw=1
ds = xr.open_dataset(join(ddir, f'rivmth_peaks_d{min_dist}.nc'))
ds_wdwmax = ds.fillna(-np.inf).rolling(time=wdw*2+1, min_periods=1, center=True).construct('window').max('window').astype(np.float32)

In [None]:
drivers = ['WSE', 'Htot', 'Hskewsurge', 'Htide', 'Q']
Npeaks=35
dss = []
for driver in drivers:
    da_peaks = ds[driver].where(ds[f'{driver}_peaks'], -np.inf)
    da_group_yr = da_peaks.groupby('time.year')
    da_peaks_am = da_group_yr == da_group_yr.max('time')
    itime = idx_topn(da_peaks.where(da_peaks_am, -np.inf), n=Npeaks).load()
    ds_topn = xr.merge([
        ds[[driver, f'{driver}_rp', f'{driver}_rp', f'{driver}_peaks']], 
        ds_wdwmax.drop([v for v in list(ds.data_vars.keys()) if v.startswith(driver)])
    ]).isel(time=itime).transpose('rank', 'index')
    dss.append(ds_topn)
ds_out = xr.concat(dss, dim='driver').reset_coords('time')
ds_out['driver'] = xr.Variable('driver', drivers)
ds_out['rank'] = xr.Variable('rank', np.arange(Npeaks).astype(int)+1)
ds_out.to_netcdf(join(ddir, f'rivmth_am_wdw{wdw}.nc'))

In [None]:
loc = 1230 #1648 #1648
# ds_topn = xr.open_dataset(join(ddir, f'rivmth_am_wdw{wdw}.nc')).load()
df_ = ds_topn.sel(index=loc).reset_coords(drop=True).to_dataframe()
print(ds_topn['WSE'].sel(index=loc).mean('rank').reset_coords(drop=True).to_series())
# df_.loc[('WSE',)].head()

In [None]:
# ds_loc = ds.sel(index=loc).reset_coords(drop=True)
# ds_loc.where(ds_loc['Hskewsurge_day_peaks']==1).dropna('time').to_dataframe().sort_values('Hskewsurge_day_rp', ascending=False).head()
# ds_loc = ds.sel(index=loc).reset_coords(drop=True)
# peaks = ds_loc.where(ds_loc['Skewsurge']==1)['Skewsurge'].dropna('time').values
# Tpeaks, peaks = xs.weibull(peaks, nyears=35)
# fig,ax=plt.subplots(1,1)
# ax.scatter( Tpeaks, peaks,)
# ax.set_xscale('log')

In [None]:
rps = [0.1, 0.5, 1, 2, 4, 8, 16, 32]
cmap = ListedColormap(google_turbo_data[50:])
cmap.set_over('black')
cmap.set_under('gray')

norm = BoundaryNorm(rps, cmap.N)
xname = 'Htot'
kwargs = dict(cmap=cmap, norm=norm, colorbar=False, y='Q_rp',  x=f'{xname}_rp', c='WSE_rp')

labs = dict(
    x = f'{xname} rp [year]',
    y = 'Q rp [year]',
    c = 'WSE rp [year]',
)
mmap  = {'WSE': 'o', xname: '^', 'Q': 'd'}
zmap  = {'WSE': 3,   xname: 1,   'Q': 2}
smap  = {'WSE': 40 , xname: 30,  'Q': 30}
# cdict = {'WSE': 'WSE_rp' , 'Skewsurge': 'gray',  'Q': 'gray'}


xmin, ymin = df_[[kwargs['x'],kwargs['y']]].min().values
xmax, ymax = df_[[kwargs['x'],kwargs['y']]].max().values

fig, ax = plt.subplots(1, 1, figsize=(5, 4.5))
ax.set_yscale('log')
ax.set_xscale('log')
for d, g in df_.groupby('driver'):
    if d not in mmap: continue
    g.plot.scatter(ax=ax, s=smap[d], marker=mmap[d], zorder=zmap[d], **kwargs)

fig.tight_layout()
pos = ax.get_position()
cax = fig.add_axes([1., pos.y0, 0.015, pos.height])
cbar = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm, extend='both')

ax.set_ylabel(labs['y'])
ax.set_xlabel(labs['x'])
rps2 = [0.02, 0.05, 0.1, 0.2,]+rps
ax.set_yticks(rps2)
ax.set_yticklabels(rps2)
ax.set_xticks(rps2)
ax.set_xticklabels(rps2)
ax.set_xlim([xmin*0.9, xmax*1.1])
ax.set_ylim([ymin*0.9, ymax*1.1])

In [None]:
loc = 1648 #2279 #930 #1617 # #499 #1728 #1648 #1230 #1648 #1230
nyears=35
min_dist=14
q=99

ds_ts = ds_rivmth.sel(index=loc).reset_coords(drop=True)
df_ts = ds_ts.to_dataframe()
df_peaks = xr.merge([
#     xpeaks.get_peaks(ds_ts.drop('Qmsl'), min_dist=min_dist, dim='time').reset_coords(drop=True),
    xpeaks.get_peaks(ds_ts, min_dist=min_dist, dim='time').reset_coords(drop=True)
]).to_dataframe()
for var in df_ts.columns:
    df_ts[f'{var}_rp'] = xs._interp_ev(df_peaks[var].values, df_ts[var].values, nyears) 

events = df_peaks.index[df_peaks['WSE'] > df_ts['WSE'].quantile(q/100.)]

df_ts1 = df_ts.copy().rolling(3, min_periods=1,center=True).max()
df_ts2 = df_ts.copy().rolling(2*2+1, min_periods=1,center=True).max()

dfs = [
    df_ts.loc[events,:],
    df_ts1.loc[events,:],
    df_ts2.loc[events, :]
]

In [None]:
rps = [ 0.5, 1, 2, 4, 8, 16, 32]
cmap = ListedColormap(google_turbo_data[50:])
cmap.set_over('black')
norm = BoundaryNorm(rps, cmap.N)
kwargs = dict(cmap=cmap, norm=norm, colorbar=False, y='Q_rp',  c='WSE_rp')
cols = dict(
    x0='Htide_day_max_rp',
    x1='Hskewsurge_day_rp', 
    x2='Htot_day_max_rp', 
)
labs = dict(
    x0 = 'Htide rp [year]',
    x1 = 'Hsurge rp [year]',
    x2 = 'Htot rp [year]',
    y = 'Q rp [year]',
    c = 'WSE rp [year]',
)
window_size=1

df_ = dfs[window_size].sort_values('WSE', ascending=True).copy()

x0min,x1min,x2min, ymin = df_[[cols['x0'],cols['x1'],cols['x2'],kwargs['y']]].min().values
x0max,x1max,x2max, ymax = df_[[cols['x0'],cols['x1'],cols['x2'],kwargs['y']]].max().values

fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(13, 4.5), sharey=True)
ax0.set_yscale('log')
ax0.set_xscale('log')
ax1.set_xscale('log')
ax2.set_xscale('log')

df_.plot.scatter(ax=ax0, s=50, x=cols['x0'], **kwargs)
df_.plot.scatter(ax=ax1, s=50, x=cols['x1'], **kwargs)
df_.plot.scatter(ax=ax2, s=50, x=cols['x2'], **kwargs)
fig.tight_layout()
pos = ax2.get_position()
cax = fig.add_axes([1., pos.y0, 0.015, pos.height])
cbar = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm, extend='max')

ax0.set_ylabel(labs['y'])
ax0.set_xlabel(labs['x0'])
ax1.set_xlabel(labs['x1'])
ax2.set_xlabel(labs['x2'])
cax.set_ylabel(labs['c'])

rps2 = [0.02, 0.05, 0.1, 0.2,]+rps
ax1.set_yticks(rps2)
ax1.set_yticklabels(rps2)
ax0.set_xticks(rps2)
ax0.set_xticklabels(rps2)
ax1.set_xticks(rps2)
ax1.set_xticklabels(rps2)
ax2.set_xticks(rps2)
ax2.set_xticklabels(rps2)

ax1.set_ylim([ymin*0.9, ymax*1.1])
ax0.set_xlim([x0min*0.9, x0max*1.1])
ax1.set_xlim([x1min*0.9, x1max*1.1])
ax2.set_xlim([x2min*0.9, x2max*1.1])
##
ax1.set_title(f'bi-variate conditions (ID={loc})')
fn = join(fdir, f'bivariate_loc{loc:04d}_{model}_q{q}d{min_dist}_wdw{window_size}_rp.png')
# plt.savefig(fn)

## EV

In [None]:
q=99
da_wse = xr.open_zarr(fn_rivmth_ts)['WSE'].sel(ensemble=model)
da_wse_ts = da_wse.sel(index=loc).reset_coords(drop=True).load()
# q95 = da_wse_ts.quantile(q/100., dim='time')
# da_wse_peaks = xpeaks.peaks_over_threshold(da_wse_ts, threshold=q95, min_dist=min_dist).load()
# da_wse_peaks = da_wse_ts.resample(time='A').max('time').load()
da_wse_peaks = xpeaks.get_peaks(da_wse_ts, min_dist=min_dist, dim='time').fillna(-np.inf).resample(time='A').max('time').load()

In [None]:
fn_ev = join(ddir, 'rivmth_AMpeaks_d14_ev_gumb.nc')
ds_ev = xr.open_dataset(fn_ev)
ds_ev

da_wse_ev = ds_ev['WSE_ev'].sel(ensemble='cnrs')
diff_cmpnd_surge = (da_wse_ev.sel(scen='cmpnd') - da_wse_ev.sel(scen='surge'))
diff_cmpnd_surge.name = 'diff_surge'
diff_cmpnd_runoff = (da_wse_ev.sel(scen='cmpnd') - da_wse_ev.sel(scen='runoff'))
diff_cmpnd_runoff.name = 'diff_runoff'

runoff_is_main_driver = da_wse_ev.sel(scen='runoff') >= da_wse_ev.sel(scen='surge')
runoff_is_main_driver.name = 'runoff'
surge_is_main_driver = da_wse_ev.sel(scen='runoff') < da_wse_ev.sel(scen='surge')
surge_is_main_driver.name = 'surge'

compound_positive = np.logical_and(diff_cmpnd_surge>0, diff_cmpnd_runoff>0)
compound_positive.name = 'compound'
compound_negative = np.logical_and(diff_cmpnd_surge<0, diff_cmpnd_runoff<0)
compound_negative.name = 'compound_neg'

ds_cmpnd = xr.merge([
    diff_cmpnd_runoff,
    diff_cmpnd_surge,
    runoff_is_main_driver,
    surge_is_main_driver,
    compound_positive,
    compound_negative,
])

In [None]:
df = ds_cmpnd.sel(rp=100).reset_coords(drop=True).to_dataframe()
df[np.logical_and(df['compound'], df['runoff'])].sort_values(by='diff_runoff', ascending=False)
df[df['compound_neg']].sort_values(by='diff_runoff', ascending=True)

In [None]:
#  ds_cmpnd.sel(rp=100, index=3612)

In [None]:
scen_cmap = {'cmpnd': 'red', 'surge': 'blue', 'runoff': 'green', 'seas': 'cyan'}
scen_mmap  = {'cmpnd': 'x', 'surge': '^', 'runoff': 'o', 'seas': '.'}
# scen_labs  = {'surge': 'surge', 'tide': 'base'}
fn_ev = join(ddir, 'rivmth_AMpeaks_d14_ev_gumb_ci_N1e4.nc')
ds_ev = xr.open_dataset(fn_ev)
ds_ev

In [None]:
loc = 1648
alpha=0.10
model='cnrs'
ds_ev_loc = ds_ev.sel(index=loc)
fig, ax = plt.subplots(1,1)
T = ds_ev_loc.rp.values
kwargs = dict(nyears=None, rp=T)
for scen in ['surge', 'runoff', 'cmpnd', 'seas']:
    _ds = ds_ev_loc.sel(scen=scen, ensemble=model)
    Tpeaks, peaks = xs.weibull(_ds['WSE_am'].values)
    ax.scatter(Tpeaks, peaks, c=scen_cmap[scen], marker=scen_mmap[scen], label=scen)
    ax.plot(T, _ds['WSE_ev'].values, color=scen_cmap[scen])
    if scen == 'cmpnd' or True:
        ci = xs._lm_fit_ci(peaks, fdist=distr.gum, n_samples=1000, alphas=np.array([0.05, 0.95]), **kwargs)
        ci = _ds['WSE_ev_ci'].sel(alpha=[alpha, 1-alpha]).values
        ax.plot(T, ci[0,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
        ax.plot(T, ci[1,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
    ax.set_xscale('log')
    ax.legend()

In [None]:
da_wse = xr.open_zarr(fn_rivmth_ts).sel(ensemble=model, index=loc)
fig, (ax, ax1) = plt.subplots(1,2,figsize=(10,5))
for yr in range(1980,2015):
    da_wse['Qmsl'].sel(time=slice(f'01-01-{yr}',f'31-12-{yr}')).groupby('time.month').mean('time').plot(ax=ax, color='green')
da_wse['Qmsl'].groupby('time.month').mean('time').plot(ax=ax, color='cyan', linewidth=4)

for yr in range(1980,2015):
    da_wse['Hskewsurge_day'].sel(time=slice(f'01-01-{yr}',f'31-12-{yr}')).groupby('time.month').mean('time').plot(ax=ax1, color='blue')
da_wse['Hseas'].groupby('time.month').mean('time').plot(ax=ax1, color='cyan', linewidth=4)



In [None]:
scen='surge'
peaks = da_wse_peaks.sel(scen=scen).values
pars = distr.gum.lmom_fit(peaks)
pars

In [None]:
wl_peaks_am = xpeaks.get_peaks(ds_rivmth['WSE'], min_dist=min_dist, dim='time').fillna(-np.inf).resample(time='A').max('time')
wl_peaks_am
ds_rp = xs.xlm_fit(wl_peaks_am, fdist=distr.gum, nyears=None, rp=rps_out)

In [None]:
scen_cmap = {'surge': 'blue', 'seas': 'cyan', 'tide': 'green'}
scen_mmap  = {'surge': 'x', 'seas': 'o', 'tide': '^'}
scen_labs  = {'surge': 'surge', 'tide': 'base'}
rps_out = np.array([1.01,1.1,1.2,1.5, 2, 4, 8, 16, 35, 100])
kwargs = dict(nyears=None, rp=rps_out)

rps_in, peaks = xs.weibull(peaks, nyears=35)
swe_rp, par = xs._lm_fit(peaks, fdist=distr.gum, **kwargs)
# ci = xs._lm_fit_ci(peaks, fdist=distr.gum, n_samples=10000, alphas=np.array([0.05, 0.95]), **kwargs)

fig, ax = plt.subplots(1,1)
ax.scatter(rps_in, peaks, c=scen_cmap[scen], marker=scen_mmap[scen], label=scen_labs[scen])
ax.plot(rps_out, swe_rp, color=scen_cmap[scen])
ax.plot(rps_out, ci[0,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
ax.plot(rps_out, ci[1,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
ax.set_xscale('log')
# ax.set_xlim(ax1.get_xlim())
# ax.set_ylim(ax1.get_ylim())

In [None]:
# gumbel
# ppf = loc - scale * ln(-ln(p))
# pars['loc']-pars['scale']*np.log(-np.log(1-1/rps_out))

In [None]:
scen_cmap = {'surge': 'blue', 'seas': 'cyan', 'tide': 'green'}
scen_mmap  = {'surge': 'x', 'seas': 'o', 'tide': '^'}
scen_labs  = {'surge': 'surge', 'tide': 'base'}
rps_out = np.array([1.01,1.1,1.2,1.5, 2, 4, 8, 16, 35, 100])
kwargs = dict(nyears=None, rp=rps_out)

fig, ax = plt.subplots(1,1)

for scen in ['tide', 'surge']:
    peaks = da_wse_peaks.sel(scen=scen).values
    peaks = peaks[np.isfinite(peaks)]
    rps_in, peaks = xs.weibull(peaks, nyears=35)
    ax.scatter(rps_in, peaks, c=scen_cmap[scen], marker=scen_mmap[scen], label=scen_labs[scen])

    swe_rp, par = xs._lm_fit(peaks, fdist=distr.gum, **kwargs)
    ax.plot(rps_out, swe_rp, color=scen_cmap[scen])
    
    ci = xs._lm_fit_ci(peaks, fdist=distr.gum, n_samples=10000, alphas=np.array([0.05, 0.95]), **kwargs)
#     swe_rp2 = xs._interp_rps(peaks, rps_in, rps_out)
#     ax.plot(rps_out, swe_rp2, color=scen_cmap[scen], linestyle='--')

#     ci = xs._interp_rps_ci(peaks, n_samples=10000, alphas=np.array([0.05, 0.95]), **kwargs)
    ax.plot(rps_out, ci[0,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
    ax.plot(rps_out, ci[1,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
    
    if scen == 'tide':
        ymin = peaks.min()
#         pars = distr.gum.lmom_fit(peaks)
#         y = pars['loc']-pars['scale']*np.log(-np.log(1-1/rps_out))
#         ax.plot(rps_out, y, 'r')
#         break
    
#     ci = xs.lm_fit_ci(peaks, n_samples=1000, fdist=distr.gpa, **kwargs)
#     ax.plot(rps_out, ci[0,:], color=scen_cmap[scen], linestyle='--', alpha=0.6)
#     ax.plot(rps_out, ci[1,:], color=scen_cmap[scen], linestyle='--', alpha=0.6)
ax.set_xscale('log')
ax.set_xticks([1,10,100])
ax.set_xticklabels([1,10,100])
ax.set_xlim([0.99,101])
ylim = ax.get_ylim()
ax.set_ylim([ymin*0.98, ylim[-1]])
ax.set_ylabel('WSE [m]')
ax.set_xlabel('return period [year]')
ax.set_title(f'empirical ev distribution (ID={loc})')
ax.legend()
ax.grid(False)
fn = join(fdir, f'ev_loc{loc:04d}_{model}_q{q}d{min_dist}.png')
# plt.savefig(fn)

In [None]:
loc = 1617
model = 'cnrs'
alphas=np.array([0.05, 0.95])

fn_peak_rp_ci = join(ddir, f'rivmth_peaks_q{q}d{min_dist}_rp_ci.nc')
fn_peak_rp = join(ddir, f'rivmth_peaks_q{q}d{min_dist}_rp.nc')
ds_rp = xr.open_dataset(fn_peak_rp)
ds_rp_ci = xr.open_dataset(fn_peak_rp_ci)
ds_rp_ci.sel(index=loc, ensemble=model, alpha=alphas, scen=scen)['WSE_peaks']

In [None]:
scen_cmap = {'surge': 'blue', 'seas': 'cyan', 'tide': 'green'}
scen_mmap  = {'surge': 'x', 'seas': 'o', 'tide': '^'}
scen_labs  = {'surge': 'surge', 'tide': 'base'}
rps_out = np.array([1, 2, 4, 8, 16, 35])
kwargs = dict(nyears=35, rp=rps_out)

fig, ax = plt.subplots(1,1)

for scen in ['tide', 'surge']:
    peaks = da_wse_peaks.sel(scen=scen).values
    peaks = peaks[np.isfinite(peaks)]
    rps_in, peaks = xs.weibull(peaks, nyears=35)
    ax.scatter(rps_in, peaks, c=scen_cmap[scen], marker=scen_mmap[scen], label=scen_labs[scen])


    ci = xs._interp_rps_ci(peaks, n_samples=10000, alphas=np.array([0.05, 0.95]), **kwargs)
    ax.plot(rps_out, ci[0,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
    ax.plot(rps_out, ci[1,:], color=scen_cmap[scen], linestyle='-.', alpha=0.6)
    
    if scen == 'tide':
        ymin = xs._interp_rps(peaks, np.array([0.2]), nyears)[0]
    
#     ci = xs.lm_fit_ci(peaks, n_samples=1000, fdist=distr.gpa, **kwargs)
#     ax.plot(rps_out, ci[0,:], color=scen_cmap[scen], linestyle='--', alpha=0.6)
#     ax.plot(rps_out, ci[1,:], color=scen_cmap[scen], linestyle='--', alpha=0.6)
ax.set_xscale('log')
ax.set_xticks(rps)
ax.set_xticklabels(rps)
ax.set_xlim([0.2,40])
ylim = ax.get_ylim()
ax.set_ylim([ymin*0.98, ylim[-1]])
ax.set_ylabel('WSE [m]')
ax.set_xlabel('return period [year]')
ax.set_title(f'empirical ev distribution (ID={loc})')
ax.legend()
fn = join(fdir, f'ev_loc{loc:04d}_{model}_q{q}d{min_dist}.png')
plt.savefig(fn)