### Ideas of what we have to do:

In [1]:
#from imports import *
import sys
from tqdm.auto import tqdm
from pathlib import Path
import pandas as pd
import torch

sys.path.insert(0, "./DiffHydro")
sys.path.insert(0, "./DiffRoute")

from diffhydro import (TimeSeriesThDF, RivTree, RivTreeCluster,
                       StagedCatchmentInterpolator, CatchmentInterpolator) 
from diffhydro.pipelines import CalibrationRouter
from diffhydro.utils import nse_fn

from diffroute.io import read_rapid_graph

### New

In [2]:
plength_thr=10**4
node_thr=10**4

model_name="muskingum"
time_window = max_delay = 30
dt = 1/24

epochs = 500
n_iter = 100
device = "cuda:6"
vpu = "604"

In [4]:
(g, interp_df,
 tr_runoff_pix, te_runoff_pix, 
 tr_discharge, te_discharge) = init_calib_exp(vpu, "1980", plength_thr, node_thr, device)

#### Upstream stats computations ... ####


Computing breakpoints:   0%|          | 0/32717 [00:00<?, ?it/s]

#### Segmentation into subgraphs ... ####
Removing edges...


  0%|          | 0/32717 [00:00<?, ?it/s]

Segment graph into connected components....
Build subgraphs for each cluster and node-cluster map...


  0%|          | 0/416 [00:00<?, ?it/s]

Establish dependencies between clusters...


  0%|          | 0/381 [00:00<?, ?it/s]

#### Grouping subgraphs to cluster and infering dependencies ... ####
Initialize dependencies...
Associate clusters for remaining subgraphs...


0it [00:00, ?it/s]

Merging graphs...


  0%|          | 0/56 [00:00<?, ?it/s]

Computing merged graphs node idxs...


  0%|          | 0/56 [00:00<?, ?it/s]

Match breakpoint nodes across clusters...


  0%|          | 0/56 [00:00<?, ?it/s]

### Small VPU can be optimized over in a single forward

In [None]:
ci = CatchmentInterpolator(g, tr_runoff_pix, interp_df).to(device)
tr_runoff = ci.interpolate_runoff(tr_runoff_pix)
te_runoff = ci.interpolate_runoff(te_runoff_pix)

tr_y = tr_discharge
te_y = te_discharge

model = CalibrationRouter(g, max_delay, dt).to(device)
opt = torch.optim.Adam(model.parameters(), lr=.1)

In [None]:
with torch.autograd.set_detect_anomaly(True):
    pbar=tqdm(range(n_iter), desc="Training")
    
    for i in pbar:
        out = model(tr_runoff)
        
        tr_nse = nse_fn(out.values, tr_y.values).mean()
        
        opt.zero_grad()
        loss = 1-tr_nse
        loss.backward()
        opt.step()
    
        out = model(te_runoff)
        te_nse = nse_fn(out.values, te_y.values).mean()
        
        pbar.set_postfix({"Tr NSE:": tr_nse.item(), "Te NSE":te_nse.item()})

### Large graphs need to be optimized sequentially

In [5]:
ci = StagedCatchmentInterpolator(g, tr_runoff_pix, interp_df).to(device)
model = CalibrationRouter(g, max_delay, dt).to(device)
opt = torch.optim.Adam(model.parameters(), lr=.1)

  0%|          | 0/56 [00:00<?, ?it/s]

In [7]:
for cluster_idx in tqdm(range(len(model.params))):
    nodes = g[cluster_idx].nodes    
    pbar = tqdm(range(n_iter), desc="Training")

    with torch.no_grad():
        tr_transfer_bucket = model.init_upstream_discharges(ci.yield_all_runoffs(tr_runoff_pix), cluster_idx)
        te_transfer_bucket = model.init_upstream_discharges(ci.yield_all_runoffs(te_runoff_pix), cluster_idx)
    
    tr_runoff = ci.interpolate_runoff(tr_runoff_pix, cluster_idx)
    te_runoff = ci.interpolate_runoff(te_runoff_pix, cluster_idx)
    
    assert pd.Series(nodes).isin(tr_discharge.columns).all()
    assert pd.Series(nodes).isin(te_discharge.columns).all()
    
    tr_y = tr_discharge[nodes]
    te_y = te_discharge[nodes]

    for _ in pbar:
        out = model.process_one_cluster(tr_runoff, cluster_idx, 
                                        tr_transfer_bucket)
        tr_nse = nse_fn(out.values, tr_y.values).mean()
        
        opt.zero_grad()
        loss = 1-tr_nse
        loss.backward()
        opt.step()
        
        out = model.process_one_cluster(te_runoff, cluster_idx, 
                                        te_transfer_bucket)
        te_nse = nse_fn(out.values, te_y.values).mean()
        
        pbar.set_postfix({"Tr NSE:": tr_nse.item(), "Te NSE":te_nse.item()})

  0%|          | 0/56 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [16]:
class RoutingDataset():
    def __init__(self, runoff_inp, discharge_lbl, seq_len):
        assert (runoff_inp.index==discharge_lbl.index).all()
        self.runoff_inp=runoff_inp
        self.discharge_lbl=discharge_lbl
        self.seq_len=seq_len

    def __getitem__(self, idx):
        index = slice(self.self.runoff_inp.index[idx],
                      self.self.runoff_inp.index[idx+seq_len])
        return self.runoff_inp[:,slice], self.discharge_lbl[:,slice]
        
    def __len__(self):
        return len(self.discharge_lbl.index) - self.seq_len

    def random_sequence(self, n_samples):
        pass