In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300
plt.rcParams["font.family"] = 'Arial'
plt.rcParams.update({'font.size': 20})
from rdkit import Chem
import numpy as np

In [None]:
df = pd.read_csv('CN_shap_250213.csv')
df['can_smiles'] = [  Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in df['Canonical_SMILES']]
df = df.sort_values(by=['can_smiles'])

In [None]:
Shap_contrib = []
for _, row in df.iterrows():
    shap_contrib = np.array([float(x) for x in row['atomwise_shap'].replace('\n','')[1:-1].split()])
    Shap_contrib.append(shap_contrib)
df['Shap_contrib'] = Shap_contrib

In [None]:
df['Shap_min'] = df.Shap_contrib.apply(lambda x: min(x))
df['Shap_max'] = df.Shap_contrib.apply(lambda x: max(x))

In [None]:
df = df.sort_values(by=['Shap_min'], ascending = True)
df.head(n=3)

In [None]:
df = df.sort_values(by=['Shap_max'], ascending = False)
df.head(n=3)

In [None]:
df.Shap_min.min(), df.Shap_max.max()

In [None]:
df

In [None]:
from IPython.display import display, SVG, HTML
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D


norm = mpl.colors.Normalize(
    vmin=0,
    vmax=20)
cmap = cm.RdBu
m = cm.ScalarMappable(norm=norm, cmap=cmap)

#fig = plt.figure()
#ax = fig.add_subplot(111)
#ax.set_yticks([])
#plt.colorbar(m, ax=[ax], location='left')

rgb2hex = lambda r,g,b: f"#{r:02x}{g:02x}{b:02x}"

def get_color(x):
    return m.to_rgba(float(x))
    #rgba = np.asarray(m.to_rgba(float(x)))
    #return tuple(rgba[:-1])

def draw_mol_svg(mol_str, color_dict, prop, figsize=(600, 600)):
    mol = Chem.MolFromSmiles(mol_str)
    
    '''
    for i, atom in enumerate(mol.GetAtoms()):
        atom.SetProp("atomNote", '%.1f' % prop[i] )
    '''
    
    mc = Chem.Mol(mol.ToBinary())
    Chem.Kekulize(mc)
    rdDepictor.Compute2DCoords(mc)

    drawer = rdMolDraw2D.MolDraw2DSVG(*figsize)
    #drawer.drawOptions().minFontSize = 80
    drawer.SetFontSize(40)
    
    '''
    opts = drawer.drawOptions()
    for i in range(mol.GetNumAtoms()):
        opts.atomLabels[i] = mol.GetAtomWithIdx(i).GetSymbol() + ' ' + '%.1f' % prop[i]
    '''
    
    n_atoms = len(mol.GetAtoms())
    assert n_atoms == len(color_dict), "{} atoms in mol, {} colors".format(n_atoms, len(color_dict))

    if color_dict is not None:
        drawer.DrawMolecule(
            mc, highlightAtoms=range(len(color_dict.keys())),
            highlightAtomColors=color_dict,
            highlightBonds=False)

    else:
        drawer.DrawMolecule(mc)

    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    svg = svg.replace('svg:', '').replace(':svg', '')
    return svg

def draw_shap(row):
    norm = mpl.colors.Normalize(
    vmin= min(row['Shap_contrib']) - ( abs(min(row['Shap_contrib'])) * 0.2 ),
    vmax= max(row['Shap_contrib']) + ( abs(max(row['Shap_contrib'])) * 0.2 ) 
    #vmin = 0.0,
    #vmax = 20.0
    )
    cmap = cm.RdBu
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    
    print('min_shap: ', min(row['Shap_contrib']), 'max_shap:', max(row['Shap_contrib']))
    svg = draw_mol_svg(
        row['Canonical_SMILES'],
        {i: m.to_rgba(row['Shap_contrib'][i]) for i in range(len(row['Shap_contrib'])) },
        row['Shap_contrib']
    )
    #return svg
    
    t = svg.split('\n')
    #t.insert(-2, f'<text x="0" y="485">{row["Canonical_SMILES"]}</text>')
    t.insert(-2, f'<text x="0" y="600">Measured: {row["CN"]:.1f}, Predicted: {row["predicted"]:.1f}</text>')
    return SVG('\n'.join(t)

In [None]:
for _, row in df.iterrows():
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCC1=CC=C(C=C1)OC')):
    if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCCCCCCCCCCCC(=O)OCC')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCCC=CC')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('COc1ccccc1OC')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('COc1ccccc1OC')):
        display(draw_shap(row))
        print(row['Canonical_SMILES'])
        break

In [None]:
import cairosvg
import matplotlib as mpl
import matplotlib.cm as cm
from IPython.display import SVG
import io

def draw_shap(row, filename="output.png"):
    norm = mpl.colors.Normalize(
        vmin=min(row['Shap_contrib']) - (abs(min(row['Shap_contrib'])) * 0.2),
        vmax=max(row['Shap_contrib']) + (abs(max(row['Shap_contrib'])) * 0.2)
    )
    cmap = cm.RdBu
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    
    print('min_shap: ', min(row['Shap_contrib']), 'max_shap:', max(row['Shap_contrib']))
    
    svg = draw_mol_svg(
        row['Canonical_SMILES'],
        {i: m.to_rgba(row['Shap_contrib'][i]) for i in range(len(row['Shap_contrib']))},
        row['Shap_contrib']
    )
    
    png_bytes = cairosvg.svg2png(bytestring=svg.encode('utf-8'), dpi=300)

    with open(filename, "wb") as f:
        f.write(png_bytes)
    
    print(f"Saved: {filename}")

# draw_shap_and_save(row, filename="shap_output.png")

In [None]:
for _, row in df.iterrows():
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCC1=CC=C(C=C1)O')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCC=C(CCCC)CCCCCCC')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCCCCCC(=O)OCCCC')):
    if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CC(C)CCOC(=O)C(C)O')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCCCCCCCCCCCC(=O)OCC')):
        display(draw_shap(row, filename="2.png"))
        print(row['Canonical_SMILES'])
        break

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from rdkit import Chem
from rdkit.Chem import Draw

def draw_shap_with_narrow_colorbar(row):
    fig, ax = plt.subplots(figsize=(6, 4)) 
    
    vmin = min(row['Shap_contrib']) - (abs(min(row['Shap_contrib'])) * 0.2)
    vmax = max(row['Shap_contrib']) + (abs(max(row['Shap_contrib'])) * 0.2)

    cmap = cm.RdBu
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    atom_colors = {i: sm.to_rgba(row['Shap_contrib'][i]) for i in range(len(row['Shap_contrib']))}
    mol = Chem.MolFromSmiles(row['Canonical_SMILES'])
    img = Draw.MolToImage(mol, size=(300, 300), highlightAtoms=list(atom_colors.keys()), highlightAtomColors=atom_colors)

    ax.imshow(img)
    ax.axis("off")

    cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7]) 
    cbar = plt.colorbar(sm, cax=cbar_ax)
    cbar.set_ticks([]) 
    cbar.set_label("SHAP Contribution", fontsize=12)  

    plt.show()

# draw_shap_with_narrow_colorbar(row)

In [None]:
for _, row in df.iterrows():
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCC1=CC=C(C=C1)O')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCC=C(CCCC)CCCCCCC')):
    if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCCCCCC(=O)OCCCC')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CC(C)CCOC(=O)C(C)O')):
    #if row['can_smiles'] == Chem.MolToSmiles(Chem.MolFromSmiles('CCCCCCCCCCCCCC(=O)OCC')):
        display(draw_shap_with_narrow_colorbar(row))
        print(row['Canonical_SMILES'])
        break