In [38]:

import os
import ray
import pickle
import pandas as pd
import numpy as np
from rdkit import Chem
from tqdm.auto import tqdm
from rdkit.Chem import AllChem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import warnings
warnings.filterwarnings(action='ignore')
from multiprocessing import Pool, get_context
import multiprocessing



In [28]:

def smi2_2Dcoords(smi):
    mol = Chem.MolFromSmiles(smi)
    mol = AllChem.AddHs(mol)
    AllChem.Compute2DCoords(mol)
    coordinates = mol.GetConformer().GetPositions().astype(np.float32)
    len(mol.GetAtoms()) == len(coordinates), "2D coordinates shape is not align with {}".format(smi)
    return coordinates




In [29]:
def smi2_3Dcoords(smi,cnt):
    mol = Chem.MolFromSmiles(smi)
    mol = AllChem.AddHs(mol)
    coordinate_list=[]
    for seed in range(cnt):
        try:
            res = AllChem.EmbedMolecule(mol, randomSeed=seed)  # will random generate conformer with seed equal to -1. else fixed random seed.
            if res == 0:
                try:
                    AllChem.MMFFOptimizeMolecule(mol)       # some conformer can not use MMFF optimize
                    coordinates = mol.GetConformer().GetPositions()
                except:
                    print("Failed to generate 3D, replace with 2D")
                    coordinates = smi2_2Dcoords(smi)

            elif res == -1:
                mol_tmp = Chem.MolFromSmiles(smi)
                AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)
                mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)
                try:
                    AllChem.MMFFOptimizeMolecule(mol_tmp)       # some conformer can not use MMFF optimize
                    coordinates = mol_tmp.GetConformer().GetPositions()
                except:
                    print("Failed to generate 3D, replace with 2D")
                    coordinates = smi2_2Dcoords(smi)
        except:
            print("Failed to generate 3D, replace with 2D")
            coordinates = smi2_2Dcoords(smi)

        assert len(mol.GetAtoms()) == len(coordinates), "3D coordinates shape is not align with {}".format(smi)
        coordinate_list.append(coordinates.astype(np.float32))
    return coordinate_list



In [30]:
def inner_smi2coords(content):
    smi = content[0]
    target = content[1:]
    cnt = 10 # conformer num,all==11, 10 3d + 1 2d

    mol = Chem.MolFromSmiles(smi)
    if len(mol.GetAtoms()) > 400:
        coordinate_list =  [smi2_2Dcoords(smi)] * (cnt+1)
        print("atom num >400,use 2D coords",smi)
    else:
        coordinate_list = smi2_3Dcoords(smi,cnt)
        coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))
    mol = AllChem.AddHs(mol)
    atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]  # after add H
    return {'atoms': atoms, 'coordinates': coordinate_list, 'mol':mol,'smi': smi, 'target': target}




In [31]:
def smi2coords(content):
    try:
        return inner_smi2coords(content)
    except:
        print("failed smiles: {}".format(content[0]))
        return None



In [39]:
inpath = './mol_property_demo.csv'
df = pd.read_csv(inpath)
sz = len(df)
name = 'train.pt'
content_list = zip(*[df[c].values.tolist() for c in df])
with get_context('fork').Pool(16) as pool:
    i = 0
    outputs = list()
    for inner_output in tqdm(pool.imap(smi2coords, content_list), total=len(df)):
        if inner_output is not None:
            outputs.append(inner_output)
            i += 1
    print('{} process {} lines'.format(name, i))


  0%|          | 0/20 [00:00<?, ?it/s]

train.pt process 20 lines


In [21]:
outputs

[{'atoms': ['O',
   'C',
   'N',
   'C',
   'C',
   'N',
   'C',
   'C',
   'C',
   'C',
   'C',
   'C',
   'C',
   'C',
   'C',
   'C',
   'C',
   'N',
   'C',
   'C',
   'C',
   'C',
   'N',
   'C',
   'C',
   'N',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H',
   'H'],
  'coordinates': [array([[-1.079646  , -1.9251337 ,  1.0858862 ],
          [-1.563253  , -1.3573889 ,  0.11348289],
          [-2.261242  , -1.9585762 , -0.89298546],
          [-2.5639415 , -3.3709674 , -0.9289337 ],
          [-2.6967113 , -1.0099312 , -1.7387233 ],
          [-2.3620906 ,  0.20935145, -1.4723111 ],
          [-1.5232918 ,  0.15792969, -0.21651633],
          [-0.10979961,  0.68218917, -0.54424036],
          [ 1.0647279 ,  0.0832902 , -0.04880654],
          [ 2.3456736 ,  0.5912071 , -0.34444657],
          [ 2.4334588 ,  1.7363776 , -1.1496334 ],
          [ 1.2854766 ,  2.3559892 , -1.6403302 ],
          [ 0.02862

In [11]:
outputs

NameError: name 'outputs' is not defined