# Interpretable ML for Hydration Free Energy Prediction

**A Hands-On Workshop**


![title](images/paper.png)


## 📋 Overview

This notebook guides you through analyzing physicochemical features and hydration free energy (`dG_exp`) data using **exploratory data analysis (EDA)** and **grouping strategies**. By the end, you'll:

-   Understand key descriptors (polar surface area, logP, etc.)
-   Visualize distributions, correlations, and outliers
-   Group molecules by chemical properties and analyze trends


## 🧪 Dataset Description

We’ll use a curated dataset from the **FreeSolv database**, containing:

-   **Target**: Experimental hydration free energy (`dG_exp`)
-   **Features**:
    -   `pol`: Polar electrostatic energy
    -   `psa`: Polar surface area
    -   `logP`: Octanol-water partition coefficient
    -   `n_donors`: Hydrogen bond donors
    -   `n_acceptors`: Hydrogen bond acceptors
    -   `nrotb`: Rotatable bonds
-   **Metadata**: `mobleyID` (unique molecule identifier), `group_id` (chemical category)


### Polar Electrostatic energy

![title](images/polarization_energy.png)


### Polar Surface Area

![title](images/psa.png)


In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from rich.console import Console
from rich.table import Table
import plotly.express as px

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", None)

In [5]:
df = pd.read_csv("data/groups/0.1/grouped_data.csv")
df.head()


Unnamed: 0,mobleyID,pol,psa,n_donors,nrotb,group_id,dG_exp,n_acceptors,logP
0,mobley_7532833,-7.491408,23.79,0,0,7.0,-3.88,1,0.52988
1,mobley_2198613,-1.497948,0.0,0,0,4.0,-0.63,0,1.2451
2,mobley_9257453,-9.095077,20.23,1,0,5.0,-7.29,1,2.699
3,mobley_755351,-13.409148,35.25,1,1,5.0,-7.29,2,1.2774
4,mobley_9729792,-3.356425,0.0,0,0,,-0.99,0,2.0587


In [14]:
console = Console()

# Map group_id to group names
group_names = {
    0: "Alkanol",
    1: "Alkanone",
    2: "Alkene",
    3: "Alkyl Alkanoate",
    4: "Halo Alkane",
    5: "Aromatic",
    6: "Aliphatic cyclic",
    7: "N-based Aliphatic",
}

# Add a new column for group names
df["group_name"] = df["group_id"].map(group_names)

# Display basic statistics using Rich
console.print("[bold magenta]Basic Statistics:[/bold magenta]")
df.describe()

Unnamed: 0,pol,psa,n_donors,nrotb,group_id,dG_exp,n_acceptors,logP
count,643.0,643.0,643.0,643.0,643.0,643.0,643.0,643.0
mean,-9.583435,20.889207,0.343701,1.62986,2.734059,-3.806952,1.382582,1.92694
std,10.516359,23.8065,0.627811,1.971726,2.630259,3.846124,1.61085,1.491136
min,-103.116055,0.0,0.0,0.0,-1.0,-25.47,0.0,-3.5854
25%,-11.689927,0.0,0.0,0.0,-1.0,-5.73,0.0,1.1233
50%,-7.678724,17.07,0.0,1.0,4.0,-3.54,1.0,1.7801
75%,-3.611565,26.3,1.0,3.0,5.0,-1.22,2.0,2.56965
max,-0.049148,136.1,6.0,12.0,7.0,3.43,9.0,9.8876


In [15]:
# Create a table for group distribution
group_distribution = df["group_name"].value_counts().reset_index()
group_distribution.columns = ["Group Name", "Count"]

table = Table(title="Group Distribution")
table.add_column("Group Name", justify="left", style="cyan", no_wrap=True)
table.add_column("Count", justify="right", style="green")

for index, row in group_distribution.iterrows():
    table.add_row(row["Group Name"], str(row["Count"]))

console.print(table)

# Visualize the distribution of groups using Plotly
fig = px.bar(
    group_distribution,
    x="Group Name",
    y="Count",
    title="Distribution of Groups",
    labels={"Group Name": "Group Name", "Count": "Number of Molecules"},
    text_auto=True,
)

fig.update_traces(
    marker_color="lightseagreen",
    marker_line_color="rgb(8,48,107)",
    marker_line_width=1.5,
    opacity=0.6,
)

fig.update_layout(
    xaxis_title="Group Name", yaxis_title="Number of Molecules", template="plotly_white"
)

fig.show()

In [None]:
import plotly.express as px
import plotly.graph_objects as go

# 1. Distribution of Experimental Free Energy (dG_exp)
fig1 = px.histogram(
    df,
    x="dG_exp",
    nbins=30,
    title="Distribution of Experimental Free Energy (dG_exp)",
    labels={"dG_exp": "Experimental Free Energy (dG_exp)"},
    color_discrete_sequence=["lightseagreen"],
)

fig1.update_layout(
    template="plotly_white",
    xaxis_title="Experimental Free Energy (dG_exp)",
    yaxis_title="Count",
)

# 2. Scatter Plot: Polarizability (pol) vs. Experimental Free Energy (dG_exp)
fig2 = px.scatter(
    df,
    x="pol",
    y="dG_exp",
    title="Polarizability (pol) vs. Experimental Free Energy (dG_exp)",
    labels={"pol": "Polarizability", "dG_exp": "Experimental Free Energy (dG_exp)"},
    color="group_name",
    hover_name="mobleyID",
    trendline="lowess",
)

fig2.update_layout(
    template="plotly_white",
    xaxis_title="Polarizability (pol)",
    yaxis_title="Experimental Free Energy (dG_exp)",
)

# 3. Box Plot: Experimental Free Energy (dG_exp) by Group
fig3 = px.box(
    df,
    x="group_name",
    y="dG_exp",
    title="Experimental Free Energy (dG_exp) by Group",
    labels={"group_name": "Group", "dG_exp": "Experimental Free Energy (dG_exp)"},
    color="group_name",
)

fig3.update_layout(
    template="plotly_white",
    xaxis_title="Group",
    yaxis_title="Experimental Free Energy (dG_exp)",
)

# Show all plots
fig1.show()
fig2.show()
fig3.show()

In [23]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

# 1. Correlation Plot (Square Heatmap)
corr_matrix = df[
    ["pol", "psa", "n_donors", "nrotb", "dG_exp", "n_acceptors", "logP"]
].corr()

fig1 = go.Figure(
    data=go.Heatmap(
        z=corr_matrix.values,
        x=corr_matrix.columns,
        y=corr_matrix.columns,
        colorscale="Viridis",
        colorbar=dict(title="Correlation"),
        zmin=-1,  # Ensure the scale is from -1 to 1
        zmax=1,
    )
)

# Add annotations for correlation values
annotations = []
for i, row in enumerate(corr_matrix.values):
    for j, value in enumerate(row):
        annotations.append(
            go.layout.Annotation(
                text=f"{value:.2f}",  # Format to 2 decimal places
                x=corr_matrix.columns[j],
                y=corr_matrix.columns[i],
                xref="x1",
                yref="y1",
                font=dict(color="white" if abs(value) > 0.5 else "black"),
                showarrow=False,
            )
        )

fig1.update_layout(
    title="Correlation Heatmap",
    template="plotly_white",
    xaxis_title="Features",
    yaxis_title="Features",
    annotations=annotations,
    width=600,  # Square layout
    height=600,
)


fig1.show()

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression

# List of features to plot against dG_exp
features = ["pol", "psa", "logP", "n_donors", "n_acceptors", "nrotb"]

# Create subplots
fig = make_subplots(
    rows=2,
    cols=3,  # 2 rows, 3 columns
    subplot_titles=[f"{feat} vs. dG_exp" for feat in features],
)

# Add scatter plots and estimator lines for each feature
for i, feat in enumerate(features):
    row = (i // 3) + 1  # Calculate row position
    col = (i % 3) + 1  # Calculate column position

    # Add scatter plot
    fig.add_trace(
        go.Scatter(
            x=df[feat],
            y=df["dG_exp"],
            mode="markers",
            marker=dict(color="lightseagreen", opacity=0.6),
            name=feat,
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    # Add estimator line (linear regression)
    x = df[feat].values.reshape(-1, 1)
    y = df["dG_exp"].values
    model = LinearRegression()
    model.fit(x, y)
    y_pred = model.predict(x)

    fig.add_trace(
        go.Scatter(
            x=df[feat],
            y=y_pred,
            mode="lines",
            line=dict(color="darkorange", width=2),
            name=f"{feat} Estimator",
            showlegend=False,
        ),
        row=row,
        col=col,
    )

# Update layout
fig.update_layout(
    title="Feature Relationships with Experimental Free Energy (dG_exp)",
    template="plotly_white",
    height=800,  # Adjust height for better visibility
    width=1200,  # Adjust width for better visibility
)

# Update axis labels
for i, feat in enumerate(features):
    row = (i // 3) + 1
    col = (i % 3) + 1
    fig.update_xaxes(title_text=feat, row=row, col=col)
    fig.update_yaxes(title_text="dG_exp", row=row, col=col)

# Show plot
fig.show()