# 0 环境配置

In [1]:
import os
import sys
import json
import pickle
import timeit
import random
import argparse
import cairosvg
import statistics
from statistics import mean, median
from importlib import reload
from joblib import Parallel, delayed
from collections import Counter
from functools import reduce

In [2]:
base_path = os.path.abspath('')
sys.path.append(base_path)

In [4]:
import oddt
import numpy as np
import pandas as pd
import seaborn as sns
import nglview as nv
import matplotlib.pyplot as plt
from openbabel import openbabel
from icecream import ic
from scipy import stats
from sklearn import metrics
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from rdkit import Chem
from rdkit import RDLogger   
from rdkit.Chem import AllChem, Draw, rdMolAlign
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.Draw import DrawingOptions
from espsim.helpers import mlCharges
from pandarallel import pandarallel
from tqdm import tqdm



In [5]:
import utils.functions as fn
import utils.alignment as al
import utils.similarity as sm
import utils.metrics as mt

In [6]:
%matplotlib inline
plt.style.use('ggplot')
# plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题

sns.set_theme()

RDLogger.EnableLog('rdApp.*') 

random.seed = 2022

In [None]:
reload(fn)
reload(al)
reload(sm)
reload(mt)


# 1 测试

## 1.1 Mol2MolSupplier

In [None]:
# 1 Mol2MolSupplier
database=fn.Mol2MolSupplier(os.path.join(base_path, 'demo-data/pde5a/actives_final.mol2'), sanitize=True)
len(database)

## 1.2 DUD-E Correlation

In [None]:
# 2 auc_dude
auc_dude = 'ROCSComboscore 0.598 ± 0.152 0.681 ± 0.166 0.674 ± 0.115 0.727 ± 0.128 ROCSColorscore 0.620 ± 0.139 0.712 ± 0.159 0.677 ± 0.117 0.752 ± 0.136 ROCSShapeTanimoto 0.547 ± 0.138 0.611 ± 0.163 0.618 ± 0.105 0.667 ± 0.127 Phase Shape_Mmod 0.677 ± 0.143 0.686 ± 0.145 0.772 ± 0.105 0.769 ± 0.108 Phase Shape_Ele 0.674 ± 0.153 0.688 ± 0.158 0.753 ± 0.105 0.750 ± 0.111 Phase Shape_Pharm 0.692 ± 0.160 0.694 ± 0.168 0.761 ± 0.145 0.767 ± 0.143 Shape-it 0.541 ± 0.133 0.590 ± 0.141 0.612 ± 0.110 0.639 ± 0.115 Align-it 0.659 ± 0.137 0.680 ± 0.157 0.729 ± 0.132 0.746 ± 0.139 ShaEPbest 0.658 ± 0.122 0.660 ± 0.139 0.709 ± 0.099 0.699 ± 0.120 ShaEPshape 0.625 ± 0.139 0.632 ± 0.143 0.681 ± 0.105 0.676 ± 0.116 ShaEPESP 0.606 ± 0.109 0.591 ± 0.117 0.627 ± 0.105 0.585 ± 0.125 SHAFTS 0.733 ± 0.144 0.731 ± 0.157 0.792 ± 0.129 0.782 ± 0.135 WEGA 0.645 ± 0.143 0.659 ± 0.154 0.716 ± 0.107 0.716 ± 0.125 LIGSIFT 0.718 ± 0.133 0.755 ± 0.143 0.758 ± 0.117 0.784 ± 0.120 LS-align 0.699 ± 0.126 0.759 ± 0.119 0.773 ± 0.098 0.786 ± 0.096'
auc_dude_list = auc_dude.split(' ')
while 'Phase' in auc_dude_list:
    auc_dude_list.remove('Phase')

auc_dude_dict = []
for i in range(len(auc_dude_list)//13):
    auc_dude_dict.append([auc_dude_list[i*13]]+[float(ele) for ele in auc_dude_list[i*13+1:i*13+13:3]])

df = pd.DataFrame(auc_dude_dict)
df = df.sort_values(by=[1])
df

In [None]:
for i in [2, 3, 4]:
    pe, pv = stats.pearsonr(df[1].to_numpy(),df[i].to_numpy())
    ic(f'pe: {pe:>10}, pv: {pv:>10}')

In [None]:
pccs = np.corrcoef([df[i].to_numpy() for i in [1,2,3,4]])
ic(pccs)

In [None]:
sns.heatmap(pccs)

In [None]:
plt.figure(figsize=(5, 5), dpi=200)
for i in [1,2,3,4]:
    plt.scatter(list(range(df.shape[0])), df[i].tolist())

## 1.3 去重

In [None]:
# table = Chem.GetPeriodicTable()

# # for ele in [14, 16, 19]:
# #     ligand.GetAtomWithIdx(ele).SetNumExplicitHs(1)
# ligand = Chem.RemoveAllHs(ligand, sanitize=True)
# ligand.RemoveAllConformers()
# for atom in ligand.GetAtoms():
#     if atom.GetTotalValence() != table.GetDefaultValence(atom.GetAtomicNum()):
#         atom.SetFormalCharge(atom.GetTotalValence()-table.GetDefaultValence(atom.GetAtomMapNum()))
# Chem.SanitizeMol(ligand)
# ligand

In [None]:
m1 = Chem.MolFromMol2File(os.path.join(base_path, f'demo-data/pde5a/crystal_ligand.mol2'))
m2 = Chem.MolFromSmiles('CCCc1nn(c2c1nc([nH]c2=O)c1cc(ccc1OCC)S(=O)(=O)n1ccn(cc1)C)C')

In [None]:
m1.RemoveAllConformers()
m1

In [None]:
m2

In [None]:
ic(Chem.MolToInchiKey(m1))
ic(Chem.MolToInchiKey(m2))
ic(Chem.MolToSmiles(m1))
ic(Chem.MolToSmiles(m2))

## 1.4 绘图

In [None]:
# parser = argparse.ArgumentParser( 'smiles to png inmage' )
# parser.add_argument( 'smiles' )
# parser.add_argument( '--filename', default="mol." )
 
# parser.add_argument( 'smiles' )
 
# param = parser.parse_args()
# smiles = param.smiles
# fname = param.filename


In [None]:
def draw_mol(mol, save_dir, fname):
    DrawingOptions.atomLabelFontSize = 55
    DrawingOptions.dotsPerAngstrom = 100
    DrawingOptions.bondLineWidth = 3.0

    Draw.MolToFile(mol, os.path.join(save_dir, fname+".png"))
    Draw.MolToFile(mol, os.path.join(save_dir, fname+'.svg'))
    cairosvg.svg2png(url=os.path.join(save_dir, fname+'.svg'), write_to= os.path.join(save_dir, "svg_"+fname+".png"))

In [None]:
draw_mol(m1, '/home/jovyan/work-home/DUD-E/src/img', 'pde5a_cry_lig')

# 2 DUD-E

In [7]:
dude_dir = 'dude/all'
n_jobs = 8

In [None]:
target_list = ['pde5a', 'akt1', 'ada', 'andr', 'def', 'gria2', 'egfr', 'gcr', 'igf1r']
target_flag = 6
target_list[target_flag]

In [None]:
gz_a_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'actives_final.sdf.gz')
gz_d_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'decoys_final.sdf.gz')
for gz_path in [gz_a_path, gz_d_path]:
    if os.path.exists(gz_path):
        os.system(f"gzip -d {gz_path}")

In [None]:
# decoys_path 
d_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'decoys_final.sdf')
# actives_path 
a_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'actives_final.sdf')
# target_rec = os.path.join(base_path, f'demo-data/{target_list[target_flag]}/receptor_FH.pdb')
# ligand_path 
l_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'crystal_ligand.mol2')

## 2.0 数据去重

In [None]:
# act_405 = [mol for mol in act if mol.title=='CHEMBL405920']
act_405_path = os.path.join(base_path, dude_dir, 'ada', 'actives_CHE405.mol2')
# w = oddt.toolkits.rdk.Outputfile('mol2', act_405_path, overwrite=True)
# for m in act_405:
#     m.Mol = ''
#     w.write(m)
# w.close()
act_405 = fn.Mol2MolSupplier(act_405_path, sanitize=True)
Draw.MolsToGridImage(act_405, molsPerRow=8, subImgSize=(150,150), legends=[mol.GetProp('_Name') for mol in act_405],maxMols=100)

In [None]:
# act = fn.Mol2MolSupplier(a_path.split('.')[0]+'.mol2', sanitize=True)
# dec = fn.Mol2MolSupplier(d_path.split('.')[0]+'.mol2', sanitize=True)
act = [m for m in Chem.SDMolSupplier(a_path, removeHs=False) if m]
dec = [m for m in Chem.SDMolSupplier(d_path, removeHs=False) if m]

ic(len(act))
ic(len(dec))

In [None]:
def de_duplicate(m_list):
    name_l = []
    inck_l = []
    de_m_l = []

    for m in m_list:
        name = m.GetProp('_Name')
        inck = Chem.MolToInchiKey(m)
        if name in name_l:
            # ic('name duplicate')
            pass
        elif inck in inck_l:
            # ic('icnk duplicate')
            pass
        else:
            name_l.append(name)
            inck_l.append(inck)
            de_m_l.append(m)
    
    return de_m_l

In [None]:
act_dd = de_duplicate(act)
ic(len(act_dd))
with Chem.SDWriter(os.path.join(base_path, dude_dir, target_list[target_flag], 'actives_dd.sdf')) as w:
    for m in act_dd:
        w.write(m)

In [None]:
dec_dd = de_duplicate(dec)
ic(len(dec_dd))
with Chem.SDWriter(os.path.join(base_path, dude_dir, target_list[target_flag], 'decoys_dd.sdf')) as w:
    for m in dec_dd:
        w.write(m)

In [None]:
a_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'actives_dd.sdf')
d_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'decoys_dd.sdf')

In [None]:
def sdf2mol2(sdf_path, mol_path):
    sdf_l = [mol for mol in list(oddt.toolkit.readfile('sdf', sdf_path)) if mol]
    w = oddt.toolkits.rdk.Outputfile('mol2', mol_path, overwrite=True)
    for m in sdf_l:
        m.Mol = ''
        w.write(m)
    w.close()

In [None]:
sdf2mol2(a_path, a_path.split('.')[0]+'.mol2')
sdf2mol2(d_path, d_path.split('.')[0]+'.mol2')

In [None]:
t_l = [mol for mol in list(oddt.toolkit.readfile('mol2',a_path.split('.')[0]+'.mol2')) if mol]
len(t_l)

In [None]:
# m = AllChem.MolFromSmiles('c1ccc(cc1)COC(=O)c1cc2c(c(c1)CCN1C=NC3=C1NC=[NH]C[C@@H]3O)CCCC2')
# type(m)

## 2.1 数据加载

In [None]:
a_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'actives_dd.sdf')
d_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'decoys_dd.sdf')
l_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'crystal_ligand.mol2')
ic(a_path)

# decoys = fn.Mol2MolSupplier(d_path, sanitize=True)
decoys = [m for m in Chem.SDMolSupplier(d_path, removeHs=False) if m]
ic(len(decoys))
Draw.MolsToGridImage(random.sample(decoys, 7), molsPerRow=7, subImgSize=(150,150), legends=[mol.GetProp('_Name') for mol in decoys[:7]],maxMols=100)

actives = [m for m in Chem.SDMolSupplier(a_path, removeHs=False) if m]
ic(len(actives))
Draw.MolsToGridImage(random.sample(actives, 7), molsPerRow=7, subImgSize=(150,150), legends=[mol.GetProp('_Name') for mol in actives[:7]],maxMols=100)

ligand = AllChem.MolFromMol2File(l_path, removeHs=False)
ligand = Chem.AddHs(ligand, addCoords=True)
ligand

## 2.1 RDKIT

### 2.1.1 测试分子对齐

In [None]:
prbs, actives_o3a_score = al.rdkit_o3a(actives[10:20], ligand, n_jobs=32)
statistics.mean(actives_o3a_score)

In [None]:
view = fn.show_ligands([ligand]+prbs)
view

In [None]:
actives_crippeno3a_score = al.rdkit_crippeno3a(actives[:10], ligand)
statistics.mean(actives_crippeno3a_score)

In [None]:
view = fn.show_ligands([ligand]+actives[:10])
view

In [None]:
# rmsd = rdkit_alignmol(actives[:10], ligand)
# statistics.mean(rmsd)
# # RuntimeError: No sub-structure match found between the probe and query mol

In [None]:
decoys_o3a_score = al.rdkit_o3a(decoys[:10], ligand)
statistics.mean(decoys_o3a_score)

In [None]:
view = fn.show_ligands([ligand]+decoys[:10])
view

In [None]:
decoys_crippeno3a_score = al.rdkit_crippeno3a(decoys[:10], ligand)
statistics.mean(decoys_crippeno3a_score)

In [None]:
view = fn.show_ligands([ligand]+decoys[:10])
view

In [None]:
prb_crippen = [Chem.rdMolDescriptors._CalcCrippenContribs(mol) for mol in actives[:10]]
ref_crippen = Chem.rdMolDescriptors._CalcCrippenContribs(ligand)

In [None]:
actives_crippeno3a_score = al.rdkit_crippeno3a(actives[:10], ligand, prb_crippen=prb_crippen, ref_crippen=ref_crippen)
statistics.mean(actives_crippeno3a_score)

In [None]:
# prb_mmff = [AllChem.MMFFGetMoleculeProperties(mol) for mol in actives[0:10]]
# ref_mmff = AllChem.MMFFGetMoleculeProperties(ligand)
# decoys_o3a_score = rdkit_o3a(decoys[:10], ligand, prb_mmff=prb_mmff, ref_mmff=ref_mmff)
# statistics.mean(decoys_o3a_score)
# # RuntimeError: Invariant Violation
# # 	Negative weight specified for a point
# # 	Violation occurred on line 57 in file Code/Numerics/Alignment/AlignPoints.cpp
# # 	Failed Expression: wData[i] > 0.0
# # 	RDKIT: 2022.03.4
# # 	BOOST: 1_74

### 2.1.2 分子三维结构相似度计算

In [None]:
# actives_sampled = random.sample(actives, 100)
# decoys_sampled = random.sample(decoys, 500)
# RDLogger.DisableLog('rdApp.*') 
RDLogger.EnableLog('rdApp.*')

In [None]:
# actives_sampled_path = os.path.join(base_path, 'demo-data/pde5a/actives_sample_100.sdf')
# decoys_sampled_path = os.path.join(base_path, 'demo-data/pde5a/decoys_sample_500.sdf')


In [None]:
# writer = Chem.SDWriter(actives_sampled_path)
# for cid in range(len(actives_sampled)):
#     writer.write(actives_sampled[cid])
    
# writer = Chem.SDWriter(decoys_sampled_path)
# for cid in range(len(decoys_sampled)):
#     writer.write(decoys_sampled[cid])

In [None]:
from sklearn.metrics import dcg_score


# actives = [m for m in Chem.SDMolSupplier(actives_sampled_path) if m]
# decoys = [m for m in Chem.SDMolSupplier(decoys_sampled_path) if m]
# ligand = AllChem.MolFromMol2File(l_path)

len_a = len(actives)
len_d = len(decoys)
ic(len_a)
ic(len_d)

In [None]:
a_o_a_time = 0
start = timeit.default_timer()
actives_o, align_score =al.rdkit_o3a(actives, ligand, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
a_o_a_time = end - start
ic(a_o_a_time)
a_o_d_time = 0
start = timeit.default_timer()
decoys_o, align_score =al.rdkit_o3a(decoys, ligand, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
a_o_d_time = end - start
ic(a_o_d_time)

a_c_a_time = 0
start = timeit.default_timer()
actives_c, align_score =al.rdkit_crippeno3a(actives, ligand, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
a_c_a_time = end - start
ic(a_c_a_time)
a_c_d_time = 0
start = timeit.default_timer()
decoys_c, align_score =al.rdkit_crippeno3a(decoys, ligand, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
a_c_d_time = end - start
ic(a_c_d_time)

In [None]:
# writer = Chem.SDWriter(actives_sampled_path)
# for cid in range(len(actives_sampled)):
#     actives_sampled[cid].SetProp('_Name', f'activeInChIKey{Chem.MolToInchiKey(actives_sampled[cid])}')
#     writer.write(actives_sampled[cid])
    
# writer = Chem.SDWriter(decoys_sampled_path)
# for cid in range(len(decoys_sampled)):
#     decoys_sampled[cid].SetProp('_Name', f'InChIKey{Chem.MolToInchiKey(decoys_sampled[cid])}')
#     writer.write(decoys_sampled[cid])

In [None]:
# sampled_path = os.path.join(base_path, 'demo-data/pde5a/sample_600.sdf')
# writer = Chem.SDWriter(sampled_path)
# for cid in range(len(actives)):
#     writer.write(actives[cid])
# for cid in range(len(decoys)):
#     writer.write(decoys[cid])

In [None]:
# actives = actives[:10]
# decoys = decoys[:50]
# lan_a = len(actives)
# len_d = len(decoys)

In [None]:
def toy_test_similarity(s_func, prb_mols, ref_mol, align_time=0):
    start = timeit.default_timer()
    similarity_list = s_func(prb_mols, ref_mol, n_jobs=n_jobs, verbose=1)
    end = timeit.default_timer()

    print(f'mean   : {mean(similarity_list)}')
    print(f'median : {median(similarity_list)}')
    time = len(prb_mols) / (end - start + align_time)
    return similarity_list, time

In [None]:
sim_dict = dict()
tim_dict = dict()

In [None]:
similarity_methods = {
    'rdkit_protrude': sm.rdkit_shape_protrude_dist, 
    'rdkit_tanimoto': sm.rdkit_shape_tanimoto_dist,
    'rdkit_tversky': sm.rdkit_shape_tversky_index,
    'rdkit_sc': sm.rdkit_sc_score,
    'rdkit_sc_tanimoto': sm.rdkit_sc_score_tanimoto,
    'rdkit_sc_tversky': sm.rdkit_sc_score_tversky,
}

for name, func in similarity_methods.items():
    _name = name
    name = _name + '-o3a'    
    print(f'🌟 {name} 🌟')
    result_tmp = sim_dict.get(name, [])
    t_tmp = tim_dict.get(name, [])
    print('----👇 actives 👇----')
    similarity_score, time_a = toy_test_similarity(func, actives_o, ligand, a_o_a_time)
    result_tmp += [(ele, 1) for ele in similarity_score]
    t_tmp.append(time_a)
    print('----👇 decoys 👇----')
    similarity_score, time_d = toy_test_similarity(func, decoys_o, ligand, a_o_d_time)
    result_tmp += [(ele, 0) for ele in similarity_score]
    t_tmp.append(time_d)
    # assert len(result_tmp) == 600, f"len(result_tmp) = {len(result_tmp)}"
    sim_dict[name] = result_tmp
    tim_dict[name] = t_tmp
    print(f'speed: {mean([time_a, time_d])}')


    name = _name + '-crippeno3a'
    print(f'🌟 {name} 🌟')
    result_tmp = sim_dict.get(name, [])
    t_tmp = tim_dict.get(name, [])
    print('----👇 actives 👇----')
    similarity_score, time_a = toy_test_similarity(func, actives_c, ligand, a_c_a_time)
    result_tmp += [(ele, 1) for ele in similarity_score]
    t_tmp.append(time_a)
    print('----👇 decoys 👇----')
    similarity_score, time_d = toy_test_similarity(func, decoys_c, ligand, a_c_d_time)
    result_tmp += [(ele, 0) for ele in similarity_score]
    t_tmp.append(time_d)
    # assert len(result_tmp) == 600, f"len(result_tmp) = {len(result_tmp)}"
    sim_dict[name] = result_tmp
    tim_dict[name] = t_tmp
    print(f'speed: {mean([time_a, time_d])}')

In [None]:
similarity_methods = {
    'rdkit_pharm_tanimoto': sm.rdkit_pharm_tanimoto,
    'rdkit_pharm_tversky': sm.rdkit_pharm_tversky,
    'rdkit_fp_maccs': sm.rdkit_fp_maccs,
    'rdkit_fp_maccs_tanimoto': sm.rdkit_fp_maccs_tainimoto,
    'rdkit_fp_maccs_tversky': sm.rdkit_fp_maccs_tversky,
    'rdkit_fp_margan': sm.rdkit_fp_morgan,
    'rdkit_fp_margan_tanimoto': sm.rdkit_fp_morgan_tanimoto,
    'rdkit_fp_margan_tversky': sm.rdkit_fp_morgan_tversky,
}

for name, func in similarity_methods.items():  
    print(f'🌟 {name} 🌟')
    result_tmp = sim_dict.get(name, [])
    t_tmp = tim_dict.get(name, [])
    print('----👇 actives 👇----')
    similarity_score, time_a = toy_test_similarity(func, actives, ligand)
    result_tmp += [(ele, 1) for ele in similarity_score]
    t_tmp.append(time_a)
    print('----👇 decoys 👇----')
    similarity_score, time_d = toy_test_similarity(func, decoys, ligand)
    result_tmp += [(ele, 0) for ele in similarity_score]
    t_tmp.append(time_d)
    # assert len(result_tmp) == 600, f"len(result_tmp) = {len(result_tmp)}"
    sim_dict[name] = result_tmp
    tim_dict[name] = t_tmp
    print(f'speed: {mean([time_a, time_d])}')


In [None]:
# similarity_methods = {
#     'rdkit_protrude': sm.rdkit_shape_protrude_dist, 
#     'rdkit_tanimoto': sm.rdkit_shape_tanimoto_dist,
#     'rdkit_tversky': sm.rdkit_shape_tversky_index,
#     'rdkit_sc': sm.rdkit_sc_score,
#     'rdkit_sc_tanimoto': sm.rdkit_sc_score_tanimoto,
#     'rdkit_sc_tversky': sm.rdkit_sc_score_tversky,
#     'rdkit_pharm_tanimoto': sm.rdkit_pharm_tanimoto,
#     'rdkit_pharm_tversky': sm.rdkit_pharm_tversky,
#     'rdkit_fp_maccs': sm.rdkit_fp_maccs,
#     'rdkit_fp_maccs_tanimoto': sm.rdkit_fp_maccs_tainimoto,
#     'rdkit_fp_maccs_tversky': sm.rdkit_fp_maccs_tversky,
#     'rdkit_fp_margan': sm.rdkit_fp_morgan,
#     'rdkit_fp_margan_tanimoto': sm.rdkit_fp_morgan_tanimoto,
#     'rdkit_fp_margan_tversky': sm.rdkit_fp_morgan_tversky,
# }

# do_not_align = ['rdkit_pharm_tanimoto', 'rdkit_pharm_tversky','rdkit_fp_maccs', 'rdkit_fp_maccs_tanimoto', 'rdkit_fp_maccs_tversky', 'rdkit_fp_margan', 'rdkit_fp_margan_tanimoto', 'rdkit_fp_margan_tversky']

# align_funcs_map_a = {
#     'o3a': a_o_a_time, 
#     'crippeno3a': a_c_a_time,
# }
# align_funcs_map_d = {
#     'o3a': a_o_d_time, 
#     'crippeno3a': a_c_d_time
# }
# align_funcs_list = ['o3a', 'crippeno3a']

# sim_dict = dict()
# tim_dict = dict()

# for name, func in similarity_methods.items():
#     ic(f'🌟 {name} 🌟')
#     actives_sim_list = []
#     decoys_sim_list = []
#     ic('----👇 actives 👇----')
#     align_funcs = align_funcs_list if name not in do_not_align else [None]
#     tmp_time_list = []
#     for align_name in align_funcs:
#         if align_name:
#             result_name = f'{name}-{align_name}'
#             align_func = align_funcs_map_a[align_name]
#             actives = actives_o if align_name=='o3a' else actives_c
#         else:
#             result_name = name
#         tmp_sim, time = toy_test_similarity(
#                             func, 
#                             actives, 
#                             ligand, 
#                             align_func=align_func
#                             )
#         result_tmp = sim_dict.get(result_name, [])
#         result_tmp += [(ele, 1) for ele in tmp_sim]
#         sim_dict[result_name] = result_tmp
#         tmp_time_list.append(time)
        
#         t_tmp = tim_dict.get(result_name, [])
#         t_tmp.append(time)
#         tim_dict[result_name] = t_tmp

#     ic('----👇 decoys 👇----')
#     for align_name in align_funcs:
#         if align_name:
#             result_name = f'{name}-{align_name}'
#             align_func = align_funcs_map_d[align_name]
#             decoys = decoys_o 
#         else:
#             result_name = name
#         tmp_sim, time = toy_test_similarity(
#                         func, 
#                         decoys, 
#                         ligand, 
#                         align_func=align_func
#                         )
#         result_tmp = sim_dict.get(result_name, [])
#         result_tmp += [(ele, 0) for ele in tmp_sim]
#         assert len(result_tmp) == (len_a + len_d), f"len(result_tmp) = {len(result_tmp)}"
#         sim_dict[result_name] = result_tmp
#         tmp_time_list.append(time)

#         t_tmp = tim_dict.get(result_name, [])
#         t_tmp.append(time)
#         tim_dict[result_name] = t_tmp
    
#     for i in range(len(tmp_time_list)//2):
#         ic(mean(tmp_time_list[i::len(tmp_time_list)//2]))

In [None]:
for name, score_list in sim_dict.items():
    mean_actives_score = mean([ele[0] for ele in score_list if ele[1]==1])
    mean_decoys_score = mean([ele[0] for ele in score_list if ele[1]==0])
    ic(name)
    ic(mean_actives_score)
    ic(mean_decoys_score)

## 2.2 ODDT

### 2.2.1 测试

In [None]:
next(oddt.toolkit.readfile('mol2', l_path))

In [None]:
tmp = [mol for mol in list(oddt.toolkit.readfile('sdf', d_path)) if mol]
len(tmp)

In [None]:
tmp[0]

### 2.2.2 相似度计算

In [None]:
ligand = next(oddt.toolkit.readfile('mol2', l_path))
actives = [mol for mol in oddt.toolkit.readfile('sdf', a_path) if mol]
decoys = [mol for mol in oddt.toolkit.readfile('sdf', d_path) if mol]
# actives = [mol for mol in oddt.toolkit.readfile('sdf', actives_sampled_path) if mol]
# decoys = [mol for mol in oddt.toolkit.readfile('sdf', decoys_sampled_path) if mol]

len_a = len(actives)
len_d = len(decoys)

In [None]:
# actives = actives[:10]
# decoys = decoys[:50]
# len_a = len(actives)
# len_d = len(decoys)

In [None]:
ic(len_a)
ic(len_d)

In [None]:
def toy_test_similarity(s_func, prb_mols, ref_mol):
    start = timeit.default_timer()
    score_list = s_func(prb_mols, ref_mol, n_jobs=n_jobs, verbose=1)
    end = timeit.default_timer()
    
    print(f'mean   : {mean(score_list)}')
    print(f'median : {median(score_list)}')

    return score_list, len(prb_mols)/(end - start)


In [None]:
similarity_methods = {
    'oddt_usr': sm.oddt_usr,
    'oddt_usr_cat': sm.oddt_usr_cat,
    'oddt_electroshape': sm.oddt_electroshape,
}
for name, func in similarity_methods.items():
    print(f'🌟 {name} 🌟')
    result_tmp = sim_dict.get(name, [])
    t_tmp = tim_dict.get(name, [])
    print('----👇 actives 👇----')
    similarity_score, time_a = toy_test_similarity(func, actives, ligand)
    result_tmp += [(ele, 1) for ele in similarity_score]
    t_tmp.append(time_a)
    print('----👇 decoys 👇----')
    similarity_score, time_d = toy_test_similarity(func, decoys, ligand)
    result_tmp += [(ele, 0) for ele in similarity_score]
    t_tmp.append(time_d)
    # assert len(result_tmp) == 600, f"len(result_tmp) = {len(result_tmp)}"
    sim_dict[name] = result_tmp
    tim_dict[name] = t_tmp
    print(f'speed: {mean([time_a, time_d])}')

In [None]:
sim_dict.keys()

In [None]:
save_dir = os.path.join(base_path, f'demo-data/result/{target_list[target_flag]}/')

In [None]:
if not os.path.exists(save_dir):
    os.system(f"mkdir {save_dir}")

In [None]:
save_dir = os.path.join(base_path, f'demo-data/result/{target_list[target_flag]}/')
with open(os.path.join(save_dir, f'{target_list[target_flag]}-sim_dict.pickle'), 'wb') as f:
    pickle.dump(sim_dict, f)
with open(os.path.join(save_dir, f'{target_list[target_flag]}-tim_dict.pickle'), 'wb') as f:
    pickle.dump(tim_dict, f)

In [None]:
save_dir = os.path.join(base_path, f'demo-data/result/{target_list[target_flag]}/')
ic(save_dir)
with open(os.path.join(save_dir, f'{target_list[target_flag]}-sim_dict.pickle'), 'rb') as f:
    sim_dict = pickle.load(f)
with open(os.path.join(save_dir, f'{target_list[target_flag]}-tim_dict.pickle'), 'rb') as f:
    tim_dict = pickle.load(f)
sim_dict.keys()

## 2.3 acpc

In [None]:
a_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'actives_dd.sdf')
d_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'decoys_dd.sdf')
l_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'crystal_ligand.mol2')
ic(a_path)

# decoys = fn.Mol2MolSupplier(d_path, sanitize=True)
dec = [m for m in Chem.SDMolSupplier(d_path, removeHs=False) if m]
ic(len(dec))
# Draw.MolsToGridImage(random.sample(decoys, 7), molsPerRow=7, subImgSize=(150,150), legends=[mol.GetProp('_Name') for mol in decoys[:7]],maxMols=100)

act = [m for m in Chem.SDMolSupplier(a_path, removeHs=False) if m]
ic(len(act))
# Draw.MolsToGridImage(random.sample(actives, 7), molsPerRow=7, subImgSize=(150,150), legends=[mol.GetProp('_Name') for mol in actives[:7]],maxMols=100)

lig = AllChem.MolFromMol2File(l_path, removeHs=False)
lig = Chem.AddHs(lig, addCoords=True)
lig

In [None]:
sim_dict = dict()
tim_dict = dict()

In [None]:
for mol in act:
    mol.SetProp('_Name', f'activesInChIKey{AllChem.MolToInchiKey(mol)}')
for mol in dec:
    mol.SetProp('_Name', f'InChIKey{AllChem.MolToInchiKey(mol)}')

### 2.3.1 mol2

In [None]:
prb_acpc_ml2_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'prb_acpc.mol2')
ref_acpc_ml2_path = l_path
out_acpc_ml2_path = os.path.join(base_path, 'demo-data/result/acpc', target_list[target_flag], f'{target_list[target_flag]}-ml2-result.txt')

fn.MolToSpecialFormatFile(act+dec, prb_acpc_ml2_path)

In [None]:
sat = timeit.default_timer()
sim_dict['acpc-ml2'] = sm.acpc(ref_path=ref_acpc_ml2_path, prb_path=prb_acpc_ml2_path, out_path=out_acpc_ml2_path, num_core=n_jobs)
end = timeit.default_timer()
tim_dict['acpc-ml2'] = [len(act+dec)/(end-sat)]

In [None]:
prb_acpc_ml2_bin_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'prb_acpc.mol2.bin')
out_acpc_ml2_bin_path = os.path.join(base_path, 'demo-data/result/acpc', target_list[target_flag], f'{target_list[target_flag]}-ml2-bin-result.txt')
os.system(f'acpc_codec -i {prb_acpc_ml2_path} -o {prb_acpc_ml2_bin_path}')

In [None]:
sat = timeit.default_timer()
sim_dict['acpc-ml2-bin'] = sm.acpc(ref_path=ref_acpc_ml2_path, prb_path=prb_acpc_ml2_bin_path, out_path=out_acpc_ml2_bin_path, num_core=n_jobs)
end = timeit.default_timer()
tim_dict['acpc-ml2-bin'] = [len(act+dec)/(end-sat)]

### 2.3.2 pqr

In [None]:
prb_acpc_pqr_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'prb_acpc.pqr')
fn.MolToSpecialFormatFile(act+dec, prb_acpc_pqr_path)
ref_acpc_pqr_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'ref_acpc.pqr')
fn.MolToSpecialFormatFile([lig], ref_acpc_pqr_path)
out_acpc_pqr_path = os.path.join(base_path, 'demo-data/result/acpc', target_list[target_flag], f'{target_list[target_flag]}-pqr-result.txt')

In [None]:
sat = timeit.default_timer()
sim_dict['acpc-pqr'] = sm.acpc(ref_path=ref_acpc_pqr_path, prb_path=prb_acpc_pqr_path, out_path=out_acpc_pqr_path, num_core=n_jobs)
end = timeit.default_timer()
tim_dict['acpc-pqr'] = [len(act+dec)/(end-sat)]

## 2.4 ESP-SIM

In [None]:
a_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'actives_dd.sdf')
d_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'decoys_dd.sdf')
l_path = os.path.join(base_path, dude_dir, target_list[target_flag], 'crystal_ligand.mol2')
ic(a_path)
ic(d_path)
ic(l_path)

# decoys = fn.Mol2MolSupplier(d_path, sanitize=True)
dec = [m for m in Chem.SDMolSupplier(d_path, removeHs=False) if m]
ic(len(dec))
# Draw.MolsToGridImage(random.sample(decoys, 7), molsPerRow=7, subImgSize=(150,150), legends=[mol.GetProp('_Name') for mol in decoys[:7]],maxMols=100)

act = [m for m in Chem.SDMolSupplier(a_path, removeHs=False) if m]
ic(len(act))
# Draw.MolsToGridImage(random.sample(actives, 7), molsPerRow=7, subImgSize=(150,150), legends=[mol.GetProp('_Name') for mol in actives[:7]],maxMols=100)

lig = AllChem.MolFromMol2File(l_path, removeHs=False)
lig = Chem.AddHs(lig, addCoords=True)
lig

In [None]:
act[0]

In [None]:
len_a = len(act)
len_d = len(dec)

In [None]:
sim_dict = dict()
tim_dict = dict()

In [None]:
ali_o3a_time = 0
sta = timeit.default_timer()
act_o3a, align_score =al.rdkit_o3a(act, lig, n_jobs=n_jobs, verbose=1)
dec_o3a, align_score =al.rdkit_o3a(dec, lig, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
ali_o3a_time = end - sta
ic(ali_o3a_time)

ali_cri_time = 0
sta = timeit.default_timer()
act_cri, align_score =al.rdkit_crippeno3a(act, lig, n_jobs=n_jobs, verbose=1)
dec_cri, align_score =al.rdkit_crippeno3a(dec, lig, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
ali_cri_time = end - sta
ic(ali_cri_time)

In [None]:
sat = timeit.default_timer()
act_sim_list = sm.esp_sim(act_cri, lig, n_jobs=n_jobs, verbose=1)
dec_sim_list = sm.esp_sim(dec_cri, lig, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
sim_dict['esp-sim-crippeno3a-tanimoto-mmff-tanimoto'] = [(sim, 1) for sim in act_sim_list] + [(sim, 0) for sim in dec_sim_list]
tim_dict['esp-sim-crippeno3a-tanimoto-mmff-tanimoto'] = [(len_a + len_d) / (end - sat + ali_cri_time)]

In [None]:
sat = timeit.default_timer()
act_sim_list = sm.esp_sim(act_o3a, lig, n_jobs=n_jobs, verbose=1)
dec_sim_list = sm.esp_sim(dec_o3a, lig, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
sim_dict['esp-sim-o3a-tanimoto-mmff-tanimoto'] = [(sim, 1) for sim in act_sim_list] + [(sim, 0) for sim in dec_sim_list]
tim_dict['esp-sim-o3a-tanimoto-mmff-tanimoto'] = [(len_a + len_d) / (end - sat + ali_o3a_time)]

In [None]:
sat = timeit.default_timer()
act_sim_list = sm.esp_sim(act_o3a, lig, shape_sim='protrude', n_jobs=n_jobs, verbose=1)
dec_sim_list = sm.esp_sim(dec_o3a, lig, shape_sim='protrude', n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
sim_dict['esp-sim-o3a-protrude-mmff-tanimoto'] = [(sim, 1) for sim in act_sim_list] + [(sim, 0) for sim in dec_sim_list]
tim_dict['esp-sim-o3a-protrude-mmff-tanimoto'] = [(len_a + len_d) / (end - sat + ali_o3a_time)]

In [None]:
sat = timeit.default_timer()
act_sim_list = sm.esp_sim(act_o3a, lig, shape_sim='tversky', n_jobs=n_jobs, verbose=1)
dec_sim_list = sm.esp_sim(dec_o3a, lig, shape_sim='tversky', n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
sim_dict['esp-sim-o3a-tversky-mmff-tanimoto'] = [(sim, 1) for sim in act_sim_list] + [(sim, 0) for sim in dec_sim_list]
tim_dict['esp-sim-o3a-tversky-mmff-tanimoto'] = [(len_a + len_d) / (end - sat + ali_o3a_time)]

In [None]:
sat = timeit.default_timer()
act_sim_list = sm.esp_sim(act_o3a, lig, systems='gasteiger', n_jobs=n_jobs, verbose=1)
dec_sim_list = sm.esp_sim(dec_o3a, lig, systems='gasteiger', n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
sim_dict['esp-sim-o3a-tanimoto-gasteiger-tanimoto'] = [(sim, 1) for sim in act_sim_list] + [(sim, 0) for sim in dec_sim_list]
tim_dict['esp-sim-o3a-tanimoto-gasteiger-tanimoto'] = [(len_a + len_d) / (end - sat + ali_o3a_time)]

In [None]:
sat = timeit.default_timer()
prb_charge = mlCharges(act+dec)
ref_charge = mlCharges([lig])[0]
end = timeit.default_timer()
ml_char_time = end - sat

In [None]:
sat = timeit.default_timer()
act_sim_list = sm.esp_sim(act_o3a, lig, systems='ml', prb_charge=prb_charge[:len_a], ref_charge=ref_charge, n_jobs=n_jobs, verbose=1)
dec_sim_list = sm.esp_sim(dec_o3a, lig, systems='ml', prb_charge=prb_charge[len_a:], ref_charge=ref_charge, n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
sim_dict['esp-sim-o3a-tanimoto-ml-tanimoto'] = [(sim, 1) for sim in act_sim_list] + [(sim, 0) for sim in dec_sim_list]
tim_dict['esp-sim-o3a-tanimoto-ml-tanimoto'] = [(len_a + len_d) / (end - sat + ali_o3a_time + ml_char_time)]

In [None]:
sat = timeit.default_timer()
act_sim_list = sm.esp_sim(act_o3a, lig, metric='carbo', n_jobs=n_jobs, verbose=1)
dec_sim_list = sm.esp_sim(dec_o3a, lig, metric='carbo', n_jobs=n_jobs, verbose=1)
end = timeit.default_timer()
sim_dict['esp-sim-o3a-tanimoto-mmff-carbo'] = [(sim, 1) for sim in act_sim_list] + [(sim, 0) for sim in dec_sim_list]
tim_dict['esp-sim-o3a-tanimoto-mmff-carbo'] = [(len_a + len_d) / (end - sat + ali_o3a_time)]

In [None]:
sim_dict.keys()

In [None]:
save_dir = os.path.join(base_path, f'demo-data/result/espsim/{target_list[target_flag]}/')
if not os.path.exists(save_dir):
    os.system(f"mkdir -p {save_dir}")

In [None]:
with open(os.path.join(save_dir, f'{target_list[target_flag]}-sim_dict.pickle'), 'wb') as f:
    pickle.dump(sim_dict, f)
with open(os.path.join(save_dir, f'{target_list[target_flag]}-tim_dict.pickle'), 'wb') as f:
    pickle.dump(tim_dict, f)

In [None]:
ic(save_dir)
with open(os.path.join(save_dir, f'{target_list[target_flag]}-sim_dict.pickle'), 'rb') as f:
    sim_dict = pickle.load(f)
with open(os.path.join(save_dir, f'{target_list[target_flag]}-tim_dict.pickle'), 'rb') as f:
    tim_dict = pickle.load(f)
sim_dict.keys()

## 2.3 评价指标

In [None]:
i = 1
for name, score in sim_dict.items():
    plt.figure(figsize=(5, 5), dpi=200)
    
    score_list = sorted(score, key=lambda x: x[0])

    y_lable = np.asarray([ele[1] for ele in score_list])
    y_score = np.asarray([ele[0] for ele in score_list])

    fpr, tpr, _ = metrics.roc_curve(y_lable, y_score)
    roc_auc = metrics.auc(fpr, tpr)

    plt.plot(
        fpr,
        tpr,
        color="darkorange",
        lw=2,
        label="ROC curve (area = %0.2f)" % roc_auc,
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC {name}")
    plt.legend(loc="lower right")
    plt.savefig(os.path.join(save_dir, f'{i :0>2}-roc-{name}.png'))
    plt.show()
    i += 1

In [None]:
i = 1
for name, score in sim_dict.items():
    fig, ax = plt.subplots(figsize=(5, 5), dpi=200)
    
    score_list = sorted(score, key=lambda x: x[0])

    y_lable = np.asarray([ele[1] for ele in score_list])
    y_score = np.asarray([ele[0] for ele in score_list])

    fpr, tpr, _ = metrics.roc_curve(y_lable, y_score)
    
    log_min = 0.001
    log_max = 1.0

    fpr = fpr.clip(log_min)
    idx = (fpr <= log_max)
    log_fpr = 1 - np.log10(fpr[idx]) / np.log10(log_min)
    
    log_roc_auc = metrics.auc(log_fpr, tpr[idx])


    plt.plot(
        fpr[idx],
        tpr[idx],
        color="darkorange",
        lw=2,
        label="logROC curve (area = %0.2f)" % log_roc_auc,
    )
    x = np.linspace(0.001, 1, 200)
    y = x
    plt.plot(x, y, color="navy", lw=2, linestyle="--")

    plt.ylim([0.0, 1.05])
    plt.xscale('log')
    plt.xlim([0.001, 1])
    ax.set_xticks([0.001, 0.01, 0.1, 1.0])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"logROC {name}")
    plt.legend(loc="lower right")
    plt.savefig(os.path.join(save_dir, f'{i :0>2}-log_roc-{name}.png'))
    plt.show()
    i += 1

In [None]:
x = np.linspace(0.001, 1, 200)
y = x
x = 1 - np.log10(x) / np.log10(log_min)
plt.plot(x, y, color="navy", lw=2, linestyle="--")
# plt.xscale('log')

In [None]:
x = np.linspace(0.001, 1, 200)
y = x
plt.plot(x, y, color="navy", lw=2, linestyle="--")
plt.xscale('log')

In [None]:
# np.asarray([1,3,4,5,6,7]).clip(4)
sorted([1,2,5,6,7], key=lambda x: -x)

In [None]:
result_list = []
for name, score in sim_dict.items():
    score_list = sorted(score, key=lambda x: -x[0])

    y_score = np.asarray([ele[0] for ele in score_list])
    y_lable = np.asarray([ele[1] for ele in score_list])
    
    metrics_list = mt.calculate_metrics(y_lable=y_lable, y_score=y_score)

    print(f'--- 🌟 {name} 🌟 ---')
    tmp_dict = {
        'name' : name,
        'time' : mean(tim_dict[name])
    }
    print(f'time: {mean(tim_dict[name])}')
    for m_name, m_value in metrics_list.items():
        tmp_dict[m_name] = m_value
        print(f'{m_name:<30}: {m_value}')
    result_list.append(tmp_dict)

In [None]:
pd.DataFrame(result_list).to_excel(os.path.join(save_dir, f'{target_list[target_flag]}_results.xlsx'))

In [None]:
molcalx = '/home/jovyan/work-home/molecule-3d-similarity/pde5_score.csv'
df = pd.read_csv(molcalx)
df_sorted = df.sort_values(by=['score'],ascending=True)
score = df_sorted['score']
x = np.array(score)*(-1)
y = df_sorted['label']
label_map = {'active': 1, 'decoy': 0}
labels = [label_map[y_true] for y_true in y]
y = np.array(labels)

In [None]:
auc = m.roc_auc(y, x, pos_label=1, ascending_score=False)
print("ROC AUC = ",auc)
bedroc = m.bedroc(y, x, alpha=20.0, pos_label=1)
print("alpha=20.0 BEDROC = ",bedroc)
logauc = m.roc_log_auc(y, x, pos_label=1, ascending_score=False, log_min=0.001, log_max=1.0)
print("logAUC [0.1%,100%]  = ",logauc)

In [None]:
result_string = 'ROCSComboscore 0.598 ± 0.152 0.681 ± 0.166 0.674 ± 0.115 0.727 ± 0.128 ROCSColorscore 0.620 ± 0.139 0.712 ± 0.159 0.677 ± 0.117 0.752 ± 0.136 ROCSShapeTanimoto 0.547 ± 0.138 0.611 ± 0.163 0.618 ± 0.105 0.667 ± 0.127 Phase Shape_Mmod 0.677 ± 0.143 0.686 ± 0.145 0.772 ± 0.105 0.769 ± 0.108 Phase Shape_Ele 0.674 ± 0.153 0.688 ± 0.158 0.753 ± 0.105 0.750 ± 0.111 Phase Shape_Pharm 0.692 ± 0.160 0.694 ± 0.168 0.761 ± 0.145 0.767 ± 0.143 Shape-it 0.541 ± 0.133 0.590 ± 0.141 0.612 ± 0.110 0.639 ± 0.115 Align-it 0.659 ± 0.137 0.680 ± 0.157 0.729 ± 0.132 0.746 ± 0.139 ShaEPbest 0.658 ± 0.122 0.660 ± 0.139 0.709 ± 0.099 0.699 ± 0.120 ShaEPshape 0.625 ± 0.139 0.632 ± 0.143 0.681 ± 0.105 0.676 ± 0.116 ShaEPESP 0.606 ± 0.109 0.591 ± 0.117 0.627 ± 0.105 0.585 ± 0.125 SHAFTS 0.733 ± 0.144 0.731 ± 0.157 0.792 ± 0.129 0.782 ± 0.135 WEGA 0.645 ± 0.143 0.659 ± 0.154 0.716 ± 0.107 0.716 ± 0.125 LIGSIFT 0.718 ± 0.133 0.755 ± 0.143 0.758 ± 0.117 0.784 ± 0.120 LS-align 0.699 ± 0.126 0.759 ± 0.119 0.773 ± 0.098 0.786 ± 0.096'

In [None]:
result_string.split(' ')[13]

In [None]:
result_string.split(' ')[14]

In [None]:
cry_single = [0.598, 0.620, ]

In [None]:
pccs = np.corrcoef(x, y)

## 2.4 可视化

In [None]:
seed_ligands = {
    'pde5a': [
        'CCCc1nn(c2c1nc([nH]c2=O)c1cc(ccc1OCC)S(=O)(=O)n1ccn(cc1)C)C',
        'CCCc1nn(c2c1nc([nH]c2=O)c1cc(ccc1OCC)S(=O)(=O)N1CCN(CC1)Cc1ccc2c(c1)OCO2)C',
    ],
    'akt1': [
        'Clc1c[nH]c2c1c(ncn2)N1CCc2c(C1)[nH]cn2',
        '[NH3+]C(Cc1c[nH]c2c1cccc2)COc1cncc(c1)c1ccc2c(c1)C(C(=O)N2)C(Cc1ccccn1)(C)C'
    ],
    'ada': [
        'CC(C(n1cnc(c1)C(=O)N)CCc1cccc2c1cccc2)O',
        'CC(C(n1cnc2c1ccnc2N)CCCCCC)O'
    ],
    'andr': [
        'OC1CCC2(C(=C1)CCC1C2CCC2(C1CCC2O)C)C',
        'N#Cc1ccc(cc1C(F)(F)F)N1C(=O)C2C(C1=O)C1(OC2(C)CN(C1)c1cccc(c1)C(=O)N)C'
    ],
    'def': [
        'CCCCCC(C(=O)NC(C(=O)N1CCCC1CO)C(C)C)CC(=O)NO',
        'O=CN(CC(C(=O)NC(C(C)(C)C)C(=O)c1ccc(cc1)F)CC1CCCC1)O'
    ],
    'gria2': [
        'O=c1[nH]c2cc(c(cc2n(c1=O)CP(=O)(O)O)N1CCOCC1)C(F)(F)F',
        '[O-][N+](=O)c1cc2[nH]c(=O)c(nc2cc1n1ccc(c1)C(=O)[O-])[O-]',
    ]
}

In [None]:
cry_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(seed_ligands[target_list[target_flag]][0]))
clu_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(seed_ligands[target_list[target_flag]][1]))
cry_inchi = Chem.MolToInchiKey(Chem.MolFromSmiles(cry_smiles))
clu_inchi = Chem.MolToInchiKey(Chem.MolFromSmiles(clu_smiles))

cry_mol = Chem.MolFromSmiles(cry_smiles)
cry_mol = AllChem.AddHs(cry_mol, addCoords=True)
clu_mol = Chem.MolFromSmiles(clu_smiles)
clu_mol = AllChem.AddHs(clu_mol, addCoords=True)

In [None]:
for m in actives:
    tmp_smiles = Chem.MolToSmiles(m)
    tmp_inchi = Chem.MolToInchiKey(m)

    if cry_smiles == tmp_smiles:
        print('cry find')
    
    if clu_inchi == tmp_inchi:
        print('clu find')   

In [None]:
fp_list = []
for m in actives:
    fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(m,2,useFeatures=True)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(ligand,2,useFeatures=True)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(cry_mol,2,useFeatures=True)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(clu_mol,2,useFeatures=True)))

tsne = TSNE(n_components=2, random_state=2022)
z = tsne.fit_transform(fp_list)



In [None]:
plt.figure(figsize=(5, 5), dpi=200)
plt.scatter(z[:-2, 0], z[:-2, 1])
plt.scatter(z[-3][0], z[-3][1])
# plt.scatter(z[-2][0], z[-2][1], color='beige')
# plt.scatter(z[67][0], z[67][1], color='beige')
# plt.scatter(z[-1][0], z[-1][1], color='springgreen')
# plt.scatter(z[285][0], z[285][1], color='springgreen')
# plt.scatter(z[215][0], z[215][1], color='lightcyan')
plt.scatter(z[291][0], z[291][1], color='lightcyan')
# plt.scatter(z[259][0], z[259][1], color='lightskyblue')
# plt.scatter(z[119][0], z[119][1], color='lightskyblue')
# plt.scatter(z[26][0], z[26][1], color='lightcoral')
# plt.scatter(z[46][0], z[46][1], color='lightcoral')
# seed = random.sample(list(range(len(actives))), 1)
# ic(seed)
# plt.scatter(z[seed[0]][0], z[seed[0]][1], color='beige')

In [None]:
from rdkit.DataStructs import TanimotoSimilarity, TverskySimilarity
fp1 = AllChem.GetMorganFingerprintAsBitVect(actives[-97],2,useFeatures=True)
fp2 = AllChem.GetMorganFingerprintAsBitVect(ligand,2,useFeatures=True)
TanimotoSimilarity(fp1, fp2)

In [None]:
# 285
# -1

# 67
# -2

# 215
# 291

# 259
# 119

# 26
# 46

In [None]:
len(actives)

In [None]:
actives[26].RemoveAllConformers()
actives[26]

In [None]:
actives[-12].RemoveAllConformers()
actives[-12]

In [None]:
actives[-15].RemoveAllConformers()
actives[-15]

In [None]:
actives[-97].RemoveAllConformers()
actives[-97]

In [None]:
actives[-99].RemoveAllConformers()
actives[-99]

In [None]:
actives[-5].RemoveAllConformers()
actives[-5]

In [None]:
clu_mol

In [None]:
cry_mol

In [None]:
ligand.RemoveAllConformers()
ligand

In [None]:
actives[-8].RemoveAllConformers()
actives[-8]

In [None]:
fp_list = []
for m in actives:
    fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(m,2)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(ligand,2)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(cry_mol,2)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(clu_mol,2)))

tsne = TSNE(n_components=2, random_state=2022)
z = tsne.fit_transform(fp_list)

plt.figure(figsize=(5, 5), dpi=200)
plt.scatter(z[:-2, 0], z[:-2, 1])
plt.scatter(z[-3][0], z[-3][1])
plt.scatter(z[-2][0], z[-2][1])
plt.scatter(z[-1][0], z[-1][1])

In [None]:
fp_list = []
for m in actives:
    fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(m,2,useFeatures=True)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(ligand,2,useFeatures=True)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(cry_mol,2,useFeatures=True)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(clu_mol,2,useFeatures=True)))

tsne = PCA(n_components=2, random_state=2022)
z = tsne.fit_transform(fp_list)

plt.figure(figsize=(5, 5), dpi=200)
plt.scatter(z[:-2, 0], z[:-2, 1])
plt.scatter(z[-3][0], z[-3][1])
plt.scatter(z[-2][0], z[-2][1])
plt.scatter(z[-1][0], z[-1][1])

In [None]:
fp_list = []
for m in actives:
    fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(m,2)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(ligand,2)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(cry_mol,2)))
fp_list.append(np.asarray(AllChem.GetMorganFingerprintAsBitVect(clu_mol,2)))

tsne = PCA(n_components=2, random_state=2022)
z = tsne.fit_transform(fp_list)

plt.figure(figsize=(5, 5), dpi=200)
plt.scatter(z[:-2, 0], z[:-2, 1])
plt.scatter(z[-3][0], z[-3][1])
plt.scatter(z[-2][0], z[-2][1])
plt.scatter(z[-1][0], z[-1][1])

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem

In [None]:
mol = Chem.MolFromSmiles('c1ccccc(C(N)=O)1')
mol
mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol,randomSeed=0xf00d)
AllChem.MMFFOptimizeMolecule(mol)
mol

In [None]:
ic(Chem.MolToMolBlock(mol))
for atom in mol.GetAtoms():
    ic(atom.GetAtomicNum())
    ic(atom.GetIdx())
ic(dir(mol.GetConformer(0)))
ic(mol.GetConformer(0).GetPositions())

In [None]:
for atm in mol.GetAtoms():
    ic(dir(atm))
    ic(atm.GetAtomicNum())
    break

In [None]:
Chem.GetAdjacencyMatrix(mol)[0]

In [None]:
[a.GetAtomicNum() for a in mol.GetAtoms()]

In [None]:
[a.GetBonds() for a in mol.GetAtoms()][6][].GetBondType()

In [None]:
# Create an rdkit mol object
# Iterate over the atoms
for i, atom in enumerate(mol.GetAtoms()):
    # For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule
    atom.SetProp("molAtomMapNumber", str(atom.GetIdx()+1))

mol

In [None]:
tmp_list = []
for root, dirs, files in os.walk(os.path.join(base_path, dude_dir), topdown=False):
    tmp_dict = {'act_ism': None, 'dec_ism': None}
    for name in files:
        if   'actives_final.ism' in name:
            tmp_dict['act_ism'] = os.path.join(root, name)
        elif 'decoys_final.ism' in name:
            tmp_dict['dec_ism'] = os.path.join(root, name)
    tmp_list.append(tmp_dict)

In [None]:
ism_path_df = pd.DataFrame(tmp_list).dropna()

In [None]:
pandarallel.initialize(progress_bar=True)

In [None]:
def extract_info_from_smiles(series: pd.Series):
    print(f"🌟  {series.name+1}-{series['act_ism'].split('/')[-2]}  🌟")
    act = pd.read_csv(series['act_ism'], sep=' ',  header=None, names=['smiles', '-', 'chemblid'])
    dec = pd.read_csv(series['dec_ism'], sep=' ',  header=None, names=['smiles', '-', 'chemblid'])

    def get_charges(series: pd.Series):
        mol = AllChem.MolFromSmiles(series['smiles'])
        if not mol:
            return []
        return [a.GetAtomicNum() for a in mol.GetAtoms()]
    
    ic(act.head(5))
    tqdm.pandas()
    act['charges'] = act.progress_apply(get_charges, axis=1)
    dec['charges'] = dec.progress_apply(get_charges, axis=1)

    return [dict(Counter([cha for ele in act['charges'].tolist() for cha in ele if ele])), dict(Counter([cha for ele in dec['charges'].tolist() for cha in ele if ele]))]

In [None]:
extract_info_from_smiles(ism_path_df.iloc[0])

In [None]:
out = Parallel(n_jobs=n_jobs, verbose=1)(delayed(extract_info_from_smiles)(row) for _, row in ism_path_df.iterrows()) 

In [None]:
out

In [None]:
def sum_dict(a,b):
    temp = dict()
    for key in a.keys()| b.keys():
        temp[key] = sum([d.get(key, 0) for d in (a, b)])
    return temp
 
def test():
    from functools import reduce
    return print(reduce(sum_dict,[a,b,c]))
 
a = {'a': 1, 'b': 2, 'c': 3} 
b = {'a':1,'b':3,'d':4}
c = {'g':3,'f':5,'a':10}
test()

In [None]:
__ATOM_LIST__ = \
    ['-', 'h',  'he',
     'li', 'be', 'b',  'c',  'n',  'o',  'f',  'ne',
     'na', 'mg', 'al', 'si', 'p',  's',  'cl', 'ar',
     'k',  'ca', 'sc', 'ti', 'v ', 'cr', 'mn', 'fe', 'co', 'ni', 'cu',
     'zn', 'ga', 'ge', 'as', 'se', 'br', 'kr',
     'rb', 'sr', 'y',  'zr', 'nb', 'mo', 'tc', 'ru', 'rh', 'pd', 'ag',
     'cd', 'in', 'sn', 'sb', 'te', 'i',  'xe',
     'cs', 'ba', 'la', 'ce', 'pr', 'nd', 'pm', 'sm', 'eu', 'gd', 'tb', 'dy',
     'ho', 'er', 'tm', 'yb', 'lu', 'hf', 'ta', 'w',  're', 'os', 'ir', 'pt',
     'au', 'hg', 'tl', 'pb', 'bi', 'po', 'at', 'rn',
     'fr', 'ra', 'ac', 'th', 'pa', 'u',  'np', 'pu']

In [None]:
act_out = reduce(sum_dict, [ele[0] for ele in out])
dec_out = reduce(sum_dict, [ele[1] for ele in out])
all_out = reduce(sum_dict, [act_out, dec_out])
out_key = list(zip(reduce(sum_dict, [ele[0] for ele in out]+[ele[1] for ele in out]).keys(), [__ATOM_LIST__[key] for key in reduce(sum_dict, [ele[0] for ele in out]+[ele[1] for ele in out]).keys()]))

print(act_out)
print(dec_out)
print(all_out)

print(dict([(__ATOM_LIST__[ele[0]], ele[1]) for ele in sorted(act_out.items(), key=lambda x: x[0])]))
print(dict([(__ATOM_LIST__[ele[0]], ele[1]) for ele in sorted(dec_out.items(), key=lambda x: x[0])]))
print(dict([(__ATOM_LIST__[ele[0]], ele[1]) for ele in sorted(all_out.items(), key=lambda x: x[0])]))

print(dict(sorted(out_key, key=lambda x: x[0])))


In [None]:
for key, val in dict([(__ATOM_LIST__[ele[0]], ele[1]) for ele in sorted(all_out.items(), key=lambda x: -x[1])]).items():
    print(f'"{key}": {val}')

## 2.5 EGNN

### 2.5.1 Preprocess data

In [8]:
# 1. 获取文件路径
tmp_list = []
for root, dirs, files in os.walk(os.path.join(base_path, dude_dir), topdown=False):
    tmp_dict = {'target': root.split('/')[-1], 'act_sdf': None, 'dec_sdf': None}
    for name in files:
        if   'actives_final.sdf' in name:
            if 'gz' in name:
                os.system(f"gzip -d {os.path.join(root, name)}")
                tmp_dict['act_sdf'] = os.path.join(root, name[:-3])
                continue
            tmp_dict['act_sdf'] = os.path.join(root, name)
        elif 'decoys_final.sdf' in name:
            if 'gz' in name:
                os.system(f"gzip -d {os.path.join(root, name)}")
                tmp_dict['dec_sdf'] = os.path.join(root, name[:-3])
                continue
            tmp_dict['dec_sdf'] = os.path.join(root, name)
    tmp_list.append(tmp_dict)

In [11]:
sdf_path_df = pd.DataFrame(tmp_list).dropna()
ic(sdf_path_df.shape)
sdf_path_df.sample(5)

ic| sdf_path_df.shape: (101, 3)


Unnamed: 0,target,act_sdf,dec_sdf
23,cxcr4,/home/jovyan/work-home/molecule-3d-similarity/...,/home/jovyan/work-home/molecule-3d-similarity/...
70,nos1,/home/jovyan/work-home/molecule-3d-similarity/...,/home/jovyan/work-home/molecule-3d-similarity/...
92,tgfr1,/home/jovyan/work-home/molecule-3d-similarity/...,/home/jovyan/work-home/molecule-3d-similarity/...
84,pur2,/home/jovyan/work-home/molecule-3d-similarity/...,/home/jovyan/work-home/molecule-3d-similarity/...
76,pgh2,/home/jovyan/work-home/molecule-3d-similarity/...,/home/jovyan/work-home/molecule-3d-similarity/...


In [13]:
def de_duplicate(m_list):
    name_l = []
    inck_l = []
    de_m_l = []

    for m in tqdm(m_list):
        name = m.GetProp('_Name')
        inck = Chem.MolToInchiKey(m)
        if name in name_l:
            # ic('name duplicate')
            pass
        elif inck in inck_l:
            # ic('icnk duplicate')
            pass
        else:
            name_l.append(name)
            inck_l.append(inck)
            de_m_l.append(m)
    
    return de_m_l

In [12]:
import rdkit
map_dict = {
    rdkit.Chem.rdchem.BondType.SINGLE   : 1,
    rdkit.Chem.rdchem.BondType.DOUBLE   : 2,
    rdkit.Chem.rdchem.BondType.TRIPLE   : 3,
    rdkit.Chem.rdchem.BondType.AROMATIC : 4,
}

In [15]:
os.path.join(os.path.dirname(base_path), 'egnn/data', )

'/home/jovyan/work-home'

In [25]:
def extract_info_from_sdf(series: pd.Series):
    print(f"🌟  {series['target']}  🌟")
    act = de_duplicate([m for m in Chem.SDMolSupplier(series['act_sdf'], removeHs=False) if m])
    dec = de_duplicate([m for m in Chem.SDMolSupplier(series['dec_sdf'], removeHs=False) if m])

    def get_info(mol, name, label):
        name        = name
        smiles      = AllChem.MolToSmiles(mol)
        charges     = json.dumps([a.GetAtomicNum() for a in mol.GetAtoms()])
        positions   = json.dumps(mol.GetConformer(0).GetPositions().tolist())
        mol_adj     = AllChem.GetAdjacencyMatrix(mol)
        # GetBondType
        edges = []
        for in_i, in_a in enumerate(list(mol.GetAtoms())[:-1]):
            in_a_bt = [ele.GetBondType() for ele in in_a.GetBonds()[sum(mol_adj[in_i][:in_i+1]):]]
            edges.extend([(in_i, wr_i, map_dict[in_a_bt[at_i]]) for at_i,wr_i  in  enumerate([to_i for to_i,to_a in enumerate(mol_adj[in_i]) if (to_i>in_i)and(to_a==1)])])
    
        return {'name':name, 'smiles':smiles, 'label':label, 'charges':charges, 'positions':positions, 'edges':edges}
    act_out = Parallel(n_jobs=n_jobs, verbose=1)(delayed(get_info)(mol, mol.GetProp('_Name'), 'act') for mol in act) 
    dec_out = Parallel(n_jobs=n_jobs, verbose=1)(delayed(get_info)(mol, mol.GetProp('_Name'), 'dec') for mol in dec)
    
    save_path = os.path.join(os.path.dirname(base_path), 'egnn/dude/data')
    os.system(f'mkdir -p {save_path}')

    out_df = pd.DataFrame(act_out + dec_out)
    ic(out_df.sample(5))
    out_df.to_csv(os.path.join(save_path, series['target']+'.csv'), index=False)

In [28]:
pd.read_csv('/home/jovyan/work-home/egnn/dude/data/aa2ar.csv').sample(5)

Unnamed: 0,name,smiles,label,charges,positions,edges
24337,ZINC01656377,[H]c1c([H])c([H])c(N([H])[H])c(Sc2c([H])c([H])...,dec,"[6, 6, 6, 6, 6, 6, 7, 16, 6, 6, 6, 6, 6, 6, 7,...","[[4.0376, 3.9609, 4.3182], [2.916, 3.2525, 3.9...","[(0, 1, 4), (0, 5, 4), (0, 20, 1), (1, 2, 4), ..."
13197,ZINC65280765,[H]c1c([H])c(S(=O)(=O)N([H])[H])c([H])c([N+](=...,dec,"[6, 6, 6, 7, 6, 6, 8, 6, 6, 6, 6, 6, 6, 6, 6, ...","[[-0.0187, 1.5258, 0.0104], [0.0021, -0.0041, ...","[(0, 1, 1), (0, 22, 1), (0, 23, 1), (0, 24, 1)..."
9715,ZINC16401530,[H]c1c([H])c(S(=O)(=O)N([H])[H])c([H])c([H])c1...,dec,"[6, 6, 6, 6, 6, 6, 6, 7, 6, 16, 7, 7, 6, 6, 6,...","[[-0.3599, 2.7788, 1.8186], [-1.0099, 3.2267, ...","[(0, 1, 4), (0, 5, 4), (0, 26, 1), (1, 2, 4), ..."
14859,ZINC14244778,[H]c1c([H])c(C(=O)N([H])N([H])C(=O)C([H])([H])...,dec,"[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 16, 8, 8, 6,...","[[3.369, -0.1417, 3.4158], [4.4076, -0.976, 3....","[(0, 1, 4), (0, 9, 4), (0, 32, 1), (1, 2, 4), ..."
2653,ZINC40660176,[H]c1nc(N([H])c2c([H])c([H])c(Cl)c(C(F)(F)F)c2...,dec,"[6, 6, 6, 7, 6, 6, 7, 6, 6, 7, 6, 6, 6, 6, 6, ...","[[2.4124, 1.3886, 2.5638], [2.4131, 2.0734, 1....","[(0, 1, 4), (0, 5, 4), (0, 33, 1), (1, 2, 4), ..."


In [None]:
extract_info_from_sdf(sdf_path_df.iloc[0])

In [None]:
Parallel(n_jobs=n_jobs)(delayed(extract_info_from_sdf)(series) for _, series in sdf_path_df.iterrows())