In [1]:
%load_ext autoreload
%autoreload 2

In [24]:
import logging

import torch
from torch.utils.data import Dataset, SubsetRandomSampler, random_split

from data.ZINC.smiles_to_graph import convertToGraph, get_prop
from gnn_pytorch.utils import normalize
from sklearn.model_selection import train_test_split

In [12]:

TRAIN_SIZE = 360000
VAL_SIZE = 90000
TEST_SIZE = 50000
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

logger = logging.getLogger(__name__)

In [45]:
class SmilesDataset(Dataset):
    """Possible Properties: ["LOGP", "TPSA", "QED"]"""

    def __init__(self, path_to_smiles: str, property: str, normalize: bool = False):
        self.smiles_list = []
        with open(path_to_smiles) as f:
            self.smiles_list = f.read().split("\n")
        self.smiles_list = self.smiles_list[:-1]
        self.property = property
        self.normalize = normalize

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        # It could possibly be faster by converting the list of smiles
        # into a graph dataset before hand instead of doing this every idx
        adj, feat = convertToGraph(self.smiles_list[idx], 1)
        prop = get_prop(self.smiles_list[idx], self.property)

        feat = torch.tensor(feat, dtype=torch.float)
        adj = torch.tensor(adj, dtype=torch.float)
        prop = torch.tensor([prop], dtype=torch.float)

        if self.normalize:
            adj = normalize(adj)

        return feat, adj, prop

logp_dataset = SmilesDataset('./data/ZINC/smiles.txt', 'LOGP', True)

In [46]:
train_set,val_set,test_set = random_split(logp_dataset, [TRAIN_SIZE, VAL_SIZE, TEST_SIZE], generator=torch.Generator().manual_seed(42))

In [47]:
from gnn_pytorch.models import GNN_Config, VanillaGCN
van_gcn_config = GNN_Config(n_layers=5)
vanilla_gcn = VanillaGCN(van_gcn_config)

05/21/2022 23:55:05 - INFO - gnn_pytorch.models -   number of parameters: 5.477770e+05


In [50]:
from gnn_pytorch.trainer import Trainer, TrainerConfig

tconf = TrainerConfig(max_epochs=1)
trainer = Trainer(model=vanilla_gcn, train_dataset=train_set, val_dataset=test_set, config=tconf)

In [51]:
trainer.train()

epoch 1 iter 313: train loss 1.53601.:   9%|▊         | 314/3600 [00:31<05:30,  9.95it/s]
