# Package Installation and Environment Initialization

In [1]:
!pip install torch-geometric --quiet
!pip install torch --quiet
!pip install pandas --quiet
!pip install rdkit --quiet
!pip install numpy --quiet
!pip install matplotlib --quiet
!pip install seaborn --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.2/34.2 MB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import torch
import torch_geometric
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import Dataset
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool
import torch_geometric.nn as pyg_nn
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch_geometric.loader import DenseDataLoader

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
pd.set_option("display.max_colwidth", None)

# Data Loading & Initial Exploration

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# Path to the train dataset
file_path = '/content/drive/My Drive/GNN Project Materials/train_dataset.csv'

# Load the CSV file into a DataFrame
train_df = pd.read_csv(file_path)

# Display the first few rows
print(train_df.head())

   id                            buildingblock1_smiles buildingblock2_smiles  \
0   0  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
1   1  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
2   2  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
3   3  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
4   4  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   

     buildingblock3_smiles  \
0  Br.Br.NCC1CCCN1c1cccnn1   
1  Br.Br.NCC1CCCN1c1cccnn1   
2  Br.Br.NCC1CCCN1c1cccnn1   
3        Br.NCc1cccc(Br)n1   
4        Br.NCc1cccc(Br)n1   

                                                          molecule_smiles  \
0  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
1  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
2  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
3     C#CCOc1ccc(CNc2nc(NCc3cc

In [6]:
print(train_df.count())

id                       5246830
buildingblock1_smiles    5246830
buildingblock2_smiles    5246830
buildingblock3_smiles    5246830
molecule_smiles          5246830
protein_name             5246830
binds                    5246830
mol_wt                   5246830
logP                     5246830
rotamers                 5246830
dtype: int64


In [7]:
# Drop building block SMILES columns
train_df.drop(columns=['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles'], inplace=True)

# Data Preprocessing

### Addressing Target Imbalance

In [8]:
# Check unique values in the 'binds' column
print(train_df['binds'].value_counts())

binds
0    5236321
1      10509
Name: count, dtype: int64


In [9]:
# Step 1: Separate the data into two parts: one for binds = 0 and one for binds = 1
binds_0_df = train_df[train_df['binds'] == 0]
binds_1_df = train_df[train_df['binds'] == 1]

# Step 2: Downsample binds = 0 to match the number of binds = 1
binds_0_downsampled_df = binds_0_df.sample(n=len(binds_1_df), random_state=42)

# Step 3: Combine the downsampled binds = 0 data with binds = 1 data
balanced_df = pd.concat([binds_0_downsampled_df, binds_1_df])

# Step 4: Shuffle the final dataset to mix the rows
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Step 5: Verify the shape and class distribution
print(f"Balanced dataset shape: {balanced_df.shape}")
print(f"Class distribution in balanced dataset: {balanced_df['binds'].value_counts()}")

# Step 6: Show the first few rows of the balanced dataset
balanced_df.head()

Balanced dataset shape: (21018, 7)
Class distribution in balanced dataset: binds
0    10509
1    10509
Name: count, dtype: int64


Unnamed: 0,id,molecule_smiles,protein_name,binds,mol_wt,logP,rotamers
0,3566327,C#CC[C@H](CC(=O)N[Dy])Nc1nc(NCC2Cc3ccccc3NC2=O)nc(Nc2ccc(Br)nc2OC)n1,sEH,0,742.055548,2.7786,2048
1,4617174,C#CC[C@H](Nc1nc(NCc2ccc(OC)c(OC)c2C)nc(Nc2ccc(O)cc2Cl)n1)C(=O)N[Dy],BRD4,1,674.09483,3.29622,2048
2,1425402,C#CC[C@@H](Nc1nc(NCc2ccccc2-c2cnn(C)c2)nc(Nc2cccc(Br)c2C)n1)C(=O)N[Dy],BRD4,0,722.065718,4.08302,1024
3,959535,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCc2ccc(C)cc2OC2CCOC2)nc(Nc2ccc3c(c2)CNC3=O)n1,BRD4,1,719.176001,2.72122,4096
4,617223,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCNC(=O)c2ccno2)nc(Nc2nc3c(C)cccc3s2)n1,BRD4,0,683.096705,2.13762,4096


In [10]:
# Save the sampled dataset to Drive
sampled_file_path = '/content/drive/My Drive/GNN Project Materials/sampled_train_dataset.csv'
balanced_df.to_csv(sampled_file_path, index=False)

# Confirm file has been saved
print(f"Sampled dataset saved to: {sampled_file_path}")

Sampled dataset saved to: /content/drive/My Drive/GNN Project Materials/sampled_train_dataset.csv


In [11]:
print(balanced_df.count())

id                 21018
molecule_smiles    21018
protein_name       21018
binds              21018
mol_wt             21018
logP               21018
rotamers           21018
dtype: int64


In [12]:
print(balanced_df['binds'].value_counts())

binds
0    10509
1    10509
Name: count, dtype: int64


In [13]:
# Ensure all SMILES are strings and not NaN
balanced_df['molecule_smiles'] = balanced_df['molecule_smiles'].fillna("").astype(str)

# Check that every item is indeed a string
print(balanced_df['molecule_smiles'].apply(type).unique())  # Should print <class 'str'>

[<class 'str'>]


## Converting SMILES to Molecular Graphs

In [14]:
from sklearn.preprocessing import LabelEncoder
# Encode proteins once for the entire dataset
proteins = balanced_df["protein_name"].unique()  # Extract unique protein names from your dataset
label_encoder = LabelEncoder()
label_encoder.fit(proteins)

In [15]:
def smiles_to_graph_with_protein(smiles, protein_name, binds, label_encoder):
    # Parse the SMILES string to create the molecule object
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Invalid SMILES string: {smiles}")
        return None

    # Initialize lists for atom and bond features
    atom_features = []
    bond_index = []
    bond_features = []

    # Extract atom features
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            int(atom.GetIsAromatic()),
            int(atom.GetHybridization()),
            atom.GetFormalCharge(),
        ]
        atom_features.append(features)

    # Extract bond features
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_index.append((i, j))
        bond_index.append((j, i))
        bond_features.append([bond.GetBondTypeAsDouble(), int(bond.GetIsAromatic())])
        bond_features.append([bond.GetBondTypeAsDouble(), int(bond.GetIsAromatic())])

    # Handle empty graphs
    if len(atom_features) == 0 or len(bond_index) == 0:
        print(f"Skipping invalid SMILES string or empty graph: {smiles}")
        return None

    # Convert features to tensors
    atom_features = torch.tensor(atom_features, dtype=torch.float)
    bond_index = torch.tensor(bond_index, dtype=torch.long).t().contiguous()
    bond_features = torch.tensor(bond_features, dtype=torch.float)

    # Scalar molecular features
    mol_wt = Descriptors.MolWt(mol)
    logP = Descriptors.MolLogP(mol)
    rotamers = Descriptors.NumRotatableBonds(mol)
    scalar_features = torch.tensor([mol_wt, logP, rotamers], dtype=torch.float)

    # Encode protein name
    protein_encoded = label_encoder.transform([protein_name])[0]
    protein_encoded_tensor = torch.tensor([protein_encoded], dtype=torch.long)

    # Create label tensor
    if binds is None:
        print(f"No label for molecule: {smiles}")
        return None
    label_tensor = torch.tensor([binds], dtype=torch.long)

    # Return a Data object
    return Data(
        x=atom_features,
        edge_index=bond_index,
        edge_attr=bond_features,
        scalar_features=scalar_features.unsqueeze(0),  # Ensure consistent shape
        protein_feature=protein_encoded_tensor,  # Ensure consistent shape
        y=label_tensor,  # Label tensor
    )

In [16]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader

# Split the dataset into train and test sets (80% train, 20% test)
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42)

# Function to process rows into graph data
def create_graph_data(df):
    graph_data_list = []
    for i, row in df.iterrows():
        graph_data = smiles_to_graph_with_protein(
            row["molecule_smiles"], row["protein_name"], row["binds"], label_encoder
        )
        if graph_data is not None:
            graph_data_list.append(graph_data)
    return graph_data_list

# Create graph data for train and test sets
train_graph_data_list = create_graph_data(train_df)
test_graph_data_list = create_graph_data(test_df)

# Create DataLoaders for batching
train_loader = DataLoader(train_graph_data_list, batch_size=32, shuffle=True)
test_loader = DataLoader(test_graph_data_list, batch_size=32, shuffle=False)



# Building a GCN (Graph Convolutional Network)

In [21]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool

class GNNModel(nn.Module):
    def __init__(self, num_node_features, num_classes, dropout_rate):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)  # First GCN layer
        self.conv2 = GCNConv(64, 32)  # Second GCN layer
        self.dropout = nn.Dropout(p=dropout_rate)  # Dropout layer
        self.fc = nn.Linear(32, num_classes)  # Fully connected layer

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = torch.relu(self.conv1(x, edge_index))  # First GCN layer
        x = self.dropout(x)  # Apply dropout
        x = torch.relu(self.conv2(x, edge_index))  # Second GCN layer
        x = self.dropout(x)  # Apply dropout
        x = global_mean_pool(x, batch)  # Global pooling
        x = self.fc(x)  # Final fully connected layer
        return x

In [30]:
from sklearn.model_selection import KFold
import torch
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import accuracy_score

# 1. Define a function for training the model with early stopping
def train_with_early_stopping(model, data_loader, criterion, optimizer, num_epochs=50, patience=10):
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    total_loss = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for data in data_loader:
            optimizer.zero_grad()

            if data.y is None:
                print("Skipping graph with no label.")
                continue

            out = model(data)  # Forward pass
            loss = criterion(out, data.y)  # Calculate loss

            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights

            running_loss += loss.item()

        avg_loss = running_loss / len(data_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}")

        # Early stopping check: monitor validation loss
        if avg_loss < best_val_loss:
            best_val_loss = avg_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print("Early stopping triggered. Training stopped.")
            break

    return model

# 2. K-Fold Cross Validation
def k_fold_cross_validation(k=5, num_epochs=50, patience=10):
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(train_graph_data_list)):
        print(f"\nFold {fold + 1}/{k}")

        # Create train and validation sets
        train_data = [train_graph_data_list[i] for i in train_idx]
        val_data = [train_graph_data_list[i] for i in val_idx]

        # Create DataLoaders
        train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=16, shuffle=False)

        # Initialize the model, criterion, and optimizer
        model = GNNModel(num_node_features=4, num_classes=2, dropout_rate=0.2).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        criterion = nn.CrossEntropyLoss()

        # Train the model
        trained_model = train_with_early_stopping(model, train_loader, criterion, optimizer, num_epochs, patience)

        # Evaluate the model on the validation set
        trained_model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for data in val_loader:
                if data.y is None:
                    continue
                out = trained_model(data)
                _, predicted = torch.max(out, dim=1)
                val_preds.extend(predicted.cpu().numpy())
                val_labels.extend(data.y.cpu().numpy())

        val_accuracy = accuracy_score(val_labels, val_preds)
        print(f"Fold {fold + 1} - Validation Accuracy: {val_accuracy:.4f}")
        fold_accuracies.append(val_accuracy)

    avg_accuracy = sum(fold_accuracies) / k
    print(f"\nAverage Validation Accuracy over {k} folds: {avg_accuracy:.4f}")

# Run K-Fold Cross-Validation
k_fold_cross_validation(k=5, num_epochs=50, patience=5)


Fold 1/5




Epoch 1/50, Loss: 0.6951395251237822
Epoch 2/50, Loss: 0.6932661018105097
Epoch 3/50, Loss: 0.6928529703943114
Epoch 4/50, Loss: 0.6927223550011797
Epoch 5/50, Loss: 0.691462608353278
Epoch 6/50, Loss: 0.6905181910421846
Epoch 7/50, Loss: 0.689309228543861
Epoch 8/50, Loss: 0.6890498926324878
Epoch 9/50, Loss: 0.6882911886528188
Epoch 10/50, Loss: 0.688129097825707
Epoch 11/50, Loss: 0.6867939002023441
Epoch 12/50, Loss: 0.6865284531493533
Epoch 13/50, Loss: 0.6856726198927826
Epoch 14/50, Loss: 0.6847849836956165
Epoch 15/50, Loss: 0.6841353175898086
Epoch 16/50, Loss: 0.6824867894760068
Epoch 17/50, Loss: 0.6835279309508633
Epoch 18/50, Loss: 0.6809354548505313
Epoch 19/50, Loss: 0.6800777805829586
Epoch 20/50, Loss: 0.6793803239549099
Epoch 21/50, Loss: 0.6774394909068322
Epoch 22/50, Loss: 0.6756060757053026
Epoch 23/50, Loss: 0.6746968872907187
Epoch 24/50, Loss: 0.6731931170300269
Epoch 25/50, Loss: 0.67107765605135
Epoch 26/50, Loss: 0.6708632380160083
Epoch 27/50, Loss: 0.66857



Epoch 1/50, Loss: 0.6949778187856096
Epoch 2/50, Loss: 0.6927801934065348
Epoch 3/50, Loss: 0.6925953968243139
Epoch 4/50, Loss: 0.6914431819167347
Epoch 5/50, Loss: 0.6897122286587918
Epoch 6/50, Loss: 0.6891332582684675
Epoch 7/50, Loss: 0.6876689433059284
Epoch 8/50, Loss: 0.6873630611281333
Epoch 9/50, Loss: 0.6864774004599428
Epoch 10/50, Loss: 0.6863345711466532
Epoch 11/50, Loss: 0.6861300953928554
Epoch 12/50, Loss: 0.6848733838333103
Epoch 13/50, Loss: 0.6845771670341492
Epoch 14/50, Loss: 0.6845848256285777
Epoch 15/50, Loss: 0.6846713016914839
Epoch 16/50, Loss: 0.6831418540900159
Epoch 17/50, Loss: 0.6826941614769019
Epoch 18/50, Loss: 0.6819630922942326
Epoch 19/50, Loss: 0.6814013667821034
Epoch 20/50, Loss: 0.6820321854172933
Epoch 21/50, Loss: 0.6818176572444747
Epoch 22/50, Loss: 0.6807591524617426
Epoch 23/50, Loss: 0.6792051840338781
Epoch 24/50, Loss: 0.6795632042584323
Epoch 25/50, Loss: 0.6787017358469197
Epoch 26/50, Loss: 0.6779814317166876
Epoch 27/50, Loss: 0.



Epoch 1/50, Loss: 0.6957519568245987
Epoch 2/50, Loss: 0.6932468798441213
Epoch 3/50, Loss: 0.6912601252798518
Epoch 4/50, Loss: 0.6893003846603966
Epoch 5/50, Loss: 0.6880313138190688
Epoch 6/50, Loss: 0.6869210040243288
Epoch 7/50, Loss: 0.6875260923053365
Epoch 8/50, Loss: 0.6847950821115627
Epoch 9/50, Loss: 0.6849186750711357
Epoch 10/50, Loss: 0.6850920550604921
Epoch 11/50, Loss: 0.6844723854705638
Epoch 12/50, Loss: 0.6833872195514856
Epoch 13/50, Loss: 0.6834289596707303
Epoch 14/50, Loss: 0.6817775818737452
Epoch 15/50, Loss: 0.6826974274712425
Epoch 16/50, Loss: 0.6823307183636497
Epoch 17/50, Loss: 0.6810932442917977
Epoch 18/50, Loss: 0.679578961926323
Epoch 19/50, Loss: 0.6805165740312492
Epoch 20/50, Loss: 0.6789992438910549
Epoch 21/50, Loss: 0.6757344278937713
Epoch 22/50, Loss: 0.6743809593560154
Epoch 23/50, Loss: 0.6734401449011849
Epoch 24/50, Loss: 0.6697239363377783
Epoch 25/50, Loss: 0.6673096731641772
Epoch 26/50, Loss: 0.6663523026510595
Epoch 27/50, Loss: 0.6



Epoch 1/50, Loss: 0.6946546713889141
Epoch 2/50, Loss: 0.693103666112766
Epoch 3/50, Loss: 0.6916803864321442
Epoch 4/50, Loss: 0.6905276722204955
Epoch 5/50, Loss: 0.6886833440534567
Epoch 6/50, Loss: 0.6887644635652957
Epoch 7/50, Loss: 0.6868924676450623
Epoch 8/50, Loss: 0.6871660867712586
Epoch 9/50, Loss: 0.6860826840440385
Epoch 10/50, Loss: 0.6849411796871462
Epoch 11/50, Loss: 0.6843578004666939
Epoch 12/50, Loss: 0.6856659436055794
Epoch 13/50, Loss: 0.6842411883379135
Epoch 14/50, Loss: 0.6832219231142868
Epoch 15/50, Loss: 0.6834901236750708
Epoch 16/50, Loss: 0.6818408981656632
Epoch 17/50, Loss: 0.6812467588539214
Epoch 18/50, Loss: 0.6804885635336287
Epoch 19/50, Loss: 0.6803451747446366
Epoch 20/50, Loss: 0.6776003099791927
Epoch 21/50, Loss: 0.6768274766511497
Epoch 22/50, Loss: 0.6758358394206634
Epoch 23/50, Loss: 0.674533414075264
Epoch 24/50, Loss: 0.6729383881111917
Epoch 25/50, Loss: 0.6720567935287031
Epoch 26/50, Loss: 0.6713028784030683
Epoch 27/50, Loss: 0.67



Epoch 1/50, Loss: 0.6950795885073586
Epoch 2/50, Loss: 0.6932415831103195
Epoch 3/50, Loss: 0.6931610881360901
Epoch 4/50, Loss: 0.6933118490486735
Epoch 5/50, Loss: 0.6929814286804653
Epoch 6/50, Loss: 0.6926490571637783
Epoch 7/50, Loss: 0.6916197379619131
Epoch 8/50, Loss: 0.6906550924928236
Epoch 9/50, Loss: 0.6903922787461071
Epoch 10/50, Loss: 0.6896047374581327
Epoch 11/50, Loss: 0.688390366812807
Epoch 12/50, Loss: 0.6885877122635223
Epoch 13/50, Loss: 0.6880927577738813
Epoch 14/50, Loss: 0.6868494405332558
Epoch 15/50, Loss: 0.6872211133398992
Epoch 16/50, Loss: 0.6862302246048391
Epoch 17/50, Loss: 0.6861617093137271
Epoch 18/50, Loss: 0.6850233119109581
Epoch 19/50, Loss: 0.6847487304638739
Epoch 20/50, Loss: 0.6842556017474244
Epoch 21/50, Loss: 0.6846363482804134
Epoch 22/50, Loss: 0.6829165643375638
Epoch 23/50, Loss: 0.6824395648080869
Epoch 24/50, Loss: 0.6829839354462913
Epoch 25/50, Loss: 0.6815708349065747
Epoch 26/50, Loss: 0.6808193454277501
Epoch 27/50, Loss: 0.6