In [20]:
import numpy as np
import pandas as pd
import tensorflow as tf
from typing import Dict

from data import dataset_partitioners
from data.dataset_generators import WindowedDatasetGenerator
from networks.recurrent import AutoRegressiveMultiStepLSTM


config: Dict = {
    "epochs": 2,
    "batch_size": 32,
    "learning_rate": 3e-4,
    "train_data_filepath": "/mnt/c/Users/JPhillips/ldz/data/basic/train.csv",
    "test_data_filepath": "/mnt/c/Users/JPhillips/ldz/data/basic/test.csv",
}


# Load data
train_df = pd.read_csv(config["train_data_filepath"])
test_df = pd.read_csv(config["test_data_filepath"])

# Split training data into train-validation sets
data_splitter = dataset_partitioners.TimeSeriesDataSplitter(train_df)
train_df, validation_df = data_splitter(0.2)

# Build TF datasets
dataset_generator = WindowedDatasetGenerator(
    batch_size=config["batch_size"],
    input_width=24,
    label_width=24,
    shift=1,
    train_df=train_df,
    validation_df=validation_df,
    test_df=test_df,
)
train_dataset = dataset_generator.train_dataset
validation_dataset = dataset_generator.validation_dataset
test_dataset = dataset_generator.test_dataset

# model = tf.keras.load_model(cfg.predict.saved_model_path)
model = AutoRegressiveMultiStepLSTM(
    recurrent_units=256,
    num_features=66,   # Must match train_df.shape[1].
    num_out_steps=24,  # Must match config_data["window_label_width"].
)

model.compile(
    loss = tf.losses.MeanSquaredError(),
    optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4),
)

In [None]:
checkpoint_dir = Path(f"{cfg.logging.checkpoint_dir}/{model.name}")
if not checkpoint_dir.exists():
    Path.mkdir(checkpoint_dir, parents=True)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_dir,
    save_weights_only=False,
    verbose=1,
    monitor="val_loss",
    save_best_only=False,
    save_freq=cfg.logging.save_freq,
)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs/{model.name}", histogram_freq=1)

callbacks = [tensorboard_callback]
if cfg.logging.save_model:
    callbacks.append(checkpoint_callback)

In [None]:
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

In [21]:
history = model.fit(
    train_dataset,
    epochs=2,
    validation_data=validation_dataset,
    # callbacks=callbacks
)

print("HISTORY: \n", history)

Epoch 1/2
Epoch 2/2
HISTORY: 
 <keras.callbacks.History object at 0x7fb7b8423f10>


In [22]:
model.save("tmp/ar_model")

2022-05-17 18:20:34.454497: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: tmp/ar_model/assets


In [None]:
example = 0.5*np.ones((1,24,66))

In [58]:
model2 = tf.keras.models.load_model("tmp/ar_model")
model2.predict(example)

array([[[26.966831  ,  8.68774   ,  2.9315445 , ...,  3.6222932 ,
          1.4841013 ,  1.044779  ],
        [28.055494  ,  9.063206  ,  3.0927482 , ...,  3.8401082 ,
          1.5994511 ,  1.0131922 ],
        [28.159857  ,  9.092141  ,  3.146552  , ...,  3.870171  ,
          1.6423197 ,  0.9317352 ],
        ...,
        [28.193731  ,  9.087407  ,  3.3121045 , ...,  3.95938   ,
          1.6351811 ,  0.8109978 ],
        [28.193344  ,  9.087151  ,  3.3166075 , ...,  3.9619372 ,
          1.6338671 ,  0.80837333],
        [28.192976  ,  9.086931  ,  3.3208897 , ...,  3.9643805 ,
          1.6325773 ,  0.80587566]]], dtype=float32)

In [59]:
model3 = tf.keras.models.load_model("outputs/2022-05-17/16-00-33/saved_models/checkpoints/AutoRegressiveMultiStepLSTM")
model3.predict(example)

array([[[183.75342   ,   4.783211  ,   0.6952013 , ...,   2.3021913 ,
           0.30798957,   2.6254535 ],
        [161.63907   ,   7.3271427 ,   2.0071354 , ...,   2.5737944 ,
           0.373977  ,   2.8834164 ],
        [136.89952   ,   9.210284  ,   2.9288535 , ...,   2.569141  ,
           0.6925335 ,   2.5410297 ],
        ...,
        [ 63.747128  ,  13.833287  ,   4.4672227 , ...,   4.299303  ,
           1.426804  ,   1.5913831 ],
        [ 62.505276  ,  13.949562  ,   4.4759736 , ...,   4.3127565 ,
           1.4014196 ,   1.5984558 ],
        [ 60.562443  ,  14.198771  ,   4.5263495 , ...,   4.2959843 ,
           1.3420253 ,   1.630337  ]]], dtype=float32)