In [25]:
import time
from collections import defaultdict
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import pandas as pd

from diffroute import LTIRouter, get_node_idxs, read_params

from src.schedule import define_schedule
from src.cat_interp import CatchmentInterpolator
from src.geoglow_io import extract_all_graphs

In [2]:
# Be sure to set the correct download path in src/config
from src.config import DATA_ROOT as root

In [3]:
def run_sim(clusters_g, node_transfer, cat):
    step_time = defaultdict(float) 
    outputs   = []
    transfered_inputs = {i: [] for i, _ in enumerate(clusters_g)}

    for i, g in enumerate(tqdm(clusters_g)):
        nodes_idx = get_node_idxs(g)

        params = (
            read_params(g, model_name, nodes_idx)
            .float()
            .to(device, non_blocking=True)
        )

        model = LTIRouter(
            g,
            nodes_idx=nodes_idx,
            max_delay=time_window,
            block_size=block_size,
            irf_fn=model_name,
            irf_agg=irf_agg,
            index_precomp=index_precomp,
            runoff_to_output=runoff_to_output,
            dt=dt,
            sampling_mode=sample_mode,
        ).to(device)
        model.aggregator.init_buffers(device)

        x = cat.read_catchment(i) 
        
        for e_dst, inp_dis in transfered_inputs[i]: x[:, e_dst] += inp_dis.squeeze()

        out = model(x.to(device, non_blocking=True), params)
        
        for (c_idx, e_src, e_dst) in node_transfer[i]: transfered_inputs[c_idx].append((e_dst, out[:, e_src].clone().detach()))
        
        output = pd.DataFrame(out.cpu().squeeze(), index=nodes_idx.index, columns=runoff.index).T
        outputs.append(output)
        
    outputs = pd.concat(outputs, axis=1).loc[q.index] / 86400.0
    return outputs

### Experiment

In [4]:
# Only keep important parameters here
time_window=30
dt=1/24
block_size=16

model_name="muskingum"
irf_agg="log_triton"
index_precomp="cpu"
sample_mode="avg"

plength_thr=10**5 # Maximum depth of 100,000 pathes per cluster
node_thr=10**4    # Maximum width of 10,000 nodes per cluster
device = "cuda:5"
runoff_to_output = False

In [6]:
config_root = root / 'configs'
root_discharge = root / "retro_feather"
runoff = pd.read_feather(root / "daily_sparse_runoff.feather")
interp_df = pd.read_feather(root / "interp_weight.feather")
vpus = [x.stem for x in root_discharge.glob("*")]

### First numerical eval.

- Here, we compute vpu by vpu to facilitate evaluation.
- It takes a lot of time as we have a lot of redundant IO, but it is still faster than doing the global evaluation in-memory because:
    - With global evaluation we need to manipulate pandas of >1TB for evaluation -> memory blow out
    - That blows up memory and even half dataset variance computations take hours.
- Hence, we evaluate per VPU and do the fast timing in global execution below, it is simpler and results are the same.
- We should find a solution to unify the two.

In [None]:
res = []

for vpu in tqdm(vpus):
    G = extract_all_graphs(config_root, [vpu])
    for n in G.nodes: G.nodes[n]["k"] /= (3600*24)

    clusters_g, node_transfer = define_schedule(G, plength_thr=plength_thr, node_thr=node_thr)
    cat = CatchmentInterpolator(clusters_g, runoff, interp_df, device=device)

    print("reading river discharges...")
    q = pd.read_feather(root_discharge / f"{vpu}.feather")
    qvar = q.var()
    qmean = q.mean()
    
    print("Running model...")
    out  = run_sim(clusters_g, node_transfer, cat)
    

    nse = 1 - ((out - q)**2).mean() / qvar
    mae = (out-q).abs().mean()

    res.append((nse, mae, qmean, qvar))

In [24]:
nse, mae, qmean, qvar = zip(*res)
nse, mae, qmean, qvar = [pd.concat(x) for x in [nse, mae, qmean, qvar]]
nse.median()

np.float32(0.9996492)

### Second, global timing

In [9]:
sync   = torch.cuda.synchronize      

def run_sim_timed(clusters_g, node_transfer, cat):
    step_time = defaultdict(float)   
    outputs   = []
    transfered_inputs = {i: [] for i, _ in enumerate(clusters_g)}

    for i, g in enumerate(tqdm(clusters_g)):

        # ── Step‑1 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        nodes_idx = get_node_idxs(g)
        step_time[1] += time.perf_counter() - t0                # CPU only

        # ── Step‑2 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        params = (
            read_params(g, model_name, nodes_idx)
            .float()
            .to(device, non_blocking=True)
        )
        sync(device)                                            # wait for copy
        step_time[2] += time.perf_counter() - t0

        # ── Step‑3 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        model = LTIRouter(
            g,
            nodes_idx=nodes_idx,
            max_delay=time_window,
            block_size=block_size,
            irf_fn=model_name,
            irf_agg=irf_agg,
            index_precomp=index_precomp,
            runoff_to_output=runoff_to_output,
            dt=dt,
            sampling_mode=sample_mode,
        ).to(device)
        model.aggregator.init_buffers(device)
        sync(device)
        step_time[3] += time.perf_counter() - t0

        # ── Step‑4 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        x = cat.read_catchment(i)        # usually CPU → GPU happens later
        step_time[4] += time.perf_counter() - t0

        # ── Step‑5 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        for e_dst, inp_dis in transfered_inputs[i]: x[:, e_dst] += inp_dis.squeeze()
        step_time[5] += time.perf_counter() - t0                # CPU tensor ops

        # ── Step‑6 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        out = model(x.to(device, non_blocking=True), params)
        sync(device)                                            # wait for kernel
        step_time[6] += time.perf_counter() - t0

        # ── Step‑7 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        for (c_idx, e_src, e_dst) in node_transfer[i]: transfered_inputs[c_idx].append((e_dst, out[:, e_src].clone().detach()))
        step_time[7] += time.perf_counter() - t0                # small CPU work

        # ── Step‑8 ───────────────────────────────────────────────
        t0 = time.perf_counter()
        output = pd.DataFrame(out.cpu().squeeze(), index=nodes_idx.index, columns=runoff.index).T
        outputs.append(output)
        step_time[8] += time.perf_counter() - t0

    # ── Final report ────────────────────────────────────────────
    total = sum(step_time.values())
    print("\n=== cumulative run‑time by step (s) ===")
    for k in range(1, 9):
        print(f"Step {k}: {step_time[k]:8.3f} s  ({step_time[k]/total:5.1%})")
    print(f"Total: {total:8.3f} s")

    return model, step_time, outputs

In [13]:
# Load full global dataset
G, runoff, interp_df = load_geoflow(vpu_numbers=None)
clusters_g, node_transfer = define_schedule(G, plength_thr=plength_thr, node_thr=node_thr)
cat = CatchmentInterpolator(clusters_g, runoff, interp_df, device=device)

Loading runoffs...


  0%|          | 0/125 [00:00<?, ?it/s]

#### Upstream stats computations ... ####


Computing breakpoints:   0%|          | 0/6838900 [00:00<?, ?it/s]

#### Segmentation into subgraphs ... ####
Removing edges...


  0%|          | 0/6838900 [00:00<?, ?it/s]

Segment graph into connected components....
Build subgraphs for each cluster and node-cluster map...


  0%|          | 0/104665 [00:00<?, ?it/s]

Establish dependencies between clusters...


  0%|          | 0/70044 [00:00<?, ?it/s]

#### Grouping subgraphs to cluster and infering dependencies ... ####
Initialize dependencies...
Associate clusters for remaining subgraphs...


0it [00:00, ?it/s]

Merging graphs...


  0%|          | 0/729 [00:00<?, ?it/s]

Computing merged graphs node idxs...


  0%|          | 0/729 [00:00<?, ?it/s]

Match breakpoint nodes across clusters...
#### Cluster Annotations ... ####


  0%|          | 0/729 [00:00<?, ?it/s]

In [14]:
# First execution is slower (why: It probably is triton JIT compile overhead)
run_sim_timed(clusters_g, node_transfer, cat)

  0%|          | 0/729 [00:00<?, ?it/s]


=== cumulative run‑time by step (s) ===
Step 1:   10.809 s  ( 2.8%)
Step 2:   49.479 s  (12.6%)
Step 3:  148.009 s  (37.8%)
Step 4:    0.148 s  ( 0.0%)
Step 5:    0.000 s  ( 0.0%)
Step 6:  182.918 s  (46.7%)
Step 7:    0.000 s  ( 0.0%)
Step 8:    0.000 s  ( 0.0%)
Total:  391.364 s


(LTIRouter(
   (aggregator): RoutingIRFAggregator()
   (conv): BlockSparseCausalConv()
 ),
 defaultdict(float,
             {1: 10.808951897546649,
              2: 49.47939977608621,
              3: 148.00921990536153,
              4: 0.14795606583356857,
              6: 182.91803489252925,
              5: 0.0,
              7: 0.0,
              8: 0.0}),
 [])

In [15]:
# Second execution is fast. 
# Step 6 is routing. 
# Step 1 is computing node_idxs, Step 2 is parameter reading, Step 3 is model init and Step 4 is interpolating catchments.
run_sim_timed(clusters_g, node_transfer, cat)

  0%|          | 0/729 [00:00<?, ?it/s]


=== cumulative run‑time by step (s) ===
Step 1:    9.510 s  ( 4.0%)
Step 2:   22.298 s  ( 9.4%)
Step 3:  131.652 s  (55.4%)
Step 4:    0.123 s  ( 0.1%)
Step 5:    0.000 s  ( 0.0%)
Step 6:   74.030 s  (31.2%)
Step 7:    0.000 s  ( 0.0%)
Step 8:    0.000 s  ( 0.0%)
Total:  237.613 s


(LTIRouter(
   (aggregator): RoutingIRFAggregator()
   (conv): BlockSparseCausalConv()
 ),
 defaultdict(float,
             {1: 9.509653015062213,
              2: 22.298364767804742,
              3: 131.65237718820572,
              4: 0.12266553193330765,
              6: 74.03042277693748,
              5: 0.0,
              7: 0.0,
              8: 0.0}),
 [])

# Comments on speed:
- Routing itself is fast.
- When saving everything to CPU, time is dominated by CPU manipulation of the output (step 8).
- We should optimize this, overlap IO with compute  and probably have a CPU-side process doing the formatting on CPU with a pinned memory tensor to receive data.
- But wether putting the efforts to do this is worth it depends on the needs of applications.
- Hence, please reach out if you use this code so we can identify the needs.