In [92]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score

In [93]:
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

In [94]:
import torch

In [95]:
import torch.nn as nn
import torch.nn.functional as F

In [96]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, global_add_pool

# Download data

In [97]:
bbbp = pd.read_csv("BBBP.csv")
bbbp.head()

Unnamed: 0,num,name,p_np,smiles
0,1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12
1,2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
2,3,40730,1,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...
3,4,24,1,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C
4,5,cloxacillin,1,Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...


In [98]:
bbbp.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2050 entries, 0 to 2049
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   num     2050 non-null   int64 
 1   name    2050 non-null   object
 2   p_np    2050 non-null   int64 
 3   smiles  2050 non-null   object
dtypes: int64(2), object(2)
memory usage: 64.2+ KB


The data looks good. 2050 entries for each column and there do not seem to be any missing or unusual values.

In [99]:
bbbp['p_np'].value_counts()

1    1567
0     483
Name: p_np, dtype: int64

Looking at class sizes, we can tell this dataset is very biased. This could be an indication to try stratified sampling.

# Process data

#### Generate canonical SMILES

In [100]:
# suppress warnings from invalid molecules for readability
# invalid molecules will be removed
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [101]:
# function to generate canon SMILES
def gen_canon_smiles(smiles_list):
    
    invalid_ids = []
    canon_smiles = []

    for i in range(len(smiles_list)):   
        mol = Chem.MolFromSmiles(smiles_list[i])
        
        # do not append NoneType if invalid
        if mol is None: 
            invalid_ids.append(i)
            continue

        canon_smiles.append(Chem.MolToSmiles(mol))

    return canon_smiles, invalid_ids

In [102]:
# generate canon smiles
canon_smiles, invalid_ids = gen_canon_smiles(bbbp.smiles)

# drop rows with invalid SMILES
bbbp = bbbp.drop(invalid_ids)

# replace SMILES with canon SMILES
bbbp.smiles = canon_smiles

# drop duplicates to prevent train/valid/test contamination
bbbp.drop_duplicates(subset=['smiles'], inplace=True)

In [103]:
bbbp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1975 entries, 0 to 2049
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   num     1975 non-null   int64 
 1   name    1975 non-null   object
 2   p_np    1975 non-null   int64 
 3   smiles  1975 non-null   object
dtypes: int64(2), object(2)
memory usage: 77.1+ KB


The dataset has been reduced, however, now it should only contain non-duplicate and valid molecules.

In [104]:
bbbp['p_np'].value_counts()

1    1501
0     474
Name: p_np, dtype: int64

The dataset is still biased.

#### Generate molecular graphs

In [105]:
def one_hot_encoding(x, permitted_list):
    """
    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.
    """

    if x not in permitted_list:
        x = permitted_list[-1]

    binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]

    return binary_encoding

In [106]:
def get_atom_features(atom, use_chirality = True, hydrogens_implicit = True):
    """
    Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output.
    """

    # define list of permitted atoms
    
    permitted_list_of_atoms =  ['C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As','Al','I', 'B','V','K','Tl','Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn', 'Li','Ge','Cu','Au','Ni','Cd','In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown']
    
    if hydrogens_implicit == False:
        permitted_list_of_atoms = ['H'] + permitted_list_of_atoms
    
    # compute atom features
    
    atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)
    
    n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])
    
    formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])
    
    hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
    
    is_in_a_ring_enc = [int(atom.IsInRing())]
    
    is_aromatic_enc = [int(atom.GetIsAromatic())]
    
    atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)]
    
    vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)]
    
    covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]

    atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled
                                    
    if use_chirality == True:
        chirality_type_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
        atom_feature_vector += chirality_type_enc
    
    if hydrogens_implicit == True:
        n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"])
        atom_feature_vector += n_hydrogens_enc

    return np.array(atom_feature_vector)

In [107]:
def get_bond_features(bond, use_stereochemistry = True):
    """
    Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output.
    """

    permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]

    bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)
    
    bond_is_conj_enc = [int(bond.GetIsConjugated())]
    
    bond_is_in_ring_enc = [int(bond.IsInRing())]
    
    bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc
    
    if use_stereochemistry == True:
        stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
        bond_feature_vector += stereo_type_enc

    return np.array(bond_feature_vector)

In [108]:
def create_pytorch_geometric_graph_data_list_from_smiles_and_labels(x_smiles, y):
    """
    Inputs:
    x_smiles = [smiles_1, smiles_2, ....] ... a list of SMILES strings
    y = [y_1, y_2, ...] ... a list of numerial labels for the SMILES strings (such as associated pKi values)
    
    Outputs:
    data_list = [G_1, G_2, ...] ... a list of torch_geometric.data.Data objects which represent labeled molecular graphs that can readily be used for machine learning
    """
    
    data_list = []
    
    for (smiles, y_val) in zip(x_smiles, y):
        
        # convert SMILES to RDKit mol object
        mol = Chem.MolFromSmiles(smiles)

        # get feature dimensions
        n_nodes = mol.GetNumAtoms()
        n_edges = 2*mol.GetNumBonds()
        unrelated_smiles = "O=O"
        unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
        n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
        n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))

        # construct node feature matrix X of shape (n_nodes, n_node_features)
        X = np.zeros((n_nodes, n_node_features))

        for atom in mol.GetAtoms():
            X[atom.GetIdx(), :] = get_atom_features(atom)
            
        X = torch.tensor(X, dtype = torch.float)
        
        # construct edge index array E of shape (2, n_edges)
        (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))
        torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
        torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
        E = torch.stack([torch_rows, torch_cols], dim = 0)
        
        # construct edge feature array EF of shape (n_edges, n_edge_features)
        EF = np.zeros((n_edges, n_edge_features))
        
        for (k, (i,j)) in enumerate(zip(rows, cols)):
            
            EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))
        
        EF = torch.tensor(EF, dtype = torch.float)
        
        # construct label tensor
        y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)
        
        # construct Pytorch Geometric data object and append to data list
        data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))

    return data_list

In [109]:
# create list of molecular graph objects from list of SMILES and list of labels
data_list = create_pytorch_geometric_graph_data_list_from_smiles_and_labels(bbbp.smiles, bbbp.p_np)

In [110]:
data_list[0]

Data(x=[20, 79], edge_index=[2, 40], edge_attr=[40, 10], y=[1])

#### Split data

In [111]:
# split data into training, validation, and test sets

labels = [int(data.y) for data in data_list] # array-like for stratification
train_ds, rest_list = train_test_split(data_list, train_size=0.8, stratify=labels)

labels = [int(data.y) for data in rest_list] # array-like for stratification
valid_ds, test_ds = train_test_split(rest_list, test_size=0.5, stratify=labels)

# print set ratios

total_len = len(bbbp)

print("train dataset:".ljust(15, " "), len(train_ds)/total_len) # 0.80
print("valid dataset:".ljust(15, " "), len(valid_ds)/total_len) # 0.10
print("test dataset:".ljust(15, " "), len(test_ds)/total_len) #0.10

train dataset:  0.8
valid dataset:  0.09974683544303797
test dataset:   0.10025316455696202


In [112]:
# create torch_geometric.loader.DataLoader objects

batch_size = 2**6 # 64

train_dl = DataLoader(train_ds, batch_size, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size, shuffle=True)

# Implementing a GNN using the PyTorch Geometric library

In [113]:
class Network(nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super().__init__()
        conv1_net = nn.Sequential(
            nn.Linear(num_edge_features, 32),
            nn.ReLU(),
            nn.Linear(32, num_node_features*32)
        )
        
        conv2_net = nn.Sequential(
            nn.Linear(num_edge_features, 32),
            nn.ReLU(),
            nn.Linear(32, 32*16)
        )
        
        self.conv1 = NNConv(num_node_features, 32, conv1_net)
        self.conv2 = NNConv(32,16, conv2_net)
        self.fc_1 = nn.Linear(16, 32)
        self.out = nn.Linear(32, 1)
        
    def forward(self, data):
        batch, x, edge_index, edge_attr = (data.batch, data.x, data.edge_index, data.edge_attr)
        
        # First graph conv layer
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        
        # Second graph conv layer
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        
        x = global_add_pool(x,batch)
        x = F.relu(self.fc_1(x))
        output = self.out(x)
        
        output = torch.sigmoid(output)
        
        return output

In [114]:
# initialize the network

num_node_features = data_list[0].num_node_features
num_edge_features = data_list[0].num_edge_features

model = Network(num_node_features, num_edge_features)

# initialize optimizer

loss_fn = nn.BCELoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

Network(
  (conv1): NNConv(79, 32, aggr=add, nn=Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=2528, bias=True)
  ))
  (conv2): NNConv(32, 16, aggr=add, nn=Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=512, bias=True)
  ))
  (fc_1): Linear(in_features=16, out_features=32, bias=True)
  (out): Linear(in_features=32, out_features=1, bias=True)
)

In [115]:
def train(model, optimizer, loss_fn, train_loader, val_loader, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    best_f1 = 0.0

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, data.y.view(-1, 1).float())
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * data.num_graphs
        train_loss /= len(train_loader.dataset)

        model.eval()
        y_true = []
        y_prob = []
        val_loss = 0.0
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                output = model(data)
                loss = loss_fn(output, data.y.view(-1, 1).float())
                val_loss += loss.item() * data.num_graphs
                y_true.append(data.y.detach().cpu().numpy())
                y_prob.append(output.detach().cpu().numpy())
            val_loss /= len(val_loader.dataset)

        y_true = np.concatenate(y_true)
        y_prob = np.concatenate(y_prob)
        y_pred = np.where(y_prob >= 0.5, 1, 0)
        f1 = f1_score(y_true, y_pred)

        print(f"Epoch {epoch+1:0>2}, train loss: {train_loss:.4f}, valid loss: {val_loss:.4f}, valid F1-score: {f1:.4f}")

        if f1 > best_f1:
            best_f1 = f1
            torch.save(model.state_dict(), 'best_model.pt')

    print(f"Best validation F1-score: {best_f1:.4f}")
    model.load_state_dict(torch.load('best_model.pt'))
    return model

In [116]:
trained_model = train(model, optimizer, loss_fn, train_dl, valid_dl, 20)

Epoch 01, train loss: 0.8513, valid loss: 0.5069, valid F1-score: 0.8776
Epoch 02, train loss: 0.4884, valid loss: 0.4679, valid F1-score: 0.8730
Epoch 03, train loss: 0.4415, valid loss: 0.4777, valid F1-score: 0.8827
Epoch 04, train loss: 0.4180, valid loss: 0.4318, valid F1-score: 0.8952
Epoch 05, train loss: 0.3993, valid loss: 0.4604, valid F1-score: 0.8609
Epoch 06, train loss: 0.3953, valid loss: 0.4658, valid F1-score: 0.8462
Epoch 07, train loss: 0.3933, valid loss: 0.4043, valid F1-score: 0.8939
Epoch 08, train loss: 0.3864, valid loss: 0.4226, valid F1-score: 0.9012
Epoch 09, train loss: 0.3801, valid loss: 0.4133, valid F1-score: 0.8859
Epoch 10, train loss: 0.3640, valid loss: 0.4096, valid F1-score: 0.9051
Epoch 11, train loss: 0.3652, valid loss: 0.3762, valid F1-score: 0.9079
Epoch 12, train loss: 0.3300, valid loss: 0.4040, valid F1-score: 0.9102
Epoch 13, train loss: 0.3325, valid loss: 0.3826, valid F1-score: 0.9079
Epoch 14, train loss: 0.3313, valid loss: 0.4017, v

In [117]:
def evaluate(model, loader, device):
    # Set the model to evaluation mode
    model.eval()
    
    # Initialize the lists for the true labels and predicted probabilities
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        # Iterate over the validation data
        for data in loader:
            # Move the data to the specified device
            data = data.to(device)
            
            # Forward pass
            output = model(data)
            
            # Convert the predicted probabilities to binary predictions
            y_pred_batch = (output > 0.5).int().cpu().numpy()
            
            # Append the true labels and predicted binary labels to the lists
            y_true.append(data.y.cpu().numpy())
            y_pred.append(y_pred_batch)
    
    # Concatenate the lists to obtain the true labels and predicted binary labels for the entire dataset
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    
    # Calculate the relevant performance metrics
    f1 = f1_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_pred)
    
    return f1, roc_auc

In [118]:
test_f1, test_roc_auc = evaluate(model, test_dl, 'cpu')
print("F1 on test set:", test_f1, 
      "\nROC-AUC on test set:", test_roc_auc)

F1 on test set: 0.9221183800623054 
ROC-AUC on test set: 0.7537499999999999
