# Traffic Prediction Modeling
This notebook demonstrates loading and preprocessing of traffic datasets, splitting data, and training GCN, GAT, and STGAT models using PyTorch Geometric and PyTorch Lightning. Modify parameters as needed.

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch_geometric as tg
from torch_geometric import nn, data
import pytorch_lightning as L
import matplotlib.pyplot as plt
import pandas as pd
import os
from pathlib import Path
import shutil
import math
import sys

from typing import List, Tuple, Union

from CustomDatasets import TrafficDataset
from CustomModels import GATLSTMModel, GCNLSTMModel, STGATModel

In [None]:
def split_dataset(
    dataset: TrafficDataset,
    possible_slot: int,
    split_days: tuple,
):
    n_train_day, n_test_day, _ = split_days
    i = int(n_train_day * possible_slot)
    j = int(n_test_day * possible_slot)
    train_dataset = dataset[:i]
    test_dataset = dataset[i : i + j]
    val_dataset = dataset[i + j :]
    return train_dataset, test_dataset, val_dataset

In [None]:
# Configuration
# Remove old processed data
!rm -rf metr-la-dataset/processed

POSSIBLE_SLOT = (24 * 60) // 5
config = {
    "F": 12,
    "H": 12,  # adjust horizon as needed
    "N_DAYS": 44,
    "N_DAY_SLOT": POSSIBLE_SLOT,
    "BATCH_SIZE": 64,
    "LR": 2e-4,
    "WEIGHT_DECAY": 5e-4
}
config["N_SLOT"] = config["N_DAY_SLOT"] - (config["H"] + config["F"]) + 1

In [None]:
# Load dataset (choose model: 'gcn', 'gat', or 'stgat')
model_name = 'gcn'  # choose 'gat' or 'stgat'
gat_version = model_name in ['gat', 'stgat']

dataset = TrafficDataset(config, root='metr-la-dataset', gat_version=gat_version)

# Clean memory and reload
import gc
gc.collect()
torch.cuda.empty_cache()

dataset = TrafficDataset(config, root='metr-la-dataset', gat_version=gat_version)

# Split
train, test, val = split_dataset(dataset, config['N_SLOT'], (30, 9, 5))

train_loader = tg.loader.DataLoader(train, batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = tg.loader.DataLoader(val, batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=2, pin_memory=True)
test_loader = tg.loader.DataLoader(test, batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=2, pin_memory=True)

print(f'Train: {len(train)}, Val: {len(val)}, Test: {len(test)}')

In [None]:
# Training and evaluation
callbacks = [
    L.callbacks.Timer(),
    L.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode='min')
]

if model_name == 'gcn':
    model = GCNLSTMModel(
        in_channel=config['F'], gcn_hidden_channel=config['F'],
        n_nodes=dataset.n_node, drop_out=0.2,
        lstm_dim=[32,128], prediction_time_step=config['H'],
        lr=config['LR'], weight_decay=config['WEIGHT_DECAY'],
        batch_size=config['BATCH_SIZE']
    )
elif model_name == 'gat':
    model = GATLSTMModel(
        in_channel=config['F'], gat_out_channel=config['F'], n_nodes=dataset.n_node,
        att_heads=8, drop_out=0.2, lstm_dim=[32,128],
        prediction_time_step=config['H'], lr=config['LR'],
        weight_decay=config['WEIGHT_DECAY'], batch_size=config['BATCH_SIZE'],
        concat_gat=True
    )
elif model_name == 'stgat':
    model = STGATModel(
        in_channel=config['F'], out_channel=config['F'], n_nodes=dataset.n_node,
        att_head_nodes=8, drop_out=0.2, lstm_dim=[32,128],
        prediction_time_step=config['H'], lr=config['LR'],
        weight_decay=config['WEIGHT_DECAY'], batch_size=config['BATCH_SIZE']
    )

trainer = L.Trainer(
    accelerator='cuda' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None,
    callbacks=callbacks, precision=16 if torch.cuda.is_available() else 32,
    max_epochs=30, default_root_dir=f'./logs/{model_name}'
)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

In [None]:
# Plot training metrics
metrics = model.history
epochs = range(model.current_epoch)

plt.figure(figsize=(10,5))
plt.plot(epochs, metrics['loss'], label='Train Loss')
plt.plot(epochs, metrics['val_loss'], label='Val Loss')
plt.xlabel('Epoch'); plt.ylabel('MSE Loss'); plt.legend(); plt.grid(); plt.title('MSE')

plt.figure(figsize=(10,5))
plt.plot(epochs, metrics['train_MAE'], label='Train MAE')
plt.plot(epochs, metrics['val_MAE'], label='Val MAE')
plt.xlabel('Epoch'); plt.ylabel('MAE'); plt.legend(); plt.grid(); plt.title('MAE')

plt.figure(figsize=(10,5))
plt.plot(epochs, metrics['train_RMSE'], label='Train RMSE')
plt.plot(epochs, metrics['val_RMSE'], label='Val RMSE')
plt.xlabel('Epoch'); plt.ylabel('RMSE'); plt.legend(); plt.grid(); plt.title('RMSE')