In [1]:
import os
import sys
import time
import argparse
import datetime as dt
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tqdm import tqdm, trange
from sklearn.metrics import mean_squared_error

root_dir = '/srv/scratch/z5370003/projects/DeepGR4J-Extremes'
sys.path.append(root_dir)

from model.tf.ml import ConvNet, LSTM
from model.tf.hydro import ProductionStorage
from data.tf.camels_dataset import HybridDataset
from utils.training import EarlyStopper, Trainer, TiltedLossMultiQuantile
from utils.evaluation import nse, normalize

2024-09-30 09:26:42.731804: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-30 09:26:42.751748: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-30 09:26:42.757704: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-30 09:26:42.774368: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Create argument parser
parser = argparse.ArgumentParser(description='Train Quantile Flow Model')

# Add arguments
parser.add_argument('--window_size', type=int, default=10, help='Size of the sliding window')
parser.add_argument('--camels_dir', type=str, default='../../data/camels/aus', help='Directory containing CAMELS dataset')
parser.add_argument('--gr4j_logfile', type=str, default='../results/gr4j/result.csv', help='GR4J calibration results file')
parser.add_argument('--station_id', type=str, nargs='+', default=None, help='Station ID')
parser.add_argument('--results_dir', type=str, default='../results/flow_cdf', help='Directory to save results')
parser.add_argument('--state_outlet', type=str, default=None, help='State outlet')
parser.add_argument('--map_zone', type=int, default=None, help='Map zone')
parser.add_argument('--ts_model', type=str, default='lstm', help='Time series model')

parser.add_argument('--hidden_dim', type=int, default=32, help='Hidden dimension for LSTM')
parser.add_argument('--lstm_dim', type=int, default=64, help='LSTM dimension')
parser.add_argument('--n_layers', type=int, default=4, help='Number of LSTM layers')
parser.add_argument('--ts_output_dim', type=int, default=8, help='Output dimension for LSTM')
parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate for LSTM')

parser.add_argument('--n_channels', type=int, default=1, help='Number of channels')
parser.add_argument('--n_filters', type=int, nargs='+', default=[16, 16, 8], help='Number of filters for CNN')
parser.add_argument('--cnn_dropout', type=float, default=0.1, help='Dropout rate for CNN')

parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--verbose', type=int, default=1, help='Verbosity level')
parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--beta_1', type=float, default=0.89, help='Beta 1 for Adam optimizer')
parser.add_argument('--beta_2', type=float, default=0.97, help='Beta 2 for Adam optimizer')
parser.add_argument('--weight_decay', type=float, default=2e-2, help='Weight decay for Adam optimizer')
parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping')
parser.add_argument('--min_delta', type=float, default=0.01, help='Minimum change in validation loss for early stopping')

_StoreAction(option_strings=['--min_delta'], dest='min_delta', nargs=None, const=None, default=0.01, type=<class 'float'>, choices=None, help='Minimum change in validation loss for early stopping', metavar=None)

In [3]:
def get_model(args):
    if args.ts_model == 'lstm':
        ts_model = LSTM(window_size=args.window_size,
                        input_dim=args.ts_input_dim,
                        hidden_dim=args.hidden_dim,
                        lstm_dim=args.lstm_dim,
                        n_layers=args.n_layers,
                        output_dim=len(args.quantiles),
                        dropout=args.dropout)
    
    elif args.ts_model == 'cnn':
        ts_model = ConvNet(n_ts=args.window_size,
                           n_features=args.ts_input_dim,
                           n_channels=args.n_channels,
                           out_dim=len(args.quantiles),
                           n_filters=args.n_filters,
                           dropout_p=args.dropout)

    return ts_model

In [4]:
def generate_predictions(model, dl, loss_fn, results_dir):
    preds = []
    true = []
    stations = []
    for step, batch in enumerate(dl):
        # Run the forward pass of the layer.
        if args.ts_model == 'cnn':
            batch['timeseries'] = tf.expand_dims(batch['timeseries'], axis=-1)
        out = model(batch['timeseries'],
                    training=False) 
        preds.append(out)
        true.append(batch['target'])
        stations.append(batch['station_id'])
    preds = tf.concat(preds, axis=0)
    true = tf.concat(true, axis=0)
    stations = tf.concat(stations, axis=0)

    print(f"{results_dir} loss: {loss_fn(true, preds).numpy():.4f}")

    # Convert to numpy array
    preds = camels_ds.target_scaler.inverse_transform(preds.numpy())
    true = camels_ds.target_scaler.inverse_transform(true.numpy())

    # Clip negative values
    preds = np.clip(preds, 0, None)
    stations = stations.numpy().flatten()

    os.makedirs(results_dir, exist_ok=True)
    mse_score = {}
    nse_score = {}
    nnse_score = {}

    for station in np.unique(stations):
        idx = (stations==station)
        fig, ax = plt.subplots(figsize=(16, 6))
        n_outputs = preds.shape[-1]
        ax.plot(true[idx, -1], label=f"True")
        if n_outputs > 1:
            for i in range(n_outputs):
                ax.plot(preds[idx, i], alpha=0.65, label=f"Pred {i}")
            # ax.fill_between(range(len(preds)), preds[:, 0], preds[:, -1], alpha=0.5, color='green')
        else:
            ax.plot(preds[idx, 0], alpha=0.65, label=f"Pred")
        mse_score[station] = mean_squared_error(true[idx, -1], preds[idx, int(n_outputs//2)])
        nse_score[station] = nse(true[idx, -1], preds[idx, int(n_outputs//2)])
        nnse_score[station] = normalize(nse_score[station])
        plt.legend()
        plt.savefig(os.path.join(results_dir, f'{station.decode("utf-8")}.png'))
        plt.close()
    
    return mse_score, nse_score, nnse_score


In [5]:
# Parse arguments
args = parser.parse_args([])

In [6]:
args.ts_model = 'cnn'
args.n_filters = [8, 8, 6]
args.batch_size = 256
args.window_size = 7
args.ts_input_dim = 9
args.camels_dir = '../../../data/camels/aus'
args.gr4j_logfile = '../../results/gr4j/result.csv'
args.station_id = ['616065']
args.results_dir = f'../../results/deepgr4j_{args.ts_model}_test'
args.epochs = 500
args.weight_decay = 0.1
args.dropout = 0.2
args.min_delta = 0.01
args.patience = 10
args.n_layers = 1

ts_vars = ['precipitation_AWAP', 'et_morton_actual_SILO',
           'tmax_awap', 'tmin_awap', 'vprp_awap']

print(args)

Namespace(window_size=7, camels_dir='../../../data/camels/aus', gr4j_logfile='../../results/gr4j/result.csv', station_id=['616065'], results_dir='../../results/deepgr4j_cnn_test', state_outlet=None, map_zone=None, ts_model='cnn', hidden_dim=32, lstm_dim=64, n_layers=1, ts_output_dim=8, dropout=0.2, n_channels=1, n_filters=[8, 8, 6], cnn_dropout=0.1, batch_size=256, epochs=500, verbose=1, learning_rate=0.001, beta_1=0.89, beta_2=0.97, weight_decay=0.1, patience=10, min_delta=0.01, ts_input_dim=9)


In [7]:
# Quantiles
args.quantiles = tf.convert_to_tensor([0.5])

# Production Storage
prod = ProductionStorage()

# Load dataset
camels_ds = HybridDataset(data_dir=args.camels_dir,
                          gr4j_logfile=args.gr4j_logfile,
                          prod=prod,
                          ts_vars=ts_vars,
                          target_vars=['streamflow_mmd'],
                          window_size=args.window_size)

station_list = camels_ds.get_station_list()
results_all = []

# Prepare data
camels_ds.prepare_data(station_list=args.station_id,
                       state_outlet=args.state_outlet,
                       map_zone=args.map_zone)

Fitting scalers


In [8]:
# Get datasets
train_ds, test_ds = camels_ds.get_datasets(test_size=0.3)

In [9]:
# Get model
model = get_model(args)
if args.ts_model == 'lstm':
    model(tf.random.normal((args.batch_size, args.window_size, args.ts_input_dim), dtype=tf.float32))
else:
    model(tf.random.normal((args.batch_size, args.window_size, args.ts_input_dim, 1), dtype=tf.float32))
    
model.summary()

# Define optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate,
                                    beta_1=args.beta_1,
                                    beta_2=args.beta_2,
                                    weight_decay=args.weight_decay,
                                    amsgrad=False,
                                    epsilon=1e-8)

# Define loss function
loss_fn = tf.keras.losses.MeanSquaredError() #TiltedLossMultiQuantile(quantiles=args.quantiles)

# Define early stopper
early_stopper = EarlyStopper(patience=args.patience, min_delta=args.min_delta)

# Define trainer
trainer = Trainer(model, optimizer=optimizer, loss_fn=loss_fn,
                  model_type='ts', early_stopper=early_stopper)

# Train model
model, train_losses, test_losses = trainer.train(train_ds, test_ds, args)

2024-09-30 09:27:48.660706: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-09-30 09:27:48.821888: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 1
Training loss over epoch: 0.0265
Validation loss over epoch: 0.0093
Time taken: 1.64s


2024-09-30 09:27:49.161598: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 2
Training loss over epoch: 0.0077
Validation loss over epoch: 0.0049
Time taken: 0.34s

Epoch 3
Training loss over epoch: 0.0050
Validation loss over epoch: 0.0038
Time taken: 0.21s


2024-09-30 09:27:49.591015: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 4
Training loss over epoch: 0.0038
Validation loss over epoch: 0.0030
Time taken: 0.21s

Epoch 5
Training loss over epoch: 0.0031
Validation loss over epoch: 0.0023
Time taken: 0.34s

Epoch 6
Training loss over epoch: 0.0026
Validation loss over epoch: 0.0019
Time taken: 0.21s

Epoch 7
Training loss over epoch: 0.0022
Validation loss over epoch: 0.0017
Time taken: 0.22s


2024-09-30 09:27:50.589824: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 8
Training loss over epoch: 0.0019
Validation loss over epoch: 0.0013
Time taken: 0.23s

Epoch 9
Training loss over epoch: 0.0017
Validation loss over epoch: 0.0012
Time taken: 0.34s

Epoch 10
Training loss over epoch: 0.0016
Validation loss over epoch: 0.0010
Time taken: 0.24s

Epoch 11
Training loss over epoch: 0.0015
Validation loss over epoch: 0.0009
Time taken: 0.26s

Epoch 12
Training loss over epoch: 0.0014
Validation loss over epoch: 0.0010
Time taken: 0.34s

Epoch 13
Training loss over epoch: 0.0013
Validation loss over epoch: 0.0011
Time taken: 0.23s

Epoch 14
Training loss over epoch: 0.0012
Validation loss over epoch: 0.0008
Time taken: 0.21s

Epoch 15
Training loss over epoch: 0.0012
Validation loss over epoch: 0.0007
Time taken: 0.22s


2024-09-30 09:27:52.659579: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 16
Training loss over epoch: 0.0011
Validation loss over epoch: 0.0007
Time taken: 0.22s

Epoch 17
Training loss over epoch: 0.0011
Validation loss over epoch: 0.0008
Time taken: 0.22s

Epoch 18
Training loss over epoch: 0.0010
Validation loss over epoch: 0.0007
Time taken: 0.23s

Epoch 19
Training loss over epoch: 0.0010
Validation loss over epoch: 0.0007
Time taken: 0.24s

Epoch 20
Training loss over epoch: 0.0009
Validation loss over epoch: 0.0008
Time taken: 0.22s

Epoch 21
Training loss over epoch: 0.0009
Validation loss over epoch: 0.0006
Time taken: 0.21s

Epoch 22
Training loss over epoch: 0.0009
Validation loss over epoch: 0.0006
Time taken: 0.23s

Epoch 23
Training loss over epoch: 0.0008
Validation loss over epoch: 0.0006
Time taken: 0.22s

Epoch 24
Training loss over epoch: 0.0008
Validation loss over epoch: 0.0006
Time taken: 0.22s

Epoch 25
Training loss over epoch: 0.0008
Validation loss over epoch: 0.0005
Time taken: 0.22s

Epoch 26
Training loss over epoch: 0.00

2024-09-30 09:27:56.315387: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 32
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0005
Time taken: 0.34s

Epoch 33
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0004
Time taken: 0.34s

Epoch 34
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0005
Time taken: 0.34s

Epoch 35
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0005
Time taken: 0.22s

Epoch 36
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0005
Time taken: 0.21s

Epoch 37
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0004
Time taken: 0.21s

Epoch 38
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0004
Time taken: 0.21s

Epoch 39
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0005
Time taken: 0.21s

Epoch 40
Training loss over epoch: 0.0006
Validation loss over epoch: 0.0005
Time taken: 0.21s

Epoch 41
Training loss over epoch: 0.0005
Validation loss over epoch: 0.0004
Time taken: 0.23s

Epoch 42
Training loss over epoch: 0.00

2024-09-30 09:28:03.810633: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 65
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.19s

Epoch 66
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.33s

Epoch 67
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0003
Time taken: 0.18s

Epoch 68
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 69
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.17s

Epoch 70
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.17s

Epoch 71
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 72
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 73
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.17s

Epoch 74
Training loss over epoch: 0.0004
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 75
Training loss over epoch: 0.00

2024-09-30 09:28:14.421952: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 128
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 129
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 130
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 131
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 132
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 133
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0005
Time taken: 0.16s

Epoch 134
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 135
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.17s

Epoch 136
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 137
Training loss over epoch: 0.0003
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 138
Training loss over 

2024-09-30 09:28:35.082184: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



Epoch 256
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0005
Time taken: 0.17s

Epoch 257
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 258
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0004
Time taken: 0.17s

Epoch 259
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0005
Time taken: 0.16s

Epoch 260
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 261
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0004
Time taken: 0.17s

Epoch 262
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0004
Time taken: 0.20s

Epoch 263
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0005
Time taken: 0.17s

Epoch 264
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0005
Time taken: 0.16s

Epoch 265
Training loss over epoch: 0.0002
Validation loss over epoch: 0.0004
Time taken: 0.16s

Epoch 266
Training loss over 

In [10]:
# Plot timeseries predictions
if (args.state_outlet is not None) or (args.map_zone is not None):
    results_dir = os.path.join(args.results_dir, f'{args.state_outlet}_{args.map_zone}')
else:
    results_dir = os.path.join(args.results_dir, 'aus')

In [11]:
os.makedirs(args.results_dir, exist_ok=True)
train_mse, train_nse, train_nnse = generate_predictions(
    model, 
    train_ds.batch(args.batch_size),
    loss_fn, os.path.join(results_dir, 'training')
)
print(train_nse, train_nnse)

../../results/deepgr4j_cnn_test/aus/training loss: 0.0002
{b'616065': 0.9359732866287231} {b'616065': 0.9398260282691459}


In [12]:
test_mse, test_nse, test_nnse = generate_predictions(
    model, 
    test_ds.batch(args.batch_size),
    loss_fn, os.path.join(results_dir, 'testing')
)
print(test_nse, test_nnse)

../../results/deepgr4j_cnn_test/aus/testing loss: 0.0004
{b'616065': 0.20985883474349976} {b'616065': 0.5586151636576187}


In [13]:
camels_ds.prod.get_x1()

1200.0

In [14]:
camels_ds.ts_arr

array([[[-0.38690794, -1.1614778 ,  1.8360896 , ..., -0.7012847 ,
         -0.31864512, -0.88869184],
        [-0.38690794,  1.01882   ,  1.5434538 , ...,  1.1764288 ,
         -0.31864512, -0.88869184],
        [-0.38690794,  1.3845892 ,  1.2760656 , ...,  1.4914361 ,
         -0.31864512, -0.88869184],
        ...,
        [-0.36000356,  0.597657  ,  2.1198955 , ...,  0.71985894,
         -0.31864512, -0.88869184],
        [-0.36651716,  1.4623109 ,  1.9932163 , ...,  1.4872379 ,
         -0.31864512, -0.88869184],
        [-0.38690794,  2.157238  ,  0.9718659 , ...,  2.1568558 ,
         -0.31864512, -0.88869184]],

       [[-0.38690794,  1.01882   ,  1.5434538 , ...,  1.1764288 ,
         -0.31864512, -0.88869184],
        [-0.38690794,  1.3845892 ,  1.2760656 , ...,  1.4914361 ,
         -0.31864512, -0.88869184],
        [-0.38690794,  1.4809911 ,  1.8608522 , ...,  1.5744592 ,
         -0.31864512, -0.88869184],
        ...,
        [-0.36651716,  1.4623109 ,  1.9932163 , ...,  