## Import Required Libraries

In [1]:
import os
import warnings

import pandas as pd
import numpy as np
import torch
from tqdm.autonotebook import tqdm

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer,  Baseline, QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

# supress warnings to keep the output clean
warnings.filterwarnings("ignore")

# set seed for reproducibility
pl.seed_everything(42)

print(f"Pytorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device Count: {torch.cuda.device_count()}")
    print(f"Device Name: {torch.cuda.get_device_name(0)}")

  from tqdm.autonotebook import tqdm
Seed set to 42


Pytorch Version: 2.10.0
CUDA Available: False


## Observation of Processed Dataset

In [2]:
# define path to the processed parquet file
# note: adjust relative path if your notebook location differs
DATA_PATH = "../local_artifacts/processed_data/training_data.parquet"

# load the df
try:
    data = pd.read_parquet(DATA_PATH)
except FileNotFoundError:
    print(f"File not found at {DATA_PATH}. Please run the pipeline first.")
    # create dummy data so the cell doesn't crash completely during static analysis checks
    data = pd.DataFrame()

if not data.empty:
    # display basic statistics
    print(f"Dataset Shape: {data.shape}")
    
    if 'arrival_time' in data.columns:
        print(f"Date range: {data['arrival_time'].min()} to {data['arrival_time'].max()}")
    
    # check for nulls in potential target/group columns
    print("\nMissing Values:")
    print(data.isnull().sum())

    # Inspect column types
    print("\nColumn Data Types:")
    print(data.dtypes)

    # preview
    data.head()

Dataset Shape: (74937, 21)
Date range: 2025-07-18T09:09:03+00:00 to 2026-01-19T10:43:23+00:00

Missing Values:
trip_uid                          0
arrival_time                      0
group_id                          0
route_id                          0
time_idx                          0
day_of_week                       0
hour_sin                          0
hour_cos                          0
regime_id                         0
track_id                          0
service_headway                   0
preceding_train_gap               1
preceding_route_id                1
empirical_median                  0
upstream_headway_14th            49
travel_time_14th                 48
travel_time_14th_deviation       48
travel_time_23rd              23418
travel_time_23rd_deviation    23418
travel_time_34th                 27
travel_time_34th_deviation       27
dtype: int64

Column Data Types:
trip_uid                       object
arrival_time                   object
group_id                

## Apply Imputation and 2nd Level Processing

In [3]:
# 1. ensure categoricals are strings
cat_cols = ['group_id','route_id','direction','regime_id','track_id', 'preceding_route_id']
for col in cat_cols:
    if col in data.columns:
        # Handle nulls in categoricals (e.g. preceding_route_id might be null for first train)
        data[col] = data[col].fillna("None").astype(str)

# 2. Parse dates (Required for splitting)
if 'arrival_time' in data.columns:
    data['arrival_time_dt'] = pd.to_datetime(data['arrival_time'])

# --- NEW IMPUTATION LOGIC FOR PIPELINE OUTPUTS ---

# 3a. Handle General Missing Reals (Filling with Neutrals)
# Preceding Train Gap & Upstream Headway: Fill with Median
for col in ['preceding_train_gap', 'upstream_headway_14th']:
    if col in data.columns:
         data[col] = data[col].fillna(data[col].median())

# Travel Time Deviations: Fill with 0.0 (Assume on-time if unknown)
dev_cols = [c for c in data.columns if 'deviation' in c]
for col in dev_cols:
    data[col] = data[col].fillna(0.0)

# Other Travel Times (14th, 34th): Fill with Median
tt_cols = ['travel_time_14th', 'travel_time_34th']
for col in tt_cols:
    if col in data.columns:
        data[col] = data[col].fillna(data[col].median())

# 3b. Handle 23rd St Express/Local logic (Special Case)
# Logic: Express trains don't stop at 23rd, so they have NULL travel times.
# We create a binary flag 'stops_at_23rd' to tell the model this is intentional.
# We then fill the actual travel_time value with the MEAN of valid stops to "neutralize" it input-wise.

# Create flag: 1 if valid stop (not null and > 0), 0 otherwise
data['stops_at_23rd'] = np.where((data['travel_time_23rd'].notna()) & (data['travel_time_23rd'] > 0), 1.0, 0.0)

# Calculate mean of VALID stops only
valid_mean_23rd = data.loc[data['stops_at_23rd'] == 1.0, 'travel_time_23rd'].mean()

# Fill invalid/missing rows with that mean
data.loc[data['stops_at_23rd'] == 0.0, 'travel_time_23rd'] = valid_mean_23rd
# Just in case any NaNs remain (e.g. the column was purely NaN)
data['travel_time_23rd'] = data['travel_time_23rd'].fillna(valid_mean_23rd if pd.notna(valid_mean_23rd) else 0.0)


# 4. Correct Time Index (Physical Time)
# REFACTOR: Decouple time_idx from row count. Use absolute minute-index.
# This prevents warping of time during service gaps.
if 'arrival_time_dt' not in data.columns:
    data['arrival_time_dt'] = pd.to_datetime(data['arrival_time'])
    
# Find global min for anchor
min_time = data['arrival_time_dt'].min()
# Calculate minutes elapsed since start
data['time_idx'] = ((data['arrival_time_dt'] - min_time).dt.total_seconds() / 60).astype(int)

# Sort by group and new physical time index
data = data.sort_values(['group_id', 'time_idx'])

print(f"Cleaned Shape:{data.shape}")
print(f"Max time_idx (Sequence length): {data['time_idx'].max()}")
try:
    print(f"Valid Mean 23rd: {valid_mean_23rd:.2f}")
except:
    pass

# Verify no nulls remain
print("\nRemaining Nulls:")
print(data.isnull().sum().sum())

data.head()

Cleaned Shape:(74937, 23)
Max time_idx (Sequence length): 266494
Valid Mean 23rd: 3.77

Remaining Nulls:
0


Unnamed: 0,trip_uid,arrival_time,group_id,route_id,time_idx,day_of_week,hour_sin,hour_cos,regime_id,track_id,...,empirical_median,upstream_headway_14th,travel_time_14th,travel_time_14th_deviation,travel_time_23rd,travel_time_23rd_deviation,travel_time_34th,travel_time_34th_deviation,arrival_time_dt,stops_at_23rd
47375,1752813480_A..S74X043,2025-07-18T09:09:03+00:00,A_South,A,0,4,0.678801,-0.734323,Day,A1,...,17.366667,15.083333,1.933333,-0.066667,3.2,-0.241667,4.766667,-0.066667,2025-07-18 09:09:03+00:00,1.0
47376,1752814680_A..S74X043,2025-07-18T09:29:56+00:00,A_South,A,20,4,0.612217,-0.79069,Day,A1,...,17.366667,20.616667,2.2,0.2,3.65,0.208333,4.916667,0.083333,2025-07-18 09:29:56+00:00,1.0
47377,1752815340_A..S53R,2025-07-18T09:42:36+00:00,A_South,A,33,4,0.566406,-0.824126,Day,A3,...,17.366667,12.866667,2.0,0.0,3.773551,0.0,4.416667,-0.416667,2025-07-18 09:42:36+00:00,0.0
47378,1752816300_A..S58R,2025-07-18T10:00:11+00:00,A_South,A,51,4,0.5,-0.866025,Day,A3,...,10.633333,17.583333,2.0,0.05,3.773551,0.0,4.1,-0.066667,2025-07-18 10:00:11+00:00,0.0
47379,1752817230_A..S57R,2025-07-18T10:12:37+00:00,A_South,A,63,4,0.45399,-0.891007,Day,A3,...,10.633333,12.6,1.833333,-0.116667,3.773551,0.0,3.933333,-0.233333,2025-07-18 10:12:37+00:00,0.0


In [4]:
# DIAGNOSTIC: Check time_idx continuity
# This helps diagnose "no meaningful learning" issues caused by broken time scales

print("Checking time_idx continuity and uniqueness...")

# 1. Uniqueness check
duplicates = data.duplicated(['group_id', 'time_idx']).sum()
print(f"Duplicate (group_id, time_idx) pairs: {duplicates}")

# 2. Continuity check
# We expect time_idx to strictly increment by 1 for each group
data = data.sort_values(['group_id', 'time_idx'])
data['time_idx_diff'] = data.groupby('group_id')['time_idx'].diff()

# The first row of each group will have NaN diff, which is fine.
# Subsequent rows MUST have diff == 1.0
gaps = data[data['time_idx_diff'] > 1.0]

if len(gaps) > 0:
    print(f"⚠️ WARNING: Found {len(gaps)} gaps in time_idx sequence!")
    print("Top 5 gaps:")
    print(gaps[['group_id', 'arrival_time', 'time_idx', 'time_idx_diff']].head())
    print("Advice: If gaps exist, the model cannot learn temporal patterns effectively. Re-run the data pipeline with the fix for time_idx generation.")
else:
    print("✅ time_idx is continuous (no gaps found).")

# Clean up temporary column
if 'time_idx_diff' in data.columns:
    data.drop(columns=['time_idx_diff'], inplace=True)

Checking time_idx continuity and uniqueness...
Duplicate (group_id, time_idx) pairs: 25
Top 5 gaps:
      group_id               arrival_time  time_idx  time_idx_diff
47376  A_South  2025-07-18T09:29:56+00:00        20           20.0
47377  A_South  2025-07-18T09:42:36+00:00        33           13.0
47378  A_South  2025-07-18T10:00:11+00:00        51           18.0
47379  A_South  2025-07-18T10:12:37+00:00        63           12.0
47380  A_South  2025-07-18T10:23:36+00:00        74           11.0
Advice: If gaps exist, the model cannot learn temporal patterns effectively. Re-run the data pipeline with the fix for time_idx generation.


## Time Based Splits and ML Dataset Creation

In [6]:
# 1 define date splits
# Ensure datetime column exists and handle timezone
if 'arrival_time_dt' not in data.columns:
    data['arrival_time_dt'] = pd.to_datetime(data['arrival_time'])

# Detect if the dataset is Timezone Aware (The new parquet file likely is UTC)
is_tz_aware = data['arrival_time_dt'].dt.tz is not None
print(f"Dataset Timezone Aware: {is_tz_aware}")

# CUTOFF DATES REVERTED TO ORIGINAL PER USER REQUEST
# We match the timezone of the cutoffs to the data to avoid TypeError: "Cannot compare tz-naive and tz-aware"
tz = "UTC" if is_tz_aware else None

train_end_date = pd.Timestamp("2025-11-18", tz=tz) # Adjusted approx
val_end_date = pd.Timestamp("2025-12-18", tz=tz)
test_end_date = pd.Timestamp("2026-01-18", tz=tz) 

print(f"Cutoff Dates (Train/Val/Test): {train_end_date}, {val_end_date}, {test_end_date}")

# 2 helper function to create dataset inputs with correct lookback context
def get_slice_with_lookback(full_df, start_date, end_date, lookback=20):
    """
    Return rows between start_date and end_date,
    Plus the last 'lookback' rows before start_date for EACH group
    """
    # get the core data for the period
    mask = (full_df['arrival_time_dt'] >= start_date) & (full_df['arrival_time_dt'] < end_date)
    core_df = full_df[mask]

    # we need to preprend the last 'lookback' rows from BEFORE start_date for EACH group
    # to serve as the history for the first few predictions
    prior_df = full_df[full_df['arrival_time_dt'] < start_date]
    pre_data = []

    for g_id, group in prior_df.groupby('group_id'):
        pre_data.append(group.tail(lookback))
    
    if pre_data:
        lookback_df = pd.concat(pre_data)
        # concat and sort to ensure time continuity per group
        return pd.concat([lookback_df, core_df]).sort_values(['group_id','time_idx'])
    return core_df

# 3 create physical dataframes
train_df = data[data['arrival_time_dt']< train_end_date]
val_df_input = get_slice_with_lookback(data, train_end_date, val_end_date, lookback=20)
test_df_input = get_slice_with_lookback(data, val_end_date, test_end_date, lookback=20)

print(f"Train rows: {len(train_df)}")
print(f"Val Rows (with context): {len(val_df_input)}")
print(f"Test Rows (with context): {len(test_df_input)}")

# create datasets
max_prediction_length = 1
# improvement plan said "Start with 20". 
max_encoder_length = 20 

training = TimeSeriesDataSet(
    train_df,
    time_idx="time_idx",
    target="service_headway",
    group_ids=["group_id"],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    # REMOVED: direction, added route_id as static
    static_categoricals=["route_id"], 
    time_varying_known_categoricals=["regime_id", "track_id"],
    time_varying_known_reals=["time_idx", "hour_sin", "hour_cos", "empirical_median"],
    # ADDED: preceding_route_id (unknown because it varies per step)
    time_varying_unknown_categoricals=["preceding_route_id"],
    time_varying_unknown_reals=[
        "service_headway",
        "preceding_train_gap",
        "upstream_headway_14th",
        "travel_time_14th",
        "travel_time_14th_deviation",
        "travel_time_23rd",
        "travel_time_23rd_deviation",
        "travel_time_34th",
        "travel_time_34th_deviation",
        "stops_at_23rd" # ADDED: binary flag for express trains
    ],
    target_normalizer=GroupNormalizer(
        groups=["group_id"], transformation="softplus"
    ),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    # ADDED: Allow gaps in time_idx because we switched to physical minutes
    allow_missing_timesteps=True 
)

# use from_dataset with sliced dataframes
validation = TimeSeriesDataSet.from_dataset(training, val_df_input, predict=False, stop_randomization=True)
test = TimeSeriesDataSet.from_dataset(training, test_df_input, predict=False, stop_randomization=True)

# OPTIMIZED: Per Infrastructure Plan
batch_size = 128 # User requested 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=8)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=4)
test_dataloader = test.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=4)

print(f"Train Batches: {len(train_dataloader)}")
print(f"Val Batches: {len(val_dataloader)}")
print(f"Test Batches {len(test_dataloader)}")
print(f"Total Batches (Train/Val/Test): {len(train_dataloader)} / {len(val_dataloader)} / {len(test_dataloader)}")

# visualize what the model sees
x, y = next(iter(train_dataloader))
print("\nFeature names:", training.static_categoricals + training.time_varying_known_categoricals + training.time_varying_unknown_categoricals)
print("Encoder Shape (Batch, Time, Features):", x['encoder_cont'].shape)

Dataset Timezone Aware: True
Cutoff Dates (Train/Val/Test): 2025-11-18 00:00:00+00:00, 2025-12-18 00:00:00+00:00, 2026-01-18 00:00:00+00:00
Train rows: 49743
Val Rows (with context): 12336
Test Rows (with context): 12620
Train Batches: 430
Val Batches: 11
Test Batches 11
Total Batches (Train/Val/Test): 430 / 11 / 11

Feature names: ['route_id', 'regime_id', 'track_id', 'preceding_route_id']
Encoder Shape (Batch, Time, Features): torch.Size([128, 20, 18])


## Model Architecture and Training Setup

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# Launch TensorBoard pointing to the logs directory we configured above
# This will open a panel below. It might be empty at first until training starts writing logs.
%tensorboard --logdir tensorboard_logs

In [None]:
# 1 configure the Temporal Fusion Transformer

# FIX: Define a subclass to handle BFloat16 plotting issue with Matplotlib
# PyTorch Forecasting's plot_prediction method crashes when passing BFloat16 tensors to Matplotlib.
class TFTWithBF16Fix(TemporalFusionTransformer):
    def plot_prediction(self, x, out, idx, **kwargs):
        # Helper to cast BF16 tensors to Float32
        def to_float32(val):
            if isinstance(val, torch.Tensor) and val.dtype == torch.bfloat16:
                return val.float()
            return val
            
        # Recursive cast for nested dictionaries/tensors
        # NOTE: 'x' is typically a dictionary of tensors, so preserving the dict structure is key.
        # However, we must ensure we don't accidentally cast things that aren't tensors.
        if isinstance(x, dict):
             # Create a new dictionary with casted values
             x = {k: to_float32(v) for k, v in x.items()}
             
    
        
        # PROPOSED FIX: Cast everything recursively but carefully reconstruction the object.
        from pytorch_forecasting.utils import to_list
        from collections import namedtuple
        
        # If 'out' is the special Output class or NamedTuple
        if hasattr(out, "_asdict"):
            out_dict = out._asdict()
            out_dict = {k: to_float32(v) for k, v in out_dict.items()}
            # Reconstruct
            out = out.__class__(**out_dict)
        elif isinstance(out, dict):
            # If it's already a dict (some versions do this), just cast
            # If it needs to be an object with .iget, we can't easily fake it if it wasn't one already.
            # But usually it IS one.
             out = {k: to_float32(v) for k, v in out.items()}
             # If the parent expects an object with .iget, we are in trouble if we just pass a dict.
             # However, usually 'out' coming from .step() is the object.
        
        # We also need to traverse lists/tuples if they exist
        out = to_float32(out) # Fallback for single tensor
        
        return super().plot_prediction(x, out, idx, **kwargs)

# SIMPLER FIX: Just disable plotting during training log
# If the casting is too brittle due to internal object structures, 
# we can override log_prediction to do nothing or handle it manually.

class TFTDisablePlotting(TemporalFusionTransformer):
    def log_prediction(self, x, out, batch_idx, **kwargs):
        # SKIP plotting during training to avoid BFloat16 Matplotlib crash
        pass

tft = TFTDisablePlotting.from_dataset(
    training,
    learning_rate=0.001,
    hidden_size=128,            # INCREASED: Scaled to 128 per optimization report
    attention_head_size=4,      # INCREASED: Multi-scale attention
    dropout=0.1,                # REDUCED: 0.1 to facilitate initial learning
    hidden_continuous_size=64,  # INCREASED: 64 to prevent bottleneck
    output_size=3,              # 3 quantiles [0.1, 0.5, 0.9]
    loss=QuantileLoss([0.1, 0.5, 0.9]),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# 2 configure training callbacks
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=1e-4,
    patience=15, 
    verbose=False,
    mode="min"
)

lr_logger = LearningRateMonitor()

# configure logger
logger = TensorBoardLogger("tensorboard_logs", name="headway_tft")

# initialize trainer
trainer = pl.Trainer(
    max_epochs=100,             # INCREASED: 100 epochs
    accelerator="auto", 
    devices=1,
    enable_model_summary=True,
    gradient_clip_val=0.1,      # CHANGED: 0.1 for LSTM stability
    callbacks=[lr_logger, early_stop_callback, TQDMProgressBar(refresh_rate=20)], 
    logger=logger,
    limit_train_batches=1.0,    # CHANGED: Use 100% of data
    precision="bf16-mixed"      # CHANGED: Mixed precision for A100
)

print("Starting Training...")
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

## Evaluation Metrics
Part A: Loss Visualization (The Developer's View)<br>
Standard TensorBoard-style plot of Train vs Val Loss.<br>
<br>
Part B: Performance Metrics (The Data Scientist's View)<br>
We will run the model on the Test Set (which it has never seen) and compute:<br>
<br>
MAE (in minutes)<br>
sMAPE (Percentage error)<br>
Part C: Representative Predictions (The Operator's View)<br>
We will select a few specific examples from the test set to plot:<br>
<br>
Routine: A standard rush-hour sequence.<br>
Disruption: A case where preceding_train_gap was high (interaction delay).<br>
Overnight: A case from the "Regime Shift" (22:00-23:00) to see if it widens confidence intervals as requested.<br>

In [None]:
# 1. Load best model
best_model_path = None
if getattr(trainer, "checkpoint_callbacks", None):
    for cb in trainer.checkpoint_callbacks:
        path = getattr(cb, "best_model_path", None)
        if path:
            best_model_path = path
            print(f"Best model found at: {best_model_path}")
            break

if best_model_path is None and getattr(trainer, "checkpoint_callback", None):
    best_model_path = getattr(trainer.checkpoint_callback, "best_model_path", None)
    print(f"Best model found via fallback: {best_model_path}")

if best_model_path is None:
    # Just a warning or fallback if training didn't complete / save
    print("No best model found. Using last model state if available.")
    best_tft = tft
else:
    # Load model using the standard class
    best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

# Performance: Move to GPU if available for inference
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Inference Device: {device}")
best_tft.to(device)

# 2. Global Predictions
print("Generating predictions on full test set...")
raw_prediction = best_tft.predict(test_dataloader, mode="raw", return_x=True)

# Extract components
predictions = raw_prediction.output["prediction"] 
x = raw_prediction.x

# 3. Metrics (Move to CPU for calculation)
print("Calculating Metrics...")
predictions_cpu = predictions.cpu()
actuals_cpu = x["decoder_target"].cpu()

# QuantileLoss expects target with shape (Batch, Time), so we strictly preserve dimensions here.
# NOTE: removing previous squeeze() logic which caused IndexError
mae_metric = MAE()
smape_metric = SMAPE()
quantile_loss_metric = QuantileLoss(quantiles=[0.1, 0.5, 0.9])

loss_val = quantile_loss_metric(predictions_cpu, actuals_cpu)

# For Point Metrics (MAE, SMAPE), we use the Median (P50)
p50_forecast = predictions_cpu[:, :, 1] # shape (Batch, prediction_len)
mae_val = mae_metric(p50_forecast, actuals_cpu)
smape_val = smape_metric(p50_forecast, actuals_cpu)

print(f"\n--- Global Test Metrics ---")
print(f"Quantile Loss: {loss_val.mean().item():.4f}")
print(f"MAE (P50):     {mae_val.mean().item():.4f} minutes")
print(f"sMAPE (P50):   {smape_val.mean().item():.4f}")

# FLATTEN TENSORS for Analysis and Plotting (Batch*Time)
# This prevents shape mismatch issues with Pandas/Matplotlib
actuals_flat = actuals_cpu.view(-1)
p50_flat = p50_forecast.view(-1)
p10_flat = predictions_cpu[:, :, 0].view(-1)
p90_flat = predictions_cpu[:, :, 2].view(-1)

# Calibration
p10_coverage = (actuals_flat <= p10_flat).float().mean()
p90_coverage = (actuals_flat <= p90_flat).float().mean()
print(f"P10 Coverage:  {p10_coverage.item():.3f} (Target 0.10)")
print(f"P90 Coverage:  {p90_coverage.item():.3f} (Target 0.90)")

# --- NEW SECTION: Group Breakdown ---
print("\n--- Metrics by Group ID ---")
if "groups" in x:
    # Groups are repeated for prediction_length if > 1, so we must be careful.
    # But usually 'groups' is (Batch,), so we repeat it to match flattened outputs if needed.
    # However x['groups'] from raw_prediction.x usually maps 1-to-1 with Batch.
    # Since we flattened prediction-wise (Batch*Time), we need to repeat group_ids if time > 1.
    
    group_ids_batch = x["groups"].cpu().view(-1) # (Batch)
    # If prediction_len > 1, we need to repeat group_ids
    if predictions_cpu.shape[1] > 1:
        group_ids = group_ids_batch.repeat_interleave(predictions_cpu.shape[1]).numpy()
    else:
        group_ids = group_ids_batch.numpy()
    
    # Get Decoder
    if hasattr(training, "categorical_encoders") and training.categorical_encoders is not None:
        group_encoder = training.categorical_encoders["group_id"]
    else:
        group_encoder = best_tft.dataset_parameters["categorical_encoders"]["group_id"]
    
    # Calculate errors
    abs_errors = torch.abs(p50_flat - actuals_flat).numpy()
    
    res_df = pd.DataFrame({
        "group_idx": group_ids,
        "mae": abs_errors
    })
    
    # Map index to name
    unique_idxs = np.unique(group_ids)
    decoded_names = group_encoder.inverse_transform(torch.tensor(unique_idxs, dtype=torch.long))
    idx_map = dict(zip(unique_idxs, decoded_names))
    res_df["group_name"] = res_df["group_idx"].map(idx_map)
    
    grouped_stats = res_df.groupby("group_name")["mae"].agg(['mean', 'count']).rename(columns={'mean': 'MAE', 'count': 'Samples'})
    print(grouped_stats.sort_values("MAE", ascending=False))
else:
    print("Could not find group information in input tensors.")
    res_df = pd.DataFrame() # Fallback

# 4. EXECUTIVE SUITE VISUALIZATION
print("\n--- Generating Executive Visualization Suite ---")
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid", context="paper", font_scale=1.2)
fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(2, 2)

# Prepare Data Arrays (use flattened)
y_true = actuals_flat.numpy()
y_pred = p50_flat.numpy()
residuals = y_pred - y_true

# PLOT 1: Truth Scatter (Top Left)
ax1 = fig.add_subplot(gs[0, 0])
# Downsample for scatter presentation if huge
if len(y_true) > 5000:
    idx_sample = np.random.choice(len(y_true), 5000, replace=False)
    p_true, p_pred = y_true[idx_sample], y_pred[idx_sample]
else:
    p_true, p_pred = y_true, y_pred

sns.scatterplot(x=p_true, y=p_pred, alpha=0.1, color="#2c3e50", edgecolor=None, ax=ax1)
# Identity line
max_val = max(p_true.max(), p_pred.max())
ax1.plot([0, max_val], [0, max_val], color="#e74c3c", linestyle="--", linewidth=2, label="Perfect Accuracy")
ax1.set_title("Observed vs. Predicted Headways", fontweight="bold")
ax1.set_xlabel("Observed Headway (min)")
ax1.set_ylabel("Predicted Headway (min)")
ax1.legend()

# PLOT 2: reliability Histogram (Top Right)
ax2 = fig.add_subplot(gs[0, 1])
sns.histplot(residuals, bins=50, kde=True, color="#3498db", ax=ax2)
ax2.axvline(x=0, color='black', linestyle='--')
# Annotate 90% bounds
r_p05 = np.percentile(residuals, 5)
r_p95 = np.percentile(residuals, 95)
ax2.axvline(x=r_p05, color="#e74c3c", linestyle=":", alpha=0.5)
ax2.axvline(x=r_p95, color="#e74c3c", linestyle=":", alpha=0.5)
ax2.set_title(f"Error Distribution (90% within [{r_p05:.1f}, {r_p95:.1f}] min)", fontweight="bold")
ax2.set_xlabel("Prediction Error (min)")
ax2.set_xlim(-5, 5) # Focus on the core

# PLOT 3: Operator View / Timeline (Bottom Full Width)
ax3 = fig.add_subplot(gs[1, :])

if not res_df.empty:
    # Find the largest group
    biggest_group = grouped_stats["Samples"].idxmax()
    subset_mask = (res_df["group_name"] == biggest_group).values
    
    # Extract sequence (assuming dataset is time-ordered per group)
    y_true_sub = actuals_flat[subset_mask].numpy()
    y_pred_sub = p50_flat[subset_mask].numpy()
    y_p10_sub = p10_flat[subset_mask].numpy()
    y_p90_sub = p90_flat[subset_mask].numpy()
    
    # Take a slice of busy time (e.g. 100 trains in the middle)
    start_idx = max(0, len(y_true_sub) // 2 - 50)
    end_idx = min(len(y_true_sub), start_idx + 100)
    
    if start_idx < end_idx:
        x_seq = range(start_idx, end_idx)
        
        ax3.plot(x_seq, y_true_sub[start_idx:end_idx], label="Observed", color="black", linewidth=2, marker='o', markersize=4)
        ax3.plot(x_seq, y_pred_sub[start_idx:end_idx], label="AI Forecast", color="#2980b9", linewidth=2, linestyle="--")
        ax3.fill_between(x_seq, y_p10_sub[start_idx:end_idx], y_p90_sub[start_idx:end_idx], color="#3498db", alpha=0.2, label="Confidence Interval")
        
        ax3.set_title(f"Live Tracking Concept: '{biggest_group}' (Sequence of 100 Arrivals)", fontweight="bold")
        ax3.set_xlabel("Sequential Train Arrivals")
        ax3.set_ylabel("Headway (min)")
        ax3.legend(loc="upper right")
        ax3.grid(True, alpha=0.3)
    else:
        ax3.text(0.5, 0.5, "Insufficient data for timeline", ha='center')
else:
    ax3.text(0.5, 0.5, "No Group Data Available for Timeline", ha='center')

plt.tight_layout()
plt.show()