# Inference Notebook
- This notebook provides a guide to predict the PPI inhibitors with your custom dataset.
- It includes preprocessing, data loading, and model inference.

In [1]:
import os
import tqdm
import numpy as np
import pandas as pd
import random 
import copy
import pickle

import torch
from torch import nn
from torch.utils.data import DataLoader

from unimol_tools import utils
from unimol_tools.data import DataHub
from unimol_tools.models import UniMolModel

from src.dataset import PPIInhibitorInferenceDataset, process_interface, calculate_physicochemical_properties
from src.model import PPIInhibitorModel
from src.utils import predict, performance_evaluation, batch_collate_fn

2024-10-29 17:11:36 | unimol_tools/weights/weighthub.py | 17 | INFO | Uni-Mol Tools | Weights will be downloaded to default directory: /data/dongok/anaconda3/envs/unimol/lib/python3.9/site-packages/unimol_tools/weights


In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [4]:
seed = 2022 
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [5]:
batch_size = 64

In [6]:
#load your dataset
datapath = 'data/toy_example/toy_example.csv'
df = pd.read_csv(datapath)
df.head()

Unnamed: 0,SMILES,ppi_label,uniprot_id1,uniprot_id2
0,COc1cc(Cc2cnc(N)nc2N)cc(OC)c1N(C)C,10.0,Q01196,Q13951
1,O=C(CS(=O)C(c1ccccc1)c1ccccc1)NO,10.0,Q01196,Q13951
2,OC1C=C2CCN3Cc4cc5c(cc4C(C1O)C23)OCO5,10.0,Q01196,Q13951
3,OC1C=C2CCN3Cc4cc5c(cc4C(C1O)C23)OCO5,10.0,Q01196,Q13951
4,NC(CCc1ccc(N(CCCl)CCCl)cc1)C(=O)O,10.0,Q01196,Q13951


In [7]:
#preprocess compounds for efficient inference with UniMol
smiles_list = df.SMILES.tolist()

datahub = DataHub(data=smiles_list, task='repr')
unimol_input = datahub.data['unimol_input']

src_tokens_dict = {smiles_list[i]: unimol_input[i]['src_tokens'] for i in range(len(smiles_list))}
with open('data/toy_example/src_tokens_dict.pickle', 'wb') as f:
    pickle.dump(src_tokens_dict, f)
    
src_distance_dict = {smiles_list[i]: unimol_input[i]['src_distance'] for i in range(len(smiles_list))}
with open('data/toy_example/src_distance_dict.pickle', 'wb') as f:
    pickle.dump(src_distance_dict, f)
    
src_coord_dict = {smiles_list[i]: unimol_input[i]['src_coord'] for i in range(len(smiles_list))}
with open('data/toy_example/src_coord_dict.pickle', 'wb') as f:
    pickle.dump(src_coord_dict, f)
    
src_edge_type_dict = {smiles_list[i]: unimol_input[i]['src_edge_type'] for i in range(len(smiles_list))}
with open('data/toy_example/src_edge_type_dict.pickle', 'wb') as f:
    pickle.dump(src_edge_type_dict, f)

2024-10-29 17:11:38 | unimol_tools/data/conformer.py | 89 | INFO | Uni-Mol Tools | Start generating conformers...
10it [00:00, 75.95it/s]
2024-10-29 17:11:38 | unimol_tools/data/conformer.py | 93 | INFO | Uni-Mol Tools | Failed to generate conformers for 0.00% of molecules.
2024-10-29 17:11:38 | unimol_tools/data/conformer.py | 95 | INFO | Uni-Mol Tools | Failed to generate 3d conformers for 0.00% of molecules.


In [8]:
#preprocess the physicochemical properties of compounds
compound_phy = calculate_physicochemical_properties(smiles_list)

with open('data/toy_example/compound_phy.pickle', 'wb') as f:
    pickle.dump(compound_phy, f)

In [9]:
#preprocess interface infomation
process_interface('data/toy_example/ppi_interface.csv').to_csv('data/toy_example/processed_interface.csv')

In [10]:
pd.read_csv('data/toy_example/ppi_interface.csv').head()

Unnamed: 0,uniprot_id1,uniprot_id2,ppi_label,uniprot_sequence1,uniprot_sequence2,interface_idx1,interface_idx2,min_index1,max_index1,min_index2,max_index2
0,Q01196,Q13951,10,MRIPVDASTSRRFTPPSTALSPGKMSEALPLGAPDAGAALAGKLRS...,MPRVVPDQRSKFENEEFFRKLSRECEIKYTGFRDRPHEERQARFQN...,"[65, 66, 67, 68, 93, 94, 95, 103, 105, 106, 10...","[1, 2, 3, 4, 10, 16, 27, 28, 29, 32, 33, 53, 5...",65,162,1,130


In [11]:
#dataloader
dataset = PPIInhibitorInferenceDataset(datapath, device)  
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=batch_collate_fn)

In [12]:
#model
compound_model = UniMolModel()
model = PPIInhibitorModel(compound_model).to(device)

model_path = 'src/weights/weights.model'
best_state_dict = torch.load(model_path)
model.load_state_dict(best_state_dict, strict=False)

2024-10-29 17:11:39 | unimol_tools/models/unimol.py | 120 | INFO | Uni-Mol Tools | Loading pretrained weights from /data/dongok/anaconda3/envs/unimol/lib/python3.9/site-packages/unimol_tools/weights/mol_pre_all_h_220816.pt


<All keys matched successfully>

In [13]:
_, pred = predict(model, dataloader, device)
pred_score = torch.sigmoid(torch.from_numpy(pred))

100%|██████████| 1/1 [00:00<00:00,  1.34it/s]


In [14]:
pred_score

tensor([0.2695, 0.8828, 0.5868, 0.5868, 0.9703, 0.0319, 0.4410, 0.0217, 0.1442,
        0.6137])