## Notebook 4: State-of-the-Art Model (AttentiveFP)

Implement, train, and evaluate a Graph Isomorphism Network (GIN), a powerful and robust GNN architecture. Our aim is to surpass the RandomForest baseline and achieve an ROC AUC score greater than 0.81.


### Setup

In [5]:
import pandas as pd
import numpy as np
import ast
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from tqdm.notebook import tqdm

# RDKit for chemoinformatics
from rdkit import Chem

# PyTorch Geometric
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, global_add_pool

# Scikit-learn for evaluation
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score

print("Libraries imported successfully.")

Libraries imported successfully.


### Load and Prepare Data

In [6]:
try:
    df = pd.read_csv('data/processed/dili_data_clean.csv')
    df.dropna(subset=['fingerprint', 'smiles'], inplace=True)
    print("Processed data loaded successfully.")
except FileNotFoundError:
    print("Error: dili_data_clean.csv not found.")

# --- Graph Conversion Functions ---
def get_atom_features(atom):
    features = []
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(atom.GetFormalCharge())
    features.append(int(atom.GetHybridization()))
    features.append(atom.GetIsAromatic())
    return features

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None

    atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.tensor(atom_features, dtype=torch.float)

    if mol.GetNumBonds() > 0:
        edge_indices = []
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.extend([(i, j), (j, i)])
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return Data(x=x, edge_index=edge_index)

# --- Create Graph Dataset ---
print("Converting SMILES to graph objects...")
data_list = [smiles_to_graph(s) for s in tqdm(df['smiles'])]

successful_indices = [i for i, d in enumerate(data_list) if d is not None]
data_list = [data_list[i] for i in successful_indices]
labels = df['dili_concern'].iloc[successful_indices].values

for i, data in enumerate(data_list):
    data.y = torch.tensor([labels[i]], dtype=torch.float)

print(f"Successfully created {len(data_list)} graph objects.")

Processed data loaded successfully.
Converting SMILES to graph objects...


  0%|          | 0/907 [00:00<?, ?it/s]



Successfully created 907 graph objects.


### Create Train and Test Sets

In [7]:
train_data, test_data = train_test_split(data_list, test_size=0.2, random_state=42, stratify=labels)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

print(f"Number of training graphs: {len(train_data)}")
print(f"Number of testing graphs: {len(test_data)}")

Number of training graphs: 725
Number of testing graphs: 182




### Define the GIN Model

Define a GIN with a simple Multi-Layer Perceptron (MLP) as its internal network.


In [8]:
num_node_features = data_list[0].x.shape[1]

class GIN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GIN, self).__init__()

        # Define the MLP for the GIN convolution
        mlp1 = Sequential(Linear(num_node_features, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv1 = GINConv(mlp1)

        mlp2 = Sequential(Linear(hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv2 = GINConv(mlp2)

        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()

        # Use global add pooling to get a graph-level representation
        x = global_add_pool(x, batch)

        x = self.lin(x)
        return x

model = GIN(hidden_channels=64)
print("GIN Model defined:")
print(model)

GIN Model defined:
GIN(
  (conv1): GINConv(nn=Sequential(
    (0): Linear(in_features=5, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  ))
  (conv2): GINConv(nn=Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  ))
  (lin): Linear(in_features=64, out_features=1, bias=True)
)


### Train the Model

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Calculate class weights
neg_count = np.sum(labels == 0)
pos_count = np.sum(labels == 1)
pos_weight_value = neg_count / pos_count if pos_count > 0 else 1
pos_weight_tensor = torch.tensor([pos_weight_value], dtype=torch.float)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y.view(-1, 1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

print("Starting GIN training...")
for epoch in range(1, 101): # Train for 100 epochs
    loss = train()
    scheduler.step()
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
print("Training finished.")

Starting GIN training...
Epoch: 10, Loss: 0.3584
Epoch: 20, Loss: 0.3541
Epoch: 30, Loss: 0.3436
Epoch: 40, Loss: 0.3403
Epoch: 50, Loss: 0.3428
Epoch: 60, Loss: 0.3316
Epoch: 70, Loss: 0.3214
Epoch: 80, Loss: 0.3213
Epoch: 90, Loss: 0.3194
Epoch: 100, Loss: 0.3162
Training finished.


### Evaluate the Model

In [10]:
def test(loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            out = model(data.x, data.edge_index, data.batch)
            preds = (torch.sigmoid(out) > 0.5).float()
            all_preds.extend(preds.view(-1).tolist())
            all_labels.extend(data.y.view(-1).tolist())
    return np.array(all_preds), np.array(all_labels)

y_pred, y_true = test(test_loader)

# Calculate metrics
gin_accuracy = accuracy_score(y_true, y_pred)
gin_roc_auc = roc_auc_score(y_true, y_pred)

### Compare Results and Conclude

In [11]:
print("--- GIN Model Performance ---")
print(f"Accuracy: {gin_accuracy:.3f}")
print(f"ROC AUC:  {gin_roc_auc:.3f}")

print("\n--- Comparison ---")
print("Metric         | RandomForest (Baseline) | GIN Model")
print("----------------|-------------------------|-----------")
rf_roc_auc = 0.761
print(f"ROC AUC       | {rf_roc_auc:.3f}                   | {gin_roc_auc:.3f}")

--- GIN Model Performance ---
Accuracy: 0.665
ROC AUC:  0.648

--- Comparison ---
Metric         | RandomForest (Baseline) | GIN Model
----------------|-------------------------|-----------
ROC AUC       | 0.761                   | 0.648
