# *** Compare SFR stages Vs RIV stages ***

- RIV stage comes from the RIV IDF files that were given to us by Deltares.
- Sys 1: DETAILWATERGANGEN (contains the detailed network) will be used for the calculations. The other systems are mostly/completely absent in our region of interest. For proof of that, check ./SFRstage_Vs_RIVstage_NBr43.ipynb 

# 0. Initial

## 0.0. Imports

In [1]:
import imod

In [2]:
import WS_Mdl.utils as U
import WS_Mdl.utils_imod as UIM
import WS_Mdl.geo as G

In [3]:
import importlib as IL
IL.reload(U)
IL.reload(UIM)
IL.reload(G)

<module 'WS_Mdl.geo' from 'G:\\code\\WS_Mdl\\geo.py'>

In [4]:
import os
from os import listdir as LD, makedirs as MDs
from os.path import join as PJ, basename as PBN, dirname as PDN, exists as PE
import shutil as sh
import pandas as pd
from datetime import datetime as DT
import matplotlib.pyplot as plt
from pathlib import Path
import re
import xarray as xr

In [5]:
import plotly.graph_objects as go
import re

In [None]:
from pathlib import Path
import xarray as xr
import pandas as pd
from tqdm.notebook import tqdm

## 0.1. Options

In [6]:
U.set_verbose(False)

In [7]:
MdlN = 'NBr47'
MdlN_RIV = 'NBr43'

In [8]:
d_Pa = U.get_MdlN_Pa(MdlN)
d_INI = U.INI_to_d(d_Pa['INI'])
Xmin, Ymin, Xmax, Ymax, cellsize, N_R, N_C = U.Mdl_Dmns_from_INI(d_Pa['INI'])
SP_date_1st = DT.strftime(DT.strptime(d_INI['SDATE'], '%Y%m%d'), '%Y-%m-%d')
dx = dy = float(d_INI['CELLSIZE'])

d_Pa_RIV = U.get_MdlN_Pa(MdlN_RIV)
d_INI_RIV = U.INI_to_d(d_Pa_RIV['INI'])
SP_date_1st_RIV = DT.strftime(DT.strptime(d_INI_RIV['SDATE'], '%Y%m%d'), '%Y-%m-%d')

In [15]:
N_system_RIV = 1; N_system_DRN = 2

In [9]:
if (Xmin, Ymin, Xmax, Ymax, cellsize, N_R, N_C) != U.Mdl_Dmns_from_INI(d_Pa_RIV['INI']):
    print("Warning: Model dimensions for RIV model differ from main model.")

In [10]:
SP_date_1st_RIV, SP_date_last_RIV = [DT.strftime(DT.strptime(d_INI_RIV[f'{i}'], '%Y%m%d'), '%Y-%m-%d') for i in ['SDATE', 'EDATE']]

In [84]:
load_head = True
load_head_RIV = True
load_P = True

# 1. Load Stuff

## 1.0. Load PRJ

In [11]:
PRJ, PRJ_OBS = UIM.r_PRJ_with_OBS(d_Pa['PRJ'])

## 1.1. Load SFR

In [13]:
DF_ = pd.read_csv(UIM.get_SFR_OBS_Out_Pas(MdlN)[0])  # 666 replace 1. This needs to be standardized.
DF = DF_[[i for i in DF_.columns if i == 'time' or 'L' in i]].copy()
DF['time'] = pd.to_datetime(SP_date_1st) + pd.to_timedelta(DF['time'] - 1, unit='D')

GDF_SFR = U.SFR_PkgD_to_DF(MdlN)

In [17]:
RIV_params = ['conductance', 'stage', 'bottom_elevation', 'infiltration_factor']
PRJ_RIV, PRJ_OBS_RIV = UIM.r_PRJ_with_OBS(d_Pa_RIV['PRJ'])

l_N_system_RIV_print = []
for i in range(PRJ_RIV['(riv)']['n_system']):
    l_N_system_RIV_print.append(f'System {i + 1}:')
    for j in RIV_params:
        if 'path' in PRJ_RIV['(riv)'][j][i]:
            l_N_system_RIV_print.append(f'\t{j:<20}: {PBN(PRJ_RIV["(riv)"][j][i]["path"])}')
        elif 'constant' in PRJ_RIV['(riv)'][j][i]:
            l_N_system_RIV_print.append(f'\t{j:<20}: {PRJ_RIV["(riv)"][j][i]["constant"]}')
        else:
            l_N_system_RIV_print.append(f'\t{j:<20}: N/A')

str_N_system_RIV_print = '\n'.join(l_N_system_RIV_print)

if N_system_RIV is None:
    print(
        f'  - You need to choose one of {PRJ_RIV["(riv)"]["n_system"]} river systems.\nHere is some information about the RIV systems:\n{str_N_system_RIV_print}\n'
    )
    N_system_RIV = int(input('Select the number of the RIV system you want to plot (1-indexed).'))
elif N_system_RIV < 1 or N_system_RIV > PRJ_RIV['(riv)']['n_system']:
    print(f'Invalid system number. It should be >= 1 & <= {PRJ_RIV["(riv)"]["n_system"]}.')

A_RIV_Stg = UIM.xr_clip_Mdl_Aa(imod.idf.open(PRJ_RIV['(riv)']['stage'][N_system_RIV - 1]['path']), MdlN=MdlN)
A_RIV_Btm = UIM.xr_clip_Mdl_Aa(imod.idf.open(PRJ_RIV['(riv)']['bottom_elevation'][N_system_RIV - 1]['path']), MdlN=MdlN)

if A_RIV_Btm.notnull().sum().values == (A_RIV_Btm == A_RIV_Stg).sum().values:
    print('\tAll river bottom elevations are equal to stage elevations.')


	All river bottom elevations are equal to stage elevations.


## 1.2. Load DRN

In [18]:
DRN_params = ['conductance', 'elevation', 'n_system', 'active']
l_N_system_DRN_print = []
for i in range(PRJ['(drn)']['n_system']):
    l_N_system_DRN_print.append(f'System {i + 1}:')
    if i in PRJ['(drn)'][j]:
        for j in DRN_params:
            if 'path' in PRJ['(drn)'][j][i]:
                l_N_system_DRN_print.append(f'\t{j:<20}: {PBN(PRJ["(drn)"][j][i]["path"])}')
            elif 'constant' in PRJ['(drn)'][j][i]:
                l_N_system_DRN_print.append(f'\t{j:<20}: {PRJ["(drn)"][j][i]["constant"]}')
            else:
                l_N_system_DRN_print.append(f'\t{j:<20}: N/A')

str_N_system_DRN_print = '\n'.join(l_N_system_DRN_print)

if N_system_DRN is None:
    print(
        f'  - You need to choose one of {PRJ["(drn)"]["n_system"]} river systems.\nHere is some information about the DRN systems:\n{str_N_system_DRN_print}\n'
    )
    N_system_DRN = int(input('Select the number of the DRN system you want to plot (1-indexed).'))
elif N_system_DRN < 1 or N_system_DRN > PRJ['(drn)']['n_system']:
    print(f'Invalid system number. It should be >= 1 & <= {PRJ["(drn)"]["n_system"]}.')

A_DRN_Elv =  UIM.xr_clip_Mdl_Aa(imod.idf.open(PRJ['(drn)']['elevation'][N_system_DRN - 1]['path']), MdlN=MdlN)

## 1.3. Load TOP BOT

In [19]:
l_Pa_TOP = [i['path'] for i in PRJ['(top)']['top']]
A_TOP = UIM.xr_clip_Mdl_Aa(
    imod.idf.open(l_Pa_TOP, pattern=r'TOP_L{layer}_{name}'), MdlN=MdlN
)  # We're just doing this to avoid errors - using {name} to capture the model number part - imod will use it for the DataArray name.
l_Pa_BOT = [i['path'] for i in PRJ['(bot)']['bottom']]
A_BOT = UIM.xr_clip_Mdl_Aa(imod.idf.open(l_Pa_BOT, pattern=r'BOT_L{layer}_{name}'), MdlN=MdlN)

## 1.4. Load head

In [20]:
DF_ = pd.DataFrame({'L': GDF_SFR.k.value_counts().index, 'count': GDF_SFR.k.value_counts()})
DF_['percentage'] = (GDF_SFR.k.value_counts(normalize=True) * 100).apply(lambda x: round(x, 2))
l_SFR_Ls = [int(i) for i in sorted(DF_.loc[DF_['percentage'] >= 1, 'L'].unique())]
print(
    f"--- Reading HD data...\nOnly nSFR-relevant Ls will be loaded - Ls: {l_SFR_Ls}.\n\tEach of those Ls contains at least 1% of the SFR reaches.\n\tIf you request to plot a TS for a reach not included in those Ls, it'll have to load separately, which may take some time."
)

if load_head:
    try:
        A_HD_ = imod.mf6.open_hds(
            hds_path=d_Pa['Out_HD_Bin'],
            grb_path=d_Pa['DIS_GRB'],
            simulation_start_time=pd.to_datetime(SP_date_1st),
            time_unit='d',
        ).astype('float32')
    except Exception as e:
        print(f'游댮游댮游댮 - An error occurred while loading HD data: {e}')

if load_head_RIV:
    try:
        A_HD_RIV_ = imod.mf6.open_hds(
            hds_path=d_Pa_RIV['Out_HD_Bin'],
            grb_path=d_Pa_RIV['DIS_GRB'],
            simulation_start_time=pd.to_datetime(SP_date_1st_RIV),
            time_unit='d',
        ).astype('float32')
    except Exception as e:
        print(f'游댮游댮游댮 - An error occurred while loading HD data: {e}')


--- Reading HD data...
Only nSFR-relevant Ls will be loaded - Ls: [1, 3, 4, 5, 7, 9, 10, 11].
	Each of those Ls contains at least 1% of the SFR reaches.
	If you request to plot a TS for a reach not included in those Ls, it'll have to load separately, which may take some time.


## 1.5 Load Rainfall

In [96]:
DF_meteo = pd.read_csv(PRJ['extra']['paths'][2][0], names=['day', 'year', 'P', 'PET'])    

In [103]:
DF_meteo['DT'] = pd.to_datetime(
    DF_meteo['year'].astype(int).astype(str) + '-' + (DF_meteo['day'].astype(int) + 1).astype(str),
    format='%Y-%j'
)

In [110]:
def load_P_Aa(DF_meteo, Xmin, Ymin, Xmax, Ymax, d_Pa):
    base_dir = Path(d_Pa['PRJ']).parent

    l_da = []
    # Iterate over the DataFrame rows to read each .asc file # I've tried parallelizing this, and the speed was about the same. So I'm keeping the simpler serial approach.
    # Using lazy loading and concat is generally faster than combine_by_coords for known structure
    for index, row in tqdm(DF_meteo.iterrows(), total=len(DF_meteo), desc="Loading P"):
        file_path = (base_dir / row['P']).resolve()
        
        # Read the .asc file using imod.rasterio (lazy)
        da = imod.rasterio.open(file_path)
        if 'band' in da.dims:
            da = da.squeeze('band', drop=True)
        
        # Sort by y to ensure slicing works correctly (rasters are often descending y)
        da = da.sortby('y')
        
        # Select only the Area of Interest
        da = da.sel(x=slice(Xmin, Xmax), y=slice(Ymin, Ymax))
        
        l_da.append(da)
    
    if l_da:
        # Concatenate along the 'time' dimension directly using the datetime values
        # This keeps 'time' as a single dimension instead of MultiIndex (year, day)
        # We use 'DT' column which we prepared earlier
        times = DF_meteo['DT'].values
        A_P = xr.concat(l_da, dim='time')
        A_P = A_P.assign_coords(time=times)
    else:
        print("No data loaded.")
        A_P = None
    return A_P

# 2. Load cell data and plot

In [112]:
import sfr_plotting as sf
import importlib as IL

---

In [113]:
from tqdm.dask import TqdmCallback

while True:
    In1 = input(
        f'Start date is {DF["time"].min()} to {DF["time"].max()} for model {MdlN}.\n\nPress any key except Y and E to continue using this temporal extent.\nPress Y to set another temporal extent.\nPress E to exit.\n\n'
    )
    if In1.upper() == 'E':
        break

    start_date = (
        pd.to_datetime(input('Enter start date (YYYY-MM-DD):\n')) if In1.upper() == 'Y' else DF['time'].min()
    )
    end_date = pd.to_datetime(input('Enter end date (YYYY-MM-DD):\n')) if In1.upper() == 'Y' else DF['time'].max()
    DF_trim = (
        DF.copy().loc[(DF['time'] >= start_date) & (DF['time'] <= end_date)].reset_index(drop=True)
        if In1.upper() == 'Y'
        else DF.copy()
    )

    with TqdmCallback(desc="Loading HD"):
        if load_head:
            A_HD = A_HD_.sel(layer=l_SFR_Ls).sel(time=slice(start_date, end_date)).compute()
        if load_head_RIV:
            A_HD_RIV = A_HD_RIV_.sel(layer=l_SFR_Ls).sel(time=slice(start_date, end_date)).compute()
    if load_P:
        DF_meteo_DT_trim = DF_meteo.loc[(DF_meteo['DT'] >= start_date) & (DF_meteo['DT'] <= end_date)]
        A_P = load_P_Aa(DF_meteo_DT_trim, Xmin, Ymin, Xmax, Ymax, d_Pa)

    while True:
        In2 = input(
            "Provide a cell ID (L R C) (with spaces or commas as separators) or a reach number. If you're providing a reach number, prefix it with 'R' (e.g., R15). Type 'E' to quit:\n"
        )
        try:                
            if In2.upper() == 'E':
                break
            elif In2.upper().startswith('R'):
                reach = int(In2.upper().replace('R', ''))
                L, R, C = U.reach_to_cell_id(reach, GDF_SFR)
            else:
                parts = re.split(r'[,\s]+', In2.strip())  # Split by commas and/or whitespace
                L, R, C = [int(j) for j in parts]
                reach = GDF_SFR.loc[(GDF_SFR.k == L) & (GDF_SFR.i == R) & (GDF_SFR.j == C), 'reach'].values[0]
            X, Y = U.reach_to_XY(reach, GDF_SFR)

            SFR_Stg = DF_trim[['time', f'L{L}_R{R}_C{C}']]

            # Pass values and plotting params for Ct values to dictionary
            d_Ct = {
                'SFR riverbed top': {
                    'value': GDF_SFR.loc[GDF_SFR['rno'] == reach, 'rtp'].values[0],
                    'line': dict(color='#ff0000', width=2, dash='dash'),
                },
                'RIV stage': {
                    'value': round(float(UIM.xr_get_value(A_RIV_Stg, X, Y, dx, dy)), 3),
                    'line': dict(color='#0000ff', width=3),
                },
                'RIV bottom': {
                    'value': round(float(UIM.xr_get_value(A_RIV_Btm, X, Y, dx, dy)), 3),
                    'line': dict(color='#0000ff', width=2, dash='dash'),
                },
                'DRN elevation': {
                    'value': round(float(UIM.xr_get_value(A_DRN_Elv, X, Y, dx, dy)), 3),
                    'line': dict(color='#d000ff', width=2, dash='dash'),
                },
                'top': {
                    'value': round(float(UIM.xr_get_value(A_TOP, X, Y, dx, dy, L=L)), 3),
                    'line': dict(color='#a47300', width=2, dash='dash'),
                },
                'bottom': {
                    'value': round(float(UIM.xr_get_value(A_BOT, X, Y, dx, dy, L=L)), 3),
                    'line': dict(color='#a47300', width=2, dash='dash'),
                },
            }

            # Prepare args for plotting function & make Directory for output
            r_info = {'reach': reach, 'L': L, 'R': R, 'C': C, 'X': X, 'Y': Y, 'MdlN': MdlN, 'MdlN_RIV': MdlN_RIV}
            X_axis = SFR_Stg['time']

            Pa_Out = PJ(d_Pa['PoP_Out_MdlN'], f'SFR/SFR_stage_TS-reach{reach}.html')
            os.makedirs(os.path.dirname(Pa_Out), exist_ok=True)
            
            args = [r_info, X_axis, SFR_Stg, d_Ct, Pa_Out]

            # Extract head time series at this location (compute to load from Dask)
            if load_head:
                with TqdmCallback(desc="Loading reach HD"):
                    try:
                        HD_TS = UIM.xr_get_value(A_HD.sel(time=slice(start_date, end_date)), X, Y, dx, dy, L=L)
                    except Exception as e:
                        print(
                            f"L {L} probably contains less than 1% of the SFR reaches, so its HD wasn't loaded.\nAn error occurred while extracting head time series: {e}"
                        )
                        print('Attempting to load full head data for the specified layer...')
                        A_HD_L = A_HD_.sel(layer=L)
                        HD_TS = UIM.xr_get_value(A_HD_L.sel(time=slice(start_date, end_date)), X, Y, dx, dy)
                HD = pd.DataFrame({'time': HD_TS.time.values, 'head': HD_TS.values})
                args.append(HD)

            if load_head_RIV:
                with TqdmCallback(desc="Loading reach HD RIV"):
                    try:
                        HD_TS_RIV = UIM.xr_get_value(A_HD_RIV.sel(time=slice(start_date, end_date)), X, Y, dx, dy, L=L)
                    except Exception as e:
                        print(
                            f"L {L} probably contains less than 1% of the SFR reaches, so its HD wasn't loaded.\nAn error occurred while extracting head time series: {e}"
                    )
                    print('Attempting to load full head data for the specified layer...')
                    A_HD_L_RIV = A_HD_RIV_.sel(layer=L)
                    HD_TS_RIV = UIM.xr_get_value(A_HD_L_RIV.sel(time=slice(start_date, end_date)), X, Y, dx, dy)
                HD_RIV = pd.DataFrame({'time': HD_TS_RIV.time.values, 'head': HD_TS_RIV.values})
                args.append(HD_RIV)

            if load_P:
                with TqdmCallback(desc="Loading reach P"):
                    P_TS = UIM.xr_get_value((A_P.sel(time=slice(start_date, end_date))), X, Y, dx, dy)
                args.append(P_TS)

            IL.reload(sf)
            sf.plot_SFR_reach_TS(*args)
        except Exception as e:
            print(f"An error occurred while processing the input: {e}\nPlease try again with a valid cell ID or reach number.")

Loading HD:   0%|          | 0/7669 [00:00<?, ?it/s]

Loading HD:   0%|          | 0/3286 [00:00<?, ?it/s]

Loading P:   0%|          | 0/2557 [00:00<?, ?it/s]

游리 - Retrieved value coordinates (X: 114500.0, Y: 394500.0) differ from requested coordinates (X: 114437.5, Y: 394962.5) by more than half the cell size (dx: 25.0, dy: 25.0).
That may be valid if the resolution of the two arrays is different, but you should double-check.
An error occurred while processing the input: invalid literal for int() with base 10: ''
Please try again with a valid cell ID or reach number.
An error occurred while processing the input: invalid literal for int() with base 10: ''
Please try again with a valid cell ID or reach number.
游리 - Retrieved value coordinates (X: 114500.0, Y: 394500.0) differ from requested coordinates (X: 114437.5, Y: 394962.5) by more than half the cell size (dx: 25.0, dy: 25.0).
That may be valid if the resolution of the two arrays is different, but you should double-check.
An error occurred while processing the input: invalid literal for int() with base 10: ''
Please try again with a valid cell ID or reach number.
An error occurred while 