In [27]:
import pandas as pd
import torch
import os
from tqdm import tqdm
from torch.utils.data import random_split
from torch.utils.data import Dataset,DataLoader
from utils.Func import extract_esm_feature,seq2fasta,extract_cp_feature,filter_invalid_smiles
from dgl.data.utils import load_graphs
from utils.TAVC_dataset import TAVC_Dataset_infer
from utils.collator import Collator_TAVC_Infer
from utils.TAVC_trainer import TAVC_Trainer
from utils.scheduler import PolynomialDecayLR
from torch.optim import Adam
from torch.nn import  BCEWithLogitsLoss
from utils.model.KPGT_v2 import *
from utils.model.DeepAVC import *
from utils.featurizer import Vocab, N_ATOM_TYPES, N_BOND_TYPES,VIRTUAL_ATOM_FEATURE_PLACEHOLDER, VIRTUAL_BOND_FEATURE_PLACEHOLDER

In [23]:
args = {
    'config':'base',
    'd_fps': 512,
    'd_mds': 200,
    'dropout':0,
    'batch_size':32,
    'weight_decay':1e-6,
    'n_tasks':1,
    'lr': 1e-4,
    'cp_feature_dir': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC_inference/cp_feature',
    'pro_feature_dir': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC_inference/pro_feature',
    'kpgt_model_path':'/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/KPGT/KPGT.pth',
    'DeepTAVC_model_path': '/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/DeepAVC/DeepTAVC.pt',
    'device':'cuda:3'}
vocab = Vocab(N_ATOM_TYPES, N_BOND_TYPES)

In [7]:
TAVC_demo_dataset = pd.read_csv('/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC_inference/DeepTAVC_inference_demo.csv')

In [12]:
# Extract compound initial feature by RDKit(if necessary)
smiles_list = TAVC_demo_dataset['SMILES'].to_list()
# filter compound with invalid smiles 
valid_smiles, invalid_smiles = filter_invalid_smiles(smiles_list)
extract_cp_feature(smiles_list = valid_smiles, 
                   output_dir = args['cp_feature_dir'],
                   num_workers=32)

100%|██████████| 20/20 [00:00<00:00, 4510.25it/s]

extracting graphs



[Parallel(n_jobs=32)]: Using backend LokyBackend with 32 concurrent workers.
Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorchUsing backend: pytorchUsing backend: pytorch



Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch


extracting fingerprints
extracting molecular descriptors


[Parallel(n_jobs=32)]: Done  20 out of  20 | elapsed:    2.5s remaining:    0.0s
[Parallel(n_jobs=32)]: Done  20 out of  20 | elapsed:    2.5s finished


'Done!'

In [13]:
# Extract protein initial feature by ESM-2 (if necessary)
pro_seq_list = list(TAVC_demo_dataset['sequence'].unique())
# Transform protein sequences into the fasta format
seq2fasta(seq_list=pro_seq_list, 
          save_dir=args['pro_feature_dir'])

extract_esm_feature(
    model_location = '/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/ESM/esm2_t33_650M_UR50D.pt',
    fasta_file = os.path.join(args['pro_feature_dir'], 'target_seq.fasta'),
    output_dir = args['pro_feature_dir'],
    toks_per_batch = 10000,
    repr_layers = [-1],
    include=['per_tok'],
    device='cuda:3',
    truncation_seq_length = 1024)

Transferred model to GPUs
Read /home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepTAVC_inference/pro_feature/target_seq.fasta with 12 sequences
Processing 1 of 2 batches (9 sequences)
Processing 2 of 2 batches (3 sequences)


'Done!'

In [8]:
target_id_list = [ f'Target_{i+1}' for i in range(len(TAVC_demo_dataset['sequence'].unique()))]
seq2id_dict = dict(zip( TAVC_demo_dataset['sequence'].unique(),target_id_list))
TAVC_demo_dataset['target_idx'] = TAVC_demo_dataset['sequence'].map(seq2id_dict)


In [9]:
### load compound initial feature
graphs, label_dict = load_graphs(os.path.join(args['cp_feature_dir'], 'cp_graphs.pkl'))
fps = torch.load(os.path.join(args['cp_feature_dir'], 'cp_fps.pt'))
mds = torch.load(os.path.join(args['cp_feature_dir'], 'cp_mds.pt'))

In [10]:
assert len(graphs) == len(fps) == len(mds)

In [11]:
### load protein initial feature
pro_feature_dict = torch.load(os.path.join(args['pro_feature_dir'],'esm_feature.pt'),map_location=args['device'])

In [12]:
### 构建出dataloader
config = config_dict['base']
collator = Collator_TAVC_Infer(config['path_length'])

In [15]:
mydataset = TAVC_Dataset_infer(
    smiles_list= TAVC_demo_dataset['SMILES'].to_list(),
    target_id_list= TAVC_demo_dataset['target_idx'].to_list(),
    target_feature_dict= pro_feature_dict,
    graphs=graphs,
    fps=fps,
    mds=mds
)

In [18]:
myloader = DataLoader(mydataset, 
                        batch_size=32, 
                        shuffle=False, 
                        drop_last=False, 
                        collate_fn=collator)

In [19]:
next(iter(myloader))

(['COc1ccc(C2=C(c3c[nH]c4ccccc34)C(=O)NC2)cc1',
  'Nc1ncnc2c1c(-c1ccc(Oc3ccccc3)cc1)nn2[C@@H]1CCCN(C(=O)CCCCC(=O)NCC#CCN2CCN(CC#Cc3ccc4c(c3)CN(C3CCC(=O)NC3=O)C4=O)CC2)C1',
  'C=CC(=O)N1CCC[C@@H](n2nc(CCCOc3ccccc3)c3c(N)ncnc32)C1',
  'CN1[C@@H]2CC[C@H]1C[C@@H](NC(c1ccc(Cl)cc1)c1ccc(Cl)cc1)C2',
  'N#CN[C@H]1C[C@@H](NS(=O)(=O)c2cccc(F)c2)c2ccc(-c3ncnc4[nH]ccc34)cc21',
  'Oc1ccc2c(c1)CCCC(C1CCC(O)CC1)=C2c1ccc(O[C@H]2CCN(CCCF)C2)cc1',
  'CC(C)CCn1cc2c(nc(NC(=O)Cc3ccc(F)cc3)n3nc(-c4ccco4)nc23)n1',
  'CCC(=O)N1CC[C@@H](Cc2ccc(F)cc2)C[C@@H]1CCCNC(=O)Nc1cc(C(C)=O)cc(C(C)=O)c1',
  'COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)c2ccc3c(c2)[C@@H](NC(=O)OC(C)(C)C)CCS3)cc1',
  'N#C[C@H]1C[C@@H](O)CC[C@@H]1n1cc(C(N)=O)c(Nc2ccc([C@@H](O)C(F)(F)F)cc2)n1',
  'Cc1nc2cnc3ccc(-c4cc(Cl)c(O)c(Cl)c4)cc3c2n1C1CCN(C)CC1',
  'CN(C)C1CCN(c2ncc(-c3cnc4[nH]ccc4n3)c(N[C@H]3CCCN(S(C)(=O)=O)C3)n2)CC1',
  'O=C(CCCCCCN1C/C=C/CCOc2cccc(c2)-c2ccnc(n2)Nc2cccc(c2)C1)NO',
  'CC(N)C12CC3CC(CC(C3)C1)C2',
  'C=CC(=O

In [22]:
kpgt_model = LiGhTPredictor(
    d_node_feats=config['d_node_feats'],
    d_edge_feats=config['d_edge_feats'],
    d_g_feats=config['d_g_feats'],
    d_fp_feats=args['d_fps'],
    d_md_feats=args['d_mds'],
    d_hpath_ratio=config['d_hpath_ratio'],
    n_mol_layers=config['n_mol_layers'],
    path_length=config['path_length'],
    n_heads=config['n_heads'],
    n_ffn_dense_layers=config['n_ffn_dense_layers'],
    input_drop=0,
    attn_drop=args['dropout'],
    feat_drop=args['dropout'],
    n_node_types=vocab.vocab_size).to(args['device'])

In [24]:
### 加载预训练权重
kpgt_model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(args['kpgt_model_path'],map_location=args['device']).items()})
# 删除没用的模型结构
del kpgt_model.md_predictor
del kpgt_model.fp_predictor
del kpgt_model.node_predictor

In [29]:
### Model Initialization
DeepTAVC = CADTI_Finetune(
d_model=256,
n_heads=8,
num_layers=1,
kpgt_model=kpgt_model,
smiles_dim=768,
protein_dim=1280,
kpgt_features_dim=2304,
mlp_hidden_dim=256,
num_classes=1,
dropout=0,
return_attn=True).to(args['device'])

In [30]:
DeepTAVC.load_state_dict(torch.load(args['DeepTAVC_model_path'],map_location=args['device'])) # 加载预训练权重

<All keys matched successfully>

In [39]:
def forward_epoch( model, batched_data, device):

    smiless, graphs, fps, mds, target_feature, target_seqs  = batched_data

    target_feature = target_feature.to(device)
    fps = fps.to(device)
    mds = mds.to(device)
    graphs = graphs.to(device)
    target_feature = target_feature.to(device)

    predictions = model(graphs, fps, mds, target_feature)[0]
    return predictions,smiless,target_seqs

def eval( model, dataloader,device):
    model.eval()
    pred_scores_all = []
    smiles_list = []
    target_seqs_list = []
    
    for batched_data in tqdm(dataloader, total=len(dataloader)):

        smiless, graphs, fps, mds, target_feature, target_seqs  = batched_data

        target_feature = target_feature.to(device)
        fps = fps.to(device)
        mds = mds.to(device)
        graphs = graphs.to(device)
        target_feature = target_feature.to(device)
        predictions = model(graphs, fps, mds, target_feature)[0]
        predictions = predictions.squeeze(1)
        pred_scores = torch.sigmoid(predictions)
        pred_scores_all.extend(list(pred_scores.detach().cpu().numpy()))
        smiles_list.extend(smiless)
        target_seqs_list.extend(target_seqs)
    
    res_df = pd.DataFrame({
        'SMILES':smiles_list,
        'target_idx':target_seqs_list,
        'pred_score':pred_scores_all
    })

    return res_df

In [40]:
res_df = eval(model=DeepTAVC,dataloader=myloader,device=args['device'])

100%|██████████| 1/1 [00:00<00:00,  6.35it/s]


In [45]:
final_res_df = pd.merge(TAVC_demo_dataset,res_df, on=['SMILES','target_idx'])

In [46]:
final_res_df

Unnamed: 0,SMILES,sequence,pchembl_value,label,target_idx,pred_score
0,COc1ccc(C2=C(c3c[nH]c4ccccc34)C(=O)NC2)cc1,MPALARDGGQLPLLVVFSAMIFGTITNQDLPVIKCVLINHKNNDSS...,5.0,0,Target_1,8.380084e-08
1,Nc1ncnc2c1c(-c1ccc(Oc3ccccc3)cc1)nn2[C@@H]1CCC...,MAAVILESIFLKRSQQKKKTSPLNFKKRLFLLTVHKLSYYEYDFER...,8.05,1,Target_2,1.0
2,C=CC(=O)N1CCC[C@@H](n2nc(CCCOc3ccccc3)c3c(N)nc...,MAAVILESIFLKRSQQKKKTSPLNFKKRLFLLTVHKLSYYEYDFER...,7.65,1,Target_2,0.9999826
3,CN1[C@@H]2CC[C@H]1C[C@@H](NC(c1ccc(Cl)cc1)c1cc...,MLLARMNPQVQPENNGADTGPEQPLRARKTAELLVVKERNGVQCLL...,5.12,0,Target_3,1.483947e-10
4,N#CN[C@H]1C[C@@H](NS(=O)(=O)c2cccc(F)c2)c2ccc(...,MQYLNIKEDCNAMAFCAKMRSSKKTEVNLEAPEPGVEVIFYLSDRE...,5.57,0,Target_4,2.556997e-11
5,Oc1ccc2c(c1)CCCC(C1CCC(O)CC1)=C2c1ccc(O[C@H]2C...,MTMTLHTKASGMALLHQIQGNELEPLNRPQLKIPLERPLGEVYLDS...,8.31,1,Target_5,0.9997373
6,CC(C)CCn1cc2c(nc(NC(=O)Cc3ccc(F)cc3)n3nc(-c4cc...,MLLETQDALYVALELVIAALSVAGNVLVCAAVGTANTLQTPTNYFL...,5.3,0,Target_6,1.417815e-08
7,CCC(=O)N1CC[C@@H](Cc2ccc(F)cc2)C[C@@H]1CCCNC(=...,METTPLNSQKQLSACEDGEDCQENGVLQKVVPTPGDKVESGQISNG...,6.65,0,Target_7,0.0001204301
8,COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2c...,PQITLWQRPFVTIKIEGQLKEALLDTGADDTVLEEMNLPGRWKPKM...,9.42,1,Target_8,0.9999999
9,N#C[C@H]1C[C@@H](O)CC[C@@H]1n1cc(C(N)=O)c(Nc2c...,MGMACLTMTEMEGTSTSSIYQNGDISGNANSMKQIDPVLQVYLYHS...,8.52,1,Target_9,1.0
