In [2]:
import sys; sys.path.append("/gaozhangyang/experiments/MotifRetro")
import pandas as pd
from rdkit import Chem
import json
import copy
from src.feat.utils import atom_to_edit_tuple, get_bond_tuple, fix_incomplete_mappings, reac_to_canonical, fix_explicit_hs, renumber_atoms_for_mapping, mark_reactants
from src.utils.chem_utils import mol2smi, smi2mol, compare_smis
from rdkit.Chem import Draw
from src.feat.reaction_actions import StopAction, AddMotifAction
from tqdm import tqdm
import numpy as np
from src.feat.featurize_gzy_psvae import ReactionSampleGenerator as ReactionSampleGenerator_gzy
import json

ImportError: cannot import name 'compare_smis' from 'src.utils.chem_utils' (/gaozhangyang/experiments/MotifRetro/src/utils/chem_utils.py)

In [1]:
# utils
action_vocab = json.load(open("/gaozhangyang/experiments/MotifRetro/data/uspto_50k/feat/uspto_50k_frag/action_vocab.json", 'r'))

props = action_vocab['prop2oh']
prop2oh = {'atom': {}, 'bond': {}}

for type_key in prop2oh.keys():
    oh_dict = props[type_key]
    for key, values in oh_dict.items():
        converted_values = {}
        for prop_val, val_oh in values.items():
            try:
                prop_val = int(prop_val)
            except ValueError:
                pass
            converted_values[prop_val] = val_oh
        prop2oh[type_key][key] = converted_values

action_vocab['prop2oh'] = prop2oh

def preprocess_mols(target, source):
    target_mol = Chem.MolFromSmiles(target)
    source_mol = Chem.MolFromSmiles(source)

    target_mol, source_mol = fix_incomplete_mappings(target_mol, source_mol)
    target_mol, source_mol = reac_to_canonical(target_mol, source_mol) 
    source_mol = fix_explicit_hs(source_mol)
    target_mol = fix_explicit_hs(target_mol)
    source_mol = renumber_atoms_for_mapping(source_mol)
    target_mol = renumber_atoms_for_mapping(target_mol)
    return target_mol, source_mol

# visualization
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D, MolsToGridImage

drawOptions = rdMolDraw2D.MolDrawOptions()
drawOptions.prepareMolsBeforeDrawing = False
drawOptions.bondLineWidth = 4
drawOptions.minFontSize = 12


def prepare_mol(mol, new_am):
    highlight_idx = []
    for i, atom in enumerate(mol.GetAtoms()):
        am = atom.GetAtomMapNum()
        if am in new_am:
            highlight_idx.append(i)
            
    try:
        mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol)
    except Chem.KekulizeException:
        mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=False)
        # Chem.SanitizeMol(mol_draw, Chem.SANITIZE_ALL ^ Chem.SANITIZE_KEKULIZE)
    
    
    return mol_draw, highlight_idx

def plot_states(states, target_mol, source_mol):
    target_am = set([atom.GetAtomMapNum() for atom in  target_mol.GetAtoms()])
    source_am = set([atom.GetAtomMapNum() for atom in  source_mol.GetAtoms()])
    new_am = list(target_am - source_am)

    mol_list = []
    highlightAtomLists = []
    for one in states:
        mol, highlight = prepare_mol(one, new_am)
        mol_list.append(mol)
        highlightAtomLists.append(highlight)

    return MolsToGridImage(mol_list, molsPerRow=5,  subImgSize=(500, 500), drawOptions=drawOptions, highlightBondLists = highlightAtomLists)


# get edit states
def get_edit_sates(sample_generator):
    state = [copy.deepcopy(sample_generator.source_mol)]

    for i in range(100):
        reaction_action = sample_generator.generate_gen_action()
        print(reaction_action)
        if type(reaction_action)==StopAction:
            break
        
        sample_generator.source_mol = reaction_action.apply(sample_generator.source_mol) # 这是关键部分
        latent_mol = copy.deepcopy(sample_generator.source_mol)
        latent_mol.UpdatePropertyCache(strict=False)
        state.append(latent_mol)
    state.append(sample_generator.target_mol)
    return state

def fix_explicit_hs(mol):
    for a in mol.GetAtoms():
        a.SetNoImplicit(False)

    mol = Chem.AddHs(mol, explicitOnly=True)
    mol = Chem.RemoveHs(mol)

    Chem.SanitizeMol(mol)
    return mol

def mol2smi(mol, rm_am=False, fixH = False):
    mol = copy.deepcopy(mol)
    if fixH:
        mol = fix_explicit_hs(mol)
    if rm_am:
        for atom in mol.GetAtoms():
            atom.SetAtomMapNum(0)
    return Chem.MolToSmiles(mol)

NameError: name 'json' is not defined

In [3]:
data = pd.read_csv("/gaozhangyang/experiments/MotifRetro/data/uspto_50k/raw_train.csv", index_col=0)

# data = pd.read_csv("/gaozhangyang/experiments/MotifRetro/data/uspto_full/USPTO_FULL_train.csv", index_col=0)

# print("cover rate {}".format((data['edit_len']<40).sum()/len(data)))
# data = data[data['edit_len']<40]

# length_array = np.array(data['edit_len'])

# print(length_array.mean(), length_array.max(), length_array.min())

In [6]:
long_path_idx = []
for idx in range(len(data)):

    target, source = data.iloc[idx]["reactants>reagents>production"].split(">>")
    # target, source = data.iloc[7]["rxn_smiles"].split(">>")

    target_mol, source_mol = preprocess_mols(target, source)


    vocab_path = "/gaozhangyang/experiments/MotifRetro/data/uspto_50k/adding_motif_trees.json"



    motifretro_sample_generator = ReactionSampleGenerator_gzy(Chem.rdchem.RWMol(source_mol), target_mol,  action_vocab=action_vocab, keep_actions_list=None, use_motif_action=True, vocab_path=vocab_path)

    motifretro_states = get_edit_sates(motifretro_sample_generator)
    if len(motifretro_states)>5:
        long_path_idx.append(idx)
    
    assert compare_smis(mol2smi(motifretro_states[-2]), mol2smi(motifretro_states[-1]))

plot_states(motifretro_states, target_mol, source_mol)



Add Motif [*:1][C:28](=[O:29])[O:30][CH2:31][c:32]1[cH:33][cH:34][cH:35][cH:36][cH:37]1
Stop
Delete bond (1, 20)
Add Motif [*:1][OH:31]
Stop
Delete bond (1, 44)
Add Motif [*:1]=[O:49]
Stop
Delete bond (1, 6)
Add Motif [*:1]=[O:65]
Stop
Delete bond (1, 9)
Add Motif [*:1][Cl:35]
Stop
Delete bond (1, 2)
Add Motif [*:1][Br:29]
Add Motif [*:2][B:26]([OH:27])[OH:28]
Stop
Delete bond (1, 2)
Add Motif [*:1][Br:43]
Add Motif [*:2][B:34]1[O:35][C:36]([CH3:37])([CH3:38])[C:39]([CH3:40])([CH3:41])[O:42]1
Stop
Edit bond (7, 8): Bond Type=Double, Bond Stereo=None
Stop
Delete bond (1, 11)
Add Motif [*:1][OH:20]
Stop
Delete bond (1, 20)
Add Motif [*:1][Br:30]
Stop
Add Motif [*:1][C:21]([CH3:22])([CH3:23])[CH3:24]
Stop
Edit Atom 1: Formal Charge=0, Chiral Type=None, 0, Is Aromatic=Yes
Add Motif [*:1][CH3:23]
Add Motif [*:23][c:24]1[cH:25][cH:26][cH:27][cH:30][cH:31]1
Add Motif [*:27][O:28][CH3:29]
Stop
Delete bond (1, 36)
Add Motif [*:1][OH:42]
Stop
Delete bond (1, 7)
Add Motif [*:1]=[O:21]
Stop
Delete

ArgumentError: Python argument types in
    rdkit.Chem.rdmolfiles.MolToSmiles(NoneType, int)
did not match C++ signature:
    MolToSmiles(RDKit::ROMol mol, bool isomericSmiles=True, bool kekuleSmiles=False, int rootedAtAtom=-1, bool canonical=True, bool allBondsExplicit=False, bool allHsExplicit=False, bool doRandom=False)
    MolToSmiles(RDKit::ROMol mol, RDKit::SmilesWriteParams params)

In [7]:
idx

208

In [16]:
mol2smi(motifretro_states[-1])

'[NH3+:2][C@H:3]([CH3:4])[CH:5]1[CH2:6][CH2:7][O:8][CH2:9][CH2:10]1.[c:1]1([F:17])[n:11][cH:12][cH:13][cH:14][c:15]1[I:16]'

True