**Log P prediction using Graph Isomorphism Network (GIN)** 

Utilizing GIN from the Torchdrug package for property prediction


Install torchdrug

In [None]:
pip install torchdrug

Import relevant libraries

In [None]:
import torch
from torchdrug import data, datasets
from torchdrug import core, models, tasks, utils

Load dataset

In [None]:
dataset = datasets.Lipophilicity("~/molecule-datasets/")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, lengths)

Visualize some of the data


In [None]:
graphs = []
labels = []
for i in range(10):
    sample = dataset[i]
    graphs.append(sample.pop("graph"))
    label = ["%s: %d" % (k, v) for k, v in sample.items()]
    label = ", ".join(label)
    labels.append(label)
graph = data.Molecule.pack(graphs)
graph.visualize(labels, num_row=1)

Tried implementing the root mean squared error

In [None]:
class RMSELoss(torch.nn.Module):
    def __init__(self):
        super(RMSELoss,self).__init__()

    def forward(self,x,y):
        criterion = nn.MSELoss()
        eps = 1e-6
        loss = torch.sqrt(criterion(x, y) + eps)
        return loss

Specify model parameters

In [None]:
model = models.GIN(input_dim=dataset.node_feature_dim,
                   hidden_dims=[256, 256, 256, 256],
                   short_cut=True, batch_norm=True, concat_hidden=True)
task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                criterion="mse", metric=( "mae", "rmse"))

Train model

In [None]:
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=1024)
solver.train(num_epoch=100)
solver.evaluate("valid")

Evaluate model on validation set

In [None]:
solver.evaluate("valid")

05:57:49   Evaluate on valid
05:57:49   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
05:57:49   mean absolute error [exp]: 0.764657
05:57:49   root mean squared error [exp]: 0.971925


{'mean absolute error [exp]': tensor(0.7647, device='cuda:0'),
 'root mean squared error [exp]': tensor(0.9719, device='cuda:0')}

Evaluate model on test set

In [None]:
solver.evaluate("test")

06:11:18   Evaluate on test
06:11:18   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
06:11:18   mean absolute error [exp]: 0.834587
06:11:18   root mean squared error [exp]: 1.04739


{'mean absolute error [exp]': tensor(0.8346, device='cuda:0'),
 'root mean squared error [exp]': tensor(1.0474, device='cuda:0')}