In [6]:
import os
import time
import torch
import esm
import pandas as pd
from transformers import AutoModel, AutoTokenizer
from model import MMCLKins
from torch_geometric.data import Data
from process_3dkdavis import generate_inner_coor, generate_protein_sequence, generate_graph_feature, gen_seq_list, gene_smi_fes
from collates import collate
%matplotlib inline

Downloading the datasets

In [8]:
print('1')

Uploading the datasets

In [3]:
ligand_path = './cases/lrrk2/sdfs'
kinase_path = './cases/lrrk2/8fo7_kinase.pdb'
pocket_path = './cases/lrrk2/8fo7_pocket_12.pdb'

Exacting the features

In [4]:
def generate_kin_fes(ki, batch_converter, model1):
    error_pdbs = []
    pdbs_1022 = []
    error_num = []
    kinase_feas = {}
    kina = ki.split('.')[1]
    pdb_path = ki
    pro_index = gen_seq_list(pdb_path)
    protein, protein_graph, coords = generate_graph_feature(pdb_path)
    protein_sequence, chain_num = generate_protein_sequence(pdb_path)
    
    pr_dist, pr_theta, pr_phi, pr_tau = generate_inner_coor(protein.x, protein.node_s, protein.edge_index)
    prot = [(0, protein_sequence[:1022])]
    batch_labels, batch_strs, batch_tokens = batch_converter(prot)  #0（蛋白序号）；Fasta序列，蛋白表征
    with torch.no_grad():
        results = model1(batch_tokens, repr_layers=[33], return_contacts=True)
    pro_token_repre = results["representations"][33]
    n = 0
    if pro_token_repre.shape[1] != protein.node_s.shape[0] + 2 :
        n = n + 1
        pdbs_1022.append(ki)
        error_num.append(pro_token_repre.shape[1] - protein.node_s.shape[0] - 2)
        if pro_token_repre.shape[1] > 1022:
            print(f'{ki} pdb exceeds 1022!!!')
        else:
            error_pdbs.append(ki)
            print(f'{ki} pdb exceeds 1022!!!')
        
    kinase_fea = Data(pro_graphs=protein_graph,
                    pro_index=pro_index,
                    pro_atoms_feats_s=protein.node_s,
                    pro_atoms_feats_v=protein.node_v,
                    pro_coords_feats=protein.x,
                    pro_edges_feats_s=protein.edge_s,
                    pro_edges_feats_v=protein.edge_v,
                    pro_edge_index=protein.edge_index,
                    pro_token_repre=pro_token_repre,
                    pro_fp=protein_sequence,
                    pr_dist=pr_dist,
                    pr_theta=pr_theta,
                    pr_phi=pr_phi,
                    pr_tau=pr_tau)
    kinase_feas[kina] = kinase_fea
            
    return kinase_feas, pdbs_1022, error_pdbs

In [5]:
model1, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()
kinase_fes, kina_1022, error_kina = generate_kin_fes(kinase_path, batch_converter, model1)
pocket_fes, pock_1022, error_pock = generate_kin_fes(pocket_path, batch_converter, model1)

model_name = "DeepChem/ChemBERTa-10M-MLM"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
lig_fes, error_mol = gene_smi_fes(ligand_path, model, tokenizer)

Uploading the pkl file of MMCLKin

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
path = './pkls/virtual_screening/MCLLNGRCSE_DTI_pearson_best.pkl'

lstm_dropout = 0.2
alpha = 0.2
num_heads = 2
hidden_dim = 256
dropout_rate = 0.3
model = MMCLKins(lstm_dropout, alpha, num_heads, hidden_dim, dropout_rate, n_head=8, smile_vocab=63, local_rank=device) 
model.load_state_dict(torch.load(path)['model'], strict=True)
model.to(device)

MMCLKins(
  (dropout): Dropout(p=0.3, inplace=False)
  (leakyrelu): LeakyReLU(negative_slope=0.2)
  (relu): ReLU()
  (prelu): PReLU(num_parameters=1)
  (elu): ELU(alpha=1.0)
  (conv): GCNConv(259, 256)
  (conv3): GCNConv(256, 256)
  (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (bilstm): LSTM(256, 256, dropout=0.2, bidirectional=True)
  (self_attention): A_MultiHeadAttention(
    (qkv_proj): Linear(in_features=512, out_features=1536, bias=True)
    (o_proj): Linear(in_features=512, out_features=512, bias=True)
  )
  (out_fc): Linear(in_features=512, out_features=256, bias=True)
  (out_m_fc): Linear(in_features=256, out_features=128, bias=True)
  (output_fc): Linear(in_features=128, out_features=1, bias=True)
  (smi_egnn): EGNNConv(
    (edge_mlp): Sequential(
      (0): Linear(in_features=57, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, b

Predicting the binding affinity between kinase and inhibitors

In [7]:
model.eval()
pred_results = []
lig_keys = lig_fes.keys()
with torch.no_grad():
    c = 0
    for i, na in enumerate(lig_keys):
        lig_feas = lig_fes[na]
        print(kinase_fes)
        kin_fes = kinase_fes['/cases/lrrk2/8fo7_kinase']
        poc_fes = pocket_fes['/cases/lrrk2/8fo7_pocket_12']
        x_feats = collate(lig_feas, kin_fes, poc_fes)
        x_feats = x_feats.to(device)
        y_pred, g_spo_att, g_spr_att, s_spo_att, s_spr_att, pg_spo_att, pg_spr_att, ps_spo_att, ps_spr_att, atominmol_indexes, subwinsmi_indexes, pocinpro_indexes = model(x_feats)
        re = [na, y_pred.item(), g_spo_att.tolist(), g_spr_att.tolist(), s_spo_att.tolist(), s_spr_att.tolist(), atominmol_indexes[0].tolist(), 
            pg_spo_att.tolist(), pg_spr_att.tolist(), subwinsmi_indexes[0].tolist(), ps_spo_att.tolist(), ps_spr_att.tolist(), pocinpro_indexes[0].tolist()]
        pred_results.append(re)


name = ['compound_id', 'pred_results', 'graph=smi+pocket_attention', 'graph=smi+pocket_in_protein_attention',      
            'sequence=smi+pocket_attention', 'sequence=smi+pocket_in_protein_attention','graph=smi_atominmol_indexes',
            'graph=atom_in_mol+pocket_attention(atominmol_indexes)', 'graph=atom_in_mol+protein_attention(atominmol_indexes)',
            'sequence=smi_subwinsmi_indexes', 'sequence=subword_in_smile+pocket_sequence_attention(subwinsmi_indexes)',
            'sequence=subword_in_smile+protein_sequence_attention(subwinsmi_indexes)','pocinpro_indexes']
da = pd.DataFrame(columns=name, data=pred_results)
print(da)

Processing the results

In [9]:
import operator
da['compound_id'] = da['compound_id'].str.replace('_P', '', regex=False)
df = da.sort_values(by='pred_results', ascending=False)
name = list(df['compound_id'])
label = 'True'
all_le = 11
na = []
nas = []
nass = []
for i, ns in enumerate(name):
    if '_' in ns:
        if ns.split('_')[0] not in nas:
            nas.append(ns.split('_')[0])
            nass.append(df.loc[i].tolist())
    else:
        if ns not in nas:
            nas.append(ns)
            nass.append(df.loc[i].tolist())

all_length = len(nas)
m1 = m2 = m3 = m4 = m5 = m6 = m7 = m8 = m9= m10 = 0
for i, com in enumerate(nas):
    if i < all_length*0.01:
        if str(com[:4]) == label:
            m1 = m1 + 1
    if i < all_length*0.02:
        if str(com[:4]) == label:
            m2 = m2 + 1
    if i < all_length*0.03:
        if str(com[:4]) == label:
            m3 = m3 + 1
    if i < all_length*0.04:
        if str(com[:4]) == label:
            m4 = m4 + 1
    if i < all_length*0.05:
        if str(com[:4]) == label:
            m5 = m5 + 1
    if i < all_length*0.06:
        if str(com[:4]) == label:
            m6 = m6 + 1
    if i < all_length*0.07:
        if str(com[:4]) == label:
            m7 = m7 + 1
    if i < all_length*0.08:
        if str(com[:4]) == label:
            m8 = m8 + 1
    if i < all_length*0.09:
        if str(com[:4]) == label:
            m9 = m9 + 1
    if i < all_length*0.1:
        if str(com[:4]) == label:
            m10 = m10 + 1        

m1r = m1/(all_le)
m2r = m2/(all_le)
m3r = m3/(all_le)
m4r = m4/(all_le)
m5r = m5/(all_le)
m6r = m6/(all_le)
m7r = m7/(all_le)
m8r = m8/(all_le)
m9r = m9/(all_le)
m10r = m10/(all_le)

name = [ 'compound_id', 'pred_results', 'graph=smi+pocket_attention',
        'graph=smi+pocket_in_protein_attention','sequence=smi+pocket_attention',
        'sequence=smi+pocket_in_protein_attention',	'graph=smi_atominmol_indexes',
        'graph=atom_in_mol+pocket_attention(atominmol_indexes)','graph=atom_in_mol+protein_attention(atominmol_indexes)',
        'sequence=smi_subwinsmi_indexes', 'sequence=subword_in_smile+pocket_sequence_attention(subwinsmi_indexes)',	
        'sequence=subword_in_smile+protein_sequence_attention(subwinsmi_indexes)', 'pocinpro_indexes']

new_csv = pd.DataFrame(columns=name, data=nass)
new_csv = pd.concat([new_csv, pd.Series(f'the number of the active drugs in top 1% ::: {m1} / {all_le} == {m1r}'), 
                    pd.Series(f'the number of the active drugs in top 2% ::: {m2} / {all_le} == {m2r}'), 
                    pd.Series(f'the number of the active drugs in top 3% ::: {m3} / {all_le} == {m3r}'), 
                    pd.Series(f'the number of the active drugs in top 4% ::: {m4} / {all_le} == {m4r}'),
                    pd.Series(f'the number of the active drugs in top 5% ::: {m5} / {all_le} == {m5r}'), 
                    pd.Series(f'the number of the active drugs in top 6% ::: {m6} / {all_le} == {m6r}'), 
                    pd.Series(f'the number of the active drugs in top 7% ::: {m7} / {all_le} == {m7r}'), 
                    pd.Series(f'the number of the active drugs in top 8% ::: {m8} / {all_le} == {m8r}'),
                    pd.Series(f'the number of the active drugs in top 9% ::: {m9} / {all_le} == {m9r}'),
                    pd.Series(f'the number of the active drugs in top 10% ::: {m10} / {all_le} == {m10r}')], ignore_index=True)
save_path = './output/vs_lrrk2/'
os.makedirs(save_path, exist_ok=True)
new_csv.to_csv(f'{save_path}/vs_lrrk2.csv', encoding='gbk', index=False)
    
print(new_csv)


In [10]:
print(f'the number of the active drugs in top 1% ::: {m1} / {all_le} == {m1r}')
print(f'the number of the active drugs in top 2% ::: {m2} / {all_le} == {m2r}')
print(f'the number of the active drugs in top 3% ::: {m3} / {all_le} == {m3r}')
print(f'the number of the active drugs in top 4% ::: {m4} / {all_le} == {m4r}')
print(f'the number of the active drugs in top 5% ::: {m5} / {all_le} == {m5r}')
print(f'the number of the active drugs in top 6% ::: {m6} / {all_le} == {m6r}')
print(f'the number of the active drugs in top 7% ::: {m7} / {all_le} == {m7r}')
print(f'the number of the active drugs in top 8% ::: {m8} / {all_le} == {m8r}')
print(f'the number of the active drugs in top 9% ::: {m9} / {all_le} == {m9r}')
print(f'the number of the active drugs in top 10% ::: {m10} / {all_le} == {m10r}')