In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import hydromt
from os.path import join
import matplotlib.pyplot as plt
from hydromt_sfincs import SfincsModel
import subprocess
import os
import geopandas as gpd

In [None]:
# simulations
rdir = r'../../4_results'

scens = pd.read_csv(join(rdir, 'sim_SCEN.csv'), index_col=0).rename(columns={'h_tsw_rp': 'h_rp'})
print(len(scens.index))
scens.head(10)

In [None]:
# read events
ddir = r'../../1_data/2_forcing'
ds_q = xr.open_dataset(join(ddir, r'cama_discharge_beira_daily_events.nc'))
ds_q = np.maximum(ds_q, ds_q.sel(rps=0).max())
da_h = xr.open_dataarray(join(ddir, r'reanalysis_gtsm_v1_beira_extended_events.nc'))
da_p = xr.open_dataarray(join(ddir, r'era5_precip_beira_hourly_spatialmean_events.nc'))
events = dict(
    h = da_h.reset_coords(drop=True),
    qb = ds_q['qb'].reset_coords(drop=True),
    qp = ds_q['qp'].reset_coords(drop=True),
    p = da_p.reset_coords(drop=True),
)

In [None]:
# lagtimes relative to qb
from datetime import timedelta
lagtimes = pd.read_csv(join(rdir, 'lagtimes.csv'), index_col=0)['lag'].to_dict()
lagtimes = {k: timedelta(days=v) for k,v in lagtimes.items()}
lagtimes['h'] = lagtimes['s']
lagtimes['qb'] = timedelta(days=0)
postfix = ''
lagtimes

In [None]:
# "worst case" zero timelag
# lagtimes = {key: timedelta(days=0) for key in events.keys()}
# postfix = '_dt0'
# lagtimes

In [None]:
from hydromt_sfincs.utils import parse_datetime, write_timeseries, write_inp

tstart = '20200101 000000'
tstop = '20200115 000000'
tref = parse_datetime(tstart)
t0 = tref + timedelta(days=7)

def get_ts(dvar, rp, lagtimes=lagtimes, events=events):
    if rp not in events[dvar].rps:
        return
    ts = events[dvar].sel(rps=rp).to_series()
    ts.index += (lagtimes[dvar].total_seconds() / 3600)
    ts.index = t0 + np.array([timedelta(hours=dt) for dt in ts.index.values])
    dates = pd.date_range(tstart, tstop, freq=np.diff(ts.index.to_pydatetime())[0])
    ts = ts.reindex(dates, fill_value=0)
    return ts

## prepare simulation events


In [None]:
# read basemodel
mdir = r"../../3_models/sfincs"
basename = '00_base_riv'
mod0 = SfincsModel(join(mdir, basename), mode='r+')

In [None]:
# select buzi and pungwe rivers
src = mod0.forcing['dis'].vector.to_gdf().loc[[1,4],:]
# h boundary location based on gtsm output location
bnd = gpd.GeoDataFrame(
    index=np.atleast_1d(da_h['stations'].values),
    geometry=gpd.points_from_xy(
        np.atleast_1d(da_h['station_x_coordinate'].values), 
        np.atleast_1d(da_h['station_y_coordinate'].values)
    ),
    crs=4326
).to_crs(src.crs)

In [None]:
# modify config
config = mod0.config.copy()
config.update({
    'tref': tstart, 
    'tstart': tstart, 
    'tstop': tstop, 
    'outputformat': 'bin',
    'dtmaxout': 86400,
    'dtout': 86400,
    'dtwnd': 600,
    'alpha': 0.7,
    'precipfile': 'sfincs.precip',
    'bzsfile': 'sfincs.bzs',
    'bndfile': 'sfincs.bnd',
    'inifile': '../qb000_qp000_h000_p000/sfincs.zsini'
})
config.pop('netamprfile',None)
# config

In [None]:
index_cols = ['qb_rp', 'qp_rp', 'p_rp', 'h_rp']
scen_rps =scens[index_cols]
scens0_lst = [scens[np.all(np.diff(scen_rps) == 0, axis=1)].sort_values('qb_rp')[1:]]
for col in index_cols:
    zero_cols = [c for c in index_cols if c != col]
    scens0_lst.append(scens[np.all(scens[zero_cols]==0, axis=1)].sort_values(col)[1:])
scens0 = pd.concat(scens0_lst)

In [None]:
for i, run in scens0.iterrows():
    name = run['scen']
    root = join(mdir, f'{name}{postfix}')

    if os.path.isfile(join(root, 'sfincs.inp')): continue
    # print(f'>>{name}')
    # mod0 = SfincsModel(join(mdir, basename), mode='r+')
    mod0.set_root(root, mode='w')
    mod0._write_gis = False
    mod0.setup_config(**config)
    mod0.config.pop('restartfile', None)
    if np.all(run[:4]==0):
        mod0.setup_config(restartfile=mod0.config.pop('inifile'))

    qb0 = get_ts('qb', run['qb_rp']).rename(1)
    qp0 = get_ts('qp', run['qp_rp']).rename(4)
    q0 = pd.concat([qb0, qp0], axis=1)
    mod0.set_forcing_1d(ts=q0, xy=src, name='discharge')
    
    h0 = get_ts('h', run['h_rp']).rename(bnd.index.item()).to_frame()
    mod0.set_forcing_1d(ts=h0, xy=bnd, name='waterlevel')
    
    p0 = get_ts('p', run['p_rp'])
    if p0 is not None:
        mod0.set_forcing_1d(ts=p0, xy=None, name='precip')
    else:
        mod0.forcing.pop('precip', None)
        mod0.config.pop('precipfile')
    
    mod0.write_forcing()
    mod0.write_config(rel_path=f'../{basename}')
    
    mod0.plot_forcing()
    plt.close('all')
    # mod0.plot_basemap()
    # break