In [9]:
import os
from collections import Counter
import random
from itertools import product 


import pandas as pd

from rdkit.Chem import AllChem
from rdkit.Chem.rdmolfiles import MolFromSmiles
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
from rdkit.Chem.Draw.IPythonConsole import DrawMorganBit, DrawMorganBits, display_pil_image
from rdkit.Chem.Draw import SimilarityMaps

import shap
import numpy as np

In [10]:
#Define the Morgan Fingerprint **USE THE SAME AS THE MAIN PROGRAM**
radius = 2
n_bits = 4096

In [11]:
def get_morgan_fp(dye_smiles_dict: dict, radius: int = 2, n_bits: int = 4096):
    dye_morgan_fp_dict = {}
    dye_morgan_fp_bit_dict = {}
    dye_morgan_fp_lst_dict = {}

    fp_gen = AllChem.GetMorganGenerator(radius=radius, fpSize=n_bits)
    ao = AllChem.AdditionalOutput()
    ao.CollectBitInfoMap()

    for dye, smiles in dye_smiles_dict.items():
        bit_info = {}
        mol = MolFromSmiles(smiles)
        fp = fp_gen.GetFingerprint(mol, additionalOutput=ao)

        dye_morgan_fp_dict[dye] = fp
        dye_morgan_fp_lst_dict[dye] = np.array(fp)
        dye_morgan_fp_bit_dict[dye] = ao.GetBitInfoMap()

    return dye_morgan_fp_dict, dye_morgan_fp_lst_dict, dye_morgan_fp_bit_dict

In [12]:
#Create the Dye:SMILES dictionary
dye_smile_dict = {}
dye_smile_df = pd.read_csv('dye_smiles.csv', usecols=['Dye', 'SMILES'])

for index, row in dye_smile_df.iterrows():
    dye_smile_dict[row['Dye']] = row['SMILES']


In [13]:
def average_shap_values(sorted_shap_values_df, dye_morgan_fp_lst_dict):
    average_shap_values_df = pd.DataFrame()
    
    #Average the sorted shap values 
    for dye_name in dye_morgan_fp_lst_dict.keys():
        
        #Get all shap values for a specific dye
        dye_specific_shap_values = sorted_shap_values_df[sorted_shap_values_df["Dye"] == dye_name]

        dye_specific_shap_values_avg = dye_specific_shap_values.mean(numeric_only=True)
      
        average_shap_values_df[dye_name] = dye_specific_shap_values_avg

    average_shap_values_df = average_shap_values_df.T
    return average_shap_values_df


In [14]:
def get_shap_list(average_shap_values_df, dye_morgan_fp_dict, dye_morgan_fp_bit_dict, dye_name, dye_smile_dict):
    #Get the # of atoms present in the dye
    mol = MolFromSmiles(dye_smile_dict[dye_name])
    num_atoms = mol.GetNumAtoms()
    shap_lst = np.zeros(num_atoms)
    

    #Go through the keys of the bit dict which shows non zero bits in the fingerprint
    for i in dye_morgan_fp_bit_dict[dye_name].keys():
         #Go through the elements of the bit dict
        for j in range(len(dye_morgan_fp_bit_dict[dye_name][i])):

          #Small script if you want to see the bit assignments, will generate a lot of files, can be done for a single dye or all
          # if dye_name == "MB":
          #   bit_pic_file = DrawMorganBit(mol, i, dye_morgan_fp_bit_dict[dye_name], useSVG=False)
          #   bit_pic_file.save(f"{dye_name}_bit_{i}_atom_num{dye_morgan_fp_bit_dict[dye_name][i][j][0]}.png")
          
          #ID which atom the bit corresponds to and add the average shap value
          if str(i) in average_shap_values_df.columns:
            shap_lst[dye_morgan_fp_bit_dict[dye_name][i][j][0]] += average_shap_values_df.at[dye_name, str(i)] 

    
    return shap_lst

In [15]:
def draw_dye_heatmaps(dye_name, shap_lst, model, matrix, dye_smile_dict):
    #Get the molecule and draw the heatmap
    
    mol = MolFromSmiles(dye_smile_dict[dye_name])
    #Use a similarity map to add the weighted SHAP values to the proper atoms
    fig = SimilarityMaps.GetSimilarityMapFromWeights(mol, shap_lst)


    #Save the figure in an appropriate folder
    dye_folder = f"{dye_name}_heatmaps"
    os.mkdir(dye_folder) if not os.path.exists(dye_folder) else None
    fig.savefig(f"{dye_folder}/{dye_name}_{model}_{matrix}_heatmap.png", bbox_inches = 'tight', transparent=True)

In [None]:
matrix_lst = ["RO", "EtOH", "CHCl3"]
model_lst = ["LR", "RF", "NN"]

print("Did you properly set the model and matrix list?   What about the Fingerprinting radius and n_bits?")

for model in model_lst:
    for matrix in matrix_lst:
        #Load the proper shap value csv
        shap_values_df = pd.read_csv(f"shap_value_raw_all_{model}_{matrix}.csv")
        
        #Create a morgan fingerprint bit which will be used to draw the heatmaps later
        dye_morgan_fp_dict, dye_morgan_fp_lst_dict, dye_morgan_fp_bit_dict = get_morgan_fp(dye_smile_dict, radius, n_bits)
        

        #Sort the dataframe by dye
        sorted_shap_values_df = shap_values_df.sort_values('Dye')
        
        #Average the shap values for each dye
        average_shap_values_df = average_shap_values(sorted_shap_values_df, dye_morgan_fp_lst_dict)
        
        #Finalize and draw on a per dye basis
        
        
        for dye_name in dye_smile_dict.keys():
            
            shap_list = get_shap_list(average_shap_values_df, dye_morgan_fp_dict, dye_morgan_fp_bit_dict, dye_name, dye_smile_dict)

            #Draw the Heatmap
            fig = draw_dye_heatmaps(dye_name, shap_list, model, matrix, dye_smile_dict)

