In [1]:
import jax
import torch
import torch.utils.data as torchdata
import numpy as np
import jax.numpy as jnp

import sys
sys.path.append("/project")
from model.hydro.gr4j_prod import ProductionStorage as TProd
from model.hydro.gr4j_prod_flax import ProductionStorage as JProd
from data.camels_sampler import CamelsBatchSampler
from data.utils import read_dataset_from_file, get_station_list

In [2]:
data_dir = '/data/camels/aus/'
sub_dir = 'no-scale'
station_id = '105105A'
run_dir = '/project/results/hygr4j'

window_size = 7
x1 = 530.927 
s_init = 0.0

batch_size = 256

In [3]:
print(f"Reading data for station_id: {station_id}")
train_ds, val_ds = read_dataset_from_file(data_dir, 
                                          sub_dir, 
                                          station_id=station_id)

Reading data for station_id: 105105A


In [4]:
t_train, X_train, y_train = train_ds.tensors
t_val, X_val, y_val = val_ds.tensors

y_mu = y_train.mean(dim=0)
y_sigma = y_train.std(dim=0)

In [5]:
X_train = torch.nan_to_num(X_train)
X_val = torch.nan_to_num(X_val)

y_train = (y_train - y_mu)/y_sigma
y_val = (y_val - y_mu)/y_sigma

Test with dummy data

In [6]:
x = torch.arange(100).reshape((-1, 1))
X = torch.concat([x, x], dim=1)
y = torch.arange(100).reshape((-1, 1))

In [7]:
def t_create_sequence(X, y, window_size):

        assert window_size is not None, "Window size cannot be NoneType."

        # Create empyty sequences
        Xs, ys = [], []

        # Add sequences to Xs and ys
        for i in range(1, len(X) - window_size):
            Xs.append(torch.concat([
                                    X[i: (i + window_size)], 
                                    y[i-1: (i + window_size - 1)]
                                ], dim=1)
                    )
            ys.append(y[i + window_size - 1])

        Xs, ys = torch.stack(Xs), torch.stack(ys)

        return Xs, ys

In [8]:
X_seq, y_seq = t_create_sequence(X, y, window_size=window_size)

In [9]:
train_ds = torchdata.TensorDataset(X, y)
train_batch_sampler = CamelsBatchSampler(train_ds, batch_size=batch_size, 
                                         window_size=window_size,
                                         drop_last=False)

train_dl = torchdata.DataLoader(train_ds, batch_sampler=train_batch_sampler, 
                                num_workers=2, prefetch_factor=2)

In [10]:
prod_seq = []
y_out = []

for (X, q) in train_dl:
    x_seq, y = t_create_sequence(X, q, window_size=window_size)
    prod_seq.append(x_seq)
    y_out.append(y)

In [11]:
A = torch.concat(prod_seq)
B = X_seq
torch.equal(A, B)

True

In [12]:
X_torch = X_train
y_torch = y_train

In [13]:
X_np = X_torch.detach().numpy()
y_np = y_torch.detach().numpy()

In [14]:
j_prod_store = JProd(x1_init=x1,
                     s_init=s_init,
                     scale=1.0)

params = j_prod_store.init(rngs={'params': jax.random.PRNGKey(0)}, x=jnp.ones((1, 5)))
params



FrozenDict({
    params: {
        x1: DeviceArray([530.927], dtype=float32),
    },
})

In [15]:
j_prod_out, j_s_store = j_prod_store.apply(params, X_np)
j_prod_out.shape, j_s_store.shape

((8397, 9), (8397,))

In [16]:
ds = torchdata.TensorDataset(X_torch, y_torch)
batch_sampler = CamelsBatchSampler(ds, batch_size=batch_size, 
                                         window_size=window_size,
                                         drop_last=False)

dl = torchdata.DataLoader(ds, batch_sampler=batch_sampler, 
                          num_workers=2, prefetch_factor=2)

In [17]:
def j_create_sequence(X, y, window_size):

        assert window_size is not None, "Window size cannot be NoneType."

        # Create empyty sequences
        Xs, ys = [], []

        # Add sequences to Xs and ys
        for i in range(1, len(X) - window_size):
            Xs.append(jnp.concatenate([
                                    X[i: (i + window_size)], 
                                    y[i-1: (i + window_size - 1)]
                                ], axis=1)
                    )
            ys.append(y[i + window_size - 1])

        Xs, ys = jnp.stack(Xs), jnp.stack(ys)
        # Xs = torch.unsqueeze(Xs, dim=3)

        return Xs, ys

In [18]:
j_prod_batch_out = []
j_prod_batch_seq = []
j_prod_batch_y = []
j_batch_input = []

j_prod_store.s_init = s_init

for i, (X, q) in enumerate(dl, start=1):
    
    X = X.detach().numpy()
    q = q.detach().numpy()

    out, s_store  = j_prod_store.apply(params, X)

    x_seq, y = j_create_sequence(out, q, window_size=window_size)
    j_prod_batch_seq.append(x_seq)
    j_prod_batch_y.append(y)

    if i==1:
        j_prod_batch_out.append(out)
        j_batch_input.append(X)
    else:
        j_prod_batch_out.append(out[window_size+1:])
        j_batch_input.append(X[window_size+1:])

    j_prod_store.s_init = s_store[-window_size-2]/(params['params']['x1']*j_prod_store.scale)


In [19]:
j_prod_seq, j_prod_y = j_create_sequence(j_prod_out, y_np,
                                         window_size=window_size)

In [20]:
np.testing.assert_almost_equal(j_batch_input[0], X_np[:batch_size])

In [21]:
np.testing.assert_almost_equal(j_batch_input[1], X_np[batch_size:2*batch_size])

In [22]:
np.testing.assert_almost_equal(np.concatenate(j_batch_input), X_np)

In [23]:
np.testing.assert_almost_equal(np.concatenate(j_prod_batch_out), j_prod_out)

In [25]:
np.testing.assert_almost_equal(np.concatenate(j_prod_batch_seq), j_prod_seq,
                               decimal=5)

In [26]:
np.testing.assert_almost_equal(np.concatenate(j_prod_batch_y), j_prod_y,
                               decimal=5)