<a href="https://colab.research.google.com/github/MZiaAfzal71/Edge-Aware-GNN/blob/main/Models/Descriptor_Augmented_Edge_Aware_GNN_for_ESOL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Descriptor-Augmented Edge-Aware GNN for ESOL

This notebook implements a **Descriptor-Augmented Edge-Aware Graph Neural Network (DA-EA-GNN)**
for predicting aqueous solubility on the **ESOL (Delaney) dataset**.

Each molecule is represented using a **hybrid molecular representation** that combines:
- A **graph-based representation** derived from its **SMILES string**, and
- **Global physicochemical descriptors** computed using **RDKit**

In the molecular graph:
- **Nodes** correspond to atoms with chemically meaningful atom features
- **Edges** correspond to bonds with explicit bond features
- Message passing is performed using **edge-aware GNN layers** to incorporate bond information

In addition to the graph representation:
- A set of **RDKit molecular descriptors** is computed for each molecule
- All descriptors, or a chemically relevant subset tailored for solubility prediction, are used
- Descriptor features are normalized and fused with graph-level embeddings

In this notebook:
- Both **structure-level (atom‚Äìbond)** and **molecule-level (descriptor)** information are used
- Training and evaluation are performed using:
  - **Repeated cross-validation**, and
  - A **Bemis‚ÄìMurcko scaffold-based split** to assess generalization to unseen chemical scaffolds
- Results are compared with the structure-only edge-aware GNN to evaluate the impact of descriptor fusion

This setup allows us to study the effect of incorporating global physicochemical information
alongside edge-aware message passing under both random and chemically realistic evaluation
protocols for molecular property prediction.

The notebook is designed to be:
- **Reproducible**
- **Interpretable**
- **Focused on understanding representation choices**

The ESOL dataset contains 1,128 small molecules with experimentally measured aqueous solubility
values and serves as a standard benchmark for evaluating molecular machine learning models.


In [None]:
# 1Ô∏è‚É£ Fetch data
!git clone https://github.com/MZiaAfzal71/Edge-Aware-GNN.git

In [None]:
# 2Ô∏è‚É£ Change current/working directory
%cd Edge-Aware-GNN/ESOL\ Dataset

In [None]:
# 3Ô∏è‚É£ Install rdkit and PyG
!pip install rdkit torch_geometric

In [None]:
#  4Ô∏è‚É£ Imports
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import random
import copy
import os

import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, AttentionalAggregation
import torch.optim as optim


from rdkit import Chem
from rdkit.Chem import Descriptors
# from rdkit.Chem.Scaffolds import MurckoScaffold

from sklearn.model_selection import RepeatedKFold, train_test_split
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

# import matplotlib.pyplot as plt
# import seaborn as sns
# import warnings
# warnings.filterwarnings("ignore")

In [None]:
# 5Ô∏è‚É£ Set random seeds for reproducibility across Python, NumPy, and PyTorch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
# 6Ô∏è‚É£ Utility functions to convert SMILES strings into graph representations with atom and bond features

ELECTRONEGATIVITY = {
    1: 2.20, 6: 2.55, 7: 3.04, 8: 3.44,
    9: 3.98, 15: 2.19, 16: 2.58,
    17: 3.16, 35: 2.96, 53: 2.66
}

def atom_features(atom):
    Z = atom.GetAtomicNum()

    hyb = atom.GetHybridization()
    hyb_onehot = [
        hyb == Chem.rdchem.HybridizationType.SP,
        hyb == Chem.rdchem.HybridizationType.SP2,
        hyb == Chem.rdchem.HybridizationType.SP3
    ]

    chiral = atom.GetChiralTag()
    chiral_onehot = [
        chiral == Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        chiral == Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW
    ]

    return [
        Z,
        # atom.GetMass(),
        atom.GetDegree(),
        atom.GetTotalValence(),
        atom.GetTotalNumHs(),
        atom.GetFormalCharge(),
        float(atom.GetIsAromatic()),
        float(atom.IsInRing()),
        *hyb_onehot,
        # *chiral_onehot,
        ELECTRONEGATIVITY.get(Z, 0.0)
    ]

def bond_features(bond):
    bt = bond.GetBondType()

    stereo = bond.GetStereo()
    stereo_onehot = [
        stereo == Chem.rdchem.BondStereo.STEREOE,
        stereo == Chem.rdchem.BondStereo.STEREOZ
    ]

    return [
        bt == Chem.rdchem.BondType.SINGLE,
        bt == Chem.rdchem.BondType.DOUBLE,
        bt == Chem.rdchem.BondType.TRIPLE,
        bt == Chem.rdchem.BondType.AROMATIC,
        float(bond.GetIsConjugated()),
        float(bond.IsInRing()),
        # float(bond.IsRotor()),
        # float(bond.GetIsAmide()),
        # *stereo_onehot
    ]

def smiles_to_graph(smiles, y=None):
    mol = Chem.MolFromSmiles(smiles)

    x = [atom_features(atom) for atom in mol.GetAtoms()]

    edge_index, edge_attr = [], []

    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bf = bond_features(bond)

        edge_index += [[i, j], [j, i]]
        edge_attr += [bf, bf]

    data = Data(
        x=torch.tensor(x, dtype=torch.float),
        edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(),
        edge_attr=torch.tensor(edge_attr, dtype=torch.float)
    )

    if y is not None:
        data.y = torch.tensor(y, dtype=torch.float)

    return data

In [None]:
# 7Ô∏è‚É£ PyTorch Dataset Class

class MoleculeDataset(Dataset):
    def __init__(self, df, desc, smiles_col, target_col, scaler=None, fit_scaler=False):
        self.smiles = df[smiles_col].values
        self.targets = df[target_col].values.astype(np.float32)

        # Compute descriptors
        descriptors = desc

        # Normalize descriptors
        if fit_scaler:
            self.scaler = RobustScaler()
            self.descriptors = self.scaler.fit_transform(descriptors)
            self.descriptors = np.clip(self.descriptors, -5, 5)
        elif scaler is not None:
            self.scaler = scaler
            self.descriptors = self.scaler.transform(descriptors)
            self.descriptors = np.clip(self.descriptors, -5, 5)
        else:
            self.scaler = None
            self.descriptors = np.clip(descriptors, -5, 5)

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        graph = smiles_to_graph(self.smiles[idx])
        graph.y = torch.tensor([self.targets[idx]], dtype=torch.float)

        desc = torch.tensor(self.descriptors[idx], dtype=torch.float)
        return graph, desc

In [None]:
# 8Ô∏è‚É£ Collate Function (Required for PyG + Descriptors)

def collate_fn(batch):
    graphs, descs = zip(*batch)
    batch_graph = Data.from_data_list(graphs)
    batch_desc = torch.stack(descs)
    return batch_graph, batch_desc

In [None]:
# 9Ô∏è‚É£ GINE-based graph neural network with attentional readout for molecular graphs

class SimpleGINE(nn.Module):
    def __init__(
        self,
        node_dim,
        edge_dim,
        hidden_dim=128,
        dropout=0.1
    ):
        super().__init__()

        def make_mlp():
            return nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )

        self.node_emb = nn.Linear(node_dim, hidden_dim)

        self.conv1 = GINEConv(make_mlp(), edge_dim=edge_dim)
        self.conv2 = GINEConv(make_mlp(), edge_dim=edge_dim)
        self.conv3 = GINEConv(make_mlp(), edge_dim=edge_dim)

        # ---- Attention readout ----
        self.readout = AttentionalAggregation(
            gate_nn=nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            )
        )

        self.dropout = dropout

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x)

        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = F.relu(self.conv3(x, edge_index, edge_attr))

        return self.readout(x, batch)

In [None]:
# üîü GNN + Descriptor Fusion Model

class GatedFusion(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()

        self.gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Sigmoid()
        )

    def forward(self, g, d):
        gate = self.gate(torch.cat([g, d], dim=1))
        return gate * g + (1.0 - gate) * d

In [None]:
# 1Ô∏è‚É£1Ô∏è‚É£ Descriptor-Augmented Edge-aware graph neural network (DA_EAGNN) based on GINE for molecular property prediction

class DA_EAGNN(nn.Module):
    def __init__(
        self,
        node_dim,
        desc_dim,
        edge_dim,
        hidden_dim=128,
        dropout=0.1
    ):
        super().__init__()

        self.gnn = SimpleGINE(
            node_dim=node_dim,
            edge_dim=edge_dim,
            hidden_dim=hidden_dim,
            dropout=dropout
        )

        self.desc_net = nn.Sequential(
            nn.Linear(desc_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.fusion = GatedFusion(hidden_dim)

        self.head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, graphs, desc):
        g = self.gnn(
            graphs.x,
            graphs.edge_index,
            graphs.edge_attr,
            graphs.batch
        )

        d = self.desc_net(desc)

        h = self.fusion(g, d)
        return self.head(h).squeeze(-1)


In [None]:
# 1Ô∏è‚É£2Ô∏è‚É£ Utility function for repeated k-fold cross-validation with PyG data loaders

def run_repeated_kfold_cv(
    df,
    descriptors,
    smiles_col="smiles",
    target_col="target",
    n_splits=5,
    n_repeats=5,
    batch_size=32,
    seed=42
):
    rkf = RepeatedKFold(
        n_splits=n_splits,
        n_repeats=n_repeats,
        random_state=seed
    )

    split_id = 0

    for train_idx, val_idx in rkf.split(df):
        repeat = split_id // n_splits
        fold   = split_id % n_splits

        train_df = df.iloc[train_idx]
        val_df   = df.iloc[val_idx]

        train_desc = descriptors[train_idx]
        val_desc = descriptors[val_idx]

        # ---- Train dataset (fit scaler) ----
        train_dataset = MoleculeDataset(
            train_df,
            train_desc,
            smiles_col,
            target_col,
            fit_scaler=True
        )

        # ---- Validation dataset (reuse scaler) ----
        val_dataset = MoleculeDataset(
            val_df,
            val_desc,
            smiles_col,
            target_col,
            scaler=train_dataset.scaler
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_fn
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_fn
        )

        yield repeat, fold, train_loader, val_loader

        split_id += 1


In [None]:
# 1Ô∏è‚É£3Ô∏è‚É£ # Training utility with early stopping, learning rate scheduling, and model checkpointing

class Trainer:
    def __init__(
        self,
        model,
        device,
        lr=1e-3,
        weight_decay=1e-4,
        patience=10,
        max_epochs=100
    ):
        self.model = model
        self.device = device
        self.patience = patience
        self.max_epochs = max_epochs

        self.criterion = nn.MSELoss()
        # self.criterion = nn.SmoothL1Loss(beta=1.0)

        self.optimizer = optim.Adam(
            model.parameters(),
            lr=lr,
            weight_decay=weight_decay
        )

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            factor=0.5,
            patience=5
        )

        self.best_val_loss = float("inf")
        self.best_model_state = None
        self.history = {
            "train_loss": [],
            "val_loss": []
        }


    def _train_one_epoch(self, train_loader):
        self.model.train()
        running_loss = 0.0

        for graphs, desc in train_loader:
            graphs = graphs.to(self.device)
            desc = desc.to(self.device)

            self.optimizer.zero_grad()
            preds = self.model(graphs, desc)
            loss = self.criterion(preds, graphs.y)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()

        return running_loss / len(train_loader)


    def _evaluate(self, loader):
        self.model.eval()
        total_loss = 0.0
        preds_all, targets_all = [], []

        with torch.no_grad():
            for graphs, desc in loader:
                graphs = graphs.to(self.device)
                desc = desc.to(self.device)

                preds = self.model(graphs, desc)
                loss = self.criterion(preds, graphs.y)

                total_loss += loss.item()
                preds_all.append(preds.cpu())
                targets_all.append(graphs.y.cpu())

        return (
            total_loss / len(loader),
            torch.cat(preds_all),
            torch.cat(targets_all)
        )


    def fit(self, train_loader, val_loader, verbose=True):
        patience_counter = 0

        for epoch in range(self.max_epochs):

            train_loss = self._train_one_epoch(train_loader)
            val_loss, _, _ = self._evaluate(val_loader)

            self.scheduler.step(val_loss)

            self.history["train_loss"].append(train_loss)
            self.history["val_loss"].append(val_loss)

            if verbose:
                print(
                    f"Epoch {epoch:03d} | "
                    f"Train: {train_loss:.4f} | "
                    f"Val: {val_loss:.4f}"
                )

            # ---- Best Model Tracking ----
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_model_state = copy.deepcopy(
                    self.model.state_dict()
                )
                patience_counter = 0
            else:
                patience_counter += 1

            # ---- Early Stopping ----
            if patience_counter >= self.patience:
                if verbose:
                    print("Early stopping triggered.")
                break

        # Restore best model
        self.model.load_state_dict(self.best_model_state)

        return self.model, self.best_val_loss


    def test(self, test_loader):
        return self._evaluate(test_loader)


    def save_best_model(self, path):
        torch.save(self.best_model_state, path)


    def load_model(self, path):
        self.model.load_state_dict(torch.load(path))


In [None]:
# 1Ô∏è‚É£4Ô∏è‚É£ Calculates all available RDKit descriptors for a given SMILES string.

def rdkit_descriptors_from_smiles(smiles):
    """
    Parameters
    ----------
    smiles : list of smiles strings

    Returns
    -------
    pd.DataFrame
        DataFrame with descriptor names as columns.
        Returns NaN values if SMILES is invalid.
    """
    # Get descriptor names and functions
    rdkit_descs = []

    desc_list = Descriptors.descList
    desc_names = [name for name, _ in desc_list]

    # Initialize output with NaNs
    # values = [np.nan] * len(desc_names)

    for sm in tqdm(smiles, total=len(smiles)):
      sm_descs = []
      mol = Chem.MolFromSmiles(sm)
      if mol is None:
          rdkit_descs.append([]*len(desc_names))
          continue

      for _, func in desc_list:
          try:
              sm_descs.append(func(mol))
          except Exception:
              sm_descs.append(np.nan)

      rdkit_descs.append(sm_descs)

    return pd.DataFrame(rdkit_descs, columns=desc_names)


In [None]:
# 1Ô∏è‚É£5Ô∏è‚É£ Load dataset, standardize target variable, and prepare data for modeling

tqdm.pandas()

file_path = "delaney-processed-scaffold.csv"
smiles_col = "smiles"
target_col = "measured log solubility in mols per litre"

set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

df = pd.read_csv(file_path)

y = df[target_col]

y_mean = y.mean()
y_std  = y.std()

y_scaled = (y - y_mean) / y_std

df_data = pd.concat([df[smiles_col], y_scaled], axis=1)

rdkit_descriptors = rdkit_descriptors_from_smiles(df[smiles_col])
variance_df = rdkit_descriptors.var()
zero_var_columns = variance_df[variance_df == 0].index.tolist()
cleaned_rdkit_descs = rdkit_descriptors.drop(columns = zero_var_columns)
X = cleaned_rdkit_descs.values


In [None]:
# 1Ô∏è‚É£6Ô∏è‚É£ Run repeated k-fold cross-validation training and record best validation loss per fold

fold_results = []

for repeat, fold, train_loader, val_loader in run_repeated_kfold_cv(
    df_data,
    X,
    smiles_col=smiles_col,
    target_col=target_col,
    n_splits=5,
    n_repeats=5
):
    print(f"\n===== Repeat {repeat + 1} | Fold {fold + 1} =====")

    model = DA_EAGNN(
        node_dim=11,
        desc_dim=X.shape[1],
        edge_dim=6,
        hidden_dim=128,
        dropout=0.1
    ).to(device)

    trainer = Trainer(
        model=model,
        device=device,
        lr=1e-3,
        patience=20,
        max_epochs=150
    )

    model, best_val_loss = trainer.fit(
        train_loader,
        val_loader
    )

    _, train_preds, train_targets = trainer.test(train_loader)
    _, val_preds, val_targets = trainer.test(val_loader)

    train_preds_true = (train_preds * y_std + y_mean).numpy()
    train_targets_true = (train_targets * y_std + y_mean).numpy()

    val_preds_true = (val_preds * y_std + y_mean).numpy()
    val_targets_true = (val_targets * y_std + y_mean).numpy()

    train_rmse = np.sqrt(mean_squared_error(
        train_preds_true,
        train_targets_true
    ))

    train_r2 = r2_score(
        train_preds_true,
        train_targets_true
    )

    train_mae = mean_absolute_error(
        train_preds_true,
        train_targets_true
    )

    val_rmse = np.sqrt(mean_squared_error(
        val_preds_true,
        val_targets_true
    ))

    val_r2 = r2_score(
        val_preds_true,
        val_targets_true
    )

    val_mae = mean_absolute_error(
        val_preds_true,
        val_targets_true
    )

    fold_results.append({
        "repeat": repeat + 1,
        "fold": fold + 1,
        "best_train_rmse": train_rmse,
        "best_train_r2": train_r2,
        "best_train_mae": train_mae,
        "best_val_rmse": val_rmse,
        "best_val_r2": val_r2,
        "best_val_mae": val_mae
    })

    print(
        f"Repeat {repeat + 1} | Fold {fold + 1} | "
        f"Best Val Loss: {best_val_loss:.4f}"
    )

fold_results_df = pd.DataFrame(fold_results)
fold_results_df.to_csv("Fold results descriptor augmented edge aware GNN.csv", index=False)


In [None]:
# 1Ô∏è‚É£7Ô∏è‚É£ Train an ensemble of descriptor-aware edge-aware GNN models using scaffold-based splits and report train/validation metrics

def train_ensemble_scaffold(
    df,
    X,
    split_col,
    device,
    smiles_col=smiles_col,
    target_col=target_col,
    y_mean=y_mean,
    y_std=y_std,
    batch_size=32,
    num_models=10,
    seed_start=42
):
    train_ind = split_col[split_col == "Train"].index
    val_ind = split_col[split_col != "Train"].index

    # train_ind, val_ind = train_test_split(list(range(1128)), test_size=0.2)

    train_df = df.loc[train_ind]
    train_X = X[train_ind, :]
    train_y_true = (df[target_col][train_ind] * y_std + y_mean).to_numpy()

    val_df = df.loc[val_ind]
    val_X = X[val_ind, :]
    val_y_true = (df[target_col][val_ind] * y_std + y_mean).to_numpy()

    # ---- Train dataset (fit scaler) ----
    train_dataset = MoleculeDataset(
        train_df,
        train_X,
        smiles_col,
        target_col,
        fit_scaler=True
    )

    # ---- Validation dataset (reuse scaler) ----
    val_dataset = MoleculeDataset(
        val_df,
        val_X,
        smiles_col,
        target_col,
        scaler=train_dataset.scaler
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    trained_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )


    fold_results = []

    for i in range(num_models):
        print(f"\n===== Ensemble model {i+1}/{num_models} =====")
        set_seed(seed_start + i)

        model = DA_EAGNN(
            node_dim=11,
            desc_dim=X.shape[1],
            edge_dim=6,
            hidden_dim=128,
            dropout=0.1
        ).to(device)

        trainer = Trainer(
            model=model,
            device=device,
            lr=1e-3,
            patience=20,
            max_epochs=150
        )

        model, best_val_loss = trainer.fit(
            train_loader,
            val_loader
        )

        _, train_preds, _ = trainer.test(trained_loader)
        _, val_preds, _ = trainer.test(val_loader)

        train_preds_true = (train_preds * y_std + y_mean).numpy()

        val_preds_true = (val_preds * y_std + y_mean).numpy()

        train_rmse = np.sqrt(mean_squared_error(
            train_preds_true,
            train_y_true
        ))

        train_r2 = r2_score(
            train_preds_true,
            train_y_true
        )

        train_mae = mean_absolute_error(
            train_preds_true,
            train_y_true
        )

        val_rmse = np.sqrt(mean_squared_error(
            val_preds_true,
            val_y_true
        ))

        val_r2 = r2_score(
            val_preds_true,
            val_y_true
        )

        val_mae = mean_absolute_error(
            val_preds_true,
            val_y_true
        )

        fold_results.append({
            "repeat": i + 1,
            "best_train_rmse": train_rmse,
            "best_train_r2": train_r2,
            "best_train_mae": train_mae,
            "best_val_rmse": val_rmse,
            "best_val_r2": val_r2,
            "best_val_mae": val_mae
        })

        print(
            f"Ensemble {i + 1} | "
            f"Best Val Loss: {best_val_loss:.4f}"
        )


    return fold_results


In [None]:
# 1Ô∏è‚É£8Ô∏è‚É£ Run scaffold-based ensemble training for the descriptor-augmented edge-aware GNN and save fold-wise performance metrics

results = train_ensemble_scaffold(
            df_data,
            X,
            df['BM-Scaffold'],
            device,
            smiles_col=smiles_col,
            target_col=target_col,
            y_mean=y_mean,
            y_std=y_std
        )
results_df = pd.DataFrame(results)
results_df.to_csv("Ensemble results descriptor augmented edge aware GNN Scaffold.csv", index=False)

In [None]:
# 1Ô∏è‚É£9Ô∏è‚É£