Caveat: bokeh plots will not be rendered in GitHub preview

In [30]:
import math
from datetime import datetime

import hvplot
import jax
import numpy as np
import optax
import polars as pl
import holoviews as hv
from polars import selectors
from bokeh.models import DatetimeTickFormatter
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import FunctionTransformer, StandardScaler
import jax.numpy as jnp
from flax.training.early_stopping import EarlyStopping
import temporal_fusion_transformer as tft
from toolz import functoolz
import gc
from typing import NamedTuple
from copy import deepcopy

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
xformatter = DatetimeTickFormatter(months="%b %Y")
hvplot.extension("bokeh")

In [15]:
TOTAL_TIME_STEPS = 9
NUM_ENCODER_STEPS = 6
FORECAST_MONTHS = 6

## 1) Load data

In [16]:
df = (
    pl.read_csv("../data/air_passengers/AirPassengers.csv", try_parse_dates=True)
    .with_columns(id=0)
    .with_columns(pl.col("Month").str.to_date("%Y-%m"))
    .with_columns(pl.col("Month").dt.month_end())
    .sort("Month")
    .upsample("Month", every="1mo")
    .rename({"#Passengers": "y", "Month": "ts"})
    .select("id", "ts", "y")
    .with_columns(pl.col("ts").dt.month().alias("month"), pl.col("ts").dt.year().alias("year"))
    .with_columns(
        pl.col("month").alias("month_sin").sin(), pl.col("month").alias("month_cos").cos()
    )
    .drop("month")
)
df.head()

id,ts,y,year,month_sin,month_cos
i32,date,i64,i32,f64,f64
0,1949-01-31,112,1949,0.841471,0.540302
0,1949-02-28,118,1949,0.909297,-0.416147
0,1949-03-31,132,1949,0.14112,-0.989992
0,1949-04-30,129,1949,-0.756802,-0.653644
0,1949-05-31,121,1949,-0.958924,0.283662


In [17]:
df["ts"].min(), df["ts"].max()

(datetime.date(1949, 1, 31), datetime.date(1960, 12, 31))

In [18]:
def plot_line(dataframe: pl.DataFrame, **kwargs) -> hv.Layout:
    kw = {
        "y": "y",
        "x": "ts",
        "xformatter": xformatter,
        "legend": True,
        "grid": True,
        "height": 200,
        **dict(**kwargs),
    }
    return dataframe.plot.line(**kw) * dataframe.plot.scatter(**kw)


def split_df(dataframe: pl.DataFrame, test_months: int) -> tuple[pl.DataFrame, pl.DataFrame]:
    test_steps = test_months - 1
    dataframe = dataframe.with_columns(
        pl.col("ts").max().dt.offset_by(f"-{test_steps+NUM_ENCODER_STEPS}mo").alias("test_boundary")
    ).with_columns(pl.col("ts").max().dt.offset_by(f"-{test_steps}mo").alias("train_boundary"))
    train_dataframe = dataframe.filter(
        pl.col("ts").dt.date() <= pl.col("train_boundary").dt.date()
    ).drop("test_boundary")
    test_dataframe = dataframe.filter(
        pl.col("ts").dt.date() >= pl.col("test_boundary").dt.date()
    ).drop("test_boundary")
    print(f"{len(train_dataframe) = }, {len(test_dataframe) = }")
    return train_dataframe, test_dataframe


def plot_split(dataframe: pl.DataFrame, test_months: int):
    train_dataframe, test_dataframe = split_df(dataframe, test_months)
    train_dataframe = train_dataframe
    test_dataframe = test_dataframe

    return (plot_line(train_dataframe) + plot_line(test_dataframe, color="red")).cols(1)


plot_split(df, FORECAST_MONTHS)

len(train_dataframe) = 139, len(test_dataframe) = 12


Note, that we need to overlap `NUM_ENCODER_STEPS` month

In [19]:
raw_train_df, raw_test_df = split_df(df, FORECAST_MONTHS)

len(train_dataframe) = 139, len(test_dataframe) = 12


## 2) Prepare inputs

In [20]:
class Preprocessor:
    def __init__(self):
        self.target = StandardScaler()
        self.real = StandardScaler()

    def fit(self, dataframe: pl.DataFrame):
        self.real.fit(dataframe.select("year").to_numpy(order="c"))
        self.target.fit(dataframe.select("y").to_numpy(order="c"))

    def transform(self, dataframe: pl.DataFrame) -> pl.DataFrame:
        real_arr = self.real.transform(dataframe.select("year").to_numpy(order="c"))
        target_arr = self.target.transform(dataframe.select("y").to_numpy(order="c"))

        dataframe = dataframe.drop(["year", "y"]).with_columns(
            year=real_arr[..., 0], y=target_arr[..., 0]
        )
        return dataframe

    def inverse_transform(self, dataframe: pl.DataFrame) -> pl.DataFrame:

        real_arr = self.real.inverse_transform(dataframe.select("year").to_numpy(order="c"))
        target_arr = self.target.inverse_transform(dataframe.select("y").to_numpy(order="c"))
        dataframe = (
            dataframe.drop(["year", "y"])
            .with_columns(YEAR=real_arr[..., 0], HICP=real_arr[..., 1], y=target_arr[..., 0])
            .with_columns(pl.col("year").cast(pl.Int32), pl.col("y").cast(pl.Int32))
        )
        for i in ["yhat", "yhat_low", "yhat_up"]:

            if i in dataframe.columns:
                y_hat = self.target.inverse_transform(dataframe.select(i).to_numpy(order="c"))

                dataframe = (
                    dataframe.drop(i)
                    .with_columns(y_q=y_hat[..., 0])
                    .rename({"y_q": i})
                    .with_columns(pl.col(i).cast(pl.Int32))
                )
        return dataframe


preprocessor = Preprocessor()
preprocessor.fit(raw_train_df)

train_scaled_df = preprocessor.transform(raw_train_df)
test_scaled_df = preprocessor.transform(raw_test_df)

In [21]:
def df_to_arr(dataframe: pl.DataFrame) -> np.ndarray:
    return dataframe.select(
        # we still need some id
        "id",
        "year",
        "month_sin",
        "month_cos",
        "y",
    ).to_numpy(order="c")


train_arr = df_to_arr(train_scaled_df)
test_arr = df_to_arr(test_scaled_df)
train_arr.shape, test_arr.shape

((139, 5), (12, 5))

In [22]:
xy_train = tft.utils.timeseries_from_array(train_arr, TOTAL_TIME_STEPS)
xy_test = tft.utils.timeseries_from_array(test_arr, TOTAL_TIME_STEPS)
xy_train.shape, xy_test.shape

((131, 9, 5), (4, 9, 5))

In [23]:
x_train, y_train = tft.utils.unpack_xy(xy_train, encoder_steps=NUM_ENCODER_STEPS)
x_test, y_test = tft.utils.unpack_xy(xy_test, encoder_steps=NUM_ENCODER_STEPS)
x_train.shape, y_train.shape, x_test.shape, y_test.shape

((131, 9, 4), (131, 3, 1), (4, 9, 4), (4, 3, 1))

## 3) Train TFT model

tbh, you are probably better of hard-coding your own embedding layer, 
since for loops take forever to compile

In [24]:
init_key, dropout_key, shuffle_key = jax.random.split(jax.random.PRNGKey(69), 3)

batch_size = 8
num_epochs = 50
latent_dim = 32


class BestWeights(NamedTuple):
    epoch: int
    weights: dict


embeds = tft.InputEmbedding(
    # id
    input_static_idx=[0],
    # year, month_sin, month_cos
    input_known_real_idx=[1, 2, 3],
    input_known_categorical_idx=[],
    input_observed_idx=[],
    # 1 will return nan
    static_categories_sizes=[2],
    known_categories_sizes=[],
    latent_dim=latent_dim,
)

model = tft.TemporalFusionTransformer(
    total_time_steps=TOTAL_TIME_STEPS,
    num_encoder_steps=NUM_ENCODER_STEPS,
    num_decoder_blocks=1,
    num_attention_heads=4,
    latent_dim=latent_dim,
    num_static_inputs=1,
    num_non_static_inputs=3,
    num_known_inputs=3,
    embedding_layer=embeds,
)

params = model.init(init_key, x_train[:8])
tx = optax.chain(
    optax.adaptive_grad_clip(0.1),
    optax.adam(1e-3),
)
# tx = optax.contrib.mechanize(optax.adam(1e-3))

state = tft.train_lib.TrainState.create(
    apply_fn=model.apply,
    tx=tx,
    params=params["params"],
    prng_key=dropout_key,
)
del params
early_stopping = EarlyStopping(min_delta=0.01, patience=5)
best_weight = BestWeights(0, deepcopy(state.params))

num_train_batches = math.ceil(len(x_train) / batch_size)

for epoch_id in range(num_epochs):
    shuffle_key = jax.random.fold_in(shuffle_key, epoch_id)
    train_loss = []
    test_loss = []

    for step_id, x_batch, y_batch in tft.train_lib.enumerate_batches(x_train, y_train, batch_size):
        state, train_loss_i = tft.train_lib.train_step(state, x_batch, y_batch)
        train_loss.append(train_loss_i)

    for _, x_batch, y_batch in tft.train_lib.enumerate_batches(x_test, y_test, batch_size):
        test_loss.append(tft.train_lib.eval_step(state, x_batch, y_batch))

    train_loss = np.mean(train_loss)
    test_loss = np.mean(test_loss)
    print(
        f"epoch={epoch_id + 1}/{num_epochs},"
        f"train_loss={train_loss:.3f},"
        f"test_loss={test_loss:.3f}"
    )

    early_stopping = early_stopping.update(test_loss)

    if early_stopping.has_improved:
        best_weight = BestWeights(epoch_id, deepcopy(state.params))

    if early_stopping.should_stop:
        print(f"stopping early, restoring best weights from epoch: {best_weight.epoch+1}")
        state = state.replace(params=best_weight.weights)
        break

gc.collect()

epoch=1/50,train_loss=2.482,test_loss=3.847
epoch=2/50,train_loss=3.089,test_loss=4.682
epoch=3/50,train_loss=2.801,test_loss=4.439
epoch=4/50,train_loss=2.636,test_loss=5.538
epoch=5/50,train_loss=2.495,test_loss=5.767
epoch=6/50,train_loss=2.483,test_loss=5.770
epoch=7/50,train_loss=2.433,test_loss=5.534
stopping early, restoring best weights from epoch: 1


0

## 3) Run inference

In [54]:
y_pred: tft.TftOutputs = model.apply({"params": state.params}, x_test)
jax.tree_util.tree_map(jnp.shape, y_pred)

TftOutputs(logits=(4, 3, 1, 3), static_flags=(4, 1), historical_flags=(4, 6, 3), future_flags=(4, 3, 3))

In [55]:
last_6_months = test_scaled_df["ts"].top_k(6)[::-1]
last_6_months

ts
date
1960-07-31
1960-08-31
1960-09-30
1960-10-31
1960-11-30
1960-12-31


In [57]:
def inverse_transform_logits(logits: np.ndarray | jnp.ndarray) -> np.ndarray:
    return preprocessor.target.inverse_transform(tft.utils.time_series_to_array(logits)).reshape(-1)


yhat_low = inverse_transform_logits(y_pred.logits[..., 0])
yhat = inverse_transform_logits(y_pred.logits[..., 1])
yhat_up = inverse_transform_logits(y_pred.logits[..., 2])

yhat_low = np.minimum(yhat, yhat_low)
yhat_up = np.maximum(yhat, yhat_up)

last_month_forecat_df = (
    pl.DataFrame(
        {
            "ts": last_6_months,
            "yhat_low": yhat_low,
            "yhat": yhat,
            "yhat_up": yhat_up,
        }
    )
    .with_columns(selectors.float().round().cast(pl.Int32))
    .join(raw_test_df.select("ts", "y"), on="ts")
    
)
last_month_forecat_df

ts,yhat_low,yhat,yhat_up,y
date,i32,i32,i32,i64
1960-07-31,243,390,486,622
1960-08-31,123,371,452,606
1960-09-30,102,385,446,508
1960-10-31,98,364,457,461
1960-11-30,187,424,547,390
1960-12-31,235,399,482,432


In [64]:
tft.utils.plot_predictions_vs_real(
    df.join(last_month_forecat_df, on="ts", how="left").filter(pl.col("ts").dt.year() >= 1960)
)

### 3.2) Plot feature importance

In [65]:
historical_a_batch = tft.utils.time_series_to_array(y_pred.historical_flags)
historical_a_df = pl.DataFrame(
    {
        "ts": raw_test_df["ts"][:9],
        "year_a": historical_a_batch[..., 0],
        "month_sin_a": historical_a_batch[..., 1],
        "month_cos_a": historical_a_batch[..., 2],
    }
).with_columns(selectors.float() / historical_a_batch.max())
historical_a_df

ts,year_a,month_sin_a,month_cos_a
date,f32,f32,f32
1960-01-31,0.103918,0.137624,0.995212
1960-02-29,0.145449,0.09914,0.992165
1960-03-31,0.096108,0.154483,0.986163
1960-04-30,0.070562,0.496946,0.669246
1960-05-31,0.073506,0.333257,0.829991
1960-06-30,0.07884,0.243731,0.914183
1960-07-31,0.093254,0.1626,0.9809
1960-08-31,0.142106,0.100815,0.993833
1960-09-30,0.115275,0.121479,1.0


In [66]:
tft.utils.plot_feature_importance(historical_a_df, "Historical Feature Importance")

In [67]:
future_a_batch = tft.utils.time_series_to_array(y_pred.future_flags)
future_a_df = pl.DataFrame(
    {
        "ts": raw_test_df["ts"].top_k(6)[::-1],
        "year_a": future_a_batch[..., 0],
        "month_sin_a": future_a_batch[..., 1],
        "month_cos_a": future_a_batch[..., 2],
    }
).with_columns(selectors.float() / future_a_batch.max())
future_a_df

ts,year_a,month_sin_a,month_cos_a
date,f32,f32,f32
1960-07-31,0.091718,0.977007,0.175608
1960-08-31,0.137456,1.0,0.106877
1960-09-30,0.181662,0.973098,0.089574
1960-10-31,0.153823,0.99197,0.09854
1960-11-30,0.085933,0.956186,0.202213
1960-12-31,0.074078,0.813456,0.356799


In [68]:
tft.utils.plot_feature_importance(future_a_df, "Future Feature Importance")

### 3.3) Generate report

In [69]:
def mape(name: str):
    alias = name.replace("yhat", "mape")
    return ((pl.col(name) - pl.col("y")) / pl.col("y")).round(2).abs().alias(alias)


mape_df = last_month_forecat_df.with_columns(
    mape("yhat_low"),
    mape("yhat"),
    mape("yhat_up"),
).select(["ts", "y", "yhat_low", "yhat", "yhat_up", "mape_low", "mape", "mape_up"])
mape_df

ts,y,yhat_low,yhat,yhat_up,mape_low,mape,mape_up
date,i64,i32,i32,i32,f64,f64,f64
1960-07-31,622,243,390,486,0.61,0.37,0.22
1960-08-31,606,123,371,452,0.8,0.39,0.25
1960-09-30,508,102,385,446,0.8,0.24,0.12
1960-10-31,461,98,364,457,0.79,0.21,0.01
1960-11-30,390,187,424,547,0.52,0.09,0.4
1960-12-31,432,235,399,482,0.46,0.08,0.12


This model would need hyperparameter fine-tuning IRL, 
but this example was meant just as demonstration.