# Temporal GNN

In [1]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.nn.recurrent import A3TGCN

In [2]:
data = torch.load("../data/trains_time.pt")
train_dataset, test_dataset = temporal_signal_split(data, train_ratio=0.8)

In [3]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, node_features, periods):
        super(TemporalGNN, self).__init__()
        # a3tgcn expects edge weight to be just one feature.
        # this merges multiple features into one.
        self.edge_attr_to_weight = torch.nn.Linear(2, 1)
        self.tgnn = A3TGCN(in_channels=node_features, 
                           out_channels=32, 
                           periods=periods)
        # Equals single-shot prediction
        self.linear = torch.nn.Linear(32, periods)

    def forward(self, x, edge_index, edge_attr):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
        # if len(edge_attr.shape) > 1 and edge_attr.shape[-1] != 1:
        #     edge_attr = self.edge_attr_to_weight(edge_attr).flatten()
        h = self.tgnn(X=x, edge_index=edge_index, edge_weight=edge_attr)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [4]:
# GPU support
device = torch.device('cpu') # cuda
subset = 2000

# Create model and optimizers
model = TemporalGNN(node_features=1, periods=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

print("Running training...")
for epoch in range(1000):
    loss = 0
    step = 0
    for snapshot in train_dataset:
        snapshot = snapshot.to(device)
        # Get model predictions
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        # Mean squared error
        loss = loss + torch.mean((y_hat-snapshot.y) ** 2) 
        step += 1
        if step > subset:
          break

    loss = loss / (step + 1)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("Epoch {} train MSE: {:.4f}".format(epoch, loss.item()))

Running training...
Epoch 0 train MSE: 3021.6807
Epoch 1 train MSE: 3003.7402
Epoch 2 train MSE: 2991.0933
Epoch 3 train MSE: 2985.2104
Epoch 4 train MSE: 2981.1799
Epoch 5 train MSE: 2977.3689
Epoch 6 train MSE: 2973.6040
Epoch 7 train MSE: 2969.8323
Epoch 8 train MSE: 2966.0217
Epoch 9 train MSE: 2962.1619
Epoch 10 train MSE: 2958.2412
Epoch 11 train MSE: 2954.2473
Epoch 12 train MSE: 2949.4641
Epoch 13 train MSE: 2945.0627
Epoch 14 train MSE: 2940.6699
Epoch 15 train MSE: 2936.2300
Epoch 16 train MSE: 2931.7332
Epoch 17 train MSE: 2927.1968
Epoch 18 train MSE: 2922.6340
Epoch 19 train MSE: 2918.0662
Epoch 20 train MSE: 2913.5093
Epoch 21 train MSE: 2908.9829
Epoch 22 train MSE: 2904.5037
Epoch 23 train MSE: 2900.0845
Epoch 24 train MSE: 2895.7393
Epoch 25 train MSE: 2891.4773
Epoch 26 train MSE: 2887.3069
Epoch 27 train MSE: 2883.2322
Epoch 28 train MSE: 2879.2588
Epoch 29 train MSE: 2875.3892
Epoch 30 train MSE: 2871.6218
Epoch 31 train MSE: 2867.9585
Epoch 32 train MSE: 2864.3967


In [5]:
model.eval()
loss = 0
step = 0
horizon = 288

# Store for analysis
predictions = []
labels = []

for snapshot in test_dataset:
    snapshot = snapshot.to(device)
    # Get predictions
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    # Mean squared error
    loss = loss + torch.mean((y_hat-snapshot.y)**2)
    # Store for analysis below
    labels.append(snapshot.y)
    predictions.append(y_hat)
    step += 1
    if step > horizon:
          break

loss = loss / (step+1)
loss = loss.item()
print("Test MSE: {:.4f}".format(loss))

Test MSE: 2013.7495
