In [2]:
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from io import BytesIO
import base64
import pandas as pd
import plotly.express as px
import plotly.io as pio
import dash
from dash import dcc, html
from dash.dependencies import Input, Output


In [4]:
small,df=pd.read_pickle('..\\ML_31_12\\small.pickle'),pd.read_pickle('..\\ML_31_12\\df.pkl')
small['Ar']=small['mol'].apply(Descriptors.NumAromaticRings)


In [8]:
def mol_to_base64(mol, size=(300, 300)):
    img = Draw.MolToImage(mol, size=size)
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode()

def rdkit_plot(df, x_col, y_col, method='MolWt', color='blue', line_color='rgba(255, 0, 0, 0.1)', cluster_col=None, marker_col=None):
    df = df.dropna(subset=[x_col, y_col]).copy()  # Avoid SettingWithCopyWarning

    lin = LinearRegression()
    lin.fit(df[x_col].values.reshape(-1, 1), df[y_col])
    print('R² score is', round(r2_score(df[y_col], lin.predict(df[x_col].values.reshape(-1, 1))), 2))

    # Generate molecular images
    df.loc[:, 'image'] = df['mol'].apply(mol_to_base64)

    # Calculate descriptor
    def calculate(mol):
        if hasattr(Descriptors, method):
            return getattr(Descriptors, method)(mol)
        else:
            raise ValueError(f"Descriptor method '{method}' not found in RDKit Descriptors")

    df.loc[:, 'calculated'] = df['mol'].apply(calculate)

    if method != 'MolWt':
        marker_col = 'calculated'

    # Plotting
    plot_args = {
        'x': x_col,
        'y': y_col,
        'custom_data': ['MW', 'name', 'IPA'],
        'color': cluster_col
    }

    if not marker_col:
        plot_args['color_discrete_sequence'] = [color]

    fig = px.scatter(df, **plot_args)

    if marker_col:
        fig.update_layout(
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
        )
        for _, row in df.iterrows():
            fig.add_annotation(
                x=row[x_col], y=row[y_col],
                text=str(row[marker_col]),
                showarrow=False, yshift=10
            )

    fig.update_traces(hovertemplate='MW=%{customdata[0]}<br>name=%{customdata[1]}<br>%{customdata[2]}')
    fig.add_scatter(
        x=df[x_col],
        y=df[x_col] * lin.coef_[0] + lin.intercept_,
        mode='lines',
        line=dict(color=line_color),
        showlegend=False
    )

    # Dash app
    app = dash.Dash(__name__)
    app.layout = html.Div([
        dcc.Graph(id='scatter-plot', figure=fig, style={"width": "70%", "display": "inline-block"}),
        html.Img(id='mol-image', style={"width": "20%", "display": "inline-block"})
    ])

    @app.callback(Output('mol-image', 'src'), Input('scatter-plot', 'hoverData'))
    def update_image(hoverData):
        if hoverData:
            idx = hoverData['points'][0]['pointIndex']
            return df.iloc[idx]['image']
        return ''

    app.run(mode='external')  # Change to 'inline' for Jupyter Notebook


In [9]:
rdkit_plot(small,'MW','Tol', cluster_col='ACN',marker_col='Ar')

R² score is 0.09
