# Inference Notebook
- This notebook provides a guide to predict the PPI inhibitors with your custom dataset.
- It includes preprocessing, data loading, and model inference.

In [1]:
import os
import tqdm
import numpy as np
import pandas as pd
import random 
import copy
import pickle

import torch
from torch import nn
from torch.utils.data import DataLoader


from unimol_tools.models import UniMolModel

from src.dataset import PPIInhibitorInferenceDataset
from src.dataset import extract_esm2, process_unimol_inputs, compute_protein_props, compute_compound_props, process_interface
from src.model import PPIInhibitorModel
from src.utils import predict, performance_evaluation, batch_collate_fn

2024-11-01 16:03:17 | unimol_tools/weights/weighthub.py | 17 | INFO | Uni-Mol Tools | Weights will be downloaded to default directory: /data/dongok/anaconda3/envs/ICAN-PPII/lib/python3.9/site-packages/unimol_tools/weights


In [3]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [5]:
seed = 2022 
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [6]:
batch_size = 64

### Load Dataset

To begin your analysis, you will need to prepare two distinct datasets:

1. **Query Dataset** contains the following columns:

    - **`smiles`**: Represents the SMILES (Simplified Molecular Input Line Entry System) notation of the compound.
    - **`ppi_label`**: Indicates the label for protein-protein interaction (PPI).      


2. **Interface Dataset** includes detailed information about protein interfaces. It contains the following columns:

    - **`ppi_label`**: Indicates the label for protein-protein interaction (PPI).
    - **`uniprot_id1`**: UniProt identifier for the first protein.
    - **`uniprot_id2`**: UniProt identifier for the second protein.
    - **`uniprot_sequence1`**: Amino acid sequence of the first protein.
    - **`uniprot_sequence2`**: Amino acid sequence of the second protein.
    - **`min_index1`**: Minimum index of interface residues in the first protein.
    - **`min_index2`**: Minimum index of interface residues in the second protein.
    - **`max_index1`**: Maximum index of interface residues in the first protein.
    - **`max_index2`**: Maximum index of interface residues in the second protein.


    **Note:** If the interface information (`min_index` and `max_index` columns) is unavailable, retain only the essential columns:

    - `ppi_label`
    - `uniprot_id1`
    - `uniprot_id2`
    - `uniprot_sequence1`
    - `uniprot_sequence2`

In [7]:
query_datapath = 'data/toy_example/toy_example.csv'
interface_datapath = 'data/toy_example/ppi_interface.csv'

In [8]:
df = pd.read_csv(query_datapath)
df.head()

Unnamed: 0,SMILES,ppi_label
0,COc1cc(Cc2cnc(N)nc2N)cc(OC)c1N(C)C,10.0
1,O=C(CS(=O)C(c1ccccc1)c1ccccc1)NO,10.0
2,OC1C=C2CCN3Cc4cc5c(cc4C(C1O)C23)OCO5,10.0
3,OC1C=C2CCN3Cc4cc5c(cc4C(C1O)C23)OCO5,10.0
4,NC(CCc1ccc(N(CCCl)CCCl)cc1)C(=O)O,10.0


In [9]:
interface_df = pd.read_csv(interface_datapath)
interface_df.head()

Unnamed: 0,uniprot_id1,uniprot_id2,ppi_label,uniprot_sequence1,uniprot_sequence2,interface_idx1,interface_idx2,min_index1,max_index1,min_index2,max_index2
0,Q01196,Q13951,10,MRIPVDASTSRRFTPPSTALSPGKMSEALPLGAPDAGAALAGKLRS...,MPRVVPDQRSKFENEEFFRKLSRECEIKYTGFRDRPHEERQARFQN...,"[65, 66, 67, 68, 93, 94, 95, 103, 105, 106, 10...","[1, 2, 3, 4, 10, 16, 27, 28, 29, 32, 33, 53, 5...",65,162,1,130


### Preprocessing

In [10]:
#preprocess compound features
smiles = df.SMILES.tolist()
process_unimol_inputs(smiles)
compute_compound_props(smiles)

2024-11-01 16:03:26 | unimol_tools/data/conformer.py | 89 | INFO | Uni-Mol Tools | Start generating conformers...
10it [00:00, 80.74it/s]
2024-11-01 16:03:26 | unimol_tools/data/conformer.py | 93 | INFO | Uni-Mol Tools | Succeed to generate conformers for 100.00% of molecules.
2024-11-01 16:03:26 | unimol_tools/data/conformer.py | 95 | INFO | Uni-Mol Tools | Succeed to generate 3d conformers for 100.00% of molecules.


In [11]:
#preprocess protein features
process_interface(interface_df)
extract_esm2(interface_df, device)
compute_protein_props(interface_df)

Summary of Parameters:

Input File: /data/dongok/ICAN-PPII/data/toy_example/protein.fa ; Job: PCP ; Output File: /data/dongok/ICAN-PPII/data/toy_example/protein_phy.csv


### Inference

In [12]:
dataset = PPIInhibitorInferenceDataset(query_datapath, device)  
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=batch_collate_fn)

In [13]:
compound_model = UniMolModel()
model = PPIInhibitorModel(compound_model).to(device)

model_path = 'src/weights/weights.model'
best_state_dict = torch.load(model_path)
model.load_state_dict(best_state_dict, strict=False)

2024-11-01 16:03:45 | unimol_tools/models/unimol.py | 120 | INFO | Uni-Mol Tools | Loading pretrained weights from /data/dongok/anaconda3/envs/ICAN-PPII/lib/python3.9/site-packages/unimol_tools/weights/mol_pre_all_h_220816.pt


<All keys matched successfully>

In [14]:
_, pred = predict(model, dataloader, device)
pred_score = torch.sigmoid(torch.from_numpy(pred))

100%|██████████| 1/1 [00:00<00:00,  6.64it/s]


In [15]:
pred_score

tensor([8.6942e-04, 8.8283e-01, 5.8728e-01, 5.8728e-01, 9.4787e-01, 3.1926e-02,
        3.1194e-01, 2.0950e-02, 1.2631e-01, 6.1373e-01])