In [38]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing, global_mean_pool, global_max_pool
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt

dir = '../data'


In [39]:
# Atom Featurisation
## Auxiliary function for one-hot enconding transformation based on list of
##permitted values

def one_hot_encoding(x, permitted_list):
    """
    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.
    """
    if x not in permitted_list:
        x = permitted_list[-1]
    binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]
    return binary_encoding
    
    
# Main atom feat. func

def get_atom_features(atom, use_chirality=True):
    # Define a simplified list of atom types
    permitted_atom_types = ['C', 'N', 'O', 'S', 'P', 'F', 'Cl', 'Br', 'I','Dy', 'Unknown']
    atom_type = atom.GetSymbol() if atom.GetSymbol() in permitted_atom_types else 'Unknown'
    atom_type_enc = one_hot_encoding(atom_type, permitted_atom_types)
    
    # Consider only the most impactful features: atom degree and whether the atom is in a ring
    atom_degree = one_hot_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 'MoreThanFour'])
    is_in_ring = [int(atom.IsInRing())]
    
    #print(atom_degree)
    #exit()
    # Optionally include chirality
    if use_chirality:
        chirality_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
        atom_features = atom_type_enc + atom_degree + is_in_ring + chirality_enc
    else:
        atom_features = atom_type_enc + atom_degree + is_in_ring
    
    return np.array(atom_features, dtype=np.float32)

# Bond featurization

def get_bond_features(bond):
    # Simplified list of bond types
    permitted_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, 'Unknown']
    bond_type = bond.GetBondType() if bond.GetBondType() in permitted_bond_types else 'Unknown'
    
    # Features: Bond type, Is in a ring
    features = one_hot_encoding(bond_type, permitted_bond_types) \
               + [int(bond.IsInRing())]
    
    return np.array(features, dtype=np.float32)


def create_pytorch_geometric_graph_data_list_from_smiles_and_labels(x_smiles, y=None):
    data_list = []
    
    for index, smiles in enumerate(x_smiles):
        mol = Chem.MolFromSmiles(smiles)
        
        if not mol:  # Skip invalid SMILES strings
            continue
        
        # Node features
        atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
        x = torch.tensor(atom_features, dtype=torch.float)
        
        # Edge features
        edge_index = []
        edge_features = []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_index += [(start, end), (end, start)]  # Undirected graph
            bond_feature = get_bond_features(bond)
            edge_features += [bond_feature, bond_feature]  # Same features in both directions
        
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)
        
        # Creating the Data object
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        #data.molecule_id = ids[index]
        if y is not None:
            data.y = torch.tensor([y[index]], dtype=torch.float)
        
        data_list.append(data)
    
    return data_list

def featurize_data_in_batches(smiles_list, labels_list, batch_size):
    data_list = []
    # Define tqdm progress bar
    pbar = tqdm(total=len(smiles_list), desc="Featurizing data")
    for i in range(0, len(smiles_list), batch_size):
        smiles_batch = smiles_list[i:i+batch_size]
        if labels_list is not None:
            labels_batch = labels_list[i:i+batch_size]
        else:
            labels_batch = None
        #ids_batch = ids_list[i:i+batch_size]
        batch_data_list = create_pytorch_geometric_graph_data_list_from_smiles_and_labels(smiles_batch, labels_batch)
        data_list.extend(batch_data_list)
        pbar.update(len(smiles_batch))
        
    pbar.close()
    return data_list

In [40]:
dtypes = {'buildingblock1_smiles': np.int16, 'buildingblock2_smiles': np.int16, 'buildingblock3_smiles': np.int16,
          'binds_BRD4':np.byte, 'binds_HSA':np.byte, 'binds_sEH':np.byte}

train = pd.read_csv('../shrunken_data/train.csv', dtype = dtypes)
print(len(train))
train.head()

98415610


Unnamed: 0,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,binds_BRD4,binds_HSA,binds_sEH
0,0,0,0,C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H]...,0,0,0
1,0,0,1,C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC...,0,0,0
2,0,0,2,C#CCOc1ccc(CNc2nc(NCc3ccc(OCC#C)cc3)nc(N[C@@H]...,0,0,0
3,0,0,6,C#CCOc1ccc(CNc2nc(NCCNC(=O)C(=C)C)nc(N[C@@H](C...,0,0,0
4,0,0,10,C#CCOc1ccc(CNc2nc(NCC(=O)NCC=C)nc(N[C@@H](CC#C...,0,0,0


In [41]:
df = train[:3000]

In [42]:
df[['binds_BRD4','binds_HSA','binds_sEH']].values.shape

(3000, 3)

In [43]:
# Define the batch size for featurization
batch_size = 2**8
smiles_list = df['molecule_smiles'].tolist()
labels_list = df[['binds_BRD4','binds_HSA','binds_sEH']].values
train_data = featurize_data_in_batches(smiles_list, labels_list, batch_size)
    

Featurizing data: 100%|██████████| 3000/3000 [00:03<00:00, 968.95it/s] 


In [45]:
dtypes = {'buildingblock1_smiles': np.int16, 'buildingblock2_smiles': np.int16, 'buildingblock3_smiles': np.int16,
          'binds_BRD4':np.byte, 'binds_HSA':np.byte, 'binds_sEH':np.byte}

test = pd.read_csv('../shrunken_data/test.csv', dtype = dtypes)
print(len(test))
test.head()

878022


Unnamed: 0,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,is_BRD4,is_HSA,is_sEH
0,0,17,17,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,True,True,True
1,0,17,87,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,True,True,True
2,0,17,99,C#CCCC[C@H](Nc1nc(NCC2(O)CCCC2(C)C)nc(Nc2ccc(C...,True,True,True
3,0,17,244,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2sc(Cl)c...,True,True,True
4,0,17,394,C#CCCC[C@H](Nc1nc(NCC2CCC(SC)CC2)nc(Nc2ccc(C=C...,True,True,True


In [46]:
test = test[:4000]

In [47]:
# Define the batch size for featurization
batch_size = 2**8
smiles_list = test['molecule_smiles'].tolist()
#labels_list = test[['binds_BRD4','binds_HSA','binds_sEH']].values
test_data = featurize_data_in_batches(smiles_list, None, batch_size)
    

Featurizing data: 100%|██████████| 4000/4000 [00:04<00:00, 939.09it/s]


In [53]:
train_data[0],train_data[2]

(Data(x=[41, 22], edge_index=[2, 88], edge_attr=[88, 6], y=[1, 3]),
 Data(x=[40, 22], edge_index=[2, 84], edge_attr=[84, 6], y=[1, 3]))

In [54]:
test_data[0],test_data[1]

(Data(x=[35, 22], edge_index=[2, 74], edge_attr=[74, 6]),
 Data(x=[40, 22], edge_index=[2, 86], edge_attr=[86, 6]))

In [55]:
# MODELLING

#Define custom GNN layer
class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(CustomGNNLayer, self).__init__(aggr='max')
        self.lin = nn.Linear(in_channels + 6, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # Start propagating messages
        return MessagePassing.propagate(self, edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        combined = torch.cat((x_j, edge_attr), dim=1)
        return combined

    def update(self, aggr_out):
        return self.lin(aggr_out)

#Define GNN Model
class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout_rate,out_channels=1):
        super(GNNModel, self).__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList([CustomGNNLayer(input_dim if i == 0 else hidden_dim, hidden_dim) for i in range(num_layers)])
        self.dropout = nn.Dropout(dropout_rate)
        self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_layers)])
        self.lin = nn.Linear(hidden_dim, out_channels)
        
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_attr)
            x = self.bns[i](x)
            x = F.relu(x)
            x = self.dropout(x)


        x = global_max_pool(x, data.batch) # Global pooling to get a graph-level representation
        x = self.lin(x)
        return x


In [69]:
def train_model(train_loader, num_epochs, input_dim, hidden_dim, num_layers, dropout_rate,out_channels, lr):
    model = GNNModel(input_dim, hidden_dim, num_layers, dropout_rate,out_channels)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = BCEWithLogitsLoss()
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            #loss = criterion(out, batch.y.view(-1, 1).float()) # ??
            loss = criterion(out, batch.y.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}')
    
    return model

def predict_with_model(model, test_loader):
    model.eval()
    predictions = []
    #molecule_ids = []

    with torch.no_grad():
        for data in test_loader:
            output = torch.sigmoid(model(data))
            predictions.extend(output.view(-1).tolist())
            #molecule_ids.extend(data.molecule_id)

    return predictions

def select_and_save_predictions_with_ids(predictions,test_df,path_test_file_for_ids,output_dir = 'results/',output_file_name ='ids_pred_results.csv'):
    
    #Combine predictions with the bools
    bool_cols  = test_df[['is_BRD4','is_HSA','is_sEH']]
    bool_cols = np.array(bool_cols).reshape(-1)
    y_pred_and_bools =  np.vstack((bool_cols,predictions)).T
    y_pred_and_bools_df = pd.DataFrame({'Bool': y_pred_and_bools[:, 0], 'binds': y_pred_and_bools[:, 1]})
    
    # drop predictions of protiens not in the test set and also drop the bool column
    y_pred_and_bools_df = y_pred_and_bools_df[y_pred_and_bools_df.Bool != 0]
    y_pred_df = y_pred_and_bools_df.drop(['Bool'],axis = 1)
    y_pred_df = y_pred_df.reset_index(drop=True)

    #read the test ids
    test = pd.read_csv(path_test_file_for_ids,index_col=False)[:len(y_pred_df)]
    test_ids = pd.DataFrame(test.id)
    assert len(y_pred_df)==len(test_ids)
    y_pred_and_ids_df = pd.concat([test_ids,y_pred_df],axis=1)
    
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir,output_file_name)
    y_pred_and_ids_df.to_csv(output_path,index=False)
    return y_pred_and_ids_df


In [58]:
# Train model
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
input_dim = train_loader.dataset[0].num_node_features
hidden_dim = 64
num_epochs = 11
num_layers = 4 #Should ideally be set so that all nodes can communicate with each other
dropout_rate = 0.3
lr = 0.001
out_channels =3
#These are just example values, feel free to play around with them.
model = train_model(train_loader,num_epochs, input_dim, hidden_dim,num_layers, dropout_rate,out_channels, lr)

Epoch 1/11, Loss: 0.053738421076869075
Epoch 2/11, Loss: 0.008369789972772544
Epoch 3/11, Loss: 0.00835864261372332
Epoch 4/11, Loss: 0.008439316725769535
Epoch 5/11, Loss: 0.007954135739016644
Epoch 6/11, Loss: 0.008297347993728645
Epoch 7/11, Loss: 0.007923215554601097
Epoch 8/11, Loss: 0.007673561386020973
Epoch 9/11, Loss: 0.007450859184362034
Epoch 10/11, Loss: 0.007932277248325223
Epoch 11/11, Loss: 0.008421062985980051


In [60]:
# Predict
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
predictions = predict_with_model(model, test_loader)


In [70]:
path_test_file_for_ids = '../data/test.csv'
output_df = select_and_save_predictions_with_ids(predictions,test,path_test_file_for_ids,output_dir = 'results')

In [71]:
output_df

Unnamed: 0,id,binds
0,295246830,0.125133
1,295246831,0.133689
2,295246832,0.099996
3,295246833,0.124910
4,295246834,0.144258
...,...,...
11995,295258825,0.114499
11996,295258826,0.115671
11997,295258827,0.120788
11998,295258828,0.115426
