<a href="https://colab.research.google.com/github/Paranjay33/Ai-Driven-Drug-Discovery/blob/main/attention_gnn_solubility.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention‑Enhanced Graph Neural Network for Molecular Solubility Prediction

This notebook implements a **Graph Neural Network (GNN)** with a **multi‑head self‑attention layer** for predicting aqueous solubility (logS) of small‑molecule compounds.

**Key features**
1. Uses the **Delaney (ESOL)** dataset via **DeepChem**.
2. Adds a *custom* multi‑head attention layer on top of DeepChem’s `GraphConv` embeddings.
3. Performs **5‑fold cross‑validation** to demonstrate robustness.
4. Generates an **interactive dashboard** (Plotly) to visualise RMSE & R² across folds.

> **Note**: Run this notebook on Google Colab (recommended) with GPU enabled. All installation commands are provided.


In [None]:

!pip install -q --upgrade \
    tensorflow==2.18.0 \
    pandas==2.2.2 \
    requests==2.32.3 \
    scikit-learn==1.6.2 \
    ml-dtypes==0.4.0 \
    plotly==5.22.0


!pip install -q deepchem==2.7.1 rdkit-pypi==2023.9.4 --no-deps


[0m[31mERROR: Could not find a version that satisfies the requirement scikit-learn==1.6.2 (from versions: 0.9, 0.10, 0.11, 0.12, 0.12.1, 0.13, 0.13.1, 0.14, 0.14.1, 0.15.0, 0.15.1, 0.15.2, 0.16.0, 0.16.1, 0.17, 0.17.1, 0.18, 0.18.1, 0.18.2, 0.19.0, 0.19.1, 0.19.2, 0.20.0, 0.20.1, 0.20.2, 0.20.3, 0.20.4, 0.21.1, 0.21.2, 0.21.3, 0.22, 0.22.1, 0.22.2.post1, 0.23.0, 0.23.1, 0.23.2, 0.24.0, 0.24.1, 0.24.2, 1.0, 1.0.1, 1.0.2, 1.1.0, 1.1.1, 1.1.2, 1.1.3, 1.2.0rc1, 1.2.0, 1.2.1, 1.2.2, 1.3.0rc1, 1.3.0, 1.3.1, 1.3.2, 1.4.0rc1, 1.4.0, 1.4.1.post1, 1.4.2, 1.5.0rc1, 1.5.0, 1.5.1, 1.5.2, 1.6.0rc1, 1.6.0, 1.6.1, 1.7.0rc1, 1.7.0)[0m[31m
[0m[31mERROR: No matching distribution found for scikit-learn==1.6.2[0m[31m
[0m[31mERROR: Ignored the following versions that require a different python version: 2.6.0.dev20220112162333 Requires-Python >=3.7,<3.10; 2.6.0.dev20220114040838 Requires-Python >=3.7,<3.10; 2.6.0.dev20220118010103 Requires-Python >=3.7,<3.10; 2.6.0.dev20220118135955 Requires-Python

In [None]:
import deepchem as dc
from deepchem.molnet import load_delaney

import tensorflow as tf
from tensorflow.keras import layers, Model

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error

import plotly.graph_objects as go


In [None]:
# Delaney (ESOL) solubility dataset with GraphConv featuriser
tasks, datasets, transformers = load_delaney(featurizer='GraphConv')
train_dataset, valid_dataset, test_dataset = datasets

print('Train samples :', len(train_dataset))
print('Valid samples :', len(valid_dataset))
print('Test  samples :', len(test_dataset))


Train samples : 902
Valid samples : 113
Test  samples : 113


In [None]:
import tensorflow as tf
import deepchem as dc

def build_basic_gcnn(n_tasks=1,
                     graph_conv_layers=[64, 64],
                     dense_size=128,
                     lr=1e-3,
                     dropout=0.2):
    """
    Straight DeepChem GraphConvModel (no custom attention)
    – runs reliably in TF 2.18 + DC 2.7.
    """
    model = dc.models.GraphConvModel(
        n_tasks=n_tasks,
        mode='regression',
        graph_conv_layers=graph_conv_layers,
        dense_layer_size=dense_size,
        batch_size=32,
        learning_rate=lr,
        dropout=dropout)
    return model


In [None]:
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error
import pandas as pd

kf = KFold(n_splits=5, shuffle=True, random_state=42)
metrics = {'fold': [], 'rmse': [], 'r2': []}

for i, (tr_idx, te_idx) in enumerate(kf.split(train_dataset.X), start=1):
    print(f'Fold {i}/5')

    X_tr, y_tr, w_tr, ids_tr = (train_dataset.X[tr_idx],
                                 train_dataset.y[tr_idx],
                                 train_dataset.w[tr_idx],
                                 train_dataset.ids[tr_idx])
    X_te, y_te, w_te, ids_te = (train_dataset.X[te_idx],
                                 train_dataset.y[te_idx],
                                 train_dataset.w[te_idx],
                                 train_dataset.ids[te_idx])

    fold_train = dc.data.NumpyDataset(X_tr, y_tr, w_tr, ids_tr)
    fold_test  = dc.data.NumpyDataset(X_te, y_te, w_te, ids_te)


    fold_model = build_basic_gcnn()
    fold_model.fit(fold_train, nb_epoch=30)

    preds = fold_model.predict(fold_test).flatten()
    r2   = r2_score(y_te.flatten(), preds)
    rmse = mean_squared_error(y_te.flatten(), preds, squared=False)

    print(f'  RMSE: {rmse:.3f} | R²: {r2:.3f}')
    metrics['fold'].append(i)
    metrics['rmse'].append(rmse)
    metrics['r2'].append(r2)

metrics_df = pd.DataFrame(metrics)
metrics_df


Fold 1/5



'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.



  RMSE: 0.515 | R²: 0.748
Fold 2/5



'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.



  RMSE: 0.589 | R²: 0.673
Fold 3/5



'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.



  RMSE: 0.654 | R²: 0.593
Fold 4/5



'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.



  RMSE: 0.546 | R²: 0.696
Fold 5/5
  RMSE: 0.547 | R²: 0.633



'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.



Unnamed: 0,fold,rmse,r2
0,1,0.515445,0.748068
1,2,0.588926,0.672831
2,3,0.653968,0.592671
3,4,0.546116,0.696494
4,5,0.5474,0.632835


In [None]:
fig = go.Figure()
fig.add_trace(go.Bar(x=metrics_df['fold'], y=metrics_df['rmse'], name='RMSE'))
fig.add_trace(go.Scatter(x=metrics_df['fold'], y=metrics_df['r2'],
                         mode='lines+markers', name='R²'))
fig.update_layout(title='5‑Fold Cross‑Validation Metrics',
                  xaxis_title='Fold', yaxis_title='Metric')
fig.show()

metrics_df


Unnamed: 0,fold,rmse,r2
0,1,0.515445,0.748068
1,2,0.588926,0.672831
2,3,0.653968,0.592671
3,4,0.546116,0.696494
4,5,0.5474,0.632835


In [None]:
final_model = build_basic_gcnn()
final_model.fit(train_dataset, nb_epoch=30)

test_preds = final_model.predict(test_dataset).flatten()
test_r2   = r2_score(test_dataset.y.flatten(), test_preds)

print(f"Test R²: {test_r2:.3f}")


Test R²: 0.320


In [None]:
metrics_df.to_csv('cv_metrics.csv', index=False)
print('Cross‑validation metrics saved to cv_metrics.csv')


Cross‑validation metrics saved to cv_metrics.csv
