In [1]:
import pytorch_lightning as pl
import numpy as np
import torch

## Datamodule

In [2]:
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset
from shared import QM9MainDatamodule

class FlattenQM9(Dataset):
    MAX_NUMBER_ATOMS = 29
    
    def __init__(self, qm9_dataset):
        self.qm9 = qm9_dataset
    
    def __len__(self):
        return len(self.qm9)
    
    def __getitem__(self, idx):
        data = self.qm9[idx]
        edges = torch.zeros((self.MAX_NUMBER_ATOMS, self.MAX_NUMBER_ATOMS, data.edge_attr.shape[1]))
        edges[data.edge_index[0], data.edge_index[1], :] = data.edge_attr        
        
        nodes = torch.full((self.MAX_NUMBER_ATOMS, data.x.shape[1] + data.pos.shape[1]), 0)
        nodes[:data.x.shape[0], :] = torch.cat([data.x, data.pos], dim=1)
        
        x = torch.cat([
            torch.flatten(nodes),
            torch.flatten(edges),
        ]).reshape(1,-1)
        return x, data.y

class FlattenQM9Datamodule(QM9MainDatamodule):
    def setup(self, stage=None):
        super().setup(stage)
        self.train_set = FlattenQM9(self.train_set)
        self.val_set = FlattenQM9(self.val_set)
        self.test_set = FlattenQM9(self.test_set)

    def train_dataloader(self):
        return TorchDataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return TorchDataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return TorchDataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers)

## Network

In [3]:
class FeedforwadNN(torch.nn.Module):
        def __init__(self, input_size, hidden_size, output_size):
            super(FeedforwadNN, self).__init__()
            self.input_size = input_size
            self.hidden_size  = hidden_size
            self.output_size = output_size
            self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
            self.relu = torch.nn.ReLU()
            self.fc2 = torch.nn.Linear(self.hidden_size, self.hidden_size)
            self.fc3 = torch.nn.Linear(self.hidden_size, self.output_size)
            
        def forward(self, x):
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            output = self.fc3(x)
            return output


In [4]:
from shared  import Module

import torch.nn.functional as F

class FeedforwadModule(Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.net = FeedforwadNN(input_size, hidden_size, output_size)
    
    def step(self, batch, batch_idx):
        x, y = batch
        output = self.net(x)
        loss = F.l1_loss(output, y)
        return loss, output.detach(), y.detach()

## Training

In [5]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

if __name__ ==  "__main__":
    early_stop_callback = EarlyStopping(monitor="loss/val", patience=10, verbose=False, mode="min")
    checkpoint_callback = ModelCheckpoint(monitor='loss/val', filename='feedforward-epoch{epoch:02d}-val_loss{val/loss:.2f}',
                                            auto_insert_metric_name=False, save_top_k=10)
    logger = pl.loggers.tensorboard.TensorBoardLogger("./lightning_logs/", name='feedforward_2_hidden_layers', version="version_0")
    trainer = pl.Trainer(gpus=[1], logger=logger, callbacks=[checkpoint_callback, early_stop_callback], max_epochs=100)
    datamodule = FlattenQM9Datamodule()

    datamodule.setup()
    x, y = datamodule.val_set[0]
    model = FeedforwadModule(input_size=x.shape[1], hidden_size=100, output_size=y.shape[1])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
if __name__ == "__main__":
    trainer.fit(model, datamodule)
    trainer.test(model, datamodule)