In [1]:
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

loader = WikiMathsDatasetLoader()

dataset = loader.get_dataset(lags=14)

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [4]:
def test(model):
    model.eval()
    cost = 0
    for time, snapshot in enumerate(test_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost = cost.item()
    print("MSE: {:.4f}".format(cost))

In [7]:
from tqdm import tqdm

model = RecurrentGCN(node_features=14, filters=32)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(50)):
    for time, snapshot in enumerate(train_dataset):
        model.train()
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = torch.mean((y_hat-snapshot.y)**2)
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()
    test(model)

test(model)

  2%|▏         | 1/50 [00:14<12:07, 14.86s/it]

MSE: 0.9470


  4%|▍         | 2/50 [00:29<11:41, 14.62s/it]

MSE: 0.8677


  6%|▌         | 3/50 [00:44<11:31, 14.71s/it]

MSE: 0.8119


  8%|▊         | 4/50 [00:58<11:14, 14.66s/it]

MSE: 0.8151


 10%|█         | 5/50 [01:14<11:11, 14.93s/it]

MSE: 0.8665


 12%|█▏        | 6/50 [01:30<11:15, 15.35s/it]

MSE: 0.8081


 14%|█▍        | 7/50 [01:45<10:59, 15.34s/it]

MSE: 0.8020


 16%|█▌        | 8/50 [02:00<10:44, 15.33s/it]

MSE: 0.8054


 18%|█▊        | 9/50 [02:15<10:24, 15.24s/it]

MSE: 0.8066


 20%|██        | 10/50 [02:30<10:05, 15.14s/it]

MSE: 0.8063


 22%|██▏       | 11/50 [02:46<09:52, 15.20s/it]

MSE: 0.8090


 24%|██▍       | 12/50 [03:01<09:39, 15.25s/it]

MSE: 0.8097


 26%|██▌       | 13/50 [03:16<09:24, 15.24s/it]

MSE: 0.8083


 28%|██▊       | 14/50 [03:32<09:10, 15.29s/it]

MSE: 0.8246


 30%|███       | 15/50 [03:48<09:09, 15.70s/it]

MSE: 0.8435


 32%|███▏      | 16/50 [04:03<08:47, 15.52s/it]

MSE: 0.8427


 34%|███▍      | 17/50 [04:20<08:42, 15.85s/it]

MSE: 0.8158


 36%|███▌      | 18/50 [04:35<08:18, 15.57s/it]

MSE: 0.7951


 38%|███▊      | 19/50 [04:51<08:03, 15.61s/it]

MSE: 0.8174


 40%|████      | 20/50 [05:06<07:45, 15.53s/it]

MSE: 0.8357


 42%|████▏     | 21/50 [05:21<07:23, 15.30s/it]

MSE: 0.8213


 44%|████▍     | 22/50 [05:36<07:09, 15.34s/it]

MSE: 0.8062


 46%|████▌     | 23/50 [05:52<06:59, 15.52s/it]

MSE: 0.8012


 48%|████▊     | 24/50 [06:08<06:49, 15.73s/it]

MSE: 0.7995


 50%|█████     | 25/50 [06:24<06:33, 15.74s/it]

MSE: 0.8074


 52%|█████▏    | 26/50 [06:40<06:21, 15.89s/it]

MSE: 0.8134


 54%|█████▍    | 27/50 [06:57<06:08, 16.03s/it]

MSE: 0.8047


 56%|█████▌    | 28/50 [07:13<05:54, 16.10s/it]

MSE: 0.8132


 58%|█████▊    | 29/50 [07:30<05:41, 16.26s/it]

MSE: 0.8216


 60%|██████    | 30/50 [07:46<05:28, 16.43s/it]

MSE: 0.8339


 62%|██████▏   | 31/50 [08:02<05:07, 16.20s/it]

MSE: 0.8012


 64%|██████▍   | 32/50 [08:18<04:48, 16.00s/it]

MSE: 0.8018


 66%|██████▌   | 33/50 [08:35<04:36, 16.28s/it]

MSE: 0.7956


 68%|██████▊   | 34/50 [08:52<04:24, 16.50s/it]

MSE: 0.8089


 70%|███████   | 35/50 [09:08<04:05, 16.35s/it]

MSE: 0.7989


 72%|███████▏  | 36/50 [09:23<03:46, 16.19s/it]

MSE: 0.8269


 74%|███████▍  | 37/50 [09:43<03:45, 17.35s/it]

MSE: 0.8142


 76%|███████▌  | 38/50 [10:00<03:25, 17.15s/it]

MSE: 0.8169


 78%|███████▊  | 39/50 [10:17<03:06, 16.99s/it]

MSE: 0.8048


 80%|████████  | 40/50 [10:35<02:53, 17.33s/it]

MSE: 0.8141


 82%|████████▏ | 41/50 [10:52<02:36, 17.38s/it]

MSE: 0.8038


 84%|████████▍ | 42/50 [11:09<02:17, 17.16s/it]

MSE: 0.8127


 86%|████████▌ | 43/50 [11:27<02:02, 17.47s/it]

MSE: 0.8095


 88%|████████▊ | 44/50 [11:46<01:46, 17.77s/it]

MSE: 0.8137


 90%|█████████ | 45/50 [12:04<01:29, 17.81s/it]

MSE: 0.8250


 92%|█████████▏| 46/50 [12:21<01:10, 17.56s/it]

MSE: 0.8353


 94%|█████████▍| 47/50 [12:39<00:53, 17.71s/it]

MSE: 0.8227


 96%|█████████▌| 48/50 [12:55<00:34, 17.27s/it]

MSE: 0.8282


 98%|█████████▊| 49/50 [13:11<00:16, 16.83s/it]

MSE: 0.8061


100%|██████████| 50/50 [13:27<00:00, 16.15s/it]

MSE: 0.8207





MSE: 0.8207
