In [1]:
import pandas as pd
import numpy as np
from tqdm.notebook import trange, tqdm
import jax
import optax
import matplotlib.pyplot as plt
import importlib
from pathlib import Path

# Force reload of project files
import data, train, models, metrics
importlib.reload(data)
from data import DataLoader


# Params
features_dict = {'daily':['gage_Q'],
                'irregular':['Blue','Green','Red','Nir','Swir1','Swir2']}
target_label = 'tss'
time_slice = slice('2008-01-01', '2023-12-31')
split_time =  np.datetime64('2017-01-01')
sequence_length = 90
batch_size = 256
num_epochs = 50
hidden_size = 64
output_size = 1

data_dir = Path("../data/Sag")

dataloader = DataLoader(data_dir = data_dir,
                        basins = ['sag_daily_data','sag_daily_data_2'],
                        features_dict = features_dict,
                        target = target_label,
                        time_slice = time_slice,
                        split_time = split_time,
                        batch_size = batch_size, 
                        sequence_length = sequence_length,
                        discharge_col = 'gage_Q',
                        range_norm_cols = ['gage_Q','tss'])
                        # log_norm_cols = ['gage_Q','tss'])

In [None]:
importlib.reload(models)
importlib.reload(train)
from models import TAPLSTM
from train import make_step, lr_dict_scheduler

lr_schedule = { 
    0: 0.005,
    30: 0.001}

# Initialize the model
key = jax.random.PRNGKey(0)
model = TAPLSTM(daily_in_size=len(features_dict['daily']),
                irregular_in_size=len(features_dict['irregular']),
                static_in_size=dataloader.x_s['sag_daily_data'].shape[0],
                out_size=output_size, 
                hidden_size=hidden_size, 
                key=key)

# Initialize optimizer 
current_lr = lr_dict_scheduler(0, lr_schedule)
optim = optax.adam(current_lr)
opt_state = optim.init(model)

# Training loop
loss_list = []
pbar = trange(num_epochs, desc="Epoch")
for epoch in pbar:
    # Update learning rate and optimizer
    current_lr = lr_dict_scheduler(epoch, lr_schedule)
    optim = optax.adam(current_lr)
    
    total_loss = 0
    num_batches = 0
    dataloader.train = True
    for _, _, batch in dataloader:
        loss, model, opt_state = make_step(model, batch, opt_state, optim,
                                           loss_name="mse",
                                           max_grad_norm=3,
                                           l2_weight = 1E-5)
        total_loss += loss
        num_batches += 1
            
    current_loss = total_loss / num_batches
    loss_list.append(current_loss)
    pbar.set_postfix_str(f"Loss: {current_loss:.4f}")

plt.plot(loss_list)
plt.ylabel("Loss")
plt.xlabel("Epoch")

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
import equinox as eqx


@eqx.filter_jit
def predict(model, batch):
    return jax.vmap(model)(batch)

basins = []
dates = []
y_hat = []
# Predict on the test data
dataloader.train = False
for basin, date, batch in tqdm(dataloader):
    basins.extend(basin)
    dates.extend(date)
    y_hat.extend(predict(model,batch))

# Create a multi-index
multi_index = pd.MultiIndex.from_arrays([basins,dates],names=['basin','date'])
y_hat = np.array(y_hat).flatten()

# Create a DataFrame with the multi-index
predictions = pd.DataFrame({target_label: y_hat}, index=multi_index)

In [None]:
# Plot the true values and predictions
fig, ax1 = plt.subplots(figsize=(12, 6))
predictions.unstack(level='basin').plot(ax=ax1)

plt.show()

In [None]:
dataloader.scale