## A Molecule Classifier
In this section will demonstrate how to use TorchDrug to create a molecule classifier. Specifically, TorchDrug provides us with a large collection of popular datasets and models for drug discovery and graph representation learning.
We will use ClinTox dataset in the following section. This dataset requires to predict whether a molecule is toxic in clinic trials, and whether it is approved by FDA

In [1]:
import torch
import torchDrug as td
from torchdrug import datasets
from torch import utils

dataset = datasets.ClinTox('../data/')
lengths = [int(0.8 * len(dataset)), int(1.0 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = utils.data.random_split(dataset, lengths)

19:52:26   Downloading http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz to C:\Users\Leanne/data/clintox.csv.gz
19:52:27   Extracting C:\Users\Leanne/data/clintox.csv.gz to C:\Users\Leanne/data\clintox.csv


Loading C:\Users\Leanne/data\clintox.csv: 100%|██████████| 1485/1485 [00:00<00:00, 82718.55it/s]
[19:52:27] Explicit valence for atom # 0 N, 5, is greater than permitted
Constructing molecules from SMILES:  18%|█▊        | 273/1484 [00:00<00:02, 547.91it/s][19:52:27] Can't kekulize mol.  Unkekulized atoms: 9
Constructing molecules from SMILES:  66%|██████▌   | 976/1484 [00:01<00:01, 461.47it/s][19:52:29] Explicit valence for atom # 10 N, 4, is greater than permitted
[19:52:29] Explicit valence for atom # 10 N, 4, is greater than permitted
Constructing molecules from SMILES:  79%|███████▉  | 1171/1484 [00:02<00:00, 459.60it/s][19:52:29] Can't kekulize mol.  Unkekulized atoms: 4
[19:52:29] Can't kekulize mol.  Unkekulized atoms: 4
Constructing molecules from SMILES: 100%|██████████| 1484/1484 [00:03<00:00, 486.50it/s]


## define a GNN for classification
We define a GNN to encode the molecule graphs. We use a Graph Isomorphism network with 4 hidden layers

In [5]:
from torchdrug import core, models, tasks

dataset = datasets.ClinTox('./data/')
model = models.GIN(input_dim=dataset.node_feature_dim, hidden_dims=[256, 256, 256, 256],
                   short_cut=True, batch_norm=True, concat_hidden=True)

20:11:11   Downloading http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz to ./data/clintox.csv.gz
20:11:16   Extracting ./data/clintox.csv.gz to ./data\clintox.csv


Loading ./data\clintox.csv: 100%|██████████| 1485/1485 [00:00<00:00, 84509.80it/s]
[20:11:16] Explicit valence for atom # 0 N, 5, is greater than permitted
Constructing molecules from SMILES:  16%|█▋        | 242/1484 [00:00<00:01, 648.72it/s][20:11:16] Can't kekulize mol.  Unkekulized atoms: 9
Constructing molecules from SMILES:  64%|██████▍   | 955/1484 [00:01<00:01, 516.58it/s][20:11:18] Explicit valence for atom # 10 N, 4, is greater than permitted
[20:11:18] Explicit valence for atom # 10 N, 4, is greater than permitted
Constructing molecules from SMILES:  82%|████████▏ | 1218/1484 [00:02<00:00, 514.56it/s][20:11:18] Can't kekulize mol.  Unkekulized atoms: 4
[20:11:18] Can't kekulize mol.  Unkekulized atoms: 4
Constructing molecules from SMILES: 100%|██████████| 1484/1484 [00:02<00:00, 547.64it/s]


The model is simply a neural network without any training target. To adapt it for classification, we wrap it with a property prediction module. We define the classification task by the binary cross entropy criterion.

In [6]:
task = tasks.PropertyPrediction(model, task=dataset.tasks, criterion="bce", metric=("auprc", "auroc"))

In [7]:
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer, batch_size=1024, gpus=[0])
solver.train(num_epoch=100)

20:11:45   Preprocess training set
20:11:45   {'batch_size': 1024,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'eps': 1e-08,
               'foreach': None,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.PropertyPrediction',
          'criterion': 'bce',
          'graph_construction_model': None,
          'metric': ('auprc', 'auroc'),
          'mlp_batch_norm': False,
          'mlp_dropout': 0,
          'model': {'activation': 'relu',
                    'batch_norm': True,
                    'class': 'models.GIN',
                    'concat_hidden': True,
                    'edge_input_dim': None,
                    'eps': 0,
                  

ImportError: DLL load failed while importing torch_ext: 找不到指定的模块。