In [30]:
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

loader = ChickenpoxDatasetLoader()

dataset = loader.get_dataset()

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

In [16]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [17]:
from tqdm.auto import tqdm

model = RecurrentGCN(node_features = 4)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(200)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

  0%|          | 0/200 [00:00<?, ?it/s]

In [38]:
dataset

<torch_geometric_temporal.signal.static_graph_temporal_signal.StaticGraphTemporalSignal at 0x7f3567e10d00>

In [33]:
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

loader = WikiMathsDatasetLoader()

dataset = loader.get_dataset(lags=14)

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)

for time, snapshot in enumerate(train_dataset):
    print(time, snapshot, snapshot.x[:1, :5])

0 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[-0.4323, -0.4739,  0.2659,  0.4844,  0.5367]])
1 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[-0.4739,  0.2659,  0.4844,  0.5367,  0.6412]])
2 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[0.2659, 0.4844, 0.5367, 0.6412, 0.2179]])
3 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[ 0.4844,  0.5367,  0.6412,  0.2179, -0.7617]])
4 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[ 0.5367,  0.6412,  0.2179, -0.7617, -0.4067]])
5 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[ 0.6412,  0.2179, -0.7617, -0.4067,  0.3064]])
6 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[ 0.2179, -0.7617, -0.4067,  0.3064,  0.4972]])
7 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068]) tensor([[-0.7617, -0.4067,  0

In [39]:
plt

NameError: name 'plt' is not defined