In [1]:

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import torch
import sys
import pandas as pd
import numpy as np
import torch

from torch_geometric.loader import DataLoader as GeometricDataloader
sys.path.insert(0, './src')

from datasets.PPIMI_datasets import CustomSmilesDataset
from compound_gnn_model import GNNComplete
from MultiPPIMI import DualMultiPPIMI

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#Class for initializing the bioactivity model and predicting a list of scores
class PPIReward():
    def __init__(self,
                 prot1 = "P62166", #NCS-1
                 prot2 = "Q9NPQ8"): #Ric8
        
        #super(PPIReward, self).__init__(name=name, multiplier=multiplier)
        self.prot1 = prot1
        self.prot2 = prot2
        self.model_list = []
        for i in range(10):
            modulator_model = GNNComplete(5, 300, JK='last', drop_ratio=0, gnn_type="gin")
            modulator_model.load_state_dict(torch.load('./src/GraphMVP_C.model'))
            PPIMI_model = DualMultiPPIMI(
                modulator_model,
                modulator_emb_dim=310, 
                ppi_emb_dim=1318, 
                device="cuda:0",
                h_dim=512, n_heads=2
                ).to("cuda:0")
            PPIMI_model.load_state_dict(torch.load("./final_models/final_filter_"+str(i)+".model"))
            self.model_list.append(PPIMI_model)
            
    def multippimi_predicting(self, PPIMI_model, device, dataloader, regression=True):
        PPIMI_model.eval()
        total_preds = []
        total_labels = []
        with torch.no_grad():
            for batch in dataloader:
                modulator, rdkit_descriptors, ppi_esm = batch
                modulator = modulator.to(device)
                rdkit_descriptors = rdkit_descriptors.to(device)
                ppi_esm = ppi_esm.to(device)
                pred_1, pred_2 = PPIMI_model(modulator, rdkit_descriptors, ppi_esm)
                if(regression):
                    pred = pred_1.squeeze()
                else:
                    pred = pred_2.squeeze()
                if pred.ndim == 1:
                    pred = pred.unsqueeze(0)
                total_preds += pred.detach().cpu().numpy().flatten().tolist()
        return np.array(total_preds)
    
    def score_predict(self, smiles, full_array = False):
        df = pd.DataFrame()
        df["SMILES"] = smiles
        df['uniprot_id1'] = [self.prot1] * len(smiles)
        df['uniprot_id2'] = [self.prot2] * len(smiles)

        sample_dataset = CustomSmilesDataset(df, labels=False)
        sample_dataloader = GeometricDataloader(sample_dataset, batch_size=1024*3, shuffle=False, drop_last=False)
    
        score_list = []
        
        for PPIMI_model in self.model_list:
            score_list.append(self.multippimi_predicting(PPIMI_model, "cuda:0", sample_dataloader))
    
        if(full_array):
            return score_list
        scores = np.median(np.array(score_list), axis = 0) - np.std(np.array(score_list), axis=0)
    
        #df["scores"] = scores
        return scores
        
    def __call__(self, smiles: str):
        if isinstance(smiles, str):
            smiles = [smiles]
        
        rewards = self.score_predict(smiles)

        return rewards

# Initialize a model selecting a PPI by their uniprotIDs

In [3]:
ric8_model = PPIReward(prot1 = "P62166", prot2 = "Q9NPQ8") #NCS1/Ric8
d2r_model = PPIReward(prot1 = "P62166", prot2 = "P14416")  #NCS1/D2R
cb1_model = PPIReward(prot1 = "P62166", prot2 = "Q506J9")  #NCS1/CB1

In [13]:
d2r_model("OCCN(CCO)C1=NC2=C(N=C(N=C2N2CCCCC2)N(CCO)CCO)C(=N1)N1CCCCC1") #Example of use

0.0

array([-0.67953842])

# Example of use predicting on a dataframe of FDA compounds

## FDA

In [None]:
#Load structures
fda_df = pd.read_csv("data/FDA/structures_FDA.txt", sep="\t")

In [10]:
fda_df.head()

Unnamed: 0,Id,Smiles
0,DB00006,CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...
1,DB00007,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=...
2,DB00014,CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...
3,DB00027,CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...
4,DB00035,NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...


In [None]:
#Get scores
fda_df["NCS1/Ric8 score"]= ric8_model(fda_df["Smiles"])
fda_df["NCS1/D2R score"]= d2r_model(fda_df["Smiles"])

0.9996127033307514665

In [12]:
#Save scores
fda_df.to_csv("data/FDA/FDA_PPI_scoring.csv", index=None)

### Show the list of top 20 NCS1/Ric8 Molecules

In [14]:
fda_df.sort_values(by=["NCS1/Ric8 score"], ascending=False)[:20][["Id","Smiles","NCS1/Ric8 score"]]

Unnamed: 0,Id,Smiles,NCS1/Ric8 score
1601,DB08910,NC1=CC=CC2=C1C(=O)N(C1CCC(=O)NC1=O)C2=O,0.707691
347,DB00480,NC1=CC=CC2=C1CN(C1CCC(=O)NC1=O)C2=O,0.564974
2226,DB13170,[H][C@@]12CSSC[C@H](NC(=O)CNC(=O)[C@@]([H])(NC...,0.434086
803,DB00970,[H][C@@]12CCCN1C(=O)[C@H](NC(=O)[C@@H](NC(=O)C...,0.359665
869,DB01041,O=C1N(C2CCC(=O)NC2=O)C(=O)C2=CC=CC=C12,0.359389
1463,DB06699,CC(C)C[C@H](NC(=O)[C@@H](CC1=CC=C(NC(N)=O)C=C1...,0.248085
1692,DB09099,C[C@@H](O)[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=...,0.245516
1545,DB06827,[H][C@@]1(C[C@H](O)NC(=N)N1)[C@]1([H])NC(=O)\C...,0.215232
808,DB00975,OCCN(CCO)C1=NC2=C(N=C(N=C2N2CCCCC2)N(CCO)CCO)C...,0.17846
1425,DB06283,[H][C@]12CSSC[C@]3([H])NC(=O)[C@H](CCCCN)NC(=O...,0.140455


### Show the list of top 20 NCS1/D2R Molecules

In [15]:
fda_df.sort_values(by=["NCS1/D2R score"], ascending=False)[:20][["Id","Smiles","NCS1/D2R score"]]

Unnamed: 0,Id,Smiles,NCS1/D2R score
347,DB00480,NC1=CC=CC2=C1CN(C1CCC(=O)NC1=O)C2=O,0.278175
1601,DB08910,NC1=CC=CC2=C1C(=O)N(C1CCC(=O)NC1=O)C2=O,0.081731
2226,DB13170,[H][C@@]12CSSC[C@H](NC(=O)CNC(=O)[C@@]([H])(NC...,0.04229
1562,DB08822,CCOC1=NC2=C(N1CC1=CC=C(C=C1)C1=CC=CC=C1C1=NOC(...,-0.014455
869,DB01041,O=C1N(C2CCC(=O)NC2=O)C(=O)C2=CC=CC=C12,-0.019304
1653,DB09053,NC1=NC=NC2=C1C(=NN2[C@@H]1CCCN(C1)C(=O)C=C)C1=...,-0.147446
1463,DB06699,CC(C)C[C@H](NC(=O)[C@@H](CC1=CC=C(NC(N)=O)C=C1...,-0.152829
1692,DB09099,C[C@@H](O)[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=...,-0.198331
230,DB00357,CCC1(CCC(=O)NC1=O)C1=CC=C(N)C=C1,-0.24405
451,DB00593,CCC1(C)CC(=O)NC1=O,-0.247444
