## Import Required Libraries

In [1]:
import os
import warnings

import pandas as pd
import numpy as np
import torch
from tqdm.autonotebook import tqdm

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer,  Baseline, QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

# supress warnings to keep the output clean
warnings.filterwarnings("ignore")

# set seed for reproducibility
pl.seed_everything(42)

print(f"Pytorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device Count: {torch.cuda.device_count()}")
    print(f"Device Name: {torch.cuda.get_device_name(0)}")

  from tqdm.autonotebook import tqdm
Seed set to 42


Pytorch Version: 2.10.0
CUDA Available: False


## Observation of Processed Dataset

In [2]:
# define path to the processed parquet file
# note: adjust relative path if your notebook location differs
DATA_PATH = "../local_artifacts/processed_data/training_data.parquet"

# load the df
data = pd.read_parquet(DATA_PATH)

# display basic statistics
print(f"Dataset Shape: {data.shape}")
print(f"Date range: {data['arrival_time'].min()} to {data['arrival_time'].max()}")

# check for nulls in potential target/group columns
# we assume 'headway' is the target column
print("\nMissing Values:")
print(data.isnull().sum())

# Inspect column types
print("\nColumn Data Types:")
print(data.dtypes)

# preview
data.head()

Dataset Shape: (75197, 20)
Date range: 2025-07-18T08:53:57+00:00 to 2026-01-19T10:43:23+00:00

Missing Values:
trip_uid                   0
trip_date                  0
arrival_time               0
timestamp                  0
group_id                   0
route_id                   0
direction                  0
stop_id                    0
time_idx                   0
day_of_week                0
hour_sin                   0
hour_cos                   0
regime_id                  0
track_id                   0
service_headway          260
preceding_train_gap        2
empirical_median           0
travel_time_14th          50
travel_time_23rd       23448
travel_time_34th          29
dtype: int64

Column Data Types:
trip_uid                object
trip_date               object
arrival_time            object
timestamp              float64
group_id                object
route_id                object
direction               object
stop_id                 object
time_idx                 int

Unnamed: 0,trip_uid,trip_date,arrival_time,timestamp,group_id,route_id,direction,stop_id,time_idx,day_of_week,hour_sin,hour_cos,regime_id,track_id,service_headway,preceding_train_gap,empirical_median,travel_time_14th,travel_time_23rd,travel_time_34th
0,1752815340_A..S53R,2025-07-18 05:09:00+00:00,2025-07-18T09:42:36+00:00,1752832000.0,A_South,A,S,A32S,29213862,4,0.566406,-0.824126,Day,A3,12.666667,,17.366667,2.0,,4.416667
1,1752816300_A..S58R,2025-07-18 05:25:00+00:00,2025-07-18T10:00:11+00:00,1752833000.0,A_South,A,S,A32S,29213880,4,0.5,-0.866025,Day,A3,17.583333,17.583333,10.633333,2.0,,4.1
2,1752817230_A..S57R,2025-07-18 05:40:30+00:00,2025-07-18T10:12:37+00:00,1752834000.0,A_South,A,S,A32S,29213892,4,0.45399,-0.891007,Day,A3,12.433333,12.433333,10.633333,1.833333,,3.933333
3,1752818010_A..S58R,2025-07-18 05:53:30+00:00,2025-07-18T10:23:36+00:00,1752834000.0,A_South,A,S,A32S,29213903,4,0.410719,-0.911762,Day,A3,10.983333,10.983333,10.633333,1.833333,,4.016667
4,1752818580_A..S57R,2025-07-18 06:03:00+00:00,2025-07-18T10:32:36+00:00,1752835000.0,A_South,A,S,A32S,29213912,4,0.374607,-0.927184,Day,A3,9.0,9.0,10.633333,1.833333,,3.916667


## Apply Imputation and 2nd Level Processing

In [3]:
# drop rows with missing headways
data = data.dropna(subset=['service_headway'])

# 2. structural missingness
# -1.0 indicates station skipped (express) distinct from 0.0
data['travel_time_23rd'] = data['travel_time_23rd'].fillna(-1.0)

# fill other numerical gaps with 0.0 (safe for now)
data['preceding_train_gap'] = data['preceding_train_gap'].fillna(0.0)
data['travel_time_14th'] = data['travel_time_14th'].fillna(0.0)
data['travel_time_34th'] = data['travel_time_34th'].fillna(0.0)

# 3. ensure categoricals are strings
cat_cols = ['group_id','route_id','direction','regime_id','track_id']
for col in cat_cols:
    data[col] = data[col].astype(str)

# 4. stricktly re-index time_idx
# sort ensures we process trains in correct arrival order per group
data = data.sort_values(['group_id', 'arrival_time'])

# create a continuous counter for each group
# this tells TFT specific train order
data["time_idx"] = data.groupby('group_id').cumcount()

print(f"Cleaned Shape:{data.shape}")
print(f"Max time_idx (Sequence length): {data['time_idx'].max()}")
# verify we still have distinct -1.0 signal
print(f"Travel time 23rd Unique values: {data['travel_time_23rd'].unique()[:5]}")

data.head()


Cleaned Shape:(74937, 20)
Max time_idx (Sequence length): 29661
Travel time 23rd Unique values: [ 3.2         3.65       -1.          3.48333333  3.58333333]


Unnamed: 0,trip_uid,trip_date,arrival_time,timestamp,group_id,route_id,direction,stop_id,time_idx,day_of_week,hour_sin,hour_cos,regime_id,track_id,service_headway,preceding_train_gap,empirical_median,travel_time_14th,travel_time_23rd,travel_time_34th
23270,1752813480_A..S74X043,2025-07-18 04:38:00+00:00,2025-07-18T09:09:03+00:00,1752830000.0,A_South,A,S,A32S,0,4,0.678801,-0.734323,Day,A1,15.1,15.1,17.366667,1.933333,3.2,4.766667
23273,1752814680_A..S74X043,2025-07-18 04:58:00+00:00,2025-07-18T09:29:56+00:00,1752831000.0,A_South,A,S,A32S,1,4,0.612217,-0.79069,Day,A1,20.883333,6.983333,17.366667,2.2,3.65,4.916667
0,1752815340_A..S53R,2025-07-18 05:09:00+00:00,2025-07-18T09:42:36+00:00,1752832000.0,A_South,A,S,A32S,2,4,0.566406,-0.824126,Day,A3,12.666667,0.0,17.366667,2.0,-1.0,4.416667
1,1752816300_A..S58R,2025-07-18 05:25:00+00:00,2025-07-18T10:00:11+00:00,1752833000.0,A_South,A,S,A32S,3,4,0.5,-0.866025,Day,A3,17.583333,17.583333,10.633333,2.0,-1.0,4.1
2,1752817230_A..S57R,2025-07-18 05:40:30+00:00,2025-07-18T10:12:37+00:00,1752834000.0,A_South,A,S,A32S,4,4,0.45399,-0.891007,Day,A3,12.433333,12.433333,10.633333,1.833333,-1.0,3.933333


## Time Based Splits and ML Dataset Creation

In [8]:
# 1 define date splits
# converting string to datetime for comparison
if data['arrival_time'].dtype == 'object':
    data['arrival_time_dt'] = pd.to_datetime(data['arrival_time'])
else:
    data['arrival_time_dt'] = data['arrival_time']

# define boundaries 
train_end_date = pd.Timestamp("2025-11-18", tz="UTC")
val_end_date = pd.Timestamp("2025-12-18", tz="UTC")
test_end_date = pd.Timestamp("2026-01-18", tz="UTC")


# 2 helper function to create dataset inputs with correct lookback context
def get_slice_with_lookback(full_df, start_date, end_date, lookback=12):
    """
    Return rows between start_date and end_date,
    Plus the last 'lookback' rows before start_date for EACH group
    """
    # get the core data for the period
    mask = (full_df['arrival_time_dt'] >= start_date) & (full_df['arrival_time_dt'] < end_date)
    core_df = full_df[mask]

    # we need to preprend the last 'lookback' rows from BEFORE start_date for EACH group
    # to serve as the history for the first few predictions
    prior_df = full_df[full_df['arrival_time_dt'] < start_date]
    pre_data = []

    for g_id, group in prior_df.groupby('group_id'):
        pre_data.append(group.tail(lookback))
    
    if pre_data:
        lookback_df = pd.concat(pre_data)
        # concat and sort to ensure time continuity per group
        return pd.concat([lookback_df, core_df]).sort_values(['group_id','time_idx'])
    return core_df

# 3 create physical dataframes
train_df = data[data['arrival_time_dt']< train_end_date]
val_df_input = get_slice_with_lookback(data, train_end_date, val_end_date, lookback=12)
test_df_input = get_slice_with_lookback(data, val_end_date, test_end_date, lookback=12)

print(f"Train rows: {len(train_df)}")
print(f"Val Rows (with context): {len(val_df_input)}")
print(f"Test Rows (with context): {len(test_df_input)}")

# create datasets
max_prediction_length = 1
max_encoder_length = 12

training = TimeSeriesDataSet(
    train_df,
    time_idx="time_idx",
    target="service_headway",
    group_ids=["group_id"],
    min_encoder_length=10,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["route_id","direction"],
    time_varying_known_categoricals=["regime_id", "track_id"],
    time_varying_known_reals=["time_idx", "hour_sin", "hour_cos", "empirical_median"],
    time_varying_unknown_reals=[
        "service_headway",
        "preceding_train_gap",
        "travel_time_14th",
        "travel_time_23rd",
        "travel_time_34th"
    ],
target_normalizer=GroupNormalizer(
    groups=["group_id"], transformation="softplus"
),
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)

# use from_dataset with sliced dataframes
# we do not use min_prediction_idx because the df slicing handled the time separation
validation = TimeSeriesDataSet.from_dataset(training, val_df_input, predict=False, stop_randomization=True)
test = TimeSeriesDataSet.from_dataset(training, test_df_input, predict=False, stop_randomization=True)

batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
test_dataloader = test.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

print(f"Train Batches: {len(train_dataloader)}")
print(f"Val Batches: {len(val_dataloader)}")
print(f"Test Batches {len(test_dataloader)}")
print(f"Total Batches (Train/Val/Test): {len(train_dataloader)} / {len(val_dataloader)} / {len(test_dataloader)}")

# visualize what the model sees
# this helps debug if features are being scaled correctly
x, y = next(iter(train_dataloader))
print("\nFeature names:", training.static_categoricals + training.time_varying_known_categoricals)
# x keys: "encoder_cat", "encoder_cont", "decoder_cat", "decoder_cont"
print("Encoder Shape (Batch, Time, Features):", x['encoder_cont'].shape)

Train rows: 49743
Val Rows (with context): 12304
Test Rows (with context): 12588
Train Batches: 776
Val Batches: 20
Test Batches 20
Total Batches (Train/Val/Test): 776 / 20 / 20

Feature names: ['route_id', 'direction', 'regime_id', 'track_id']
Encoder Shape (Batch, Time, Features): torch.Size([64, 12, 13])


## Model Architecture and Training Setup

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# Launch TensorBoard pointing to the logs directory we configured above
# This will open a panel below. It might be empty at first until training starts writing logs.
%tensorboard --logdir tensorboard_logs

In [None]:
# 1 configure the Temporal Fusion Transformer
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.001,    # start with resonable default
    hidden_size=32,        # critical for small dataset, keep params low
    attention_head_size=4,  # sufficient for simple temporal pattersn
    dropout=0.3,            # high dropout to force generalization
    hidden_continuous_size=16, #reduce preojection size
    output_size=3,          # 3 quantiles [0.1, 0.5, 0.9]
    loss=QuantileLoss([0.1, 0.5, 0.9]),    # standard loss for probabilistic forecasting
    log_interval=10,        # logging frequency
    reduce_on_plateau_patience=4, # reduce LR if loss doesn't improve for 4 epochs
)

print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# 2 configure training callbacks
# early stopping prevents overfitting by monitoring validation loss
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=1e-4,
    patience=15, # aggressive for small dataset
    verbose=False,
    mode="min"
)

lr_logger = LearningRateMonitor()

# configure logger - using the explicit tensorboard_logs directory
logger = TensorBoardLogger("tensorboard_logs", name="headway_tft")

# initialize trainer
trainer = pl.Trainer(
    max_epochs=50,
    accelerator="auto", # uses gpu if found
    devices=1,
    enable_model_summary=True,
    gradient_clip_val=0.1, # critical: prevents exploding gradients from outliers
    callbacks=[lr_logger,early_stop_callback],
    logger=logger,
    limit_train_batches=1.0
)

print("Starting Training...")
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

## Evaluation Metrics
Part A: Loss Visualization (The Developer's View)<br>
Standard TensorBoard-style plot of Train vs Val Loss.<br>
<br>
Part B: Performance Metrics (The Data Scientist's View)<br>
We will run the model on the Test Set (which it has never seen) and compute:<br>
<br>
MAE (in minutes)<br>
sMAPE (Percentage error)<br>
Part C: Representative Predictions (The Operator's View)<br>
We will select a few specific examples from the test set to plot:<br>
<br>
Routine: A standard rush-hour sequence.<br>
Disruption: A case where preceding_train_gap was high (interaction delay).<br>
Overnight: A case from the "Regime Shift" (22:00-23:00) to see if it widens confidence intervals as requested.<br>

In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# 1. Load best model
best_model_path = None
if getattr(trainer, "checkpoint_callbacks", None):
    for cb in trainer.checkpoint_callbacks:
        path = getattr(cb, "best_model_path", None)
        if path:
            best_model_path = path
            print(f"Best model found at: {best_model_path}")
            break

if best_model_path is None and getattr(trainer, "checkpoint_callback", None):
    best_model_path = getattr(trainer.checkpoint_callback, "best_model_path", None)
    print(f"Best model found via fallback: {best_model_path}")

if best_model_path is None:
    raise ValueError("No best model checkpoint found. Did training fail?")

# Load model
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

# Performance: Move to GPU if available for inference
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Inference Device: {device}")
best_tft.to(device)

# 2. Global Predictions
print("Generating predictions on full test set...")
raw_prediction = best_tft.predict(test_dataloader, mode="raw", return_x=True)

# Unpack tuple if necessary
if isinstance(raw_prediction, tuple) or isinstance(raw_prediction, list):
    x = raw_prediction[1]
    raw_prediction = raw_prediction[0]
else:
    x = raw_prediction

# 3. Metrics (Move to CPU for calculation)
print("Calculating Metrics...")
predictions = raw_prediction.prediction.cpu()
actuals = x["decoder_target"].cpu()

mae_metric = MAE()
smape_metric = SMAPE()
quantile_loss_metric = QuantileLoss(quantiles=[0.1, 0.5, 0.9])

loss_val = quantile_loss_metric(predictions, actuals)
p50_forecast = predictions[:, :, 1] # Median
mae_val = mae_metric(p50_forecast, actuals)
smape_val = smape_metric(p50_forecast, actuals)

print(f"\n--- Global Test Metrics ---")
print(f"Quantile Loss: {loss_val.mean().item():.4f}")
print(f"MAE (P50):     {mae_val.mean().item():.4f} minutes")
print(f"sMAPE (P50):   {smape_val.mean().item():.4f}")

# Calibration
p10 = predictions[:, :, 0]
p90 = predictions[:, :, 2]
p10_coverage = (actuals <= p10).float().mean()
p90_coverage = (actuals <= p90).float().mean()
print(f"P10 Coverage:  {p10_coverage.item():.3f} (Target 0.10)")
print(f"P90 Coverage:  {p90_coverage.item():.3f} (Target 0.90)")

# --- NEW SECTION: Group Breakdown ---
print("\n--- Metrics by Group ID ---")
if "groups" in x:
    # Get group IDs from tensor
    group_ids = x["groups"].cpu().view(-1).numpy()
    
    # Use the encoder from training object to map int -> string name
    group_encoder = training.categorical_encoders["group_id"]
    
    # Create DataFrame for aggregation
    # We calculate the absolute error per sample first
    abs_errors = torch.abs(p50_forecast - actuals).mean(dim=1).numpy()
    
    res_df = pd.DataFrame({
        "group_idx": group_ids,
        "mae": abs_errors
    })
    
    # Map index to name
    unique_idxs = np.unique(group_ids)
    # inverse_transform takes a long tensor
    decoded_names = group_encoder.inverse_transform(torch.tensor(unique_idxs, dtype=torch.long))

    idx_map = dict(zip(unique_idxs, decoded_names))
    res_df["group_name"] = res_df["group_idx"].map(idx_map)
    
    # Group By and Aggregate
    grouped_stats = res_df.groupby("group_name")["mae"].agg(['mean', 'count']).rename(columns={'mean': 'MAE', 'count': 'Samples'})
    print(grouped_stats.sort_values("MAE", ascending=False))
else:
    print("Could not find group information in input tensors.")

# 4. Visualization Sample
print("\nVisualizing Single Prediction (Manual Plot):")
idx = 0 

encoder_target = x['encoder_target'][idx].cpu()
decoder_target = x['decoder_target'][idx].cpu()
prediction_p10 = predictions[idx, :, 0].cpu()
prediction_p50 = predictions[idx, :, 1].cpu()
prediction_p90 = predictions[idx, :, 2].cpu()

enc_len = len(encoder_target)
dec_len = len(decoder_target)
history_time = range(-enc_len, 0)
future_time = range(0, dec_len)

plt.figure(figsize=(10, 6))
plt.plot(history_time, encoder_target, label="History", color="gray", marker=".")
plt.plot(future_time, decoder_target, label="Actual Future", color="black", marker="o")
plt.plot(future_time, prediction_p50, label="Forecast P50", color="blue", marker="x")
plt.fill_between(future_time, prediction_p10, prediction_p90, color="blue", alpha=0.2, label="Confidence (P10-P90)")
plt.title(f"Test Sample #{idx}: Headway Prediction")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()