In [1]:
from pathlib import Path

import torch
import numpy as np
import pandas as pd
import hvplot.pandas

import diffhydro as dh
import diffhydro.pipelines as dhp
import xtensor as xt

from utils.io import load_ono_data, split_and_normalize_data

### Parameters

In [2]:
N = 10                  # Repeat experiment 10 times
lr = .005               # Initial Runoff Learning rate
wd = .001               # LSTM Weight decay

n_epoch = 10            # training epoch
init_window = 100       # LSTM init window
pred_len = 200

device = "cuda:6"       #  Device: Assume available to NVIDIA GPU
irf_fns = [ "hayami"]   # 'pure_lag', 'linear_storage', 'nash_cascade', 'muskingum']

root = Path("./data")

### Load data

In [3]:
inp, lbl, static, g = load_ono_data(root)
inp_tr, inp_te, lbl_tr, lbl_te, lbl_std = split_and_normalize_data(inp, lbl)
basin_area, cat_area, channel_dist = static.values.t()

In [4]:
tr_ds  = dhp.JointRoutingRunoffDataset(inp_tr, lbl_tr, g, 
                                       init_window, pred_len, 
                                       cat_area, basin_area, 
                                       channel_dist)
val_ds = dhp.JointRoutingRunoffDataset(inp_te, lbl_te, g, 
                                       init_window, pred_len, 
                                       cat_area, basin_area, 
                                       channel_dist)

param_model = dhp.utils.MLP(2,2)
model = dhp.RunoffRoutingModel(param_model, dt=.25, max_delay=100).to(device)

module = dhp.RunoffRoutingModule(model, tr_ds, val_ds,
                         batch_size=256,
                         clip_grad_norm=1,
                         routing_lr=10**-4, 
                         routing_wd=10**-3, 
                         runoff_lr=.005, 
                         runoff_wd=.001,
                         scheduler_step_size=None,
                         scheduler_gamma=.3)

trloss, valloss = module.train(n_epoch, device)
y_te, o_te = module.extract_test(device)

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [7]:
(trloss.hvplot() * valloss.hvplot()).opts(ylim=(.8, 1))

In [10]:
#module.extract_test??

In [4]:
res = {}
for init_window in [100, 200, 300]:
    for pred_len in [10, 30, 50, 100, 200]:
        break
        tr_ds  = dhp.JointRoutingRunoffDataset(inp_tr, lbl_tr, g, 
                                               init_window, pred_len, 
                                               cat_area, basin_area, 
                                               channel_dist)
        val_ds = dhp.JointRoutingRunoffDataset(inp_te, lbl_te, g, 
                                               init_window, pred_len, 
                                               cat_area, basin_area, 
                                               channel_dist)
        
        param_model = dhp.utils.MLP(2,2)
        model = dhp.RunoffRoutingModel(param_model, dt=.25, max_delay=100).to(device)
        
        module = dhp.RunoffRoutingModule(model, tr_ds, val_ds,
                                 batch_size=256,
                                 clip_grad_norm=1,
                                 routing_lr=10**-4, 
                                 routing_wd=10**-3, 
                                 runoff_lr=.005, 
                                 runoff_wd=.001,
                                 scheduler_step_size=None,
                                 scheduler_gamma=.3)
        trloss, valloss = module.train(n_epoch, device)
        y_new, o_new = module.extract_test(device)
        
        res[(init_window, pred_len)]=(trloss, valloss, y_new, o_new)