# GNN Model V2 (Context-aware GAT)

## STEP 0: Environment setup
Initialize libraries, load config, and set up logging for a reproducible V2 training run.

In [None]:
import os
import sys
import logging

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn.functional as F
import torch.optim as optim

# GNN Imports
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GATConv

# Utils
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# Add src/ directory to path for config + logger utils
sys.path.append(os.path.abspath('../src'))
import config
from utils import setup_logger

logger = setup_logger(
    name="gnn-model-v2",
    log_dir=getattr(config, 'LOG_DIR', 'log'),
    filename="gnn_model_v2_training.log",
    level=logging.INFO,
    mode="w",
)

logger.info("--- STARTING GNN MODEL V2 TRAINING ---")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Device: {device}")

# --- 1. Load Data ---
clean_path = getattr(config, 'CLEANED_CSV_PATH', os.path.join(config.DATA_DIR, "vehicle_positions_cleaned.csv"))
logger.info(f"Loading data from: {clean_path}")
df = pd.read_csv(clean_path)

# Sort strictly for lag/sequence calculation
df['dt'] = pd.to_datetime(df['timestamp'])
df = df.sort_values(['trip_id', 'dt']).reset_index(drop=True)
logger.info(f"Loaded rows: {len(df):,}")

## STEP 1: Feature engineering (context features)
Add real-time context features so the GNN can learn more than static stop popularity.

- `prev_stop_delay`: previous-stop delay within each `trip_id`
- Cyclical time: `hour_sin/cos`, `day_sin/cos`
- Trajectory context: `stop_sequence` within each trip
- Stop history: `history_mean` (target encoding)
- Scaling: separate scalers for node features vs. dynamic context

In [None]:
# =============================================================================
# STEP 2: FEATURE ENGINEERING (V2 UPGRADES)
# =============================================================================
logger.info("--- Generating Advanced Features ---")

# 1. Real-Time Lag (Previous Stop Delay)
# Shift delay by 1 within each trip
df['prev_stop_delay'] = df.groupby('trip_id')['delay_seconds'].shift(1).fillna(0)

# 2. Cyclical Time Embeddings
df['hour'] = df['dt'].dt.hour
df['day_of_week'] = df['dt'].dt.dayofweek

df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
df['day_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
df['day_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)

# 3. [NEW] Stop Sequence (Trajectory Context)
# Counts 0, 1, 2... for every stop in the trip.
# Helps model know if it's the start or end of a route.
logger.info("Generating 'Stop Sequence' feature...")
df['stop_sequence'] = df.groupby('trip_id').cumcount()

# 4. Target Encoding (History)
stop_col = 'last_stop_id' if 'last_stop_id' in df.columns else 'stop_id'
global_mean = df['delay_seconds'].mean()
stop_history = df.groupby(stop_col)['delay_seconds'].mean()
df['history_mean'] = df[stop_col].map(stop_history).fillna(global_mean)

# --- SCALING ---
logger.info("Scaling Features...")

# A. Node Features (Static)
scaler_nodes = StandardScaler()
df[['lat_scaled', 'lon_scaled', 'hist_scaled']] = scaler_nodes.fit_transform(
    df[['latitude', 'longitude', 'history_mean']]
)

# B. Context Features (Dynamic)
# Clip lag to -30min to +60min to prevent extreme outliers from killing the gradients
df['prev_delay_clipped'] = df['prev_stop_delay'].clip(-1800, 3600)
scaler_lag = StandardScaler()
df['prev_delay_scaled'] = scaler_lag.fit_transform(df[['prev_delay_clipped']])

# Scale Sequence
scaler_seq = StandardScaler()
df['seq_scaled'] = scaler_seq.fit_transform(df[['stop_sequence']])

logger.info("Feature Engineering Complete.")

## STEP 2: Graph construction (static city map)
Build a directed stop graph once and keep it on the GPU.

- Nodes: unique stops (`stop_idx`) with static features (`lat_scaled`, `lon_scaled`, `hist_scaled`)
- Edges: observed stop-to-stop transitions within each trip (A → B)
- Output: `torch_geometric.data.Data(x, edge_index)` used by every batch

In [None]:
# =============================================================================
# STEP 3: GRAPH CONSTRUCTION
# =============================================================================
logger.info("--- Constructing Transit Graph ---")

stop_encoder = LabelEncoder()
df['stop_idx'] = stop_encoder.fit_transform(df[stop_col])

# Node Features Tensor
node_cols = ['lat_scaled', 'lon_scaled', 'hist_scaled']
node_features_df = df.groupby('stop_idx')[node_cols].mean()
x = torch.tensor(node_features_df.values, dtype=torch.float)

# Edge Index
df_sorted = df.sort_values(by=['trip_id', 'dt'])
df_sorted['next_stop_idx'] = df_sorted.groupby('trip_id')['stop_idx'].shift(-1)
edges_df = df_sorted.dropna(subset=['next_stop_idx'])
unique_edges = edges_df[['stop_idx', 'next_stop_idx']].drop_duplicates()
edge_index = torch.tensor(unique_edges.values.T, dtype=torch.long)

# Load to GPU
graph_data = Data(x=x, edge_index=edge_index).to(device)
logger.info(f"Graph on GPU: {graph_data}")

## STEP 3: Model definition (Context-aware GAT)
Use attention-based message passing to produce a stop embedding, then fuse it with per-observation context.

- Spatial encoder: 2-layer GAT over the stop graph
- Fusion: concatenate stop embedding + context features
- Head: small MLP regressor to predict delay seconds

In [None]:
# =============================================================================
# STEP 4: GNN MODEL V2 ARCHITECTURE
# =============================================================================
logger.info("--- Defining Model V2 (Wider Capacity) ---")

class ContextAwareGAT_V2(torch.nn.Module):
    def __init__(self, num_node_features, num_context_features):
        super(ContextAwareGAT_V2, self).__init__()
        
        # 1. Graph Layers (Increased Width: 32 -> 64)
        # Heads=4, so internal dimension is 64*4 = 256
        self.gat1 = GATConv(num_node_features, 64, heads=4, dropout=0.2) 
        self.gat2 = GATConv(64 * 4, 128, heads=1, concat=False, dropout=0.2)
        
        # 2. Fusion (Graph=128 + Context)
        fusion_dim = 128 + num_context_features
        
        # 3. Deeper Regression Head
        self.fc1 = torch.nn.Linear(fusion_dim, 256)
        self.bn1 = torch.nn.BatchNorm1d(256)
        
        self.fc2 = torch.nn.Linear(256, 128)
        self.bn2 = torch.nn.BatchNorm1d(128)
        
        self.fc3 = torch.nn.Linear(128, 64)
        self.fc_out = torch.nn.Linear(64, 1)

    def forward(self, x, edge_index, context_features, stop_indices):
        # Graph Phase
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = self.gat2(x, edge_index)
        x = F.elu(x)
        
        # Fusion Phase
        batch_embeddings = x[stop_indices]
        combined = torch.cat([batch_embeddings, context_features], dim=1)
        
        # Regression Phase
        out = self.fc1(combined)
        out = self.bn1(out)
        out = F.elu(out)
        
        out = self.fc2(out)
        out = self.bn2(out)
        out = F.elu(out)
        
        out = self.fc3(out)
        out = F.elu(out)
        
        return self.fc_out(out)

# Define Context Features (6 Total now)
context_cols = ['hour_sin', 'hour_cos', 'day_sin', 'day_cos', 'prev_delay_scaled', 'seq_scaled']
num_context = len(context_cols)

model = ContextAwareGAT_V2(num_node_features=3, num_context_features=num_context).to(device)
logger.info(f"Model V2 Initialized. Context Features: {num_context}")


## STEP 4: Training (train/val/test + scheduler)
Train with MAE (L1 loss) and monitor validation MAE for `ReduceLROnPlateau`.

- Data flow: graph stays on GPU; batches stream `context_features` + `stop_indices`
- Splits: train/val/test with fixed seed for reproducibility
- Logging: per-epoch train MAE, val MAE, best val, and LR changes

In [None]:
# =============================================================================
# STEP 4: TRAINING WITH SCHEDULER (VAL-AWARE)
# =============================================================================
logger.info("--- Preparing V2 Training Loop (Train/Val/Test + Plateau Monitoring) ---")

# -------------------------
# Config-driven hyperparams
# -------------------------
random_state = getattr(config, 'RANDOM_STATE', getattr(config, 'SEED', 42))
test_size = float(getattr(config, 'TEST_SIZE', 0.2))
val_size = float(getattr(config, 'VAL_SIZE', 0.1))  # fraction of total rows

batch_size = int(getattr(config, 'GNN_V2_BATCH_SIZE', 16384))
num_epochs = int(getattr(config, 'GNN_V2_NUM_EPOCHS', 50))
learning_rate = float(getattr(config, 'GNN_V2_LR', 0.003))
weight_decay = float(getattr(config, 'GNN_V2_WEIGHT_DECAY', getattr(config, 'WEIGHT_DECAY', 0.0)))
num_workers = int(getattr(config, 'NUM_WORKERS', 4))
pin_memory = bool(getattr(config, 'PIN_MEMORY', True))

sched_factor = float(getattr(config, 'GNN_V2_SCHED_FACTOR', 0.5))
sched_patience = int(getattr(config, 'GNN_V2_SCHED_PATIENCE', 3))
sched_threshold = float(getattr(config, 'GNN_V2_SCHED_THRESHOLD', 1e-4))
sched_min_lr = float(getattr(config, 'GNN_V2_SCHED_MIN_LR', 1e-6))
sched_cooldown = int(getattr(config, 'GNN_V2_SCHED_COOLDOWN', 0))
champion_mae = float(getattr(config, 'CHAMPION_MAE', 43.18))

logger.info(
    "Hyperparams | "
    f"epochs={num_epochs} batch_size={batch_size} lr={learning_rate} wd={weight_decay} "
    f"test_size={test_size} val_size={val_size} seed={random_state} "
    f"plateau(factor={sched_factor}, patience={sched_patience}, threshold={sched_threshold}, min_lr={sched_min_lr}, cooldown={sched_cooldown})"
 )

# -------------------------
# Prepare data arrays
# -------------------------
X_context = df[context_cols].values
X_stop_idx = df['stop_idx'].values
y_target = df['delay_seconds'].values

indices = np.arange(len(df))
trainval_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)

# Convert val_size (global) into fraction of trainval
val_rel = val_size / max(1e-12, (1.0 - test_size))
val_rel = min(max(val_rel, 0.0), 0.5)  # keep sane
train_idx, val_idx = train_test_split(trainval_idx, test_size=val_rel, random_state=random_state)

logger.info(f"Split sizes | train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}")

# -------------------------
# Tensors & loaders
# -------------------------
def make_tensors(idxs):
    ctx = torch.tensor(X_context[idxs], dtype=torch.float32)
    stp = torch.tensor(X_stop_idx[idxs], dtype=torch.long)
    yt = torch.tensor(y_target[idxs], dtype=torch.float32).view(-1, 1)
    return ctx, stp, yt

train_context, train_stop_indices, y_train_tensor = make_tensors(train_idx)
val_context, val_stop_indices, y_val_tensor = make_tensors(val_idx)
test_context, test_stop_indices, y_test_tensor = make_tensors(test_idx)

train_dataset = TensorDataset(train_context, train_stop_indices, y_train_tensor)
val_dataset = TensorDataset(val_context, val_stop_indices, y_val_tensor)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
    drop_last=False,
 )
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    drop_last=False,
 )

# -------------------------
# Optimizer & scheduler
# -------------------------
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=sched_factor,
    patience=sched_patience,
    threshold=sched_threshold,
    cooldown=sched_cooldown,
    min_lr=sched_min_lr,
 )
criterion = torch.nn.L1Loss()

# Log model size (useful for thesis + reproducibility)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters | trainable={trainable_params:,} total={total_params:,}")

@torch.no_grad()
def evaluate_mae(loader) -> float:
    model.eval()
    total = 0.0
    n_batches = 0
    for context_batch, stop_idx_batch, y_batch in loader:
        context_batch = context_batch.to(device, non_blocking=True)
        stop_idx_batch = stop_idx_batch.to(device, non_blocking=True)
        y_batch = y_batch.to(device, non_blocking=True)
        out = model(graph_data.x, graph_data.edge_index, context_batch, stop_idx_batch)
        loss = criterion(out, y_batch)
        total += float(loss.item())
        n_batches += 1
    return total / max(1, n_batches)

logger.info("Starting Training with ReduceLROnPlateau (stepped on VAL MAE)...")

best_val = float('inf')
history = {'train_mae': [], 'val_mae': [], 'lr': []}

for epoch in range(num_epochs):
    model.train()
    total_train = 0.0
    n_train_batches = 0

    for context_batch, stop_idx_batch, y_batch in train_loader:
        context_batch = context_batch.to(device, non_blocking=True)
        stop_idx_batch = stop_idx_batch.to(device, non_blocking=True)
        y_batch = y_batch.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        out = model(graph_data.x, graph_data.edge_index, context_batch, stop_idx_batch)
        loss = criterion(out, y_batch)
        loss.backward()
        optimizer.step()

        total_train += float(loss.item())
        n_train_batches += 1

    train_mae = total_train / max(1, n_train_batches)
    val_mae = evaluate_mae(val_loader)

    lr_before = float(optimizer.param_groups[0]['lr'])
    scheduler.step(val_mae)
    lr_after = float(optimizer.param_groups[0]['lr'])
    lr_note = "LR↓" if lr_after < lr_before else "LR="

    improved = val_mae < (best_val - sched_threshold)
    if improved:
        best_val = val_mae

    # Plateau diagnostics
    sched_best = getattr(scheduler, 'best', None)
    bad_epochs = getattr(scheduler, 'num_bad_epochs', None)
    cooldown = getattr(scheduler, 'cooldown_counter', None)
    delta_to_champion = val_mae - champion_mae

    history['train_mae'].append(train_mae)
    history['val_mae'].append(val_mae)
    history['lr'].append(lr_after)

    logger.info(
        f"Epoch {epoch+1:03d}/{num_epochs} | "
        f"train_MAE={train_mae:.2f}s | val_MAE={val_mae:.2f}s | best_val={best_val:.2f}s | "
        f"Δ_vs_{champion_mae:.2f}s={delta_to_champion:+.2f}s | "
        f"{lr_note} {lr_after:.6g} | bad_epochs={bad_epochs} | sched_best={sched_best} | cooldown={cooldown}"
    )

logger.info("V2 Training Complete.")

## STEP 5: Evaluation
Evaluate on the held-out test split and save a diagnostics figure.

- Metrics: MAE / RMSE / $R^2$
- Compare: delta vs. the Random Forest baseline (`CHAMPION_MAE`)
- Artifact: save `plots/gnn_model_v2_diagnostics.png`

In [None]:
# =============================================================================
# STEP 5: MODEL V2 EVALUATION
# =============================================================================

logger.info("--- Starting Model V2 Evaluation on Test Set ---")

# Config-driven evaluation settings
batch_size = int(getattr(config, 'GNN_V2_EVAL_BATCH_SIZE', getattr(config, 'GNN_V2_BATCH_SIZE', 16384)))
num_workers = int(getattr(config, 'NUM_WORKERS', 4))
pin_memory = bool(getattr(config, 'PIN_MEMORY', True))
plots_dir = getattr(config, 'PLOTS_DIR', os.path.join(os.getcwd(), 'plots'))
os.makedirs(plots_dir, exist_ok=True)

# 1. Create Test Loader
test_dataset = TensorDataset(test_context, test_stop_indices, y_test_tensor)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
 )

model.eval()
all_preds = []
all_targets = []

# 2. Run Inference
logger.info("Running inference...")
with torch.no_grad():
    for context_batch, stop_idx_batch, y_batch in test_loader:
        context_batch = context_batch.to(device, non_blocking=True)
        stop_idx_batch = stop_idx_batch.to(device, non_blocking=True)

        # Forward Pass
        out = model(graph_data.x, graph_data.edge_index, context_batch, stop_idx_batch)

        all_preds.append(out.cpu().numpy())
        all_targets.append(y_batch.cpu().numpy())

# Concatenate
y_pred = np.vstack(all_preds).flatten()
y_true = np.vstack(all_targets).flatten()

# 3. Calculate Metrics
mae = mean_absolute_error(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
r2 = r2_score(y_true, y_pred)

logger.info("--- FINAL GNN MODEL V2 RESULTS ---")
logger.info(f"MAE (Mean Absolute Error): {mae:.2f} seconds")
logger.info(f"RMSE (Root Mean Sq Error): {rmse:.2f} seconds")
logger.info(f"R² Score:                  {r2:.4f}")

# Comparisons
champion_mae = float(getattr(config, 'CHAMPION_MAE', 43.18))
logger.info(f"VS Baseline GNN (121s):     {121 - mae:+.2f}s improvement")
logger.info(f"VS Context RF ({champion_mae:.2f}s):  {champion_mae - mae:+.2f}s difference")

# 4. Generate Diagnostic Plots
fig, axes = plt.subplots(1, 3, figsize=(20, 6))

# Subsample for plotting
plot_indices = np.random.choice(len(y_pred), size=min(5000, len(y_pred)), replace=False)
y_pred_sub = y_pred[plot_indices]
y_true_sub = y_true[plot_indices]
residuals = y_true_sub - y_pred_sub

# --- Plot A: Predicted vs Actual ---
axes[0].scatter(y_true_sub, y_pred_sub, alpha=0.3, s=10, color='blue')
axes[0].plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', lw=2, label='Perfect Fit')
axes[0].set_title(f"GNN Model V2: Predicted vs Actual\nMAE: {mae:.2f}s | R²: {r2:.2f}")
axes[0].set_xlabel("Actual Delay (s)")
axes[0].set_ylabel("Predicted Delay (s)")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# --- Plot B: Residuals ---
axes[1].scatter(y_pred_sub, residuals, alpha=0.3, s=10, color='purple')
axes[1].axhline(0, color='red', linestyle='--', lw=2)
axes[1].set_title("Residuals vs Predictions")
axes[1].set_xlabel("Predicted Delay (s)")
axes[1].set_ylabel("Error (s)")
axes[1].grid(True, alpha=0.3)

# --- Plot C: Error Distribution ---
sns.histplot(residuals, bins=50, kde=True, ax=axes[2], color='green')
axes[2].axvline(0, color='red', linestyle='--', lw=2)
axes[2].set_title("Error Distribution")
axes[2].set_xlabel("Error (s)")

fig.tight_layout()
plot_path = os.path.join(plots_dir, "gnn_model_v2_diagnostics.png")
fig.savefig(plot_path, dpi=150)
plt.close(fig)
logger.info(f"Saved diagnostic plots to: {plot_path}")