In [1]:
# Regular EDA (exploratory data analysis) and plotting libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import warnings
import rdkit
from rdkit import RDLogger
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdmolops
from rdkit.Chem import Descriptors, AllChem, rdMolDescriptors, rdCoordGen
from rdkit.ML.Descriptors import MoleculeDescriptors
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG=True
%matplotlib inline 

# ML models
import torch
from torch.nn import Linear, Dropout
import torch.nn as nn
import torch.optim as optim
import torch_geometric
from torch.optim.lr_scheduler import StepLR
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import GCNConv, TopKPooling, global_max_pool, global_mean_pool, radius_graph
from torch_geometric.nn import GINConv, GATConv, GraphConv
from torch_geometric.nn import GlobalAttention
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from torch_geometric.utils import to_networkx
from torch_geometric.utils import from_smiles
import torch.nn.functional as F
from e3nn.o3 import Irreps
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import Irreps
from e3nn.nn.models.v2106.gate_points_networks import SimpleNetwork, NetworkForAGraphWithAttributes

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score

In [2]:
df = pd.read_csv('BBBP.csv')
df = df.sample(n=2000, random_state=42).reset_index(drop=True)
df.shape

(2000, 4)

In [3]:
df

Unnamed: 0,num,name,p_np,smiles
0,1812,doliracetam,1,C1=CC=CC2=C1C(C(N2CC(N)=O)=O)C3=CC=CC=C3
1,696,methamphetamine,1,CN[C@@H](C)Cc1ccccc1
2,908,quinacillin,0,CC1(C)S[C@@H]2[C@H](NC(=O)c3nc4ccccc4nc3C(O)=O...
3,546,GR94839_L,0,c1(CC(N2[C@H](CN(CC2)C(=O)C)C[N@]2CC[C@H](O)C2...
4,1851,fluradoline,1,C1=C(F)C=CC3=C1C=C(SCCNC)C2=CC=CC=C2O3
...,...,...,...,...
1995,1532,rivastigmine,1,[C@H](C1=CC(=CC=C1)OC(N(CC)C)=O)(N(C)C)C
1996,833,fenbenicillin,0,CC1(C)S[C@@H]2[C@H](NC(=O)C(Oc3ccccc3)c4ccccc4...
1997,1271,formocortal,1,[C@]12(OC(O[C@@H]1CC3C2(CC(O)[C@@]4(F)C3CC(=C5...
1998,1594,tepirindole,1,C3=C(C1=CCN(CCC)CC1)C2=CC(=CC=C2[NH]3)Cl


In [4]:
def smiles_to_e3nn_graph(smiles, target):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    # Add hydrogens (important for 3D geometry)
    mol = Chem.AddHs(mol)

    # Generate 3D conformer with ETKDG
    if AllChem.EmbedMolecule(mol, AllChem.ETKDG()) != 0:
        return None

    # UFF optimization
    try:
        AllChem.UFFOptimizeMolecule(mol)
    except:
        return None

    try:
        conf = mol.GetConformer()
    except:
        return None

    # Atom features and positions
    pos = []
    node_input = []
    node_attr = []
    for atom in mol.GetAtoms():
        p = conf.GetAtomPosition(atom.GetIdx())
        pos.append([p.x, p.y, p.z])

        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetTotalNumHs(),
            atom.GetFormalCharge(),
            int(atom.GetHybridization()),
            int(atom.GetIsAromatic()),
            atom.GetMass(),
            int(atom.GetChiralTag()),
            atom.GetImplicitValence()
        ]
        node_input.append(features)
        node_attr.append([1.0])  # Placeholder scalar attr


    # Bond features and edge index
    edges = []
    edge_attr = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.extend([[i, j], [j, i]])

        bond_feats = [
            int(bond.GetBondTypeAsDouble()),  # bond type as float
            int(bond.IsInRing()),             # in-ring flag
            int(bond.GetIsConjugated())       # conjugation
        ]
        edge_attr.extend([bond_feats, bond_feats])

    try:
        data = Data(
            pos=torch.tensor(pos, dtype=torch.float),
            edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(),
            node_input=torch.tensor(node_input, dtype=torch.float),
            node_attr=torch.tensor(node_attr, dtype=torch.float),
            edge_attr=torch.tensor(edge_attr, dtype=torch.float),
            y=torch.tensor([target], dtype=torch.float),
            smiles=smiles  # optional: keep for eval
        )
    except:
        return None

    return data

In [5]:
RDLogger.DisableLog('rdApp.warning')

# Create all graphs with clear argument names
graph_list = [
    smiles_to_e3nn_graph(
        smiles=s,
        target=t)
    
    for s, t in zip(df['smiles'], df['p_np'])
]

# Filter out failed conformer generations
graph_list = [g for g in graph_list if g is not None]

[16:53:43] Explicit valence for atom # 1 N, 4, is greater than permitted
[16:53:56] Explicit valence for atom # 11 N, 4, is greater than permitted
[16:54:53] Explicit valence for atom # 5 N, 4, is greater than permitted
[16:55:03] UFFTYPER: Unrecognized atom type: Ca+2 (0)
[16:55:03] UFFTYPER: Unrecognized atom type: Ca+2 (0)
[16:55:11] UFFTYPER: Unrecognized charge state for atom: 16
[16:55:11] UFFTYPER: Unrecognized charge state for atom: 16
[16:55:38] Explicit valence for atom # 6 N, 4, is greater than permitted
[16:55:53] Explicit valence for atom # 5 N, 4, is greater than permitted
[16:56:13] Explicit valence for atom # 5 N, 4, is greater than permitted
[16:56:15] Explicit valence for atom # 5 N, 4, is greater than permitted
[16:56:18] Explicit valence for atom # 12 N, 4, is greater than permitted
[16:56:22] Explicit valence for atom # 6 N, 4, is greater than permitted
[16:56:22] Explicit valence for atom # 5 N, 4, is greater than permitted
[16:56:29] Explicit valence for atom # 5

In [6]:
# Define the ratio for splitting the dataset (80% for training, 20% for validation)
train_ratio = 0.80

# Calculate the total number of samples in the dataset
dataset_size = len(graph_list)

# Calculate the number of samples for the training and validation sets
train_size = int(train_ratio * dataset_size)
val_size = dataset_size - train_size

# Set a random seed for reproducibility
random_seed = 66
generator = torch.Generator().manual_seed(random_seed)

# Split the dataset into training and validation subsets
train_dataset, val_dataset = random_split(graph_list, [train_size, val_size], generator=generator)

In [7]:
# Create the DataLoaders for the train and val sets
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [8]:
class EGNN(nn.Module):
    """
    Equivariant Graph Neural Network Classifier using e3nn and MLP head.
    Designed for binary classification with BCEWithLogitsLoss.
    """

    def __init__(self,
                 radius=3.0,
                 num_neighbors=12,
                 num_nodes=30,
                 mul=50,
                 layers=3,
                 lmax=2,
                 pool_nodes=False,
                 hidden_dim1=64,
                 hidden_dim2=32,
                 dropout=0.7):
        """
        Initialize the E3NNGraphClassifier.

        Parameters:
            radius (float): Max graph radius for neighbor search.
            num_neighbors (int): Max number of neighbors per node.
            num_nodes (int): Max number of nodes per graph.
            mul (int): Multiplicity for equivariant layers.
            layers (int): Number of equivariant layers.
            lmax (int): Maximum angular momentum.
            pool_nodes (bool): Whether to apply node pooling in equivariant network.
            hidden_dim1 (int): First hidden layer dimension in MLP head.
            hidden_dim2 (int): Second hidden layer dimension in MLP head.
            dropout (float): Dropout probability for MLP head.
        """
        super().__init__()

        # Define irreducible representations (irreps)
        self.irreps_node_input = Irreps("9x0e")      # 9 scalar features per node
        self.irreps_node_attr  = Irreps("0e")        # Dummy node attribute (scalar)
        self.irreps_edge_attr  = Irreps("3x0e")      # Optional edge attributes (3 scalars)
        self.irreps_node_output = Irreps("1x0e")     # Scalar output from equivariant network

        # Equivariant GNN Backbone
        self.network = NetworkForAGraphWithAttributes(
            irreps_node_input=self.irreps_node_input,
            irreps_node_attr=self.irreps_node_attr,
            irreps_edge_attr=self.irreps_edge_attr,
            irreps_node_output=self.irreps_node_output,
            max_radius=radius,
            num_neighbors=num_neighbors,
            num_nodes=num_nodes,
            mul=mul,
            layers=layers,
            lmax=lmax,
            pool_nodes=pool_nodes
        )

        # MLP Head for graph-level classification
        self.mlp = nn.Sequential(
            nn.Linear(self.irreps_node_output.dim * 2, hidden_dim1),  # Combine max & mean pooling
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim2, 1)  # Single output logit for binary classification
        )

    def forward(self, data):
        """
        Forward pass of the classifier.

        Parameters:
            data (torch_geometric.data.Batch): Batch of graph data.

        Returns:
            logits (torch.Tensor): Raw logits for each graph in the batch.
        """
        # Pass node features through equivariant graph network
        node_out = self.network(data)

        # Global pooling: concatenate max and mean pooled features across nodes
        pooled = torch.cat([
            global_max_pool(node_out, data.batch),
            global_mean_pool(node_out, data.batch)
        ], dim=1)

        # Pass pooled features through MLP head to obtain logits
        logits = self.mlp(pooled).squeeze(-1)

        return logits  # Raw logits (BCEWithLogitsLoss expects logits, not probabilities)


In [9]:
# Set device to GPU if available, else fallback to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the EGNN Graph Classifier with custom hyperparameters
model = EGNN(
    radius=3.0,           # Max radius for neighbor search
    num_neighbors=16,     # Max neighbors per node
    num_nodes=30,         # Max nodes per graph
    mul=50,              # Layer multiplicity in equivariant network
    layers=3,             # Number of equivariant layers
    lmax=2,               # Max angular momentum
    hidden_dim1=64,      # First hidden layer size (MLP)
    hidden_dim2=32,       # Second hidden layer size (MLP)
    dropout=0.7           # Dropout rate in MLP
).to(device)

# Optimizer: Adam for parameter updates
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

# Learning rate scheduler: halve LR every 10 epochs
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

# Binary classification loss using raw logits
criterion = nn.BCEWithLogitsLoss()

In [10]:
def train(loader):
    model.train()
    total_loss = total_samples = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        logits = model(batch)
        labels = batch.y.view(-1).float()  # Ensure float type for BCE loss

        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()

        total_loss += float(loss) * batch.num_graphs
        total_samples += batch.num_graphs

    return total_loss / total_samples


@torch.no_grad()
def val(loader):
    model.eval()
    total_loss = total_samples = 0

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        labels = batch.y.view(-1).float()

        loss = criterion(logits, labels)
        total_loss += float(loss) * batch.num_graphs
        total_samples += batch.num_graphs

    return total_loss / total_samples

In [11]:
warnings.filterwarnings("ignore", category=UserWarning)


# Initialize lists to store training and val scores (RMSE values)
score_train = []
score_val = []

# Set the number of epochs for training
epochs = 25

# Loop over each epoch for training
for epoch in range(epochs):
    
    # Train the model and get the training RMSE
    train_rmse = train(train_loader)

    # Adjust learning rate
    scheduler.step()

    # Validate the model and get the val RMSE
    val_rmse = val(val_loader)

    # Append the training RMSE and val RMSE to their respective lists
    score_train.append(train_rmse)
    score_val.append(val_rmse)

    # Print the progress of the training process (epoch number, train loss, val loss)
    print(f'Epoch: {epoch+1}/{epochs} | Train Loss: {train_rmse:.4f}, '
          f'Validation Loss: {val_rmse:.4f}')

Epoch: 1/25 | Train Loss: 0.5918, Validation Loss: 0.5922
Epoch: 2/25 | Train Loss: 0.5818, Validation Loss: 0.5604
Epoch: 3/25 | Train Loss: 0.5569, Validation Loss: 0.5578
Epoch: 4/25 | Train Loss: 0.5786, Validation Loss: 0.5498
Epoch: 5/25 | Train Loss: 0.5730, Validation Loss: 0.5493
Epoch: 6/25 | Train Loss: 0.5713, Validation Loss: 0.5476
Epoch: 7/25 | Train Loss: 0.5496, Validation Loss: 0.5430
Epoch: 8/25 | Train Loss: 0.5392, Validation Loss: 0.5137
Epoch: 9/25 | Train Loss: 0.5320, Validation Loss: 0.4804
Epoch: 10/25 | Train Loss: 0.5244, Validation Loss: 0.4883
Epoch: 11/25 | Train Loss: 0.5017, Validation Loss: 0.4670
Epoch: 12/25 | Train Loss: 0.4960, Validation Loss: 0.4802
Epoch: 13/25 | Train Loss: 0.5119, Validation Loss: 0.4829
Epoch: 14/25 | Train Loss: 0.4777, Validation Loss: 0.4678
Epoch: 15/25 | Train Loss: 0.4883, Validation Loss: 0.4675
Epoch: 16/25 | Train Loss: 0.4801, Validation Loss: 0.5171
Epoch: 17/25 | Train Loss: 0.4648, Validation Loss: 0.4450
Epoch:

In [12]:
@torch.no_grad()
def eval_model(loader):
    model.eval()
    preds, trues, smi = [], [], []

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits)   # Convert logits to probabilities
        preds.append(probs.cpu())
        trues.append(batch.y.view(-1).cpu())
        smi.extend(batch.smiles)

    preds = torch.cat(preds).numpy()
    trues = torch.cat(trues).numpy()

    df_result = pd.DataFrame({
        'smiles': smi,
        'actual': trues,
        'probability': preds,
        'pred': (preds >= 0.5).astype(int)
    })

    return df_result

In [13]:
val_res = eval_model(val_loader)

In [14]:
def evaluate_binary(y_true, y_prob):
    y_pred = (y_prob >= 0.5).astype(int)
    
    acc = accuracy_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_prob)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall   : {recall:.4f}")
    print(f"F1-score : {f1:.4f}")

    return {
        "accuracy": round(acc, 4),
        "roc_auc": round(auc, 4),
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "f1_score": round(f1, 4)
    }

In [15]:
evaluate_binary(val_res['actual'], val_res['pred']);

Accuracy : 0.8186
ROC-AUC  : 0.6463
Precision: 0.8149
Recall   : 0.9833
F1-score : 0.8912
