In [None]:
import json

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers

#load preprocess - used to convert to graph structue
import nfp
from preprocess_inputs_cfc import preprocessor
preprocessor.from_json('model_3_tfrecords_multi_halo_cfc/preprocessor.json')

from rdkit import Chem
from rdkit.Chem import Draw

#load model
class Slice(layers.Layer):
    def call(self, inputs):
        input_shape = tf.shape(inputs)
        num_bonds = input_shape[1] / 2
        output = tf.slice(inputs, [0, 0, 0], [-1, num_bonds, -1])
        output.set_shape(self.compute_output_shape(inputs.shape))
        return output

    def compute_output_shape(self, input_shape):
        return [input_shape[0], None, input_shape[2]]
    
custom_objects = {**nfp.custom_objects,'Slice':Slice}

model = tf.keras.models.load_model('model_3_multi_halo_cfc/best_model.hdf5', custom_objects=custom_objects)

def get_bdes(smiles_):
    #make the test data graphs
    smiles = Chem.CanonSmiles(smiles_)  
      
    pred_bdes = predict_bdes(smiles)

    dict_CH_bdes = {}

    dict_atom_to_bond = get_bond_idx_for_C(smiles)

    for atom, bond in dict_atom_to_bond.items():
        # need to complete the bde min max mean for each site and then dump it into a json file...
        bdes = pred_bdes['pred_bde'].iloc[bond]
        bdfes = pred_bdes['pred_bdfe'].iloc[bond]
        param_atoms =  {'bde_min':min(bdes), 
                        'bde_max':max(bdes), 
                        'bde_avg':float(np.mean(bdes)),
                        'bdfe_min':min(bdfes), 
                        'bdfe_max':max(bdfes), 
                        'bfde_avg':float(np.mean(bdfes))}
        dict_CH_bdes.update({atom:param_atoms})

    return dict_CH_bdes

def predict_bdes(smiles_):
    smiles = Chem.CanonSmiles(smiles_)
    test_smiles = (
        tf.data.Dataset.from_generator(
            lambda:  ([get_data(smiles)]), 
            output_signature= { **preprocessor.output_signature,'n_atom': tf.TensorSpec(shape=(), dtype=tf.int32, name=None),\
            'n_bond': tf.TensorSpec(shape=(), dtype=tf.int32, name=None) })
        .padded_batch(batch_size=1000, padding_values={**preprocessor.padding_values,'n_atom': tf.constant(0, dtype="int32"),\
            'n_bond': tf.constant(0, dtype="int32")})
        )
    
    predicted_bdes = model.predict(test_smiles, verbose=True)

    df = pd.DataFrame(predicted_bdes.reshape(-1, 2), columns=['pred_bde','pred_bdfe'])
    
    return df

def get_data(smiles):
    input_dict = preprocessor(smiles)
    input_dict['n_atom'] = len(input_dict['atom'] )
    input_dict['n_bond'] = len(input_dict['bond'] )
    return input_dict

def func(x):
    x['bond_index'] = range(0, predicted_bdes.shape[1])
    return x

def get_bond_idx_for_C(smiles_):
    """
    input : reactant Canonical SMILES
    output:
        dict_atom_to_bond, dict() with keys being atom idx in the CanonicalSMILES 
                                       values being the bond idx of the corresponding C-H bonds.
    """
    smiles = Chem.CanonSmiles(smiles_)
    m = Chem.MolFromSmiles(smiles)
    m = Chem.AddHs(m)

    m_, g = group_symmetric_atoms(smiles)
    groups_done = []
    idx_to_keep = [] 

    for at in m.GetAtoms():
        if at.GetSymbol() == 'C' and 'H' in [a.GetSymbol() for a in at.GetNeighbors()]:
            if g[at.GetIdx()] not in groups_done:
                groups_done.append(g[at.GetIdx()])
                idx_to_keep.append(at.GetIdx())

    dict_atom_to_bond = {}    
    for at in m.GetAtoms():
        if at.GetIdx() in idx_to_keep:
            bonds_to_keep = []
            bonds = at.GetBonds()
            for b in bonds:
                b_atoms = [b.GetBeginAtom().GetSymbol(), b.GetEndAtom().GetSymbol()]
                if 'H' in b_atoms and 'C' in b_atoms:
                    bonds_to_keep.append(b.GetIdx())
            dict_atom_to_bond.update({at.GetIdx():bonds_to_keep})
            
    return dict_atom_to_bond

def group_symmetric_atoms(smiles_):
    """
    input : reactant Canonical SMILES
    output:
        mol, Chem.Mol() object annotated with the symmetry group atoms are belonging to
        idx_to_group, dict() with keys being atom idx in the CanonicalSMILES and values beig the label of the group they belong to.
    """
    
    smiles = Chem.CanonSmiles(smiles_)
    mol    = Chem.MolFromSmiles(smiles)
    Chem.RemoveStereochemistry(mol)
    groups = Chem.CanonicalRankAtoms(mol, breakTies=False)
    
    idx_to_group = {}
    

    for at in mol.GetAtoms():
        at.SetProp('atomNote', f"{groups[at.GetIdx()]}")  
        if at.GetSymbol() == 'C':
            idx_to_group.update({at.GetIdx(): groups[at.GetIdx()]})

    return mol, idx_to_group


def is_mol_symmetric(smiles_):
    """
    input : reactant Canonical SMILES
    output:
        boolean, True if the carbon squelettom has equivalent carbons, False if not
    """
    smiles = Chem.CanonSmiles(smiles_)
    mol = Chem.MolFromSmiles(smiles)
    
    # remove stereochemistry: helps find symmetries...
    Chem.RemoveStereochemistry(mol)
    
    groups = list(Chem.CanonicalRankAtoms(mol, breakTies=False))

    if len(groups) - len(set(groups)) > 0:
        return True
    else:
        return False
    
def get_bdes_and_diplay_m(smiles_):
    smiles = Chem.CanonSmiles(smiles_)   
    
    pred_bdes = predict_bdes(smiles)

    m = Chem.MolFromSmiles(smiles)
    m = Chem.AddHs(m)
    for b in m.GetBonds():
        b.SetProp('bondNote', f"{round(pred_bdes['pred_bde'].iloc[b.GetIdx()],0)}")
    
    return m


In [None]:
file_to_featurize = '../Data/numbered_reaction_1.csv'
from tqdm import tqdm

#qfile_to_featurize = '../ZINC_molecules/ZINC_data/ZINC_lead/all_smiles_ds1_f1_SP3_filtered.csv'
#df_ = pd.read_csv(file_to_featurize, index_col=0)
#smiles_to_bde = df_.Reactant_SMILES.unique()
#smiles_to_bde = [Chem.CanonSmiles(smiles) for smiles in smiles_to_bde if smiles==smiles]
#smiles_to_bde.append(Chem.CanonSmiles('CC(=O)O[C@H]1CCC[C@@H]2CCCC[C@@H]21'))
desc_file = open('../Data/Descriptors/bdes.json')
df_desc = json.load(desc_file)
desc_file.close()

sheet_id = "1OijQ0fiJTJn8OOOU9pJ9qm5JxmQWKpmb"
url      = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?gid=751558775&format=csv"
df_xlsx  = pd.read_csv(url, index_col=0)
smiles_to_bde = df_xlsx.Reactant_SMILES.unique()
smiles_to_bde = [Chem.CanonSmiles(smiles) for smiles in smiles_to_bde if smiles==smiles]
smiles_target = ['O=C(OC)CC[C@@H](C)[C@H]1CC[C@@]2([H])[C@]3([H])CCC4C[C@H](OC(C)=O)CC[C@]4(C)[C@@]3([H])CC[C@@]21C', 
                 'CC([C@H](OC(CO)=O)C[C@@]([C@@H](O)[C@@H]1C)(C)CC)([C@@H]23)[C@@H](CCC12CCC3=O)C',
                 'O=C(OC)C(N1C(C(C=CC=C2)=C2C1=O)=O)C3CCCCC3',
                 'C[C@@H]1CC[C@H]2[C@H](C(=O)O[C@H]3[C@@]24[C@H]1CC[C@](O3)(OO4)C)C',
                 'CC([C@H](OC(COC(C)=O)=O)C[C@@]([C@@H](OC(C)=O)[C@@H]1C)(C)CC)([C@@H]23)[C@@H](CCC12CCC3=O)C']

smiles_to_bde.extend(smiles_target)

count = 0
for k, smiles in tqdm(enumerate(smiles_to_bde)):

    smiles = Chem.CanonSmiles(smiles)
    
    if smiles != Chem.CanonSmiles(smiles):
        print(f"{smiles} is not canonical") 
        count += 1

    if smiles in df_desc.keys():                
        df_desc[smiles].update({'pred_bdes': get_bdes(smiles)})
    
    else:
        df_desc.update({smiles: {}})
        df_desc[smiles].update({'pred_bdes': get_bdes(smiles)})

    if k % 100 == 0:
        with open('../Data/Descriptors/bdes.json', "w") as desc_file:
            json.dump(df_desc, desc_file, sort_keys=True, indent=1)
        print("\n\nBDE updated\n\n")

with open('../Data/Descriptors/bdes.json', "w") as desc_file:
    json.dump(df_desc, desc_file, sort_keys=True, indent=1)
print("\n\nBDE updated\n\n")
    
print(count)

In [None]:
# examples of the code and display the molecule with the bde values
smiles = 'CC(=O)O[C@H]1CCC[C@@H]2CCCC[C@@H]21'
get_bdes(smiles)

In [None]:
smiles = 'CC(=O)O[C@H]1CCC[C@@H]2CCCC[C@@H]21'
print(get_bdes(smiles))
Draw.MolToImage(get_bdes_and_diplay_m(smiles), size=(500, 500))