In [6]:
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser, NeighborSearch, Selection
import numpy as np
from transformers.models.esm.openfold_utils.protein import Protein

In [7]:
sEH = 'TLRAAVFDLDGVLALPAVFGVLGRTEEALALPRGLLNDAFQKGGPEGATTRLMKGEITLSQWIPLMEENCRKCSETAKVCLPKNFSIKEIFDKAISARKINRPMLQAALMLRKKGFTTAILTNTWLDDRAERDGLAQLMCELKMHFDFLIESCQVGMVKPEPQIYKFLLDTLKASPSEVVFLDDIGANLKPARDLGMVTILVQDTDTALKELEKVTGIQLLNTPAPLPTSCNPSDMSHGYVTVKPRVRLHFVELGSGPAVCLCHGFPESWYSWRYQIPALAQAGYRVLAMDMKGYGESSAPPEIEEYCMEVLCKEMVTFLDKLGLSQAVFIGHDWGGMLVWYMALFYPERVRAVASLNTPFIPANPNMSPLESIKANPVFDYQLYFQEPGVAEAELEQNLSRTFKSLFRASDESVLSMHKVCEAGGLFVNSPEEPSLSRMVTEEEIQFYVQQFKKSGFRGPLNWYRNMERNWKWACKSLGRKILIPALMVTAEKDFVLVPQMSQHMEDWIPHLKRGHIEDCGHWTQMDKPTEVNQILIKWLDSDARNPPVVSKM'
sEH_id = 'P34913-1'

example_mol = 'C#CCOc1ccc(CNc2nc(NCc3ccc4[nH]c(C)cc4c3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1'


In [8]:
import pandas as pd
import pickle
from pathlib import Path
import polars as pl

# Load data
types = {'buildingblock1_smiles': np.int16, 'buildingblock2_smiles': np.int16, 'buildingblock3_smiles': np.int16,
          'binds_BRD4':np.byte, 'binds_HSA':np.byte, 'binds_sEH':np.byte}

directory = Path("../data/shrunken/")
train = pl.read_parquet(directory / "train.parquet")
train = train.to_pandas(use_pyarrow_extension_array=True)
test = pl.read_parquet(directory / "test.parquet")
test = test.to_pandas(use_pyarrow_extension_array=True)
print(test.columns)
# Load building blocks
BBs_dict_reverse_1 = (pickle.load(open(directory / 'train_dicts/BBs_dict_reverse_1.p', 'br')))
BBs_dict_reverse_2 = pickle.load(open(directory / 'train_dicts/BBs_dict_reverse_2.p', 'br'))
BBs_dict_reverse_3 = pickle.load(open(directory / 'train_dicts/BBs_dict_reverse_3.p', 'br'))

train_df = train[(train.binds_sEH == 1)].sample(n=100000)
train_df = pd.concat([train_df, train[(train.binds_BRD4 == 0) & (train.binds_HSA == 0) & (train.binds_sEH == 0)].sample(n=300000)])
train_df.reset_index(drop=True, inplace=True)

test_df = test[(test['is_sEH'] == 1)]
test_df = pd.concat([test_df, test[(test['is_sEH'] == 0)].sample(n=300000)])
test_df.reset_index(drop=True, inplace=True)



Index(['buildingblock1_smiles', 'buildingblock2_smiles',
       'buildingblock3_smiles', 'molecule_smiles', 'is_BRD4', 'is_HSA',
       'is_sEH'],
      dtype='object')


In [9]:
len(test[test['is_sEH'] == 0])

319880

In [10]:
import torch
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv, Linear
import torch.nn.functional as F

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string")
    
    new_mol = Chem.AddHs(mol)
    updated_mol = Chem.SanitizeMol(new_mol)
    if updated_mol is None:
        raise ValueError("Molecule sanitization failed")
    
    edges = []
    edge_features = []
    node_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),   
            atom.GetDegree(),      
            atom.GetFormalCharge(),
            atom.GetHybridization().real,
        ]
        node_features.append(features)

    for bond in mol.GetBonds():
        edges.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))
        edge_features.append([bond.GetBondTypeAsDouble()])

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor(node_features, dtype=torch.float)
    edge_attr = torch.tensor(edge_features, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    
class MoleculeDataset(Dataset):
    def __init__(self, dataframe, target_column):
        super(MoleculeDataset, self).__init__()
        self.dataframe = dataframe
        self.target_column = target_column
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        smiles = row['molecule_smiles']
        target = row[self.target_column]
        data = smiles_to_graph(smiles)
        if data is None:
            return None 
        data.y = torch.tensor([target], dtype=torch.float)
        return data


In [11]:
import torch_geometric

dataset = MoleculeDataset(train_df, 'binds_sEH')
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  
    return torch_geometric.data.Batch.from_data_list(batch)

train_loader = DataLoader(dataset, batch_size=20000, shuffle=True, collate_fn=collate_fn)
if len(train_loader) == 0:
    raise ValueError("DataLoader is empty. Check the dataset and collation settings.")

test_dataset = MoleculeDataset(test_df, 'is_sEH')
test_loader = DataLoader(test_dataset, batch_size=2000, shuffle=False, collate_fn=collate_fn)



In [13]:

class GCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GCN, self).__init__()
        self.conv1 = pyg_nn.GCNConv(num_node_features, 16)  
        self.conv2 = pyg_nn.GCNConv(16, 16) 
        self.fc = Linear(16, 1) 

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        device = next(self.parameters()).device
        batch = batch.to(device)
        x = x.to(device)
        edge_index = edge_index.to(device)
        
        if x.size(1) != self.conv1.in_channels:
            self.conv1 = pyg_nn.GCNConv(x.size(1), 16).to(device)
        else:
            self.conv1 = self.conv1.to(device)
        
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = pyg_nn.global_mean_pool(x, batch)
        x = self.fc(x).squeeze()
        return sigmoid(x)



In [18]:
from sklearn.metrics import average_precision_score
import warnings
from tqdm import tqdm

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_node_features=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

warnings.filterwarnings('ignore', message=r'UFFTYPER: Unrecognized charge state for atom:*')
warnings.filterwarnings('ignore', message=r'UFFTYPER: Unrecognized atom type:*')

for epoch in range(10):
    total_loss = 0
    y_true = []
    y_scores = []
    model.train()
    for data in tqdm(train_loader, desc="Training"):
        if data is None or len(data.y) == 0:
            continue
        data = data
        optimizer.zero_grad()
        output = model(data)
        labels = data.y.view(-1)
        loss = criterion(output, labels.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        y_true.extend(labels.cpu().numpy())
        y_scores.extend(torch.sigmoid(output).detach().cpu().numpy())
    average_loss = total_loss / len(train_loader)
    
    # Calculate mean average precision (mAP)
    mAP = average_precision_score(y_true, y_scores)
    
    print(f"Epoch {epoch+1}, Loss: {average_loss}, mAP: {mAP}")


Training:   5%|▌         | 1/20 [00:17<05:40, 17.92s/it]


KeyboardInterrupt: 

In [19]:
model.eval() 
correct = 0
total = 0

with torch.no_grad():
    for data in tqdm(test_loader, desc="Testing"):
        outputs = model(data)
        predicted = (outputs > 0.5).float() 
        labels = data.y.view(-1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Compute accuracy
accuracy = 100 * correct / total
print('Accuracy on the test dataset: {:.2f}%'.format(accuracy))


Testing:   1%|▏         | 6/430 [00:06<07:15,  1.03s/it]


KeyboardInterrupt: 

In [20]:
def construct_interaction_graph(smiles, protein, protein_structure, cutoff=5.0):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        # Handle invalid SMILES input
        return None

    parser = PDBParser()
    structure = parser.get_structure(protein, protein_structure)
    model = structure[0]

    G = nx.Graph()

    # Add ligand atoms to graph
    for idx, atom in enumerate(mol.GetAtoms()):
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization().real
        ]
        G.add_node(idx, features=torch.tensor(features, dtype=torch.float), type='ligand')

    # Add protein residues to graph
    residue_idx = mol.GetNumAtoms()
    for residue in model.get_residues():
        if residue.get_resname() not in ['HOH']:  # Exclude water
            G.add_node(residue_idx, features=residue.get_resname(), type='protein')
            residue_idx += 1

    # Create edges based on distance
    for i in range(mol.GetNumAtoms()):
        ligand_pos = mol.GetConformer().GetAtomPosition(i)
        for j, protein_atom in enumerate(model.get_atoms()):
            if protein_atom.get_parent().get_resname() != 'HOH':
                if np.linalg.norm(np.array(ligand_pos) - np.array(protein_atom.coord)) < cutoff:
                    G.add_edge(i, mol.GetNumAtoms() + j)

    edge_index = torch.tensor(list(G.edges)).t().contiguous()
    x_list = []
    for i in G.nodes:
        features = G.nodes[i]['features']
        if isinstance(features, str):
            continue 
        try:
            feature_tensor = torch.tensor(features, dtype=torch.float)
        except ValueError:
            continue
        x_list.append(feature_tensor)
    x = torch.stack(x_list, dim=0)
    return Data(x=x, edge_index=edge_index)



class MoleculeDataset(Dataset):
    def __init__(self, dataframe, target_column):
        super(MoleculeDataset, self).__init__()
        self.dataframe = dataframe
        self.target_column = target_column
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        smiles = row['molecule_smiles']
        target = row[self.target_column]
        data = construct_interaction_graph(smiles, 'P34913', '/home/daniel-de-dios/Downloads/3i28.pdb' ,cutoff= 0.1)
        if data is None:
            return None 
        data.y = torch.tensor([target], dtype=torch.float)
        return data


In [21]:
import torch_geometric

dataset = MoleculeDataset(train_df, 'binds_sEH')
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  
    return torch_geometric.data.Batch.from_data_list(batch)

train_loader = DataLoader(dataset, batch_size=20000, shuffle=True, collate_fn=collate_fn)
if len(train_loader) == 0:
    raise ValueError("DataLoader is empty. Check the dataset and collation settings.")

test_dataset = MoleculeDataset(test_df, 'is_sEH')
test_loader = DataLoader(test_dataset, batch_size=2000, shuffle=False, collate_fn=collate_fn)



In [22]:
from torch.nn.functional import sigmoid
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from torch.nn import Linear

class GCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GCN, self).__init__()
        self.conv1 = pyg_nn.GCNConv(num_node_features, 16)  
        self.conv2 = pyg_nn.GCNConv(16, 16) 
        self.fc = Linear(16, 1) 

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        device = next(self.parameters()).device
        batch = batch.to(device)
        x = x.to(device)
        edge_index = edge_index.to(device)
        
        if x.size(1) != self.conv1.in_channels:
            self.conv1 = pyg_nn.GCNConv(x.size(1), 16).to(device)
        else:
            self.conv1 = self.conv1.to(device)
        
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = pyg_nn.global_mean_pool(x, batch)
        x = self.fc(x).squeeze()
        return sigmoid(x)