## Notebook 6: Deeper GIN Model

Improve upon our previous GIN model by increasing its depth. By adding more layers, we give the model the capacity to learn more complex relationships and potentially increase its predictive power.


### Setup

In [1]:
import pandas as pd
import numpy as np
import ast
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
from tqdm.notebook import tqdm

# RDKit for chemoinformatics
from rdkit import Chem

# PyTorch Geometric
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, global_add_pool

# Scikit-learn for evaluation
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score

print("Libraries imported successfully.")


Libraries imported successfully.


### Load and Prepare Data

In [2]:
try:
    df = pd.read_csv('data/processed/dili_data_clean.csv')
    df.dropna(subset=['fingerprint', 'smiles'], inplace=True)
    print("Processed data loaded successfully.")
except FileNotFoundError:
    print("Error: dili_data_clean.csv not found.")

# Graph Data Conversion
def get_atom_features(atom):
    features = [atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(), int(atom.GetHybridization()), atom.GetIsAromatic()]
    return features

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None
    atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.tensor(atom_features, dtype=torch.float)
    if mol.GetNumBonds() > 0:
        edge_indices = []
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.extend([(i, j), (j, i)])
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    return Data(x=x, edge_index=edge_index)

print("Converting SMILES to graph objects...")
data_list = [smiles_to_graph(s) for s in tqdm(df['smiles'])]

successful_indices = [i for i, d in enumerate(data_list) if d is not None]
data_list = [data_list[i] for i in successful_indices]
y = df['dili_concern'].iloc[successful_indices].values

for i, data in enumerate(data_list):
    data.y = torch.tensor([y[i]], dtype=torch.float)

print(f"Final aligned dataset size: {len(data_list)}")

Processed data loaded successfully.
Converting SMILES to graph objects...


  0%|          | 0/907 [00:00<?, ?it/s]



Final aligned dataset size: 907


### Create Train and Test Splits

In [3]:
indices = np.arange(len(data_list))
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42, stratify=y)

train_data = [data_list[i] for i in train_indices]
test_data = [data_list[i] for i in test_indices]
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

print("Train/test splits created.")

Train/test splits created.




### Define the Deeper GIN Model

Increased the model's depth from 2 to 4 GIN layers. We also add Batch Normalization, which helps stabilize training for deeper networks.


In [4]:
num_node_features = data_list[0].x.shape[1]

class DeeperGIN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(DeeperGIN, self).__init__()
        
        # Layer 1
        mlp1 = Sequential(Linear(num_node_features, hidden_channels), BatchNorm1d(hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv1 = GINConv(mlp1)
        
        # Layer 2
        mlp2 = Sequential(Linear(hidden_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv2 = GINConv(mlp2)
        
        # Layer 3
        mlp3 = Sequential(Linear(hidden_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv3 = GINConv(mlp3)
        
        # Layer 4
        mlp4 = Sequential(Linear(hidden_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv4 = GINConv(mlp4)

        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).relu()
        x = self.conv4(x, edge_index).relu()
        
        x = global_add_pool(x, batch)
        
        x = self.lin(x)
        return x

model = DeeperGIN(hidden_channels=64)
print("Deeper GIN Model defined:")
print(model)


Deeper GIN Model defined:
DeeperGIN(
  (conv1): GINConv(nn=Sequential(
    (0): Linear(in_features=5, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
  ))
  (conv2): GINConv(nn=Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
  ))
  (conv3): GINConv(nn=Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
  ))
  (conv4): GINConv(nn=Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_s

### Train the Deeper Model

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

neg_count = np.sum(y[train_indices] == 0)
pos_count = np.sum(y[train_indices] == 1)
pos_weight_value = neg_count / pos_count if pos_count > 0 else 1
pos_weight_tensor = torch.tensor([pos_weight_value], dtype=torch.float)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y.view(-1, 1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

print("Starting Deeper GIN training...")
for epoch in range(1, 101):
    loss = train()
    scheduler.step()
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
print("Training finished.")


Starting Deeper GIN training...
Epoch: 10, Loss: 0.3130
Epoch: 20, Loss: 0.2773
Epoch: 30, Loss: 0.2369
Epoch: 40, Loss: 0.2063
Epoch: 50, Loss: 0.1771
Epoch: 60, Loss: 0.1566
Epoch: 70, Loss: 0.1233
Epoch: 80, Loss: 0.1136
Epoch: 90, Loss: 0.1060
Epoch: 100, Loss: 0.0974
Training finished.


### Evaluate the Deeper Model

In [6]:
def test(loader):
    model.eval()
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            out = model(data.x, data.edge_index, data.batch)
            probs = torch.sigmoid(out).view(-1)
            all_probs.extend(probs.tolist())
            all_labels.extend(data.y.view(-1).tolist())
    return np.array(all_probs), np.array(all_labels)

y_probs, y_true = test(test_loader)
y_pred = (y_probs > 0.5).astype(int)

# Calculate metrics
deep_gin_accuracy = accuracy_score(y_true, y_pred)
deep_gin_roc_auc = roc_auc_score(y_true, y_probs)

### Compare Results and Conclude

In [7]:
print("--- Deeper GIN Model Performance ---")
print(f"Accuracy: {deep_gin_accuracy:.3f}")
print(f"ROC AUC:  {deep_gin_roc_auc:.3f}")

print("\n--- Comparison ---")
print("Metric         | RandomForest (Baseline) | Deeper GIN Model")
print("----------------|-------------------------|------------------")
rf_roc_auc = 0.761 # From our previous best model
print(f"ROC AUC       | {rf_roc_auc:.3f}                   | {deep_gin_roc_auc:.3f}")


--- Deeper GIN Model Performance ---
Accuracy: 0.742
ROC AUC:  0.739

--- Comparison ---
Metric         | RandomForest (Baseline) | Deeper GIN Model
----------------|-------------------------|------------------
ROC AUC       | 0.761                   | 0.739
