In [None]:
import yaml
import xarray as xr
import os
import pickle
import pandas as pd
from datetime import datetime, timedelta
import numpy as np
import geopandas as gpd

from h2ox.ai.dataset.dataset_factory import DatasetFactory
from h2ox.ai.dataset.dataset import FcastDataset
from h2ox.ai.dataset.utils import group_consecutive_nans

%load_ext autoreload
%autoreload 2

In [None]:
cfg = yaml.load(open('./../conf-all.yaml','r'), Loader=yaml.SafeLoader)

In [None]:
cfg['dataset_parameters']['test_date_ranges'] = [['2020-10-01', '2020-12-31']]

In [None]:
dataset_factory = DatasetFactory(cfg)

In [None]:
# test pytorch dataset build
ptds = dataset_factory.build_dataset()

In [None]:
import torch

In [None]:
from torch.utils.data import DataLoader
from h2ox.ai.dataset import maybe_load
from h2ox.ai.dataset.dataset import train_validation_test_split
from h2ox.ai.dataset.utils import calculate_errors, revert_to_levels
from h2ox.ai.model_gnn import initialise_gnn

In [None]:
item = ptds.__getitem__(0)

In [None]:
dataset_parameters = cfg['dataset_parameters']
model_parameters = cfg['model_parameters']
training_parameters = cfg['training_parameters']

In [None]:
pickle.dump(item, open('./../models/kaveri_dummy_item.pkl','wb'))

In [None]:
item

In [None]:
if dataset_parameters["norm_difference"]:
        var_norms = ptds.augment_dict
        target_var = ptds.target_var[0]

        std_target = dict(
            zip(
                var_norms["std_norm"]["shift_targets_WATER_VOLUME"]["std"].to_dict()[
                    "coords"
                ]["global_sites"]["data"],
                var_norms["std_norm"]["shift_targets_WATER_VOLUME"]["std"].to_dict()[
                    "data"
                ],
            )
        )

In [None]:
print (std_target)

In [None]:
model = initialise_gnn(
            item,
            sites=maybe_load(dataset_parameters["select_sites"]),
            sites_edges=maybe_load(dataset_parameters["sites_edges"]),
            flow_std=std_target,
            device='cpu',
            hidden_size=model_parameters["hidden_size"],
            num_layers=model_parameters["num_layers"],
            dropout=model_parameters["dropout"],
            bayesian_linear=model_parameters["bayesian_linear"],
            bayesian_lstm = model_parameters["bayesian_lstm"],
            lstm_params=model_parameters["lstm_params"],
        )

In [None]:
model.load_state_dict(torch.load('./../experiments_2/sacred/147/model_epoch249.pt'))

In [None]:
train_dd, validation_dd, test_dd = train_validation_test_split(
    ptds,
    cfg=dataset_parameters,
    time_dim="date",
)

# build dataloaders
train_dl = DataLoader(
    train_dd,
    batch_size=training_parameters["batch_size"],
    shuffle=False,
    num_workers=training_parameters["num_workers"],
)
val_dl = DataLoader(
    validation_dd,
    batch_size=training_parameters["batch_size"],
    shuffle=False,
    num_workers=training_parameters["num_workers"],
)
test_dl = DataLoader(
    test_dd,
    batch_size=training_parameters["batch_size"],
    shuffle=False,
    num_workers=training_parameters["num_workers"],
)

In [None]:
def move_to(obj, device):
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        res = {}
        for k, v in obj.items():
            res[k] = move_to(v, device)
        return res
    elif isinstance(obj, list):
        res = []
        for v in obj:
            res.append(move_to(v, device))
        return res
    else:
        raise TypeError("Invalid type for move_to")

In [None]:
batch = next(iter(test_dl))

In [None]:
model = model.eval()

In [None]:
model = model.train(False)

In [None]:
y_hat = model(batch)

In [None]:
y_hat[0,:,2]*krs_diff_std

In [None]:
krs_diff_std = 0.012115440092245034

In [None]:
from serve import H2OxHandler
from ts.context import Context
import xarray as xr
import json

In [None]:
ctx = Context(
    model_name='kaveri', 
    model_dir='/home/lucas/h2ox-ai/models/', 
    manifest = {'model': 
                {
                    'serializedFile': 'kaveri.pt',
                }
               }, 
    batch_size=8, 
    gpu='0',
    mms_version=0.1, 
    limit_max_image_pixels=True
)

In [None]:
inst = H2OxHandler()
inst.initialize(ctx)

In [None]:
sample_data = json.load(open('./../data/kaveri_sample_2020_10_01.json','r'))

In [None]:
inps = inst.preprocess(sample_data)

In [None]:
batch.keys()

In [None]:
for kk,vv in batch.items():
    if kk not in ['y','meta']:
        for kk2, vv2 in vv.items():
            if kk2 !='y':
                print (kk,kk2,inps['2020-10-01'][kk][kk2].shape, vv2.shape)
                print (np.isclose(inps['2020-10-01'][kk][kk2][0,-4:,:].cpu().numpy(), vv2[0,-4:,:].cpu().numpy()).all())

In [None]:
model = model.to('cuda:0')

In [None]:
model.device

In [None]:
for kk in batch.keys():
    if kk not in ['y','meta']:
        for kk2 in batch[kk].keys():
            print (batch[kk][kk2].device)
            batch[kk][kk2] = batch[kk][kk2].to('cuda:0')
            print (batch[kk][kk2].device)

In [None]:
model(inps['2020-10-01'])

In [None]:
inst.model.load_state_dict(torch.load('./../experiments_2/sacred/147/model_epoch249.pt'))

In [None]:
inst.model.eval()

In [None]:
weights = torch.load('./../experiments_2/sacred/147/model_epoch249.pt')

In [None]:
name

In [None]:
weights[name]

In [None]:
inst_weights = {name:W for name, W in inst.model.named_parameters()}

In [None]:
inst_weights[name].cpu().detach().numpy()

In [None]:
weights[name].cpu().numpy()

In [None]:
[kk for kk in inst_weights.keys() if kk not in weights.keys()]

In [None]:
weights['encoders.narayanapura.lstms.0.weight_ih_sampler.eps_w'].shape

In [None]:
with torch.no_grad():
    print (inst.model[0])

In [None]:
inst.model

In [None]:
missing_keys = [kk for kk in weights.keys() if kk not in inst_weights.keys()]

In [None]:
missing_keys

In [None]:
inst_keys = inst.model.state_dict().keys()

In [None]:
inst.model.state_dict()[kk].data

In [None]:
weights[kk]

In [None]:
with torch.no_grad():
    for kk in missing_keys:
        inst.model.state_dict()[kk].data = weights[kk]

In [None]:
weights[kk]

In [None]:
inst.model.state_dict()[kk]

In [None]:
with torch.no_grad():
    for layer in mask_model.state_dict():
        print(layer)
        #print(torch.ones_like(mask_model.state_dict()[layer].data))
        mask_model.state_dict()[layer].data.fill_(1)

In [None]:
[(kk in inst_keys) for kk in missing_keys]

In [None]:
for kk, vv in weights.items():
    print (kk)
    np.isclose(inst_weights[kk].cpu().detach().numpy(), vv.cpu().numpy())

In [None]:
y_ts = inst.model(inps['2020-10-01'])

In [None]:
y_ts[0,:,2]

In [None]:
y_hat[0,:,6]