In [None]:
from hydromt_sfincs import SfincsModel
from os.path import join
import xarray as xr

In [None]:
model_root = r'../../3_models/SFINCS'
fdir = r'../../4_results'
base_root = join(model_root, "01_rivpowlaw")
events = {'idai': ('20190314', '20190322'), 'eloise':('20210119', '20210128')}
scens = {'h': ['bzs'], 'q': ['dis'], 'p': ['netampr', 'dis', 'bzs']}


In [None]:
# read data
forcing = {}
for event, dates in events.items():
    forcing[event] = {}
    forcing[f'{event}_base'] = {}
    for scen, fnames in scens.items():
        mod = SfincsModel(join(base_root, f'{event}_{scen}'), mode='r')
        forcing[event][f'{fnames[0]}'] = mod.forcing[fnames[0]].load()
        for fname in fnames[1:]:
            forcing[f'{event}_base'][f'{fname}'] = mod.forcing[fname].load()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
from string import ascii_uppercase as abcd


kwargs=dict()
colors = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']
kwargs0 = dict(figsize=(14, 9))
kwargs0.update(**kwargs)
fig, axes = plt.subplots(3, 2, **kwargs0)


for j, (event, dates) in enumerate(events.items()):
    trange = slice(*dates)
    locator = mdates.AutoDateLocator()
    formatter = mdates.ConciseDateFormatter(locator)
    for i, name in enumerate(forcing[event]):
        da = forcing[event][name].sel(time=trange)
        longname = da.attrs.get("standard_name", "")
        unit = da.attrs.get("unit", "")
        prefix = ""
        if da.ndim == 3:
            da = da.mean(dim=[da.raster.x_dim, da.raster.y_dim])
            prefix = "mean "
        # convert to Single index dataframe (bar plots don't work with xarray)
        df = da.squeeze().to_series()
        if isinstance(df.index, pd.MultiIndex):
            df = df.unstack(0)
        ylim = [0, df.values.max()*1.05]
        # convert dates a-priori as automatic conversion doesn't always work
        df.index = mdates.date2num(df.index)
        if longname == "precipitation":
            longname = 'runoff'
            axes[i,j].bar(df.index, df.values, facecolor="darkblue", label='spatial avg.')
        else:
            # tide / disclim
            da1 = forcing[f'{event}_base'][name].sel(time=trange)
            df1 = da1.squeeze().to_series().unstack(0)
            df1.index = mdates.date2num(df1.index)
            
            if longname == 'waterlevel':
                dh = 0.8  # undo EGM correction
                unit = 'm+MSL'
                df1, df = df1 - dh, df - dh
                ylim = [df.values.min()*1.05, df.values.max()*1.05]
                # argmax = df.max(0).argmax()
                # print(argmax)

            bound = 'H' if longname == 'waterlevel' else 'Q'
            cols = [3] if longname == 'waterlevel' else [0,3]
            icol0 = 0 if longname == 'waterlevel' else 1


            # plot actual
            suffix = 'tot' if longname == 'waterlevel' else 'event'
            df.columns = [f'{bound}{ib} - {suffix}' for ib in df.columns.values]
            for icol, col in enumerate(df.columns[cols]):
                df.loc[:,col].plot.line(ax=axes[i,j], c=colors[icol+icol0], label=col)
                
            # plot tide/clim
            suffix = 'tide' if longname == 'waterlevel' else 'clim.'
            df1.columns = [f'{bound}{ib} - {suffix}' for ib in df1.columns.values]
            for icol, col in enumerate(df1.columns[cols]):
                df1.loc[:,col].plot.line(ax=axes[i,j], ls='--', color=colors[icol+icol0], label=col)


        if j == 1:
            ylim = axes[i,0].get_ylim()
            axes[i,j].legend(
                bbox_to_anchor=(1, 1),
                loc="upper right",
                ncol=2,
            )

        axes[i,j].set_ylim(ylim)
        if j == 0:
            axes[i,j].set_ylabel(f"{prefix}{longname}\n[{unit}]")
        else:
            axes[i,j].set_yticklabels('')

        axes[i,j].set_xlim([df.index[0], df.index[-1]])
        axes[i,j].xaxis.set_major_locator(locator)
        axes[i,j].xaxis.set_major_formatter(formatter)
        if i != 2:
            axes[i,j].set_xticklabels('')
        if i == 0:
            axes[i,j].set_title(event.capitalize())
        axes[i,j].text(0.01, 0.9, abcd[i*2+j].upper(), fontsize=14, fontweight='bold', transform=axes[i,j].transAxes)

fig.subplots_adjust(wspace=0.05) 
plt.savefig(join(fdir, f'forcing.png'), dpi=225, bbox_inches="tight")