## Taipei YouBike Forecasting with the Temporal Fusion Transformer (TFT)

This notebook adapts the standard TFT workflow for the specific task of forecasting the `occupancy_ratio` of Taipei's YouBike stations. It uses the feature-engineered dataset we prepared previously, which includes time-based features, holiday information, station clusters, and weather data.

The key steps are:
1.  **Load the pre-processed data.**
2.  **Create the `TimeSeriesDataSet`**, which is a special data loader from `pytorch-forecasting` that handles the creation of sequences (the "sliding window" we discussed).
3.  **Configure and train the TFT model** using PyTorch Lightning.
4.  **Evaluate the model** by visualizing its predictions against the actual data.
5.  **Interpret the model** to see which features it found most important.

In [None]:
import pandas as pd
import numpy as np
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from lightning.pytorch.loggers import TensorBoardLogger
import torch

# Set a seed for reproducibility
pl.seed_everything(42)

### 1. Load the Pre-processed Data

We load the `model_ready_dl_features.parquet.gz` file created in the previous step. This file contains all our engineered features, with categorical variables one-hot encoded and numerical variables normalized.

In [None]:
file_path = "/content/drive/MyDrive/Youbike_Master_Project/YouBike_Demand_Forecast/data/model_ready_dl_features.parquet.gz"
data = pd.read_parquet(file_path)

# Pytorch Forecasting requires a specific data type for the target variable if using a normalizer
data['occupancy_ratio'] = data['occupancy_ratio'].astype(np.float32)

print("Data loaded successfully.")
data.info()

### 2. Create a Time Index

The `pytorch-forecasting` library requires a continuous integer index representing the time steps. We create this by sorting our data and then assigning a cumulative count for each station.

In [None]:
data.sort_values(['sno', 'time'], inplace=True)
data['time_idx'] = data.groupby('sno').cumcount()

print("Time index created.")
print(data[['sno', 'time', 'time_idx']].head())

### 3. Define Model Parameters and Create the `TimeSeriesDataSet`

This is the most critical step. We define the sliding window parameters and configure the `TimeSeriesDataSet`, which will handle the memory-efficient creation of training sequences.

-   `max_encoder_length`: How many past time steps the model uses for its prediction (the history). 24 hours = 144 steps (24 * 6).
-   `max_prediction_length`: How many future time steps the model will forecast. 6 hours = 36 steps (6 * 6).
-   **Feature Types:** We must tell the model which features are which:
    -   `target`: The variable we want to predict (`occupancy_ratio`).
    -   `group_ids`: The identifier for each time series (`sno`).
    -   `static_features`: Features that are constant for each station (like its cluster ID).
    -   `time_varying_known_reals`: Features we know in advance for the future (e.g., we know the hour of the day and day of the week for tomorrow).
    -   `time_varying_unknown_reals`: Features we don't know in advance (e.g., the weather tomorrow, or the occupancy ratio itself).

In [None]:
# Define sliding window parameters
max_prediction_length = 6 * 6 # Predict 6 hours ahead
max_encoder_length = 24 * 6 # Use 24 hours of history

# Define the training/validation split point
# We'll use the last month of data for validation
training_cutoff = data["time_idx"].max() - (30 * 24 * 6) # 30 days of 10-min intervals

# Define all feature columns
static_features = [col for col in data.columns if 'cluster_is_' in col]
time_varying_known_features = [col for col in data.columns if 'day_is_' in col] + ['hour', 'day_of_week']
time_varying_unknown_features = ['occupancy_ratio', 'Temperature', 'Dew Point', 'Humidity', 'Speed', 'Pressure']

# Create the TimeSeriesDataSet for training
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="occupancy_ratio",
    group_ids=["sno"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=static_features,
    time_varying_known_reals=time_varying_known_features,
    time_varying_unknown_reals=time_varying_unknown_features,
    target_normalizer=GroupNormalizer(
        groups=["sno"], transformation="softplus"
    ),  # Use softplus to ensure predictions are non-negative
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

# Create the TimeSeriesDataSet for validation
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)

# Create dataloaders for model training
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=2)

print("TimeSeriesDataSets and DataLoaders created successfully.")

### 4. Configure and Train the TFT Model

Now we set up the model using PyTorch Lightning. 
- We use `EarlyStopping` to monitor the validation loss and stop training automatically when the model stops improving, which prevents overfitting.
- A `TensorBoardLogger` is used to log the training progress, which can be visualized later.
- The `TemporalFusionTransformer.from_dataset()` method conveniently creates a model with parameters tailored to our specific dataset.

In [None]:
# Configure the trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
logger = TensorBoardLogger("lightning_logs")

trainer = pl.Trainer(
    max_epochs=50,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    enable_model_summary=True,
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback],
    logger=logger,
)

# Define the TFT model
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,  # 7 quantiles by default
    loss=QuantileLoss(),
    log_interval=10, 
    reduce_on_plateau_patience=4,
)

print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# Start training
# This step can take a very long time depending on your hardware.
print("Starting model training...")
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

### 5. Evaluate and Visualize Predictions

After training, we load the best model (based on the lowest validation loss) and use it to make predictions on the validation set. Visualizing these predictions is the best way to see how well the model has learned the patterns.

In [None]:
# Load the best model from the checkpoint
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

# Make predictions on the validation set
raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)

# Visualize some examples
print("Plotting predictions for 5 random examples from the validation set...")
for i in range(5):
    fig, ax = plt.subplots(figsize=(10, 5))
    best_tft.plot_prediction(x, raw_predictions, idx=i, ax=ax, add_loss_to_title=True)
    plt.show()

### 6. Interpret Model Behavior

A key advantage of the TFT is its interpretability. We can look at the built-in feature importance plots to understand what the model learned.

-   **Encoder Importance:** Which past features were most important for making a prediction.
-   **Decoder Importance:** Which future-known features were most important.

In [None]:
# Calculate and plot interpretation
interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
figs = best_tft.plot_interpretation(interpretation)

# Display the feature importance plots
for key, value in figs.items():
    print(f"--- Importance for: {key} ---")
    value.show()