In [13]:
import math
from datetime import datetime

import hvplot
import jax
import numpy as np
import optax
import polars as pl
from bokeh.models import DatetimeTickFormatter
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import FunctionTransformer, StandardScaler
import jax.numpy as jnp
from flax.training.early_stopping import EarlyStopping
import temporal_fusion_transformer as tft
from toolz import functoolz
import gc


xformatter = DatetimeTickFormatter(months="%b %Y")
hvplot.extension("bokeh")

In [2]:
total_time_steps = 12
encoder_steps = 9
num_inputs = 3
batch_size = 8
num_epochs = 100

In [3]:
df = (
    pl.read_csv("../data/air_passengers/AirPassengers.csv", try_parse_dates=True)
    .with_columns(id=0)
    .with_columns(pl.col("Month").str.to_date("%Y-%m"))
    .with_columns(pl.col("Month").dt.month_end())
    .sort("Month")
    .upsample("Month", every="1mo")
    .rename({"#Passengers": "y", "Month": "ts"})
    .select("id", "ts", "y")
    .with_columns(pl.col("ts").dt.month().alias("month"), pl.col("ts").dt.year().alias("year"))
)
df.head(10)

id,ts,y,month,year
i32,date,i64,i8,i32
0,1949-01-31,112,1,1949
0,1949-02-28,118,2,1949
0,1949-03-31,132,3,1949
0,1949-04-30,129,4,1949
0,1949-05-31,121,5,1949
0,1949-06-30,135,6,1949
0,1949-07-31,148,7,1949
0,1949-08-31,148,8,1949
0,1949-09-30,136,9,1949
0,1949-10-31,119,10,1949


In [4]:
test_boundary = datetime(1960, 1, 1)

tft.utils.plot_split(df, test_boundary)

In [5]:
preprocessor = ColumnTransformer(
    [
        ("id", "passthrough", ["id"]),
        ("year", StandardScaler(), ["year"]),
        (
            "month",
            FunctionTransformer(
                lambda x: x - 1,
            ),
            ["month"],
        ),
        ("y", StandardScaler(), ["y"]),
    ],
    verbose=True,
)
preprocessor.fit(df)

[ColumnTransformer] ............ (1 of 4) Processing id, total=   0.0s
[ColumnTransformer] .......... (2 of 4) Processing year, total=   0.0s
[ColumnTransformer] ......... (3 of 4) Processing month, total=   0.0s
[ColumnTransformer] ............. (4 of 4) Processing y, total=   0.0s


In [6]:
train_dataframe, test_dataframe = split_dataframe(df, test_boundary)
len(train_dataframe), len(test_dataframe)

(132, 12)

In [7]:
train_arr = preprocessor.transform(train_dataframe)
test_arr = preprocessor.transform(test_dataframe)
train_arr.shape, test_arr.shape

((132, 4), (12, 4))

In [8]:
xy_train = tft.utils.timeseries_from_array(train_arr, total_time_steps)
xy_test = tft.utils.timeseries_from_array(test_arr, total_time_steps)
xy_train.shape, xy_test.shape

((121, 12, 4), (1, 12, 4))

In [9]:
x_train, y_train = tft.utils.unpack_xy(xy_train, encoder_steps=encoder_steps)
x_test, y_test = tft.utils.unpack_xy(xy_test, encoder_steps=encoder_steps)
x_train.shape, y_train.shape, x_test.shape, y_test.shape

((121, 12, 3), (121, 3, 1), (1, 12, 3), (1, 3, 1))

In [14]:
init_key, dropout_key, shuffle_key = jax.random.split(jax.random.PRNGKey(69), 3)

model = tft.TemporalFusionTransformer(
    total_time_steps=total_time_steps,
    num_decoder_blocks=1,
    num_attention_heads=4,
    # id
    input_static_idx=[0],
    # year
    input_known_real_idx=[1],
    # month,
    input_known_categorical_idx=[2],
    input_observed_idx=[],
    num_encoder_steps=9,
    static_categories_sizes=[1],
    known_categories_sizes=[12],
    latent_dim=16,
)

params = model.init(init_key, x_train[:8])
tx = optax.chain(
    optax.adaptive_grad_clip(0.1),
    optax.adam(5e-4),
)
# tx = optax.contrib.mechanize(optax.adam(1e-3))

state = tft.train_lib.TrainState.create(
    apply_fn=model.apply,
    tx=tx,
    params=params["params"],
    prng_key=dropout_key,
)
early_stopping = EarlyStopping(min_delta=0.1)

num_train_batches = math.ceil(len(x_train) / batch_size)

for epoch_id in range(num_epochs):
    shuffle_key = jax.random.fold_in(shuffle_key, epoch_id)
    train_loss = []
    test_loss = []

    for step_id, x_batch, y_batch in tft.train_lib.enumerate_batches(
        x_train, y_train, batch_size, prng_key=shuffle_key
    ):
        state, train_loss_i = tft.train_lib.train_step(state, x_batch, y_batch)
        train_loss.append(train_loss_i)

    for _, x_batch, y_batch in tft.train_lib.enumerate_batches(
        x_test, y_test, batch_size, prng_key=shuffle_key
    ):
        test_loss.append(tft.train_lib.eval_step(state, x_batch, y_batch))

    train_loss = np.mean(train_loss)
    test_loss = np.mean(test_loss)
    if epoch_id == 0 or epoch_id == num_epochs - 1 or epoch_id % 5 == 0:
        print(
            f"epoch={epoch_id + 1}/{num_epochs},"
            f"train_loss={train_loss:.3f},"
            f"test_loss={test_loss:.3f}"
        )

    early_stopping = early_stopping.update(test_loss)
    if early_stopping.should_stop:
        print("stopping early")
        break

gc.collect()

epoch=1/100,train_loss=0.475,test_loss=0.513
epoch=6/100,train_loss=0.128,test_loss=0.145
epoch=11/100,train_loss=0.105,test_loss=0.129
epoch=16/100,train_loss=0.091,test_loss=0.102
epoch=21/100,train_loss=0.079,test_loss=0.097
epoch=26/100,train_loss=0.070,test_loss=0.072
epoch=31/100,train_loss=0.063,test_loss=0.066
epoch=36/100,train_loss=0.048,test_loss=0.063
epoch=41/100,train_loss=0.043,test_loss=0.113
epoch=46/100,train_loss=0.042,test_loss=0.103
epoch=51/100,train_loss=0.036,test_loss=0.136
epoch=56/100,train_loss=0.035,test_loss=0.128
epoch=61/100,train_loss=0.037,test_loss=0.100
epoch=66/100,train_loss=0.036,test_loss=0.086
epoch=71/100,train_loss=0.030,test_loss=0.094
epoch=76/100,train_loss=0.028,test_loss=0.120
epoch=81/100,train_loss=0.030,test_loss=0.123
epoch=86/100,train_loss=0.028,test_loss=0.088
epoch=91/100,train_loss=0.030,test_loss=0.096
epoch=96/100,train_loss=0.025,test_loss=0.079
epoch=100/100,train_loss=0.028,test_loss=0.095


0

In [15]:
predicted: tft.TftOutputs = model.apply({"params": state.params}, x_test)
jax.tree_util.tree_map(jnp.shape, predicted)

TftOutputs(logits=(1, 3, 1, 3), static_flags=(1, 1), historical_flags=(1, 9, 2), future_flags=(1, 3, 2))

In [16]:
predicted_data = {
    "yhat_low": predicted.logits[..., 0],
    "yhat": predicted.logits[..., 1],
    "yhat_up": predicted.logits[..., 2],
}
jax.tree_util.tree_map(jnp.shape, predicted_data)

{'yhat': (1, 3, 1), 'yhat_low': (1, 3, 1), 'yhat_up': (1, 3, 1)}

In [17]:
target_scaler = preprocessor.transformers_[-2]
target_scaler

('y', StandardScaler(), ['y'])

In [18]:
predicted_data = jax.tree_util.tree_map(
    functoolz.compose(
        lambda y_pr: target_scaler[1].inverse_transform(y_pr).reshape(-1),
        tft.utils.time_series_to_array,
    ),
    predicted_data,
)
jax.tree_util.tree_map(jnp.shape, predicted_data)

{'yhat': (3,), 'yhat_low': (3,), 'yhat_up': (3,)}

In [19]:
future_ts = test_dataframe["ts"][9:]
prediction_df = pl.DataFrame(
    {
        "ts": future_ts,
        **predicted_data,
    }
)
prediction_df

ts,yhat,yhat_low,yhat_up
date,f32,f32,f32
1960-10-31,405.769684,383.156769,437.868988
1960-11-30,358.299286,344.405792,396.047089
1960-12-31,395.334137,377.393402,440.140869


In [20]:
test_vs_prediction_df = test_dataframe.join(prediction_df, on=["ts"], how="left").select(
    "id", "ts", "year", "month", "y", "yhat_low", "yhat", "yhat_up"
)
test_vs_prediction_df

id,ts,year,month,y,yhat_low,yhat,yhat_up
i32,date,i32,i8,i64,f32,f32,f32
0,1960-01-31,1960,1,417,,,
0,1960-02-29,1960,2,391,,,
0,1960-03-31,1960,3,419,,,
0,1960-04-30,1960,4,461,,,
0,1960-05-31,1960,5,472,,,
0,1960-06-30,1960,6,535,,,
0,1960-07-31,1960,7,622,,,
0,1960-08-31,1960,8,606,,,
0,1960-09-30,1960,9,508,,,
0,1960-10-31,1960,10,461,383.156769,405.769684,437.868988


In [21]:
tft.utils.plot_predictions_vs_real(test_vs_prediction_df)

In [22]:
features_importance = jax.tree_util.tree_map(
    tft.utils.time_series_to_array,
    tft.FeatureImportance(
        historical_flags=predicted.historical_flags,
        future_flags=predicted.future_flags,
    ),
)

tft.utils.plot_feature_importance(
    test_dataframe["ts"], features_importance, feature_names=["year", "month"]
)