# DrugBAN Running Demo | [Paper](https://doi.org/10.1038/s42256-022-00605-1) | [Repo](https://github.com/peizhenbai/DrugBAN)

| [Open In Colab](https://colab.research.google.com/github/pz-white/DrugBAN/blob/main/drugban_demo.ipynb) (click `Runtime` → `Run all (Ctrl+F9)` |

This is a code demo of DrugBAN framework for drug-target interaction prediction. It takes about 3 minutes to run the whole pipeline.

## Setup

The first few blocks of code are necessary to set up the notebook execution environment. This checks if the notebook is running on Google Colab and installs required packages.

In [None]:
# if 'google.colab' in str(get_ipython()):
#     print('Running on CoLab')
#     !pip uninstall --yes yellowbrick
#     !pip install -U -q psutil
#     !pip install dgl dgllife
#     !pip install rdkit-pypi
#     !pip install PrettyTable yacs
#     !git clone https://github.com/pz-white/DrugBAN.git
#     %cd DrugBAN
# else:
#     print('Not running on CoLab')

## Import required modules.

In [1]:
from models import DrugBAN
from time import time
from utils import set_seed, graph_collate_func, mkdir
from configs import get_cfg_defaults
from dataloader import DTIDataset, MultiDataLoader
from torch.utils.data import DataLoader
from trainer import Trainer
from domain_adaptator import Discriminator
import torch
import argparse
import warnings, os
import pandas as pd
import pickle
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem

## Configuration

The customized configuration used in this demo is stored in `configs/DrugBAN_Demo.yaml`, this file overwrites defaults in `config.py` where a value is specified.

For saving time to run a whole pipeline in this demo, we sample small subsets from the original BindingDB dataset, which is located at `datasets/bindingdb_sample`

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg_path = "./configs/DrugBAN_Demo.yaml"
data = "binding_db_refined/random"
comet_support = False

cfg = get_cfg_defaults()
cfg.merge_from_file(cfg_path)
cfg.freeze()

torch.cuda.empty_cache()
warnings.filterwarnings("ignore")
set_seed(cfg.SOLVER.SEED)
mkdir(cfg.RESULT.OUTPUT_DIR)
experiment = None
print(f"Config yaml: {cfg_path}")
print(f"Running on: {device}")
print(f"Hyperparameters:")
dict(cfg)

Config yaml: ./configs/DrugBAN_Demo.yaml
Running on: cpu
Hyperparameters:


{'DRUG': CfgNode({'NODE_IN_FEATS': 75, 'PADDING': True, 'HIDDEN_LAYERS': [128, 128, 128], 'NODE_IN_EMBEDDING': 128, 'MAX_NODES': 290}),
 'PROTEIN': CfgNode({'NUM_FILTERS': [128, 128, 128], 'KERNEL_SIZE': [3, 6, 9], 'EMBEDDING_DIM': 128, 'PADDING': True}),
 'BCN': CfgNode({'HEADS': 2}),
 'DECODER': CfgNode({'NAME': 'MLP', 'IN_DIM': 256, 'HIDDEN_DIM': 512, 'OUT_DIM': 128, 'BINARY': 1}),
 'SOLVER': CfgNode({'MAX_EPOCH': 10, 'BATCH_SIZE': 64, 'NUM_WORKERS': 0, 'LR': 5e-05, 'DA_LR': 0.001, 'SEED': 42}),
 'RESULT': CfgNode({'OUTPUT_DIR': './result/demo', 'SAVE_MODEL': True}),
 'DA': CfgNode({'TASK': False, 'METHOD': 'CDAN', 'USE': False, 'INIT_EPOCH': 10, 'LAMB_DA': 1, 'RANDOM_LAYER': False, 'ORIGINAL_RANDOM': False, 'RANDOM_DIM': None, 'USE_ENTROPY': True}),
 'COMET': CfgNode({'WORKSPACE': 'pz-white', 'PROJECT_NAME': 'DrugBAN', 'USE': False, 'TAG': None})}

## Data Loader

The train/valid/test datasets are specified using the `DTIDataset()` function and loaded using the `DataLoader()` function.

In [3]:
# # 给df里的小分子加上坐标并保存
# def add_xyz(df):
#     bad_mol = []
#     coords_list = []
#     list_ids = df.index.values
#     print(len(list_ids))
#     for index in list_ids:
#         if index%100 == 0:
#             print(index)
#         try:
#             mol_smiles = df.iloc[index]['SMILES']
#             # 给分子图加上坐标
#             mol = Chem.MolFromSmiles(mol_smiles)
#             mol = Chem.AddHs(mol)
#             # 获取坐标
#             AllChem.EmbedMolecule(mol)
#             AllChem.UFFOptimizeMolecule(mol)
#             # 获取优化后的分子坐标
#             coords = mol.GetConformer().GetPositions()
#             coords = torch.tensor(coords, dtype=torch.float32)
#         except:
#             bad_mol.append(index)
#             coords = None
#         coords_list.append(coords)
    
#     df['xyz'] = coords_list
#     # 删除指定索引的行
#     df = df.drop(bad_mol)
#     # 重新排列索引
#     df = df.reset_index(drop=True)
#     return bad_mol, df



In [4]:
# # 给bingding_db数据集里的小分子加上坐标并保存
# bad_train, train_refine = add_xyz(df_train)
# with open("/Users/caozhiwei/RDKit/DrugBAN-main/datasets/binding_db_refined/random/train.pkl",'wb') as file:
#     pickle.dump(train_refine,file)

# bad_val, val_refine = add_xyz(df_val)
# with open("/Users/caozhiwei/RDKit/DrugBAN-main/datasets/binding_db_refined/random/val.pkl",'wb') as file:
#     pickle.dump(val_refine,file)

# bad_test, test_refine = add_xyz(df_test)
# with open("/Users/caozhiwei/RDKit/DrugBAN-main/datasets/binding_db_refined/random/test.pkl",'wb') as file:
#     pickle.dump(test_refine,file)

In [5]:
dataFolder = f'./datasets/{data}'

train_path = os.path.join(dataFolder, 'train.pkl')
val_path = os.path.join(dataFolder, "val.pkl")
test_path = os.path.join(dataFolder, "test.pkl")

with open(train_path, 'rb') as file:
    df_train =pickle.load(file)
with open(val_path, 'rb') as file:
    df_val =pickle.load(file)
with open(test_path, 'rb') as file:
    df_test =pickle.load(file)
print(len(df_train.index.values))

df_train = df_train.drop([320, 2287, 6926, 11398, 11480, 29670])
df_train = df_train.reset_index(drop=True)
print(len(df_train.index.values))

34305
34299


In [6]:
train_dataset = DTIDataset(df_train.index.values, df_train)
val_dataset = DTIDataset(df_val.index.values, df_val)
test_dataset = DTIDataset(df_test.index.values, df_test)

params = {'batch_size': cfg.SOLVER.BATCH_SIZE, 'shuffle': True, 'num_workers': cfg.SOLVER.NUM_WORKERS, 'drop_last': True, 'collate_fn': graph_collate_func}
training_generator = DataLoader(train_dataset, **params)
params['shuffle'] = False
params['drop_last'] = False
val_generator = DataLoader(val_dataset, **params)
test_generator = DataLoader(test_dataset, **params)

In [None]:
# from models import swish, Interaction, FlexibleMol
# drug_extractor = FlexibleMol(node_channels=74, 
#                              num_radial=6, 
#                              num_spherical=3, 
#                              cutoff=3.0, 
#                              hidden_channels=128, 
#                              middle_channels=256,
#                              num_gnn=3, 
#                              num_lin=3, 
#                              num_res=4, 
#                              act=swish)

In [None]:
# # 找出会出错误的小分子的索引
# for i in range(len(train_dataset)):
#     if i % 1000 == 0:
#         print('now is {}'.format(i))
#     v_d = drug_extractor(train_dataset[i][0], train_dataset[i][1])
#     has_nan = torch.isnan(v_d)
#     if torch.any(has_nan):
#         print(i)

## Setup Model and Optimizer

Here, we use the previously defined configuration to set up the model and optimizer we will subsequently train.


In [7]:
model = DrugBAN(**cfg).to(device)
opt = torch.optim.Adam(model.parameters(), lr=cfg.SOLVER.LR)
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

## Model Training and Test Optimized Model

Optimize model parameters using the trainer and check test performance.

In [None]:
trainer = Trainer(model, opt, device, training_generator, val_generator, test_generator, opt_da=None, discriminator=None, experiment=experiment, **cfg)
result = trainer.train()
with open(os.path.join(cfg.RESULT.OUTPUT_DIR, "model_architecture.txt"), "w") as wf:
    wf.write(str(model))
print(f"Directory for saving result: {cfg.RESULT.OUTPUT_DIR}")

## Expected Output

Awesome! You complete all demo steps and should get output like the following. Please note that these numbers might be different due to the update of environment setup on colab.

```
Training at Epoch 1 with training loss 0.7483742804754347
Validation at Epoch 1 with validation loss 0.6943950802087784  AUROC 0.6544117647058824 AUPRC 0.44206349206349205
Test at Best Model of Epoch 1 with test loss 0.6565468311309814  AUROC 0.4245614035087719 AUPRC 0.4018830588082055 Sensitivity 0.0 Specificity 1.0 Accuracy 0.3877551020408163 Thred_optim 0.42230069637298584
Directory for saving result: ./result/demo
```

Finally, the output result is saved in the colab temporary directory: `DrugBAN/result/demo`. You can access it by clicking `Files` tab on the left side of colab interface.