In [None]:
import json

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers

#load preprocess - used to convert to graph structue
import nfp
from preprocess_inputs_cfc import preprocessor
preprocessor.from_json('model_3_tfrecords_multi_halo_cfc/preprocessor.json')

from rdkit import Chem
from rdkit.Chem import Draw

#load model
class Slice(layers.Layer):
    def call(self, inputs):
        input_shape = tf.shape(inputs)
        num_bonds = input_shape[1] / 2
        output = tf.slice(inputs, [0, 0, 0], [-1, num_bonds, -1])
        output.set_shape(self.compute_output_shape(inputs.shape))
        return output

    def compute_output_shape(self, input_shape):
        return [input_shape[0], None, input_shape[2]]
    
custom_objects = {**nfp.custom_objects,'Slice':Slice}

model = tf.keras.models.load_model('model_3_multi_halo_cfc/best_model.hdf5', custom_objects=custom_objects)


def get_bdes(smiles_):
    #make the test data graphs
    smiles = Chem.CanonSmiles(smiles_)  
      
    pred_bdes = predict_bdes(smiles)

    dict_CH_bdes = {}

    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    
    for bond in mol.GetBonds():
        # need to complete the bde min max mean for each site and then dump it into a json file...
        bdes  = pred_bdes['pred_bde'].iloc[bond.GetIdx()]
        bdfes = pred_bdes['pred_bdfe'].iloc[bond.GetIdx()]

        bond_desc =  {'bde': float(bdes), 'bdfe': float(bdfes)}

        dict_CH_bdes.update({str(bond.GetBeginAtomIdx()) + '-' + str(bond.GetEndAtomIdx()):bond_desc})

    return dict_CH_bdes

def predict_bdes(smiles_):
    smiles = Chem.CanonSmiles(smiles_)
    test_smiles = (
        tf.data.Dataset.from_generator(
            lambda:  ([get_data(smiles)]), 
            output_signature= { **preprocessor.output_signature,'n_atom': tf.TensorSpec(shape=(), dtype=tf.int32, name=None),\
            'n_bond': tf.TensorSpec(shape=(), dtype=tf.int32, name=None) })
        .padded_batch(batch_size=1000, padding_values={**preprocessor.padding_values,'n_atom': tf.constant(0, dtype="int32"),\
            'n_bond': tf.constant(0, dtype="int32")})
        )
    
    predicted_bdes = model.predict(test_smiles, verbose=True)

    df = pd.DataFrame(predicted_bdes.reshape(-1, 2), columns=['pred_bde','pred_bdfe'])
    
    return df

def get_data(smiles):
    input_dict = preprocessor(smiles)
    input_dict['n_atom'] = len(input_dict['atom'] )
    input_dict['n_bond'] = len(input_dict['bond'] )
    return input_dict

    
def get_bdes_and_diplay_m(smiles_):
    smiles = Chem.CanonSmiles(smiles_)   
    
    pred_bdes = predict_bdes(smiles)

    m = Chem.MolFromSmiles(smiles)
    m = Chem.AddHs(m)
    for b in m.GetBonds():
        b.SetProp('bondNote', f"{round(pred_bdes['pred_bde'].iloc[b.GetIdx()],0)}")
    
    return m


In [None]:
#file_to_featurize = '../Data/numbered_reaction_1.csv'
from tqdm import tqdm
from rdkit.Chem import AllChem
file_to_featurize = '../ZINC_molecules/ZINC_data/ZINC_lead/all_smiles_ds1_f1_SP3_filtered.csv'
df_ = pd.read_csv(file_to_featurize, index_col=0)
smiles_to_bde = df_.smiles.unique()
#smiles_to_bde = df_.Reactant_SMILES.unique()
smiles_to_bde = list(smiles_to_bde)
smiles_to_bde += ['CC(=O)O[C@H]1CC[C@]2(C)[C@H]3CC[C@]4(C)[C@@H]([C@H](C)CCCC(C)C)CC[C@H]4[C@@H]3C[C@@H](Br)[C@@]2(Br)C1',
               'CC(=O)O[C@@H]1CC[C@@]2(C)[C@H](CC[C@@H]3[C@@H]2CC[C@]2(C)[C@@H](OC(C)=O)CC[C@@H]32)C1',
               'CC(=O)O[C@@H]1CC[C@@]2(C)[C@H](CC[C@@H]3[C@@H]2CC[C@]2(C)[C@@H]([C@H](C)CCCC(C)C)CC[C@@H]32)C1',
               'COC(=O)[C@H](CC(C)C)NC(=O)[C@H](C)C(=O)OC(C)(C)C',
               'C1C2C[C@H]3CC(C[C@@H]1C3)[C@@]21CO1',
               'C1CC[C@H]2O[C@H]2C1',
               'CC(=O)O[C@H]1CCC[C@@H]2CCCC[C@@H]21']

smiles_to_bde = [Chem.CanonSmiles(smiles) for smiles in smiles_to_bde if smiles==smiles]

smiles_to_bde += ["COC(=O)CC[C@@H](C)[C@H]1CC[C@H]2[C@@H]3CCC4C[C@H](OC(C)=O)CC[C@]4(C)[C@H]3CC[C@]12C",
"C[C@@H]1CC[C@H]2[C@@H](C)C(=O)O[C@@H]3O[C@@]4(C)CC[C@@H]1[C@]32OO4",
"CC[C@]1(C)C[C@@H](OC(=O)CO)C2(C)[C@H](C)CCC3(CCC(=O)[C@H]32)[C@@H](C)[C@@H]1O",
"COC(=O)C(C1CCCCC1)N1C(=O)c2ccccc2C1=O",
"CC[C@]1(C)C[C@@H](OC(=O)COC(C)=O)C2(C)[C@H](C)CCC3(CCC(=O)[C@H]32)[C@@H](C)[C@@H]1OC(C)=O"
 ]

smiles_target = ['O=C(OC)CC[C@@H](C)[C@H]1CC[C@@]2([H])[C@]3([H])CCC4C[C@H](OC(C)=O)CC[C@]4(C)[C@@]3([H])CC[C@@]21C', 
                 'CC([C@H](OC(CO)=O)C[C@@]([C@@H](O)[C@@H]1C)(C)CC)([C@@H]23)[C@@H](CCC12CCC3=O)C',
                 'O=C(OC)C(N1C(C(C=CC=C2)=C2C1=O)=O)C3CCCCC3',
                 'C[C@@H]1CC[C@H]2[C@H](C(=O)O[C@H]3[C@@]24[C@H]1CC[C@](O3)(OO4)C)C',
                 'CC([C@H](OC(COC(C)=O)=O)C[C@@]([C@@H](OC(C)=O)[C@@H]1C)(C)CC)([C@@H]23)[C@@H](CCC12CCC3=O)C']
smiles_to_bde += smiles_target

sheet_id = "1OijQ0fiJTJn8OOOU9pJ9qm5JxmQWKpmb"
url      = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?gid=751558775&format=csv"
df_xlsx  = pd.read_csv(url, index_col=0)
df_xlsx = df_xlsx.dropna(subset=['Reactant_SMILES'])
smiles = df_xlsx.Reactant_SMILES.values

smiles_to_bde += list(smiles)

smiles_to_bde = list(set([Chem.CanonSmiles(smi) for smi in smiles_to_bde]))

desc_file = open('../Data/Descriptors/bdes_graph.json')
df_desc   = json.load(desc_file)
desc_file.close()

count = 0
print(f"there are {len(smiles_to_bde)} SMILES to compute BDEs for")
for k, smiles in tqdm(enumerate(smiles_to_bde)):

    smiles = Chem.CanonSmiles(smiles)
    
    if smiles != Chem.CanonSmiles(smiles):
        print(f"{smiles} is not canonical") 
        count += 1

    m = Chem.MolFromSmiles(smiles)
    m = Chem.AddHs(m)
    AllChem.EmbedMolecule(m)
    block = Chem.MolToMolBlock(m)

    if block == '':
        print(f"{smiles} is not a valid molecule, no embedding possible")
        break

    if smiles in df_desc.keys():                
        df_desc[smiles].update({'pred_bdes': get_bdes(smiles)})
    
    else:
        df_desc.update({smiles: {}})
        df_desc[smiles].update({'pred_bdes': get_bdes(smiles)})

    if k % 100 == 0:
    #if k == k:
        with open('../Data/Descriptors/bdes_graph.json', "w") as desc_file:
            json.dump(df_desc, desc_file, sort_keys=True, indent=1)
        print("\n\nBDE updated\n\n")

with open('../Data/Descriptors/bdes_graph.json', "w") as desc_file:
    json.dump(df_desc, desc_file, sort_keys=True, indent=1)
    print("\n\nBDE updated\n\n")
print(count)

In [None]:
print(smiles)
type(df_desc[smiles]['pred_bdes']['0-1']['bde'])

In [None]:
type(df_desc['BrC12CC3CC(CC(C3)C1)C2']['pred_bdes']['0-1']['bde'])

In [None]:
with open('../Data/Descriptors/bdes_graph.json', "w") as desc_file:        
    json.dump(df_desc, desc_file, sort_keys=True, indent=1)
print("\n\nBDE updated\n\n")

In [None]:
df_desc

In [None]:
smiles = 'CC(Cl)CC1CCCC1c2ccccc2'
print(get_bdes(smiles))
Draw.MolToImage(get_bdes_and_diplay_m(smiles), size=(500, 500))