<a href="https://colab.research.google.com/github/aaarrti/ISS/blob/main/notebooks/example_electricity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 1. Train the model

In [1]:
from google.colab import auth

auth.authenticate_user()

Install TFT re-implementation from git, until it is released.

In [4]:
# Hide jax, since we don't use it here.
!python -m pip uninstall --yes jax jaxlib chex optax orbax-checkpoint dopamine-rl
!python -m pip install -U keras_tuner
!python -m pip install 'git+https://github.com/aaarrti/tf2_temporal_fusion_transformer.git@dev'

Found existing installation: jax 0.4.12
Uninstalling jax-0.4.12:
  Successfully uninstalled jax-0.4.12
Found existing installation: jaxlib 0.4.12
Uninstalling jaxlib-0.4.12:
  Successfully uninstalled jaxlib-0.4.12
Found existing installation: chex 0.1.7
Uninstalling chex-0.1.7:
  Successfully uninstalled chex-0.1.7
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/aaarrti/tf2_temporal_fusion_transformer.git@dev
  Cloning https://github.com/aaarrti/tf2_temporal_fusion_transformer.git (to revision dev) to /tmp/pip-req-build-ced2uc3e
  Running command git clone --filter=blob:none --quiet https://github.com/aaarrti/tf2_temporal_fusion_transformer.git /tmp/pip-req-build-ced2uc3e
  Running command git checkout -b dev --track origin/dev
  Switched to a new branch 'dev'
  Branch 'dev' set up to track remote branch

In [3]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2

In [5]:
import tensorflow as tf
import numpy as np
from keras.utils.tf_utils import can_jit_compile, set_random_seed
from keras.callbacks import TensorBoard, TerminateOnNaN, BackupAndRestore
from temporal_fusion_transformer.experiments import electricity_experiment
from temporal_fusion_transformer.tf.quantile_loss import QuantileLoss, QuantileRMSE
from temporal_fusion_transformer.utils import map_dict, filter_dict, make_tft_model
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import keras_tuner as kt
import pickle
import matplotx
from temporal_fusion_transformer.plotting import plot_predictions
from keras.api.keras.experimental import CosineDecay

plt.style.use(matplotx.styles.duftify(matplotx.styles.dracula))

In [6]:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu, experimental_spmd_xla_partitioning=True)
tf.config.list_logical_devices()

[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU_SYSTEM:0', device_type='TPU_SYSTEM'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU'),
 LogicalDevice(name='/job:worker/replica:0/task:0/device:XLA_CPU:0', device_typ

General setup.

In [7]:
prng_seed = 42
epochs = 5
# 8 TPU cores
batch_size = 512 * 8
num_electricity_samples = 1853057
steps_per_epoch = num_electricity_samples // batch_size
set_random_seed(prng_seed)

if can_jit_compile(True):
    tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
    tf.config.optimizer.set_jit("autoclustering")

compute_dtype = tf.keras.mixed_precision.global_policy().compute_dtype
compute_dtype

'bfloat16'

Load data.

In [8]:
element_spec = {
    "identifier": tf.TensorSpec([None, 192, 1], dtype=tf.string),
    "time": tf.TensorSpec([None, 192, 1], dtype=tf.float32),
    "outputs": tf.TensorSpec([None, 24, 1], dtype=tf.float32),
    "inputs_static": tf.TensorSpec([None, 1], dtype=tf.int32),
    "inputs_known_real": tf.TensorSpec([None, 192, 3], dtype=tf.float32),
}

In [9]:
def map_fn(arg):
    return (
        dict(
            static=arg["inputs_static"],
            known_real=tf.cast(arg["inputs_known_real"], compute_dtype),
        ),
        tf.cast(arg["outputs"], compute_dtype),
    )


with strategy.scope():
    train_ds = (
        tf.data.Dataset.from_tensor_slices(
            [f"gs://tf2_tft_v2/data/electricity/train/{i}" for i in range(19)]
        )
        .flat_map(lambda i: tf.data.Dataset.load(i, element_spec=element_spec))
        .rebatch(batch_size, True)
        .map(map_fn, tf.data.AUTOTUNE)
        .shuffle(batch_size, prng_seed, True)
        .cache()
        .repeat(epochs)
        .prefetch(tf.data.AUTOTUNE)
    )

    validation_ds = (
        tf.data.Dataset.from_tensor_slices(
            [f"gs://tf2_tft_v2/data/electricity/validation/{i}" for i in range(3)]
        )
        .flat_map(lambda i: tf.data.Dataset.load(i, element_spec=element_spec))
        .rebatch(batch_size, True)
        .map(map_fn, tf.data.AUTOTUNE)
        .shuffle(batch_size, prng_seed, True)
        .cache()
        .repeat(epochs)
        .prefetch(tf.data.AUTOTUNE)
    )


train_ds.element_spec

({'static': TensorSpec(shape=(4096, 1), dtype=tf.int32, name=None),
  'known_real': TensorSpec(shape=(4096, 192, 3), dtype=tf.bfloat16, name=None)},
 TensorSpec(shape=(4096, 24, 1), dtype=tf.bfloat16, name=None))

Look for the best hyperparameters.

In [None]:
def hyper_model(hyper_params: kt.HyperParameters) -> kt.HyperModel:
    num_attention_heads = hyper_params.Int(
        "num_attention_heads",
        min_value=1,
        max_value=14,
        default=electricity_experiment.default_params.num_attention_heads,
    )
    hidden_layer_size = hyper_params.Int(
        "hidden_layer_size",
        min_value=5,
        max_value=25,
        default=electricity_experiment.default_params.hidden_layer_size,
    )

    grad_clip_norm = hyper_params.Boolean("grad_clip_norm")
    cosine_decay = hyper_params.Boolean("cosine_decay")

    with hyper_params.conditional_scope("grad_clip_norm", True):
        if grad_clip_norm:
            clip_norm_val = hyper_params.Float(
                "clip_norm_val",
                min_value=0.01,
                max_value=10,
                sampling="log",
                default=electricity_experiment.default_params.max_gradient_norm,
            )
    with hyper_params.conditional_scope("grad_clip_norm", False):
        if not grad_clip_norm:
            clip_norm_val = None

    with hyper_params.conditional_scope("cosine_decay", True):
        if cosine_decay:
            init_learning_rate = hyper_params.Float(
                "learning_rate",
                1e-4,
                0.01,

                sampling="log",
                default=electricity_experiment.default_params.learning_rate,
            )
            decay_steps = hyper_params.Float(
                "decay_steps", 0.1 * steps_per_epoch, steps_per_epoch, sampling="log"
            )
            decay_alpha = hyper_params.Float("decay_alpha", 0.1, 0.5, sampling="log")
            learning_rate = CosineDecay(init_learning_rate, decay_steps, decay_alpha)

    with hyper_params.conditional_scope("cosine_decay", False):
        if not cosine_decay:
            learning_rate = hyper_params.Float(
                "learning_rate",
                1e-4,
                0.01,
                sampling="log",
                default=electricity_experiment.default_params.learning_rate,
            )

    hp_model = make_tft_model(
        electricity_experiment,
        unroll_lstm=True,
        num_attention_heads=num_attention_heads,
        hidden_layer_size=hidden_layer_size,
        prng_seed=prng_seed,
    )

    hp_model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=learning_rate,
            jit_compile=can_jit_compile(True),
            clipnorm=clip_norm_val,
        ),
        loss=QuantileLoss(hp_model.quantiles),
        jit_compile=can_jit_compile(True),
        metrics=[QuantileRMSE(hp_model.quantiles)],
    )
    return hp_model


with strategy.scope():
    tuner: kt.Hyperband = kt.Hyperband(
        hyper_model,
        objective=kt.Objective("quantile_rmse", direction="min"),
        max_epochs=epochs,
        factor=3,
        directory="gs://tf2_tft_v2/logs/keras_tuner",
        project_name="tft_electricity",
        max_consecutive_failed_trials=1,
        seed=prng_seed,
    )

    tuner.search(
        train_ds,
        epochs=epochs,
        validation_data=validation_ds,
        callbacks=[
            TensorBoard(
                "gs://tf2_tft_v2/logs/tensorboard",
                update_freq=10,
                write_steps_per_second=True,
            ),
            TerminateOnNaN(),
            BackupAndRestore("gs://tf2_tft_v2/logs/checkpoints"),
        ],
        steps_per_epoch=steps_per_epoch,
    )

    best_hps = tuner.get_best_hyperparameters()[0]

Train model with best hyperparameters.

In [None]:
model = tuner.hypermodel.build(best_hps)
history = model.fit(
    train_ds,
    epochs=epochs,
    validation_data=validation_ds,
    callbacks=[
        TensorBoard(
            "gs://tf2_tft_v2/logs/tensorboard",
            update_freq=10,
            write_steps_per_second=True,
        ),
        TerminateOnNaN(),
        BackupAndRestore("gs://tf2_tft_v2/logs/checkpoints"),
    ],
    steps_per_epoch=steps_per_epoch,
)

Launch tensorboard.

In [None]:
%tensorboard --logdir "gs://tf2_tft_v2/logs/tensorboard/"

Save weights.

In [None]:
model.save_weights("gs://tf2_tft_v2/data/electricity/weights_v1.keras")

# Part 2. Inference and explanations.

In [None]:
with open("../data/electricity/scalers.pickle", "rb") as file:
    target_scaler: StandardScaler = pickle.load(file, fix_imports=True).target["MT_001"]

In [None]:
# TODO: update experiment with best hyper params.
model = make_tft_model(electricity_experiment)
# model.load_weights("gs://tf2_tft_v2/data/electricity/weights/weights")
model.jit_compile = can_jit_compile(True)

In [None]:
batch_size = 64

test_ds = (
    tf.data.Dataset.load("../data/electricity/test/0")
    .rebatch(batch_size, True)
    .as_numpy_iterator()
    .next()
)
map_dict(test_ds, value_mapper=np.shape)

In [None]:
# Make sure we do prediction for 1 entity.
test_ds["identifier"][..., 0]

In [None]:
t = test_ds["time"]
outputs = test_ds["outputs"]


def rename_inputs(arg):
    if arg == "inputs_static":
        return "static"
    if arg == "inputs_known_real":
        return "known_real"


x_batch = map_dict(
    filter_dict(
        test_ds, key_filter=lambda k: k in ("inputs_static", "inputs_known_real")
    ),
    key_mapper=rename_inputs,
)
map_dict(x_batch, np.shape)

In [None]:
logits = model.predict(
    x_batch,
    batch_size=batch_size,
)
# q_01_logits, q_05_logits, q_09_logits = logits[...,0], logits[...,1], logits[...,2]

past_time = t[:, : model.num_encoder_steps, 0]
past_outputs = outputs[:, : model.num_encoder_steps]

look_ahead_time = t[:, model.num_encoder_steps :, 0]
look_ahead_outputs = outputs[:, model.num_encoder_steps :]


def scale_target(_, arr):
    return target_scaler.inverse_transform(arr)


with plt.style.context(matplotx.styles.duftify(matplotx.styles.dracula)):
    plot_predictions(
        predicted_outputs=logits,
        future_timestamps=look_ahead_time,
        past_outputs=past_outputs,
        past_time_stamps=past_time,
        target_scaler=scale_target,
        num_outputs=1,
        quantiles=model.quantiles,
        future_outputs=look_ahead_outputs,
    )
    plt.tight_layout()
    plt.show()