In [3]:
import joblib
from dgllife.model import model_zoo
from dgllife.utils import AttentiveFPAtomFeaturizer
from dgllife.utils import AttentiveFPBondFeaturizer
from torch.utils.data import DataLoader
from dgllife.data import MoleculeCSVDataset
from dgllife.utils import mol_to_bigraph
import dgl
from multiprocessing import Pool
from rdkit import Chem
import torch
import numpy as np
from functools import partial
from rdkit.Chem.Crippen import MolLogP

In [4]:
pool=Pool(64)
def norm_mol(mol):
    try:
        smiles = Chem.MolToSmiles(mol,isomericSmiles=False)
        mol = Chem.MolFromSmiles(smiles)
        return mol
    except:
        return None
def collate_molgraphs(data):
    smiles_list, graph_list = map(list, zip(*data))
    
    bg = dgl.batch(graph_list)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    return smiles_list, bg
class GraphDataset(object):
    def __init__(self,smiles_list,smiles_to_graph):
        self.smiles=smiles_list
        if len(smiles_list) > 100:
            self.graphs = pool.map(smiles_to_graph,self.smiles)
        else:
            self.graphs = []
            for s in self.smiles:
                self.graphs.append(smiles_to_graph(s))
        

    def __getitem__(self, item):
        return self.smiles[item], self.graphs[item]

    def __len__(self):
        """Size for the dataset

        Returns
        -------
        int
            Size for the dataset
        """
        return len(self.smiles)

class MTATFP_model:
    def __init__(self,clf_path,device='cuda',n_tasks=2):
        self.atom_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='hv')
        self.bond_featurizer = AttentiveFPBondFeaturizer(bond_data_field='he')
        n_feats = self.atom_featurizer.feat_size('hv')
        e_feats = self.bond_featurizer.feat_size('he')
        fn = f'/public/home/chensheng/project/aixfuse/models/dgl/dgl.pt'
        self.n_tasks=n_tasks
        model = model_zoo.AttentiveFPPredictor(node_feat_size=n_feats,
                                        edge_feat_size=e_feats,
                                        num_layers=2,
                                        num_timesteps=1,
                                        graph_feat_size=300,
                                        n_tasks=n_tasks
                                            )
        model.load_state_dict(torch.load(clf_path,map_location=torch.device(device)))
        self.device=device
        self.gcn_net = model.to(device)
    
    def __call__(self, mol_list):
        new_mols=pool.map(norm_mol, mol_list)
        i=0
        new_mol_list=[]
        ind_list=[]
        for mol in new_mols:
            if mol is not None:
                new_mol_list.append(mol)
                ind_list.append(i)
            i+=1
        final_scores=np.ones([len(mol_list),self.n_tasks])
        if ind_list:
            mol_to_graph = partial(mol_to_bigraph, node_featurizer=self.atom_featurizer,edge_featurizer=self.bond_featurizer)
            test_datasets=GraphDataset(new_mol_list,mol_to_graph)
            test_loader = DataLoader(test_datasets, batch_size=256,shuffle=False,collate_fn=collate_molgraphs)
            results=[]
            for batch_data in test_loader:
                test_smiles, test_bg = batch_data
                test_bg = test_bg.to(self.device)
                test_n_feats = test_bg.ndata.pop('hv').to(self.device)
                test_e_feats = test_bg.edata.pop('he').to(self.device)
                test_prediction = self.gcn_net(test_bg, test_n_feats, test_e_feats)
                result=test_prediction.detach().cpu().numpy()
                result=-result
                result[result<1]=1
                result[result>12]=12 # 12?
                results.append(result)
            scores=np.concatenate(results,axis=0)
            final_scores[ind_list]=scores
        return final_scores[:,0],final_scores[:,1]


In [8]:
import pandas as pd
from scipy.stats import pearsonr
df=pd.read_csv('init_test_test.csv')
mol_list=[]
xp_5ntp_list=[]
xp_6n7d_list=[]
for smiles,xp_5ntp,xp_6n7d in zip(df['SMILES'],df['5NTP_XP'],df['6QU7_XP']):
    mol=Chem.MolFromSmiles(smiles)
    mol_list.append(mol)
    xp_5ntp_list.append(xp_5ntp)
    xp_6n7d_list.append(xp_6n7d)
mtatfp_model=MTATFP_model('data/models/dgl/dgl.pt')
scores1,scores2=mtatfp_model(mol_list)
print(scores1,xp_5ntp_list)
print(scores2,xp_6n7d_list)


Process ForkPoolWorker-1:
Traceback (most recent call last):
Process ForkPoolWorker-2:
  File "/home/chensheng/anaconda3/envs/mtdd/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Process ForkPoolWorker-3:
  File "/home/chensheng/anaconda3/envs/mtdd/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/chensheng/anaconda3/envs/mtdd/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
Process ForkPoolWorker-4:
  File "/home/chensheng/anaconda3/envs/mtdd/lib/python3.8/multiprocessing/queues.py", line 358, in get
    return _ForkingPickler.loads(res)
Traceback (most recent call last):
Traceback (most recent call last):
AttributeError: Can't get attribute 'norm_mol' on <module '__main__'>
  File "/home/chensheng/anaconda3/envs/mtdd/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/chensheng/anaconda3/envs/mtdd/lib/python3.8

KeyboardInterrupt: 