# Compound SV / SHAP mappping

In [1]:
%load_ext autoreload

# Import libraries

In [3]:
import pandas as pd
import os
from rdkit import Chem
from ML.ml_utils_reg import create_directory
from ML.utils_mol_draw import get_ecfp4_bit_info, shap_to_atom_weight, \
    get_atom_wise_weight_map_ref_mol, get_color_bar
from ML import utils_pil
from ML.utils_shap import get_mmp_core
from ccrlib_master.utils_mmp import count_heavy_atoms
import seaborn as sns
import tqdm
sns.set_theme(style="whitegrid")
%autoreload 2

# Load data

In [3]:
results_path = './regression_shap_mmp/ECFP4/regression/'

In [4]:
df_shap_loaded_all = pd.read_pickle(os.path.join(results_path, f'df_shap.pkl'))
df_shap_loaded_all.rename(columns={'train_test': 'Set'}, inplace=True)
df_shap_loaded_all['Set'] = df_shap_loaded_all['Set'].replace({'test': 'Test', 'train': 'Train'})
display(df_shap_loaded_all)

Unnamed: 0,trial,algorithm,split,explainer,cid,smiles,target ID,experimental,prediction,expected_value,mae,analog_series_id,mmp_id,similarity,dPot,shap_values,fingerprint,conf_expected,Set,mmp_trial
0,0,SVR,Random,SVETA,CHEMBL472581,CS(=O)(=O)CCNCc1nc(-c2ccc3ncnc(Nc4ccc5c(cnn5Cc...,203,7.150028,7.250393,6.646533,0.100365,305,166,0.812500,0.001286,"[0.0, 0.032219119696185815, -0.055681158185522...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.250393,Train,0
1,0,SVR,Random,SVETA,CHEMBL509914,O=C(C=Cc1ccccc1)Cc1cc2c(Nc3cccc(Br)c3)ncnc2cn1,203,8.040005,8.140346,6.646533,0.100341,27,125,0.789474,1.120014,"[0.0, 0.021741757320681867, -0.091442926251187...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",8.140346,Train,0
2,0,SVR,Random,SVETA,CHEMBL4541014,COc1cc(Nc2ncnc3cc(-c4ccccc4)sc23)cc(OC)c1OC,203,7.638272,7.538456,6.646533,0.099816,253,0,0.800000,1.151490,"[0.0, 0.04179085647819404, -0.0580110349601591...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.538456,Train,0
3,0,SVR,Random,SVETA,CHEMBL128987,Brc1cccc(Nc2ncnc3cc(NCCN4CCOCC4)ncc23)c1,203,8.489991,8.561117,6.646533,0.071125,17,250,0.881356,0.219966,"[0.0, 0.03716784869633809, -0.0843523984224269...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",8.561117,Train,0
4,0,SVR,Random,SVETA,CHEMBL447223,C=CC(=O)Nc1ccc2ncnc(Cc3ccc(F)c(Cl)c3)c2c1,203,9.119987,9.183653,6.646533,0.063666,80,194,0.661017,0.030041,"[0.0, 0.05648243399732695, -0.1151572299096595...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",9.183653,Train,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
475,9,RFR,Stratified,TreeExplainer,CHEMBL1766496,Cn1c(=O)c(Cc2ccc(F)cc2F)cc2cnc(NC3CCOCC3)nc21,260,7.397940,7.804251,7.365917,0.406311,283,294,0.558824,0.279841,"[0.0, -0.002593746648617525, -2.67996651018620...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.804251,Test,0
476,9,RFR,Stratified,TreeExplainer,CHEMBL4647072,c1ccc2cc(-c3nnc(N4CCNCC4)cc3-c3ccncc3)ccc2c1,260,6.795880,6.829682,7.365917,0.033802,142,314,0.755102,0.157608,"[0.0, -0.002915379880214459, 1.434880119631998...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",6.829682,Test,0
477,9,RFR,Stratified,TreeExplainer,CHEMBL1784169,Cc1nc(CNC(=O)c2ccc(-c3cc(C(=O)NC4CC4)cc(F)c3C)...,260,7.100015,6.861094,7.365917,0.238922,184,161,0.655738,0.600019,"[0.0, -0.0036892747342790245, 8.46013695991132...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...",6.861094,Test,0
478,9,RFR,Stratified,TreeExplainer,CHEMBL3086677,O=C1c2cc(OCCC3CCOCC3)ccc2COc2cc(Nc3ccc(F)cc3F)...,260,6.770636,6.492897,7.365917,0.277739,28,253,0.704225,0.886942,"[0.0, -0.0037749859997347813, -1.2529271480161...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...",6.492897,Test,0


In [5]:
df_shap_loaded = df_shap_loaded_all.query('Set == "Test"')
display(df_shap_loaded)

Unnamed: 0,trial,algorithm,split,explainer,cid,smiles,target ID,experimental,prediction,expected_value,mae,analog_series_id,mmp_id,similarity,dPot,shap_values,fingerprint,conf_expected,Set,mmp_trial
0,0,SVR,Random,SVETA,CHEMBL3774926,C#Cc1cccc(Nc2ncnc3cc4c(cc23)N(CCCN2CCOCC2)C(=O...,203,6.856985,6.896116,6.646533,0.039131,24,273,0.871429,0.182090,"[0.0, 0.024560749532099555, 0.0757187703470913...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",6.896116,Test,0
1,0,SVR,Random,SVETA,CHEMBL2334001,COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCO,203,7.197911,7.912631,6.646533,0.714721,278,9,0.630769,0.802089,"[0.0, 0.10321972999540505, -0.1523825149093888...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...",7.912631,Test,0
2,0,SVR,Random,SVETA,CHEMBL4546122,COc1ccc(-c2cc3ncnc(Nc4cc(OC)c(OC)c(OC)c4)c3s2)cc1,203,7.283997,7.087565,6.646533,0.196432,253,1,0.820000,0.716003,"[0.0, 0.07220751201514818, -0.0532990340197648...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.087565,Test,0
3,0,SVR,Random,SVETA,CHEMBL453378,CCOC(=O)C=CC(=O)Cc1cc2c(Nc3cccc(Br)c3)ncnc2cn1,203,8.819874,8.350949,6.646533,0.468925,143,202,0.589744,0.270142,"[0.0, 0.020148654861969935, -0.081527995396772...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",8.350949,Test,0
4,0,SVR,Random,SVETA,CHEMBL3622675,C#Cc1cccc(Nc2ncnc3cc(OCC)c(NC(=O)C=CCN(C)C4CC4...,203,8.619789,8.556186,6.646533,0.063603,10,325,0.890411,0.138303,"[0.0, 0.007068350585841467, -0.062895413946332...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",8.556186,Test,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
475,9,RFR,Stratified,TreeExplainer,CHEMBL1766496,Cn1c(=O)c(Cc2ccc(F)cc2F)cc2cnc(NC3CCOCC3)nc21,260,7.397940,7.804251,7.365917,0.406311,283,294,0.558824,0.279841,"[0.0, -0.002593746648617525, -2.67996651018620...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.804251,Test,0
476,9,RFR,Stratified,TreeExplainer,CHEMBL4647072,c1ccc2cc(-c3nnc(N4CCNCC4)cc3-c3ccncc3)ccc2c1,260,6.795880,6.829682,7.365917,0.033802,142,314,0.755102,0.157608,"[0.0, -0.002915379880214459, 1.434880119631998...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",6.829682,Test,0
477,9,RFR,Stratified,TreeExplainer,CHEMBL1784169,Cc1nc(CNC(=O)c2ccc(-c3cc(C(=O)NC4CC4)cc(F)c3C)...,260,7.100015,6.861094,7.365917,0.238922,184,161,0.655738,0.600019,"[0.0, -0.0036892747342790245, 8.46013695991132...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...",6.861094,Test,0
478,9,RFR,Stratified,TreeExplainer,CHEMBL3086677,O=C1c2cc(OCCC3CCOCC3)ccc2COc2cc(Nc3ccc(F)cc3F)...,260,6.770636,6.492897,7.365917,0.277739,28,253,0.704225,0.886942,"[0.0, -0.0037749859997347813, -1.2529271480161...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...",6.492897,Test,0


# Load MMP datasets

In [6]:
mmp_path = "./ccrlib_master/"
df_mmp = pd.read_csv(mmp_path + "df_mmp_final_top10.csv")
df_mmp

Unnamed: 0,core,as,sub_1,cid_1,sub_2,cid_2,tid,mmp_id,dpot,similarity,mmp_trial
0,COc1cc(Nc2ncnc3cc([*:1])sc23)cc(OC)c1OC,253,c1ccc([*:1])cc1,CHEMBL4541014,Brc1ccc([*:1])cc1,CHEMBL4572443,203,0,1.151490,0.800000,0
1,COc1cc(Nc2ncnc3cc([*:1])sc23)cc(OC)c1OC,253,Ic1ccc([*:1])cc1,CHEMBL4473768,COc1ccc([*:1])cc1,CHEMBL4546122,203,1,0.716003,0.820000,0
2,COc1cc(Nc2ncnc3cc([*:1])sc23)cc(OC)c1OC,253,c1csc([*:1])c1,CHEMBL4460381,Clc1ccc([*:1])cc1,CHEMBL4552482,203,2,2.038223,0.685185,0
3,COc1cc2ncnc(N3CCCc4ccccc43)c2cc1NC(=O)C=CC[*:1],135,CN(C)[*:1],CHEMBL4176787,CC[*:1],CHEMBL4162530,203,3,0.338819,0.814286,0
4,Clc1cc(Nc2ncnc3cccc(O[*:1])c23)ccc1OCc1ccccn1,207,C[*:1],CHEMBL194389,C1CC([*:1])CCO1,CHEMBL193578,203,4,0.736759,0.718750,0
...,...,...,...,...,...,...,...,...,...,...,...
45337,O=c1cc(-c2[nH]c([*:1])nc2-c2ccc(F)cc2)cc[nH]1,3,O=C(O)c1ccc([*:1])cc1,CHEMBL3313935,C#Cc1ccc([*:1])cc1,CHEMBL3314276,260,475,0.504318,0.692308,9
45338,Nc1ccccc1Nc1ccc2c(c1)CCc1ccc(O[*:1])cc1C2=O,73,OCC(O)C[*:1],CHEMBL2152936_CHEMBL2152938,C[*:1],CHEMBL2152777,260,476,1.364568,0.677966,9
45339,O=C1NCc2c(-c3ccccc3Cl)nc(O[*:1])nc2N1c1c(Cl)cc...,144,CN(C)CC[*:1],CHEMBL211426,[*:1],CHEMBL213846,260,477,0.021189,0.616667,9
45340,O=C1c2ccc(Nc3ccccc3)cc2CCc2ccc(O[*:1])cc21,200,C[*:1],CHEMBL2152784,[*:1],CHEMBL2152796,260,478,0.162727,0.673913,9


# Select dataset to map

In [7]:
split = "Stratified"
df_shap_split = df_shap_loaded.query(f'split == "{split}"').sort_values('mmp_id', ascending=False)
df_shap_split

Unnamed: 0,trial,algorithm,split,explainer,cid,smiles,target ID,experimental,prediction,expected_value,mae,analog_series_id,mmp_id,similarity,dPot,shap_values,fingerprint,conf_expected,Set,mmp_trial
612,2,SVR,Stratified,SVETA,CHEMBL231209,Nc1nccc2scc(-c3ccc(NC(=O)Nc4cccc(F)c4)cc3)c12,279,7.823909,7.792598,6.780393,0.031311,469,712,0.577465,0.875061,"[0.0, -0.042315349732468664, 0.125772584619251...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.792598,Test,0
253,1,SVR,Stratified,SVETA,CHEMBL231209,Nc1nccc2scc(-c3ccc(NC(=O)Nc4cccc(F)c4)cc3)c12,279,7.823909,8.059602,6.605561,0.235693,469,712,0.577465,0.875061,"[0.0, 0.0049452889600290805, 0.026442399419653...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",8.059602,Test,0
686,4,SVR,Stratified,SVETA,CHEMBL1980297,Nc1ncc(-c2cnn(CCO)c2)c2scc(-c3ccc(NC(=O)Nc4ccc...,279,8.698970,7.513292,6.658732,1.185678,469,712,0.577465,0.875061,"[0.0, -0.041108049796584355, -0.03691504541052...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...",7.513292,Test,0
575,6,RFR,Stratified,TreeExplainer,CHEMBL231209,Nc1nccc2scc(-c3ccc(NC(=O)Nc4cccc(F)c4)cc3)c12,279,7.823909,7.511273,7.162679,0.312635,469,712,0.577465,0.875061,"[0.0, 0.0006084641881898278, -0.00030742519273...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.511273,Test,0
686,4,RFR,Stratified,TreeExplainer,CHEMBL1980297,Nc1ncc(-c2cnn(CCO)c2)c2scc(-c3ccc(NC(=O)Nc4ccc...,279,8.698970,7.591821,6.987765,1.107149,469,712,0.577465,0.875061,"[0.0, -0.0005145513361640042, -3.7467281072167...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...",7.591822,Test,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64,6,SVR,Stratified,SVETA,CHEMBL1241427_CHEMBL431784_CHEMBL99414,CNC(=O)Oc1ccc2c(c1)C1(C)CCN(C)C1N2,220,7.246417,6.358522,6.202381,0.887895,195,0,0.576271,1.644357,"[0.0, 0.025880837740166147, -0.060095810589210...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",6.358522,Test,0
51,2,RFR,Stratified,TreeExplainer,CHEMBL564986,O=C(Cc1ccc(C(F)(F)F)cc1)Nc1ccc(F)c(OCCCN2CCOCC...,2409,6.806875,7.304029,7.787603,0.497153,179,0,0.626866,1.373581,"[0.0, 0.0037664150739146864, 1.448210125090554...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...",7.304028,Test,0
308,0,RFR,Stratified,TreeExplainer,CHEMBL3110016,O=C(NO)C1C(c2ccccc2)C1c1cccc(-c2ncc(F)cn2)c1,1865,5.275724,5.967153,6.960280,0.691429,293,0,0.560976,0.262496,"[0.0, -0.003005826466542203, -0.00214813620201...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",5.967153,Test,0
326,3,RFR,Stratified,TreeExplainer,CHEMBL4572443,COc1cc(Nc2ncnc3cc(-c4ccc(Br)cc4)sc23)cc(OC)c1OC,203,6.486782,7.381853,7.703351,0.895070,253,0,0.800000,1.151490,"[0.0, -0.0020277417244506067, -0.0003055260831...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.381853,Test,0


# Generate MOL, bit info and atom wise shap values

In [14]:
tqdm.tqdm.pandas()

In [16]:
df_shap_split['mol'] = [Chem.MolFromSmiles(x) for x in tqdm.tqdm(df_shap_split['smiles'])]
df_shap_split['bit_info'] = [get_ecfp4_bit_info(smiles) for smiles in tqdm.tqdm(df_shap_split['smiles'])]
df_shap_split['column_mapped_shap'] = df_shap_split.progress_apply(lambda x: shap_to_atom_weight(mol=x.mol, dict_bit_info=x.bit_info, shapley_values=x.shap_values), axis=1)
df_shap_split['HA'] = df_shap_split['smiles'].progress_apply(lambda x: count_heavy_atoms(x))
df_shap_split

100%|██████████| 90700/90700 [00:21<00:00, 4306.44it/s]
100%|██████████| 90700/90700 [00:29<00:00, 3073.58it/s]
100%|██████████| 90700/90700 [01:17<00:00, 1175.41it/s]
100%|██████████| 90700/90700 [00:00<00:00, 155842.18it/s]


Unnamed: 0,trial,algorithm,split,explainer,cid,smiles,target ID,experimental,prediction,expected_value,...,dPot,shap_values,fingerprint,conf_expected,Set,mmp_trial,mol,bit_info,column_mapped_shap,HA
612,2,SVR,Stratified,SVETA,CHEMBL231209,Nc1nccc2scc(-c3ccc(NC(=O)Nc4cccc(F)c4)cc3)c12,279,7.823909,7.792598,6.780393,...,0.875061,"[0.0, -0.042315349732468664, 0.125772584619251...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.792598,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260376...,"{56: ((14, 2),), 184: ((17, 2),), 191: ((12, 1...","[0.6138131423114661, 0.30262942439721496, 0.02...",27
253,1,SVR,Stratified,SVETA,CHEMBL231209,Nc1nccc2scc(-c3ccc(NC(=O)Nc4cccc(F)c4)cc3)c12,279,7.823909,8.059602,6.605561,...,0.875061,"[0.0, 0.0049452889600290805, 0.026442399419653...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",8.059602,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260376...,"{56: ((14, 2),), 184: ((17, 2),), 191: ((12, 1...","[0.5930084798165375, 0.30460524251568305, -0.0...",27
686,4,SVR,Stratified,SVETA,CHEMBL1980297,Nc1ncc(-c2cnn(CCO)c2)c2scc(-c3ccc(NC(=O)Nc4ccc...,279,8.698970,7.513292,6.658732,...,0.875061,"[0.0, -0.041108049796584355, -0.03691504541052...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...",7.513292,Test,0,<rdkit.Chem.rdchem.Mol object at 0x0000025B395...,"{10: ((13, 2),), 43: ((7, 1),), 56: ((22, 2),)...","[0.35222626261260337, -0.07853901646811377, -0...",35
575,6,RFR,Stratified,TreeExplainer,CHEMBL231209,Nc1nccc2scc(-c3ccc(NC(=O)Nc4cccc(F)c4)cc3)c12,279,7.823909,7.511273,7.162679,...,0.875061,"[0.0, 0.0006084641881898278, -0.00030742519273...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.511273,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260376...,"{56: ((14, 2),), 184: ((17, 2),), 191: ((12, 1...","[0.053600748505578444, 0.03979514973030453, 0....",27
686,4,RFR,Stratified,TreeExplainer,CHEMBL1980297,Nc1ncc(-c2cnn(CCO)c2)c2scc(-c3ccc(NC(=O)Nc4ccc...,279,8.698970,7.591821,6.987765,...,0.875061,"[0.0, -0.0005145513361640042, -3.7467281072167...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...",7.591822,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260376...,"{10: ((13, 2),), 43: ((7, 1),), 56: ((22, 2),)...","[0.08970334490605274, 0.03934064787168397, 0.0...",35
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64,6,SVR,Stratified,SVETA,CHEMBL1241427_CHEMBL431784_CHEMBL99414,CNC(=O)Oc1ccc2c(c1)C1(C)CCN(C)C1N2,220,7.246417,6.358522,6.202381,...,1.644357,"[0.0, 0.025880837740166147, -0.060095810589210...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",6.358522,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260354...,"{74: ((7, 2),), 120: ((18, 2),), 301: ((18, 1)...","[0.44120750893578875, 0.5744055703769274, 0.76...",19
51,2,RFR,Stratified,TreeExplainer,CHEMBL564986,O=C(Cc1ccc(C(F)(F)F)cc1)Nc1ccc(F)c(OCCCN2CCOCC...,2409,6.806875,7.304029,7.787603,...,1.373581,"[0.0, 0.0037664150739146864, 1.448210125090554...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...",7.304028,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260354...,"{13: ((21, 1),), 41: ((1, 1),), 80: ((2, 0), (...","[-0.028324448300812773, -0.05160956879102481, ...",31
308,0,RFR,Stratified,TreeExplainer,CHEMBL3110016,O=C(NO)C1C(c2ccccc2)C1c1cccc(-c2ncc(F)cn2)c1,1865,5.275724,5.967153,6.960280,...,0.262496,"[0.0, -0.003005826466542203, -0.00214813620201...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",5.967153,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260354...,"{68: ((21, 2),), 335: ((4, 1),), 378: ((19, 0)...","[-0.04223952263293593, -0.0718805097932884, -0...",26
326,3,RFR,Stratified,TreeExplainer,CHEMBL4572443,COc1cc(Nc2ncnc3cc(-c4ccc(Br)cc4)sc23)cc(OC)c1OC,203,6.486782,7.381853,7.703351,...,1.151490,"[0.0, -0.0020277417244506067, -0.0003055260831...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",7.381853,Test,0,<rdkit.Chem.rdchem.Mol object at 0x00000260354...,"{62: ((12, 1),), 131: ((21, 1),), 162: ((20, 1...","[-0.003716929056370039, -0.006848392501659079,...",29


# Free memory

In [17]:
del df_shap_loaded_all

# Drawing drawing parameters

In [18]:
%autoreload
font = utils_pil.get_font(75)
font_bold = utils_pil.get_font(75, utils_pil.DEFAULT_FONT_BOLD)
mol_size = (1500, 1000)

In [71]:
from tqdm import tqdm

def get_compound_shap_map(df_shap_sp, target_id=None, trial=1, mmp_path_file = mmp_path +'mmp_results/top10/'):
    
    # generate images for each mmp compound
    df_shap_l = []

    df_target = df_shap_sp.loc[df_shap_sp['target ID'] == target_id]
    
    for trial in range(trial):
        
        df_target_trial = df_target.query('trial == @trial')
        
        for mmid in tqdm(df_target_trial.mmp_id.unique()[:]):
            
            df_shap_mmp = df_target_trial.query(f'mmp_id == {mmid}').sort_values(by='HA', ascending=True)
            
            mmp_core = get_mmp_core(df_shap_mmp['target ID'].unique()[0], df_shap_mmp.analog_series_id.unique()[0], df_shap_mmp.cid.values[0], mmp_path_file)
            
            df_shap_mmp['Mapping img'] = df_shap_mmp.apply(lambda x: get_atom_wise_weight_map_ref_mol(m=x.mol, weights=x.column_mapped_shap, mol_size=mol_size, 
                                                                                                      ref_mol=df_shap_mmp.head(1).mol.values[0],
                                                                                                      highlightsub=mmp_core, cmap=None), axis=1)
            df_shap_l.append(df_shap_mmp)

    df_shap_final = pd.concat(df_shap_l)
    df_shap_final.rename(columns={'train_test': 'Set'}, inplace=True)
    df_shap_final['Set'] = df_shap_final['Set'].str.replace('test', 'Test')
    del df_shap_l
    
    return df_shap_final

def get_shap_map(df_shap_final, prediction_type = 'correct', correct_thr=0.5, incorrect_thr=1.5):
    
    if prediction_type == 'correct':
        df_shap_final_ = df_shap_final.loc[(df_shap_final['mae'] < correct_thr)]

    else:
        df_shap_final_ = df_shap_final.loc[(df_shap_final['mae'] > incorrect_thr)]
        
    #display(df_shap_final_)
    df_img = df_shap_final_[['target ID', 'algorithm', 'explainer', 'Mapping img', 'prediction', 'experimental', 'Set', 'cid', 'mmp_id', 'trial']].copy()
    df_img['Prediction'] = df_img['prediction'].apply(lambda x: round(x, 1))
    df_img['Experimental'] =  df_img['experimental'].apply(lambda x: str(round(x, 1)))
    #display(df_img)
    del df_shap_final_

    column_img = 'img'
    column_img_title = 'Title image pred'
    column_img_title_ = 'Title image exp'

    df_img[column_img_title] = df_img['Prediction'].apply(lambda x: utils_pil.get_text_image(f'Prediction: {x}', font=font, width_spacing = 0, height_spacing= 10))
    df_img[column_img_title_] = df_img['Experimental'].apply(lambda x: utils_pil.get_text_image(f'Experimental:  {x}', font=font, width_spacing = 40, height_spacing= 10))

    df_img[column_img] = df_img[['Title image exp', 'Title image pred' ,'Mapping img']].apply(lambda x: utils_pil.get_grid_image(x.to_numpy().reshape((3,1))), axis=1)
    df_img.reset_index(inplace=True)
    
    final_img = []
    for mmidx in df_img.mmp_id.unique()[:]:
        df_img_target_mmp =  df_img.query(f'mmp_id == {mmidx}')
        if len(df_img_target_mmp) == 2:
            final_img.append(df_img_target_mmp)
    
    del df_img
    
    if len(final_img) > 0:
        df_final_img = pd.concat(final_img)
        df_final_img.rename(columns={'algorithm': 'Algorithm'}, inplace=True)
        #display(df_final_img)
        return df_final_img
    else:
        return None

def get_cpd_shap_img(df_image, prediction_type = 'correct', target_id=None, trial=1, max_mmp_images=10):
    
    if df_image is None:
        print('No images to generate')
    else:
        for trl in range(trial):
            df_img_target_trial = df_image.query('trial == @trl')

            for i, mmp_idx in enumerate(df_img_target_trial.mmp_id.unique()[:]):
                if i > max_mmp_images:
                    print('Max images reached')
                    break
                else:
                    print('Generating image for mmp_id: ', mmp_idx)
                    df_img_mmp = df_img_target_trial.query('mmp_id == @mmp_idx')
                    for algorithm in ['SVR', 'RFR']:
                        df_img_c = df_img_mmp.query('Algorithm == @algorithm')
                        
                        df_img_c = df_img_c.sort_values(by=['Experimental'], ascending=True)
            
                        df_img_pivot = df_img_c.pivot(index='mmp_id', columns=['Algorithm', 'Experimental', ], values='img')
                        
                        df_img_pivot_t = df_img_pivot.T.sort_values(by=['Experimental'], ascending=False)
                        
                        df_img_pivot_t['Algorithm'] = df_img_pivot.apply(
                            lambda x: utils_pil.get_text_image(str(x.name[0]), font=font_bold), axis=0)
                        df_img_pivot = df_img_pivot_t.T
            
                        del df_img_pivot_t
            
                        df_img_pivot = df_img_pivot.loc[['Algorithm', df_img_pivot.index[0], ]]
            
                        img = utils_pil.get_grid_image(df_img_pivot.to_numpy())
            
                        img.save(os.path.join(create_directory(f'./shap_map_{prediction_type}_predictions/{target_id}/{split}/'),
                                              f'{mmp_idx}_{algorithm}_{trial}.jpg'), quality=100, subsampling=0)

In [72]:
for t in df_shap_split['target ID'].unique()[:1]:
    
    df_shap_map = get_compound_shap_map(df_shap_split, target_id=t, trial=1)
    
    for pred_type in ['incorrect',]: #'correct'
        map_predtype = get_shap_map(df_shap_map, prediction_type=pred_type)
        shap_img  = get_cpd_shap_img(map_predtype, prediction_type=pred_type, target_id=t, trial=1, max_mmp_images=10)
        del map_predtype, shap_img
    del df_shap_map

100%|██████████| 713/713 [02:50<00:00,  4.17it/s]


Generating image for mmp_id:  709
Generating image for mmp_id:  701
Generating image for mmp_id:  530
Generating image for mmp_id:  358
Generating image for mmp_id:  314
Generating image for mmp_id:  234
Generating image for mmp_id:  220
Generating image for mmp_id:  196
Generating image for mmp_id:  191
Generating image for mmp_id:  178
Generating image for mmp_id:  140
Max images reached


# Generate Colorbar

In [None]:
get_color_bar(result_path='./figures/', cmap=None, ticklabels=['Negative', 0, 'Positive'], label='SV / SHAP values', font_size=20)