In [None]:
# AUTHENTICATION FOR COLAB
# If you are running this locally, you can skip this.
# If you are running in Colab, you MUST run this cell first.
try:
    from google.colab import auth
    auth.authenticate_user()
    print("Authenticated successfully!")
except ImportError:
    print("Not running in Colab. Ensure your local environment is authenticated via 'gcloud auth application-default login'")

In [1]:
!pip install pytorch-forecasting pytorch-lightning

Collecting pytorch-forecasting
  Downloading pytorch_forecasting-1.6.1-py3-none-any.whl.metadata (14 kB)
Collecting pytorch-lightning
  Using cached pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting torch!=2.0.1,<3.0.0,>=2.0.0 (from pytorch-forecasting)
  Downloading torch-2.10.0-cp313-none-macosx_11_0_arm64.whl.metadata (31 kB)
Collecting lightning<2.7.0,>=2.0.0 (from pytorch-forecasting)
  Downloading lightning-2.6.0-py3-none-any.whl.metadata (44 kB)
Collecting scipy<2.0,>=1.8 (from pytorch-forecasting)
  Using cached scipy-1.17.0-cp313-cp313-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting scikit-learn<2.0,>=1.2 (from pytorch-forecasting)
  Using cached scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl.metadata (11 kB)
Collecting scikit-base<0.14.0 (from pytorch-forecasting)
  Downloading scikit_base-0.13.1-py3-none-any.whl.metadata (8.8 kB)
Collecting fsspec<2027.0,>=2022.5.0 (from fsspec[http]<2027.0,>=2022.5.0->lightning<2.7.0,>=2.0.0->pytorch-forecasting)


In [None]:
# !pip install pytorch-forecasting pytorch-lightning

import os
import warnings
import numpy as np
import pandas as pd
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline, QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss

warnings.filterwarnings("ignore")
print(f"PyTorch Version: {torch.__version__}")

## 1. Data Loading

We load the data from our staging area.
*   **Key Requirement:** TFT requires a `time_idx` (integer step) and `group_ids` (series identifier).
*   Our `group_ids` will be the `route_id` (e.g., A, C, E).
*   Our `time_idx` must be continuous for each group.

In [None]:
from google.cloud import bigquery
import os

# Read SQL Query from file
# Assuming the notebook is running from the 'notebooks' directory
sql_path = os.path.join("sql", "tft_training_data.sql")

with open(sql_path, "r") as f:
    query = f.read()

client = bigquery.Client(project="realtime-headway-prediction")
df = client.query(query).to_dataframe()

# Convert Types
df['arrival_time'] = pd.to_datetime(df['arrival_time'])
df['hour_of_day'] = df['hour_of_day'].astype(float) # TFT expects Reals to be float
df['day_of_week'] = df['day_of_week'].astype(str) # Categorical

print(f"Loaded {len(df)} rows")
df.head()

In [None]:
# Feature Engineering: Rolling Statistics
# 1. Route-Level Statistics (The "A" train context)
indexer = df.groupby(["route_id", "track"])["service_headway"]
df["rolling_mean_10"] = indexer.transform(lambda x: x.rolling(window=10, min_periods=1).mean())
df["rolling_std_10"] = indexer.transform(lambda x: x.rolling(window=10, min_periods=1).std()).fillna(0)
df["rolling_max_20"] = indexer.transform(lambda x: x.rolling(window=20, min_periods=1).max())

# 2. Track-Level Statistics (The "Invsisible Traffic" context)
# We sort by track first to ensure rolling excludes route partitioning
df.sort_values(['track', 'arrival_time'], inplace=True)
track_indexer = df.groupby("track")["track_headway"]
df["rolling_track_mean_5"] = track_indexer.transform(lambda x: x.rolling(window=5, min_periods=1).mean().fillna(0))

# 3. Create Group ID
df['group_id'] = df['route_id'] + "_" + df['track']

# 4. Time Indexing & Time Delta
# Sort by group to establish sequence
df.sort_values(['group_id', 'arrival_time'], inplace=True)

# Event-Based Indexing (Sequence Order)
df['time_idx'] = df.groupby('group_id').cumcount()

# Explicit Time Feature (Fixing the "Time Distortion")
# This tells the model how much wall-clock time passed since the last event
df["dt_since_prev"] = df.groupby("group_id")["arrival_time"].diff().dt.total_seconds() / 60
df["dt_since_prev"] = df["dt_since_prev"].fillna(0)

print("Features Added:")
print("- Route Context: rolling_mean_10, rolling_std_10")
print("- Track Context: rolling_track_mean_5 (Congestion indicator)")
print("- Time Context: dt_since_prev (Wall-clock minutes per step)")

df[['arrival_time', 'group_id', 'time_idx', 'service_headway', 'dt_since_prev', 'track_headway']].head(15)

## 2. TFT Data Structure (TimeSeriesDataSet)

This is the most critical part. We map our columns to TFT's input buckets.

*   **time_idx**: The integer time step.
*   **target**: `headway`
*   **group_ids**: `['route_id', 'direction_id', 'stop_id']` (defines a unique time series).
*   **static_categoricals**: `['route_id']` (Things that don't change).
*   **time_varying_known_reals**: `['scheduled_headway', 'hour_of_day']` (We know the schedule in the future).
*   **time_varying_unknown_reals**: `['headway', 'actual_delay']` (We only know these in the past).

In [None]:
# Definition - adapt column names to your schema
max_prediction_length = 3 
max_encoder_length = 20

# TRAIN/VAL SPLIT
max_time_idx = df["time_idx"].max()
training_cutoff = int(max_time_idx * 0.8)

# 1. Training Dataset - FULL VISIBILITY MODE
training = TimeSeriesDataSet(
    df[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="service_headway",
    group_ids=["group_id"],
    min_encoder_length=10, 
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    
    # Static variables 
    static_categoricals=["route_id", "track"], 
    
    # Known Future Inputs
    time_varying_known_reals=["scheduled_headway", "hour_of_day"],
    time_varying_known_categoricals=["day_of_week"],
    
    # Unknown Future Inputs - AUGMENTED
    time_varying_unknown_reals=[
        "service_headway",          # Target
        "rolling_mean_10",          # Recent performance of THIS route
        "rolling_std_10",           
        
        "track_headway",            # CRITICAL: Gap to ANY train in front
        "rolling_track_mean_5",     # CRITICAL: Is the physical track congested?
        
        "dt_since_prev"             # CRITICAL: Explicit time duration of the step
    ],
    
    # Standard Z-Score normalization
    target_normalizer=GroupNormalizer(
        groups=["group_id"], transformation=None
    ), 
    
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True
)

# 2. Validation Dataset 
validation = TimeSeriesDataSet.from_dataset(
    training, 
    df, 
    predict=False, 
    stop_randomization=True,
    min_prediction_idx=training_cutoff + 1
)

# Dataloaders
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

print("Datasets Configured (Full Visibility Mode).") 
print(f"Features: {training.reals}")

## 3. Training the Model

We use PyTorch Lightning.
*   **QuantileLoss**: Optimizes for the median (0.5) as well as the 10th and 90th percentiles. This gives us prediction intervals (e.g., "Train is arriving in 5 mins +/- 1 min").

In [None]:
import pytorch_lightning as pl
try:
    import lightning.pytorch as pl_new
except ImportError:
    pl_new = None

# MODEL CONFIGURATION
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.01,
    hidden_size=64,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,  
    loss=QuantileLoss(),
    log_interval=10, 
    reduce_on_plateau_patience=4,
)

# DIAGNOSTIC & FIX
using_modern_pl = False
if pl_new is not None and isinstance(tft, pl_new.LightningModule):
    pl = pl_new
    using_modern_pl = True
elif isinstance(tft, pl.LightningModule):
    pass
else:
    raise TypeError("Library Version Mismatch: Please Restart Runtime/Kernel.")

# TRAINER CONFIGURATION - FULL DATA MODE
trainer = pl.Trainer(
    max_epochs=30, # Increased to allow convergence on full data
    accelerator="gpu", 
    devices=1,
    enable_model_summary=True,
    gradient_clip_val=0.1, 
    limit_train_batches=1.0, # CRITICAL: Use 100% of the data (remove "handbrake")
)

print(f"Model Configured: Hidden Size=64, LR=0.01")
print("Training on 100% of batches (Handbrake removed).")

trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

## 4. Evaluation & Interpretability

TFT's superpower is interpretability. We can plot:
1.  **Variable Importance**: Which features matters most? (Schedule vs Past Delays).
2.  **Attention Weights**: Is the model looking at the recent past or distant past?

In [None]:
import matplotlib.pyplot as plt

# 1. Load the best model execution
best_model_path = trainer.checkpoint_callback.best_model_path
print(f"Loading best model from: {best_model_path}")
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

# 2. Get predictions on the validation set
# mode="raw" returns the full distribution (quantiles), not just the mean
# API Fix: Handle cases where predict returns (output, x, index) vs (output, x)
prediction_result = best_tft.predict(val_dataloader, mode="raw", return_x=True)

if isinstance(prediction_result, tuple) and len(prediction_result) >= 2:
    raw_predictions = prediction_result[0]
    x = prediction_result[1]
    print(f"Prediction returned {len(prediction_result)} items. Successfully unpacked.")
else:
    # Fallback to standard unpacking (will crash if incorrect, but we hope the tuple check caught it)
    raw_predictions, x = prediction_result

# 3. INTERPRETATION: Variable Importance
# This shows which features (e.g. 'scheduled_headway', 'rolling_mean') the model actually used.
print("generating Variable Importance Plot...")
try:
    interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
    best_tft.plot_interpretation(interpretation)
    plt.show()
except Exception as e:
    print(f"Could not generate interpretation plot: {e}")

# 4. VISUALIZATION: Real vs Predicted
# We plot a few examples. The grey area represents the uncertainty (10th to 90th percentile).
print("\nPlotting example predictions (Grey cone = 10th-90th percentile confidence interval):")
# Plotting indices 0, 10, 20 just to see a variety
for idx in [0, 10, 20]:  
    try:
        # Check if index is valid
        if idx < x["decoder_target"].shape[0]:
            best_tft.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True)
            plt.show()
        else:
            print(f"Index {idx} out of bounds for plotting.")
    except Exception as e:
        print(f"Error plotting index {idx}: {e}")

print("\nINTERPRETATION GUIDE:")
print("- Encoder Variables: What *past* information mattered most?")
print("- Decoder Variables: What *future* information (like schedule) helped most?")
print("- Static Variables: Did the Route ID matter?")