In [1]:
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import torch

from diffhydro import LTIStagedRouter, StagedCatchmentInterpolator, TimeSeriesThDF
from diffhydro.utils import Timer, nse_fn
from diffhydro.io import read_rapid_graph, read_multiple_rapid_graphs

### Experiment parameters

In [2]:
root = Path("../data") # Set your data root path

In [3]:
discharge_path = root / "geoglows" / "retro_feather"
vpu_config_path = root / "geoglows" / 'configs'
runoff_path = root / "geoglows" / "daily_sparse_runoff.feather"
interp_weight_path = root / "geoglows" / "interp_weight.feather"

In [4]:
# graph partitioning parameters
plength_thr = 10**4
node_thr = 10**4
# Routing model parameterss
max_delay = 32
dt = 1/24
# Experiment pathes and variables
device = "cuda:0"

### Download data if necessary

If the dataset is not already downloaded, please make sure your root path points to a folder with sufficient disk space.

If you are only interesting in running the simulation with diffroute, then please leave the flage download_gt_as_well = False.
This will download everything needed to run the model (\~20 GB ~ 1 min.), and will not download the original simulation results (\~700GB).

Only if you are interested in validatig the results against the original data set download_gt_as_well = True.
This will additionally download the output river discharge, but make sure your root path points to a location with sufficient storage capacity (>700GB)

In [None]:
import sys; sys.path.insert(0, "..") # Needed to access the download utilities
from utils.download import download_full_geoglows_data
download_full_geoglows_data(root, exclude_discharge=True)

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 754 files:   0%|          | 0/754 [00:00<?, ?it/s]

### 

In [5]:
pixel_runoff = pd.read_feather(runoff_path).loc[:"2019"]  / (3600. * 24) # Convert in m3 / s
cat_interp_df = pd.read_feather(interp_weight_path).set_index("river_id")

### Run the full simulation and compute max Q stat

In [7]:
vpus = [x for x in vpu_config_path.glob("*") if not x.name.startswith(".")]
%time g = read_multiple_rapid_graphs(vpus, plength_thr=plength_thr, node_thr=node_thr).to(device)

  0%|          | 0/125 [00:00<?, ?it/s]

#### Upstream stats computations ... ####


Computing breakpoints:   0%|          | 0/6838900 [00:00<?, ?it/s]

#### Segmentation into subgraphs ... ####
Removing edges...


  0%|          | 0/6838900 [00:00<?, ?it/s]

Segment graph into connected components....
Build subgraphs for each cluster and node-cluster map...


  0%|          | 0/104665 [00:00<?, ?it/s]

Establish dependencies between clusters...


  0%|          | 0/70044 [00:00<?, ?it/s]

#### Grouping subgraphs to cluster and infering dependencies ... ####
Initialize dependencies...
Associate clusters for remaining subgraphs...


0it [00:00, ?it/s]

Merging graphs...


  0%|          | 0/729 [00:00<?, ?it/s]

Computing merged graphs node idxs...


  0%|          | 0/729 [00:00<?, ?it/s]

Match breakpoint nodes across clusters...


  0%|          | 0/729 [00:00<?, ?it/s]

CPU times: user 6min 21s, sys: 1min 38s, total: 8min
Wall time: 8min 6s


In [None]:
pixel_runoff = TimeSeriesThDF.from_pandas(pixel_runoff).to(device)

In [None]:
cat = StagedCatchmentInterpolator(g, pixel_runoff, cat_interp_df).to(device)
model = LTIStagedRouter(max_delay=max_delay, dt=dt).to(device)

In [None]:
cat_runoffs = cat.yield_all_runoffs(pixel_runoff)
#cat_discharge = model.route_all_clusters_yield(cat_runoffs, g)

In [None]:
for output in tqdm(model.route_all_clusters_yield(cat_runoffs, g)):
    pass

### Validate against the original simulation

In [None]:
nses = []

for i,vpu in enumerate(tqdm(list(discharge_path.glob("*.feather")))):
    
    q = pd.read_feather(vpu)
    lbl = TimeSeriesThDF.from_pandas(q).to(device)
    g = read_rapid_graph(vpu_config_path / vpu.stem, 
                         plength_thr=plength_thr, 
                         node_thr=node_thr).to(device)
    
    interp_df = cat_interp_df.loc[g.nodes]
    pix_idxs = interp_df["pixel_idx"].unique()
    runoff = pixel_runoff[pix_idxs]
    
    cat = StagedCatchmentInterpolator(g, runoff, interp_df).to(device)
    model = LTIStagedRouter(
                  max_delay=max_delay,
                  block_size=block_size,
                  block_f=block_f,
                  dt=dt,
                  cascade=cascade,
                  sampling_mode=sampling_mode
              ).to(device)
    
    cat_runoffs = cat.yield_all_runoffs(runoff)
    cat_discharges =  model.route_all_clusters_yield(cat_runoffs, g, 
                                                     display_progress=False):
    for output in cat_discharges:
        y = lbl[output.columns]
        nse = nse_fn(output.values, y.values)
        nses.append(pd.Series(nse.squeeze().cpu().numpy(), index=output.columns))

In [None]:
pd.concat(nses).median().item()