In [None]:
import math
from datetime import datetime

import hvplot
import jax
import numpy as np
import optax
import polars as pl
from bokeh.models import DatetimeTickFormatter
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import FunctionTransformer, StandardScaler
import jax.numpy as jnp
from flax.training.early_stopping import EarlyStopping
import temporal_fusion_transformer as tft
from toolz import functoolz
import gc


xformatter = DatetimeTickFormatter(months="%b %Y")
hvplot.extension("bokeh")

In [None]:
total_time_steps = 12
encoder_steps = 9
num_inputs = 3
batch_size = 8
num_epochs = 100

In [None]:
df = (
    pl.read_csv("../data/air_passengers/AirPassengers.csv", try_parse_dates=True)
    .with_columns(id=0)
    .with_columns(pl.col("Month").str.to_date("%Y-%m"))
    .with_columns(pl.col("Month").dt.month_end())
    .sort("Month")
    .upsample("Month", every="1mo")
    .rename({"#Passengers": "y", "Month": "ts"})
    .select("id", "ts", "y")
    .with_columns(pl.col("ts").dt.month().alias("month"), pl.col("ts").dt.year().alias("year"))
)
df.head(10)

In [None]:
test_boundary = datetime(1960, 1, 1)

tft.utils.plot_split(df, test_boundary)

In [None]:
preprocessor = ColumnTransformer(
    [
        ("id", "passthrough", ["id"]),
        ("year", StandardScaler(), ["year"]),
        (
            "month",
            FunctionTransformer(
                lambda x: x - 1,
            ),
            ["month"],
        ),
        ("y", StandardScaler(), ["y"]),
    ],
    verbose=True,
)
preprocessor.fit(df)

In [None]:
train_dataframe, test_dataframe = split_dataframe(df, test_boundary)
len(train_dataframe), len(test_dataframe)

In [None]:
train_arr = preprocessor.transform(train_dataframe)
test_arr = preprocessor.transform(test_dataframe)
train_arr.shape, test_arr.shape

In [None]:
xy_train = tft.utils.timeseries_from_array(train_arr, total_time_steps)
xy_test = tft.utils.timeseries_from_array(test_arr, total_time_steps)
xy_train.shape, xy_test.shape

In [None]:
x_train, y_train = tft.utils.unpack_xy(xy_train, encoder_steps=encoder_steps)
x_test, y_test = tft.utils.unpack_xy(xy_test, encoder_steps=encoder_steps)
x_train.shape, y_train.shape, x_test.shape, y_test.shape

In [None]:
init_key, dropout_key, shuffle_key = jax.random.split(jax.random.PRNGKey(69), 3)

model = tft.TemporalFusionTransformer(
    total_time_steps=total_time_steps,
    num_decoder_blocks=1,
    num_attention_heads=4,
    # id
    input_static_idx=[0],
    # year
    input_known_real_idx=[1],
    # month,
    input_known_categorical_idx=[2],
    input_observed_idx=[],
    num_encoder_steps=9,
    static_categories_sizes=[1],
    known_categories_sizes=[12],
    latent_dim=16,
)

params = model.init(init_key, x_train[:8])
tx = optax.chain(
    optax.adaptive_grad_clip(0.1),
    optax.adam(5e-4),
)
# tx = optax.contrib.mechanize(optax.adam(1e-3))

state = tft.train_lib.TrainState.create(
    apply_fn=model.apply,
    tx=tx,
    params=params["params"],
    prng_key=dropout_key,
)
early_stopping = EarlyStopping(min_delta=0.1)

num_train_batches = math.ceil(len(x_train) / batch_size)

for epoch_id in range(num_epochs):
    shuffle_key = jax.random.fold_in(shuffle_key, epoch_id)
    train_loss = []
    test_loss = []

    for step_id, x_batch, y_batch in tft.train_lib.enumerate_batches(
        x_train, y_train, batch_size, prng_key=shuffle_key
    ):
        state, train_loss_i = tft.train_lib.train_step(state, x_batch, y_batch)
        train_loss.append(train_loss_i)

    for _, x_batch, y_batch in tft.train_lib.enumerate_batches(
        x_test, y_test, batch_size, prng_key=shuffle_key
    ):
        test_loss.append(tft.train_lib.eval_step(state, x_batch, y_batch))

    train_loss = np.mean(train_loss)
    test_loss = np.mean(test_loss)
    if epoch_id == 0 or epoch_id == num_epochs - 1 or epoch_id % 5 == 0:
        print(
            f"epoch={epoch_id + 1}/{num_epochs},"
            f"train_loss={train_loss:.3f},"
            f"test_loss={test_loss:.3f}"
        )

    early_stopping = early_stopping.update(test_loss)
    if early_stopping.should_stop:
        print("stopping early")
        break

gc.collect()

In [None]:
predicted: tft.TftOutputs = model.apply({"params": state.params}, x_test)
jax.tree_util.tree_map(jnp.shape, predicted)

In [None]:
predicted_data = {
    "yhat_low": predicted.logits[..., 0],
    "yhat": predicted.logits[..., 1],
    "yhat_up": predicted.logits[..., 2],
}
jax.tree_util.tree_map(jnp.shape, predicted_data)

In [None]:
target_scaler = preprocessor.transformers_[-2]
target_scaler

In [None]:
predicted_data = jax.tree_util.tree_map(
    functoolz.compose(
        lambda y_pr: target_scaler[1].inverse_transform(y_pr).reshape(-1),
        tft.utils.time_series_to_array,
    ),
    predicted_data,
)
jax.tree_util.tree_map(jnp.shape, predicted_data)

In [None]:
future_ts = test_dataframe["ts"][9:]
prediction_df = pl.DataFrame(
    {
        "ts": future_ts,
        **predicted_data,
    }
)
prediction_df

In [None]:
test_vs_prediction_df = test_dataframe.join(prediction_df, on=["ts"], how="left").select(
    "id", "ts", "year", "month", "y", "yhat_low", "yhat", "yhat_up"
)
test_vs_prediction_df

In [None]:
tft.utils.plot_predictions_vs_real(test_vs_prediction_df)

In [None]:
features_importance = jax.tree_util.tree_map(
    tft.utils.time_series_to_array,
    tft.FeatureImportance(
        historical_flags=predicted.historical_flags,
        future_flags=predicted.future_flags,
    ),
)

tft.utils.plot_feature_importance(
    test_dataframe["ts"], features_importance, feature_names=["year", "month"]
)