# Functions and Imports (no user input)

In [None]:
import os
import pandas as pd
import numpy as np
import openpyxl

from scipy import stats
from scipy.interpolate import interp1d

from sklearn import linear_model
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler, QuantileTransformer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from yellowbrick.cluster import KElbowVisualizer

from datetime import date

import holoviews as hv
import bokeh
from bokeh.models import HoverTool
hv.extension('bokeh')

from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D

from IPython.display import SVG

# For alternative fonts:
# import matplotlib.font_manager
# fonts = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')
# font_names = [matplotlib.font_manager.get_font(f).family_name for f in fonts]
# print(font_names)


In [None]:
"""
--------------------------------------------------------------------------------------------------------------------------------
    Image generation from SMILES strings
--------------------------------------------------------------------------------------------------------------------------------
"""

# adapted from https://birdlet.github.io/2018/06/06/rdkit_svg_web/
def DrawMol(dataframe, smiles_column_loc, image_column, id_column='Name', molSize=(200, 100), kekulize=True):
    images = []
    for index, row in dataframe.iterrows():
        smiles_string = row.iloc[smiles_column_loc]  # Use iloc for position-based indexing
        
        # Check if the SMILES string is blank or NaN
        if pd.isnull(smiles_string) or smiles_string == '':
            print(f"Skipping ID: {row[id_column]} due to blank or NaN SMILES")
            continue

        try:
            mc = Chem.MolFromSmiles(smiles_string)
            if kekulize:
                try:
                    Chem.Kekulize(mc)
                except:
                    mc = Chem.Mol(smiles_string.ToBinary())

            if not mc.GetNumConformers():
                Chem.rdDepictor.Compute2DCoords(mc)

            drawer = rdMolDraw2D.MolDraw2DSVG(*molSize)
            drawer.DrawMolecule(mc)
            drawer.FinishDrawing()
            svg = drawer.GetDrawingText().replace('svg:', '')
            images.append(SVG(svg).data)
        except:
            print(f"Error for ID: {row[id_column]} with SMILES: {smiles_string}")

    try:
        dataframe.insert(smiles_column_loc+1, image_column, images)
    except: 
        dataframe[image_column] = images
        
    return dataframe
    
    
    
    """
--------------------------------------------------------------------------------------------------------------------------------
    Allow for user to input both name (string) or index number for column headers
--------------------------------------------------------------------------------------------------------------------------------
"""

def get_column_loc(column, dataframe):
    if isinstance(column, str):
        return dataframe.columns.get_loc(column), column
    else:
        return column, dataframe.columns[column]
    



"""
--------------------------------------------------------------------------------------------------------------------------------
    HoloViews: Scatter Plot
--------------------------------------------------------------------------------------------------------------------------------
"""
def scatter_plot(
    dataframe: pd.DataFrame, 
    x: str,  # x-axis data
    y: str,  # y-axis data
    title: str = 'default',  # title of plot
    x_label: str = 'default',  # axis label to be printed on plot (does not need to match dataframe name)
    x_range: tuple = None,  # range of x-axis
    y_label: str = 'default',  # axis label to be printed on plot (does not need to match dataframe name)
    y_range: tuple = None,  # range of y-axis
    legend: str = '',  # string with data label if using classifiers/building plots by category
    svgs: str = None,  # string with column name of svgs 
    hover_list: list = None,  # list of column names with data to be shown on hover 
    marker: str = 'o',  # marker type - most of the matplotlib markers are supported (https://matplotlib.org/stable/api/markers_api.html)
    bubbleplot: bool = False,  # if True, will create a bubble plot
    size: int = 10,  # size of markers (recommended: 10-20)
    bubblesize: str = None,  # string with column name for size of points in bubbleplot
    heatmap: bool = False,  # if True, will create a heatmap
    heatmap_col: str = '',  # color of heatmap
    clabel: str = 'default', # label for heatmap colorbar
    heatmap_color: str = 'Plasma',  # color of heatmap
    color: str = '#931319',  # color of markers
    line_color: str = '#29323d',  # color of marker line_color
    alpha: int = 1,  # transparency of markers
    groupby: str = None,  # string with column name to group data by
    height: int = 500,  #plot height (recommended: 500)
    width: int = 500,  #plot width (recommended: 500)
):
    
    """
    scatter_plot function based off of HoloViews 'Scatter' element. See documentation for more information:
    hv.help(hv.Scatter)
    https://holoviews.org/reference/elements/bokeh/Scatter.html
    """

    if x_label == 'default':  # if no x_label provided, use x column name
        x_label = x
    if y_label == 'default':  # if no y_label provided, use y column name
        y_label = y
    if clabel == 'default':  # if no clabel provided, use heatmap_col column name
        clabel = heatmap_col

    if not x_range:
        x_min = min(dataframe[x]); x_max = max(dataframe[x])
        x_buffer = abs(x_max-x_min)/10
        x_range = (x_min-x_buffer, x_max+x_buffer)
    if not y_range:
        y_min = min(dataframe[y]); y_max = max(dataframe[y])
        y_buffer = abs(y_max-y_min)/10
        y_range = (y_min-y_buffer, y_max+y_buffer)

    if groupby is not None:
        color = hv.Cycle(color).values
        hover_list.insert(0, groupby)

    if svgs == None and labels == None: # no hover information provided
        if title == 'default':  # if no title provided, define from x, y labels
            title = f'{y_label} vs. {x_label}'
        plt = hv.Scatter(dataframe, kdims=[x], vdims=[y], label=legend).opts(title=title, marker=marker, height=height, width=width, color=color, alpha=alpha, size=size, line_color=line_color)
    else:  # hover information provided, build list of hover tools
        hover_list.insert(0, y)
        tooltips = f'<div>end' # beginning of tooltips if no svgs provided
        if svgs != None:
            tooltips = f'<div><div>@{svgs}{{safe}}</div>end'  # beginning of tooltips if svgs are provided
            hover_list.insert(1, svgs)
        if len(hover_list) < 4:
            for label in hover_list:
                if label != svgs and label != y:
                    tooltips = tooltips.replace('end', f'<div><span style="font-size: 17px; font-weight: bold;">@{label}</span></div>end')
        else:
            for label in hover_list:
                if label != svgs and label != y:
                    tooltips = tooltips.replace('end', f'<div><span style="font-size: 12px;">{label}: @{label}</span></div>end')
        
        tooltips = tooltips.replace('end', '</div>')
        hover = HoverTool(tooltips=tooltips)
        if heatmap == False and bubbleplot == False:  # if no heatmap or bubbleplot, build scatter plot  
            if title == 'default':  # if no title provided, define from x, y labels
                title = f'{y_label} vs. {x_label}'          
            plt = hv.Scatter(dataframe, kdims=[x], vdims=hover_list, label=legend).opts(title=title, marker=marker, height=height, width=width, tools=[hover], color=color, alpha=alpha, size=size, line_color=line_color)

        elif heatmap == True and bubbleplot == False:
            if heatmap_col not in hover_list:
                hover_list.append(heatmap_col)
            if title == 'default':  # if no title provided, define from x, y labels
                title = f'{y_label} vs. {x_label}, colored by {heatmap_col}'
            plt = hv.Scatter(dataframe, kdims=[x], vdims=hover_list, label=legend).opts(title=title, marker=marker, height=height, width=width, tools=[hover], color=heatmap_col, cmap=heatmap_color, colorbar=True, clabel=clabel, alpha=alpha, size=size, line_color=line_color)

        elif heatmap == False and bubbleplot == True:
            if bubblesize not in hover_list:
                hover_list.append(bubblesize)
            if title == 'default':  # if no title provided, define from x, y labels
                title = f'{y_label} vs. {x_label}, sized by {bubblesize}'
            min_size = min(dataframe[bubblesize]); max_size = max(dataframe[bubblesize])
            plt = hv.Scatter(dataframe, kdims=[x], vdims=hover_list, label=legend).opts(title=title, marker=marker, height=height, width=width, tools=[hover], color=color, alpha=alpha, size=((hv.dim(bubblesize)-min_size)/max_size)*(10*size), line_color=line_color)

        elif heatmap == True and bubbleplot == True:
            if heatmap_col not in hover_list:
                hover_list.append(heatmap_col)
            if bubblesize not in hover_list:
                hover_list.append(bubblesize)

            if title == 'default':
                title = f'{y_label} vs. {x_label}, colored by {heatmap_col}, sized by {bubblesize}'
            min_size = min(dataframe[bubblesize]); max_size = max(dataframe[bubblesize])
            plt = hv.Scatter(dataframe, kdims=[x], vdims=hover_list, label=legend).opts(title=title, marker=marker, height=height, width=width, tools=[hover], color=heatmap_col, cmap=heatmap_color, colorbar=True, clabel=clabel, alpha=alpha, size=((hv.dim(bubblesize)-min_size)/max_size)*(10*size), line_color=line_color, xlim=x_range, ylim=y_range)
        
        if groupby != None:
            color = hv.Cycle(color).values
            plt = plt.opts(color=groupby, cmap=color)

        return plt
        



"""
--------------------------------------------------------------------------------------------------------------------------------
    HoloViews: Slope
--------------------------------------------------------------------------------------------------------------------------------
"""

def plot_slope(
    dataframe: pd.DataFrame, 
    x: str,  # string with column name, used to determine slope
    y: str,  # string with column name, used to determine slope
    x_label: str = 'default',  # axis label to be printed on plot (does not need to match dataframe name)
    y_label: str = 'default',  # axis label to be printed on plot (does not need to match dataframe name)
    color: str = '#000000',  # color of slope line
    line_width: int = 2,  # width of slope line
    alpha: int = 1,  # transparency of slope line
    height: int = 500,  # plot height (recommended: 500)
    width: int = 500,  # plot width (recommended: 500)
    pad: float = 0.0  # padding around data points
):
    
    if x_label == 'default':  # if no x_label provided, use x column name
        x_label = x
    if y_label == 'default':  # if no y_label provided, use y column name
        y_label = y
    
    x_min, x_max = dataframe[x].agg(['min', 'max'])

    y_min, y_max = dataframe[y].agg(['min', 'max'])


    line = hv.Bounds((x_min-pad, y_min-pad, x_max+pad, y_max+pad)).opts(color=None)

    slope, intercept, r_value, p_value, std_err = stats.linregress(dataframe[x], dataframe[y])
    slope_plt = hv.Slope(slope, intercept).opts(xlabel=x_label, ylabel=y_label, line_color=color, line_width=line_width, alpha=alpha, height=height, width=width)
    
    plt =  hv.Overlay(slope_plt*line)
    return plt, slope, intercept, r_value, p_value, std_err
        



"""
--------------------------------------------------------------------------------------------------------------------------------
    HoloViews: Confidence Interval
--------------------------------------------------------------------------------------------------------------------------------
"""

def plot_confidence_interval(
        dataframe: pd.DataFrame,  # dataframe
        x: str,  # string with column name, used to determine confidence interval
        y: str,  # string with column name, used to determine confidence interval
        x_label: str = 'default',  # axis label to be printed on plot (does not need to match dataframe name)
        x_range: tuple = None,  # range of x-axis
        y_label: str = 'default',  # axis label to be printed on plot (does not need to match dataframe name)
        y_range: tuple = None,  # range of y-axis
        ci: int = 0.999,  # confidence interval (0.9-0.99 recommended)
        color: str = '#5289a1',  # color of confidence interval
        line_color: str = '#FFFFFF',  # color of confidence interval line
        alpha: int = 0.2,  # transparency of confidence interval
        height: int = 500,  #plot height (recommended: 500)
        width: int = 500,  #plot width (recommended: 500)
):
        
    """ 
    Confidence interval calculations use inferences made on the mean and variance of the distributed data (assumes normal distribution)
    and is calculated by applying a student-t test. Plotting function based off of HoloViews 'Area' element as 'area between curves'. 
    See documentation for more information:
    hv.help(hv.Area)
    https://holoviews.org/reference/elements/bokeh/Area.html
    
    """  

    if x_label == 'default':  # if no x_label provided, use x column name
        x_label = x
    if y_label == 'default':  # if no y_label provided, use y column name
        y_label = y

    if not x_range:
        x_min = min(dataframe[x]); x_max = max(dataframe[x])
        x_buffer = abs(x_max-x_min)/10
        x_range = (x_min-x_buffer, x_max+x_buffer)
    if not y_range:
        y_min = min(dataframe[y]); y_max = max(dataframe[y])
        y_buffer = abs(y_max-y_min)/10
        y_range = (y_min-y_buffer, y_max+y_buffer)

    n = len(dataframe[x])
    t_value = stats.t.ppf(1 - (1 - ci) / 2, n - 2)  # t-value for confidence interval (student-t test for n-2 degrees of freedom)
    x_mean = np.mean(dataframe[x])  # mean of x values
    
    slope, intercept, r_value, p_value, std_err = stats.linregress(dataframe[x], dataframe[y])

    S_xx = (n * np.sum(dataframe[x] ** 2) - np.sum(dataframe[x]) ** 2) / n  # sample-corrected sum of squares (sum of the square of the difference between x and its mean)
    S_xy = (n * np.sum(dataframe[x] * dataframe[y]) - np.sum(dataframe[x]) * np.sum(dataframe[y])) / n  # sample-corrected covariance for x and y 
    S_yy = (n * np.sum(dataframe[y] ** 2) - np.sum(dataframe[y]) ** 2) / n  # sample-corrected sum of squares (sum of the square of the difference between y and its mean)
    
    SSE = S_yy - slope * S_xy # sum of squared estimate of errors (deviation of the observed value from the estimated value)
    s2 = SSE / (n - 2)  #variance of the x, y data
    s = np.sqrt(s2)  # standard deviation of the x, y data

    unique_x = np.unique(dataframe[x])  # unique x values (prevents overplotting of confidence interval)
    mean_upperconfidence_list = slope * unique_x + intercept + t_value * s * np.sqrt((1 / n + (np.square(unique_x - x_mean)) / S_xx))  # line for upper confidence interval
    mean_lowerconfidence_list = slope * unique_x + intercept - t_value * s * np.sqrt((1 / n + (np.square(unique_x - x_mean)) / S_xx))  # line for lower confidence interval

    upper_spread = interp1d(x=unique_x, y=mean_upperconfidence_list, kind='quadratic', fill_value='extrapolate')  # interpolation function for upper confidence interval (smooths line)
    lower_spread = interp1d(x=unique_x, y=mean_lowerconfidence_list, kind='quadratic', fill_value='extrapolate')  # interpolation function for lower confidence interval (smooths line)

    ci_x = np.linspace(min(unique_x) - abs(max(unique_x) - min(unique_x)) / 2, max(unique_x) + abs(max(unique_x) - min(unique_x)) / 2, num=1000)  # x values for confidence interval plot (extends beyond data range)
    ci_upper_y = upper_spread(ci_x)  # y values for upper confidence interval plot corresponding to 'extended' x values
    ci_lower_y = lower_spread(ci_x)  # y values for lower confidence interval plot corresponding to 'extended' x values

    # plot confidence interval
    ci_plt = hv.Area((ci_x, ci_upper_y, ci_lower_y), vdims=['ci_y1', 'ci_y2']).opts(xlabel=x_label, ylabel=y_label, color=color, alpha=alpha, line_color=line_color, height=height, width=width, xlim=x_range, ylim=y_range)
    return ci_plt

# Data Import from Data Frame and Generate SMILES Images (skip)

In [None]:
# excel_file = 'example_spreadsheet.xlsx'
# excel_sheet = 'Sheet1'
# header = 0  # row number of header (0-index) (set header = 1 to drop row with x1, x2... column names if present)

# id_column = 'Name' #name or 0-index
# smiles_column = 'Smiles' #name or 0-index (leave blank if not available)
# response_column = 'Fake Yield' #name or 0-index (leave blank if not available)
# descriptor_start_column = 'P_NMR_min' #name or 0-index
# image_column = 'Image' #name of column that svgs will go into (not pre-existing)

# # Read in data
# df = pd.read_excel(excel_file, excel_sheet, header=header, engine='openpyxl')
# df = df.dropna(axis=1, how='any')

# # Generate list of descriptors 
# descriptor_start_column_loc, descriptor_start_column = get_column_loc(descriptor_start_column, df)
# descriptors = list(df.columns)[descriptor_start_column_loc:]

# #  Generate folder for any saved figures, named with run date
# run_date = date.today().strftime("%b-%d-%Y")
# if not os.path.exists(run_date):
#     os.makedirs(run_date)

# # Generate SMILES images
# image_column = 'Image' #name of column that svgs will go into (not pre-existing)

# smiles_column_loc, smiles_column = get_column_loc(smiles_column, df)
# df = DrawMol(df, smiles_column_loc, image_column)

# Data Import from Pickle File (SMILES SVGs pre-generated)

In [None]:
id_column = 'Name' #name or 0-index
response_column = 'Fake Yield' #name or 0-index (leave blank if not available)
descriptor_start_column = 'P_NMR_min' #name or 0-index
image_column = 'Image'  

# Read in data from pickle file (svgs are already imbedded)
df = pd.read_pickle('tutorial_df.pkl')

# Generate list of descriptors 
descriptor_start_column_loc, descriptor_start_column = get_column_loc(descriptor_start_column, df)
descriptors = list(df.columns)[descriptor_start_column_loc:]

#  Generate folder for any saved figures, named with run date
run_date = date.today().strftime("%b-%d-%Y")
if not os.path.exists(run_date):
    os.makedirs(run_date)

# Modeling

## Univariate Scatter Plot

In [None]:
dataframe = df
x_axis = 'P_NBO_max'
y_axis = response_column
hover_list = [id_column]

scatter_plt = scatter_plot(dataframe, x=x_axis, y=y_axis, svgs=image_column, hover_list=hover_list, color='#7291ab', alpha=0.8)
scatter_plt

## Slope Line

Scroll out to visualize line - no set axes on this plot

In [None]:
line, slope, intercept, r_value, p_value, std_err = plot_slope(dataframe, x=x_axis, y=y_axis)

line

## Confidence Interval

In [None]:
confidence_interval = plot_confidence_interval(dataframe, x=x_axis, y=y_axis, ci=0.99)  # Confidence interval is calculated incorrectly here but represents the function
confidence_interval

## Combine Plots

In [None]:
scatter_plt * line  # bottom layer (first) contains interactive information (if multiple layers contain hover information), top layer is last

In [None]:
plt = scatter_plt * line * confidence_interval  # bottom layer (first) contains interactive information, top layer is last
plt

# Save HTML File

In [None]:
file_name = 'univariate model {y_axis} vs {x_axis}'
file_path = run_date

hv.save(plt, file_path + '/' + file_name + '.html', fmt='html')

# Chemical Space: PCA

## Run PCA

In [None]:
dataframe = df 
train_dataframe = dataframe
n_components = 2

# Fit PCA
pca = PCA(n_components=n_components)
pca.fit(train_dataframe[descriptors])
principal_components = pca.transform(dataframe[descriptors])

# Add principal components to dataframe
for i in range(n_components):
    dataframe[f'pc{i+1}'] = principal_components[:, i]

# Print explained variance
pca_score = pca.explained_variance_ratio_
print(f'Variance explained by each principal component:')
for i, variance in enumerate(pca_score):
    print(f'pc{i+1}: {round(variance*100, 1)}%')

## Interactive PCA Plot

In [None]:
dataframe = df
x_axis = 'pc1'
y_axis = 'pc2'
hover_list = [id_column]  # set svgs = column name if images are available, they will be automatically added to hover_list

plt = scatter_plot(dataframe, x=x_axis, y=y_axis, svgs=image_column, hover_list=hover_list, color='#7291ab', alpha=0.8)
plt

## Cluster PCA Space (k-Means)

In [None]:
# # Uncomment if KElbowVisualizer (below) is producing a font error

# import matplotlib.pyplot as plt
# plt.rcParams["font.family"] = "DejaVu Sans"  # or another font available on your system

In [None]:
dataframe = df
dimred_columns = [col for col in dataframe.columns if col.startswith('pc')]  # list of columns with principal components

clustering_algorithm = 'kMeans'
clustering_range = (5,15)
random_state = 42
n_init = 10 

model = KMeans(random_state=random_state, n_init=n_init)
visualizer = KElbowVisualizer(model, k=clustering_range)
visualizer.fit(dataframe[dimred_columns])  # Fit the data to the visualizer
n_clust = visualizer.elbow_value_
print(f'Optimal number of clusters using distortion score (elbow plot): {n_clust}')
visualizer.show()  # Finalize and render the figure

# Add cluster number to dataframe using optimal number of clusters
kmeans_clustering = KMeans(n_clusters=n_clust, random_state=random_state, n_init=n_init)
kmeans_clustering.fit(dataframe[dimred_columns])

# Add cluster number to dataframe
dataframe[f'{clustering_algorithm}_cluster'] = kmeans_clustering.labels_ + 1

## Interactive Clustered PCA Plot

In [None]:
dataframe = df
x_axis = 'pc1'
y_axis = 'pc2'

groupby = 'kMeans_cluster'  # name or 0-index for column used to group data
color_palette = 'Category20'  # color palette for grouped data- limited Bokeh support for use with 'groupby'

plt = scatter_plot(dataframe, x=x_axis, y=y_axis, svgs=image_column, hover_list=[id_column], groupby=groupby, color=color_palette)
plt

# Interactive Plotting Tools (advanced mode)

Examples shown using PCA plot generated above and can be modified in many ways

## Categorical Coloring (user defined colors and labels, in contrast to cluster plotting method above)

In [None]:
dataframe = df
x_axis = 'pc1'
y_axis = 'pc2'

category_column = 'Commercial'  # name or 0-index for column containing category criteria
category_criteria = {  # column value: [label, color] (column value can be string or integer)
    0: ['Non-commercial', '#7291ab'],
    1: ['Commercial', '#45d7ed']
}

plt = None
for key, value in category_criteria.items():
    legend = value[0]
    color = value[1]
    if plt:
        subplt = scatter_plot(dataframe.loc[dataframe[category_column] == key], x=x_axis, y=y_axis, legend=legend, svgs=image_column, hover_list=[id_column], color=color)
        plt = plt * subplt
    else:
        plt = scatter_plot(dataframe.loc[dataframe[category_column] == key], x=x_axis, y=y_axis, legend=legend, svgs=image_column, hover_list=[id_column], color=color)
base_plt = scatter_plot(dataframe, x=x_axis, y=y_axis, svgs=image_column, hover_list=[id_column], alpha=0)  # add svg hover info to all layers (default is bottom plt only)
plt = base_plt * plt
plt

## Heatmap

In [None]:
dataframe = df
x_axis = 'pc1'
y_axis = 'pc2'
color_col = 'bite_angle'  # dataframe column used to color data points
heatmap_color = 'Plasma_r'
title = f'PCA, colored by {color_col}'
hover_list = [id_column]

plt = scatter_plot(dataframe, x=x_axis, y=y_axis, title=title, svgs=image_column, hover_list=hover_list, heatmap=True, heatmap_col=color_col, heatmap_color=heatmap_color, alpha=0.8)
plt

## Bubble Plot

In [None]:

dataframe = df
x_axis = 'pc1'
y_axis = 'pc2'
size_col = 'bite_angle'
title = f'PCA, sized by {size_col}'
hover_list = [id_column]

plt = scatter_plot(dataframe, x=x_axis, y=y_axis, title=title, svgs=image_column, hover_list=hover_list, bubbleplot=True, bubblesize=size_col, alpha=0.8)

plt

## Heatmap Bubble Plot

In [None]:
dataframe = df
x_axis = 'pc1'
y_axis = 'pc2'
color_col = response_column
heatmap_color = 'Plasma'
size_col = 'bite_angle'
hover_list = [id_column]

save_plot = False
file_name = 'pca heatmap bubbleplot'
file_path = run_date

plt = scatter_plot(dataframe, x=x_axis, y=y_axis, svgs=image_column, hover_list=hover_list, heatmap=True, heatmap_col=color_col, heatmap_color=heatmap_color, bubbleplot=True, bubblesize=size_col, alpha=0.8)

if save_plot:
    hv.save(plt, file_path + '/' + file_name + '.html', fmt='html')

plt