In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from rdkit import Chem
from rdkit.Chem import AllChem, rdMolDescriptors
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from tqdm import tqdm

# ---------------------------
# 1. Data Preparation
# ---------------------------

def mol_to_graph_data_obj(mol):
    """Convert RDKit mol object to PyTorch Geometric Data object"""
    # Node features: atomic number one-hot encoding (simplified here as atomic number)
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append([atom.GetAtomicNum()])
    x = torch.tensor(atom_features, dtype=torch.float)

    # Edges
    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])
        # Bond type as edge feature (single=1, double=2, triple=3, aromatic=4)
        bond_type = bond.GetBondType()
        if bond_type == Chem.rdchem.BondType.SINGLE:
            bt = 1
        elif bond_type == Chem.rdchem.BondType.DOUBLE:
            bt = 2
        elif bond_type == Chem.rdchem.BondType.TRIPLE:
            bt = 3
        elif bond_type == Chem.rdchem.BondType.AROMATIC:
            bt = 4
        else:
            bt = 0
        edge_attr.append([bt])
        edge_attr.append([bt])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

def get_rdkit_fingerprint(mol, fp_size=2048):
    """Generate RDKit Daylight-like fingerprint as numpy array"""
    fp = Chem.RDKFingerprint(mol, fpSize=fp_size)
    arr = np.zeros((fp_size,), dtype=int)
    on_bits = list(fp.GetOnBits())
    arr[on_bits] = 1
    return arr

# Load your dataset CSV with 'canonical_smiles' and 'label' columns
df = pd.read_csv('../data/merged_dataset.csv')

# Filter valid molecules
valid_rows = []
graph_data_list = []
fingerprints = []
labels = []

print("Processing molecules and generating features...")
for idx, row in tqdm(df.iterrows(), total=len(df)):
    smi = row['canonical_smiles']
    label = row['activity_label']  # assuming binary label: 0 or 1
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        continue
    # Graph data
    graph_data = mol_to_graph_data_obj(mol)
    graph_data.y = torch.tensor([label], dtype=torch.float)
    graph_data_list.append(graph_data)

    # Fingerprint
    fp = get_rdkit_fingerprint(mol)
    fingerprints.append(fp)

    labels.append(label)
    valid_rows.append(idx)

# Convert fingerprints to tensor
fingerprints = torch.tensor(np.array(fingerprints), dtype=torch.float)

# ---------------------------
# 2. Dataset and DataLoader
# ---------------------------

# Split indices for train/test
train_idx, test_idx = train_test_split(range(len(graph_data_list)), test_size=0.2, stratify=labels, random_state=42)

train_graphs = [graph_data_list[i] for i in train_idx]
test_graphs = [graph_data_list[i] for i in test_idx]

train_fps = fingerprints[train_idx]
test_fps = fingerprints[test_idx]

train_labels = torch.tensor([labels[i] for i in train_idx], dtype=torch.float)
test_labels = torch.tensor([labels[i] for i in test_idx], dtype=torch.float)

train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

# ---------------------------
# 3. FP-GNN Model Definition
# ---------------------------

class FP_GNN(nn.Module):
    def __init__(self, fp_dim=2048, gnn_hidden_dim=128, fp_hidden_dim=128, out_dim=1):
        super(FP_GNN, self).__init__()
        # GNN layers (Graph Attention Network)
        self.conv1 = GATConv(in_channels=1, out_channels=gnn_hidden_dim, heads=4, concat=True)
        self.conv2 = GATConv(in_channels=gnn_hidden_dim*4, out_channels=gnn_hidden_dim, heads=4, concat=False)
        
        # Fingerprint MLP
        self.fp_mlp = nn.Sequential(
            nn.Linear(fp_dim, fp_hidden_dim),
            nn.ReLU(),
            nn.Linear(fp_hidden_dim, fp_hidden_dim),
            nn.ReLU()
        )
        
        # Combined MLP
        self.combined_mlp = nn.Sequential(
            nn.Linear(gnn_hidden_dim + fp_hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, out_dim)
        )
        
    def forward(self, data, fp):
        # data: PyG batch object with x, edge_index, edge_attr, batch
        x, edge_index = data.x, data.edge_index
        
        # GNN forward
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # Global pooling (mean)
        x = global_mean_pool(x, data.batch)
        
        # Fingerprint forward
        fp_emb = self.fp_mlp(fp)
        
        # Concatenate graph and fingerprint embeddings
        combined = torch.cat([x, fp_emb], dim=1)
        
        out = self.combined_mlp(combined)
        return out.squeeze(1)  # output shape: (batch_size,)

# ---------------------------
# 4. Training and Evaluation Functions
# ---------------------------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FP_GNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

def train_epoch(model, loader, fps, optimizer, criterion):
    model.train()
    total_loss = 0
    for i, batch in enumerate(loader):
        batch = batch.to(device)
        batch_fp = fps[i*loader.batch_size : i*loader.batch_size + batch.num_graphs].to(device)
        optimizer.zero_grad()
        out = model(batch, batch_fp)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate(model, loader, fps, labels):
    model.eval()
    preds = []
    trues = []
    for i, batch in enumerate(loader):
        batch = batch.to(device)
        batch_fp = fps[i*loader.batch_size : i*loader.batch_size + batch.num_graphs].to(device)
        out = model(batch, batch_fp)
        preds.append(torch.sigmoid(out).cpu())
        trues.append(batch.y.cpu())
    preds = torch.cat(preds).numpy()
    trues = torch.cat(trues).numpy()
    roc_auc = roc_auc_score(trues, preds)
    pr_auc = average_precision_score(trues, preds)
    return roc_auc, pr_auc

# ---------------------------
# 5. Run Training and Evaluation
# ---------------------------

num_epochs = 30

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, train_fps, optimizer, criterion)
    roc_auc, pr_auc = evaluate(model, test_loader, test_fps, test_labels)
    print(f"Epoch {epoch:02d} - Train Loss: {train_loss:.4f} - Test ROC AUC: {roc_auc:.4f} - Test PR AUC: {pr_auc:.4f}")


Processing molecules and generating features...


100%|██████████| 33208/33208 [05:23<00:00, 102.73it/s]


Epoch 01 - Train Loss: 0.6934 - Test ROC AUC: 0.5619 - Test PR AUC: 0.5840
Epoch 02 - Train Loss: 0.6933 - Test ROC AUC: 0.5355 - Test PR AUC: 0.5317
Epoch 03 - Train Loss: 0.6932 - Test ROC AUC: 0.4980 - Test PR AUC: 0.5176
Epoch 04 - Train Loss: 0.6931 - Test ROC AUC: 0.5120 - Test PR AUC: 0.5213
Epoch 05 - Train Loss: 0.6928 - Test ROC AUC: 0.5274 - Test PR AUC: 0.5308
Epoch 06 - Train Loss: 0.6920 - Test ROC AUC: 0.5184 - Test PR AUC: 0.5216
Epoch 07 - Train Loss: 0.6919 - Test ROC AUC: 0.5191 - Test PR AUC: 0.5207
Epoch 08 - Train Loss: 0.6919 - Test ROC AUC: 0.5232 - Test PR AUC: 0.5219
Epoch 09 - Train Loss: 0.6918 - Test ROC AUC: 0.5201 - Test PR AUC: 0.5213
Epoch 10 - Train Loss: 0.6917 - Test ROC AUC: 0.5214 - Test PR AUC: 0.5216
Epoch 11 - Train Loss: 0.6916 - Test ROC AUC: 0.5211 - Test PR AUC: 0.5217
Epoch 12 - Train Loss: 0.6915 - Test ROC AUC: 0.5212 - Test PR AUC: 0.5208
Epoch 13 - Train Loss: 0.6915 - Test ROC AUC: 0.5200 - Test PR AUC: 0.5210
Epoch 14 - Train Loss: 0.

In [None]:
# Plot of Train loss over different epochs
import matplotlib.pyplot as plt

train_losses = []  # List to store loss values for each epoch

num_epochs = 30
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, train_fps, optimizer, criterion)
    train_losses.append(train_loss)
    print(f"Epoch {epoch} Loss: {train_loss:.4f}")

# After training, plot the train loss curve
plt.figure(figsize=(8,5))
plt.plot(range(1, num_epochs + 1), train_losses, marker='o', color='blue')
plt.title('Training Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
