# Plots for this paper

This files provide instructions for two areas of feature attribution comparison:
- Global direction scores
- Atom-level coloring accuracy


## Global Direction comparation

In [None]:
import dill
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as mticker
from IPython.display import SVG
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import wilcoxon

In [None]:
# Group lasso vs. non for gradInput
global_dir_gradinput_1O42_843 = pd.read_csv("./global_dir_compare/gradinput/global_dir_gradinput_1O42-843.csv")
data_x = global_dir_gradinput_1O42_843["global_dir_test_MSE_L_gradinput_1O42-843"]
data_y = global_dir_gradinput_1O42_843["global_dir_test_MSE_N_GL_gradinput_1O42-843"]
# Scatter plot with connecting lines
# Perform the Wilcoxon signed-rank test
stat, p = wilcoxon(data_x, data_y)
change_xy = sum((data_y-data_x)/data_x)/len(data_x)*100
z_values = np.arange(50, 100, 5)
plt.figure(figsize=(10, 6))
plt.scatter(z_values, data_x, color='blue', label='w/o group lasso')
plt.scatter(z_values, data_y, color='red', label='w. group lasso')
plt.plot(z_values, data_x, 'b--', alpha=0.5)
plt.plot(z_values, data_y, 'r-', alpha=0.5)
plt.title('Global direction (w. vs. w/o group lasso) across different minimum MCS %', fontsize=14)
plt.suptitle("GradInput", fontsize=16)
plt.xlabel('Minimum shared MCS %', fontsize=14)
plt.ylabel('Global direction scores', fontsize=14)
text_string = (f"Averaged increase: {round(change_xy, 2)}%\n"
               f"p-values (Wilcoxon): {str(round(p, 4))}")
plt.text(72, 0.51, text_string, fontsize=12, color='green', 
        bbox=dict(facecolor='white', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
plt.legend()
plt.grid(True)
plt.show()

## Plot atom-level feature attribution

In [None]:
def SVGplot(hetero_data, iorj, smile, pair_id, mask_data, mask_pair, mask_id):
    # Load your molecule, for example using a SMILES string
    smiles=hetero_data[pair_id][iorj][smile]
    mol = Chem.MolFromSmiles(smiles)
    # Define your mask vector (it should have the same length as the number of atoms in the molecule)
    mask_vector = mask_data[mask_pair][mask_id]
    # Create a dictionary of colors for each atom
    atom_colors = {i: value_to_color(mask_vector[i]) for i in range(len(mask_vector))}
    highlight_radii = {i: 0.5 for i in range(mol.GetNumAtoms())} 
    # Drawing options
    drawer = rdMolDraw2D.MolDraw2DSVG(400, 400)
    opts = drawer.drawOptions()

    for i in range(mol.GetNumAtoms()):
        opts.atomLabels[i] = mol.GetAtomWithIdx(i).GetSymbol()

    # Draw the molecule with highlighted atoms
    drawer.DrawMolecule(mol, highlightAtoms=range(mol.GetNumAtoms()), highlightAtomColors=atom_colors,  highlightAtomRadii=highlight_radii)
    drawer.FinishDrawing()
    # Convert the drawing to an SVG
    svg = drawer.GetDrawingText()
    return svg

In [None]:
def create_color_map_legend():
    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(6, 1))
    fig.subplots_adjust(bottom=0.5)

    # Set up the colormap
    cmap = plt.get_cmap('coolwarm')
    norm = mcolors.Normalize(vmin=-1, vmax=1)

    # Create a colorbar
    cb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation='horizontal')
    cb.set_label('Feature Attribution')

    # Save the color map legend as an SVG
    plt.savefig('color_map_legend.svg', format='svg')

create_color_map_legend()

In [None]:
with open('./colors/ig/1O42-843/w_group_lasso/1O42-843_seed_1337_nn_MSE+N_mean_ig_hiddenDim_32_GL_0.001_test.pt', 'rb') as file:
    mask_data = dill.load(file)
with open('.code/data/1O42-843/1O42-843_seed_1337_test.pt', 'rb') as file:
    hetero_data = dill.load(file)
    
svgplot = SVGplot(hetero_data, "data_j", "smiles", 0, mask_data, 0, 1)
from IPython.display import SVG
SVG(svgplot)