In [1]:
import pytorch_lightning as pl
import numpy as np
import torch

## Datamodule

In [2]:
from torch_geometric.datasets import QM9


In [3]:
from shared import QM9MainDatamodule
from torch_geometric.loader import DataLoader as GeomDataLoader


class GraphQM9Datamodule(QM9MainDatamodule):
    def train_dataloader(self):
        return GeomDataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return GeomDataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return GeomDataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers)

## Networks

In [4]:
from torch_geometric.nn.models.dimenet import DimeNet as GeoDimeNet

# wraper for error in initialization, for proper initialization torch.no_grad() required
class DimeNet(GeoDimeNet):
    def __init__(self, *args, **kwargs):
        with torch.no_grad():
            super().__init__(*args, **kwargs)

In [5]:
import torch.nn.functional as F
from shared import Module

class DimeNetModule(Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.net = DimeNet(**kwargs)
        
    def step(self, batch, batch_idx):
        z, pos, y, g_batch = batch.z, batch.pos, batch.y, batch.batch
        output = self.net(z, pos, g_batch)
        loss = F.l1_loss(output, y)
        return loss, output.detach(), y.detach()

## Training

In [6]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

if __name__ ==  "__main__":
    logger = pl.loggers.tensorboard.TensorBoardLogger("./lightning_logs/", name='dimenet', version='with_gradient_clip')
    early_stop_callback = EarlyStopping(monitor="loss/val", patience=10, verbose=True, mode="min")
    checkpoint_callback = ModelCheckpoint(monitor='loss/val', filename='dimenet-epoch{epoch:02d}-val_loss{val/loss:.2f}',
                                          auto_insert_metric_name=False, save_top_k=10)


    trainer = pl.Trainer(gpus=[0], logger=logger, max_epochs=100,  
                         callbacks=[checkpoint_callback, early_stop_callback],  gradient_clip_val=0.5)
    datamodule = GraphQM9Datamodule()
    model = DimeNetModule(hidden_channels=128, out_channels=12, num_blocks=6,
                            num_bilinear=8, num_spherical=7, num_radial=6,
                            cutoff=5.0, envelope_exponent=5, num_before_skip=1,
                            num_after_skip=2, num_output_layers=3)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
if __name__ ==  "__main__":
    trainer.fit(model, datamodule)
    trainer.test(model, datamodule)