In [1]:
import numpy as np 
import pandas as pd
import matplotlib as plt

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import ChebConv  # Chebyshev graph convolution

# Spatio-temporal convolution block
class STConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, K, num_nodes):
        super(STConvBlock, self).__init__()
        self.temporal1 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, K), padding=(0, K//2))
        self.graph_conv = ChebConv(out_channels, out_channels, K)
        self.temporal2 = nn.Conv2d(out_channels, out_channels, kernel_size=(1, K), padding=(0, K//2))
        self.num_nodes = num_nodes
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, edge_index):
        # x: [batch_size, channels, num_nodes, time_steps]
        x = F.glu(self.temporal1(x), dim=1)  # First gated temporal conv
        x = x.permute(0, 3, 2, 1)  # [batch_size, time_steps, num_nodes, channels]
        batch_size, time_steps, num_nodes, out_channels = x.size()

        # Graph convolution over each time step
        x = x.contiguous().view(batch_size * time_steps, num_nodes, out_channels)  # Reshape to (batch_size * time_steps, num_nodes, out_channels)
        x = self.graph_conv(x, edge_index)  # Graph convolution
        x = x.view(batch_size, time_steps, num_nodes, out_channels).permute(0, 3, 2, 1)  # Reshape back
        
        x = F.glu(self.temporal2(x), dim=1)  # Second gated temporal conv
        x = self.bn(x)  # Batch norm
        return x

# Full STGCN Model
class STGCN(nn.Module):
    def __init__(self, in_channels, out_channels, K, num_nodes, num_blocks):
        super(STGCN, self).__init__()
        self.blocks = nn.ModuleList([STConvBlock(in_channels, out_channels, K, num_nodes) for _ in range(num_blocks)])
        self.final_conv = nn.Conv2d(out_channels, 1, kernel_size=(1, 1))  # Final output layer

    def forward(self, x, edge_index):
        for block in self.blocks:
            x = block(x, edge_index)
        x = self.final_conv(x)
        return x.squeeze()



In [None]:
from torch_geometric.data import DataLoader
from sklearn.metrics import mean_absolute_error

# Assume you have your dataset loaded in a torch_geometric format
# traffic_data is a PyTorch Geometric dataset
# edge_index is the adjacency matrix edges of the graph

def train(model, optimizer, data_loader):
    model.train()
    total_loss = 0
    for data in data_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)

def test(model, data_loader):
    model.eval()
    preds, labels = [], []
    for data in data_loader:
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            preds.append(out.cpu().numpy())
            labels.append(data.y.cpu().numpy())
    preds = np.concatenate(preds, axis=0)
    labels = np.concatenate(labels, axis=0)
    return mean_absolute_error(labels, preds)

# Example setup
model = STGCN(in_channels=1, out_channels=64, K=3, num_nodes=228, num_blocks=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_loader = DataLoader(traffic_data, batch_size=64, shuffle=True)
test_loader = DataLoader(traffic_data, batch_size=64, shuffle=False)

# Training loop
for epoch in range(50):
    train_loss = train(model, optimizer, train_loader)
    test_mae = test(model, test_loader)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Test MAE: {test_mae}")


In [7]:

import torch
from torch_geometric.data import Data

adj_matrix = ...  # Load adjacency matrix W from PeMSD7
features = ...    # Load node features X, such as traffic flow or speed data
labels = ...      # Load labels, which might be future traffic flow for forecasting


In [8]:
adj_matrix

Ellipsis

In [None]:
dataset = PeMSDataset(root='/path/to/dataset')
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset, batch_size=64, shuffle=False)
