In [1]:
import torch
from torch import nn, optim
from torch.utils import data as torch_data

from torchdrug import core, datasets, tasks, models, data

dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain", bond_feature="pretrain")

Loading /home/abhor/molecule-datasets/clintox.csv: 100%|█████| 1485/1485 [00:00<00:00, 133028.80it/s]
[15:01:30] Explicit valence for atom # 0 N, 5, is greater than permitted
Constructing molecules from SMILES:  16%|███▊                    | 232/1484 [00:00<00:01, 976.67it/s][15:01:31] Can't kekulize mol.  Unkekulized atoms: 9
Constructing molecules from SMILES:  63%|███████████████         | 929/1484 [00:01<00:00, 793.80it/s][15:01:31] Explicit valence for atom # 10 N, 4, is greater than permitted
[15:01:31] Explicit valence for atom # 10 N, 4, is greater than permitted
Constructing molecules from SMILES:  79%|██████████████████▏    | 1171/1484 [00:01<00:00, 757.01it/s][15:01:32] Can't kekulize mol.  Unkekulized atoms: 4
[15:01:32] Can't kekulize mol.  Unkekulized atoms: 4
Constructing molecules from SMILES: 100%|███████████████████████| 1484/1484 [00:01<00:00, 784.23it/s]


In [18]:
for i in range(len(dataset.data)):
    dataset[i]["graph"] = data.Molecule.from_smiles(dataset[i]["graph"].to_smiles(), atom_feature="pretrain", bond_feature="pretrain", with_hydrogen=True)



In [19]:
model = models.GIN(input_dim=dataset.node_feature_dim,
#                    hidden_dims=[300, 300, 300, 300, 300],
                   hidden_dims=[128, 128, 128, 128, 128],
                   edge_input_dim=dataset.edge_feature_dim,
                   batch_norm=True, readout="mean")

task = tasks.AttributeMasking(model, mask_rate=0.15)

optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=None, batch_size=256)

solver.train(num_epoch=100)

# import json

# with open("models/clintox_gin.json", "w") as fout:
#     json.dump(solver.config_dict(), fout)
# solver.save("models/clintox_gin.pth")

23:50:53   Preprocess training set
23:50:53   {'batch_size': 256,
 'class': 'core.Engine',
 'gpus': None,
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'eps': 1e-08,
               'foreach': None,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.AttributeMasking',
          'graph_construction_model': None,
          'mask_rate': 0.15,
          'model': {'activation': 'relu',
                    'batch_norm': True,
                    'class': 'models.GIN',
                    'concat_hidden': False,
                    'edge_input_dim': 11,
                    'eps': 0,
                    'hidden_dims': [128, 128, 128, 128, 128],
                    'input_dim': 22,
                    'learn

23:51:12   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:12   Epoch 19 begin
23:51:13   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:13   Epoch 19 end
23:51:13   duration: 1.10 secs
23:51:13   speed: 5.44 batch / sec
23:51:13   ETA: 1.37 mins
23:51:13   ------------------------------
23:51:13   average accuracy: 0.851979
23:51:13   average cross entropy: 0.442334
23:51:13   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:13   Epoch 20 begin
23:51:14   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:14   Epoch 20 end
23:51:14   duration: 1.01 secs
23:51:14   speed: 5.96 batch / sec
23:51:14   ETA: 1.36 mins
23:51:14   ------------------------------
23:51:14   average accuracy: 0.85561
23:51:14   average cross entropy: 0.415975
23:51:14   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:14   Epoch 21 begin
23:51:16   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:16   Epoch 21 end
23:51:16   duration: 1.14 secs
23:51:16   speed: 5.26 batch / sec
23:51:16   ETA: 1.35 mins
23:51:16   ------------------------------
23:51:16   average accuracy: 0.8580

23:51:37   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:37   Epoch 42 end
23:51:37   duration: 1.15 secs
23:51:37   speed: 5.21 batch / sec
23:51:37   ETA: 58.01 secs
23:51:37   ------------------------------
23:51:37   average accuracy: 0.885856
23:51:37   average cross entropy: 0.320148
23:51:37   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:37   Epoch 43 begin
23:51:38   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:38   Epoch 43 end
23:51:38   duration: 1.03 secs
23:51:38   speed: 5.81 batch / sec
23:51:38   ETA: 57.01 secs
23:51:38   ------------------------------
23:51:38   average accuracy: 0.884558
23:51:38   average cross entropy: 0.32436
23:51:38   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:38   Epoch 44 begin
23:51:39   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:51:39   Epoch 44 end
23:51:39   duration: 1.07 secs
23:51:39   speed: 5.59 batch / sec
23:51:39   ETA: 56.06 secs
23:51:39   ------------------------------
23:51:39   average accuracy: 0.893947
23:51:39   average cross entropy: 0.317446
23:51:39   >>>>>>>>

23:52:01   Epoch 65 end
23:52:01   duration: 1.54 secs
23:52:01   speed: 3.90 batch / sec
23:52:01   ETA: 35.11 secs
23:52:01   ------------------------------
23:52:01   average accuracy: 0.899622
23:52:01   average cross entropy: 0.287455
23:52:01   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:01   Epoch 66 begin
23:52:02   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:02   accuracy: 0.903955
23:52:02   cross entropy: 0.324245
23:52:02   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:02   Epoch 66 end
23:52:02   duration: 1.08 secs
23:52:02   speed: 5.56 batch / sec
23:52:02   ETA: 34.10 secs
23:52:02   ------------------------------
23:52:02   average accuracy: 0.89404
23:52:02   average cross entropy: 0.306718
23:52:02   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:02   Epoch 67 begin
23:52:03   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:03   Epoch 67 end
23:52:03   duration: 1.01 secs
23:52:03   speed: 5.97 batch / sec
23:52:03   ETA: 33.05 secs
23:52:03   ------------------------------
23:52:03   average accuracy: 0.8996

23:52:24   Epoch 88 begin
23:52:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:24   Epoch 88 end
23:52:24   duration: 0.94 secs
23:52:24   speed: 6.41 batch / sec
23:52:24   ETA: 11.33 secs
23:52:24   ------------------------------
23:52:24   average accuracy: 0.906172
23:52:24   average cross entropy: 0.277536
23:52:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:24   Epoch 89 begin
23:52:25   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:25   Epoch 89 end
23:52:25   duration: 0.84 secs
23:52:25   speed: 7.18 batch / sec
23:52:25   ETA: 10.28 secs
23:52:25   ------------------------------
23:52:25   average accuracy: 0.910664
23:52:25   average cross entropy: 0.252135
23:52:25   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:25   Epoch 90 begin
23:52:26   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:52:26   Epoch 90 end
23:52:26   duration: 0.86 secs
23:52:26   speed: 6.94 batch / sec
23:52:26   ETA: 9.24 secs
23:52:26   ------------------------------
23:52:26   average accuracy: 0.90219
23:52:26   average cross entropy: 0.2

In [20]:
torch.save(model, "models/clintox_with_Hs_gin.pth")