In [2]:
! pip install torch-geometric-temporal

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch-geometric-temporal
  Downloading torch_geometric_temporal-0.54.0.tar.gz (48 kB)
[K     |████████████████████████████████| 48 kB 2.5 MB/s 
Collecting torch_sparse
  Downloading torch_sparse-0.6.15.tar.gz (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 9.0 MB/s 
[?25hCollecting torch_scatter
  Downloading torch_scatter-2.0.9.tar.gz (21 kB)
Collecting torch_geometric
  Downloading torch_geometric-2.1.0.post1.tar.gz (467 kB)
[K     |████████████████████████████████| 467 kB 47.0 MB/s 
Building wheels for collected packages: torch-geometric-temporal, torch-geometric, torch-scatter, torch-sparse
  Building wheel for torch-geometric-temporal (setup.py) ... [?25l[?25hdone
  Created wheel for torch-geometric-temporal: filename=torch_geometric_temporal-0.54.0-py3-none-any.whl size=86745 sha256=601cdac7bc6c69981fab19f8c646ab07acf0bee07a4bc9f277a390c6cbab1e8e
  Stored 

We load our data.

In [3]:
from torch_geometric_temporal.dataset import WindmillOutputSmallDatasetLoader
#from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

loader = WindmillOutputSmallDatasetLoader()
#loader = ChickenpoxDatasetLoader()

dataset = loader.get_dataset()

We look at our data.

**How many data points?**

**What does a single datapoint look like?**

**What are the edge indices?**

**What are the edge attributes?**


**What kind of task is this - regression or classification?**

In [4]:
print("Number of data points: ", dataset.snapshot_count)
snapshot = dataset[0]
print("Features: ", snapshot.num_features, "Edge index: ", len(snapshot.edge_index), "Edge attributes: ", len(snapshot.edge_attr))

Number of data points:  17464
Features:  8 Edge index:  2 Edge attributes:  121


We split it into train and test. **Note need to be careful with temporal data.**

In [5]:
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

We define a model.

**What type of model?**

In [6]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

We train the model.

**What optimizer?**

**How many epochs?**

In [9]:
from tqdm import tqdm

model = RecurrentGCN(node_features = 8)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(25)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("MSE: {:.4f}".format(cost))

  4%|▍         | 1/25 [00:31<12:29, 31.24s/it]

MSE: 1.0152


  8%|▊         | 2/25 [01:03<12:06, 31.58s/it]

MSE: 1.0072


 12%|█▏        | 3/25 [01:34<11:33, 31.53s/it]

MSE: 1.0044


 16%|█▌        | 4/25 [02:07<11:14, 32.12s/it]

MSE: 1.0046


 20%|██        | 5/25 [02:39<10:38, 31.93s/it]

MSE: 1.0052


 24%|██▍       | 6/25 [03:11<10:06, 31.95s/it]

MSE: 1.0047


 28%|██▊       | 7/25 [03:42<09:31, 31.75s/it]

MSE: 1.0037


 32%|███▏      | 8/25 [04:14<09:01, 31.87s/it]

MSE: 1.0030


 36%|███▌      | 9/25 [04:46<08:27, 31.74s/it]

MSE: 1.0027


 40%|████      | 10/25 [05:17<07:56, 31.76s/it]

MSE: 1.0026


 44%|████▍     | 11/25 [05:48<07:20, 31.46s/it]

MSE: 1.0025


 48%|████▊     | 12/25 [06:22<06:58, 32.19s/it]

MSE: 1.0024


 52%|█████▏    | 13/25 [06:55<06:27, 32.31s/it]

MSE: 1.0022


 56%|█████▌    | 14/25 [07:26<05:51, 31.98s/it]

MSE: 1.0020


 60%|██████    | 15/25 [07:57<05:18, 31.83s/it]

MSE: 1.0017


 64%|██████▍   | 16/25 [08:28<04:44, 31.59s/it]

MSE: 1.0014


 68%|██████▊   | 17/25 [08:59<04:11, 31.40s/it]

MSE: 1.0013


 72%|███████▏  | 18/25 [09:30<03:38, 31.24s/it]

MSE: 1.0012


 76%|███████▌  | 19/25 [10:01<03:06, 31.05s/it]

MSE: 1.0010


 80%|████████  | 20/25 [10:31<02:34, 30.94s/it]

MSE: 1.0009


 84%|████████▍ | 21/25 [11:02<02:03, 30.90s/it]

MSE: 1.0007


 88%|████████▊ | 22/25 [11:33<01:32, 30.89s/it]

MSE: 1.0005


 92%|█████████▏| 23/25 [12:05<01:02, 31.12s/it]

MSE: 1.0003


 96%|█████████▌| 24/25 [12:36<00:31, 31.04s/it]

MSE: 1.0001


100%|██████████| 25/25 [13:07<00:00, 31.50s/it]

MSE: 0.9999





We evaluate.

**What is MSE?**

In [10]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

MSE: 0.9856
