# GNN Model V1 (Spatial baseline)

## STEP 0: Environment setup
Import libraries, load config, and initialize logging for a reproducible training run.

In [None]:
import os
import sys
import logging
import gc

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn.functional as F
import torch.optim as optim

# GNN & PyTorch Geometric
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

# Machine Learning utils
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder

# Add src/ directory to path for config + logger utils
sys.path.append(os.path.abspath('../src'))
import config
from utils import setup_logger

logger = setup_logger(
    name="gnn-model-v1",
    log_dir=getattr(config, 'LOG_DIR', os.path.join(os.getcwd(), 'log')),
    filename='gnn_model_v1_training.log',
    level=logging.INFO,
    mode='w',
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Device: {device}")

## STEP 1: Data loading and feature engineering
Load the cleaned dataset and build the node/trip features used by the spatial GNN baseline.

In [None]:
# --- 1. Load Data ---
clean_path = getattr(config, 'CLEANED_CSV_PATH', os.path.join(config.DATA_DIR, "vehicle_positions_cleaned.csv"))

if not os.path.exists(clean_path):
    logger.error(f"Cleaned data not found at {clean_path}")
    raise FileNotFoundError("Cleaned data not found. Please run the Cleaning Notebook first.")

logger.info(f"Loading data from: {clean_path}")
df = pd.read_csv(clean_path)

# --- 2. Extract Time Features ---
# Convert string timestamp to datetime objects
df['dt'] = pd.to_datetime(df['timestamp'])
df['hour'] = df['dt'].dt.hour
df['minute'] = df['dt'].dt.minute
df['day_of_week'] = df['dt'].dt.dayofweek
logger.info("Time features (hour, minute, day_of_week) extracted.")

# --- 3. Target Encoding (Historical Mean Delay) ---
# NOTE: We are calculating this on the WHOLE dataset for graph construction simplicity.
# In a strict production environment, this should be calculated on Train only to avoid leakage.
logger.info("Calculating Global Historical Mean Delays (Target Encoding)...")

# Verify column name (sometimes it is 'stop_id', sometimes 'last_stop_id')
stop_col = 'last_stop_id' if 'last_stop_id' in df.columns else 'stop_id'
logger.info(f"Using column '{stop_col}' as the Stop Identifier.")

stop_history = df.groupby(stop_col)['delay_seconds'].mean()
global_mean = df['delay_seconds'].mean()

# Map history to a new column
df['history_mean'] = df[stop_col].map(stop_history).fillna(global_mean)

# --- 4. Scaling (StandardScaler) ---
# Neural Networks require normalized inputs (mean=0, std=1)
features_to_scale = ['latitude', 'longitude', 'hour', 'minute', 'day_of_week', 'history_mean']
logger.info(f"Scaling features: {features_to_scale}")

scaler = StandardScaler()
X_scaled = scaler.fit_transform(df[features_to_scale])

# Add these scaled features back to DF with distinct names
df[['lat_scaled', 'lon_scaled', 'h_scaled', 'm_scaled', 'd_scaled', 'hist_scaled']] = X_scaled

logger.info(f"Data Loaded & Processed. Total Rows: {len(df):,}")
logger.info(f"Feature Engineering Complete. Ready for Graph Construction.")

## STEP 2: Graph construction
Construct a stop graph (nodes=stops, edges=observed stop-to-stop transitions) for PyTorch Geometric.

In [None]:
# =============================================================================
# STEP 3: GRAPH CONSTRUCTION
# =============================================================================

logger.info("--- Constructing the Transit Graph ---")

# --- 1. Define Nodes (Unique Stops) ---
# Map string IDs (e.g., "F02165") to integer indices (0, 1, 2...)
# We use the LabelEncoder to create a consistent mapping
stop_encoder = LabelEncoder()

# Ensure we use the same column name identified in Step 2
stop_col = 'last_stop_id' if 'last_stop_id' in df.columns else 'stop_id'
df['stop_idx'] = stop_encoder.fit_transform(df[stop_col])

num_nodes = len(stop_encoder.classes_)
logger.info(f"Total Unique Stops (Nodes): {num_nodes:,}")

# --- 2. Create Node Features Tensor [Num_Nodes, 3] ---
# Features: Latitude, Longitude, Historical Mean Delay (All Scaled)
# We aggregate by stop_idx to get one feature vector per stop.
# Since Lat/Lon are static per stop, 'mean' effectively just selects the value.
node_features_df = df.groupby('stop_idx')[['lat_scaled', 'lon_scaled', 'hist_scaled']].mean()

# Convert to Float Tensor (GNNs require Float32)
x = torch.tensor(node_features_df.values, dtype=torch.float)
logger.info(f"Node Features Tensor created: {x.shape}")

# --- 3. Create Edge Index (The Road Network) ---
# Sort by trip and time to determine the sequence A -> B
logger.info("Inferring edges from trip sequences...")
df_sorted = df.sort_values(by=['trip_id', 'timestamp'])

# Shift the 'stop_idx' to find the next stop for every row within the same trip
df_sorted['next_stop_idx'] = df_sorted.groupby('trip_id')['stop_idx'].shift(-1)

# Filter valid transitions (drop the last stop of every trip where next is NaN)
edges_df = df_sorted.dropna(subset=['next_stop_idx'])

# Keep only unique connections to define the static graph structure
unique_edges = edges_df[['stop_idx', 'next_stop_idx']].drop_duplicates()

# Convert to Long Tensor [2, Num_Edges] and transpose to shape [2, E]
# Note: We cast to integer because 'next_stop_idx' became float due to NaNs
edge_index = torch.tensor(unique_edges.values.T, dtype=torch.long)

logger.info(f"Total Unique Connections (Edges): {edge_index.shape[1]:,}")

# --- 4. Package into Graph Object & Move to GPU ---
# We create the PyTorch Geometric Data object
graph_data = Data(x=x, edge_index=edge_index)

# Transfer to GPU (RTX 4060)
graph_data = graph_data.to(device)

logger.info("Graph successfully created and loaded to GPU!")
logger.info(f"Graph Info: {graph_data}")

# --- 5. Graph Sanity Check ---
# Calculate Average Degree (Avg connections per stop)
avg_degree = graph_data.num_edges / graph_data.num_nodes
logger.info(f"Average Degree: {avg_degree:.2f} (Avg connections per stop)")

## STEP 3: Model definition
Define a spatial GCN that combines stop embeddings with per-observation trip/time features.

In [None]:
# =============================================================================
# STEP 4: DEFINING THE MODEL ARCHITECTURE
# =============================================================================

logger.info("--- Defining GNN Architecture ---")

class TransitGNN(torch.nn.Module):
    def __init__(self, num_node_features, num_trip_features):
        super(TransitGNN, self).__init__()
        
        # 1. Graph Layers (Spatial Logic - "The City Structure")
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 32)
        
        # 2. Regression Layers (Temporal Logic - "The Specific Trip")
        self.fc1 = torch.nn.Linear(32 + num_trip_features, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, trip_features, stop_indices):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        dropout_p = float(getattr(config, 'DROPOUT_RATE', 0.1))
        x = F.dropout(x, p=dropout_p, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        batch_node_embeddings = x[stop_indices]
        combined = torch.cat([batch_node_embeddings, trip_features], dim=1)
        out = self.fc1(combined)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out

num_node_features = graph_data.x.shape[1]
num_trip_features = 3

logger.info(f"Model Input Params - Node Feats: {num_node_features}, Trip Feats: {num_trip_features}")

model = TransitGNN(num_node_features, num_trip_features).to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model initialized on {device} | trainable={trainable_params:,} total={total_params:,}")

## STEP 4: Training
Prepare DataLoaders and train the model using MAE (L1 loss). Log per-epoch train MAE.

In [None]:
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

# =============================================================================
# STEP 4: TRAINING THE GNN (CONFIG-DRIVEN)
# =============================================================================

logger.info("--- Preparing Data Loaders ---")

trip_cols = ['h_scaled', 'm_scaled', 'd_scaled']
num_trip_features = len(trip_cols)

test_size = float(getattr(config, 'TEST_SIZE', 0.2))
random_state = getattr(config, 'RANDOM_STATE', 42)
batch_size = int(getattr(config, 'GNN_V1_BATCH_SIZE', getattr(config, 'BATCH_SIZE', 16384)))
num_workers = int(getattr(config, 'NUM_WORKERS', 4))
pin_memory = bool(getattr(config, 'PIN_MEMORY', True))

X_trip = df[trip_cols].values
X_stop_idx = df['stop_idx'].values
y_target = df['delay_seconds'].values

indices = np.arange(len(df))
train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)

logger.info(f"Train rows: {len(train_idx):,} | Test rows: {len(test_idx):,}")

train_trip_feats = torch.tensor(X_trip[train_idx], dtype=torch.float32)
train_stop_indices = torch.tensor(X_stop_idx[train_idx], dtype=torch.long)
y_train_tensor = torch.tensor(y_target[train_idx], dtype=torch.float32).view(-1, 1)

test_trip_feats = torch.tensor(X_trip[test_idx], dtype=torch.float32)
test_stop_indices = torch.tensor(X_stop_idx[test_idx], dtype=torch.long)
y_test_tensor = torch.tensor(y_target[test_idx], dtype=torch.float32).view(-1, 1)

train_dataset = TensorDataset(train_trip_feats, train_stop_indices, y_train_tensor)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
    drop_last=False,
 )

logger.info(f"Batch size: {batch_size} | train batches/epoch: {len(train_loader)}")

# (Re)initialize model to ensure a clean run
model = TransitGNN(num_node_features=num_node_features, num_trip_features=num_trip_features).to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters | trainable={trainable_params:,} total={total_params:,}")

lr = float(getattr(config, 'GNN_V1_LR', getattr(config, 'LEARNING_RATE', 0.005)))
num_epochs = int(getattr(config, 'GNN_V1_NUM_EPOCHS', getattr(config, 'NUM_EPOCHS', 20)))

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.L1Loss()

logger.info(f"Training config: epochs={num_epochs} lr={lr} loss=L1(MAE)")

train_losses = []
for epoch in range(num_epochs):
    model.train()
    batch_loss = 0.0
    for trip_batch, stop_idx_batch, y_batch in train_loader:
        trip_batch = trip_batch.to(device, non_blocking=True)
        stop_idx_batch = stop_idx_batch.to(device, non_blocking=True)
        y_batch = y_batch.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        out = model(graph_data.x, graph_data.edge_index, trip_batch, stop_idx_batch)
        loss = criterion(out, y_batch)
        loss.backward()
        optimizer.step()
        batch_loss += float(loss.item())
    avg_loss = batch_loss / max(1, len(train_loader))
    train_losses.append(avg_loss)
    logger.info(f"Epoch {epoch + 1}/{num_epochs} | Train MAE: {avg_loss:.2f} sec")

logger.info("Training complete.")

## STEP 5: Evaluation and diagnostics
Evaluate on the test split (MAE/RMSE/R²) and save a diagnostics figure under `plots/`.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# =============================================================================
# STEP 5: MODEL EVALUATION & PLOTTING
# =============================================================================

logger.info("--- Evaluating on test set ---")

batch_size = int(getattr(config, 'GNN_V1_BATCH_SIZE', getattr(config, 'BATCH_SIZE', 16384)))
num_workers = int(getattr(config, 'NUM_WORKERS', 4))
pin_memory = bool(getattr(config, 'PIN_MEMORY', True))
plots_dir = os.path.abspath(getattr(config, 'PLOTS_DIR', os.path.join(os.getcwd(), 'plots')))
os.makedirs(plots_dir, exist_ok=True)

test_dataset = TensorDataset(test_trip_feats, test_stop_indices, y_test_tensor)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
 )

model.eval()
all_preds = []
all_targets = []

logger.info("Running inference on test set...")
with torch.no_grad():
    for trip_batch, stop_idx_batch, y_batch in test_loader:
        trip_batch = trip_batch.to(device, non_blocking=True)
        stop_idx_batch = stop_idx_batch.to(device, non_blocking=True)
        out = model(graph_data.x, graph_data.edge_index, trip_batch, stop_idx_batch)
        all_preds.append(out.cpu().numpy())
        all_targets.append(y_batch.cpu().numpy())

y_pred = np.vstack(all_preds).flatten()
y_true = np.vstack(all_targets).flatten()

mae = mean_absolute_error(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
r2 = r2_score(y_true, y_pred)

logger.info("--- GNN MODEL V1 RESULTS ---")
logger.info(f"MAE (sec):  {mae:.2f}")
logger.info(f"RMSE (sec): {rmse:.2f}")
logger.info(f"R²:         {r2:.4f}")

champion_mae = float(getattr(config, 'CHAMPION_MAE', 43.18))
logger.info(f"Delta vs champion MAE ({champion_mae:.2f}): {champion_mae - mae:+.2f}")

logger.info("Generating diagnostics plot...")
fig, axes = plt.subplots(1, 3, figsize=(20, 6))

rng = np.random.RandomState(getattr(config, 'RANDOM_STATE', 42))
plot_indices = rng.choice(len(y_pred), size=min(5000, len(y_pred)), replace=False)
y_pred_sub = y_pred[plot_indices]
y_true_sub = y_true[plot_indices]
residuals = y_true_sub - y_pred_sub

axes[0].scatter(y_true_sub, y_pred_sub, alpha=0.3, s=10, color='blue')
axes[0].plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', lw=2, label='Perfect Fit')
axes[0].set_title(f"Predicted vs Actual\nMAE: {mae:.1f}s | R²: {r2:.2f}")
axes[0].set_xlabel("Actual Delay (s)")
axes[0].set_ylabel("Predicted Delay (s)")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].scatter(y_pred_sub, residuals, alpha=0.3, s=10, color='purple')
axes[1].axhline(0, color='red', linestyle='--', lw=2)
axes[1].set_title("Residuals vs Predictions")
axes[1].set_xlabel("Predicted Delay (s)")
axes[1].set_ylabel("Error (Actual - Predicted)")
axes[1].grid(True, alpha=0.3)

sns.histplot(residuals, bins=50, kde=True, ax=axes[2], color='green')
axes[2].axvline(0, color='red', linestyle='--', lw=2)
axes[2].set_title("Error Distribution")
axes[2].set_xlabel("Prediction Error (s)")
axes[2].set_ylabel("Count")

plt.tight_layout()
plot_path = os.path.join(plots_dir, 'gnn_model_v1_diagnostics.png')
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.close(fig)
logger.info(f"Saved plot: {plot_path}")
plt.show()