In [51]:
import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from torch_geometric_temporal.nn.recurrent import DCRNN
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader, EnglandCovidDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

In [14]:
class LitDiffConvModel(pl.LightningModule):

    def __init__(self, node_features, filters):
        super().__init__()
        self.recurrent = DCRNN(node_features, filters, 1)
        self.linear = torch.nn.Linear(filters, 1)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x = train_batch.x
        y = train_batch.y.view(-1, 1)
        edge_index = train_batch.edge_index
        h = self.recurrent(x, edge_index)
        h = F.relu(h)
        h = self.linear(h)
        loss = F.mse_loss(h, y)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x = val_batch.x
        y = val_batch.y.view(-1, 1)
        edge_index = val_batch.edge_index
        h = self.recurrent(x, edge_index)
        h = F.relu(h)
        h = self.linear(h)
        loss = F.mse_loss(h, y)
        metrics = {'val_loss': loss}
        self.log_dict(metrics)
        return metrics

In [44]:
loader = ChickenpoxDatasetLoader()

dataset_loader = loader.get_dataset(lags = 32)

train_loader, val_loader = temporal_signal_split(dataset_loader,
                                                 train_ratio=0.2)

In [18]:
model = LitDiffConvModel(node_features=32,
                         filters=16)

early_stop_callback = EarlyStopping(monitor='val_loss',
                                    min_delta=0.00,
                                    patience=10,
                                    verbose=False,
                                    mode='max')

In [19]:
trainer = pl.Trainer(callbacks=[early_stop_callback])

trainer.fit(model, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type   | Params
-------------------------------------
0 | recurrent | DCRNN  | 4.7 K 
1 | linear    | Linear | 17    
-------------------------------------
4.7 K     Trainable params
0         Non-trainable params
4.7 K     Total params
0.019     Total estimated model params size (MB)


Epoch 176: : 490it [00:01, 309.76it/s, loss=0.19, v_num=0]


In [50]:
dataset_loader.edge_index

array([[ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,
         3,  3,  3,  3,  3,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  6,
         6,  6,  7,  7,  7,  7,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,
        10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12,
        12, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 15,
        15, 15, 16, 16, 16, 16, 16, 17, 17, 17, 17, 18, 18, 18, 18, 18,
        18, 18, 19, 19, 19, 19],
       [10,  6, 13,  1,  0,  5, 16,  0, 16,  1, 14, 10,  8,  2,  5,  8,
        15, 12,  9, 10,  3,  4, 13,  0, 10,  2,  5,  0, 16,  6, 14, 13,
        11, 18,  7, 17, 11, 18,  3,  2, 15,  8, 10,  9, 13,  3, 12, 10,
         5,  9,  8,  3, 10,  2, 13,  0,  6, 11,  7, 13, 18,  3,  9, 13,
        12, 13,  9,  6,  4, 12,  0, 11, 10, 18, 19,  1, 14,  6, 16,  3,
        15,  8, 16, 14,  1,  0,  6,  7, 19, 17, 18, 14, 18, 17,  7,  6,
        19, 11, 18, 14, 19, 17]])

In [52]:
loader = EnglandCovidDatasetLoader()

dataset_loader = loader.get_dataset(lags = 32)

train_loader, val_loader = temporal_signal_split(dataset_loader,
                                                 train_ratio=0.2)

In [62]:
dataset_loader.edge_indices[1].shape

(2, 1743)