In [1]:
import pytorch_lightning as pl
import numpy as np
import torch

## Datamodule

In [2]:
from shared import QM9MainDatamodule
from torch_geometric.loader import DataLoader as GeomDataLoader


class GraphQM9Datamodule(QM9MainDatamodule):
    def train_dataloader(self):
        return GeomDataLoader(self.train_set, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return GeomDataLoader(self.val_set, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return GeomDataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers)

## Networks

In [3]:
import torch.nn.functional as F
from shared import Module
from torch_geometric.nn.models.schnet import SchNet

class SchNetModule(Module):
    def __init__(self, **kwargs):
        super().__init__()
        if 'lr' in kwargs:
            self.lr = kwargs.pop('lr')
        self.net = SchNet(**kwargs)
        
    def step(self, batch, batch_idx):
        z, pos, y, g_batch = batch.z, batch.pos, batch.y, batch.batch
        output = self.net(z, pos, g_batch)
        loss = F.l1_loss(output, y)
        return loss, output.detach(), y.detach()
    
    def configure_optimizers(self):
        optimizer =  torch.optim.Adam(self.parameters(), lr=self.lr)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.96)
            },
        }

## Training

In [6]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

if __name__ ==  "__main__":
    logger = pl.loggers.tensorboard.TensorBoardLogger("./lightning_logs/", name='schnet', version='with_gradient_clip')
    early_stop_callback = EarlyStopping(monitor="loss/val", patience=10, verbose=False, mode="min")
    checkpoint_callback = ModelCheckpoint(monitor='loss/val', filename='schnet-epoch{epoch:02d}-val_loss{val/loss:.2f}',
                                          auto_insert_metric_name=False, save_top_k=10)


    trainer = pl.Trainer(gpus=[1], logger=logger, max_epochs=200,  
                         callbacks=[checkpoint_callback], gradient_clip_val=0.5)
    datamodule = GraphQM9Datamodule()
    # TODO output_size require modyfying of lib, make class in future
    model = SchNetModule(hidden_channels= 128, num_filters = 128,
                 num_interactions = 6, num_gaussians = 50,
                 cutoff= 10.0, max_num_neighbors = 32, output_size=12)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
if __name__ ==  "__main__":
    trainer.fit(model, datamodule, ckpt_path="./lightning_logs/schnet/with_gradient_clip/checkpoints/schnet-epoch99-val_loss0.00.ckpt")
    trainer.test(model, datamodule)