In [1]:
import sys
from pathlib import Path
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import hvplot.pandas

import torch

from diffhydro import (TimeSeriesThDF, CatchmentInterpolator, StagedCatchmentInterpolator,
                       RivTree, RivTreeCluster, nse_fn)
from diffhydro.utils import SimpleTimeSeriesSampler
from diffhydro.pipelines import LearnedRouter, LearningModule

from diffroute.io import _read_rapid_graph
from diffroute.graph_utils import define_schedule

from utils.io import load_vpu

### New

In [2]:
plength_thr=10**4
node_thr=10**4
max_delay = 30
dt = 1/24
epochs = 500

device = "cuda:1"
tr_vpu = "603"
te_vpu = "602"

n_iter_per_cluster = 100
n_clusters = 20

init_len, pred_len = max_delay, 1024

In [3]:
root = Path("../data/geoglows")

In [4]:
interp_df = pd.read_feather(root / "interp_weight.feather").set_index("river_id")
runoff = pd.read_feather(root / "daily_sparse_runoff.feather").loc[:"2019"] / (3600. * 24)
runoff = TimeSeriesThDF.from_pandas(runoff)

### Data

In [6]:
tr_g, tr_discharge, tr_runoff = load_vpu(root, tr_vpu, runoff, interp_df, device,
                                         plength_thr=plength_thr, node_thr=node_thr)
te_g, te_discharge, te_runoff = load_vpu(root, te_vpu, runoff, interp_df, device,
                                         plength_thr=plength_thr, node_thr=node_thr)

#### Upstream stats computations ... ####


Computing breakpoints:   0%|          | 0/55576 [00:00<?, ?it/s]

#### Segmentation into subgraphs ... ####
Removing edges...


  0%|          | 0/55576 [00:00<?, ?it/s]

Segment graph into connected components....
Build subgraphs for each cluster and node-cluster map...


  0%|          | 0/689 [00:00<?, ?it/s]

Establish dependencies between clusters...


  0%|          | 0/625 [00:00<?, ?it/s]

#### Grouping subgraphs to cluster and infering dependencies ... ####
Initialize dependencies...
Associate clusters for remaining subgraphs...


0it [00:00, ?it/s]

Merging graphs...


  0%|          | 0/53 [00:00<?, ?it/s]

Computing merged graphs node idxs...


  0%|          | 0/53 [00:00<?, ?it/s]

Match breakpoint nodes across clusters...


  0%|          | 0/53 [00:00<?, ?it/s]

#### Upstream stats computations ... ####


Computing breakpoints:   0%|          | 0/29581 [00:00<?, ?it/s]

#### Segmentation into subgraphs ... ####
Removing edges...


  0%|          | 0/29581 [00:00<?, ?it/s]

Segment graph into connected components....
Build subgraphs for each cluster and node-cluster map...


  0%|          | 0/403 [00:00<?, ?it/s]

Establish dependencies between clusters...


  0%|          | 0/329 [00:00<?, ?it/s]

#### Grouping subgraphs to cluster and infering dependencies ... ####
Initialize dependencies...
Associate clusters for remaining subgraphs...


0it [00:00, ?it/s]

Merging graphs...


  0%|          | 0/17 [00:00<?, ?it/s]

Computing merged graphs node idxs...


  0%|          | 0/17 [00:00<?, ?it/s]

Match breakpoint nodes across clusters...


  0%|          | 0/17 [00:00<?, ?it/s]

In [7]:
tr_data = SimpleTimeSeriesSampler(tr_runoff, tr_discharge, init_len, pred_len)
te_data = SimpleTimeSeriesSampler(te_runoff, te_discharge, init_len, pred_len)

In [8]:
model  = LearnedRouter(max_delay, dt)
module = LearningModule(model, tr_data, te_data, tr_g, te_g).to(device)

In [15]:
te_nse, tr_nse = module.learn(n_iter=20, n_epoch=20)

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/20 [00:00<?, ?it/s]

In [16]:
(te_nse.hvplot(label="Test") *\
 tr_nse.hvplot(label="Train")).opts(ylim=(.99, 1))

### ok

In [None]:
results = module.learn_staged_sequentially()