In [1]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GINConv, global_add_pool
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, accuracy_score, classification_report
from rdkit import Chem
import pandas as pd
import numpy as np

from skfp.datasets.lrgb import load_peptides_func
from skfp.model_selection import scaffold_train_test_split

dataset = load_peptides_func(as_frame=True)
smiles, toxicity = dataset["SMILES"], dataset["toxic"]

train_smiles, test_smiles, train_toxicity, test_toxicity = scaffold_train_test_split(
    smiles, toxicity, test_size=0.2
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def smiles_to_pyg_graph(smiles_str, y):
    mol = Chem.MolFromSmiles(smiles_str)
    if mol is None:
        return None

    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append([
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            int(atom.GetHybridization()),
            int(atom.GetIsAromatic())
        ])
    x = torch.tensor(atom_features, dtype=torch.float)

    edge_indices = []
    for bond in mol.GetBonds():
        u = bond.GetBeginAtomIdx()
        v = bond.GetEndAtomIdx()
        edge_indices.append([u, v])
        edge_indices.append([v, u])

    if not edge_indices:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()

    y_tensor = torch.tensor([y], dtype=torch.long)

    return Data(x=x, edge_index=edge_index, y=y_tensor)


In [9]:
train_data_list = []
for i in range(len(train_smiles)):
    graph_data = smiles_to_pyg_graph(train_smiles[i], train_toxicity.iloc[i])
    if graph_data:
        train_data_list.append(graph_data)

test_data_list = []
for i in range(len(test_smiles)):
    graph_data = smiles_to_pyg_graph(test_smiles[i], test_toxicity.iloc[i])
    if graph_data:
        test_data_list.append(graph_data)



In [10]:
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_data_list, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data_list, batch_size=32, shuffle=False)

In [11]:
class GIN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes, hidden_channels):
        super(GIN, self).__init__()
        self.conv1 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(num_node_features, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        ))
        self.conv2 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        ))
        self.conv3 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        ))

        self.lin = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)

        x = global_add_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        return x

if train_data_list:
    num_node_features = train_data_list[0].x.shape[1]
else:
    raise ValueError("No valid training graphs were generated. Check SMILES input and RDKit conversion.")

num_classes = 2
hidden_channels = 64

model = GIN(num_node_features, num_classes, hidden_channels)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def evaluate(loader):
    model.eval()
    y_true = []
    y_pred = []
    y_pred_proba = []
    with torch.no_grad():
        for data in loader:
            out = model(data)
            pred_labels = out.argmax(dim=1)
            y_true.extend(data.y.tolist())
            y_pred.extend(pred_labels.tolist())
            y_pred_proba.extend(F.softmax(out, dim=1)[:, 1].tolist())

    accuracy = accuracy_score(y_true, y_pred)
    try:
        auroc = roc_auc_score(y_true, y_pred_proba)
    except ValueError:
        auroc = float('nan')

    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    
    return accuracy, auroc, report, y_true, y_pred

In [None]:
print("Starting GIN model training...")
num_epochs = 500
for epoch in range(1, num_epochs + 1):
    loss = train()
    train_acc, train_auroc, _, _, _ = evaluate(train_loader)
    test_acc, test_auroc, test_report_dict, _, _ = evaluate(test_loader)
    
    if epoch % 10 == 0 or epoch == num_epochs:
        print(f'\n--- Epoch {epoch:03d} Test Metrics ---')
        print(f"Accuracy: {test_acc:.4f}, AUROC: {test_auroc:.4f}")
        print(f"Precision (Class 0): {test_report_dict['0']['precision']:.4f}, Recall (Class 0): {test_report_dict['0']['recall']:.4f}, F1 (Class 0): {test_report_dict['0']['f1-score']:.4f}")
        print(f"Precision (Class 1): {test_report_dict['1']['precision']:.4f}, Recall (Class 1): {test_report_dict['1']['recall']:.4f}, F1 (Class 1): {test_report_dict['1']['f1-score']:.4f}")
        print(f"Macro Avg F1: {test_report_dict['macro avg']['f1-score']:.4f}")
        print(f"Weighted Avg F1: {test_report_dict['weighted avg']['f1-score']:.4f}")
    
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
          f'Train Acc: {train_acc:.4f}, Train AUROC: {train_auroc:.4f}, '
          f'Test Acc: {test_acc:.4f}, Test AUROC: {test_auroc:.4f}')

print("\nGIN model training complete.")
# Clean output because GIN produces a lot of logs

In [15]:
print("\nFinal Test Metrics:")
final_test_acc, final_test_auroc, final_test_report_dict, y_true_final, y_pred_final = evaluate(test_loader)
print(f"Accuracy: {final_test_acc:.4f}")
print(f"AUROC: {final_test_auroc:.4f}")
report = classification_report(y_true_final, y_pred_final, zero_division=0)
print("Final Test Classification Report:\n", report)


Final Test Metrics:
Accuracy: 0.7647
AUROC: 0.7244
Final Test Classification Report:
               precision    recall  f1-score   support

           0       0.76      0.98      0.85      2153
           1       0.84      0.29      0.43       954

    accuracy                           0.76      3107
   macro avg       0.80      0.63      0.64      3107
weighted avg       0.78      0.76      0.72      3107



In [18]:
with open("reports/gin_report.txt", "w") as f:
    f.write(report)
