In [1]:
import time
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
import torch

from diffroute import get_node_idxs

from src.geoglow_io import load_geoflow
from src.cat_interp import CatchmentInterpolator
from src.schedule import define_schedule
from src.models import CalibratedRouting
from src.utils import nse_loss_fn

In [2]:
from src.config import DATA_ROOT as root

In [3]:
def callibrate_model(model, x, y, epochs):
    model.model.aggregator.init_buffers(device)
    model.init_params()
    opt = torch.optim.Adam(model.parameters(), lr=.5)
    
    losses = []
    for i in tqdm(range(epochs)):
        out = model(x) 
        loss = nse_loss_fn(out.squeeze(), y.squeeze())
        if loss.isnan().item(): break
        loss.backward()
        opt.step()
        opt.zero_grad()
        losses.append(loss.item())
        print(losses[-1])
    with torch.no_grad():
        out = model(x)
        
    return out, losses

In [4]:
plength_thr=10**5
node_thr=10**4
runoff_to_output=False
device = "cuda:4"

time_window=30
dt=1/24
block_size=16
model_name="muskingum"
irf_agg="log_triton"
index_precomp="cpu"
sample_mode="avg"
data = []

epochs = 1000
max_delay = time_window
irf_fn=model_name
sampling_mode=sample_mode
irf_agg = "log_triton"
conv_imp="triton"
index_precomp = "cpu"

max_delay = 30
block_size = 16
dt = 1/24
epochs = 500
vpu = "605"

In [5]:
g, runoff, interp_df = load_geoflow(vpu_numbers=[vpu])
clusters_g, node_transfer = define_schedule(g, plength_thr=10**4, node_thr=20000)
nodes_idx = get_node_idxs(g)

Loading runoffs...


  0%|          | 0/1 [00:00<?, ?it/s]

#### Upstream stats computations ... ####


Computing breakpoints:   0%|          | 0/286905 [00:00<?, ?it/s]

#### Segmentation into subgraphs ... ####
Removing edges...


  0%|          | 0/286905 [00:00<?, ?it/s]

Segment graph into connected components....
Build subgraphs for each cluster and node-cluster map...


  0%|          | 0/3463 [00:00<?, ?it/s]

Establish dependencies between clusters...


  0%|          | 0/3427 [00:00<?, ?it/s]

#### Grouping subgraphs to cluster and infering dependencies ... ####
Initialize dependencies...
Associate clusters for remaining subgraphs...


0it [00:00, ?it/s]

Merging graphs...


  0%|          | 0/102 [00:00<?, ?it/s]

Computing merged graphs node idxs...


  0%|          | 0/102 [00:00<?, ?it/s]

Match breakpoint nodes across clusters...
#### Cluster Annotations ... ####


In [6]:
all_discharges = pd.read_feather(root / "retro_feather" / f"{vpu}.feather")

tr_index = all_discharges.index[all_discharges.index.year< 1980] 
te_index = all_discharges.index[all_discharges.index.year>=1980]

tr_runoff = runoff.loc[tr_index] / 3600 / 24
tr_discharges = all_discharges.loc[tr_index, nodes_idx.index]
tr_cat = CatchmentInterpolator(clusters_g, tr_runoff, interp_df, device=device)

te_runoff = runoff.loc[te_index] / 3600 / 24
te_discharges = all_discharges.loc[te_index, nodes_idx.index]
te_cat = CatchmentInterpolator(clusters_g, te_runoff, interp_df, device=device)

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

### Experiments

In [None]:
outputs, outputs_te, models = [],[],[]
all_losses = []
transfered_inputs = {i:[] for i,_ in enumerate(clusters_g)}
transfered_inputs_te = {i:[] for i,_ in enumerate(clusters_g)}

for i,g in enumerate(tqdm(clusters_g)):
    nodes_idx = get_node_idxs(g)    
    model = CalibratedRouting(g, 
                               nodes_idx=nodes_idx,
                               max_delay=time_window, 
                               block_size=block_size,
                               irf_fn=model_name, 
                               irf_agg=irf_agg, 
                               index_precomp=index_precomp, 
                               runoff_to_output=runoff_to_output,
                               dt=dt, sampling_mode=sample_mode,
                               ).to(device)

    # Assemble cluster
    x = tr_cat.read_catchment(i) 
    for e_dst, inp_dis in transfered_inputs[i]: x[:,e_dst] += inp_dis.squeeze()
    y = torch.from_numpy(tr_discharges[nodes_idx.index].values.T)[None].to(x.device)
    
    # Calibration
    out, losses = callibrate_model(model, x, y, epochs)

    # Evaluation
    x = te_cat.read_catchment(i) 
    for e_dst, inp_dis in transfered_inputs_te[i]: x[:,e_dst] += inp_dis.squeeze()   
    with torch.no_grad(): out_te =  model(x)

    # Transmit to downstream clusters
    for (c_idx, e_src, e_dst) in node_transfer[i]:
        transfered_inputs[c_idx].append((e_dst, out[:,e_src].clone().detach()))       
        transfered_inputs_te[c_idx].append((e_dst, out_te[:,e_src].clone().detach()))       

    outputs.append(pd.DataFrame(out.cpu().squeeze(), index=nodes_idx.index).T)
    outputs_te.append(pd.DataFrame(out_te.cpu().squeeze(), index=nodes_idx.index).T)
    
    models.append(model.cpu())
    all_losses.append(losses)

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

2.2627456188201904
1.215781569480896
0.3738154172897339
0.26802685856819153
0.4643107056617737
0.5486710071563721
0.49715954065322876
0.39129379391670227
0.3019059896469116
0.25791317224502563
0.24832585453987122
0.24541305005550385
0.22883765399456024
0.19700093567371368
0.1618330031633377
0.13672174513339996
0.12611056864261627
0.124087393283844
0.12175871431827545
0.11403991281986237
0.10138774663209915
0.087760329246521
0.0770530179142952
0.07060522586107254
0.06696175783872604
0.06345668435096741
0.05842364951968193
0.05232778191566467
0.04694455862045288
0.04354241117835045
0.04180643707513809
0.04027869179844856
0.037715308368206024
0.034044887870550156
0.0301715899258852
0.027060767635703087
0.02502744272351265
0.02368927001953125
0.022454626858234406
0.021032201126217842
0.019541967660188675
0.018244756385684013
0.01723240502178669
0.016345974057912827
0.01537337526679039
0.014279711060225964
0.01323835738003254
0.012453708797693253
0.011957334354519844
0.011583305895328522
0.

  0%|          | 0/500 [00:00<?, ?it/s]

2.310561418533325
1.1118309497833252
0.28373318910598755
0.3062053918838501
0.5474051237106323
0.6042450070381165
0.5140641331672668
0.38385581970214844
0.28651750087738037
0.24335186183452606
0.23542535305023193
0.23337194323539734
0.21794214844703674
0.18698959052562714
0.15322022140026093
0.13112297654151917
0.12355232983827591
0.12174368649721146
0.1166011169552803
0.10595856606960297
0.09312326461076736
0.0826122984290123
0.07661733776330948
0.0736040472984314
0.06984619796276093
0.06303229182958603
0.05442431941628456
0.04723157361149788
0.043448399752378464
0.04236774146556854
0.04156135767698288
0.03911271318793297
0.03505146503448486
0.030869251117110252
0.027979478240013123
0.026616329327225685
0.02581145614385605
0.02436244674026966
0.021986940875649452
0.019526800140738487
0.01800495572388172
0.01756664365530014
0.017415517941117287
0.016675323247909546
0.015195689164102077
0.013581442646682262
0.012542147189378738
0.012236484326422215
0.012210395187139511
0.011895927600562

  0%|          | 0/500 [00:00<?, ?it/s]

2.241441249847412
1.03379225730896
0.3083544075489044
0.35680103302001953
0.5268887281417847
0.5466773509979248
0.46390628814697266
0.3648524880409241
0.29921695590019226
0.2724483609199524
0.2637653648853302
0.24957554042339325
0.2206016331911087
0.1846376657485962
0.15565554797649384
0.14038006961345673
0.13476654887199402
0.13084784150123596
0.1236877515912056
0.11307455599308014
0.10150681436061859
0.09126725792884827
0.08287859708070755
0.07571059465408325
0.06926563382148743
0.06351302564144135
0.05843876302242279
0.053788408637046814
0.049177151173353195
0.04445692524313927
0.04004064202308655
0.036700017750263214
0.034774165600538254
0.033666156232357025
0.0323067232966423
0.03006695955991745
0.027208132669329643
0.024565713480114937
0.02278277277946472
0.021783368661999702
0.020962471142411232
0.019821934401988983
0.018367961049079895
0.01695854961872101
0.01587977260351181
0.01512104831635952
0.014483430422842503
0.013787776231765747
0.013012592680752277
0.012272285297513008


  0%|          | 0/500 [00:00<?, ?it/s]

2.014587640762329
0.8449533581733704
0.22901907563209534
0.35537296533584595
0.5203214883804321
0.5158057808876038
0.41719958186149597
0.313285231590271
0.2511218786239624
0.23111918568611145
0.2284470796585083
0.2197723388671875
0.19724753499031067
0.16658172011375427
0.13760730624198914
0.11803603172302246
0.10970418155193329
0.1078336238861084
0.10550286620855331
0.0982908084988594
0.08625437319278717
0.07330843061208725
0.06390143185853958
0.059422817081213
0.057806096971035004
0.056014612317085266
0.052399467676877975
0.047317709773778915
0.04210088029503822
0.0378408245742321
0.03485659137368202
0.032681893557310104
0.030586719512939453
0.02827211655676365
0.02597525715827942
0.023957202211022377
0.022156422957777977
0.02039971947669983
0.018744828179478645
0.017455600202083588
0.01664326712489128
0.01605086401104927
0.015272303484380245
0.014167995192110538
0.013017118908464909
0.012200774624943733
0.011760610155761242
0.011370029300451279
0.010757575742900372
0.0100215692073106

  0%|          | 0/500 [00:00<?, ?it/s]

2.4445364475250244
1.2147995233535767
0.5005813241004944
0.4321826100349426
0.49850010871887207
0.5138391852378845
0.4637717604637146
0.4004829227924347
0.3623713254928589
0.34520423412323
0.32759520411491394
0.29748472571372986
0.2589682340621948
0.2234654277563095
0.19756479561328888
0.18070164322853088
0.16999508440494537
0.16184715926647186
0.15239126980304718
0.14054743945598602
0.127879798412323
0.11500384658575058
0.10184167325496674
0.08981907367706299
0.0806482806801796
0.07437353581190109
0.0696796402335167
0.06542067974805832
0.061161164194345474
0.0568961426615715
0.052840910851955414
0.0490177720785141
0.04492644593119621
0.040256958454847336
0.03563544526696205
0.03213978931307793
0.030154045671224594
0.028916049748659134
0.0272779893130064
0.02483327127993107
0.02222503162920475
0.020341023802757263
0.019400814548134804
0.01888410560786724
0.0181896835565567
0.017213359475135803
0.01622323878109455
0.015410743653774261
0.014701676554977894
0.013952959328889847
0.01315448

  0%|          | 0/500 [00:00<?, ?it/s]

1.957848072052002
0.9013990163803101
0.27811506390571594
0.37658852338790894
0.501059353351593
0.441450297832489
0.3233305513858795
0.24390047788619995
0.22044554352760315
0.22707629203796387
0.2303728461265564
0.213293194770813
0.18094587326049805
0.1485828012228012
0.12683972716331482
0.1154719889163971
0.1087413877248764
0.1030903235077858
0.09713034331798553
0.09020166844129562
0.08255895227193832
0.0746665820479393
0.06700697541236877
0.060473062098026276
0.0556534081697464
0.052151452749967575
0.04901987686753273
0.04567812383174896
0.04223198443651199
0.03900641202926636
0.036190636456012726
0.03387502580881119
0.03193432092666626
0.02997606061398983
0.02775643579661846
0.02551618218421936
0.023725129663944244
0.022550690919160843
0.021712632849812508
0.02080312930047512
0.019630666822195053
0.018331628292798996
0.01719891093671322
0.016365526244044304
0.015685776248574257
0.014978715218603611
0.014289528131484985
0.013769224286079407
0.01338463556021452
0.0129398163408041
0.012

  0%|          | 0/500 [00:00<?, ?it/s]

1.5786683559417725
0.6971840262413025
0.2674447298049927
0.3575258255004883
0.40090447664260864
0.3147265911102295
0.22220240533351898
0.18497395515441895
0.18894967436790466
0.19822090864181519
0.19116708636283875
0.1663559228181839
0.13551262021064758
0.11233392357826233
0.10204790532588959
0.09966418892145157
0.09699452668428421
0.08981503546237946
0.07966732978820801
0.07046867907047272
0.06435626745223999
0.06069299206137657
0.0575157068669796
0.05334119126200676
0.0482572466135025
0.04366216063499451
0.040567539632320404
0.038465745747089386
0.036190494894981384
0.03334113582968712
0.030488336458802223
0.028396744281053543
0.027278294786810875
0.026597749441862106
0.025550318881869316
0.023887965828180313
0.02214093692600727
0.020905861631035805
0.02014388144016266
0.019330786541104317
0.018176808953285217
0.016962885856628418
0.01613379456102848
0.01572108455002308
0.015356982126832008
0.014760326594114304
0.014015262015163898


### Format results

In [16]:
tr_nses, te_nses = [],[]
for tr_o, te_o in tqdm(zip(outputs, outputs_te)):
    tr_y = tr_discharges[tr_o.columns].reset_index(drop=True)
    te_y = te_discharges[te_o.columns].reset_index(drop=True)
    tr_nses.append(((tr_o-tr_y)**2).mean() / tr_y.var())
    te_nses.append(((te_o-te_y)**2).mean() / te_y.var())

ote = pd.concat(outputs_te, axis=1)
te_nses = 1-pd.concat(te_nses)

0it [00:00, ?it/s]

In [19]:
te_nses.median()

0.99993324