## Notebook 3: Graph Neural Network Model

To build, train, and evaluate a Graph Neural Network (GNN) to predict DILI. We will compare its performance against our RandomForest baseline to see if learning from the molecular graph structure provides a significant advantage.

### Setup

In [10]:
import pandas as pd
import numpy as np
import ast

# RDKit for chemoinformatics
from rdkit import Chem
from rdkit.Chem import AllChem

# PyTorch and PyTorch Geometric
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
import torch.nn.functional as F

# Scikit-learn for our final model and evaluation
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, roc_auc_score

from tqdm.notebook import tqdm

print("Libraries imported successfully.")

Libraries imported successfully.


### Load the Processed Data

In [11]:
try:
    df = pd.read_csv('data/processed/dili_data_clean.csv')
    df.dropna(subset=['fingerprint', 'smiles'], inplace=True)
    print("Processed data loaded successfully.")
    print(f"Shape of the dataset: {df.shape}")
except FileNotFoundError:
    print("Error: dili_data_clean.csv not found.")

Processed data loaded successfully.
Shape of the dataset: (907, 5)


### Convert Molecules to Graph Representations

We need to convert each SMILES string into a graph format that PyG can understand. A graph is defined by:
- **Nodes (Atoms):** We'll create a feature vector for each atom.
- **Edges (Bonds):** We'll define the connections between atoms.

In [12]:
def get_atom_features(atom):
    features = []
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(atom.GetFormalCharge())
    # Convert enum to a numerical value
    features.append(int(atom.GetHybridization()))
    features.append(atom.GetIsAromatic())
    return features

def get_bond_features(bond):
    bond_type = bond.GetBondType()
    return [
        bond_type == Chem.rdchem.BondType.SINGLE,
        bond_type == Chem.rdchem.BondType.DOUBLE,
        bond_type == Chem.rdchem.BondType.TRIPLE,
        bond_type == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing(),
    ]

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None
    
    atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.tensor(atom_features, dtype=torch.float)

    edge_indices, edge_attrs = [], []
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_feats = get_bond_features(bond)
        edge_indices.extend([(i, j), (j, i)])
        edge_attrs.extend([bond_feats, bond_feats])

    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# Create a list of graph objects
print("Converting SMILES to graph objects...")
data_list = [smiles_to_graph(s) for s in tqdm(df['smiles'])]

successful_indices = [i for i, d in enumerate(data_list) if d is not None]
data_list = [data_list[i] for i in successful_indices]
df_model = df.iloc[successful_indices].copy()
labels = df_model['dili_concern'].values

# Assign labels to each graph object
for i, data in enumerate(data_list):
    data.y = torch.tensor([labels[i]], dtype=torch.float)

print(f"Successfully created {len(data_list)} graph objects.")

Converting SMILES to graph objects...


  0%|          | 0/907 [00:00<?, ?it/s]



Successfully created 907 graph objects.


### Create Train and Test Sets

We'll split our list of graphs into a training set and a testing set.


In [13]:
# Assign labels to each graph object
for i, data in enumerate(data_list):
    data.y = torch.tensor([labels[i]], dtype=torch.float)

# Split the data
train_data, test_data = train_test_split(data_list, test_size=0.2, random_state=42, stratify=labels)

print(f"Number of training graphs: {len(train_data)}")
print(f"Number of testing graphs: {len(test_data)}")

# Create DataLoaders to handle batching
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

Number of training graphs: 725
Number of testing graphs: 182




### Define the GNN Model Architecture

In [14]:
class GATClassifier(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_channels):
        super(GATClassifier, self).__init__()
        self.conv1 = GATConv(num_node_features, hidden_channels, heads=2, edge_dim=num_edge_features)
        self.conv2 = GATConv(hidden_channels * 2, hidden_channels, heads=1, edge_dim=num_edge_features)
        self.lin1 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def extract_embedding(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.conv1(x, edge_index, edge_attr)
        x = F.elu(x)
        x = self.conv2(x, edge_index, edge_attr)
        embedding = global_mean_pool(x, batch)
        return embedding

    def forward(self, data):
        embedding = self.extract_embedding(data)
        x = F.relu(self.lin1(embedding))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

# Initialize the model
num_node_features = data_list[0].x.shape[1]
num_edge_features = data_list[0].edge_attr.shape[1]
gnn_model = GATClassifier(
    num_node_features=num_node_features,
    num_edge_features=num_edge_features,
    hidden_channels=64
)
print("GNN Classifier defined.")

GNN Classifier defined.


### Train the GNN to Create Good Features

In [15]:
# Split data for GNN training
train_data, test_data = train_test_split(data_list, test_size=0.2, random_state=42, stratify=labels)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

# Setup optimizer and loss
optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
neg_count = np.sum(labels == 0)
pos_count = np.sum(labels == 1)
pos_weight_value = neg_count / pos_count if pos_count > 0 else 1
pos_weight_tensor = torch.tensor([pos_weight_value], dtype=torch.float)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

# Training loop
print("Training GNN feature extractor...")
for epoch in range(1, 101):
    gnn_model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = gnn_model(data)
        loss = criterion(out, data.y.view(-1, 1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:02d}, GNN Training Loss: {total_loss / len(train_loader):.4f}')
print("GNN training finished.")



Training GNN feature extractor...
Epoch: 20, GNN Training Loss: 0.3577
Epoch: 40, GNN Training Loss: 0.3485
Epoch: 60, GNN Training Loss: 0.3424
Epoch: 80, GNN Training Loss: 0.3376
Epoch: 100, GNN Training Loss: 0.3336
GNN training finished.


### Generate Graph Embeddings

In [16]:
full_loader = DataLoader(data_list, batch_size=32, shuffle=False)
gnn_model.eval()
embeddings = []
with torch.no_grad():
    for data in tqdm(full_loader, desc="Generating Trained GNN Embeddings"):
        embedding = gnn_model.extract_embedding(data)
        embeddings.append(embedding.cpu().numpy())

gnn_features = np.concatenate(embeddings, axis=0)
print(f"Shape of trained GNN features: {gnn_features.shape}")

Generating Trained GNN Embeddings:   0%|          | 0/29 [00:00<?, ?it/s]

Shape of trained GNN features: (907, 64)


### Create and Evaluate the Final Hybrid Model

In [17]:
fingerprints = np.array(df_model['fingerprint'].apply(ast.literal_eval).tolist())
X_hybrid = np.concatenate([fingerprints, gnn_features], axis=1)
y = labels

X_train, X_test, y_train, y_test = train_test_split(X_hybrid, y, test_size=0.2, random_state=42, stratify=y)

# Define a more extensive parameter grid to search
param_grid = {
    'n_estimators': [100, 200, 300, 400],
    'max_depth': [10, 20, 30, None],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['sqrt', 'log2']
}

# Initialize the RandomForest model
rf = RandomForestClassifier(class_weight='balanced', random_state=42)

# Initialize GridSearchCV
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=3, n_jobs=-1, verbose=2, scoring='roc_auc')

# Fit GridSearchCV
print("\nStarting extensive GridSearchCV to find best hyperparameters...")
grid_search.fit(X_train, y_train)

print("\nBest parameters found: ", grid_search.best_params_)

# Use the best model found by the grid search to make predictions
best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test)

# Calculate final metrics
hybrid_accuracy = accuracy_score(y_test, y_pred)
hybrid_roc_auc = roc_auc_score(y_test, y_pred)


Starting extensive GridSearchCV to find best hyperparameters...
Fitting 3 folds for each of 288 candidates, totalling 864 fits
[CV] END max_depth=10, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100; total time=   2.5s
[CV] END max_depth=10, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100; total time=   2.5s
[CV] END max_depth=10, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100; total time=   2.5s
[CV] END max_depth=10, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=200; total time=   3.9s
[CV] END max_depth=10, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=200; total time=   2.5s
[CV] END max_depth=10, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=200; total time=   2.6s
[CV] END max_depth=10, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=300; total time=   3.9s
[CV] END max_depth=10, max_featu

### Compare to Baseline and Conclude

In [None]:
print("--- Hybrid Model Performance ---")
print(f"Accuracy: {hybrid_accuracy:.3f}")
print(f"ROC AUC:  {hybrid_roc_auc:.3f}")

print("\n--- Final Comparison ---")
print("Metric         | RandomForest (Baseline) | GNN-Only | Hybrid Model")
print("----------------|-------------------------|----------|--------------")
rf_accuracy = 0.768 
rf_roc_auc = 0.761
gnn_roc_auc = 0.609 
print(f"ROC AUC       | {rf_roc_auc:.3f}                   | {gnn_roc_auc:.3f}    | {hybrid_roc_auc:.3f}")

--- Hybrid Model Performance ---
Accuracy: 0.764
ROC AUC:  0.677

--- Final Comparison ---
Metric         | RandomForest (Baseline) | GNN-Only | Hybrid Model
----------------|-------------------------|----------|--------------
ROC AUC       | 0.761                   | 0.609    | 0.677
