In [6]:
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
import numpy as np

In [2]:

loader = WikiMathsDatasetLoader()

dataset = loader.get_dataset(lags=14)

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)

In [15]:
print(np.array(dataset.features).shape[0])

717


In [14]:
print(dataset.shape[0])

AttributeError: 'StaticGraphTemporalSignal' object has no attribute 'shape'

In [12]:
print(dataset.edge_weight)

[1 4 2 ... 1 1 2]


In [8]:
print(np.array(dataset.targets).shape)

(717, 1068)


In [3]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [4]:
from tqdm import tqdm

model = RecurrentGCN(node_features=14, filters=32)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(50)):
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = torch.mean((y_hat-snapshot.y)**2)
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

  8%|▊         | 4/50 [01:47<20:38, 26.93s/it]


KeyboardInterrupt: 

In [None]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

In [18]:
from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    print(data)
    data.num_graphs

    x = scatter_mean(data.x, data.batch, dim=0)
    print(x.size())

Batch(batch=[1071], edge_index=[2, 4114], ptr=[33], x=[1071, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[1190], edge_index=[2, 4502], ptr=[33], x=[1190, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[1015], edge_index=[2, 3970], ptr=[33], x=[1015, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[931], edge_index=[2, 3680], ptr=[33], x=[931, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[1076], edge_index=[2, 4202], ptr=[33], x=[1076, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[910], edge_index=[2, 3550], ptr=[33], x=[910, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[1074], edge_index=[2, 3836], ptr=[33], x=[1074, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[940], edge_index=[2, 3556], ptr=[33], x=[940, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[1098], edge_index=[2, 4010], ptr=[33], x=[1098, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[962], edge_index=[2, 3820], ptr=[33], x=[962, 21], y=[32])
torch.Size([32, 21])
Batch(batch=[1180], edge_index=[2, 4218], ptr=[33], x=[118