# Package Installation and Environment Initialization

In [None]:
!pip install torch-geometric --quiet
!pip install torch --quiet
!pip install pandas --quiet
!pip install rdkit --quiet
!pip install numpy --quiet
!pip install matplotlib --quiet
!pip install seaborn --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.2/34.2 MB[0m [31m33.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [21]:
import torch
import torch_geometric
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import Dataset
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool
import torch_geometric.nn as pyg_nn
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch_geometric.loader import DenseDataLoader
from sklearn.preprocessing import LabelEncoder

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [None]:
pd.set_option("display.max_colwidth", None)

# Data Loading & Initial Exploration

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
# Path to the train dataset
file_path = '/content/drive/My Drive/GNN Project Materials/train_dataset.csv'

# Load the CSV file into a DataFrame
train_df = pd.read_csv(file_path)

# Display the first few rows
print(train_df.head())

   id                            buildingblock1_smiles buildingblock2_smiles  \
0   0  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
1   1  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
2   2  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
3   3  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
4   4  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   

     buildingblock3_smiles  \
0  Br.Br.NCC1CCCN1c1cccnn1   
1  Br.Br.NCC1CCCN1c1cccnn1   
2  Br.Br.NCC1CCCN1c1cccnn1   
3        Br.NCc1cccc(Br)n1   
4        Br.NCc1cccc(Br)n1   

                                                          molecule_smiles  \
0  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
1  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
2  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
3     C#CCOc1ccc(CNc2nc(NCc3cc

In [11]:
print(train_df.count())

id                       5246830
buildingblock1_smiles    5246830
buildingblock2_smiles    5246830
buildingblock3_smiles    5246830
molecule_smiles          5246830
protein_name             5246830
binds                    5246830
mol_wt                   5246830
logP                     5246830
rotamers                 5246830
dtype: int64


In [12]:
# Drop building block SMILES columns
train_df.drop(columns=['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles'], inplace=True)

# Data Preprocessing

### Addressing Target Imbalance

In [13]:
# Check unique values in the 'binds' column
print(train_df['binds'].value_counts())

binds
0    5236321
1      10509
Name: count, dtype: int64


In [14]:
# Step 1: Separate the data into two parts: one for binds = 0 and one for binds = 1
binds_0_df = train_df[train_df['binds'] == 0]
binds_1_df = train_df[train_df['binds'] == 1]

# Step 2: Downsample binds = 0 to match the number of binds = 1
binds_0_downsampled_df = binds_0_df.sample(n=len(binds_1_df), random_state=42)

# Step 3: Combine the downsampled binds = 0 data with binds = 1 data
balanced_df = pd.concat([binds_0_downsampled_df, binds_1_df])

# Step 4: Shuffle the final dataset to mix the rows
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Step 5: Verify the shape and class distribution
print(f"Balanced dataset shape: {balanced_df.shape}")
print(f"Class distribution in balanced dataset: {balanced_df['binds'].value_counts()}")

# Step 6: Show the first few rows of the balanced dataset
balanced_df.head()

Balanced dataset shape: (21018, 7)
Class distribution in balanced dataset: binds
0    10509
1    10509
Name: count, dtype: int64


Unnamed: 0,id,molecule_smiles,protein_name,binds,mol_wt,logP,rotamers
0,3566327,C#CC[C@H](CC(=O)N[Dy])Nc1nc(NCC2Cc3ccccc3NC2=O)nc(Nc2ccc(Br)nc2OC)n1,sEH,0,742.055548,2.7786,2048
1,4617174,C#CC[C@H](Nc1nc(NCc2ccc(OC)c(OC)c2C)nc(Nc2ccc(O)cc2Cl)n1)C(=O)N[Dy],BRD4,1,674.09483,3.29622,2048
2,1425402,C#CC[C@@H](Nc1nc(NCc2ccccc2-c2cnn(C)c2)nc(Nc2cccc(Br)c2C)n1)C(=O)N[Dy],BRD4,0,722.065718,4.08302,1024
3,959535,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCc2ccc(C)cc2OC2CCOC2)nc(Nc2ccc3c(c2)CNC3=O)n1,BRD4,1,719.176001,2.72122,4096
4,617223,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCNC(=O)c2ccno2)nc(Nc2nc3c(C)cccc3s2)n1,BRD4,0,683.096705,2.13762,4096


In [15]:
# Save the sampled dataset to Drive
sampled_file_path = '/content/drive/My Drive/GNN Project Materials/sampled_train_dataset.csv'
balanced_df.to_csv(sampled_file_path, index=False)

# Confirm file has been saved
print(f"Sampled dataset saved to: {sampled_file_path}")

Sampled dataset saved to: /content/drive/My Drive/GNN Project Materials/sampled_train_dataset.csv


In [16]:
print(balanced_df.count())

id                 21018
molecule_smiles    21018
protein_name       21018
binds              21018
mol_wt             21018
logP               21018
rotamers           21018
dtype: int64


In [17]:
print(balanced_df['binds'].value_counts())

binds
0    10509
1    10509
Name: count, dtype: int64


In [18]:
# Ensure all SMILES are strings and not NaN
balanced_df['molecule_smiles'] = balanced_df['molecule_smiles'].fillna("").astype(str)

# Check that every item is indeed a string
print(balanced_df['molecule_smiles'].apply(type).unique())  # Should print <class 'str'>

[<class 'str'>]


## Converting SMILES to Molecular Graphs

In [22]:
# Encode proteins once for the entire dataset
proteins = balanced_df["protein_name"].unique()  # Extract unique protein names
label_encoder = LabelEncoder()
label_encoder.fit(proteins)

# Store the encoded proteins for easy access
protein_mapping = {protein: label_encoder.transform([protein])[0] for protein in proteins}

In [23]:
def smiles_to_graph_with_protein(smiles, protein_name, binds, protein_mapping):
    # Parse the SMILES string to create the molecule object
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Invalid SMILES string: {smiles}")
        return None

    # Initialize lists for atom and bond features
    atom_features = []
    bond_index = []
    bond_features = []

    # Extract atom features
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetTotalNumHs(),
            atom.GetFormalCharge(),
            int(atom.GetIsAromatic()),
            int(atom.GetHybridization()),
            atom.GetMass(),
        ]
        atom_features.append(features)

    # Extract bond features
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_index.append((i, j))
        bond_index.append((j, i))
        bond_features.append([bond.GetBondTypeAsDouble(), int(bond.GetIsAromatic()), bond.IsInRing()])
        bond_features.append([bond.GetBondTypeAsDouble(), int(bond.GetIsAromatic()), bond.IsInRing()])

    # Handle empty graphs
    if len(atom_features) == 0 or len(bond_index) == 0:
        print(f"Skipping invalid SMILES string or empty graph: {smiles}")
        return None

    # Convert features to tensors
    atom_features = torch.tensor(atom_features, dtype=torch.float)
    bond_index = torch.tensor(bond_index, dtype=torch.long).t().contiguous()
    bond_features = torch.tensor(bond_features, dtype=torch.float)

    # Scalar molecular features
    mol_wt = Descriptors.MolWt(mol)
    logP = Descriptors.MolLogP(mol)
    rotamers = Descriptors.NumRotatableBonds(mol)
    tpsa = Descriptors.TPSA(mol)
    qed = Descriptors.qed(mol)
    scalar_features = torch.tensor([mol_wt, logP, rotamers, tpsa, qed], dtype=torch.float)

    # Use pre-encoded protein name (no need for LabelEncoder here)
    protein_encoded_tensor = torch.tensor([protein_mapping[protein_name]], dtype=torch.long)

    # Create label tensor
    if binds is None:
        print(f"No label for molecule: {smiles}")
        return None
    label_tensor = torch.tensor([binds], dtype=torch.long)

    # Return a Data object
    return Data(
        x=atom_features,
        edge_index=bond_index,
        edge_attr=bond_features,
        scalar_features=scalar_features.unsqueeze(0),  # Ensure consistent shape
        protein_feature=protein_encoded_tensor,        # Ensure consistent shape
        y=label_tensor,  # Label tensor
    )

In [24]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader

# Split the dataset into train and test sets (80% train, 20% test)
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42)

# Function to process rows into graph data
def create_graph_data(df, protein_mapping):
    graph_data_list = []
    for i, row in df.iterrows():
        graph_data = smiles_to_graph_with_protein(
            row["molecule_smiles"], row["protein_name"], row["binds"], protein_mapping
        )
        if graph_data is not None:
            graph_data_list.append(graph_data)
    return graph_data_list

# Create graph data for train and test sets
train_graph_data_list = create_graph_data(train_df, protein_mapping)
test_graph_data_list = create_graph_data(test_df, protein_mapping)

# Create DataLoaders for batching
train_loader = DataLoader(train_graph_data_list, batch_size=32, shuffle=True)
test_loader = DataLoader(test_graph_data_list, batch_size=32, shuffle=False)




# Building a GCN (Graph Convolutional Network)

In [25]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F

class GNNModel(nn.Module):
    def __init__(self, num_node_features, num_classes, dropout_rate):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)  # First GCN layer
        self.conv2 = GCNConv(64, 32)  # Second GCN layer
        self.batch_norm1 = nn.BatchNorm1d(64)  # Batch normalization after conv1
        self.batch_norm2 = nn.BatchNorm1d(32)  # Batch normalization after conv2
        self.dropout = nn.Dropout(p=dropout_rate)  # Dropout layer
        self.fc = nn.Linear(32, num_classes)  # Fully connected layer
        self.fc_dropout = nn.Dropout(p=dropout_rate)  # Dropout after FC

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = torch.relu(self.conv1(x, edge_index))  # First GCN layer
        x = self.batch_norm1(x)  # Apply batch normalization
        x = self.dropout(x)  # Apply dropout
        x = torch.relu(self.conv2(x, edge_index))  # Second GCN layer
        x = self.batch_norm2(x)  # Apply batch normalization
        x = self.dropout(x)  # Apply dropout
        x = global_mean_pool(x, batch)  # Global pooling
        x = self.fc(x)  # Final fully connected layer
        x = self.fc_dropout(x)  # Dropout after FC
        return x


In [27]:
import torch
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from torch_geometric.data import DataLoader

# 1. Define a function for training the model with early stopping
def train_with_early_stopping(model, data_loader, criterion, optimizer, num_epochs=50, patience=10):
    best_val_loss = float('inf')
    best_val_accuracy = 0.0
    epochs_without_improvement = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for data in data_loader:
            data = data.to(device)  # Ensure data is moved to the correct device
            optimizer.zero_grad()

            if data.y is None:
                print("Skipping graph with no label.")
                continue

            out = model(data)  # Forward pass
            loss = criterion(out, data.y)  # Calculate loss

            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights

            running_loss += loss.item()

        avg_loss = running_loss / len(data_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}")

        # Early stopping check: monitor validation accuracy
        if avg_loss < best_val_loss:
            best_val_loss = avg_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print("Early stopping triggered. Training stopped.")
            break

    return model

# 2. K-Fold Cross Validation
def k_fold_cross_validation(k=5, num_epochs=50, patience=10):
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(train_graph_data_list)):
        print(f"\nFold {fold + 1}/{k}")

        # Create train and validation sets
        train_data = [train_graph_data_list[i] for i in train_idx]
        val_data = [train_graph_data_list[i] for i in val_idx]

        # Create DataLoaders
        train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=16, shuffle=False)

        # Initialize the model, criterion, and optimizer
        model = GNNModel(num_node_features=7, num_classes=2, dropout_rate=0.2).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        criterion = nn.CrossEntropyLoss()

        # Train the model
        trained_model = train_with_early_stopping(model, train_loader, criterion, optimizer, num_epochs, patience)

        # Evaluate the model on the validation set
        trained_model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                if data.y is None:
                    continue
                out = trained_model(data)
                _, predicted = torch.max(out, dim=1)
                val_preds.extend(predicted.cpu().numpy())
                val_labels.extend(data.y.cpu().numpy())

        val_accuracy = accuracy_score(val_labels, val_preds)
        print(f"Fold {fold + 1} - Validation Accuracy: {val_accuracy:.4f}")
        fold_accuracies.append(val_accuracy)

    avg_accuracy = sum(fold_accuracies) / k
    print(f"\nAverage Validation Accuracy over {k} folds: {avg_accuracy:.4f}")

# Run K-Fold Cross-Validation
k_fold_cross_validation(k=5, num_epochs=50, patience=5)


Fold 1/5




Epoch 1/50, Loss: 0.6761662165560138
Epoch 2/50, Loss: 0.6697875984216842
Epoch 3/50, Loss: 0.668996815661613
Epoch 4/50, Loss: 0.6615954270402543
Epoch 5/50, Loss: 0.663035144916471
Epoch 6/50, Loss: 0.6637565908193872
Epoch 7/50, Loss: 0.6589506368078601
Epoch 8/50, Loss: 0.6592473716217045
Epoch 9/50, Loss: 0.6552583933868816
Epoch 10/50, Loss: 0.6532640633982796
Epoch 11/50, Loss: 0.6495934433517502
Epoch 12/50, Loss: 0.6506268258043759
Epoch 13/50, Loss: 0.6464145071937412
Epoch 14/50, Loss: 0.6473129392024595
Epoch 15/50, Loss: 0.6426361125587141
Epoch 16/50, Loss: 0.6420263056947841
Epoch 17/50, Loss: 0.6396305101805153
Epoch 18/50, Loss: 0.644463865177243
Epoch 19/50, Loss: 0.6399361574834082
Epoch 20/50, Loss: 0.6375584654093639
Epoch 21/50, Loss: 0.6358930351263561
Epoch 22/50, Loss: 0.6336999157804655
Epoch 23/50, Loss: 0.6320805711567614
Epoch 24/50, Loss: 0.6239674833030678
Epoch 25/50, Loss: 0.6264175964734217
Epoch 26/50, Loss: 0.6261735613436819
Epoch 27/50, Loss: 0.625



Epoch 1/50, Loss: 0.6788958615837143
Epoch 2/50, Loss: 0.6662451661015805
Epoch 3/50, Loss: 0.6607308638648669
Epoch 4/50, Loss: 0.6565869634131614
Epoch 5/50, Loss: 0.6557309153388995
Epoch 6/50, Loss: 0.6525256295691773
Epoch 7/50, Loss: 0.6564639087735402
Epoch 8/50, Loss: 0.6503576457358711
Epoch 9/50, Loss: 0.6455773200560694
Epoch 10/50, Loss: 0.6437951955727249
Epoch 11/50, Loss: 0.6435331087758793
Epoch 12/50, Loss: 0.6421193688859838
Epoch 13/50, Loss: 0.6393642762540779
Epoch 14/50, Loss: 0.6336728071131689
Epoch 15/50, Loss: 0.6434752948218946
Epoch 16/50, Loss: 0.6408239341873618
Epoch 17/50, Loss: 0.6372900078256972
Epoch 18/50, Loss: 0.6359958152212513
Epoch 19/50, Loss: 0.6309537585651974
Epoch 20/50, Loss: 0.6343448486679658
Epoch 21/50, Loss: 0.6335535803178545
Epoch 22/50, Loss: 0.6284372358982685
Epoch 23/50, Loss: 0.6336658606206053
Epoch 24/50, Loss: 0.6306728111152559
Epoch 25/50, Loss: 0.6316482415522463
Epoch 26/50, Loss: 0.6298326861277205
Epoch 27/50, Loss: 0.



Epoch 1/50, Loss: 0.6788554632422756


KeyboardInterrupt: 