In [None]:
import xarray as xr
from os.path import join
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

## read and align data

In [None]:
# data directory
ddir = r'../../1_data/2_forcing'
rdir = r"../../4_results"

# data labels
labels = {
    'qb': 'Discharge Buzi\n[m3/s]',
    'qp': 'Discharge Pungwe\n[m3/s]',
    'p': 'Rainfall\n[mm/hr]',
    'h_tsw': 'Total waterlevel (incl. wave setup)\n[m+MSL]',
}

In [None]:
from eva import get_peaks, get_peak_hydrographs

# discharge
fnq = join(ddir, 'cama_discharge_beira_daily.nc')
daq = xr.open_dataset(fnq)['discharge'].load()
dsq = xr.merge([
    daq.sel(index=1).rename('qb').reset_coords(drop=True),
    daq.sel(index=4).rename('qp').reset_coords(drop=True)
])
# fnq = r'../../3_models/wflow/run_vito_ksath1000/output_src.nc'
# daq = xr.open_dataset(fnq)['q_river'].load()
# dsq = xr.merge([
#     daq.sel(index=0).rename('qb').reset_coords(drop=True),
#     daq.sel(index=3).rename('qp').reset_coords(drop=True)
# ])

In [None]:
# GTSM waterlevels + ERA5 waves
# contains: "waterlevel" (tide+surge), "tide", "surge", "shww"
fnh = join(ddir, 'reanalysis_gtsm_v1_beira_extended.nc')
dsh0 = xr.open_dataset(fnh).load()
dsh0 = dsh0.rename({'waterlevel': 'h_ts', 'surge': 's', 'tide': 't', 'shww': 'w'})
dsh0['h_tsw'] = dsh0['h_ts'] + 0.2*dsh0['w']
dsh0['sw'] = dsh0['s'].fillna(0) + 0.2*dsh0['w']
# skew surge (not used)
# high_tide = get_peaks(dsh0['t'].load(), period='12H').dropna('time').reindex_like(dsh0, 'nearest')
# dsh0['ss'] = dsh0['h_ts'] - high_tide
# dsh0['ssw'] = dsh0['h_tsw'] - high_tide

In [None]:
# ERA5 precipitation
fnp = join(ddir, 'era5_precip_beira_hourly_spatialmean.nc')
dap0 = xr.open_dataset(fnp, chunks='auto')['precip'].load()

In [None]:
# read timeseries and peaks data
period='AS-AUG'

ds = xr.open_dataset(join(ddir, 'beira_drivers_daily.nc'))
df_peaks0 = pd.read_csv(join(rdir, 'drivers_am_peaks.csv'), index_col=0, parse_dates=['time'])
ds_peaks = df_peaks0.to_xarray().reindex_like(ds)
df_bm = df_peaks0.resample(period).max().dropna()

drivers = ['qb', 'qp', 'p', 's', 'w']

In [None]:
# read distributions
from eva import get_frozen_dist, rps_dist, emperical_dist, _RPS

dist_params = pd.read_csv(join(rdir, 'marginal_params.csv'), index_col=0).rename({'h_tsw': 'h_tsw0'})
dists = {}
for dvar, row in dist_params.iterrows():
    params = row[-2:] if row[0] == 'gumb' else row[-3:]
    dists[dvar] = get_frozen_dist(params, row[0])

# surge
df_surge_emp_dist = pd.read_csv(join(rdir, 'marginal_surge.csv'), index_col=0)
dists['s'] = rps_dist(df_surge_emp_dist['rp[year]'].values, df_surge_emp_dist['surge[m]'].values)

# get h rps
df_sim_am0 = pd.read_csv(join(rdir, 'sim_AM.csv'), index_col=0)
dists['h_tsw'] = emperical_dist(df_sim_am0['h_tsw'].values, df_sim_am0['h_tsw'].size)

df_rps = pd.DataFrame(columns=dists.keys(), index=_RPS)
df_rps.index.name = 'rps'
for dvar in dists:
    df_rps[dvar] = dists[dvar].ppf(1-1/_RPS)


## analyse lag-times

In [None]:
import numpy as np
# correlation ufunc function
# from http://xarray.pydata.org/en/stable/dask.html#automatic-parallelization
def _covariance(x, y):
    return np.nanmean(
        (x - np.nanmean(x, axis=-1, keepdims=True))
        * (y - np.nanmean(y, axis=-1, keepdims=True)),
        axis=-1,
    )


def _pearson_correlation(x, y):
    return _covariance(x, y) / (np.nanstd(x, axis=-1) * np.nanstd(y, axis=-1))

def pearson_correlation(sim, obs, dim="time"):
    """Returns the Pearson correlation coefficient of two time series.

    Parameters
    ----------
    sim : xarray DataArray
        simulations time series
    obs : xarray DataArray
        observations time series
    dim : str, optional
        name of time dimension in sim and obs (the default is 'time')

    Returns
    -------
    xarray DataArray
        the pearson correlation coefficient
    """
    # wrap numpy function
    kwargs = dict(
        input_core_dims=[[dim], [dim]], dask="parallelized", output_dtypes=[float]
    )
    pearsonr = xr.apply_ufunc(_pearson_correlation, sim, obs, **kwargs)
    pearsonr.name = "pearson_coef"
    return pearsonr

In [None]:
from datetime import timedelta

def time_lag_crosscorr(
    sim, obs, quantile=None, lags=np.arange(-10,11,1), t_unit='days', dim='time'
):
    """Returns the time lag between two time series based on a lag time 
    with the maximum pearson correlation.
    
    Parameters
    ----------
    sim : xarray DataArray
        simulations time series
    obs : xarray DataArray
        observations time series
    quantile : numpy ndarray, optional
        quantile based threshold (the default is None, which does not use any threshold)
    lags : numpy ndarray, optional
        range of considered lag times (the default is np.arange(-10,11,1))
    t_unit : str, optional
        time unit used to parse lags to timedelta format (the default is 'days')
    dim : str, optional
        name of time dimension in sim and obs (the default is 'time')
    
    Returns
    -------
    xarray DataSet
        lag time and associated correlation coefficient
    """

    if quantile:
        obs.load()        
        obs = obs.where(obs>=obs.quantile(quantile, dim=dim))
    # loop through time lags and calculate cross correlation
    r = []
    lags = np.asarray(lags)
    time_org = sim[dim].to_index()
    for dt in lags:
        time_new = time_org + timedelta(**{t_unit: float(dt)})
        ts = slice(max(time_org.min(), time_new.min()), min(time_org.max(), time_new.max()))
        sim[dim] = time_new
        r.append(pearson_correlation(sim.sel(**{dim:ts}), obs.sel(**{dim:ts})))
    sim[dim] = time_org # reset time
    pearsonr = xr.concat(r, dim='dt')
    pearsonr['dt'] = xr.Variable('dt', lags)
    # get maximum cross corr
    pearsonr_max = pearsonr.max(dim='dt')
    pearsonr_max.name = 'lag_rho'
    pearsonr_max.attrs.update(description='maximum pearson coefficient for given time lag')
    # get lag time of maximum cross corr
    # NOTE that we assume a evenly spaced lag times
    lag = xr.where(
        np.isfinite(pearsonr).sum(dim='dt')==lags.size,
        pearsonr.argmax(dim='dt', skipna=False), 
        np.nan)*np.diff(lags)[0] + lags.min()
    lag.name = 'lag'
    lag.attrs.update(description='time lag with maximum pearson coefficient', unit=t_unit)
    # merge max cross corr and lag tiem
    return xr.merge([lag, pearsonr_max]), pearsonr

In [None]:
ref_dvar = 'qb'
dt_lst =[]
fig, ax = plt.subplots(1,1)
dvar_lst = ['qp', 'p', 's', 'w']
for dvar in dvar_lst:
    da_out, pearsonr = time_lag_crosscorr(ds[ref_dvar], ds[dvar])
    dt_lst.append(da_out.reset_coords(drop=True).compute())
    # rlist.append(rlist)
    pearsonr.plot(label=dvar, ax=ax)
ds_dt = xr.concat(dt_lst, dim='dvar')
ds_dt['dvar'] = xr.IndexVariable('dvar', dvar_lst)
df_timelags = ds_dt['lag'].to_series()
df_timelags.to_csv(r'../../4_results/lagtimes.csv')
ax.legend()

## create model events

In [None]:
events_dict = {}

In [None]:
# discharge: qb, qp
# get design hydrographs by vertical averaging normalized AM peak hydrographs
df_rps.loc[0,:] = 0
df_rps = df_rps.sort_index()
q_event_lst = []
for dvar in ['qp', 'qb']:
    # update RP0 based on wet season average
    m = dsq[dvar].time.dt.month
    df_rps.loc[0,dvar] = dsq[dvar].isel(time=np.logical_or(m>=11, m<=4)).mean('time').compute().item()
    # create events
    daq_hydrograph0 = get_peak_hydrographs(dsq[dvar], ds_peaks[dvar], wdw_size=21).compute()
    daq_hydrograph = daq_hydrograph0.mean('peak')
    # if dvar in df_timelags:
    #     daq_hydrograph['time'] = daq_hydrograph['time'] + df_timelags[dvar]
    daq_events = df_rps[dvar].to_xarray() * daq_hydrograph
    daq_events['time'] = daq_events['time']*24
    events_dict[dvar] = daq_events
    q_event_lst.append(daq_events)
dsq_events = xr.merge(q_event_lst).dropna('time')
dsq_events['time'].attrs.update(units='hour')
_ = dsq_events['qb'].sel(rps=100).plot.line(x='time')
_ = dsq_events['qp'].sel(rps=100).plot.line(x='time')
# dsq_events.to_netcdf(fnq.replace('.nc', '_events.nc'))

In [None]:
daq_hydrograph0.reset_coords(drop=True).rename('q').to_dataframe().unstack(0).plot(color='k', lw=0.5, alpha=0.5, legend=False)
daq_hydrograph.to_series().plot(color='r', lw=2, legend=False)

In [None]:
# coastal
def sort_center(x, dim):
    a = x.get_axis_num(dim)
    n = x[dim].size
    x_sorted = np.apply_along_axis(np.sort, a, x)
    idx = [None if i!=a else range(n) for i in range(x.ndim)]
    reorder = np.append(np.arange(0, n, 2), np.arange(1, n, 2)[::-1])
    return xr.DataArray(np.take_along_axis(x_sorted, reorder[idx], a), x.coords)

# get design hydrographs by horizontal averaging normalized AM peak hydrographs
dasw = dsh0['sw']#.rolling(time=48, center=True).mean('time')
dasw_peaks = get_peaks(dasw, "BM", min_dist=6*24*5, period='AS-AUG')
dasw_hydrographs = get_peak_hydrographs(
    dasw,
    dasw_peaks, 
    wdw_size=int(6*24*5), 
    normalize=True,
    # n_peaks=20
)
dasw_hydrographs = np.maximum(0,dasw_hydrographs.dropna('peak'))
dasw_hydrographs['time'] = dasw_hydrographs['time']/6  # hr
dasw_hydrograph = sort_center(dasw_hydrographs, dim='time').mean('peak') # horizontal averaging
# dasw_hydrograph = dasw_hydrographs.mean('peak') # vertical averaging

da_mhws_peaks = get_peaks(dsh0['t'], "BM", min_dist=6*24*10, period="29.5D")
da_mhws_hydrographs = get_peak_hydrographs(
    dsh0['t'], da_mhws_peaks, 
    wdw_size=int(6*24*14.5), 
    normalize=False,
)
da_mhws_hydrographs['time'] = da_mhws_hydrographs['time']/6 # hr
da_mhws_hydrograph = da_mhws_hydrographs.mean('peak')
mhws = 3.8 # highest astronomical tide IHO tidal constituents
df_rps.loc[0,'h_tsw'] = mhws
da_mhws_hydrograph = da_mhws_hydrograph/da_mhws_hydrograph.max('time')*mhws
dasw_hydrograph = dasw_hydrograph.reindex(time=da_mhws_hydrograph.time, fill_value=0)

da_h_events = da_mhws_hydrograph + dasw_hydrograph * (df_rps['h_tsw'].to_xarray()-mhws)
# da_h_events['time'] = da_h_events['time'] + df_timelags[['s','w']].min()*24
da_h_events['time'].attrs.update(units='hour')
_ = da_h_events.sel(rps=[0,2,500]).plot.line(x='time')
# da_h_events.to_netcdf(fnh.replace('.nc', '_events.nc'))
events_dict['h_tsw'] = da_h_events

In [None]:
dasw_hydrographs.reset_coords(drop=True).rename('sw').to_dataframe().unstack(0).plot(color='k', lw=0.5, alpha=0.5, legend=False)
dasw_hydrograph.reindex(time=dasw_hydrographs.time).to_series().plot(color='r', lw=2, legend=False)

In [None]:
from eva import eva_idf, get_hyetograph

# precip
# get design events from IDF curves and alternating block method
durations=np.array([1, 2, 3, 6, 12, 24], dtype=int)
da_p_bm = eva_idf(dap0, durations=durations, distribution='gumb', rps=df_rps.index.values)
da_p_events = get_hyetograph(da_p_bm['return_values'], dt=1, length=durations[-1])
# da_p_events['time'] = da_p_events['time'] + df_timelags['p']*24
da_p_events['time'].attrs.update(units='hour')
_ = da_p_events.sel(rps=[2,500]).plot.line(x='time')
# da_p_events.to_netcdf(fnp.replace('.nc', '_events.nc'))
events_dict['p'] = da_p_events

In [None]:
# lagtimes relative to qb
df_timelags['h_tsw'] = df_timelags['s']

In [None]:
from string import ascii_uppercase as letters

rps = [0,2,100]
dsq_events['qb'].sel(rps=rps)
n = 4
fig, axes = plt.subplots(n, 1, figsize=(10, 3*n), sharex=True)

for i, dvar in enumerate(labels):
    dt0 = 0
    if dvar in df_timelags:
        dt0 = df_timelags[dvar]*24
    df0 = events_dict[dvar].sel(rps=rps).transpose('rps', ...).to_series().unstack(0)
    df0.columns = ['non-flood', '2-year', '100-year']
    df0.columns.name = 'design event'
    df0.index = (df0.index.values+dt0)/24
    df0.plot(ax=axes[i], legend=i==0)
    axes[i].set_ylabel(labels[dvar])
# dsq_events['qp'].sel(rps=rps).to_series().unstack(0).plot(ax=axes[1], legend=False)
# axes[1].set_ylabel(labels['qp'])
# da_p_events.sel(rps=rps).to_series().unstack().plot(ax=axes[2], legend=False)
# axes[2].set_ylabel(labels['p'])
# da_h_events.sel(rps=rps).to_series().unstack().plot(ax=axes[3], legend=False)
# axes[3].set_ylabel(labels['h_tsw'])
    title = labels[dvar].split('\n')[0]
    axes[i].set_title(f'{letters[i]}) {title}')
    axes[i].grid()

axes[3].set_xlabel('time relative to Buzi discharge peak [day]')
axes[3].set_xlim([-6, 4])
plt.savefig(join(r'../../4_results', f'drivers_events.png'), dpi=300, bbox_axes='tight')