In [1]:
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

In [2]:
loader = ChickenpoxDatasetLoader()

dataset = loader.get_dataset()

In [3]:
dir(loader)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_dataset',
 '_edge_weights',
 '_edges',
 '_get_edge_weights',
 '_get_edges',
 '_get_targets_and_features',
 '_read_web_data',
 'features',
 'get_dataset',
 'lags',
 'targets']

In [4]:
a = loader.features

In [5]:
len(a)

517

In [6]:
a[0].shape

(20, 4)

In [7]:
dir(dataset)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__get_item__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__next__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_check_temporal_consistency',
 '_get_edge_index',
 '_get_edge_weight',
 '_get_features',
 '_get_target',
 '_set_snapshot_count',
 'edge_index',
 'edge_weight',
 'features',
 'snapshot_count',
 'targets']

In [8]:
dataset.edge_index

array([[ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,
         3,  3,  3,  3,  3,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  6,
         6,  6,  7,  7,  7,  7,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,
        10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12,
        12, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 15,
        15, 15, 16, 16, 16, 16, 16, 17, 17, 17, 17, 18, 18, 18, 18, 18,
        18, 18, 19, 19, 19, 19],
       [10,  6, 13,  1,  0,  5, 16,  0, 16,  1, 14, 10,  8,  2,  5,  8,
        15, 12,  9, 10,  3,  4, 13,  0, 10,  2,  5,  0, 16,  6, 14, 13,
        11, 18,  7, 17, 11, 18,  3,  2, 15,  8, 10,  9, 13,  3, 12, 10,
         5,  9,  8,  3, 10,  2, 13,  0,  6, 11,  7, 13, 18,  3,  9, 13,
        12, 13,  9,  6,  4, 12,  0, 11, 10, 18, 19,  1, 14,  6, 16,  3,
        15,  8, 16, 14,  1,  0,  6,  7, 19, 17, 18, 14, 18, 17,  7,  6,
        19, 11, 18, 14, 19, 17]])

In [9]:
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

In [10]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [11]:
from tqdm import tqdm

model = RecurrentGCN(node_features = 4)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(200)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [01:05<00:00,  3.08it/s]


In [12]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

MSE: 1.0245


In [13]:
import io
import json
import numpy as np
from six.moves import urllib

In [14]:
url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/chickenpox.json"
_dataset = json.loads(urllib.request.urlopen(url).read())

In [15]:
_dataset.keys()

dict_keys(['edges', 'node_ids', 'FX'])

In [16]:
np.array(_dataset["edges"]).T.shape

(2, 102)

In [17]:
stacked_target =np.array(_dataset["FX"])

In [18]:
stacked_target.shape

(521, 20)

In [19]:
stacked_target[0].shape

(20,)

In [20]:
lags = 4

In [21]:
features = [stacked_target[i:i+lags,:].T for i in range(stacked_target.shape[0]-lags)]

In [22]:
features[0].shape

(20, 4)

In [23]:
targets  = [stacked_target[i+lags,:].T for i in range(stacked_target.shape[0]-lags)]

In [24]:
_edges = np.array(_dataset["edges"]).T

In [25]:
_edge_weights = np.ones(_edges.shape[1])

In [26]:
_edge_weights.shape

(102,)

In [27]:
from torch_geometric_temporal.signal import StaticGraphTemporalSignal

In [28]:
dataset = StaticGraphTemporalSignal(_edges, _edge_weights, features, targets)

In [29]:
dataset

<torch_geometric_temporal.signal.static_graph_temporal_signal.StaticGraphTemporalSignal at 0x25a2bcaf5c0>

In [30]:
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

In [31]:
model = RecurrentGCN(node_features = 4)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(200)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [01:02<00:00,  3.19it/s]


In [32]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

MSE: 1.0260


In [42]:
k = list(train_dataset)[0]

In [44]:
k.x.shape

torch.Size([20, 4])

In [45]:
k.edge_index

tensor([[ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,
          3,  3,  3,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10,
         10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
         13, 14, 14, 14, 14, 14, 14, 15, 15, 15, 16, 16, 16, 16, 16, 17, 17, 17,
         17, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19],
        [10,  6, 13,  1,  0,  5, 16,  0, 16,  1, 14, 10,  8,  2,  5,  8, 15, 12,
          9, 10,  3,  4, 13,  0, 10,  2,  5,  0, 16,  6, 14, 13, 11, 18,  7, 17,
         11, 18,  3,  2, 15,  8, 10,  9, 13,  3, 12, 10,  5,  9,  8,  3, 10,  2,
         13,  0,  6, 11,  7, 13, 18,  3,  9, 13, 12, 13,  9,  6,  4, 12,  0, 11,
         10, 18, 19,  1, 14,  6, 16,  3, 15,  8, 16, 14,  1,  0,  6,  7, 19, 17,
         18, 14, 18, 17,  7,  6, 19, 11, 18, 14, 19, 17]])

In [46]:
k.edge_attr

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [49]:
model(k.x, k.edge_index, k.edge_attr)

tensor([[ 0.0729],
        [-0.1257],
        [ 0.4060],
        [ 0.5820],
        [ 0.1707],
        [ 0.0650],
        [-0.1160],
        [ 0.1641],
        [-0.1907],
        [ 0.2460],
        [-0.0037],
        [ 0.1704],
        [ 0.1678],
        [-0.0207],
        [-0.0222],
        [ 0.1170],
        [-0.0387],
        [-0.1362],
        [-0.2043],
        [ 0.0226]], grad_fn=<AddmmBackward>)