In [2]:
from pathlib import Path
from tqdm import tqdm

import pandas as pd
import geopandas as gpd
import xarray as xr
import torch

from src.cat_interp import CatchmentInterpolator
from src.schedule import define_schedule
from src.models import MultiStageLearnedMLP
from src.utils import nse_loss_fn

In [3]:
from src.config import DATA_ROOT as root

### Helpers

In [4]:
from src.geoglow_io import *
from src.schedule import define_schedule
from src.cat_interp import CatchmentInterpolator
from diffroute import get_node_idxs

def load_vpu(vpu, device):
    all_discharges = pd.read_feather(root / "retro_feather" / f"{vpu}.feather")
    gdf = pd.read_feather(root / "tdxhydro_feather" / f"streams_{vpu}.feather")

    g, runoff, interp_df = load_geoflow(vpu_numbers=[vpu])
    annotate_graph_with_physical_prop(g, gdf)
    nodes_idx = get_node_idxs(g)
    runoff = runoff.loc[all_discharges.index]
    discharges = all_discharges[nodes_idx.index]

    clusters_g, node_transfer = define_schedule(g, plength_thr=10**4, node_thr=20000)
    cat = CatchmentInterpolator(clusters_g, runoff, interp_df, device=device)
    return g, clusters_g, node_transfer, cat, discharges

def test(ms, discharges, N):
    df_out = []
    df_y = []
    nses = []
    
    with torch.no_grad():
        outputs = ms.forward(N)
        for i,output in enumerate(tqdm(outputs)):
            out = pd.DataFrame(output.cpu().squeeze(), index=ms.node_idxs[i].index).T
            y = discharges[out.columns].reset_index(drop=True)
            nse = 1-(((out-y)**2).mean() / y.var())
    
            nses+=[nse]
            df_out += [out]
            df_y += [y]
    return nses, df_out, df_y

#### params

In [5]:
model_name="muskingum"
max_delay = 50
dt = 1/24
epochs = 1000
device = "cuda:3"
test_device = "cuda:4" # GPU memory gets full so test on another GPU
n_iter = 100

#### Init

In [None]:
g, clusters_g, node_transfer, cat, discharges = load_vpu("603", device)
g2, clusters_g2, node_transfer2, cat2, discharges2 = load_vpu("602", test_device)

In [7]:
stds = read_physical_prop(g)[-1] # Important: Need to normalize physical properties using the same std.
ms = MultiStageLearnedMLP(clusters_g, node_transfer, cat, param_stds=stds).to(device)
ms.init_layers_buffers(device)
ms2 = MultiStageLearnedMLP(clusters_g2, node_transfer2, cat2, param_stds=stds).to(test_device)
ms2.init_layers_buffers(test_device)

opt = torch.optim.Adam(ms.parameters(), lr=.0001)
ys = [torch.from_numpy(discharges[nodes_idx.index].values.T)[None].to(device) 
     for nodes_idx in tqdm(ms.node_idxs)]

  0%|          | 0/52 [00:00<?, ?it/s]
[A [00:00, ?it/s]
  2%|▏         | 1/52 [00:00<00:06,  7.34it/s]
100%|██████████| 52/52 [00:00<00:00, 150.30it/s]
52it [00:00, 150.55it/s]
  0%|          | 0/17 [00:00<?, ?it/s]
[A [00:00, ?it/s]
100%|██████████| 17/17 [00:00<00:00, 122.08it/s]
17it [00:00, 123.67it/s]


  0%|          | 0/52 [00:00<?, ?it/s]

#### Training

In [None]:
all_losses = []
res = []

for N in tqdm(range(1,18)):
    ms.mlp.to(device)
    for i in tqdm(range(n_iter)):
        outputs = ms.forward(N)
        losses = [nse_loss_fn(out.squeeze(), y.squeeze()) \
                  for out, y in zip(outputs[:N], ys[:N])]
        loss = sum(losses) / len(losses)
        
        
        loss.backward()
        opt.step()
        opt.zero_grad()
        
        all_losses.append(loss.item())
        print((loss.item()))#, [x.item() for x in losses]))
        
    nses, df_out, df_y = test(ms, discharges, N=1)
    print([x.mean() for x in nses])
    print([x.median() for x in nses])

    for model in ms2.routers: model.mlp =  ms.mlp.to(test_device)
    tnses, df_out, df_y = test(ms2, discharges2, N=1)
    print([nse.median() for nse in tnses])
    print([nse.mean() for nse in tnses])
    res.append((nses, tnses))

  0%|          | 0/100 [00:00<?, ?it/s]

2.242368221282959
2.1881749629974365
2.133002519607544
2.0766680240631104
2.019174814224243
1.9604743719100952
1.9006462097167969
1.8397632837295532
1.7778904438018799
1.7151025533676147
1.6514334678649902
1.5869015455245972
1.5216141939163208
1.4556466341018677
1.3890966176986694
1.322152853012085
1.254896879196167
1.1875550746917725
1.1204158067703247
1.053498387336731
0.9869477152824402
0.920948326587677
0.8557178378105164
0.7914872765541077
0.7285439968109131
0.667191207408905
0.6077373623847961
0.5504348278045654
0.49557480216026306
0.4433974325656891
0.3942476511001587
0.348408967256546
0.3060794472694397
0.26745766401290894
0.23268622159957886
0.20182795822620392
0.17487025260925293
0.15172916650772095
0.13227182626724243
0.1163206398487091
0.10362468659877777
0.0939001590013504
0.08682490885257721
0.0820600837469101
0.07924021035432816
0.07799863070249557
0.07798223942518234
0.0788537785410881
0.08031033724546432
0.08208249509334564
0.08394575119018555
0.08571714162826538
0.087

  0%|          | 0/1 [00:00<?, ?it/s]

[0.9944226]
[0.9974633]


  0%|          | 0/1 [00:00<?, ?it/s]

[0.99803555]
[0.9956648]


  0%|          | 0/100 [00:00<?, ?it/s]

0.18471604585647583
0.18344275653362274
0.18086251616477966
0.17720115184783936
0.17271052300930023
0.16765066981315613
0.1622563600540161
0.15674728155136108
0.1512976437807083
0.1460166573524475
0.14104968309402466
0.13649238646030426
0.13237354159355164
0.12871424853801727
0.12548577785491943
0.12265321612358093
0.12020111083984375
0.11808908730745316
0.11629696190357208
0.1147976964712143
0.11344299465417862
0.11211833357810974
0.11077067255973816
0.10939139127731323
0.10798883438110352
0.10658667236566544
0.10522255301475525
0.10390713065862656
0.10267352312803268
0.1015295535326004
0.10048311948776245
0.0995110496878624
0.09858889877796173
0.09768825024366379
0.09678944945335388
0.09586770832538605
0.09490355849266052
0.09388592094182968
0.0928037017583847
0.09165870398283005
0.09045694023370743
0.08920923620462418
0.08792690932750702
0.08662942796945572
0.08532685041427612
0.08403126895427704
0.08275626599788666
0.08150176703929901
0.08027857542037964
0.07908371090888977
0.07792

  0%|          | 0/1 [00:00<?, ?it/s]

[0.9970746]
[0.99871135]


  0%|          | 0/1 [00:00<?, ?it/s]

[0.9989303]
[0.997661]


  0%|          | 0/100 [00:00<?, ?it/s]

0.06627922505140305
0.06455418467521667
0.062437690794467926
0.06024621054530144
0.05822639912366867
0.056534670293331146
0.05521848797798157
0.05423600599169731
0.05350290238857269
0.052943915128707886
0.05251047760248184
0.052189670503139496
0.05198858305811882
0.05190981552004814
0.05191725119948387
0.05195818468928337
0.0519697368144989
0.05187736451625824
0.05163673684000969
0.051239024847745895
0.05069928616285324
0.05006356164813042
0.04939868301153183
0.04875871539115906
0.048166386783123016
0.04766044020652771
0.047235384583473206
0.046900954097509384
0.04663205146789551
0.04643293842673302
0.04628543555736542
0.04618849605321884
0.04612807184457779
0.04608898609876633
0.04605947062373161
0.04600213095545769
0.04589974880218506
0.04575556516647339
0.04555533453822136
0.04532461613416672
0.04506158083677292
0.04480241984128952
0.04455024003982544
0.04431487247347832
0.04411432147026062
0.04393746703863144
0.04379257559776306
0.04366854950785637
0.043575190007686615
0.0434959344

  0%|          | 0/1 [00:00<?, ?it/s]

[0.9976853]
[0.99909943]


  0%|          | 0/1 [00:00<?, ?it/s]

[0.99920183]
[0.99797916]


  0%|          | 0/100 [00:00<?, ?it/s]

0.05527282506227493
0.05472759157419205
0.05416979640722275
0.05393435060977936
0.05407347530126572
0.0543920174241066
0.05456317961215973
0.05434667319059372
0.05372242629528046
0.05285327136516571
0.051970697939395905
0.051217835396528244
0.05068817734718323
0.05028275400400162
0.04993029683828354
0.049599893391132355
0.04932345449924469
0.049147896468639374
0.049111492931842804
0.049163758754730225
0.04916901886463165
0.04908202588558197
0.048824138939380646
0.04843898490071297
0.047977373003959656
0.04752689599990845
0.047136805951595306
0.04682953283190727
0.04659903421998024
0.04643332213163376
0.04633459448814392
0.04629822075366974
0.04629889875650406
0.04632154852151871
0.04632954299449921
0.046267785131931305
0.046133853495121
0.04593333601951599
0.04569820687174797
0.04545566439628601
0.04523685947060585
0.04505433887243271
0.04490027576684952
0.04478500783443451
0.04470982775092125
0.044677525758743286
0.04463754594326019
0.04462286829948425
0.04459156095981598
0.0445092394

  0%|          | 0/1 [00:00<?, ?it/s]

[0.9980698]
[0.9992689]


  0%|          | 0/1 [00:00<?, ?it/s]

[0.99932474]
[0.9982877]


  0%|          | 0/100 [00:00<?, ?it/s]

0.051817383617162704


### Eval

In [9]:
for model in ms2.routers: model.mlp =  ms.mlp.to(test_device)
tnses, df_out, df_y = test(ms2, discharges2, N=N)
# Check if any NSE disparity between cluster: minimal
for nse, tnse in res: print([x.median() for x in tnse])

  0%|          | 0/12 [00:00<?, ?it/s]

[0.99803555]
[0.9989303]
[0.99920183]
[0.99932474]
[0.9995174]
[0.9996587]
[0.99973077]
[0.9998072]
[0.99982846]
[0.9998453]
[0.9998565]


In [15]:
pd.concat(tnses).median()

0.9997825