## 1Ô∏è‚É£ Import Required Libraries

In [20]:
import os
import sys
import copy
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv, GAT
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

# MLflow & DagsHub
import mlflow
import mlflow.pytorch

# Set seeds
np.random.seed(42)
torch.manual_seed(42)

print("‚úÖ All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'üî• CUDA' if torch.cuda.is_available() else 'üíª CPU'}")

‚úÖ All libraries imported successfully!
PyTorch version: 2.8.0
Device: üíª CPU


## 2Ô∏è‚É£ Load & Explore Service Metrics Data

In [21]:
# Load all data
data_path = '../data'

print("üìä Loading transport network data...\n")
nodes_df = pd.read_csv(f'{data_path}/nodes.csv')
edges_df = pd.read_csv(f'{data_path}/edges.csv')
service_metrics_df = pd.read_csv(f'{data_path}/service_metrics.csv')
timetables_df = pd.read_csv(f'{data_path}/timetables.csv')

print(f"‚úÖ Data loaded!")
print(f"   üìç Nodes (locations): {len(nodes_df)}")
print(f"   üõ£Ô∏è  Edges (routes): {len(edges_df)}")
print(f"   üìä Service metrics: {len(service_metrics_df)}")
print(f"   üöå Timetables: {len(timetables_df)}\n")

# Display service metrics structure
print("üìà Service Metrics by Day Type & Time Period:")
print(service_metrics_df.head(10))
print("\nMetrics Statistics:")
print(service_metrics_df[['reliability_baseline', 'crowding_baseline', 'service_availability']].describe())

üìä Loading transport network data...

‚úÖ Data loaded!
   üìç Nodes (locations): 40
   üõ£Ô∏è  Edges (routes): 1539
   üìä Service metrics: 15
   üöå Timetables: 12270

üìà Service Metrics by Day Type & Time Period:
   metric_id day_type    time_period  reliability_baseline  crowding_baseline  \
0          1  regular  early_morning                  0.85               0.20   
1          2  regular   morning_peak                  0.75               0.90   
2          3  regular         midday                  0.80               0.50   
3          4  regular   evening_peak                  0.70               0.95   
4          5  regular          night                  0.65               0.30   
5          6  weekend  early_morning                  0.80               0.15   
6          7  weekend        morning                  0.78               0.60   
7          8  weekend      afternoon                  0.82               0.70   
8          9  weekend        evening            

## 3Ô∏è‚É£ Initialize DagsHub Integration

In [22]:
import dagshub
from dotenv import load_dotenv

# Load environment variables
load_dotenv(dotenv_path='../../.env')

DAGSHUB_REPO_OWNER = os.getenv("DAGSHUB_REPO_OWNER", "your-username")
DAGSHUB_REPO_NAME = os.getenv("DAGSHUB_REPO_NAME", "travion-research-project")

# Try DagsHub initialization
try:
    dagshub.init(repo_owner=DAGSHUB_REPO_OWNER, repo_name=DAGSHUB_REPO_NAME, mlflow=True)
    print("‚úÖ DagsHub initialized!")
    print(f"üìä Tracking URI: {mlflow.get_tracking_uri()}")
except Exception as e:
    print(f"‚ö†Ô∏è  DagsHub initialization failed: {e}")
    print("üìä Using local MLflow tracking...")
    mlflow.set_tracking_uri("file:///tmp/mlruns")
    mlflow.set_experiment("transport-gnn-routing")

# Start MLflow run
mlflow.start_run()
print(f"üöÄ MLflow run started!")

‚úÖ DagsHub initialized!
üìä Tracking URI: https://dagshub.com/iamsahan/ml-services.mlflow
üöÄ MLflow run started!


## 4Ô∏è‚É£ Prepare Graph Data Structure

In [23]:
print("üîß Building graph data structure...\n")

# 1. Node Features: Type + Region + Coordinates
type_encoder = LabelEncoder()
region_encoder = LabelEncoder()

type_encoded = type_encoder.fit_transform(nodes_df['type'])
region_encoded = region_encoder.fit_transform(nodes_df['region'])

# One-hot encoding
type_one_hot = np.eye(len(type_encoder.classes_))[type_encoded]
region_one_hot = np.eye(len(region_encoder.classes_))[region_encoded]

# Normalize coordinates
coords = nodes_df[['latitude', 'longitude']].values
coords_scaler = StandardScaler()
coords_normalized = coords_scaler.fit_transform(coords)

# Combine node features
node_features = np.hstack([type_one_hot, region_one_hot, coords_normalized])
node_features_tensor = torch.tensor(node_features, dtype=torch.float32)

print(f"‚úÖ Node Features: {node_features_tensor.shape}")
print(f"   Types: {list(type_encoder.classes_)}")
print(f"   Regions: {list(region_encoder.classes_)}")

# 2. Edge Features: Mode + Distance + Duration + Fare + IsActive + Frequency
mode_encoder = LabelEncoder()
mode_encoder.fit(edges_df['mode'].unique())

service_schedule_counts = timetables_df['service_id'].value_counts()

edge_features_list = []
target_labels = []

for idx, edge in edges_df.iterrows():
    service_id = edge['service_id']
    mode = edge['mode']
    distance = edge['distance_km']
    duration = edge['duration_min']
    fare = edge['fare_lkr']
    is_active = edge['is_active']
    
    mode_encoded = mode_encoder.transform([mode])[0]
    
    # Normalize features
    distance_norm = distance / 500
    duration_norm = duration / 600
    fare_norm = fare / 3000
    frequency_norm = min(service_schedule_counts.get(service_id, 1) / 10, 1.0)
    
    features = np.array([
        mode_encoded / (len(mode_encoder.classes_) - 1),
        distance_norm,
        duration_norm,
        fare_norm,
        float(is_active),
        frequency_norm
    ], dtype=np.float32)
    
    edge_features_list.append(features)
    
    # Target: quality heuristic
    rating = 0.3 * float(is_active) + 0.35 * (1 - min(fare_norm, 1.0)) + 0.35 * (1 - min(duration_norm, 1.0))
    target_labels.append(rating)

edge_features = np.array(edge_features_list)
target_labels = np.array(target_labels)

edge_features_tensor = torch.tensor(edge_features, dtype=torch.float32)
target_labels_tensor = torch.tensor(target_labels, dtype=torch.float32)

print(f"\n‚úÖ Edge Features: {edge_features_tensor.shape}")
print(f"   Target range: [{target_labels.min():.3f}, {target_labels.max():.3f}]")

# 3. Graph Structure: Edge Index
location_id_to_idx = {loc_id: idx for idx, loc_id in enumerate(nodes_df['location_id'])}

edge_list = []
for _, edge in edges_df.iterrows():
    src = location_id_to_idx[edge['origin_id']]
    dst = location_id_to_idx[edge['destination_id']]
    edge_list.append([src, dst])

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

print(f"\n‚úÖ Graph Structure:")
print(f"   Nodes: {len(nodes_df)}")
print(f"   Edges: {edge_index.shape[1]}")

üîß Building graph data structure...

‚úÖ Node Features: torch.Size([40, 14])
   Types: ['airport', 'city', 'train_station']
   Regions: ['Central', 'Eastern', 'North Central', 'North Western', 'Northern', 'Sabaragamuwa', 'Southern', 'Uva', 'Western']

‚úÖ Edge Features: torch.Size([1539, 6])
   Target range: [0.533, 0.969]

‚úÖ Graph Structure:
   Nodes: 40
   Edges: 1539


## 5Ô∏è‚É£ Build Graph Neural Network Model

In [44]:
class TemporalTransportGNN(nn.Module):
    """Temporal-Aware GNN for transport routing - matches service architecture"""
    
    def __init__(
        self,
        node_features,
        num_day_types=2,  # regular, weekend (simplified for now)
        num_time_periods=5,  # early_morning, morning_peak, midday, evening_peak, night
        num_modes=3,  # bus, train, ridehailing
        hidden_dim=64
    ):
        super(TemporalTransportGNN, self).__init__()
        
        # Location encoder (GCN)
        self.conv1 = GCNConv(node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Temporal embeddings
        self.day_type_embedding = nn.Embedding(num_day_types, hidden_dim // 2)
        self.time_period_embedding = nn.Embedding(num_time_periods, hidden_dim // 2)
        
        # Temporal fusion
        self.temporal_fusion = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Mode embedding
        self.mode_embedding = nn.Embedding(num_modes, hidden_dim // 2)
        
        # Prediction head (origin + dest + temporal + mode ‚Üí reliability score)
        self.prediction_head = nn.Sequential(
            nn.Linear(hidden_dim * 2 + hidden_dim + hidden_dim // 2, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()  # Output: 0-1 reliability score
        )
        
    def forward(self, x, edge_index, origin_idx, dest_idx, day_type_id, time_period_id, mode_id):
        """
        Args:
            x: Node features [num_nodes, node_features]
            edge_index: Graph connectivity [2, num_edges]
            origin_idx: Origin node indices [batch_size]
            dest_idx: Destination node indices [batch_size]
            day_type_id: Day type IDs [batch_size]
            time_period_id: Time period IDs [batch_size]
            mode_id: Transport mode IDs [batch_size]
        
        Returns:
            Reliability predictions [batch_size]
        """
        # Learn location embeddings
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        x = torch.relu(self.conv3(x, edge_index))
        
        # Get origin/destination embeddings
        origin_emb = x[origin_idx]
        dest_emb = x[dest_idx]
        
        # Get temporal embeddings
        day_emb = self.day_type_embedding(day_type_id)
        time_emb = self.time_period_embedding(time_period_id)
        temporal_emb = torch.cat([day_emb, time_emb], dim=1)
        temporal_emb = self.temporal_fusion(temporal_emb)
        
        # Get mode embedding
        mode_emb = self.mode_embedding(mode_id)
        
        # Combine all features
        combined = torch.cat([origin_emb, dest_emb, temporal_emb, mode_emb], dim=1)
        
        # Predict reliability
        predictions = self.prediction_head(combined).squeeze()
        return predictions

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# For training, we'll use 2 day types (regular/weekend) and 5 time periods
model = TemporalTransportGNN(
    node_features=node_features_tensor.shape[1],
    num_day_types=2,  # regular, weekend
    num_time_periods=5,  # early_morning, morning_peak, midday, evening_peak, night
    num_modes=3,  # bus, train, ridehailing
    hidden_dim=64
).to(device)

print("‚úÖ Temporal GNN Model initialized!")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Device: {device}")
print(f"   Architecture: TemporalTransportGNN (matches service)")


‚úÖ Temporal GNN Model initialized!
   Parameters: 52,929
   Device: cpu
   Architecture: TemporalTransportGNN (matches service)


## 6Ô∏è‚É£ Train GNN Model

In [56]:
# Hyperparameters
HIDDEN_DIM = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 100
BATCH_SIZE = 16
PATIENCE = 25
WEIGHT_DECAY = 5e-4

# Move graph data to device
node_features_tensor = node_features_tensor.to(device)
edge_index = edge_index.to(device)

# ======================================================================
# Prepare Training Data: (origin, dest, day_type, time_period, mode, rating)
# ======================================================================

# Map modes to IDs
MODE_TO_ID = {mode: idx for idx, mode in enumerate(edges_df['mode'].unique())}

# Map day types and time periods to IDs
DAY_TYPE_TO_ID = {"regular": 0, "weekend": 1}
TIME_PERIOD_TO_ID = {
    "early_morning": 0,
    "morning_peak": 1,
    "midday": 2,
    "evening_peak": 3,
    "night": 4
}

# Create training samples from edges
training_data = []
for idx, edge in edges_df.iterrows():
    origin_idx = location_id_to_idx[edge['origin_id']]
    dest_idx = location_id_to_idx[edge['destination_id']]
    mode_id = MODE_TO_ID[edge['mode']]
    
    # Create samples for different day types and time periods
    for day_type, day_id in DAY_TYPE_TO_ID.items():
        for time_period, time_id in TIME_PERIOD_TO_ID.items():
            # Get baseline metrics
            baseline_df = service_metrics_df[
                (service_metrics_df['day_type'] == day_type) &
                (service_metrics_df['time_period'] == time_period)
            ]
            
            if len(baseline_df) > 0:
                # Calculate target reliability based on service characteristics
                base_reliability = float(baseline_df['reliability_baseline'].values[0])
                
                # Adjust based on edge properties
                fare_norm = edge['fare_lkr'] / 3000
                duration_norm = edge['duration_min'] / 600
                is_active = float(edge['is_active'])
                
                # Target: base reliability * service quality
                quality_factor = 0.4 * is_active + 0.3 * (1 - min(fare_norm, 1.0)) + 0.3 * (1 - min(duration_norm, 1.0))
                target = base_reliability * quality_factor
                
                training_data.append({
                    'origin_idx': origin_idx,
                    'dest_idx': dest_idx,
                    'day_type_id': day_id,
                    'time_period_id': time_id,
                    'mode_id': mode_id,
                    'target': target
                })

print(f"üìä Generated {len(training_data)} training samples")
print(f"   From {len(edges_df)} edges √ó {len(DAY_TYPE_TO_ID)} day types √ó {len(TIME_PERIOD_TO_ID)} time periods")

# Convert to tensors
origins = torch.tensor([d['origin_idx'] for d in training_data], dtype=torch.long, device=device)
dests = torch.tensor([d['dest_idx'] for d in training_data], dtype=torch.long, device=device)
day_types = torch.tensor([d['day_type_id'] for d in training_data], dtype=torch.long, device=device)
time_periods = torch.tensor([d['time_period_id'] for d in training_data], dtype=torch.long, device=device)
modes = torch.tensor([d['mode_id'] for d in training_data], dtype=torch.long, device=device)
targets = torch.tensor([d['target'] for d in training_data], dtype=torch.float32, device=device)

# Add realistic noise
noise_std = 0.08
targets_noisy = targets + torch.randn_like(targets) * noise_std
targets_noisy = torch.clamp(targets_noisy, 0.0, 1.0)

# Train/Val/Test split
n_samples = len(targets)
train_size = int(0.7 * n_samples)
val_size = int(0.15 * n_samples)

indices = torch.randperm(n_samples, device=device)
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

print(f"\nüìä Data split:")
print(f"   Train: {len(train_indices)} | Val: {len(val_indices)} | Test: {len(test_indices)}")
print(f"   Target noise: ¬±{noise_std*100:.0f}% added\n")

# Optimizer & Loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
criterion = nn.MSELoss()

# Training loop
training_history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_mae': [], 'val_r2_score': [], 'val_accuracy': []}
best_val_loss = float('inf')
best_val_acc = 0.0
patience_counter = 0

print("üöÄ Starting training with Temporal GNN...\n")

def calculate_r2_score(predictions, targets):
    ss_res = torch.sum((targets - predictions) ** 2)
    ss_tot = torch.sum((targets - torch.mean(targets)) ** 2)
    r2 = 1 - (ss_res / ss_tot)
    return float(r2.cpu()) if hasattr(r2, 'cpu') else float(r2)

def calculate_accuracy(predictions, targets, threshold=0.02):
    errors = torch.abs(predictions - targets)
    correct = (errors <= threshold).float().mean()
    return float(correct.cpu()) if hasattr(correct, 'cpu') else float(correct)

for epoch in range(NUM_EPOCHS):
    # Training
    model.train()
    train_losses = []
    
    # Shuffle training indices
    perm = torch.randperm(len(train_indices), device=device)
    shuffled_train = train_indices[perm]
    
    for i in range(0, len(shuffled_train), BATCH_SIZE):
        batch_idx = shuffled_train[i:i+BATCH_SIZE]
        
        predictions = model(
            node_features_tensor,
            edge_index,
            origins[batch_idx],
            dests[batch_idx],
            day_types[batch_idx],
            time_periods[batch_idx],
            modes[batch_idx]
        )
        batch_targets = targets_noisy[batch_idx]
        
        loss = criterion(predictions, batch_targets)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_losses.append(loss.item())
    
    train_loss = np.mean(train_losses)
    
    # Validation (against clean targets)
    model.eval()
    with torch.no_grad():
        val_predictions = model(
            node_features_tensor,
            edge_index,
            origins[val_indices],
            dests[val_indices],
            day_types[val_indices],
            time_periods[val_indices],
            modes[val_indices]
        )
        val_targets = targets[val_indices]
        val_loss = criterion(val_predictions, val_targets).item()
        val_mae = torch.abs(val_predictions - val_targets).mean().item()
        val_r2 = calculate_r2_score(val_predictions, val_targets)
        val_acc = calculate_accuracy(val_predictions, val_targets, threshold=0.02) * 100
    
    training_history['epoch'].append(epoch)
    training_history['train_loss'].append(train_loss)
    training_history['val_loss'].append(val_loss)
    training_history['val_mae'].append(val_mae)
    training_history['val_r2_score'].append(val_r2)
    training_history['val_accuracy'].append(val_acc)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
    
    # Learning rate scheduling
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    if new_lr < old_lr:
        print(f"üìâ Learning rate reduced: {old_lr:.6f} ‚Üí {new_lr:.6f}")
    
    # Log to MLflow
    try:
        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metric("val_loss", val_loss, step=epoch)
        mlflow.log_metric("val_mae", val_mae, step=epoch)
        mlflow.log_metric("val_r2_score", val_r2, step=epoch)
        mlflow.log_metric("val_accuracy_percent", val_acc, step=epoch)
        mlflow.log_metric("learning_rate", optimizer.param_groups[0]['lr'], step=epoch)
    except:
        pass
    
    # Print progress
    if epoch % 5 == 0 or epoch < 10:
        print(f"{'‚úÖ' if val_loss < best_val_loss else '‚è≥'} Epoch {epoch+1:3d} | Loss: {val_loss:.4f} | MAE: {val_mae:.4f} | R¬≤: {val_r2:.4f} | Acc: {val_acc:.1f}%")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        best_model_state = copy.deepcopy(model.state_dict())
    else:
        patience_counter += 1
    
    if patience_counter >= PATIENCE:
        print(f"\n‚õî Early stopping at epoch {epoch+1}")
        model.load_state_dict(best_model_state)
        break

print(f"\n‚úÖ Training completed!")
print(f"üìä Best validation loss: {best_val_loss:.4f}")
print(f"üìä Best R¬≤ Score: {max(training_history['val_r2_score']):.4f}")
print(f"üìä Best Accuracy (¬±2% threshold): {best_val_acc:.1f}%")
print(f"üìä Final Accuracy: {training_history['val_accuracy'][-1]:.1f}%")


üìä Generated 10773 training samples
   From 1539 edges √ó 2 day types √ó 5 time periods

üìä Data split:
   Train: 7541 | Val: 1615 | Test: 1617
   Target noise: ¬±8% added

üöÄ Starting training with Temporal GNN...

‚úÖ Epoch   1 | Loss: 0.0037 | MAE: 0.0481 | R¬≤: 0.5754 | Acc: 26.4%
‚è≥ Epoch   2 | Loss: 0.0037 | MAE: 0.0489 | R¬≤: 0.5721 | Acc: 25.3%
‚è≥ Epoch   3 | Loss: 0.0037 | MAE: 0.0482 | R¬≤: 0.5736 | Acc: 28.2%
‚è≥ Epoch   4 | Loss: 0.0039 | MAE: 0.0505 | R¬≤: 0.5489 | Acc: 25.8%
‚è≥ Epoch   5 | Loss: 0.0039 | MAE: 0.0503 | R¬≤: 0.5517 | Acc: 24.3%
‚è≥ Epoch   6 | Loss: 0.0037 | MAE: 0.0490 | R¬≤: 0.5683 | Acc: 26.6%
‚è≥ Epoch   7 | Loss: 0.0038 | MAE: 0.0493 | R¬≤: 0.5622 | Acc: 24.8%
‚è≥ Epoch   8 | Loss: 0.0041 | MAE: 0.0531 | R¬≤: 0.5225 | Acc: 22.9%
‚è≥ Epoch   9 | Loss: 0.0038 | MAE: 0.0495 | R¬≤: 0.5556 | Acc: 26.4%
‚è≥ Epoch  10 | Loss: 0.0037 | MAE: 0.0483 | R¬≤: 0.5700 | Acc: 26.5%
‚è≥ Epoch  11 | Loss: 0.0038 | MAE: 0.0483 | R¬≤: 0.5550 | Acc: 27.0%
üìâ Lea

## 7Ô∏è‚É£ Implement Routing Prediction System

In [57]:
# Build adjacency graph for visualization
graph = defaultdict(list)
for idx, edge in edges_df.iterrows():
    src = location_id_to_idx[edge['origin_id']]
    dst = location_id_to_idx[edge['destination_id']]
    graph[src].append({
        'destination': dst,
        'edge_idx': idx,
        'mode': edge['mode'],
        'distance': edge['distance_km'],
        'duration': edge['duration_min'],
        'fare': edge['fare_lkr']
    })

print(f"‚úÖ Built routing graph with {len(graph)} connected nodes")
print(f"‚úÖ Temporal model trained and ready for inference")
print(f"   Note: Model predicts reliability based on (origin, destination, day_type, time_period, mode)")


‚úÖ Built routing graph with 39 connected nodes
‚úÖ Temporal model trained and ready for inference
   Note: Model predicts reliability based on (origin, destination, day_type, time_period, mode)


In [60]:
from collections import defaultdict

# ============================================================================
# TEMPORAL INFERENCE API: Predict Reliability Using Trained GNN
# ============================================================================

def predict_reliability_temporal(origin_id, destination_id, day_type, time_period, mode, query_datetime=None):
    """
    Predict reliability score using the trained TemporalTransportGNN.
    
    Args:
        origin_id: Starting location ID
        destination_id: Destination location ID
        day_type: 'regular' or 'weekend'
        time_period: 'early_morning', 'morning_peak', 'midday', 'evening_peak', or 'night'
        mode: 'bus', 'train', or 'ride_hail'
        query_datetime: Optional datetime for baseline metrics
    
    Returns:
        Dictionary with reliability prediction and metrics
    """
    
    if query_datetime is None:
        query_datetime = datetime.now()
    
    # Validate inputs
    origin_idx = location_id_to_idx.get(origin_id)
    dest_idx = location_id_to_idx.get(destination_id)
    
    if origin_idx is None or dest_idx is None:
        return {'error': 'Invalid origin or destination location'}
    
    if day_type not in DAY_TYPE_TO_ID:
        return {'error': f'Invalid day_type. Must be one of: {list(DAY_TYPE_TO_ID.keys())}'}
    
    if time_period not in TIME_PERIOD_TO_ID:
        return {'error': f'Invalid time_period. Must be one of: {list(TIME_PERIOD_TO_ID.keys())}'}
    
    if mode not in MODE_TO_ID:
        return {'error': f'Invalid mode. Must be one of: {list(MODE_TO_ID.keys())}'}
    
    # Get model prediction
    model.eval()
    with torch.no_grad():
        origin_tensor = torch.tensor([origin_idx], dtype=torch.long, device=device)
        dest_tensor = torch.tensor([dest_idx], dtype=torch.long, device=device)
        day_type_tensor = torch.tensor([DAY_TYPE_TO_ID[day_type]], dtype=torch.long, device=device)
        time_period_tensor = torch.tensor([TIME_PERIOD_TO_ID[time_period]], dtype=torch.long, device=device)
        mode_tensor = torch.tensor([MODE_TO_ID[mode]], dtype=torch.long, device=device)
        
        prediction = model(
            node_features_tensor,
            edge_index,
            origin_tensor,
            dest_tensor,
            day_type_tensor,
            time_period_tensor,
            mode_tensor
        )
        
        # Handle both scalar and array predictions
        reliability_score = float(prediction.item() if prediction.dim() == 0 else prediction.cpu().numpy()[0])
    
    # Get baseline metrics
    baseline_df = service_metrics_df[
        (service_metrics_df['day_type'] == day_type) &
        (service_metrics_df['time_period'] == time_period)
    ]
    
    if len(baseline_df) > 0:
        temporal_baseline = {
            'reliability': float(baseline_df['reliability_baseline'].values[0]),
            'crowding': float(baseline_df['crowding_baseline'].values[0]),
            'availability': float(baseline_df['service_availability'].values[0])
        }
    else:
        temporal_baseline = {'reliability': 0.75, 'crowding': 0.5, 'availability': 0.9}
    
    # Find edge to get service details
    edge_info = None
    for idx, edge in edges_df.iterrows():
        if (location_id_to_idx[edge['origin_id']] == origin_idx and 
            location_id_to_idx[edge['destination_id']] == dest_idx and
            edge['mode'] == mode):
            edge_info = edge
            break
    
    origin_name = nodes_df[nodes_df['location_id'] == origin_id]['name'].values[0]
    dest_name = nodes_df[nodes_df['location_id'] == destination_id]['name'].values[0]
    
    result = {
        'origin': origin_id,
        'origin_name': origin_name,
        'destination': destination_id,
        'destination_name': dest_name,
        'day_type': day_type,
        'time_period': time_period,
        'mode': mode,
        'gnn_reliability_score': reliability_score,
        'temporal_baseline': temporal_baseline,
        'normalized_vs_baseline': reliability_score / temporal_baseline['reliability'] if temporal_baseline['reliability'] > 0 else 1.0
    }
    
    if edge_info is not None:
        result.update({
            'service_id': edge_info['service_id'],
            'operator': edge_info['operator'],
            'distance_km': float(edge_info['distance_km']),
            'duration_min': float(edge_info['duration_min']),
            'fare_lkr': float(edge_info['fare_lkr']),
            'is_active': bool(edge_info['is_active'])
        })
    
    return result

print("‚úÖ Temporal Inference API initialized!")
print("   Function: predict_reliability_temporal(origin, destination, day_type, time_period, mode)")
print(f"   Day types: {list(DAY_TYPE_TO_ID.keys())}")
print(f"   Time periods: {list(TIME_PERIOD_TO_ID.keys())}")
print(f"   Modes: {list(MODE_TO_ID.keys())}")


‚úÖ Temporal Inference API initialized!
   Function: predict_reliability_temporal(origin, destination, day_type, time_period, mode)
   Day types: ['regular', 'weekend']
   Time periods: ['early_morning', 'morning_peak', 'midday', 'evening_peak', 'night']
   Modes: ['bus', 'ride_hail', 'train']


## 8Ô∏è‚É£ Evaluate & Test - Find Best Method Between 2 Nodes

## üéØ Predict Reliability for Individual Transport Methods/Edges

In [61]:
# ============================================================================
# Test: Predict Reliability Using Temporal GNN
# ============================================================================

print("="*80)
print("üéØ PREDICT RELIABILITY USING TEMPORAL GNN")
print("="*80)

# Get a test route
test_edge = edges_df.iloc[0]
origin_id = test_edge['origin_id']
dest_id = test_edge['destination_id']
mode = test_edge['mode']
origin_name = nodes_df[nodes_df['location_id'] == origin_id]['name'].values[0]
dest_name = nodes_df[nodes_df['location_id'] == dest_id]['name'].values[0]

print(f"\nüìç Route: {origin_name} ({origin_id}) ‚Üí {dest_name} ({dest_id})")
print(f"üöå Mode: {mode}")

# Predict for different temporal contexts
test_cases = [
    ("regular", "morning_peak"),
    ("regular", "midday"),
    ("weekend", "afternoon"),
]

print(f"\n{'Day Type':<12} | {'Time Period':<15} | {'GNN Score':<12} | {'Baseline':<10} | {'Normalized':<10}")
print("-" * 65)

for day_type, time_period in [("regular", "morning_peak"), ("regular", "midday"), ("weekend", "morning_peak")]:
    result = predict_reliability_temporal(origin_id, dest_id, day_type, time_period, mode)
    
    if 'error' not in result:
        print(f"{day_type:<12} | {time_period:<15} | {result['gnn_reliability_score']:>10.4f} | {result['temporal_baseline']['reliability']:>8.1%} | {result['normalized_vs_baseline']:>8.2f}x")
    else:
        print(f"‚ùå Error: {result['error']}")

print("\n" + "="*80)
print("‚ú® Example result for midday regular day:")
result = predict_reliability_temporal(origin_id, dest_id, "regular", "midday", mode)
if 'error' not in result:
    print(f"   GNN Prediction: {result['gnn_reliability_score']:.4f}")
    print(f"   Baseline Reliability: {result['temporal_baseline']['reliability']:.1%}")
    print(f"   Normalized Score: {result['normalized_vs_baseline']:.2f}x")
    if 'service_id' in result:
        print(f"   Service ID: {result['service_id']}")
        print(f"   Distance: {result['distance_km']:.1f} km")
        print(f"   Duration: {result['duration_min']:.0f} min")
        print(f"   Fare: {result['fare_lkr']:.0f} LKR")
print("="*80)


üéØ PREDICT RELIABILITY USING TEMPORAL GNN

üìç Route: Colombo (1) ‚Üí Negombo (5)
üöå Mode: bus

Day Type     | Time Period     | GNN Score    | Baseline   | Normalized
-----------------------------------------------------------------
regular      | morning_peak    |     0.6151 |    75.0% |     0.82x
regular      | midday          |     0.6564 |    80.0% |     0.82x
weekend      | morning_peak    |     0.6151 |    75.0% |     0.82x

‚ú® Example result for midday regular day:
   GNN Prediction: 0.6564
   Baseline Reliability: 80.0%
   Normalized Score: 0.82x
   Service ID: BUS_0001
   Distance: 31.4 km
   Duration: 50 min
   Fare: 184 LKR


In [62]:
# ============================================================================
# Test: Examples Using the Temporal Inference API
# ============================================================================

print("="*80)
print("üìä TEST: MULTIPLE ROUTES AND TEMPORAL SCENARIOS")
print("="*80)

# Test several routes with different temporal contexts
test_routes = edges_df.sample(min(3, len(edges_df))).reset_index()

for route_idx, edge in test_routes.iterrows():
    origin_id = edge['origin_id']
    dest_id = edge['destination_id']
    mode = edge['mode']
    origin_name = nodes_df[nodes_df['location_id'] == origin_id]['name'].values[0]
    dest_name = nodes_df[nodes_df['location_id'] == dest_id]['name'].values[0]
    
    print(f"\n{'='*80}")
    print(f"Route {route_idx+1}: {origin_name} ‚Üí {dest_name}")
    print(f"Service: {edge['service_id']} | Mode: {mode} | Operator: {edge['operator']}")
    print(f"Distance: {edge['distance_km']:.1f} km | Duration: {edge['duration_min']:.0f} min | Fare: {edge['fare_lkr']:.0f} LKR")
    print(f"{'Day Type':<12} | {'Time Period':<15} | {'GNN Score':<12} | {'Baseline':<10} | {'Normalized':<10}")
    print("-" * 65)
    
    # Test across different temporal contexts
    for day_type in ["regular", "weekend"]:
        for time_period in ["morning_peak", "midday", "evening_peak"]:
            result = predict_reliability_temporal(origin_id, dest_id, day_type, time_period, mode)
            
            if 'error' not in result:
                print(f"{day_type:<12} | {time_period:<15} | {result['gnn_reliability_score']:>10.4f} | {result['temporal_baseline']['reliability']:>8.1%} | {result['normalized_vs_baseline']:>8.2f}x")
            else:
                print(f"{day_type:<12} | {time_period:<15} | ‚ùå {result['error'][:20]}")

print("\n" + "="*80)
print("‚úÖ Temporal inference tests complete!")
print("="*80)


üìä TEST: MULTIPLE ROUTES AND TEMPORAL SCENARIOS

Route 1: Kandy Railway Station ‚Üí Badulla Railway Station
Service: BUS_1534 | Mode: bus | Operator: SLTB
Distance: 57.0 km | Duration: 66 min | Fare: 351 LKR
Day Type     | Time Period     | GNN Score    | Baseline   | Normalized
-----------------------------------------------------------------
regular      | morning_peak    |     0.6151 |    75.0% |     0.82x
regular      | midday          |     0.6564 |    80.0% |     0.82x
regular      | evening_peak    |     0.5802 |    70.0% |     0.83x
weekend      | morning_peak    |     0.6151 |    75.0% |     0.82x
weekend      | midday          |     0.6564 |    75.0% |     0.88x
weekend      | evening_peak    |     0.5802 |    75.0% |     0.77x

Route 2: Matara ‚Üí Badulla
Service: TRAIN_0868 | Mode: train | Operator: SLR
Distance: 129.6 km | Duration: 173 min | Fare: 633 LKR
Day Type     | Time Period     | GNN Score    | Baseline   | Normalized
--------------------------------------------

In [63]:
# Test set evaluation (using clean targets and temporal parameters)
model.eval()
with torch.no_grad():
    test_predictions = model(
        node_features_tensor,
        edge_index,
        origins[test_indices],
        dests[test_indices],
        day_types[test_indices],
        time_periods[test_indices],
        modes[test_indices]
    )
    test_targets = targets[test_indices]  # Clean targets
    test_loss = criterion(test_predictions, test_targets).item()
    test_mae = torch.abs(test_predictions - test_targets).mean().item()
    test_r2 = calculate_r2_score(test_predictions, test_targets)
    test_acc = calculate_accuracy(test_predictions, test_targets, threshold=0.02) * 100  # ¬±2% threshold

print("üìä Test Set Performance:")
print(f"   MSE: {test_loss:.4f}")
print(f"   MAE: {test_mae:.4f}")
print(f"   R¬≤ Score: {test_r2:.4f}")
print(f"   Accuracy (¬±2% threshold): {test_acc:.1f}%\n")

# Log test metrics
try:
    mlflow.log_metric("test_loss", test_loss)
    mlflow.log_metric("test_mae", test_mae)
    mlflow.log_metric("test_r2_score", test_r2)
    mlflow.log_metric("test_accuracy_percent", test_acc)
except:
    pass

print("‚úÖ Test evaluation complete!")


üìä Test Set Performance:
   MSE: 0.0038
   MAE: 0.0478
   R¬≤ Score: 0.5729
   Accuracy (¬±2% threshold): 30.3%

‚úÖ Test evaluation complete!


## Save Model & Log to DagsHub

In [None]:
# Save model to multiple locations
import os

# Save to local model folder
local_model_dir = "../model"
os.makedirs(local_model_dir, exist_ok=True)
local_model_path = os.path.join(local_model_dir, "transport_gnn_routing.pth")

# Save complete checkpoint with all artifacts (matches service expectations)
checkpoint = {
    'model_state_dict': model.state_dict(),
    'node_features': node_features_tensor.cpu(),
    'edge_index': edge_index.cpu(),
    'region_encoder': region_encoder,
    'type_encoder': type_encoder,
    'mode_encoder': mode_encoder,
    'scaler': coords_scaler,
    # Critical: Save location mapping used during training
    'location_id_to_idx': location_id_to_idx,
    # Temporal mappings
    'mode_to_id': MODE_TO_ID,
    'day_type_to_id': DAY_TYPE_TO_ID,
    'time_period_to_id': TIME_PERIOD_TO_ID,
    'num_day_types': 2,
    'num_time_periods': 5,
    'num_modes': 3,
    'hidden_dim': HIDDEN_DIM
}

torch.save(checkpoint, local_model_path)
print(f"‚úÖ Model saved locally with all artifacts: {local_model_path}")
print(f"   Includes: model_state_dict, node_features, edge_index, encoders, location_id_to_idx, temporal mappings")
print(f"   Architecture: TemporalTransportGNN")
print(f"   Location mapping: {len(location_id_to_idx)} locations")

# Save to /tmp for MLflow
temp_model_path = "/tmp/transport_gnn_routing.pth"
torch.save(checkpoint, temp_model_path)

# Log to MLflow/DagsHub
try:
    mlflow.log_artifact(temp_model_path, artifact_path="models")
    print(f"‚úÖ Model logged to DagsHub/MLflow")
except:
    print(f"‚ö†Ô∏è  MLflow artifact logging skipped")

# Log hyperparameters
try:
    mlflow.log_params({
        "learning_rate": LEARNING_RATE,
        "epochs": epoch + 1,
        "batch_size": BATCH_SIZE,
        "hidden_dim": HIDDEN_DIM,
        "num_nodes": len(nodes_df),
        "num_edges": len(edges_df),
        "num_training_samples": len(training_data),
        "num_day_types": 2,
        "num_time_periods": 5,
        "num_modes": 3
    })
    print("‚úÖ Hyperparameters logged to DagsHub")
except:
    print("‚ÑπÔ∏è  Hyperparameters not logged")

# End run
mlflow.end_run()
print("\nüéâ Training complete! Temporal GNN model saved with CSV data integration.")


‚úÖ Model saved locally with all artifacts: ../model/transport_gnn_routing.pth
   Includes: model_state_dict, node_features, edge_index, encoders, temporal mappings
   Architecture: TemporalTransportGNN
‚úÖ Model logged to DagsHub/MLflow
‚úÖ Hyperparameters logged to DagsHub

üéâ Training complete! Temporal GNN model saved with CSV data integration.
