# GNN Training (Fixed Version)

In [1]:
import os
import torch
import gc
from tqdm import tqdm

from gnn_model.HeteroGAT import HeteroGAT
from gnn_model.trainer import Trainer
from data_processing.data_loader import DataLoader
from data_processing.config import DataProcessingConfig

## Data Preparation

In [2]:
# Scenarios
results_dir = os.path.join('..', 'studies', 'manhattan_case_study', 'results')
scenario_names = ['gnn_ex1', # 'gnn_ex2', 'gnn_ex3', 'gnn_ex4'
                  ]
scenarios = [os.path.join(results_dir, sc) for sc in scenario_names]

# Create config
config = DataProcessingConfig(
    sim_duration=86400,
)

# Set to True only when data needs to be reprocessed
overwrite = True

In [None]:
# Safe data loading with error handling
def load_data_safely():
    try:
        gc.collect()  # Clean up memory before loading
        torch.cuda.empty_cache()  # Clear GPU cache if available
        
        loader = DataLoader(scenarios, config, overwrite=overwrite)
        data, masks = loader.load_data()
        
        # Validate loaded data
        if not data or len(data) == 0:
            raise ValueError("No data was loaded")
            
        return data, masks
    except Exception as e:
        print(f"Error loading data: {str(e)}")
        return None, None

data, masks = load_data_safely()

  0%|          | 0/1 [00:00<?, ?it/s]

Error processing timestep 44460: 3261462
Failed timesteps: [44460]


In [None]:
# Validate loaded data
if data is not None:
    print(f"Successfully loaded {len(data)} data points")
    print(f"Train/Val/Test split: {sum(masks['train'])}/{sum(masks['val'])}/{sum(masks['test'])}")
else:
    print("Failed to load data. Please check the error message above.")

Successfully loaded 2879 data points


TypeError: list indices must be integers or slices, not str

## Model Setup

In [None]:
# Model parameters
num_classes = 2
hidden_channels = 64
epochs = 200
batch_size = 32

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Initialize model with error handling
try:
    model = HeteroGAT(hidden_channels, num_classes).to(device=device)
    print("Model initialized successfully")
except Exception as e:
    print(f"Error initializing model: {str(e)}")

In [None]:
# Training setup
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Initialize trainer with error handling
if data is not None and model is not None:
    try:
        trainer = Trainer(data, device, masks, config, batch_size=batch_size)
        print("Trainer initialized successfully")
    except Exception as e:
        print(f"Error initializing trainer: {str(e)}")

In [None]:
# Training with progress tracking and error handling
if 'trainer' in locals():
    try:
        trainer.train(model, criterion, optimizer)
    except Exception as e:
        print(f"Error during training: {str(e)}")
else:
    print("Trainer not initialized. Cannot proceed with training.")