In [None]:
import os
import sys
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence

# GNN / PyG
from torch_geometric.data import Data
from torch_geometric.nn import GATConv

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# Torch utils
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Project imports (config + logger)
sys.path.append(os.path.abspath('../src'))
import config
from utils import setup_logger

sns.set_style('whitegrid')

# GNN Model V3 — Temporal + Spatial Learning

This notebook implements a stronger iteration of our GNN by adding **temporal memory** on top of the spatial stop-graph model.

## What changes vs Model V2
- **Model V2**: Spatial stop embeddings (GAT) + per-row context (lag + time + sequence index).
- **Model V3**: Spatial stop embeddings (GAT) + **windowed temporal encoder (GRU)** over each trip’s recent context history.

## Why this can improve MAE
Transit delay is dynamic: being late at the previous stop and how that lateness evolves across the last few stops are strong predictors of the next delay. A GRU can encode this short-term evolution more directly than a single lag value.

All hyperparameters and paths are loaded from `src/config.py` (with `GNN_V3_*` keys).

In [None]:
# =============================================================================
# STEP 0: RUNTIME SETUP (CONFIG + LOGGING)
# =============================================================================

log_dir = getattr(config, 'LOG_DIR', 'log')
logger = setup_logger(
    name='gnn_model_v3',
    log_dir=log_dir,
    filename='gnn_model_v3_training.log',
    level=logging.INFO,
    mode='w',
)

seed = int(getattr(config, 'SEED', 42))
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info('--- STARTING GNN MODEL V3 TRAINING ---')
logger.info(f"Device: {device}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")

plots_dir = getattr(config, 'PLOTS_DIR', os.path.join(os.getcwd(), 'plots'))
os.makedirs(plots_dir, exist_ok=True)

## 1. Data Loading + Canonical Sorting

We load the cleaned dataset and enforce a strict sort by `(trip_id, timestamp)` so that temporal features (lags, deltas, rolling windows) are correct and reproducible.

In [None]:
# =============================================================================
# STEP 1: LOAD DATA
# =============================================================================

clean_path = getattr(
    config,
    'CLEANED_CSV_PATH',
    os.path.join(getattr(config, 'DATA_DIR', 'data'), 'vehicle_positions_cleaned.csv')
)
logger.info(f"Loading data from: {clean_path}")
df = pd.read_csv(clean_path)

required_cols = ['timestamp', 'trip_id', 'delay_seconds', 'latitude', 'longitude']
missing_required = [c for c in required_cols if c not in df.columns]
if missing_required:
    raise ValueError(f"Missing required columns: {missing_required}")

# Choose stop column name consistently
stop_col = 'last_stop_id' if 'last_stop_id' in df.columns else 'stop_id'
if stop_col not in df.columns:
    raise ValueError("No stop column found (expected 'last_stop_id' or 'stop_id')")

# Basic cleanup (keep minimal; do not silently change semantics)
df = df.dropna(subset=['timestamp', 'trip_id', stop_col, 'delay_seconds', 'latitude', 'longitude']).copy()
df['dt'] = pd.to_datetime(df['timestamp'], errors='coerce')
df = df.dropna(subset=['dt']).copy()

# Canonical sort for temporal feature engineering
df = df.sort_values(['trip_id', 'dt']).reset_index(drop=True)
logger.info(f"Rows after basic cleanup: {len(df):,}")

## 2. Feature Engineering (V3: Temporal-Signal Upgrade)

Model V3 adds features that help a temporal encoder learn *how delay evolves* within a trip:

- **Lag**: previous stop delay (`prev_stop_delay`).
- **Rolling lag**: short-window mean of lag (`rolling_prev_delay`) to stabilize noise.
- **Time delta**: seconds since the previous observation in the trip (`time_delta_sec`).
- **Progress**: normalized trip progress $\in [0,1]$ so the model can compare different trip lengths.
- **Cyclical time**: sine/cosine for hour and day-of-week.

We then scale static node features and dynamic context features for stable training.

In [None]:
# =============================================================================
# STEP 2: FEATURE ENGINEERING (V3)
# =============================================================================

logger.info('--- Generating V3 Features ---')

# -------------------------
# Config-driven knobs
# -------------------------
lag_clip_min = int(getattr(config, 'GNN_V3_LAG_CLIP_MIN', -1800))
lag_clip_max = int(getattr(config, 'GNN_V3_LAG_CLIP_MAX', 3600))
time_delta_clip = int(getattr(config, 'GNN_V3_TIME_DELTA_CLIP_SEC', 900))
rolling_w = int(getattr(config, 'GNN_V3_ROLLING_LAG_WINDOW', 3))

# 1) Lag feature
df['prev_stop_delay'] = df.groupby('trip_id')['delay_seconds'].shift(1).fillna(0.0)

# 2) Rolling lag mean (using only past information)
df['rolling_prev_delay'] = (
    df.groupby('trip_id')['prev_stop_delay']
      .rolling(window=rolling_w, min_periods=1)
      .mean()
      .reset_index(level=0, drop=True)
)

# 3) Time delta between consecutive observations within trip
df['time_delta_sec'] = df.groupby('trip_id')['dt'].diff().dt.total_seconds().fillna(0.0)
df['time_delta_sec'] = df['time_delta_sec'].clip(0, time_delta_clip)

# 4) Progress in trip (normalized)
trip_len = df.groupby('trip_id')[stop_col].transform('size').astype(np.float32)
stop_sequence = df.groupby('trip_id').cumcount().astype(np.float32)
df['progress'] = np.where(trip_len > 1, stop_sequence / (trip_len - 1.0), 0.0)

# 5) Cyclical time embeddings
df['hour'] = df['dt'].dt.hour
df['day_of_week'] = df['dt'].dt.dayofweek
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
df['day_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
df['day_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)

# 6) Stop-level historical mean delay (static signal)
global_mean = float(df['delay_seconds'].mean())
stop_history = df.groupby(stop_col)['delay_seconds'].mean()
df['history_mean'] = df[stop_col].map(stop_history).fillna(global_mean)

# -------------------------
# Scaling
# -------------------------
logger.info('Scaling features...')

# A) Node features (static): lat/lon + history
scaler_nodes = StandardScaler()
df[['lat_scaled', 'lon_scaled', 'hist_scaled']] = scaler_nodes.fit_transform(
    df[['latitude', 'longitude', 'history_mean']]
 )

# B) Dynamic features (context stream)
df['prev_delay_clipped'] = df['prev_stop_delay'].clip(lag_clip_min, lag_clip_max)
df['rolling_prev_delay_clipped'] = df['rolling_prev_delay'].clip(lag_clip_min, lag_clip_max)

scaler_lag = StandardScaler()
df['prev_delay_scaled'] = scaler_lag.fit_transform(df[['prev_delay_clipped']])

scaler_roll = StandardScaler()
df['rolling_prev_delay_scaled'] = scaler_roll.fit_transform(df[['rolling_prev_delay_clipped']])

scaler_delta = StandardScaler()
df['time_delta_scaled'] = scaler_delta.fit_transform(df[['time_delta_sec']])

scaler_prog = StandardScaler()
df['progress_scaled'] = scaler_prog.fit_transform(df[['progress']])

logger.info('Feature engineering complete.')

## 3. Construct the Static Stop Graph

We build a directed stop graph once and keep it on the GPU. Each training example will reference a stop index (node) plus a **temporal context window**.

- **Nodes**: unique stops, with static features `[lat_scaled, lon_scaled, hist_scaled]`.
- **Edges**: inferred from sequential stop transitions within trips (A → B if B follows A).

In [None]:
# =============================================================================
# STEP 3: GRAPH CONSTRUCTION
# =============================================================================

logger.info('--- Constructing Transit Graph ---')

stop_encoder = LabelEncoder()
df['stop_idx'] = stop_encoder.fit_transform(df[stop_col].astype(str))

# Node feature tensor (mean over rows that visit each stop)
node_cols = ['lat_scaled', 'lon_scaled', 'hist_scaled']
node_features_df = df.groupby('stop_idx')[node_cols].mean()
x = torch.tensor(node_features_df.values, dtype=torch.float32)

# Edges from sequential transitions within trip
df_sorted = df.sort_values(by=['trip_id', 'dt']).copy()
df_sorted['next_stop_idx'] = df_sorted.groupby('trip_id')['stop_idx'].shift(-1)
edges_df = df_sorted.dropna(subset=['next_stop_idx'])
unique_edges = edges_df[['stop_idx', 'next_stop_idx']].drop_duplicates()
edge_index = torch.tensor(unique_edges.values.T, dtype=torch.long)

graph_data = Data(x=x, edge_index=edge_index).to(device)
logger.info(f"Graph ready | nodes={graph_data.num_nodes:,} edges={graph_data.num_edges:,}")

## 4. Build a Windowed Temporal Dataset (Per-Trip Sequences)

The key upgrade in Model V3 is that we do not feed only a single “previous delay” scalar.
Instead, we create a short window of recent context for each row (within the same `trip_id`) and encode it with a GRU.

**How the windows work**
- For each row at position $t$ in a trip, we take the last $k$ timesteps of context features up to $t$ (left-padded with zeros if needed).
- We also store the true window length (for packing into the GRU).
- The static graph stays unchanged; each example also carries a `stop_idx` to pull the correct node embedding.

In [None]:
# =============================================================================
# STEP 4: WINDOWED SEQUENCE DATASET (V3)
# =============================================================================

logger.info('--- Preparing lazy temporal windows for GRU ---')

seq_len = int(getattr(config, 'GNN_V3_SEQ_LEN', 12))
split_by_trip = bool(getattr(config, 'GNN_V3_SPLIT_BY_TRIP', True))

# Dynamic context features to be fed as a temporal sequence
context_seq_cols = [
    'hour_sin', 'hour_cos', 'day_sin', 'day_cos',
    'prev_delay_scaled', 'rolling_prev_delay_scaled',
    'time_delta_scaled', 'progress_scaled',
]
missing_ctx = [c for c in context_seq_cols if c not in df.columns]
if missing_ctx:
    raise ValueError(f"Missing context columns for V3: {missing_ctx}")

# IMPORTANT: With ~15M rows, pre-building X_seq_all would allocate multiple GB and can freeze VS Code.
# Instead, we keep only the base arrays and build windows on-the-fly per batch via a custom collate_fn.

X_ctx_base = df[context_seq_cols].to_numpy(dtype=np.float32, copy=False)
stop_idx_base = df['stop_idx'].to_numpy(dtype=np.int64, copy=False)
y_base = df['delay_seconds'].to_numpy(dtype=np.float32, copy=False).reshape(-1, 1)

n = len(df)
num_feat = len(context_seq_cols)
logger.info(f"Base arrays | rows={n:,} seq_len={seq_len} features={num_feat}")

# Precompute trip segment info (cheap vs building a (n, seq_len, feat) tensor)
trip_codes, _ = pd.factorize(df['trip_id'].astype(str), sort=False)
change = np.flatnonzero(np.diff(trip_codes) != 0) + 1
starts = np.concatenate(([0], change))
ends = np.concatenate((change, [n]))

trip_start = np.empty((n,), dtype=np.int64)
pos_in_trip = np.empty((n,), dtype=np.int64)
for s, e in zip(starts, ends):
    trip_start[s:e] = s
    pos_in_trip[s:e] = np.arange(e - s, dtype=np.int64)

# -------------------------
# Train / Val / Test split
# -------------------------
random_state = int(getattr(config, 'RANDOM_STATE', getattr(config, 'SEED', 42)))
test_size = float(getattr(config, 'TEST_SIZE', 0.2))
val_size = float(getattr(config, 'VAL_SIZE', 0.1))  # fraction of total rows
indices = np.arange(n)

if split_by_trip:
    trips = df['trip_id'].astype(str).unique()
    trainval_trips, test_trips = train_test_split(trips, test_size=test_size, random_state=random_state)
    val_rel = val_size / max(1e-12, (1.0 - test_size))
    val_rel = min(max(val_rel, 0.0), 0.5)
    train_trips, val_trips = train_test_split(trainval_trips, test_size=val_rel, random_state=random_state)

    train_mask = df['trip_id'].astype(str).isin(train_trips).to_numpy()
    val_mask = df['trip_id'].astype(str).isin(val_trips).to_numpy()
    test_mask = df['trip_id'].astype(str).isin(test_trips).to_numpy()

    train_idx = indices[train_mask]
    val_idx = indices[val_mask]
    test_idx = indices[test_mask]
else:
    trainval_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
    val_rel = val_size / max(1e-12, (1.0 - test_size))
    val_rel = min(max(val_rel, 0.0), 0.5)
    train_idx, val_idx = train_test_split(trainval_idx, test_size=val_rel, random_state=random_state)

logger.info(f"Split sizes | train={len(train_idx):,} val={len(val_idx):,} test={len(test_idx):,} | split_by_trip={split_by_trip}")

class RowIndexDataset(torch.utils.data.Dataset):
    def __init__(self, row_indices: np.ndarray):
        self.row_indices = row_indices
    def __len__(self):
        return int(self.row_indices.shape[0])
    def __getitem__(self, i: int):
        return int(self.row_indices[i])

def make_window_batch(row_indices: list[int]):
    # Builds (B, seq_len, num_feat) lazily for this batch only.
    batch_size = len(row_indices)
    seq = np.zeros((batch_size, seq_len, num_feat), dtype=np.float32)
    lengths = np.empty((batch_size,), dtype=np.int64)
    stop_idx = np.empty((batch_size,), dtype=np.int64)
    y = np.empty((batch_size, 1), dtype=np.float32)

    for b, r in enumerate(row_indices):
        pos = int(pos_in_trip[r])
        L = seq_len if pos + 1 >= seq_len else (pos + 1)
        start = r - L + 1
        seq[b, -L:, :] = X_ctx_base[start : r + 1]
        lengths[b] = L
        stop_idx[b] = stop_idx_base[r]
        y[b, 0] = y_base[r, 0]

    return (
        torch.from_numpy(seq),
        torch.from_numpy(lengths),
        torch.from_numpy(stop_idx),
        torch.from_numpy(y),
    )

train_dataset = RowIndexDataset(train_idx)
val_dataset = RowIndexDataset(val_idx)
test_dataset = RowIndexDataset(test_idx)

## 5. Define the V3 Architecture (GAT + GRU Fusion)

Model V3 has two branches:

1) **Spatial branch (GAT)**: learns a stop embedding from the static stop graph.
2) **Temporal branch (GRU)**: learns a compact state from the last $k$ steps of context.

We then **fuse** these representations and regress the delay in seconds.

In [None]:
# =============================================================================
# STEP 5: MODEL DEFINITION (V3)
# =============================================================================

logger.info('--- Defining Model V3 (GAT + GRU) ---')

dropout = float(getattr(config, 'GNN_V3_DROPOUT', 0.2))
gru_hidden = int(getattr(config, 'GNN_V3_GRU_HIDDEN', 64))
num_seq_features = len(context_seq_cols)

class ContextAwareGAT_GRU_V3(nn.Module):
    def __init__(self, num_node_features: int, num_seq_features: int, gru_hidden: int, dropout: float):
        super().__init__()
        
        # Spatial encoder (similar capacity to V2)
        self.gat1 = GATConv(num_node_features, 64, heads=4, dropout=dropout)
        self.gat2 = GATConv(64 * 4, 128, heads=1, concat=False, dropout=dropout)
        
        # Temporal encoder
        self.gru = nn.GRU(
            input_size=num_seq_features,
            hidden_size=gru_hidden,
            num_layers=1,
            batch_first=True,
            dropout=0.0,
        )
        
        fusion_dim = 128 + gru_hidden
        self.fc1 = nn.Linear(fusion_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 64)
        self.fc_out = nn.Linear(64, 1)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x_nodes, edge_index, seq_batch, lengths, stop_indices):
        # Spatial
        x = self.gat1(x_nodes, edge_index)
        x = F.elu(x)
        x = self.gat2(x, edge_index)
        x = F.elu(x)
        stop_emb = x[stop_indices]  # [B, 128]
        
        # Temporal (pack padded sequence for efficiency/correctness)
        # lengths is [B] with values in [1, seq_len]
        lengths_cpu = lengths.detach().to('cpu')
        packed = pack_padded_sequence(
            seq_batch,
            lengths=lengths_cpu,
            batch_first=True,
            enforce_sorted=False,
        )
        _, h_n = self.gru(packed)
        temp_emb = h_n[-1]  # [B, gru_hidden]
        
        # Fuse + regress
        fused = torch.cat([stop_emb, temp_emb], dim=1)
        out = self.dropout(F.elu(self.bn1(self.fc1(fused))))
        out = self.dropout(F.elu(self.bn2(self.fc2(out))))
        out = self.dropout(F.elu(self.fc3(out)))
        return self.fc_out(out)

model = ContextAwareGAT_GRU_V3(
    num_node_features=graph_data.x.shape[1],
    num_seq_features=num_seq_features,
    gru_hidden=gru_hidden,
    dropout=dropout,
).to(device)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model V3 initialized | params(trainable)={trainable_params:,} | seq_len={seq_len} | seq_features={num_seq_features}")

## 6. Training Loop (Val-Aware Scheduler)

We use the same training philosophy as Model V2:
- Stream batches of lightweight tensors with a `DataLoader`.
- Keep the graph on the GPU.
- Optimize MAE (L1 loss).
- Step `ReduceLROnPlateau` on **validation MAE** (not training loss).

In [None]:
# =============================================================================
# STEP 6: TRAINING (V3)
# =============================================================================

logger.info('--- Preparing DataLoaders + Training ---')

batch_size = int(getattr(config, 'GNN_V3_BATCH_SIZE', 4096))
num_epochs = int(getattr(config, 'GNN_V3_NUM_EPOCHS', 50))
learning_rate = float(getattr(config, 'GNN_V3_LR', 0.003))
weight_decay = float(getattr(config, 'GNN_V3_WEIGHT_DECAY', 0.0))
num_workers = int(getattr(config, 'NUM_WORKERS', 4))
pin_memory = bool(getattr(config, 'PIN_MEMORY', True))

sched_factor = float(getattr(config, 'GNN_V3_SCHED_FACTOR', 0.5))
sched_patience = int(getattr(config, 'GNN_V3_SCHED_PATIENCE', 3))
sched_threshold = float(getattr(config, 'GNN_V3_SCHED_THRESHOLD', 1e-4))
sched_min_lr = float(getattr(config, 'GNN_V3_SCHED_MIN_LR', 1e-6))
sched_cooldown = int(getattr(config, 'GNN_V3_SCHED_COOLDOWN', 0))
champion_mae = float(getattr(config, 'CHAMPION_MAE', 43.18))

logger.info(
    "Hyperparams | "
    f"epochs={num_epochs} batch_size={batch_size} lr={learning_rate} wd={weight_decay} "
    f"seq_len={seq_len} gru_hidden={gru_hidden} dropout={dropout} "
    f"plateau(factor={sched_factor}, patience={sched_patience}, threshold={sched_threshold}, min_lr={sched_min_lr}, cooldown={sched_cooldown})"
 )

# DataLoaders build windows lazily via make_window_batch from Step 4
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
    drop_last=False,
    collate_fn=make_window_batch,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    drop_last=False,
    collate_fn=make_window_batch,
)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=sched_factor,
    patience=sched_patience,
    threshold=sched_threshold,
    cooldown=sched_cooldown,
    min_lr=sched_min_lr,
)
criterion = nn.L1Loss()

@torch.no_grad()
def evaluate_mae(loader) -> float:
    model.eval()
    total = 0.0
    n_batches = 0
    for seq_batch, lengths, stop_idx, y_batch in loader:
        seq_batch = seq_batch.to(device, non_blocking=True)
        lengths = lengths.to(device, non_blocking=True)
        stop_idx = stop_idx.to(device, non_blocking=True)
        y_batch = y_batch.to(device, non_blocking=True)
        out = model(graph_data.x, graph_data.edge_index, seq_batch, lengths, stop_idx)
        loss = criterion(out, y_batch)
        total += float(loss.item())
        n_batches += 1
    return total / max(1, n_batches)

logger.info('Starting training (scheduler stepped on VAL MAE)...')
best_val = float('inf')
history = {'train_mae': [], 'val_mae': [], 'lr': []}

for epoch in range(num_epochs):
    model.train()
    total_train = 0.0
    n_train_batches = 0
    for seq_batch, lengths, stop_idx, y_batch in train_loader:
        seq_batch = seq_batch.to(device, non_blocking=True)
        lengths = lengths.to(device, non_blocking=True)
        stop_idx = stop_idx.to(device, non_blocking=True)
        y_batch = y_batch.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        out = model(graph_data.x, graph_data.edge_index, seq_batch, lengths, stop_idx)
        loss = criterion(out, y_batch)
        loss.backward()
        optimizer.step()

        total_train += float(loss.item())
        n_train_batches += 1

    train_mae = total_train / max(1, n_train_batches)
    val_mae = evaluate_mae(val_loader)

    lr_before = float(optimizer.param_groups[0]['lr'])
    scheduler.step(val_mae)
    lr_after = float(optimizer.param_groups[0]['lr'])
    lr_note = 'LR↓' if lr_after < lr_before else 'LR='
    improved = val_mae < (best_val - sched_threshold)
    if improved:
        best_val = val_mae

    sched_best = getattr(scheduler, 'best', None)
    bad_epochs = getattr(scheduler, 'num_bad_epochs', None)
    cooldown = getattr(scheduler, 'cooldown_counter', None)
    delta_to_champion = val_mae - champion_mae

    history['train_mae'].append(train_mae)
    history['val_mae'].append(val_mae)
    history['lr'].append(lr_after)

    logger.info(
        f"Epoch {epoch+1:03d}/{num_epochs} | "
        f"train_MAE={train_mae:.2f}s | val_MAE={val_mae:.2f}s | best_val={best_val:.2f}s | "
        f"Δ_vs_{champion_mae:.2f}s={delta_to_champion:+.2f}s | "
        f"{lr_note} {lr_after:.6g} | bad_epochs={bad_epochs} | sched_best={sched_best} | cooldown={cooldown}"
    )

logger.info('V3 training complete.')

## 7. Evaluation + Diagnostics Plot

We evaluate on the held-out test set and save a diagnostic figure under `plots/` using an accurate filename:
- `gnn_model_v3_diagnostics.png`

In [None]:
# =============================================================================
# STEP 7: EVALUATION (V3)
# =============================================================================

logger.info('--- Starting Model V3 Evaluation on Test Set ---')

eval_batch_size = int(getattr(config, 'GNN_V3_EVAL_BATCH_SIZE', getattr(config, 'GNN_V3_BATCH_SIZE', 4096)))
test_loader = DataLoader(
    test_dataset,
    batch_size=eval_batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    collate_fn=make_window_batch,
)

model.eval()
all_preds = []
all_targets = []

logger.info('Running inference...')
with torch.no_grad():
    for seq_batch, lengths, stop_idx, y_batch in test_loader:
        seq_batch = seq_batch.to(device, non_blocking=True)
        lengths = lengths.to(device, non_blocking=True)
        stop_idx = stop_idx.to(device, non_blocking=True)
        out = model(graph_data.x, graph_data.edge_index, seq_batch, lengths, stop_idx)
        all_preds.append(out.detach().cpu().numpy())
        all_targets.append(y_batch.detach().cpu().numpy())

y_pred = np.vstack(all_preds).flatten()
y_true = np.vstack(all_targets).flatten()

mae = mean_absolute_error(y_true, y_pred)
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
r2 = float(r2_score(y_true, y_pred))

logger.info('--- FINAL GNN MODEL V3 RESULTS ---')
logger.info(f"MAE (Mean Absolute Error): {mae:.2f} seconds")
logger.info(f"RMSE (Root Mean Sq Error): {rmse:.2f} seconds")
logger.info(f"R² Score:                  {r2:.4f}")

champion_mae = float(getattr(config, 'CHAMPION_MAE', 43.18))
logger.info(f"VS Context RF ({champion_mae:.2f}s):  {champion_mae - mae:+.2f}s difference")

# Diagnostics
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
plot_n = min(5000, len(y_pred))
plot_indices = np.random.choice(len(y_pred), size=plot_n, replace=False)
y_pred_sub = y_pred[plot_indices]
y_true_sub = y_true[plot_indices]
residuals = y_true_sub - y_pred_sub

axes[0].scatter(y_true_sub, y_pred_sub, alpha=0.3, s=10, color='blue')
axes[0].plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', lw=2, label='Perfect Fit')
axes[0].set_title(f"GNN Model V3: Predicted vs Actual\nMAE: {mae:.2f}s | R²: {r2:.2f}")
axes[0].set_xlabel('Actual Delay (s)')
axes[0].set_ylabel('Predicted Delay (s)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].scatter(y_pred_sub, residuals, alpha=0.3, s=10, color='purple')
axes[1].axhline(0, color='red', linestyle='--', lw=2)
axes[1].set_title('Residuals vs Predictions')
axes[1].set_xlabel('Predicted Delay (s)')
axes[1].set_ylabel('Error (s)')
axes[1].grid(True, alpha=0.3)

sns.histplot(residuals, bins=50, kde=True, ax=axes[2], color='green')
axes[2].axvline(0, color='red', linestyle='--', lw=2)
axes[2].set_title('Error Distribution')
axes[2].set_xlabel('Error (s)')

fig.tight_layout()
plot_path = os.path.join(plots_dir, 'gnn_model_v3_diagnostics.png')
fig.savefig(plot_path, dpi=150)
plt.close(fig)
logger.info(f"Saved diagnostic plots to: {plot_path}")