<div style="display: flex; align-items: center;">
    <span style="font-size: 24px; color: #003366; font-weight: 500;">Predicting Molecule Binding using Graph Neural Network</span>
    <img src="../logo.svg" style="height: 50px; width: auto; margin-left: auto;"/>
</div>

In [None]:
import os
import sys
import time
import rdkit
import torch
import psutil
import pickle
import warnings
import numpy as np
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F 

from rdkit import Chem
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit.Chem import DataStructs, AllChem, Descriptors
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator

from torch.optim import SGD
from torch.nn import Linear, Dropout
from torch.nn.functional import relu
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Subset, DataLoader
from torch_geometric.datasets import MoleculeNet
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool, global_max_pool

from sklearn.manifold import TSNE
from sklearn.utils import resample
from sklearn.metrics import confusion_matrix, f1_score
from sklearn.model_selection import KFold, train_test_split

from tqdm import tnrange
from collections import Counter
from ogb.utils import smiles2graph
from IPython.display import display, HTML
from standardiser import break_bonds, neutralise, unsalt, standardise

warnings.filterwarnings("ignore")
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 1: Check system availability </h2>
</div>

In [None]:
def check_availability():
    if "CUDA_VISIBLE_DEVICES" not in os.environ:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    if torch.cuda.is_available():
        device = torch.device("cuda")
        gpu_info = os.popen('nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits').readlines()
        gpu_available = 100 - int(gpu_info[0].strip())
        gpu_result = f"\033[1m\033[34mGPU availability: \033[91m{gpu_available:.2f}%\033[0m"
    else:
        device = torch.device("cpu")
        gpu_result = 'GPU is not available, using CPU instead'

    cpu_percentage = psutil.cpu_percent()
    cpu_available = 100 - cpu_percentage
    cpu_result = f"\033[1m\033[34mCPU availability: \033[91m{cpu_available:.2f}%\033[0m"
    
    print(gpu_result)
    print(cpu_result)
    return device

device = check_availability()

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 2: Load data </h2>
</div>

In [None]:
df = pd.read_csv('../data/leash_bio_brd4.csv')
display(df.head())
print(df.shape)

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 3: Remove salts and standardise smiles </h2>
</div>

In [None]:
def remove_salts(df):
    def remove_salt(smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return ''
        
        mol = break_bonds.run(mol)
        mol = neutralise.run(mol)
        non_salt_frags = []
        for frag in Chem.GetMolFrags(mol, asMols=True):        
            if unsalt.is_nonorganic(frag): 
                continue 
            if unsalt.is_salt(frag): 
                continue      
            non_salt_frags.append(frag)
        
        non_salt_smiles = [Chem.MolToSmiles(frag) for frag in non_salt_frags]
        non_salt_smiles = '.'.join(non_salt_smiles) 

        try:
            mol = Chem.MolFromSmiles(non_salt_smiles)
            standard_mol = standardise.run(mol)
            standard_smiles = Chem.MolToSmiles(standard_mol)
            return standard_smiles
        except standardise.StandardiseException as e:
            return None
    
    initial_count = len(df)
    df['SMILES_unsalt'] = df['SMILES'].apply(remove_salt)
    df_unsalt = df.dropna(subset=['SMILES_unsalt'])
    df_unsalt = df_unsalt.drop(columns=['SMILES'])
    df_unsalt = df_unsalt.rename(columns={'SMILES_unsalt': 'SMILES'})
    final_count = len(df_unsalt)
    print(f"\033[1m\033[34mNumber of datapoints removed: \033[91m{initial_count - final_count}\033[0m")
    print(f"\033[1m\033[34mNumber of datapoints remaining: \033[91m{final_count}\033[0m")
    return df_unsalt, initial_count, final_count

df_remove_salts, initial_count, after_salts_count = remove_salts(df)

In [None]:
df = df_remove_salts.copy()
df = df[['id', 'SMILES', 'Target']]

display(df.head())
print(df.shape)

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 4: Balance dataset </h2>
</div>

In [None]:
df['Target'].value_counts()

In [None]:
df_majority = df[df['Target'] == 0]
df_minority = df[df['Target'] == 1]

df_majority_downsampled = resample(df_majority, replace=False, n_samples=df_minority.shape[0], random_state=42)
df = pd.concat([df_majority_downsampled, df_minority])
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
df['Target'].value_counts()

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 5: Train-Test split </h2>
</div>

In [None]:
train_df, test_df = train_test_split(df, test_size=0.1, random_state=42, stratify=df['Target'])

print("Train Data")
display(train_df.head())
print(train_df.shape)

print("-" * 80)
print("Test Data")
display(test_df.head())
print(test_df.shape)

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 6: Visualise train-test data </h2>
</div>

In [None]:
def generate_ecfp(smiles_list, radius=2, n_bits=2048):
    ecfp_list = []
    generator = GetMorganGenerator(radius=radius, fpSize=n_bits)
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            ecfp = generator.GetFingerprint(mol)
            ecfp_list.append(np.array(ecfp))
        else:
            ecfp_list.append(np.zeros(n_bits))
    return np.array(ecfp_list)

X_train = generate_ecfp(train_df['SMILES'])
X_test = generate_ecfp(test_df['SMILES'])
y_train = train_df['Target']
y_test = test_df['Target']

tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(np.vstack((X_train, X_test)))
tsne_train = tsne_results[:len(X_train)]
tsne_test = tsne_results[len(X_train):]

plt.figure(figsize=(6, 6))
plt.scatter(tsne_train[:, 0], tsne_train[:, 1], c='#7b1fa2', label=f'Train Data (n={len(X_train)})', s=10, alpha=0.7)
plt.scatter(tsne_test[:, 0], tsne_test[:, 1], c='#ff6f00', label=f'Test Data (n={len(X_test)})', s=10, alpha=1)
plt.title('t-SNE plot of Train and Test Data')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.legend()
os.makedirs('model_files', exist_ok=True)
plt.savefig('model_files/tsne_train_vs_test_data.png', bbox_inches='tight')
plt.show()

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 7: Convert Data into Graph format </h2>
</div>

In [None]:
class CustomMoleculeNetDataset(InMemoryDataset):
    def __init__(self, data_list):
        super(CustomMoleculeNetDataset, self).__init__(".", transform=None, pre_transform=None)
        self.data_list = data_list
        self.data, self.slices = self.collate(data_list)

    @staticmethod
    def create_data_list(df):
        data_list = []
        for _, row in df.iterrows():
            graph = smiles2graph(row['SMILES'])
            data = Data(
                x=torch.tensor(graph['node_feat']),
                edge_index=torch.tensor(graph['edge_index']),
                edge_attr=torch.tensor(graph['edge_feat'])
            )
            data.smiles = row['SMILES']
            data.y = torch.tensor([[row['Target']]], dtype=torch.float) 
            data_list.append(data)
        return data_list

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if isinstance(idx, slice):
            return self.data_list[idx.start:idx.stop:idx.step]
        elif isinstance(idx, int):
            return self.data_list[idx]

data_list = CustomMoleculeNetDataset.create_data_list(train_df)
dataset = CustomMoleculeNetDataset(data_list)

print("Dataset type: ", type(dataset))
print("Dataset features: ", dataset.num_features)
print("Dataset length: ", len(dataset))
print("Dataset sample: ", dataset[0])
print("Sample nodes: ", dataset[0].num_nodes)
print("Sample edges: ", dataset[0].num_edges)

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 8: Model Architecture </h2>
</div>

In [None]:
torch.manual_seed(42)
class MolecularGraphNeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(MolecularGraphNeuralNetwork, self).__init__()
        embedding_size = 128  
        self.initial_conv = GCNConv(dataset.num_features, embedding_size)
        self.conv1 = GCNConv(embedding_size, embedding_size)
        self.conv2 = GCNConv(embedding_size, embedding_size)
        self.conv3 = GCNConv(embedding_size, embedding_size)
        self.out = torch.nn.Linear(embedding_size * 2, 1)
        self.bn1 = torch.nn.BatchNorm1d(embedding_size)
        self.bn2 = torch.nn.BatchNorm1d(embedding_size)
        self.bn3 = torch.nn.BatchNorm1d(embedding_size)
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, x, edge_index, batch_index):
        x = self.initial_conv(x, edge_index)
        x = F.leaky_relu(x, negative_slope=0.01)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.conv1(x, edge_index)
        x = F.leaky_relu(x, negative_slope=0.01)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.leaky_relu(x, negative_slope=0.01)
        x = self.bn3(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        x = F.leaky_relu(x, negative_slope=0.01)
        x = torch.cat([global_max_pool(x, batch_index), global_mean_pool(x, batch_index)], dim=1)
        x = self.out(x)
        return x

model = MolecularGraphNeuralNetwork().to(device)
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

In [None]:
class CustomEarlyStopping:
    def __init__(self, patience, min_epochs):
        self.patience = patience
        self.min_epochs = min_epochs
        self.best_loss = np.inf
        self.best_epoch = 0
        self.early_stop = False

    def __call__(self, epoch, avg_test_loss):
        if epoch < self.min_epochs:
            return False

        if avg_test_loss < self.best_loss:
            self.best_loss = avg_test_loss
            self.best_epoch = epoch
        elif epoch - self.best_epoch >= self.patience:
            self.early_stop = True
            display(HTML(f"<font color='green'><small>Early stopping at epoch {epoch+1}</small></font>"))
        return self.early_stop

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 9: Model Training </h2>
</div>

In [None]:
NUM_FOLDS = 5
num_graphs_per_batch = 256
n_epochs = 250

train_loss_per_fold = {}
validation_loss_per_fold = {}
skf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)

for fold, (train_idx, validation_idx) in enumerate(skf.split(dataset.data_list)):
    start_time_fold = time.time()
    
    model = MolecularGraphNeuralNetwork().to(device)
    model = torch.nn.DataParallel(model).to(device)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-3)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=20, factor=0.5, min_lr=1e-6)  
    custom_early_stopping = CustomEarlyStopping(patience=20, min_epochs=150)

    train_loader = DataLoader([dataset.data_list[idx] for idx in train_idx], batch_size=num_graphs_per_batch, shuffle=True, num_workers=3)
    validation_loader = DataLoader([dataset.data_list[idx] for idx in validation_idx], batch_size=num_graphs_per_batch, shuffle=True, num_workers=3)
    display(HTML(f"<small>Fold {fold + 1}, Train Data: {len(train_loader.dataset)}, Validation Data: {len(validation_loader.dataset)}</small></font>"))

    train_loss_per_fold[fold] = []
    validation_loss_per_fold[fold] = []

    for epoch in tnrange(n_epochs, leave=False):
        model.train()
        epoch_train_losses = []
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            pred = model(batch.x.float(), batch.edge_index, batch.batch)
            loss = loss_fn(pred, batch.y.float().to(device))
            loss.backward()
            optimizer.step()
            epoch_train_losses.append(loss.item())

        model.eval()
        epoch_validation_losses = []
        with torch.no_grad():
            for batch in validation_loader:
                batch = batch.to(device)
                pred = model(batch.x.float(), batch.edge_index, batch.batch)
                loss = loss_fn(pred, batch.y.float().to(device))
                epoch_validation_losses.append(loss.item())

        train_loss = np.mean(epoch_train_losses)
        validation_loss = np.mean(epoch_validation_losses)
        train_loss_per_fold[fold].append(train_loss)
        validation_loss_per_fold[fold].append(validation_loss)

        current_lr = optimizer.param_groups[0]['lr']
    
        if epoch % 20 == 0:
            display(HTML(f"<font color='grey'><small>Epoch {epoch+1},   TrainLoss {train_loss:.4f},   ValidationLoss {validation_loss:.4f},   LearningRate {current_lr:.6f}</small></font>"))

        if custom_early_stopping(epoch, validation_loss):
            break

        scheduler.step(validation_loss)
        state_dict = model.state_dict()
        os.makedirs('model_files', exist_ok=True)
        torch.save(state_dict, f'model_files/model_fold_{fold+1}.pth')

    np.save(f'model_files/train_losses_fold_{fold+1}.npy', np.array(train_loss_per_fold[fold]))
    np.save(f'model_files/validation_losses_fold_{fold+1}.npy', np.array(validation_loss_per_fold[fold]))

    end_time_fold = time.time()
    fold_time = round((end_time_fold - start_time_fold) / 60, 2)
    display(HTML(f"<font color='green'><b><small>Fold {fold + 1} checkpoints saved in model_fold_{fold+1}.pth, Time Taken: {fold_time:.2f} minutes</small><b></font>"))
    print("-" * 80)

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 10: Training and Validation losses </h2>
</div>

In [None]:
def plot_training_validation_losses(train_loss_per_fold, validation_loss_per_fold, num_folds, save_path=None):
    fig, axes = plt.subplots(num_folds, 1, figsize=(9, 3 * num_folds), sharex=True)
    for k in range(num_folds):
        axes[k].plot(train_loss_per_fold[k], label='Training Loss', color='#e64a19')
        axes[k].plot(validation_loss_per_fold[k], label='Validation Loss', color='#388e3c')
        axes[k].set_ylabel(f'Losses (Fold {k+1})', fontsize=8)
        axes[k].legend(fontsize=8, loc='upper right')
        axes[k].set_xlabel('Epoch Number', fontsize=8)
        axes[k].tick_params(axis='x', labelsize=8)
        axes[k].tick_params(axis='y', labelsize=8)
        axes[k].legend(fontsize=8, loc='upper right')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

NUM_FOLDS = 5
train_loss_per_fold = {}
validation_loss_per_fold = {}

for fold in range(NUM_FOLDS):
    train_losses = np.load(f'model_files/train_losses_fold_{fold+1}.npy', allow_pickle=True)
    validation_losses = np.load(f'model_files/validation_losses_fold_{fold+1}.npy', allow_pickle=True)
    train_loss_per_fold[fold] = train_losses.tolist()
    validation_loss_per_fold[fold] = validation_losses.tolist()

plot_training_validation_losses(train_loss_per_fold, validation_loss_per_fold, NUM_FOLDS, 'model_files/training_and_validation_losses.png')

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 11: Make predictions on test data </h2>
</div>

In [None]:
NUM_FOLDS = 5
num_graphs_per_batch = 256
model_folder = 'model_files'

class CustomMoleculeNetDataset_predict(InMemoryDataset):
    def __init__(self, data_list):
        super(CustomMoleculeNetDataset, self).__init__(".", transform=None, pre_transform=None)
        self.data_list = data_list
        self.data, self.slices = self.collate(data_list)

    @staticmethod
    def create_data_list(df):
        data_list = []
        for _, row in df.iterrows():
            graph = smiles2graph(row['SMILES'])
            data = Data(
                x=torch.tensor(graph['node_feat']),
                edge_index=torch.tensor(graph['edge_index']),
                edge_attr=torch.tensor(graph['edge_feat'])
            )
            data.smiles = row['SMILES']
            data_list.append(data)
        return data_list

test_data = CustomMoleculeNetDataset_predict.create_data_list(test_df)
test_loader = DataLoader(test_data, batch_size=num_graphs_per_batch)

models = []
for fold in range(NUM_FOLDS):
    model = MolecularGraphNeuralNetwork()
    model_checkpoint_path = f'{model_folder}/model_fold_{fold+1}.pth'
    checkpoint = torch.load(model_checkpoint_path, map_location=torch.device('cpu')) 

    if 'module.' in list(checkpoint.keys())[0]:
        checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}

    model.load_state_dict(checkpoint)  
    model.eval()
    models.append(model)
    
predictions = []
for batch in test_loader:
    batch = batch.to(device)
    batch_predictions = []
    for model in models:
        model = model.to(device)  
        with torch.no_grad():
            pred = model(batch.x.float().to(device), batch.edge_index.to(device), batch.batch.to(device))
            batch_predictions.append(torch.sigmoid(pred).cpu().numpy())

    batch_predictions = np.concatenate(batch_predictions, axis=1)
    mean_predictions = batch_predictions.mean(axis=1)
    mean_predictions = (mean_predictions > 0.5).astype(int)
    predictions.extend(mean_predictions)

test_results = pd.DataFrame({'id':test_df['id'], 'SMILES': test_df['SMILES'], 'Target': test_df['Target'], 'Target_pred': predictions})
display(test_results.head())
print(test_results.shape)

<div style="background-color:#4B6587; color:#F0E5CF; padding: 1px; border-radius: 10px;">
    <h2 style="font-size: 16px; margin-left: 10px;"> Step 12: Model Evaluation </h2>
</div>

In [None]:
sys.path.append(os.path.abspath(".."))
from my_cm import *

true_labels = test_results['Target']
predicted_labels = test_results['Target_pred']
cm = confusion_matrix(true_labels, predicted_labels)
PrettyConfusionMatrix(cm, labels=('0', '1'), save_path='model_files/my_confusion_matrix.png')