# Temporal GNN

In [5]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.nn.recurrent import A3TGCN

In [6]:
data = torch.load("../data/trains.pt")
train_dataset, test_dataset = temporal_signal_split(data, train_ratio=0.8)

In [7]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, node_features, periods):
        super(TemporalGNN, self).__init__()
        # a3tgcn expects edge weight to be just one feature.
        # this merges multiple features into one.
        self.edge_attr_to_weight = torch.nn.Linear(2, 1)
        self.tgnn = A3TGCN(in_channels=node_features, 
                           out_channels=32, 
                           periods=periods)
        # Equals single-shot prediction
        self.linear = torch.nn.Linear(32, periods)

    def forward(self, x, edge_index, edge_attr):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
        # if len(edge_attr.shape) > 1 and edge_attr.shape[-1] != 1:
        #     edge_attr = self.edge_attr_to_weight(edge_attr).flatten()
        h = self.tgnn(X=x, edge_index=edge_index, edge_weight=edge_attr)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [8]:
# GPU support
device = torch.device('cpu') # cuda
subset = 2000

# Create model and optimizers
model = TemporalGNN(node_features=1, periods=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

print("Running training...")
for epoch in range(1000):
    loss = 0
    step = 0
    for snapshot in train_dataset:
        snapshot = snapshot.to(device)
        # Get model predictions
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        # Mean squared error
        loss = loss + torch.mean((y_hat-snapshot.y) ** 2) 
        step += 1
        if step > subset:
          break

    loss = loss / (step + 1)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("Epoch {} train MSE: {:.4f}".format(epoch, loss.item()))

Running training...
Epoch 0 train MSE: 0.5253
Epoch 1 train MSE: 0.4679
Epoch 2 train MSE: 0.4236
Epoch 3 train MSE: 0.3826
Epoch 4 train MSE: 0.3425
Epoch 5 train MSE: 0.3059
Epoch 6 train MSE: 0.2767
Epoch 7 train MSE: 0.2584
Epoch 8 train MSE: 0.2542
Epoch 9 train MSE: 0.2633
Epoch 10 train MSE: 0.2773
Epoch 11 train MSE: 0.2861
Epoch 12 train MSE: 0.2860
Epoch 13 train MSE: 0.2794
Epoch 14 train MSE: 0.2701
Epoch 15 train MSE: 0.2614
Epoch 16 train MSE: 0.2552
Epoch 17 train MSE: 0.2524
Epoch 18 train MSE: 0.2524
Epoch 19 train MSE: 0.2544
Epoch 20 train MSE: 0.2569
Epoch 21 train MSE: 0.2593
Epoch 22 train MSE: 0.2608
Epoch 23 train MSE: 0.2612
Epoch 24 train MSE: 0.2605
Epoch 25 train MSE: 0.2589
Epoch 26 train MSE: 0.2568
Epoch 27 train MSE: 0.2546
Epoch 28 train MSE: 0.2528
Epoch 29 train MSE: 0.2516
Epoch 30 train MSE: 0.2511
Epoch 31 train MSE: 0.2513
Epoch 32 train MSE: 0.2519
Epoch 33 train MSE: 0.2527
Epoch 34 train MSE: 0.2532
Epoch 35 train MSE: 0.2532
Epoch 36 train MSE

In [9]:
model.eval()
loss = 0
step = 0
horizon = 288

# Store for analysis
predictions = []
labels = []

for snapshot in test_dataset:
    snapshot = snapshot.to(device)
    # Get predictions
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    # Mean squared error
    loss = loss + torch.mean((y_hat-snapshot.y)**2)
    # Store for analysis below
    labels.append(snapshot.y)
    predictions.append(y_hat)
    step += 1
    if step > horizon:
          break

loss = loss / (step+1)
loss = loss.item()
print("Test MSE: {:.4f}".format(loss))

Test MSE: 0.2340


In [6]:
y_hat[:5]

tensor([[1.0855],
        [1.8012],
        [2.0900],
        [2.1391],
        [2.0034]], grad_fn=<SliceBackward0>)

In [7]:
snapshot.y[:5]

tensor([[5.],
        [2.],
        [3.],
        [4.],
        [1.]])