In [1]:
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool as gap
from torch_geometric.utils import subgraph
import torch.optim as optim
from torch.optim.lr_scheduler import *
from sklearn.model_selection import train_test_split
import random
from rdkit import Chem
import itertools
from model import DumplingGNN
import csv
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, matthews_corrcoef,
    precision_recall_curve, auc
)


In [2]:
# 加载数据
supplier = Chem.SDMolSupplier('train_chembl.sdf', removeHs=False)

data_list = []

def mol_to_graph_with_coords(mol):
    # Extracting atomic signatures
    atom_features = []
    for atom in mol.GetAtoms():
        atom_feature = [
            atom.GetAtomicNum(),  
            atom.GetDegree(),  
            atom.GetTotalNumHs(),  
            atom.GetImplicitValence(),  
            atom.GetIsAromatic(),  
        ] + list(mol.GetConformer().GetAtomPosition(atom.GetIdx()))  
        atom_features.append(atom_feature)

    # Extract Side Index
    edge_indices = []
    for bond in mol.GetBonds():
        start_atom = bond.GetBeginAtomIdx()
        end_atom = bond.GetEndAtomIdx()
        edge_indices.append([start_atom, end_atom])
        edge_indices.append([end_atom, start_atom])

    # # Convert features and indices to PyTorch tensors
    x = torch.tensor(atom_features, dtype=torch.float)
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    
    # Get labels
    label = mol.GetProp("label")
    y = torch.tensor([float(label)], dtype=torch.float)

    data = Data(x=x, edge_index=edge_index, y=y)

    return data

In [3]:
for mol in supplier:
    if mol is not None:
        data = mol_to_graph_with_coords(mol)
        data_list.append(data)


train_data_list, test_data_list = train_test_split(data_list, test_size=0.2, random_state=42)

train_loader = DataLoader(train_data_list, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data_list, batch_size=1)


In [4]:
# Initialising the model
model = DumplingGNN(hidden_channels=32)
model_name = model.__class__.__name__
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

In [5]:
def train(loader):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y.long())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


In [6]:


def test(loader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    pred_labels = []
    with torch.no_grad():
        for data in loader:
            out = model(data)
            probs = torch.sigmoid(out)  
            # print(probs)
            all_preds.extend(probs[:, 1].tolist()) 
            all_labels.extend(data.y.tolist()) 

            result = out.argmax(dim=1)  
            # print(result)
            pred_labels.extend(result.tolist())
            correct += (result == data.y).sum().item()  
            total += data.y.size(0)
    
    # Print CM Matrix
    print('Confusion Matrix:')
    table = torch.zeros(2, 2, dtype=torch.int32)
    for i in range(len(all_labels)):
        table[int(all_labels[i]), int(pred_labels[i] > 0.5)] += 1
    print(table)

    return all_preds, all_labels, pred_labels
  
train_losses = []
test_accuracies = []
auc_scores = []
higest_accuracy = 0
higest_auc = 0
csv_file = f'{model_name}.csv'


for epoch in range(500):
    train_loss = train(train_loader)
    all_preds, all_labels, pred_labels= test(test_loader)

    cm = confusion_matrix(all_labels, pred_labels)
    acc = accuracy_score(all_labels, pred_labels)
    se = recall_score(all_labels, pred_labels)  # Sensitivity / Recall
    sp = cm[0, 0] / (cm[0, 0] + cm[0, 1])  # Specificity
    mcc = matthews_corrcoef(all_labels, pred_labels)
    auc_score = roc_auc_score(all_labels, all_preds)
    f1 = f1_score(all_labels, pred_labels)
    ba = (se + sp) / 2  # Balanced Accuracy
    precision, recall, _ = precision_recall_curve(all_labels, all_preds)
    prauc = auc(recall, precision)  # PR AUC
    ppv = precision_score(all_labels, pred_labels)  # Positive Predictive Value
    npv = cm[0, 0] / (cm[0, 0] + cm[1, 0])  # Negative Predictive Value


    # print(f"Accuracy: {acc}")
    # print(f"Sensitivity (Recall): {se}")
    # print(f"Specificity: {sp}")
    # print(f"MCC: {mcc}")
    # print(f"AUC: {auc_score}")
    # print(f"F1 Score: {f1}")
    # print(f"Balanced Accuracy: {ba}")
    # print(f"PRAUC: {prauc}")
    # print(f"PPV: {ppv}")
    # print(f"NPV: {npv}")

    headers = ['epoch', 'Loss', 'Accuracy', 'Sensitivity (Recall)', 'Specificity', 'MCC', 'AUC', 'F1 Score', 'Balanced Accuracy', 'PRAUC', 'PPV', 'NPV']
    data = [epoch+1, train_loss, acc, se, sp, mcc, auc_score, f1, ba, prauc, ppv, npv]


    with open(csv_file, mode='a', newline='') as file:  
        writer = csv.writer(file)
        if file.tell() == 0:
            writer.writerow(headers)
        writer.writerow(data)
        train_losses.append(train_loss)
        test_accuracies.append(acc)
    if acc > higest_accuracy:
        higest_accuracy = acc
        torch.save(model.state_dict(), f'saved_best_acc{model_name}.pth')
    if auc_score > higest_auc:
        higest_auc = auc_score
        torch.save(model.state_dict(), f'saved_best_auc{model_name}.pth')
    #scheduler.step(train_loss)

    print(f'Epoch: {epoch+1}, Loss: {train_loss:.4f}, Test Accuracy: {acc:.4f}, AUC: {auc_score:.4f}')
    print(f"Now Highest Test Accuracy: {higest_accuracy:.4f}, Highest AUC: {higest_auc:.4f}")
    
    #plt.figure(figsize=(10, 5))
print('Highest Test Accuracy:', higest_accuracy)
print('Highest AUC:', higest_auc)
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()









Confusion Matrix:
tensor([[100,   0],
        [ 77,   0]], dtype=torch.int32)
Epoch: 1, Loss: 0.6665, Test Accuracy: 0.5650, AUC: 0.5745
Now Highest Test Accuracy: 0.5650, Highest AUC: 0.5745


  _warn_prf(average, modifier, msg_start, len(result))


Confusion Matrix:
tensor([[100,   0],
        [ 77,   0]], dtype=torch.int32)
Epoch: 2, Loss: 0.6598, Test Accuracy: 0.5650, AUC: 0.5875
Now Highest Test Accuracy: 0.5650, Highest AUC: 0.5875


  _warn_prf(average, modifier, msg_start, len(result))


KeyboardInterrupt: 