In [1]:
import joblib
from dgllife.model import model_zoo
from dgllife.utils import AttentiveFPAtomFeaturizer
from dgllife.utils import AttentiveFPBondFeaturizer
from torch.utils.data import DataLoader
from dgllife.data import MoleculeCSVDataset
from dgllife.utils import mol_to_bigraph
import dgl
from multiprocessing import Pool
from rdkit import Chem
import torch
import numpy as np
from functools import partial
import pandas as pd
from scipy.stats import pearsonr
import argparse
from glob import glob
import os
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
import sys
sys.path.append('utils')
import dual_MCTS
import importlib
import iter_gen
importlib.reload(iter_gen)
from iter_gen import task
from matplotlib import pyplot as plt

pool=Pool(64)

def collate_molgraphs(data):
    smiles_list, graph_list = map(list, zip(*data))
    
    bg = dgl.batch(graph_list)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    return smiles_list, bg

class GraphDataset(object):
    def __init__(self,smiles_list,smiles_to_graph):
        self.smiles=smiles_list
        if len(smiles_list) > 100:
            self.graphs = pool.map(smiles_to_graph,self.smiles)
        else:
            self.graphs = []
            for s in self.smiles:
                self.graphs.append(smiles_to_graph(s))
        

    def __getitem__(self, item):
        return self.smiles[item], self.graphs[item]

    def __len__(self):
        """Size for the dataset

        Returns
        -------
        int
            Size for the dataset
        """
        return len(self.smiles)

result={}

  _mcf.append(_pains, sort=True)['smarts'].values]


use GPU


In [7]:
# task_name='gsk3b_jnk3'
# target1='GSK3'+r"$\beta$"
# target2='JNK3'
task_name='rorgt_dhodh'
target1='ROR'+r"$\gamma$"+'t'
target2='DHODH'
generated_path='data/outputs/generated/'
model_path='data/models/dgl'
target1,target2=task_name.split('_')
pdb_id1=task[task_name]['pdb_id1']
pdb_id2=task[task_name]['pdb_id2']
prec=task[task_name]['prec']
color1,color2=task[task_name]['color']
generated_dir=os.path.join(generated_path,f'{target1}_{target2}')
model_dir=os.path.join(model_path,f'{target1}_{target2}')
final_train_csv=os.path.join(generated_dir,'util_iter_4.csv')
test_csv=os.path.join(generated_dir,'gen_iter_5_scores.csv')
train_df=pd.read_csv(final_train_csv)
train_set=set(train_df['SMILES'])
test_df=pd.read_csv(test_csv)
mol_list=[]
score1_list=[]
score2_list=[]
for smiles,score1,score2 in zip(test_df['SMILES'],test_df[f'{pdb_id1}_{prec}'],test_df[f'{pdb_id2}_{prec}']):
    if smiles in train_set or np.isnan(score1) or np.isnan(score2):
        continue
    mol=Chem.MolFromSmiles(smiles)
    mol_list.append(mol)
    score1_list.append(-score1)
    score2_list.append(-score2)
print(len(mol_list))

target1_mse_list=[]
target2_mse_list=[]
target1_r2_list=[]
target2_r2_list=[]
target1_pcc_list=[]
target2_pcc_list=[]
for iter in range(5):
    model_path = os.path.join(model_dir,f'gen_iter_{iter}.pt')
    mtatfp_model=dual_MCTS.MTATFP_model(model_path)
    scores1,scores2=mtatfp_model(mol_list)
    target1_mse=mean_squared_error(score1_list, scores1)
    target1_mse_list.append(target1_mse)
    target2_mse=mean_squared_error(score2_list, scores2)
    target2_mse_list.append(target2_mse)
    target1_r2=r2_score(score1_list, scores1)
    target1_r2_list.append(target1_r2)
    target2_r2=r2_score(score2_list, scores2)
    target2_r2_list.append(target2_r2)
    pcc1=pearsonr(score1_list, scores1)[0]
    target1_pcc_list.append(pcc1)
    pcc2=pearsonr(score2_list, scores2)[0]
    target2_pcc_list.append(pcc2)
    print(model_path,"target1|mean_squared_error:", target1_mse)
    print(model_path,"target1|r2 score:", target1_r2)
    print(model_path,"target1|pcc:", pcc1)
    print(model_path,"target2|mean_squared_error:", target2_mse)
    print(model_path,"target2|r2 score:", target2_r2)
    print(model_path,"target2|pcc:", pcc2)
result[task_name]=(target1_mse_list,target2_mse_list,target1_r2_list,target2_r2_list,target1_pcc_list,target2_pcc_list,score1_list,score2_list,scores1,scores2)

5406
data/models/dgl/rorgt_dhodh/gen_iter_0.pt target1|mean_squared_error: 1.8486469952414852
data/models/dgl/rorgt_dhodh/gen_iter_0.pt target1|r2 score: 0.21791156800174372
data/models/dgl/rorgt_dhodh/gen_iter_0.pt target1|pcc: 0.5609333672161727
data/models/dgl/rorgt_dhodh/gen_iter_0.pt target2|mean_squared_error: 2.8266408788635444
data/models/dgl/rorgt_dhodh/gen_iter_0.pt target2|r2 score: 0.10466086185552193
data/models/dgl/rorgt_dhodh/gen_iter_0.pt target2|pcc: 0.47292005110789137
data/models/dgl/rorgt_dhodh/gen_iter_1.pt target1|mean_squared_error: 1.3589916408862714
data/models/dgl/rorgt_dhodh/gen_iter_1.pt target1|r2 score: 0.425065118297156
data/models/dgl/rorgt_dhodh/gen_iter_1.pt target1|pcc: 0.669622602591349
data/models/dgl/rorgt_dhodh/gen_iter_1.pt target2|mean_squared_error: 2.347143115505303
data/models/dgl/rorgt_dhodh/gen_iter_1.pt target2|r2 score: 0.25654188692580926
data/models/dgl/rorgt_dhodh/gen_iter_1.pt target2|pcc: 0.5854587237599781
data/models/dgl/rorgt_dhod

In [8]:
# task_name='gsk3b_jnk3'
color1,color2=task[task_name]['color']
target1_mse_list,target2_mse_list,target1_r2_list,target2_r2_list,target1_pcc_list,target2_pcc_list,score1_list,score2_list,scores1,scores2=result[task_name]
img_dir=os.path.join('data/outputs/images/nn_test',task_name)
plt.figure(figsize=(6, 6))
plt.grid(True,color="gray",linewidth="0.5",axis="both",zorder=0)
plt.scatter(-np.array(score1_list),-np.array(scores1),s=1,zorder=10)
data=-np.concatenate([score1_list,scores1])
min_data=np.min(data)
max_data=np.max(data)
gap=(max_data-min_data)/15
plt.plot([min_data,max_data],[min_data,max_data],c='black',zorder=100)
plt.xlim(min_data,max_data)
plt.ylim(min_data,max_data)
plt.text(min_data+0.1,max_data-gap,'MSE={:.3f}'.format(target1_mse_list[-1]),fontsize=14)
plt.text(min_data+0.1,max_data-2*gap,'R$\mathregular{^2}$='+'{:.3f}'.format(target1_r2_list[-1]),fontsize=14)
plt.text(min_data+0.1,max_data-3*gap,'PCC={:.3f}'.format(target1_pcc_list[-1]),fontsize=14)
plt.xlabel("Docking Score on"+ target1,fontdict={'family' : 'Times New Roman', 'size'   : 18})
plt.ylabel("Predicted Docking Score on"+ target1,fontdict={'family' : 'Times New Roman', 'size'   : 18})
plt.xticks(fontproperties = 'Times New Roman', size = 14)
plt.yticks(fontproperties = 'Times New Roman', size = 14)
plt.savefig(os.path.join(img_dir,'plot1.png'),dpi=250,transparent=True)
plt.clf()

lines,labels=[],[]
plt.figure(figsize=(6, 6))
plt.grid(True,color="gray",linewidth="0.5",axis="both",zorder=0)
plt.scatter(-np.array(score2_list),-np.array(scores2),s=1,zorder=10)
data=-np.concatenate([score2_list,scores2])
min_data=np.min(data)
max_data=np.max(data)
gap=(max_data-min_data)/15
plt.plot([min_data,max_data],[min_data,max_data],c='black',zorder=100)
plt.xlim(min_data,max_data)
plt.ylim(min_data,max_data)
plt.text(min_data+0.1,max_data-gap,'MSE={:.3f}'.format(target2_mse_list[-1]),fontsize=14)
plt.text(min_data+0.1,max_data-2*gap,'R$\mathregular{^2}$='+'{:.3f}'.format(target2_r2_list[-1]),fontsize=14)
plt.text(min_data+0.1,max_data-3*gap,'PCC={:.3f}'.format(target2_pcc_list[-1]),fontsize=14)
plt.xlabel(f"Docking Score on"+ target2,fontdict={'family' : 'Times New Roman', 'size'   : 18})
plt.ylabel(f"Predicted Docking Score on"+ target1,fontdict={'family' : 'Times New Roman', 'size'   : 18})
plt.xticks(fontproperties = 'Times New Roman', size = 14)
plt.yticks(fontproperties = 'Times New Roman', size = 14)
plt.savefig(os.path.join(img_dir,'plot2.png'),dpi=250,transparent=True)
plt.clf()
width =0.3
fig, ax1 = plt.subplots(figsize=(6, 6))
plt.grid(True,color="gray",linewidth="0.5",axis="both",zorder=0)
plt.xticks([0,1,2,3,4],['Init.','Iter.1','Iter.2','Iter.3','Iter.4'])
plt.xlabel("Model for"+ target1,fontdict={'family' : 'Times New Roman', 'size'   : 16})
plt.xticks(fontproperties = 'Times New Roman', size = 14)
ax1.set_ylabel("MSE of predicted docking scores",fontdict={'family' : 'Times New Roman', 'size'   : 16})
ax1.bar(np.array(range(len(target1_mse_list)))-0.2 , target1_mse_list,width=width,edgecolor='black',color=color1,zorder=100,label='MSE')
line,label=ax1.get_legend_handles_labels()
lines.extend(line)
labels.extend(label)
# plt.savefig(os.path.join(img_dir,'target1_mse.png'))
# plt.clf()
ax2 = ax1.twinx()
ax2.set_ylabel("R$\mathregular{^2}$ of predicted docking scores",fontdict={'family' : 'Times New Roman', 'size'   : 16})
ax2.bar(np.array(range(len(target1_r2_list)))+0.2, target1_r2_list,width=width,edgecolor='black',color=color2,zorder=100,label="R$\mathregular{^2}$")
line,label=ax2.get_legend_handles_labels()
lines.extend(line)
labels.extend(label)
plt.legend(lines,labels,loc=(0.2,0.85),fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(img_dir,'target1_r2_mse.png'),dpi=250,transparent=True)
plt.clf()

lines,labels=[],[]
fig, ax1 = plt.subplots(figsize=(6, 6))
plt.grid(True,color="gray",linewidth="0.5",axis="both",zorder=0)
plt.xticks([0,1,2,3,4],['Init.','Iter.1','Iter.2','Iter.3','Iter.4'])
plt.xlabel("Model for"+ target2,fontdict={'family' : 'Times New Roman', 'size'   : 16})
plt.xticks(fontproperties = 'Times New Roman', size = 14)
ax1.set_ylabel("MSE of predicted docking scores",fontdict={'family' : 'Times New Roman', 'size'   : 16})
ax1.bar(np.array(range(len(target2_mse_list)))-0.2, target2_mse_list,width=width,edgecolor='black',color=color1,zorder=100,label='MSE')
line,label=ax1.get_legend_handles_labels()
lines.extend(line)
labels.extend(label)
# plt.savefig(os.path.join(img_dir,'target2_mse.png'))
# plt.clf()
ax2 = ax1.twinx()
ax2.set_ylabel("R$\mathregular{^2}$ of predicted docking scores",fontdict={'family' : 'Times New Roman', 'size'   : 16})
ax2.bar(np.array(range(len(target2_r2_list)))+0.2, target2_r2_list,width=width,edgecolor='black',color=color2,zorder=100,label="R$\mathregular{^2}$")
line,label=ax2.get_legend_handles_labels()
lines.extend(line)
labels.extend(label)
plt.legend(lines,labels,loc=(0.2,0.85),fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(img_dir,'target2_r2_mse.png'),dpi=250,transparent=True)
plt.clf()

<Figure size 600x600 with 0 Axes>

<Figure size 600x600 with 0 Axes>

<Figure size 600x600 with 0 Axes>

<Figure size 600x600 with 0 Axes>