In [1]:
import pickle
import pandas as pd

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from rdkit import Chem

import shap
shap.explainers._deep.deep_tf.op_handlers['AddV2'] = shap.explainers._deep.deep_tf.passthrough

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

#from tensorflow.compat.v1.keras.backend import get_session
#tf.compat.v1.disable_v2_behavior()
#tf.compat.v1.enable_eager_execution()

In [2]:
def shap_CN(num_atoms_of_interest):
    df = pd.read_csv('molecules_to_predict.csv')
    df['total_atoms'] = [ Chem.MolFromSmiles(smi).GetNumHeavyAtoms() for smi in df.Canonical_SMILES]
    indices_of_interest = df[ df['total_atoms'] == num_atoms_of_interest ].index

    with open('weights_last_layer.pkl','rb') as f:
        #CN_readout, norelu, relu
        w1, w2, w3 = pickle.load(f)

    with open('feat_vectors_for_shap.pkl','rb') as f:
        af_pkl, gf_pkl = pickle.load(f)

    af_pkl = af_pkl[indices_of_interest][:, 0:num_atoms_of_interest]
    gf_pkl = gf_pkl[indices_of_interest]

    df = df[ df['total_atoms'] == num_atoms_of_interest ]
    
    #af_input = layers.Input(shape=[num_atoms_of_interest,64], dtype=tf.float32, name='af')
    af_input = layers.Input(shape=[64], dtype=tf.float32, name='af')
    gf_input = layers.Input(shape=[64], dtype=tf.float32, name='gf')

    #af = layers.GlobalAveragePooling1D()(af_input)
    af = layers.Dense(64, activation='relu', name = 'denserelu')(af_input)
    af = layers.Dense(64, name = 'dense')(af)
    gf = layers.Add()([gf_input, af])
    prediction = layers.Dense(1, name = 'dense_final')(gf)
    
    input_tensors = [af_input, gf_input]
    model = tf.keras.Model(input_tensors, [prediction])
    
    model.layers[-1].set_weights( [w1[0].numpy(), w1[1].numpy()] )
    model.layers[-3].set_weights( [w2[0].numpy(), w2[1].numpy()] )
    model.layers[-5].set_weights( [w3[0].numpy(), w3[1].numpy()] )

    af_pkl_avg = np.mean(af_pkl, axis = 1)
    pred = model.predict([af_pkl_avg, gf_pkl]).squeeze()
    
    print(np.abs(pred -  df.predicted).mean(), np.abs(pred -  df.predicted).max())
    
    #### SHAP part ####
    e = shap.DeepExplainer(model, [af_pkl_avg, gf_pkl])
    shap_values = e.shap_values([af_pkl_avg, gf_pkl])
    
    af_shap, gf_shap = shap_values[0]
    
    all_shap = af_shap + gf_shap
    
    atomwise_shap = np.zeros((len(df), num_atoms_of_interest, 64))
    for i in range(len(af_pkl)): # af: (Num_atoms_in_a_molecule * 64)
        for j in range(len(af_pkl[i])):  # Num_atoms_in_a_molecule
            for k in range(len(af_pkl[i][j])): # 64
                #atomwise_shap[i][j][k] = (af_pkl[i][j][k] / (num_atoms_of_interest * af_pkl_avg[i][k])) * af_shap[i][k]    
                atomwise_shap[i][j][k] = (af_pkl[i][j][k] / (num_atoms_of_interest * af_pkl_avg[i][k])) * all_shap[i][k]    
                
    # for atom color map plot
    atomwise_shap_for_plot = np.sum(atomwise_shap, axis = 2)
    
    # to find the ax+b correlation between af_shap + gf_shap vs. predicted CN
    af_shap_summed = np.sum(af_shap, axis = 1)
    gf_shap_summed = np.sum(gf_shap, axis = 1)
    
    total_shap = af_shap_summed + gf_shap_summed
    
    reg = LinearRegression().fit(total_shap.reshape(-1,1), df.predicted)
    #print(num_atoms_of_interest, reg.coef_, reg.intercept_, reg.score(total_shap.reshape(-1,1), df.predicted))
    
    a, b  = reg.coef_[0], reg.intercept_
    af_shap_normalized_to_CN = a * af_shap_summed + (b/2)
    gf_shap_normalized_to_CN = a * gf_shap_summed + (b/2)
    
    atomwise_shap_normalized_to_CN = np.multiply(atomwise_shap_for_plot,
                                                  np.tile(
                                                      np.expand_dims(
                                                          np.divide(  np.array(df.predicted), 
                                                                      np.sum(atomwise_shap_for_plot, axis = -1)
                                                                   ), 
                                                       axis = 1),
                                                  num_atoms_of_interest))
    atom_shap_total = np.sum(atomwise_shap_normalized_to_CN, axis = -1)
    return df, atomwise_shap_normalized_to_CN, atom_shap_total

In [3]:
df = pd.read_csv('molecules_to_predict.csv')
df['total_atoms'] = [ Chem.MolFromSmiles(smi).GetNumHeavyAtoms() for smi in df.Canonical_SMILES]
df_w_shap = pd.DataFrame(columns = list(df.columns) + ['atom_shap_total','atomwise_shap'])

for i in range(8,25):
#for i in range(8,9):
    df_sub, atomwise_shap, atom_shap_total = shap_CN(i)
    df_sub['atomwise_shap'] = list(atomwise_shap)
    df_sub['atom_shap_total'] = list(atom_shap_total)
    df_w_shap = pd.concat([df_w_shap, df_sub])

2.9231015719738832e-06 1.335449218231588e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


2.670873232012304e-06 2.387451172580768e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


2.432051561199441e-06 9.344482421624889e-06


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


3.6005653246407032e-06 2.1198730465243898e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


4.356046109455248e-06 4.878417968257054e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


4.449384034259652e-06 3.473632813211225e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


2.950825051731113e-06 1.0954589839684559e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


3.5141292857694576e-06 1.5322021482688797e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


3.894559860301294e-06 1.8078613280181344e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


4.720552571650198e-06 2.3002929687265805e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


4.981856281816022e-06 1.8763427732437776e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


6.762483724050602e-06 1.838134765819177e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


5.767883299601095e-06 1.831054687784217e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


4.966895918354843e-06 1.5246582037775624e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


3.7810959596757107e-06 1.5653320318165242e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


4.944732664435847e-06 1.6982421868760866e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


5.437627155953824e-06 2.4577636722256102e-05


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


In [4]:
'''
import matplotlib.pyplot as plt
import matplotlib

import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})
%matplotlib inline

matplotlib.rcParams['figure.dpi'] = 300
#plt.rcParams["font.family"] = 'Arial'
#plt.rcParams.update({'font.size': 24})

plt.scatter(af_shap_summed + gf_shap_summed, df.predicted)
'''

'\nimport matplotlib.pyplot as plt\nimport matplotlib\n\nimport seaborn as sns\nsns.set(context=\'talk\', style=\'ticks\',\n        color_codes=True, rc={\'legend.frameon\': False})\n%matplotlib inline\n\nmatplotlib.rcParams[\'figure.dpi\'] = 300\n#plt.rcParams["font.family"] = \'Arial\'\n#plt.rcParams.update({\'font.size\': 24})\n\nplt.scatter(af_shap_summed + gf_shap_summed, df.predicted)\n'

In [5]:
df_w_shap

Unnamed: 0,Canonical_SMILES,Device_tier,Train/Valid/Test,CN,predicted,glob_vector,total_atoms,atom_shap_total,atomwise_shap
9,C1CC=CCCC=C1,1,Train,25.7,25.178371,1.279825 6.559031 -3.800939 -1.587972 1.457567...,8,25.178371,"[3.1698394378240833, 3.1698394378240833, 3.124..."
30,COc1ccccc1,1,Train,6.2,5.952137,0.297395 2.443781 -1.300710 -0.222946 0.281806...,8,5.952137,"[0.9108201797130551, 1.0394858541399652, 0.932..."
43,CCCCC(C)CC,1,Train,45.0,46.368923,2.267500 11.448519 -7.420034 -3.916420 2.46237...,8,46.368923,"[4.8671194106524815, 13.378437286900466, -0.48..."
64,CCCCCC(C)C,1,Train,52.6,53.743870,2.649966 13.292315 -8.495021 -4.692094 2.94048...,8,53.743870,"[2.789452487661217, 5.362448544461024, 7.05991..."
73,CC1=CC(=CC=C1)C,1,Train,7.0,7.101751,0.114691 2.321426 -1.313794 -0.092019 0.266234...,8,7.101751,"[0.3512299984009973, 1.472342742429441, 0.3305..."
...,...,...,...,...,...,...,...,...,...
496,O(C(=O)CCCCCCCC=CCCCCCCCC)C(C)CC,3,Train,72.0,69.144760,3.250906 17.743423 -10.671388 -5.775561 3.9041...,24,69.144760,"[2.4940025036433457, 1.4305945791904493, 0.424..."
525,CCCCOC(=O)CCCCCCC/C=C/C/C=C/C/C=C/CC,1,Valid,28.6,28.940542,1.226841 7.070701 -4.538584 -2.350751 1.198096...,24,28.940542,"[0.9716006150550872, 2.356929163283433, 2.3697..."
614,CCCCOC(=O)CCCCCCCC=CCC=CCC=CCC,3,Test,29.0,28.940540,1.226841 7.070702 -4.538585 -2.350750 1.198097...,24,28.940540,"[0.9716004844608673, 2.356928512211588, 2.3697..."
621,CCCCCCCCC(CCC)C(CCC)CCCCCCCC,3,Test,47.0,44.491978,2.486418 10.892790 -6.722002 -3.544846 2.27920...,24,44.491978,"[1.1879794266889057, 1.712365627548747, 2.1823..."


In [6]:
df_w_shap.to_csv('CN_shap_220614.csv', index = False)