## **Spatial-Temporal Graph Neural Networks (ST-GNNs)**


ST-GNNs integrate both spatial and temporal dependencies in graph data, making them suitable for time-series data such as traffic prediction, human motion modeling, etc.

**Imports**

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

**Temporal Aggregation Module**

In [None]:
class TemporalAggregator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TemporalAggregator, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
    
    def forward(self, x):
        _, (h_n, _) = self.lstm(x)
        return h_n.squeeze(0)

**Spatial-Temporal GNN Model Definition**

In [None]:
class STGNN(MessagePassing):
    def __init__(self, node_feat_dim, hidden_dim, time_steps):
        super(STGNN, self).__init__(aggr='add')
        self.temporal_agg = TemporalAggregator(node_feat_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, time_steps)
    
    def forward(self, x, edge_index):
        x = self.temporal_agg(x)
        x = self.propagate(edge_index, x=x)
        return F.log_softmax(self.fc(x), dim=1)

    def message(self, x_j):
        return x_j

**Instantiate Model**

In [None]:
node_feat_dim = 16
hidden_dim = 8
time_steps = 4
stgnn_model = STGNN(node_feat_dim, hidden_dim, time_steps)