In [1]:
import os
import sys
import yaml
import pickle
import random
import logging
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Subset, DataLoader

sys.path.append("../../code/modules/")

from common.utils import load_cfg
from pocket_modules.loaders import PocketTestDataset
from pocket_modules.trainers import Pseq2SitesTrainer

### 1. Define settings

In [2]:
conf_path = "../../code/pocket_extractor_config.yml"
config = load_cfg(conf_path)

torch.manual_seed(config['Train']['seed'])
np.random.seed(config['Train']['seed'])

device = torch.device("cuda:" + str(config['Train']['device'])) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config['Train']['seed'])

### 2. Load data

In [3]:
def load_data(data_path, feature_path):
    df = pd.read_csv(f"{data_path}", sep = "\t")
    
    with open(f"{feature_path}", "rb") as f:
        features = pickle.load(f)
    
    pid_list, seqs_list = df.iloc[:, 0].values, df.iloc[:,3].values
    
    seqs_dict = dict()
    
    for i, j in zip(pid_list, seqs_list):
        seqs_dict[i] = j
    
    pid_list = [i for i in list(seqs_dict.keys())]
    seqs_list = [seqs_dict[i] for i in pid_list]
    seqs_lengths = np.array([len(i) for i in seqs_list])
    
    return pid_list, features, seqs_list, seqs_lengths

In [4]:
IC50_ID, IC50_features, IC50_seqs, IC50_lengths = load_data("../../input_data/BindingDB/IC50_data.tsv", 
            "../../input_data/BindingDB/IC50_protein_features.pkl")
print(f"IC50: {len(IC50_ID)}")

IC50: 4347


In [5]:
Ki_ID, Ki_features, Ki_seqs, Ki_lengths = load_data("../../input_data/BindingDB/Ki_data.tsv", 
            "../../input_data/BindingDB/Ki_protein_features.pkl")
print(f"Ki: {len(Ki_ID)}")

Ki: 2431


### 03. Load pocket extractor

In [6]:
trainer = Pseq2SitesTrainer(config, device)     
trainer.model.load_state_dict(torch.load(config['Path']['check_point']))    

<All keys matched successfully>

### 04. Get pocket prediction results

In [7]:
def remove_over_lengths(pred_list, length):

    pred = list()
    
    for idx, val in enumerate(pred_list):
        if val < length:
            pred.append(val)

    return np.array(pred)

def extract_results(pdbid, predictions, lengths_dict):
    results = dict()
    
    for idx, (pdb, pre) in enumerate(zip(pdbid, predictions)):
        ind = np.where((pre >= 0.4))[0]
        results[pdb] = remove_over_lengths(ind, lengths_dict[idx])
    
    return results

In [9]:
# IC50
IC50_Dataset = PocketTestDataset(PID = IC50_ID, Pseqs = IC50_seqs, Pfeatures = IC50_features)
IC50_Loader = DataLoader(IC50_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
IC50_predictions = trainer.test(IC50_Loader)

IC50_predictions = extract_results(IC50_ID, IC50_predictions, IC50_lengths)
with open("../../input_data/BindingDB/IC50_pockets.pkl", "wb") as f:
    pickle.dump(IC50_predictions, f)

100%|██████████████████████████████████████████████████████████████████| 136/136 [00:23<00:00,  5.75it/s]


In [11]:
# Ki
Ki_Dataset = PocketTestDataset(PID = Ki_ID, Pseqs = Ki_seqs, Pfeatures = Ki_features)
Ki_Loader = DataLoader(Ki_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
Ki_predictions = trainer.test(Ki_Loader)

Ki_predictions = extract_results(Ki_ID, Ki_predictions, Ki_lengths)
with open("../../input_data/BindingDB/Ki_pockets.pkl", "wb") as f:
    pickle.dump(Ki_predictions, f)

100%|████████████████████████████████████████████████████████████████████| 76/76 [00:13<00:00,  5.81it/s]
