In [None]:
import torchdrug as td
from torchdrug import data
import matplotlib
%matplotlib inline  

edge_list = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 0]]
graph = data.Graph(edge_list, num_node=6)
graph.visualize()

In [None]:
mol = data.Molecule.from_smiles("C1=CC=CC=C1")
mol.visualize()
print(mol.node_feature.shape)
print(mol.edge_feature.shape)

In [None]:
mols = data.PackedMolecule.from_smiles(["CCSCCSP(=S)(OC)OC", "CCOC(=O)N", "N(Nc1ccccc1)c2ccccc2", "NC(=O)c1cccnc1"])
mols.visualize()
mols = mols.cuda()
print(mols)

In [None]:
import torch
from torchdrug import datasets

dataset = datasets.ClinTox("~/molecule-datasets/")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, lengths)

In [None]:
from torchdrug import core, models, tasks

model = models.GIN(input_dim=dataset.node_feature_dim, hidden_dims=[256, 256, 256, 256],
                   short_cut=True, batch_norm=True, concat_hidden=True)

In [None]:
task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                criterion="bce", metric=("auprc", "auroc"))

In [None]:
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer, batch_size=1024)
solver.train(num_epoch=10)

In [None]:
solver.evaluate("valid")