In [166]:
import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from torch_geometric_temporal.nn.recurrent import DCRNN
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader, EnglandCovidDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

In [167]:
class LitDiffConvModel(pl.LightningModule):

    def __init__(self, node_features, filters):
        super().__init__()
        self.recurrent = DCRNN(node_features, filters, 1)
        self.linear = torch.nn.Linear(filters, 1)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x = train_batch.x
        y = train_batch.y.view(-1, 1)
        edge_index = train_batch.edge_index
        h = self.recurrent(x, edge_index)
        h = F.relu(h)
        h = self.linear(h)
        loss = F.mse_loss(h, y)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x = val_batch.x
        y = val_batch.y.view(-1, 1)
        edge_index = val_batch.edge_index
        h = self.recurrent(x, edge_index)
        h = F.relu(h)
        h = self.linear(h)
        loss = F.mse_loss(h, y)
        metrics = {'val_loss': loss}
        self.log_dict(metrics)
        return metrics

In [168]:
loader = ChickenpoxDatasetLoader()

dataset_loader = loader.get_dataset(lags = 32)

train_loader, val_loader = temporal_signal_split(dataset_loader,
                                                 train_ratio=0.2)

In [172]:
model = LitDiffConvModel(node_features=32,
                         filters=16)

early_stop_callback = EarlyStopping(monitor='val_loss',
                                    min_delta=0.00,
                                    patience=10,
                                    verbose=False,
                                    mode='max')

In [179]:
dataset_loader.features[0]

array([[-1.08135724e-03,  2.85705967e-02,  3.54742090e-01,
         2.95438182e-01,  7.10565536e-01, -6.83076297e-01,
         2.95438182e-01, -2.08645035e-01,  2.13385932e+00,
        -1.42437514e+00, -8.90639974e-01,  1.39256048e+00,
        -1.15750756e+00,  7.40217490e-01, -6.53424343e-01,
         1.71873197e+00, -1.15750756e+00,  6.51261628e-01,
        -5.05164573e-01,  1.15534484e+00,  5.91957721e-01,
        -4.75512620e-01, -1.03889974e+00, -1.08135724e-03,
         1.74838392e+00, -3.41105606e+00,  8.58825306e-01,
        -1.66159078e+00, -1.19689173e-01,  4.14045997e-01,
        -4.45860666e-01, -1.49341127e-01],
       [-7.11136085e-01, -5.98430173e-01,  1.90511208e-01,
         1.09215850e+00, -7.24692522e-02,  1.01702123e+00,
        -1.08682246e+00,  4.53491669e-01, -1.85175164e-01,
        -8.23841997e-01,  5.66197581e-01, -4.10586987e-01,
         7.54040767e-01, -1.12439109e+00,  7.91609404e-01,
        -8.98979271e-01,  6.03766218e-01,  5.66197581e-01,
        -1.31

In [180]:
dataset_loader.features[1]

array([[ 2.85705967e-02,  3.54742090e-01,  2.95438182e-01,
         7.10565536e-01, -6.83076297e-01,  2.95438182e-01,
        -2.08645035e-01,  2.13385932e+00, -1.42437514e+00,
        -8.90639974e-01,  1.39256048e+00, -1.15750756e+00,
         7.40217490e-01, -6.53424343e-01,  1.71873197e+00,
        -1.15750756e+00,  6.51261628e-01, -5.05164573e-01,
         1.15534484e+00,  5.91957721e-01, -4.75512620e-01,
        -1.03889974e+00, -1.08135724e-03,  1.74838392e+00,
        -3.41105606e+00,  8.58825306e-01, -1.66159078e+00,
        -1.19689173e-01,  4.14045997e-01, -4.45860666e-01,
        -1.49341127e-01, -3.07333111e-02],
       [-5.98430173e-01,  1.90511208e-01,  1.09215850e+00,
        -7.24692522e-02,  1.01702123e+00, -1.08682246e+00,
         4.53491669e-01, -1.85175164e-01, -8.23841997e-01,
         5.66197581e-01, -4.10586987e-01,  7.54040767e-01,
        -1.12439109e+00,  7.91609404e-01, -8.98979271e-01,
         6.03766218e-01,  5.66197581e-01, -1.31223428e+00,
         1.99

In [171]:
trainer = pl.Trainer(callbacks=[early_stop_callback])

trainer.fit(model, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type   | Params
-------------------------------------
0 | recurrent | DCRNN  | 4.6 K 
1 | linear    | Linear | 17    
-------------------------------------
4.6 K     Trainable params
0         Non-trainable params
4.6 K     Total params
0.018     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (20x48 and 47x16)

In [150]:
dataset_loader.features

[array([[-1.08135724e-03,  2.85705967e-02,  3.54742090e-01,
          2.95438182e-01,  7.10565536e-01, -6.83076297e-01,
          2.95438182e-01, -2.08645035e-01,  2.13385932e+00,
         -1.42437514e+00, -8.90639974e-01,  1.39256048e+00,
         -1.15750756e+00,  7.40217490e-01, -6.53424343e-01,
          1.71873197e+00, -1.15750756e+00,  6.51261628e-01,
         -5.05164573e-01,  1.15534484e+00,  5.91957721e-01,
         -4.75512620e-01, -1.03889974e+00, -1.08135724e-03,
          1.74838392e+00, -3.41105606e+00,  8.58825306e-01,
         -1.66159078e+00, -1.19689173e-01,  4.14045997e-01,
         -4.45860666e-01, -1.49341127e-01],
        [-7.11136085e-01, -5.98430173e-01,  1.90511208e-01,
          1.09215850e+00, -7.24692522e-02,  1.01702123e+00,
         -1.08682246e+00,  4.53491669e-01, -1.85175164e-01,
         -8.23841997e-01,  5.66197581e-01, -4.10586987e-01,
          7.54040767e-01, -1.12439109e+00,  7.91609404e-01,
         -8.98979271e-01,  6.03766218e-01,  5.66197581e-

In [149]:
make_adj(dataset_loader.edge_index, dataset_loader.edge_weight)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,1,1,0,0,0,1,1,0,0,0,1,0,0,1,0,0,1,0,0,0
1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0
2,0,0,1,0,0,1,0,0,1,0,1,0,0,0,0,0,0,0,0,0
3,0,0,0,1,0,0,0,0,1,1,1,0,1,0,0,1,0,0,0,0
4,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0
5,1,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0
6,1,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,1,0,1,0
7,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,1,1,0
8,0,0,1,1,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0,0
9,0,0,0,1,0,0,0,0,0,1,1,0,1,1,0,0,0,0,0,0


# England

In [125]:
loader = EnglandCovidDatasetLoader()

dataset_loader = loader.get_dataset(lags = 32)

train_loader, val_loader = temporal_signal_split(dataset_loader,
                                                 train_ratio=0.2)

In [126]:
dataset_loader.features[0].shape # 129 regions (nodes) and 32 lags

(129, 32)

In [127]:
dataset_loader.features[0]

array([[-1.4697223 , -1.92830573, -1.69901402, ..., -0.32326373,
        -0.43790959, -0.43790959],
       [-1.2509701 , -1.18115383, -1.32078636, ...,  1.12278286,
         0.56425275,  1.12278286],
       [-1.09340295, -1.09340295, -1.09340295, ...,  0.63650546,
         1.78977774, -0.13234272],
       ...,
       [-1.00045398, -1.08022874, -0.7611297 , ..., -0.1229316 ,
        -0.28248112,  0.99391507],
       [-0.88180575, -0.88180575, -0.63248871, ..., -0.00919612,
        -0.1961839 ,  0.1154624 ],
       [-1.16647027, -1.16647027, -0.48229059, ..., -0.65333551,
         1.22815861,  1.57024845]])

In [41]:
dataset_loader.targets[0].shape # 129 targest

(129,)

In [55]:
len(dataset_loader.edge_weights[0])

2158

In [78]:
import pandas as pd
import numpy as np
def make_adj(edge_indices, edge_weights):
    A = pd.DataFrame(data = 0, index=np.arange(edge_indices.max()+1), columns=np.arange(edge_indices.max()+1))

    for i, w in zip(edge_indices.T,edge_weights):
        A.loc[i[0],i[1]] = w

    return A

A = make_adj(dataset_loader.edge_indices[0], dataset_loader.edge_weights[0])

In [106]:
def make_edges(A):
    tmp_df = pd.melt(A, ignore_index = False).reset_index()
    tmp_df = tmp_df[tmp_df.value!=0]
    return tmp_df.values[:,:2].T, tmp_df.values[:,2]

In [118]:
len(dataset_loader.features)

29