In [1]:
%load_ext autoreload
%autoreload 2

from hydroml.data.camels_aus_ds import get_dataset
from hydroml.config.config import Config
from hydroml.models.lstm import HydroLSTM
from hydroml.models import get_model
import pandas as pd

config = Config(,
                target_features=['streamflow_mmd'],
                dynamic_features=['precipitation_AWAP', 'et_morton_wet_SILO'],
                static_features=[],
                evolving_static_features={'dynamic_feature_mean': {'dynamic_features': ['precipitation_AWAP']}}, 
                evolving_metadata={'observed_target_std': {'target_features': ['streamflow_mmd']}}, # can be set automatically
                batch_size=16,)

dataset = get_dataset("Z://Data//CAMELS_AUS//preprocessed", config, 'cal')
dataloader = dataset.to_dataloader()
model = HydroLSTM(config) # or get_model(config)


                                                                         

In [59]:
# test forward pass
data = dataset[3]

x_dynamic = data['x_dynamic']
x_static = data['x_static']
y = data['y']
date =  pd.to_datetime(data['date'])
catchment_id = data['catchment_id']
metadata = data['metadata']
y_hat = model.forward(x_dynamic, x_static)
y, y_hat, catchment_id, metadata, date


(tensor([4.3113]),
 tensor(-0.1323, grad_fn=<SqueezeBackward0>),
 '410730',
 {'observed_target_std': tensor([1.4631])},
 Timestamp('2011-01-14 00:00:00'))

In [78]:
# test loss
for batch in dataloader:
    loss = model.loss(batch, 0)
    break
print(loss)

tensor(2.1164, grad_fn=<MeanBackward0>)


In [182]:
# test prediction step
# this only predicts one batch at a time
prediction = model.predict_step(batch, 0, 0)
prediction.shape


torch.Size([10, 1, 1])

In [181]:
# make predictions for all catchments
from hydroml.prediction.prediction import process_and_convert_dataloader_to_xarray
ds = process_and_convert_dataloader_to_xarray(dataloader, model)
ds