### First training stage:  Train CADTI

In [None]:
import pandas as pd
import torch
import os
from torch.utils.data import random_split
from torch.utils.data import Dataset,DataLoader
from utils.Func import extract_esm_feature,seq2fasta,extract_cp_feature,filter_invalid_smiles
from dgl.data.utils import load_graphs
from utils.TAVC_dataset import TAVC_Dataset_Train
from utils.collator import Collator_TAVC_Train
from utils.TAVC_trainer import TAVC_Trainer
from utils.scheduler import PolynomialDecayLR
from torch.optim import Adam
from torch.nn import  BCEWithLogitsLoss
from utils.model.KPGT_v2 import *
from utils.model.DeepAVC import *
from utils.featurizer import Vocab, N_ATOM_TYPES, N_BOND_TYPES,VIRTUAL_ATOM_FEATURE_PLACEHOLDER, VIRTUAL_BOND_FEATURE_PLACEHOLDER

In [40]:
args = {
    'config':'base',
    'd_fps': 512,
    'd_mds': 200,
    'dropout':0,
    'weight_decay':1e-6,
    'n_tasks':1,
    'lr': 1e-4,
    'kpgt_model_path':'/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/KPGT/KPGT.pth',
    'cp_feature_dir': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC/CPI_dataset/demo/cp_feature',
    'pro_feature_dir': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC/CPI_dataset/demo/pro_feature',
    'CADTI_model_path': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/DeepAVC/CADTI.pt',
    'n_epochs':20, 
    'device':'cuda:3',
    'random_seed': 42,
    'batch_size':32,
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio':0.1,
    'MLP_layer_num':2,
    'MLP_hidden_dim':256}
vocab = Vocab(N_ATOM_TYPES, N_BOND_TYPES)

In [3]:
# load CPI dataset
CPI_dataset = pd.read_pickle('/home2/kangboming/kangboming/workspace2/AVC_paper/data/T_AVC/cpi_data/cpi_dataset.pkl')

In [4]:
CPI_dataset_demo = CPI_dataset.sample(2000)

In [None]:
# Extract compound initial feature by RDKit(if necessary)
smiles_list = CPI_dataset_demo['SMILES'].to_list()
# filter compound with invalid smiles 
valid_smiles, invalid_smiles = filter_invalid_smiles(smiles_list)
extract_cp_feature(smiles_list = valid_smiles, 
                   output_dir = args['cp_feature_dir'],
                   num_workers=32)

In [None]:
# Extract protein initial feature by ESM-2 (if necessary)
pro_seq_list = list(CPI_dataset_demo['sequence'].unique())
# Transform protein sequences into the fasta format
seq2fasta(seq_list=pro_seq_list, 
          save_dir=args['pro_feature_dir'])

extract_esm_feature(
    model_location = '/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/ESM/esm2_t33_650M_UR50D.pt',
    fasta_file = os.path.join(args['pro_feature_dir'], 'target_seq.fasta'),
    output_dir = args['pro_feature_dir'],
    toks_per_batch = 10000,
    repr_layers = [-1],
    include=['per_tok'],
    device='cuda:3',
    truncation_seq_length = 1024)

In [7]:
target_id_list = [ f'Target_{i+1}' for i in range(len(CPI_dataset_demo['sequence'].unique()))]

In [8]:
seq2id_dict = dict(zip( CPI_dataset_demo['sequence'].unique(),target_id_list))

In [9]:
CPI_dataset_demo['target_idx'] = CPI_dataset['sequence'].map(seq2id_dict)

In [10]:
### load compound initial feature
graphs, label_dict = load_graphs(os.path.join(args['cp_feature_dir'], 'cp_graphs.pkl'))
fps = torch.load(os.path.join(args['cp_feature_dir'], 'cp_fps.pt'))
mds = torch.load(os.path.join(args['cp_feature_dir'], 'cp_mds.pt'))

In [11]:
assert len(graphs) == len(fps) == len(mds)

In [12]:
### load protein initial feature
pro_feature_dict = torch.load(os.path.join(args['pro_feature_dir'],'esm_feature.pt'),map_location=args['device'])

In [13]:
# Build dataset
CPI_dataset = TAVC_Dataset_Train(smiles_list = CPI_dataset_demo['SMILES'].to_list(),
                          target_seq_list=CPI_dataset_demo['target_idx'].to_list(),
                          target_feature_dict=pro_feature_dict,
                          label_list=CPI_dataset_demo['label'].to_list(),
                          graphs=graphs,
                          fps=fps,
                          mds=mds)

In [None]:
### data split
train_ratio = args['train_ratio']
val_ratio = args['val_ratio']
dataset_size = len(CPI_dataset)
train_size = int(train_ratio * dataset_size) 
val_size = int(val_ratio * dataset_size)   
test_size = dataset_size - train_size - val_size 

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(CPI_dataset, [train_size, val_size, test_size])
print(f'Train size:{len(train_dataset)}\nValidation size:{len(val_dataset)}\nTest size:{len(test_dataset)}')

### build dataloader 
config = config_dict[args['config']]
collator = Collator_TAVC_Train(config['path_length'])

train_loader = DataLoader(train_dataset, 
                          batch_size=args['batch_size'], 
                          shuffle=True,  
                          drop_last=False, 
                          collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], 
                        shuffle=False,
                          drop_last=False, 
                          collate_fn=collator)
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], 
                         shuffle=False,  
                         drop_last=False, 
                         collate_fn=collator)


In [15]:
kpgt_model = LiGhTPredictor(
    d_node_feats=config['d_node_feats'],
    d_edge_feats=config['d_edge_feats'],
    d_g_feats=config['d_g_feats'],
    d_fp_feats=args['d_fps'],
    d_md_feats=args['d_mds'],
    d_hpath_ratio=config['d_hpath_ratio'],
    n_mol_layers=config['n_mol_layers'],
    path_length=config['path_length'],
    n_heads=config['n_heads'],
    n_ffn_dense_layers=config['n_ffn_dense_layers'],
    input_drop=0,
    attn_drop=args['dropout'],
    feat_drop=args['dropout'],
    n_node_types=vocab.vocab_size).to(args['device'])

In [16]:
### 加载预训练权重
kpgt_model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(args['kpgt_model_path'],map_location=args['device']).items()})
# 删除没用的模型结构
del kpgt_model.md_predictor
del kpgt_model.fp_predictor
del kpgt_model.node_predictor

In [17]:
### Model Initialization
CADTI_model = CADTI_Finetune(
d_model=256,
n_heads=8,
num_layers=1,
kpgt_model=kpgt_model,
smiles_dim=768,
protein_dim=1280,
kpgt_features_dim=2304,
mlp_hidden_dim=256,
num_classes=1,
dropout=0,
return_attn=True).to(args['device'])

In [None]:
print("model have {}M parameters in total that require gradients".format(
sum(p.numel() for p in CADTI_model.parameters() if p.requires_grad) / 1e6))

In [19]:
optimizer = Adam(CADTI_model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
lr_scheduler = PolynomialDecayLR(optimizer, 
                                warmup_updates=args['n_epochs']*(len(train_loader))//100, 
                                tot_updates=args['n_epochs']*len(train_loader),
                                lr=args['lr'], 
                                end_lr=1e-5, 
                                power=1)

loss_fn = BCEWithLogitsLoss(reduction='mean')

In [20]:
trainer = TAVC_Trainer(args=args, 
                        optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    loss_fn=loss_fn,
                    device=args['device'],
                    model_name='CADTI')

In [None]:
performance_res_df = trainer.fit(model=CADTI_model,
                    train_loader=train_loader,
                    val_loader=val_loader)

In [22]:
performance_res_df

Unnamed: 0,acc,recall,prec,f1,auroc,auprc,epoch,set
0,0.475,0.892473,0.466292,0.612546,0.579138,0.558452,1,val
1,0.46,0.55914,0.436975,0.490566,0.455834,0.461419,2,val
2,0.515,0.258065,0.461538,0.331034,0.496433,0.484353,3,val
3,0.5,0.591398,0.470085,0.52381,0.500553,0.455568,4,val
4,0.545,0.580645,0.509434,0.542714,0.55924,0.491858,5,val
5,0.555,0.612903,0.518182,0.561576,0.553311,0.492143,6,val
6,0.54,0.537634,0.505051,0.520833,0.554216,0.501016,7,val
7,0.505,0.451613,0.466667,0.459016,0.530901,0.475482,8,val
8,0.495,0.462366,0.457447,0.459893,0.534218,0.499018,9,val
9,0.5,0.419355,0.458824,0.438202,0.534017,0.498821,10,val


### Second training stage: Train DeepTAVC

In [25]:
args = {
    'config':'base',
    'd_fps': 512,
    'd_mds': 200,
    'dropout':0,
    'weight_decay':1e-6,
    'n_tasks':1,
    'lr': 1e-4,
    'kpgt_model_path':'/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/KPGT/KPGT.pth',
    'cp_feature_dir': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC/TAVC_dataset/cp_feature',
    'pro_feature_dir': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC/TAVC_dataset/pro_feature',
    'n_epochs':20, 
    'device':'cuda:3',
    'random_seed': 42,
    'batch_size':32,
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio':0.1,
    'MLP_layer_num':2,
    'MLP_hidden_dim':256}

In [23]:
TAVC_data = pd.read_csv('/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC/TAVC_dataset/DeepTAVC_dataset.csv')

In [24]:
TAVC_data.head()

Unnamed: 0,canonical_smiles,sequence,pchembl_value,avd_label
0,B.CP(c1ccccc1)c1ccc(O)cc1,MTMTLHTKASGMALLHQIQGNELEPLNRPQLKIPLERPLGEVYLDS...,5.01,0
1,B.Oc1ccc(P(c2ccccc2)c2ccccc2)cc1,MTMTLHTKASGMALLHQIQGNELEPLNRPQLKIPLERPLGEVYLDS...,4.92,0
2,BP(=O)(COCCn1cnc2c(N)ncnc21)OP(=O)(O)OP(=O)(O)O,PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGK...,5.0075,0
3,BP(=O)(CO[C@H](C)Cn1cnc2c(N)ncnc21)OP(=O)(O)OP...,PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGK...,4.6325,0
4,BP(=O)(OC[C@@H]1C=C[C@H](n2cc(C)c(=O)[nH]c2=O)...,PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGK...,7.35,1


In [26]:
TAVC_dataset_demo = TAVC_data.sample(2000)

In [28]:
# Extract compound initial feature by RDKit(if necessary)
smiles_list = TAVC_dataset_demo['canonical_smiles'].to_list()
# filter compound with invalid smiles 
valid_smiles, invalid_smiles = filter_invalid_smiles(smiles_list)
extract_cp_feature(smiles_list = valid_smiles, 
                   output_dir = args['cp_feature_dir'],
                   num_workers=32)

100%|██████████| 2000/2000 [00:00<00:00, 5777.07it/s]

extracting graphs



[Parallel(n_jobs=32)]: Using backend LokyBackend with 32 concurrent workers.
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
[Parallel(n_jobs=32)]: Done 171 tasks      | elapsed:    7.2s
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorchUsing backend: pytorch


Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

[Parallel(n_jobs=32)]: Done 1657 tasks      | elapsed:   10.3s
[Parallel(n_jobs=32)]: Done 1937 out of 2000 | elapsed:   10.9s remaining:    0.4s
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorchUsing backend: pytorchUsing backend: pytorch



Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
U

extracting fingerprints
extracting molecular descriptors


'Done!'

In [29]:
# Extract protein initial feature by ESM-2 (if necessary)
pro_seq_list = list(TAVC_dataset_demo['sequence'].unique())
# Transform protein sequences into the fasta format
seq2fasta(seq_list=pro_seq_list, 
          save_dir=args['pro_feature_dir'])

extract_esm_feature(
    model_location = '/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/ESM/esm2_t33_650M_UR50D.pt',
    fasta_file = os.path.join(args['pro_feature_dir'], 'target_seq.fasta'),
    output_dir = args['pro_feature_dir'],
    toks_per_batch = 10000,
    repr_layers = [-1],
    include=['per_tok'],
    device='cuda:3',
    truncation_seq_length = 1024)

Transferred model to GPUs
Read /home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC/TAVC_dataset/pro_feature/target_seq.fasta with 134 sequences
Processing 1 of 15 batches (28 sequences)
Processing 2 of 15 batches (21 sequences)
Processing 3 of 15 batches (17 sequences)
Processing 4 of 15 batches (14 sequences)
Processing 5 of 15 batches (12 sequences)
Processing 6 of 15 batches (10 sequences)
Processing 7 of 15 batches (8 sequences)
Processing 8 of 15 batches (7 sequences)
Processing 9 of 15 batches (5 sequences)
Processing 10 of 15 batches (3 sequences)
Processing 11 of 15 batches (3 sequences)
Processing 12 of 15 batches (2 sequences)
Processing 13 of 15 batches (2 sequences)
Processing 14 of 15 batches (1 sequences)
Processing 15 of 15 batches (1 sequences)


'Done!'

In [30]:
target_id_list = [ f'Target_{i+1}' for i in range(len(TAVC_dataset_demo['sequence'].unique()))]
seq2id_dict = dict(zip( TAVC_dataset_demo['sequence'].unique(),target_id_list))
TAVC_dataset_demo['target_idx'] = TAVC_dataset_demo['sequence'].map(seq2id_dict)

In [32]:
### load compound initial feature
graphs, label_dict = load_graphs(os.path.join(args['cp_feature_dir'], 'cp_graphs.pkl'))
fps = torch.load(os.path.join(args['cp_feature_dir'], 'cp_fps.pt'))
mds = torch.load(os.path.join(args['cp_feature_dir'], 'cp_mds.pt'))

In [33]:
assert len(graphs) == len(fps) == len(mds)

In [34]:
### load protein initial feature
pro_feature_dict = torch.load(os.path.join(args['pro_feature_dir'],'esm_feature.pt'),map_location=args['device'])

In [35]:
# Build dataset
TAVC_dataset = TAVC_Dataset_Train(smiles_list = TAVC_dataset_demo['canonical_smiles'].to_list(),
                          target_seq_list=TAVC_dataset_demo['target_idx'].to_list(),
                          target_feature_dict=pro_feature_dict,
                          label_list=TAVC_dataset_demo['avd_label'].to_list(),
                          graphs=graphs,
                          fps=fps,
                          mds=mds)

In [36]:
### data split
train_ratio = args['train_ratio']
val_ratio = args['val_ratio']
dataset_size = len(TAVC_dataset)
train_size = int(train_ratio * dataset_size) 
val_size = int(val_ratio * dataset_size)   
test_size = dataset_size - train_size - val_size 

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(TAVC_dataset, [train_size, val_size, test_size])
print(f'Train size:{len(train_dataset)}\nValidation size:{len(val_dataset)}\nTest size:{len(test_dataset)}')

### build dataloader 
config = config_dict[args['config']]
collator = Collator_TAVC_Train(config['path_length'])

train_loader = DataLoader(train_dataset, 
                          batch_size=args['batch_size'], 
                          shuffle=True,  
                          drop_last=False, 
                          collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], 
                        shuffle=False,
                          drop_last=False, 
                          collate_fn=collator)
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], 
                         shuffle=False,  
                         drop_last=False, 
                         collate_fn=collator)


Train size:1600
Validation size:200
Test size:200


In [43]:
### Model Initialization
DeepTAVC = CADTI_Finetune(
d_model=256,
n_heads=8,
num_layers=1,
kpgt_model=kpgt_model,
smiles_dim=768,
protein_dim=1280,
kpgt_features_dim=2304,
mlp_hidden_dim=256,
num_classes=1,
dropout=0,
return_attn=True).to(args['device'])
print("model have {}M parameters in total that require gradients".format(
sum(p.numel() for p in DeepTAVC.parameters() if p.requires_grad) / 1e6))

model have 92.973273M parameters in total that require gradients


In [44]:
DeepTAVC.load_state_dict(torch.load(args['CADTI_model_path'],map_location=args['device'])) # 加载预训练权重

<All keys matched successfully>

In [45]:
trainer = TAVC_Trainer(args=args, 
                        optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    loss_fn=loss_fn,
                    device=args['device'],
                    model_name='DeepTAVC')

In [46]:
performance_res_df = trainer.fit(model=DeepTAVC,
                    train_loader=train_loader,
                    val_loader=val_loader)

  4%|▍         | 2/50 [00:00<00:09,  5.00it/s]

[Batch 1],3.121765613555908


100%|██████████| 50/50 [00:19<00:00,  2.53it/s]
100%|██████████| 7/7 [00:00<00:00,  7.26it/s]


[Epoch1], val_auroc: 0.771, val_auprc: 0.648 


  2%|▏         | 1/50 [00:00<00:08,  5.56it/s]

[Batch 1],0.5072362422943115


100%|██████████| 50/50 [00:16<00:00,  2.95it/s]
100%|██████████| 7/7 [00:00<00:00, 15.97it/s]


[Epoch2], val_auroc: 0.783, val_auprc: 0.674 


  2%|▏         | 1/50 [00:00<00:08,  5.62it/s]

[Batch 1],0.44070374965667725


100%|██████████| 50/50 [00:12<00:00,  4.08it/s]
100%|██████████| 7/7 [00:00<00:00,  7.78it/s]


[Epoch3], val_auroc: 0.787, val_auprc: 0.672 


  2%|▏         | 1/50 [00:00<00:24,  2.00it/s]

[Batch 1],0.18070515990257263


100%|██████████| 50/50 [00:16<00:00,  3.11it/s]
100%|██████████| 7/7 [00:01<00:00,  6.74it/s]


[Epoch4], val_auroc: 0.786, val_auprc: 0.659 


  2%|▏         | 1/50 [00:00<00:25,  1.92it/s]

[Batch 1],0.14406126737594604


100%|██████████| 50/50 [00:16<00:00,  3.12it/s]
100%|██████████| 7/7 [00:00<00:00, 17.00it/s]


[Epoch5], val_auroc: 0.786, val_auprc: 0.657 


  2%|▏         | 1/50 [00:00<00:08,  5.59it/s]

[Batch 1],0.11063903570175171


100%|██████████| 50/50 [00:12<00:00,  4.10it/s]
100%|██████████| 7/7 [00:00<00:00, 19.75it/s]


[Epoch6], val_auroc: 0.784, val_auprc: 0.662 


  2%|▏         | 1/50 [00:00<00:07,  6.87it/s]

[Batch 1],0.0610453300178051


100%|██████████| 50/50 [00:07<00:00,  6.58it/s]
100%|██████████| 7/7 [00:00<00:00, 19.94it/s]


[Epoch7], val_auroc: 0.784, val_auprc: 0.655 


  2%|▏         | 1/50 [00:00<00:07,  6.80it/s]

[Batch 1],0.05921315401792526


100%|██████████| 50/50 [00:07<00:00,  6.30it/s]
100%|██████████| 7/7 [00:00<00:00, 19.72it/s]


[Epoch8], val_auroc: 0.783, val_auprc: 0.650 


  2%|▏         | 1/50 [00:00<00:07,  6.76it/s]

[Batch 1],0.07419779896736145


100%|██████████| 50/50 [00:13<00:00,  3.79it/s]
100%|██████████| 7/7 [00:00<00:00, 16.97it/s]


[Epoch9], val_auroc: 0.783, val_auprc: 0.651 


  2%|▏         | 1/50 [00:00<00:08,  5.91it/s]

[Batch 1],0.04547952860593796


100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
100%|██████████| 7/7 [00:01<00:00,  6.95it/s]


[Epoch10], val_auroc: 0.782, val_auprc: 0.649 


  2%|▏         | 1/50 [00:00<00:24,  1.99it/s]

[Batch 1],0.015752270817756653


100%|██████████| 50/50 [00:18<00:00,  2.74it/s]
100%|██████████| 7/7 [00:01<00:00,  6.73it/s]


[Epoch11], val_auroc: 0.781, val_auprc: 0.644 


  4%|▍         | 2/50 [00:00<00:10,  4.46it/s]

[Batch 1],0.03219792991876602


100%|██████████| 50/50 [00:12<00:00,  4.09it/s]
100%|██████████| 7/7 [00:00<00:00, 20.05it/s]


[Epoch12], val_auroc: 0.777, val_auprc: 0.640 


  2%|▏         | 1/50 [00:00<00:20,  2.45it/s]

[Batch 1],0.024051669985055923


100%|██████████| 50/50 [00:16<00:00,  3.00it/s]
100%|██████████| 7/7 [00:00<00:00, 15.15it/s]


[Epoch13], val_auroc: 0.778, val_auprc: 0.640 


  2%|▏         | 1/50 [00:00<00:08,  5.55it/s]

[Batch 1],0.01233658753335476


100%|██████████| 50/50 [00:17<00:00,  2.94it/s]
100%|██████████| 7/7 [00:00<00:00,  9.46it/s]


[Epoch14], val_auroc: 0.778, val_auprc: 0.642 


  2%|▏         | 1/50 [00:00<00:08,  6.04it/s]

[Batch 1],0.0046801394782960415


100%|██████████| 50/50 [00:08<00:00,  6.02it/s]
100%|██████████| 7/7 [00:00<00:00, 16.95it/s]


[Epoch15], val_auroc: 0.777, val_auprc: 0.637 


  2%|▏         | 1/50 [00:00<00:08,  5.74it/s]

[Batch 1],0.005778110586106777


100%|██████████| 50/50 [00:08<00:00,  5.87it/s]
100%|██████████| 7/7 [00:00<00:00, 17.08it/s]


[Epoch16], val_auroc: 0.777, val_auprc: 0.634 


  2%|▏         | 1/50 [00:00<00:08,  5.93it/s]

[Batch 1],0.004880940541625023


 92%|█████████▏| 46/50 [00:14<00:01,  3.25it/s]


KeyboardInterrupt: 