<a href="https://colab.research.google.com/github/MedDataInt/Drug-discovery-from-TorchDrug/blob/main/TorchDrug_Pretraining_and_Finetuning_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Introduction

In many drug discovery tasks, it is costly in both time and money to collect labeled data. As a solution, self-supervised pretraining is introduced to learn molecular representations from massive unlabeled data.

In this tutorial, we will demonstrate how to pretrain a graph neural network on molecules, and how to finetune the model on downstream tasks.

### Manual Steps

0.   Get your own copy of this file via "File > Save a copy in Drive...",
1.   Set the runtime to **GPU** via "Runtime > Change runtime type..."

### Colab Tutorials

#### Quick Start
1. [Basic Usage and Pipeline](https://colab.research.google.com/drive/1Tbnr1Fog_YjkqU1MOhcVLuxqZ4DC-c8-#forceEdit=true&sandboxMode=true)

#### Drug Discovery Tasks
1. [Property Prediction](https://colab.research.google.com/drive/1sb2w3evdEWm-GYo28RksvzJ74p63xHMn?usp=sharing#forceEdit=true&sandboxMode=true)
2. [Pretrained Molecular Representations](https://colab.research.google.com/drive/10faCIVIfln20f2h1oQk2UrXiAMqZKLoW?usp=sharing#forceEdit=true&sandboxMode=true)
3. [De Novo Molecule Design](https://colab.research.google.com/drive/1JEMiMvSBuqCuzzREYpviNZZRVOYsgivA?usp=sharing#forceEdit=true&sandboxMode=true)
4. [Retrosynthesis](https://colab.research.google.com/drive/1IH1hk7K3MaxAEe5m6CFY7Eyej3RuiEL1?usp=sharing#forceEdit=true&sandboxMode=true)
5. [Knowledge Graph Reasoning](https://colab.research.google.com/drive/1-sjqQZhYrGM0HiMuaqXOiqhDNlJi7g_I?usp=sharing#forceEdit=true&sandboxMode=true)

In [None]:
import os
import torch
os.environ["TORCH_VERSION"] = torch.__version__

!pip install torch-scatter torch-cluster -f https://pytorch-geometric.com/whl/torch-$TORCH_VERSION.html
!pip install torchdrug

Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 232 kB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.9
Collecting torchdrug
  Downloading torchdrug-0.1.2.post1-py3-none-any.whl (191 kB)
[K     |████████████████████████████████| 191 kB 12.2 MB/s 
[?25hCollecting ninja
  Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)
[K     |████████████████████████████████| 108 kB 53.8 MB/s 
Collecting rdkit-pypi
  Downloading rdkit_pypi-2021.9.5.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.3 MB)
[K     |████████████████████████████████| 22.3 MB 62.9 MB/s 
Installing collected packages: rdkit-pypi, ninja, torchdrug
Successfully installed ninja-1.10.2.3 rdkit-pypi-2021.9.5.

# Self-Supervised Pretraining

Pretraining is an effective approach to transfer learning in Graph Neural Networks for graph-level property prediction. Here we focus on pretraining GNNs via different self-supervised strategies. These methods typically construct unsupervised loss functions based on structural information in molecules.

For illustrative purpose, we only use the ClinTox dataset in this tutorial, which is much smaller than the standard pretraining datasets. For real applications, we suggest using larger datasets like ZINC2M.



## Infograph

InfoGraph (IG) proposes to maximize the mutual information between the graph-level and node-level representations. It learns the model by distinguishing whether a node-graph pair comes from a single graph or two different graphs. The following figure illustrates the high-level idea of InfoGraph.

![infograph.png](https://raw.githubusercontent.com/DeepGraphLearning/torchdrug/master/asset/model/infograph.png)

We use GIN as our graph represenation model, and wrap it with InfoGraph.


In [None]:
import torch
from torch import nn
from torch.utils import data as torch_data

from torchdrug import core, datasets, tasks, models

dataset = datasets.ClinTox("~/molecule-datasets/", node_feature="pretrain",
                           edge_feature="pretrain")

gin_model = models.GIN(input_dim=dataset.node_feature_dim,
                       hidden_dims=[300, 300, 300, 300, 300],
                       edge_input_dim=dataset.edge_feature_dim,
                       batch_norm=True, readout="mean")
model = models.InfoGraph(gin_model, separate_model=False)

task = tasks.Unsupervised(model)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)

solver.train(num_epoch=10)
solver.save("clintox_gin_infograph.pth")

01:25:29   Downloading http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz to /root/molecule-datasets/clintox.csv.gz
01:25:30   Extracting /root/molecule-datasets/clintox.csv.gz to /root/molecule-datasets/clintox.csv


Loading /root/molecule-datasets/clintox.csv: 100%|██████████| 1485/1485 [00:00<00:00, 36184.89it/s]
Constructing molecules from SMILES: 100%|██████████| 1484/1484 [00:04<00:00, 348.53it/s]


01:25:44   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:25:44   Epoch 0 begin
01:26:10   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:26:10   graph-node mutual information: -0.064598
01:26:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:26:11   Epoch 0 end
01:26:11   duration: 36.41 secs
01:26:11   speed: 0.16 batch / sec
01:26:11   ETA: 5.46 mins
01:26:11   max GPU memory: 342.8 MiB
01:26:11   ------------------------------
01:26:11   average graph-node mutual information: -0.266218
01:26:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:26:11   Epoch 1 begin
01:26:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:26:11   Epoch 1 end
01:26:12   duration: 0.64 secs
01:26:12   speed: 9.43 batch / sec
01:26:12   ETA: 2.47 mins
01:26:12   max GPU memory: 338.8 MiB
01:26:12   ------------------------------
01:26:12   average graph-node mutual information: 0.186866
01:26:12   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:26:12   Epoch 2 begin
01:26:12   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:26:12   Epoch 2 end
01:26:12   duration: 0.62 secs
01:26:12   speed:

## Attribute Masking

The aim of Attribute Masking (AM) is to capture domain knowledge by learning the regularities of the node/edge attributes distributed over graph structure. The high-level idea is to predict atom types in molecular graphs from randomly masked node features.

![attrmasking.png](https://raw.githubusercontent.com/DeepGraphLearning/torchdrug/master/asset/model/attribute_masking.png)

Again, we use GIN as our graph representation model.



In [None]:
import torch
from torch import nn
from torch.utils import data as torch_data

from torchdrug import core, datasets, tasks, models

dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain",
                           bond_feature="pretrain")

model = models.GIN(input_dim=dataset.node_feature_dim,
                   hidden_dims=[300, 300, 300, 300, 300],
                   edge_input_dim=dataset.edge_feature_dim,
                   batch_norm=True, readout="mean")
task = tasks.AttributeMasking(model, mask_rate=0.15)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)

solver.train(num_epoch=10)
solver.save("clintox_gin_attributemasking.pth")

# Finetune on Labeled Datasets
When the GNN pre-training is finished, we can finetune the pre-trained GNN model on downstream tasks. Here we use BACE dataset for illustration, which contains 1,513 molecules with binding affinity results a set of inhibitors of human 𝛽-secretase 1(BACE-1).

First, we download the BACE dataset and split it into training, validation and test sets. Note that we need to set the node and edge feature in the dataset as pretrain in order to make it compatible with the pretrained model.



In [None]:
from torchdrug import data

dataset = datasets.BACE("~/molecule-datasets/",
                        atom_feature="pretrain", bond_feature="pretrain")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = data.ordered_scaffold_split(dataset, lengths)

01:27:18   Downloading http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/bace.csv to /root/molecule-datasets/bace.csv


Loading /root/molecule-datasets/bace.csv: 100%|██████████| 1514/1514 [00:00<00:00, 7727.68it/s]
Constructing molecules from SMILES: 100%|██████████| 1513/1513 [00:03<00:00, 463.58it/s]


Then, we define the same model as the pre-training stage and set up the optimizer and solver for our downstream task. The only difference here is that we use PropertyPrediction task to support supervised learning.



In [None]:
model = models.GIN(input_dim=dataset.node_feature_dim,
                hidden_dims=[300, 300, 300, 300, 300],
                edge_input_dim=dataset.edge_feature_dim,
                batch_norm=True, readout="mean")
task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                criterion="bce", metric=("auprc", "auroc"))

optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=256)

01:27:54   Preprocess training set


Now we can load our pretrained model and finetune it on downstream datasets.



In [None]:
checkpoint = torch.load("clintox_gin_infograph.pth")["model"]
task.load_state_dict(checkpoint, strict=False)

solver.train(num_epoch=100)
solver.evaluate("valid")

01:28:43   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:28:43   Epoch 0 begin
01:28:43   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:28:43   binary cross entropy: 0.689819
01:28:44   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:28:44   Epoch 0 end
01:28:44   duration: 49.59 secs
01:28:44   speed: 0.10 batch / sec
01:28:44   ETA: 1.36 hours
01:28:44   max GPU memory: 300.2 MiB
01:28:44   ------------------------------
01:28:44   average binary cross entropy: 0.611645
01:28:44   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:28:44   Epoch 1 begin
01:28:44   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:28:44   Epoch 1 end
01:28:44   duration: 0.52 secs
01:28:44   speed: 9.65 batch / sec
01:28:44   ETA: 40.92 mins
01:28:44   max GPU memory: 300.4 MiB
01:28:44   ------------------------------
01:28:45   average binary cross entropy: 0.530311
01:28:45   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:28:45   Epoch 2 begin
01:28:45   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:28:45   Epoch 2 end
01:28:45   duration: 0.49 secs
01:28:45   speed: 10.27 batch / sec
01:28:45

{'auprc [Class]': tensor(0.8905, device='cuda:0'),
 'auroc [Class]': tensor(0.6084, device='cuda:0')}