# GNN Training

In [1]:
import os
import gc

from gnn_model.HeteroGAT import HeteroGAT
from gnn_model.Trainer import Trainer
from data_processing.data_loader import DataLoader
from data_processing.config import DataProcessingConfig
from utils import visualize_graph

In [2]:
import torch.serialization
from torch_geometric.data.storage import BaseStorage

# Add BaseStorage class to safe globals for loading
torch.serialization.add_safe_globals([BaseStorage])

## Data Preparation

In [3]:
# Scenarios
case_study = 'manhattan_case_study'
# case_study = 'GNN_study'
results_dir = os.path.join('..', 'studies', case_study, 'results')
scenario_names = [
    # 'test_small_1',
    # 'test_small_1',
    # 'test_small_1',
    'test_manhattan_scenario_1',
    # 'test_manhattan_scenario_1',
    # 'test_manhattan_scenario_1',
    'test_manhattan_scenario_2',
    'test_manhattan_scenario_3',
    # # 'test_manhattan_scenario_4',
    # # 'test_manhattan_scenario_5',
    # # 'test_manhattan_scenario_6', 
    # # 'test_manhattan_scenario_7', 
    # # 'test_manhattan_scenario_8',
    # # 'test_manhattan_scenario_9', 
    # # 'test_manhattan_scenario_10',
    # # 'test_manhattan_scenario_11', 
    # # 'test_manhattan_scenario_12',
]
                  
scenarios = [os.path.join(results_dir, sc) for sc in scenario_names]

# Create config
config = DataProcessingConfig(
    # sim_duration=1800,
)

# Set to True only when data needs to be reprocessed
overwrite = False
# Set to True to balance edges in the graph
balance_edges = False
edge_balance_ratio = 1.0  # Ratio of positive to negative edges

pos_weight = 1.0

In [4]:
# Safe data loading with error handling
def load_data_safely():
    gc.collect()  # Clean up memory before loading
    torch.cuda.empty_cache()  # Clear GPU cache if available
    
    loader = DataLoader(scenarios, config, overwrite=overwrite, balance_edges=balance_edges, edge_balance_ratio=edge_balance_ratio)
    data, *masks = loader.load_data()
    
    # Validate loaded data
    if not data or len(data) == 0:
        raise ValueError("No data was loaded")
        
    return data, masks
    
data, masks = load_data_safely()
train_masks, val_masks, test_masks = masks if masks else (None, None, None)

Loading scenarios: 100%|██████████| 3/3 [00:01<00:00,  2.08it/s]

Scenario split: Train=1, Val=1, Test=1 scenarios
Timestep split: Train=1439, Val=1440, Test=1440 timesteps





In [5]:
# Validate loaded data
if data is not None:
    print(f"Successfully loaded {len(data)} data points")
    print(f"Train/Val/Test split: {sum(train_masks)}/{sum(val_masks)}/{sum(test_masks)}")
else:
    print("Failed to load data. Please check the error message above.")

Successfully loaded 4319 data points
Train/Val/Test split: 1439/1440/1440


## Model Setup

In [6]:
# Model parameters
num_classes = 1  # Binary classification - single output logit with sigmoid activation
hidden_channels = 128  # Size of hidden layers in GNN
epochs = 200  # Maximum number of training epochs (may stop earlier due to early stopping)
batch_size = 16  # Number of graphs per batch

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## Model Setup

In [7]:
# Initialize model with error handling
model = HeteroGAT(hidden_channels, num_classes).to(device=device)
print(f"Model initialized with {hidden_channels} hidden channels and {num_classes} output channel")

Model initialized with 128 hidden channels and 1 output channel


In [8]:
# Training setup with weighted BCE loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
trainer = Trainer(data, device, masks, config, batch_size=batch_size, epochs=epochs, pos_weight=pos_weight)

Filtered 1 empty graphs from dataset

Training Configuration
--------------------------------------------------------------------------------
Batch Size:          16
Max Epochs:          200
Device:              cpu
Pos Weight:          1.0000
--------------------------------------------------------------------------------


In [9]:
# # Training with progress tracking and error handling
# try:
#     # Add numpy scalar to safe globals before training
#     import torch.serialization
#     torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar'])
    
#     # Start training
#     trainer.train(model, optimizer)
# except Exception as e:
#     print(f"Error during training: {str(e)}")
#     # Print more detailed error information
#     import traceback
#     traceback.print_exc()

# RF

In [10]:
loader = DataLoader(scenarios, config)
(X_train_vr, y_train_vr), (X_val_vr, y_val_vr), (X_test_vr, y_test_vr) = loader.get_edge_classification_data_for_rf(edge_type='vr_graph')

# Now you can use X_train, y_train with scikit-learn
from sklearn.ensemble import RandomForestClassifier

clf_vr = RandomForestClassifier()
clf_vr.fit(X_train_vr, y_train_vr)
print("Validation accuracy:", clf_vr.score(X_val_vr, y_val_vr))

Validation accuracy: 0.9954856318387989


In [24]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)

# Predict on validation set
y_val_pred_vr = clf_vr.predict(X_val_vr)
y_val_proba_vr = clf_vr.predict_proba(X_val_vr)[:, 1] if hasattr(clf_vr, "predict_proba") else None

print("Validation Metrics:")
print("Accuracy:", accuracy_score(y_val_vr, y_val_pred_vr))
print("Precision:", precision_score(y_val_vr, y_val_pred_vr, zero_division=0))
print("Recall:", recall_score(y_val_vr, y_val_pred_vr, zero_division=0))
print("F1 Score:", f1_score(y_val_vr, y_val_pred_vr, zero_division=0))
if y_val_proba_vr is not None:
    print("ROC AUC:", roc_auc_score(y_val_vr, y_val_proba_vr))
print("Confusion Matrix:\n", confusion_matrix(y_val_vr, y_val_pred_vr))
print("Classification Report:\n", classification_report(y_val_vr, y_val_pred_vr, zero_division=0))

Validation Metrics:
Accuracy: 0.9983148114292193
Precision: 0.9802828151347514
Recall: 0.9279333520022403
F1 Score: 0.9533900133787925
ROC AUC: 0.991741851883105
Confusion Matrix:
 [[3772514    1333]
 [   5147   66273]]
Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00   3773847
           1       0.98      0.93      0.95     71420

    accuracy                           1.00   3845267
   macro avg       0.99      0.96      0.98   3845267
weighted avg       1.00      1.00      1.00   3845267



In [None]:
import matplotlib.pyplot as plt

plt.hist(y_val_proba_vr)

In [11]:
from data_processing.data_loader import DataLoader

loader = DataLoader(scenarios, config)
(X_train_rr, y_train_rr), (X_val_rr, y_val_rr), (X_test_rr, y_test_rr) = loader.get_edge_classification_data_for_rf(edge_type='rr_graph')

clf_rr = RandomForestClassifier()
clf_rr.fit(X_train_rr, y_train_rr)
print("Validation accuracy:", clf_rr.score(X_val_rr, y_val_rr))

In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)

# Predict on validation set
y_val_pred_rr = clf_rr.predict(X_val_rr)
y_val_proba_rr = clf_rr.predict_proba(X_val_rr)[:, 1] if hasattr(clf_rr, "predict_proba") else None

print("Validation Metrics:")
print("Accuracy:", accuracy_score(y_val_rr, y_val_pred_rr))
print("Precision:", precision_score(y_val_rr, y_val_pred_rr, zero_division=0))
print("Recall:", recall_score(y_val_rr, y_val_pred_rr, zero_division=0))
print("F1 Score:", f1_score(y_val_rr, y_val_pred_rr, zero_division=0))
if y_val_proba_rr is not None:
    print("ROC AUC:", roc_auc_score(y_val_rr, y_val_proba_rr))
print("Confusion Matrix:\n", confusion_matrix(y_val_rr, y_val_pred_rr))
print("Classification Report:\n", classification_report(y_val_rr, y_val_pred_rr, zero_division=0))

In [15]:
from xgboost import XGBClassifier
from sklearn.metrics import classification_report

xgb_vr = XGBClassifier()
xgb_vr.fit(X_train_vr, y_train_vr)

y_pred_vr = xgb_vr.predict(X_val_vr)

print(classification_report(y_val_vr, y_pred_vr))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00   3773847
           1       0.96      0.91      0.94     71420

    accuracy                           1.00   3845267
   macro avg       0.98      0.96      0.97   3845267
weighted avg       1.00      1.00      1.00   3845267



In [16]:
# Evaluate XGBoost predictions on train, val, and test sets, similar to Random Forest
# Training set
y_train_pred_xgb_vr = xgb_vr.predict(X_train_vr)
print("XGBoost - Training set classification report:")
print(classification_report(y_train_vr, y_train_pred_xgb_vr))

# Validation set
y_val_pred_xgb_vr = xgb_vr.predict(X_val_vr)
print("XGBoost - Validation set classification report:")
print(classification_report(y_val_vr, y_val_pred_xgb_vr))

# Test set
y_test_pred_xgb_vr = xgb_vr.predict(X_test_vr)
print("XGBoost - Test set classification report:")
print(classification_report(y_test_vr, y_test_pred_xgb_vr))

XGBoost - Training set classification report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00   3740710
           1       0.97      0.92      0.94     72266

    accuracy                           1.00   3812976
   macro avg       0.99      0.96      0.97   3812976
weighted avg       1.00      1.00      1.00   3812976

XGBoost - Validation set classification report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00   3773847
           1       0.96      0.91      0.94     71420

    accuracy                           1.00   3845267
   macro avg       0.98      0.96      0.97   3845267
weighted avg       1.00      1.00      1.00   3845267

XGBoost - Test set classification report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00   3830959
           1       0.98      0.83      0.90     70923

    accuracy                           1.00   

In [None]:
from xgboost import XGBClassifier
from sklearn.metrics import classification_report

xgb_rr = XGBClassifier()
xgb_rr.fit(X_train_rr, y_train_rr)

y_pred_rr = xgb_rr.predict(X_val_rr)

print(classification_report(y_val_rr, y_pred_rr))

In [None]:
# Evaluate XGBoost predictions on train, val, and test sets, similar to Random Forest
# Training set
y_train_pred_xgb_rr = xgb_rr.predict(X_train_rr)
print("XGBoost - Training set classification report:")
print(classification_report(y_train_rr, y_train_pred_xgb_rr))

# Validation set
y_val_pred_xgb_rr = xgb_rr.predict(X_val_rr)
print("XGBoost - Validation set classification report:")
print(classification_report(y_val_rr, y_val_pred_xgb_rr))

# Test set
y_test_pred_xgb_rr = xgb_rr.predict(X_test_rr)
print("XGBoost - Test set classification report:")
print(classification_report(y_test_rr, y_test_pred_xgb_rr))

## Visualization

Let's visualize a sample graph from our dataset and analyze model predictions.

### Graph Visualization Functions

The following cells implement visualization functionality for our heterogeneous graph neural network:

1. Node visualization:
   - Vehicles: red nodes
   - Requests: turquoise nodes

2. Edge visualization:
   - True assignments: solid red lines
   - Non-assignments: dotted gray lines
   - Predicted assignments: semi-transparent blue lines (with probability scores)

3. Additional features:
   - Node labels (V for vehicles, R for requests)
   - Edge probability labels for predicted assignments
   - Comprehensive legend
   - Force-directed layout for clear visualization

In [17]:
# visualize_graph(data, graph_idx=0, model=model, device=device)