In [15]:
#neccessary imports
import pandas as pd
import pyarrow.dataset as ds
from loguru import logger
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
import tensorflow as tf
import dgl
from dgl import batch as dgl_batch
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
from torch.utils.data import Dataset, DataLoader
from dgl.nn import GraphConv, GlobalAttentionPooling

#wandb.login()

import torch_geometric
from torch_geometric.data import Data
print("PyTorch version:", torch.__version__)
print("DGL version:", dgl.__version__)
print("CUDA available in PyTorch:", torch.cuda.is_available())

# Check CUDA support in DGL
try:
    # Create a simple graph
    g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))

    # Try to move the graph to GPU
    g = g.to('cuda')
    print("CUDA available in DGL: True")
except:
    print("CUDA available in DGL: False")
print(dgl.__version__)

PyTorch version: 2.2.1+cu121
DGL version: 2.1.0
CUDA available in PyTorch: True
CUDA available in DGL: False
2.1.0


In [16]:
def clean_smi(smi: str | list):
    r""" Clean a SMILES string by removing salts and fragments.
    Parameters
    ----------
    smi : str | list
        The SMILES string for a molecule. or a list of SMILES strings
    Returns
    -------
    str | list
        The cleaned SMILES string.
    """
    if isinstance(smi, list):
        return [clean_smi(s) for s in smi]
    # Remove [Dy] from smiles
    smi = smi.replace("[Dy]", "")

    # Convert SMILES to a RDKit molecule object
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        raise ValueError("Invalid SMILES string")
    
    # Remove any salts or fragments
    mol = Chem.RemoveHs(mol)  # Remove explicit hydrogens
    fragments = Chem.GetMolFrags(mol, asMols=True)
    
    # Keep the largest fragment
    largest_fragment = max(fragments, default=mol, key=lambda m: m.GetNumAtoms())
    
    # Standardize the molecule
    AllChem.Compute2DCoords(largest_fragment)  # Compute 2D coordinates
    
    # Convert the molecule back to a canonical SMILES string
    cleaned_smiles = Chem.MolToSmiles(largest_fragment, canonical=True)
    return cleaned_smiles


def smiles_to_dgl_graph(smiles: str |list):
    r""" Convert a SMILES string to a DGLGraph.
    Parameters
    ----------
    smiles : str | list
        The SMILES string for a molecule. or a list of SMILES strings
    Returns
    -------
    DGLGraph
        A DGLGraph object for the molecule.
    """
    if isinstance(smiles, list):
        return [smiles_to_dgl_graph(s) for s in smiles]
    clean_smiles=clean_smi(smiles)
    mol = Chem.MolFromSmiles(clean_smiles)
    if mol is None:
        return None

    # Node features
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append([
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization(),
            atom.GetIsAromatic(),
            atom.GetTotalNumHs()
        ])
    
    # Edge features and adjacency list
    src, dst = [], []
    bond_features = []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        src.append(start)
        dst.append(end)
        bond_features.append([
            int(bond.GetBondTypeAsDouble()), 
            bond.GetBondType(),
            bond.GetIsConjugated(),
            bond.IsInRing()
        ])
    """
    g = dgl.graph((src, dst))
    g.ndata['h'] = torch.tensor(atom_features, dtype=torch.float)
    g.edata['h'] = torch.tensor(bond_features, dtype=torch.float)
    g= dgl.add_self_loop(g)
    return g
    """
    edge_index = torch.tensor([src, dst], dtype=torch.long)
    node_features = torch.tensor(atom_features, dtype=torch.float)
    edge_features = torch.tensor(bond_features, dtype=torch.float)

    data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
    return data

#create a dictionary with the protein names as keys and the protein sequences as values
protein_sequences = {
    "BRD4": "NPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPDYYKIIKTPMDMGTIKKRLENNYYWNAQECIQDFNTMFTNCYIYNKPGDDIVLMAEALEKLFLQKINELPTEETEIMIVQAKGRGRGRKETGTAKPGVSTVPNTTQASTPPQTQTPQPNPPPVQATPHPFPAVTPDLIVQTPVMTVVPPQPLQTPPPVPPQPQPPPAPAPQPVQSHPPIIAATPQPVKTKKGVKRKADTTTPTTIDPIHEPPSLPPEPKTTKLGQRRESSRPVKPPKKDVPDSQQHPAPEKSSKVSEQLKCCSGILKEMFAKKHAAYAWPFYKPVDVEALGLHDYCDIIKHPMDMSTIKSKLEAREYRDAQEFGADVRLMFSNCYKYNPPDHEVVAMARKLQDVFEMRFAKMPDE",
    "sEH": "TLRAAVFDLDGVLALPAVFGVLGRTEEALALPRGLLNDAFQKGGPEGATTRLMKGEITLSQWIPLMEENCRKCSETAKVCLPKNFSIKEIFDKAISARKINRPMLQAALMLRKKGFTTAILTNTWLDDRAERDGLAQLMCELKMHFDFLIESCQVGMVKPEPQIYKFLLDTLKASPSEVVFLDDIGANLKPARDLGMVTILVQDTDTALKELEKVTGIQLLNTPAPLPTSCNPSDMSHGYVTVKPRVRLHFVELGSGPAVCLCHGFPESWYSWRYQIPALAQAGYRVLAMDMKGYGESSAPPEIEEYCMEVLCKEMVTFLDKLGLSQAVFIGHDWGGMLVWYMALFYPERVRAVASLNTPFIPANPNMSPLESIKANPVFDYQLYFQEPGVAEAELEQNLSRTFKSLFRASDESVLSMHKVCEAGGLFVNSPEEPSLSRMVTEEEIQFYVQQFKKSGFRGPLNWYRNMERNWKWACKSLGRKILIPALMVTAEKDFVLVPQMSQHMEDWIPHLKRGHIEDCGHWTQMDKPTEVNQILIKWLDSDARNPPVVSKM",
    "HSA": "DAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTESLVNRRPCFSALEVDETYVPKEFNAETFTFHADICTLSEKERQIKKQTALVELVKHKPKATKEQLKAVMDDFAAFVEKCCKADDKETCFAEEGKKLVAASQAALGL",
}

amino_acids = "ACDEFGHIKLMNPQRSTVWY"
aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}


def one_hot_encode_sequence(sequence, aa_to_index=aa_to_index):
    one_hot_encoded = np.zeros((len(sequence), len(amino_acids)), dtype=np.float32)
    for i, aa in enumerate(sequence):
        if aa in aa_to_index:
            one_hot_encoded[i, aa_to_index[aa]] = 1.0
    return one_hot_encoded

# Function to generate one-hot encoded features for proteins
def generate_protein_one_hot_features(protein_sequences):
    features = {}
    for protein_name, sequence in protein_sequences.items():
        one_hot_features = one_hot_encode_sequence(sequence)
        features[protein_name] = one_hot_features
    return features

def protein_to_features(protein_seq: str):
    #one hot encode the amino acid composition
    one_hot_encoded = one_hot_encode_sequence(protein_seq)
    amino_acid_composition = torch.tensor(one_hot_encoded, dtype=torch.float)
    # Hydrophobicity
    hydrophobicity = [0] * len(protein_seq)
    for i, amino_acid in enumerate(protein_seq):
        if amino_acid in ['A', 'I', 'L', 'M', 'F', 'W', 'V']:
            hydrophobicity[i] = 1
    hydrophobicity = torch.tensor(hydrophobicity, dtype=torch.float).unsqueeze(1)
    # Charge
    charge = [0] * len(protein_seq)
    for i, amino_acid in enumerate(protein_seq):
        if amino_acid in ['K', 'R']:
            charge[i] = 1
        elif amino_acid in ['D', 'E']:
            charge[i] = -1
    charge = torch.tensor(charge, dtype=torch.float).unsqueeze(1)
    # Concatenate features
    protein_features = torch.cat((amino_acid_composition, hydrophobicity, charge), dim=1)
    return protein_features

def protein_to_graph(protein_seq, protein_features, neighbor_distance=3):
    node_features = protein_features
    edge_index = []
    for i in range(len(protein_seq)):
        for j in range(i + 1, min(i + neighbor_distance + 1, len(protein_seq))):
            edge_index.append([i, j])
            edge_index.append([j, i])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()

    data = Data(x=node_features, edge_index=edge_index)
    return data



In [17]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool, global_max_pool

class ProteinLigandGCN(torch.nn.Module):
    def __init__(self, protein_feature_dim, ligand_feature_dim, hidden_dim, output_dim,dropout=0.2):
        super(ProteinLigandGCN, self).__init__()
        # Initialize GCN layers for protein
        self.protein_conv1 = GCNConv(protein_feature_dim, hidden_dim)
        self.protein_conv2 = GCNConv(hidden_dim, hidden_dim)

        # Initialize GCN layers for ligand
        self.ligand_conv1 = GCNConv(ligand_feature_dim, hidden_dim)
        self.ligand_conv2 = GCNConv(hidden_dim, hidden_dim)

        # Dropout layer
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
        # Fully connected layers for each graph after GCN layers
        self.protein_fc = nn.Linear(hidden_dim, hidden_dim)
        self.ligand_fc = nn.Linear(hidden_dim, hidden_dim)

        # Final fully connected layer after combining
        self.final_fc = nn.Linear(2 * hidden_dim, output_dim)

    def forward(self, protein_data, ligand_data):
        # Protein graph processing
        protein_graph, protein_features = protein_data
        ligand_graph, ligand_features = ligand_data
        
        # Apply GCN layers to protein graph
        protein_x = F.relu(self.protein_conv1(protein_features, protein_graph.edge_index))
        protein_x = self.dropout1(protein_x)
        protein_x = F.relu(self.protein_conv2(protein_x, protein_graph.edge_index))
        protein_x = self.dropout2(protein_x)
        
        # Pooling (choose one of the following or implement your own)
        #protein_output = global_add_pool(protein_x, protein_graph.batch)  # Global sum pooling
        protein_output = global_mean_pool(protein_x, protein_graph.batch)  # Global mean pooling
        # protein_output = global_max_pool(protein_x, protein_graph.batch)  # Global max pooling
        
        protein_output = F.relu(self.protein_fc(protein_output))

        # Apply GCN layers to ligand graph
        ligand_x = F.relu(self.ligand_conv1(ligand_features, ligand_graph.edge_index))
        ligand_x = self.dropout1(ligand_x)
        ligand_x = F.relu(self.ligand_conv2(ligand_x, ligand_graph.edge_index))
        ligand_x = self.dropout2(ligand_x)
        
        # Pooling (choose one of the following or implement your own)
        #ligand_output = global_add_pool(ligand_x, ligand_graph.batch)  # Global sum pooling
        ligand_output = global_mean_pool(ligand_x, ligand_graph.batch)  # Global mean pooling
        # ligand_output = global_max_pool(ligand_x, ligand_graph.batch)  # Global max pooling
        
        ligand_output = F.relu(self.ligand_fc(ligand_output))

        # Combine protein and ligand outputs
        combined_output = torch.cat((protein_output, ligand_output), dim=1)
        final_output = self.final_fc(combined_output)

        return final_output


In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Batch
class ProteinLigandDataset(Dataset):
    def __init__(self, data, protein_sequences, transform=None):
        self.data = data
        self.protein_sequences = protein_sequences
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        protein_name = self.data.iloc[idx]["protein_name"]
        molecule_smiles = self.data.iloc[idx]["molecule_smiles"]
        #check if there is a label in the data
        if "binds" in self.data.columns:
            label = self.data.iloc[idx]["binds"]
        else:
            label = self.data.iloc[idx]["id"]
        #label = self.data.iloc[idx]["binds"]
        
        protein_graph = protein_to_graph(self.protein_sequences[protein_name], protein_to_features(self.protein_sequences[protein_name]))
        smiles_graph = smiles_to_dgl_graph(molecule_smiles)
        """
        protein_features = protein_graph.ndata['h']
        smiles_features = smiles_graph.ndata['h']
                    'protein_features': protein_features,
                                'smiles_features': smiles_features,
        """
        sample = {
            'protein_graph': protein_graph,
            'smiles_graph': smiles_graph,
            'label': label
        }
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample


def custom_collate_fn(batch):
    protein_graphs = [item['protein_graph'] for item in batch]
    #protein_features = [item['protein_features'] for item in batch]
    smiles_graphs = [item['smiles_graph'] for item in batch]
    #smiles_features = [item['smiles_features'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=str)
    """
    batched_protein_graph = dgl_batch(protein_graphs)
    batched_smiles_graph = dgl_batch(smiles_graphs)
    
    batched_protein_features = torch.cat(protein_features)
    batched_smiles_features = torch.cat(smiles_features)
    """
    batched_protein_graph = Batch.from_data_list(protein_graphs)
    batched_smiles_graph = Batch.from_data_list(smiles_graphs)
    return {
        'protein_graph': batched_protein_graph,
        'smiles_graph': batched_smiles_graph,
        'label': labels
    }





In [19]:
test_data=ds.dataset(source="../../../data/test.parquet", format="parquet").to_table().to_pandas()


In [20]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")



protein_node_feats = 22
smiles_node_feats = 6
edge_feats = 4
hidden_dim = 128
output_dim = 1

# Initialize the model
model = ProteinLigandGCN(protein_node_feats, smiles_node_feats, hidden_dim, output_dim).to(device)
# Load the model
checkpoint = torch.load('best_model_24.pt')
logger.info(f'keys in checkpoint: {checkpoint.keys()}')
logger.info(f'model state dict: {checkpoint["odel_state_dict"]}')
model.load_state_dict(checkpoint['odel_state_dict'])
model.eval()

# Prepare the test data
# Assuming `test_dataset` is your dataset for test data
test_data=ds.dataset(source="../../../data/test.parquet", format="parquet").to_table().to_pandas()

test_loader =ProteinLigandDataset(test_data, protein_sequences)

# Make predictions
predictions = {}
with torch.no_grad():
    for batch in test_loader:
        protein_graph = batch['protein_graph'].to(device)
        protein_features=protein_graph.x.to(device)
        smiles_graph = batch['smiles_graph'].to(device)
        smiles_features=smiles_graph.x.to(device)
        label = batch['label']
        output = model((protein_graph, protein_features), (smiles_graph, smiles_features))
        #logger.info(f"Predicted label for {label} is {output.squeeze(1).cpu().numpy()[0]}")
        predictions[label] = output.squeeze(1).cpu().numpy()[0]
        
        
        

# Save the predictions to a parquet file with the columns "id" and "binds"
predictions_df = pd.DataFrame(predictions.items(), columns=["id", "binds"])
predictions_df.to_parquet("predictions.parquet", index=False)

[32m2024-07-05 12:36:36.611[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mkeys in checkpoint: dict_keys(['epoch', 'odel_state_dict', 'optimizer_state_dict', 'average_train_loss', 'average_val_loss'])[0m


[32m2024-07-05 12:36:36.697[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mmodel state dict: OrderedDict([('protein_conv1.bias', tensor([-0.6005, -0.4973, -0.6663, -0.4601, -0.6563, -0.4538, -0.4558, -0.4845,
        -0.4897, -0.4733, -0.5136, -0.6005, -0.6004, -0.6003, -0.6000, -0.6873,
        -0.6004, -0.4847, -0.5994, -0.4775, -0.6601, -0.6005, -0.6004, -0.6607,
        -0.5992, -0.6000, -0.4971, -0.5973, -0.4891, -0.5065, -0.4617, -0.6002,
        -0.4900, -0.6004, -0.4612, -0.4735, -0.6004, -0.4861, -0.6005, -0.6616,
        -0.4834, -0.6005, -0.6004, -0.4855, -0.7409, -0.6560, -0.6005, -0.6796,
        -0.4818, -0.5989, -0.6005, -0.4376, -0.6004, -0.4918, -0.6003, -0.4753,
        -0.4994, -0.6563, -0.6004, -0.4780, -0.6004, -0.6003, -0.5180, -0.7271,
        -0.4879, -0.4654, -0.5180, -0.4836, -0.6004, -0.4936, -0.6001, -0.6004,
        -0.6002, -0.6002, -0.6005, -0.4758, -0.5065, -0.4954, -0.6004, -0.4627,
        -0.6002, -0.4680, -0.5998, -0.

Using cuda device


In [23]:
predictions.head()

Unnamed: 0,id,binds
0,295246830,0.499211
1,295246831,0.499211
2,295246832,0.499211
3,295246833,0.499211
4,295246834,0.499211


In [21]:
#read the predictions file
predictions = pd.read_parquet("predictions.parquet")
import polars as pl

pred=pl.read_parquet("predictions.parquet")
pred.head()
#read the binds column and apply the sigmoid function to get the probability using torch.sigmoid
predictions["binds"]=torch.sigmoid(torch.tensor(predictions["binds"].to_numpy())).numpy()
predictions.head()

#save the predictions to a parquet file
predictions.to_parquet("predictions.parquet", index=False)