## Gongguan YouBike Forecasting with the Temporal Fusion Transformer (TFT)

This notebook trains a TFT model on the focused **Gongguan case study** dataset. It uses the final feature-engineered data, which includes:
- Cyclical time features (sin/cos)
- Holiday and storm event flags
- Weather data (normalized)
- Static POI features (normalized)
- Behavioral station clusters (one-hot encoded)
- Historical lag and inflow/outflow features (normalized)

The workflow is structured to load this data, configure the powerful `TimeSeriesDataSet` loader, train the model, and evaluate its performance.

In [None]:
import pandas as pd
import numpy as np
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from lightning.pytorch.loggers import TensorBoardLogger
import torch
import warnings

warnings.filterwarnings("ignore")
pl.seed_everything(42)

### 1. Load the Pre-processed Gongguan Dataset

We load the `gongguan_model_ready_features.parquet.gz` file. This is the final output of our master feature engineering pipeline.

In [None]:
file_path = "/content/drive/MyDrive/Youbike_Master_Project/YouBike_Demand_Forecast/data/gongguan_model_ready_features.parquet.gz"
data = pd.read_parquet(file_path)

# Ensure data types are correct for the library
data['occupancy_ratio'] = data['occupancy_ratio'].astype(np.float32)
# Convert boolean one-hot columns to integers
for col in data.columns:
    if data[col].dtype == 'bool':
        data[col] = data[col].astype(int)

print("Gongguan dataset loaded successfully.")
data.info()

### 2. Create a Time Index

The `pytorch-forecasting` library requires a continuous integer index for time. We create this by sorting and assigning a cumulative count for each station.

In [None]:
data.sort_values(['sno', 'time'], inplace=True)
# The time index must be an integer
data['time_idx'] = data.groupby('sno').cumcount()

print("Time index created.")

### 3. Define Features and Create the `TimeSeriesDataSet`

This is the most critical step. We programmatically identify all our engineered features and assign them to the correct category for the TFT model.

-   `max_encoder_length`: How much history to use (e.g., 24 hours).
-   `max_prediction_length`: How far to predict into the future (e.g., 6 hours).

In [None]:
# Define sliding window parameters
max_prediction_length = 6 * 6 # Predict 6 hours ahead (36 steps)
max_encoder_length = 24 * 6 # Use 24 hours of history (144 steps)

# Define the training/validation split point (e.g., use last month for validation)
training_cutoff = data["time_idx"].max() - (30 * 24 * 6)

# --- Automatically identify feature columns ---
target = 'occupancy_ratio'
group_ids = ['sno']

# Static features do not change over time for a given station
static_reals = [col for col in data.columns if col.startswith('poi_') or col in ['lat', 'lng']]
static_categoricals = [col for col in data.columns if col.startswith('cluster_is_')]

# Time-varying features that are known in the future
time_varying_known_reals = [
    'hour_sin', 'hour_cos', 'day_of_week_sin', 'day_of_week_cos', 'month_sin', 'month_cos',
    'is_major_storm_day'
] + [col for col in data.columns if col.startswith('day_is_')]

# Time-varying features that are not known in the future
time_varying_unknown_reals = [
    target,
    'Temperature', 'Dew Point', 'Humidity', 'Speed', 'Pressure',
] + [col for col in data.columns if '_lag_' in col or 'inflow' in col or 'outflow' in col]
# Remove the target from the list as it's handled separately
time_varying_unknown_reals.remove(target)

# --- Create the TimeSeriesDataSet ---
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target=target,
    group_ids=group_ids,
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=static_categoricals,
    static_reals=static_reals,
    time_varying_known_reals=time_varying_known_reals,
    time_varying_unknown_reals=time_varying_unknown_reals,
    target_normalizer=GroupNormalizer(groups=group_ids, transformation="softplus"),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)

# Create dataloaders
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=2)

print("TimeSeriesDataSets and DataLoaders created successfully.")

### 4. Configure and Train the TFT Model

With the data prepared, we can now define and train our model. We use `EarlyStopping` to prevent overfitting.

In [None]:
# Configure the trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min")
logger = TensorBoardLogger("lightning_logs")

trainer = pl.Trainer(
    max_epochs=30,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    enable_model_summary=True,
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback],
    logger=logger,
)

# Define the TFT model
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=64, # Increased hidden size for more complex patterns
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=32, # Increased for more features
    output_size=7,  # To predict 7 quantiles
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# Start training
print("Starting model training...")
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

### 5. Evaluate and Visualize Predictions

After training, we load the best performing model and visualize its predictions on the validation set to qualitatively assess its performance.

In [None]:
# Load the best model from the checkpoint
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

# Make predictions on the validation set
raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)

# Visualize some examples
print("Plotting predictions for 5 random examples from the validation set...")
for i in range(5):
    fig, ax = plt.subplots(figsize=(10, 5))
    best_tft.plot_prediction(x, raw_predictions, idx=i, ax=ax, add_loss_to_title=True)
    plt.show()

### 6. Interpret Model Behavior

Finally, we can use TFT's built-in interpretability to see which features the model found most useful for its predictions.

In [None]:
# Calculate and plot interpretation
interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
figs = best_tft.plot_interpretation(interpretation)

# Display the feature importance plots
for key, value in figs.items():
    print(f"--- Importance for: {key} ---")
    value.show()