In [1]:
import os
import sys
import yaml
import pickle
import random
import logging
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Subset, DataLoader

sys.path.append("../../code/modules/")

from common.utils import load_cfg
from pocket_modules.loaders import PocketTestDataset
from pocket_modules.trainers import Pseq2SitesTrainer

### 1. Define settings

In [2]:
conf_path = "../../code/pocket_extractor_config.yml"
config = load_cfg(conf_path)

torch.manual_seed(config['Train']['seed'])
np.random.seed(config['Train']['seed'])

device = torch.device("cuda:" + str(config['Train']['device'])) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config['Train']['seed'])

### 2. Load data

In [3]:
def load_data(data_path, feature_path, seqs_index):
    df = pd.read_csv(f"{data_path}", sep = "\t")
    
    with open(f"{feature_path}", "rb") as f:
        features = pickle.load(f)
    
    pid_list, seqs_list = df.iloc[:, 1].values, df.iloc[:,seqs_index].values
    
    seqs_dict = dict()
    
    for i, j in zip(pid_list, seqs_list):
        seqs_dict[i] = j
    
    pid_list = [i for i in list(seqs_dict.keys())]
    seqs_list = [seqs_dict[i] for i in pid_list]
    seqs_lengths = np.array([len(i) for i in seqs_list])
    
    return pid_list, features, seqs_list, seqs_lengths

In [4]:
Training_ID, Training_features, Training_seqs, Training_seq_lengths  = load_data(f"../../input_data/PDB/BA/Training_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/Training_protein_features.pkl", 4) 

CASF2016_ID, CASF2016_features, CASF2016_seqs, CASF2016_seq_lengths = load_data(f"../../input_data/PDB/BA/CASF2016_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/CASF2016_protein_features.pkl", 4) 

CASF2013_ID, CASF2013_features, CASF2013_seqs, CASF2013_seq_lengths = load_data(f"../../input_data/PDB/BA/CASF2013_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/CASF2013_protein_features.pkl", 4) 

CSAR2014_ID, CSAR2014_features, CSAR2014_seqs, CSAR2014_seq_lengths = load_data(f"../../input_data/PDB/BA/CSAR2014_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/CSAR2014_protein_features.pkl", 4) 

CSAR2012_ID, CSAR2012_features, CSAR2012_seqs, CSAR2012_seq_lengths = load_data(f"../../input_data/PDB/BA/CSAR2012_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/CSAR2012_protein_features.pkl", 4) 

CSARset1_ID, CSARset1_features, CSARset1_seqs, CSARset1_seq_lengths = load_data(f"../../input_data/PDB/BA/CSARset1_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/CSARset1_protein_features.pkl", 4) 

CSARset2_ID, CSARset2_features, CSARset2_seqs, CSARset2_seq_lengths = load_data(f"../../input_data/PDB/BA/CSARset2_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/CSARset2_protein_features.pkl", 4) 

Astex_ID, Astex_features, Astex_seqs, Astex_seq_lengths = load_data(f"../../input_data/PDB/BA/Astex_BA_data.tsv",
                                                                        "../../input_data/PDB/BA/Astex_protein_features.pkl", 4) 

COACH420_ID, COACH420_features, COACH420_seqs, COACH420_seq_lengths = load_data(f"../../input_data/PDB/BA/COACH420_IS_data.tsv",
                                                                        "../../input_data/PDB/BA/COACH420_protein_features.pkl", 3) 

HOLO4K_ID, HOLO4K_features, HOLO4K_seqs, HOLO4K_seq_lengths = load_data(f"../../input_data/PDB/BA/HOLO4K_IS_data.tsv",
                                                                        "../../input_data/PDB/BA/HOLO4K_protein_features.pkl", 3)                                       

### 3. Load Pocket extractor

In [5]:
trainer = Pseq2SitesTrainer(config, device)     
trainer.model.load_state_dict(torch.load(config['Path']['check_point']))    

<All keys matched successfully>

### 4. Get pocket prediction results

In [6]:
def remove_over_lengths(pred_list, length):

    pred = list()
    
    for idx, val in enumerate(pred_list):
        if val < length:
            pred.append(val)

    return np.array(pred)

def extract_results(pdbid, predictions, lengths_dict):
    results = dict()
    
    for idx, (pdb, pre) in enumerate(zip(pdbid, predictions)):
        ind = np.where((pre >= 0.4))[0]
        results[pdb] = remove_over_lengths(ind, lengths_dict[idx])
    
    return results

In [7]:
# PDBbind
Training_Dataset = PocketTestDataset(PID = Training_ID, Pseqs = Training_seqs, Pfeatures = Training_features)
Training_Loader = DataLoader(Training_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
Training_predictions = trainer.test(Training_Loader)

Training_predictions = extract_results(Training_ID, Training_predictions, Training_seq_lengths)
with open("../../input_data/PDB/BA/Training_pockets.pkl", "wb") as f:
    pickle.dump(Training_predictions, f)

100%|████████████████████████████████████████████████████████████████████| 76/76 [00:14<00:00,  5.32it/s]


In [8]:
# CASF2016
CASF2016_Dataset = PocketTestDataset(PID = CASF2016_ID, Pseqs = CASF2016_seqs, Pfeatures = CASF2016_features)
CASF2016_Loader = DataLoader(CASF2016_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
CASF2016_predictions = trainer.test(CASF2016_Loader)

CASF2016_predictions = extract_results(CASF2016_ID, CASF2016_predictions, CASF2016_seq_lengths)
with open("../../input_data/PDB/BA/CASF2016_pockets.pkl", "wb") as f:
    pickle.dump(CASF2016_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.75it/s]


In [9]:
# CASF2013
CASF2013_Dataset = PocketTestDataset(PID = CASF2013_ID, Pseqs = CASF2013_seqs, Pfeatures = CASF2013_features)
CASF2013_Loader = DataLoader(CASF2013_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
CASF2013_predictions = trainer.test(CASF2013_Loader)

CASF2013_predictions = extract_results(CASF2013_ID, CASF2013_predictions, CASF2013_seq_lengths)
with open("../../input_data/PDB/BA/CASF2013_pockets.pkl", "wb") as f:
    pickle.dump(CASF2013_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.84it/s]


In [10]:
# CSAR2014
CSAR2014_Dataset = PocketTestDataset(PID = CSAR2014_ID, Pseqs = CSAR2014_seqs, Pfeatures = CSAR2014_features)
CSAR2014_Loader = DataLoader(CSAR2014_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
CSAR2014_predictions = trainer.test(CSAR2014_Loader)

CSAR2014_predictions = extract_results(CSAR2014_ID, CSAR2014_predictions, CSAR2014_seq_lengths)
with open("../../input_data/PDB/BA/CSAR2014_pockets.pkl", "wb") as f:
    pickle.dump(CSAR2014_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 57.07it/s]


In [11]:
# CSAR2012
CSAR2012_Dataset = PocketTestDataset(PID = CSAR2012_ID, Pseqs = CSAR2012_seqs, Pfeatures = CSAR2012_features)
CSAR2012_Loader = DataLoader(CSAR2012_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
CSAR2012_predictions = trainer.test(CSAR2012_Loader)

CSAR2012_predictions = extract_results(CSAR2012_ID, CSAR2012_predictions, CSAR2012_seq_lengths)
with open("../../input_data/PDB/BA/CSAR2012_pockets.pkl", "wb") as f:
    pickle.dump(CSAR2012_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 25.77it/s]


In [12]:
# CSARset1
CSARset1_Dataset = PocketTestDataset(PID = CSARset1_ID, Pseqs = CSARset1_seqs, Pfeatures = CSARset1_features)
CSARset1_Loader = DataLoader(CSARset1_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
CSARset1_predictions = trainer.test(CSARset1_Loader)

CSARset1_predictions = extract_results(CSARset1_ID, CSARset1_predictions, CSARset1_seq_lengths)
with open("../../input_data/PDB/BA/CSARset1_pockets.pkl", "wb") as f:
    pickle.dump(CSARset1_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.90it/s]


In [13]:
# CSARset2
CSARset2_Dataset = PocketTestDataset(PID = CSARset2_ID, Pseqs = CSARset2_seqs, Pfeatures = CSARset2_features)
CSARset2_Loader = DataLoader(CSARset2_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
CSARset2_predictions = trainer.test(CSARset2_Loader)

CSARset2_predictions = extract_results(CSARset2_ID, CSARset2_predictions, CSARset2_seq_lengths)
with open("../../input_data/PDB/BA/CSARset2_pockets.pkl", "wb") as f:
    pickle.dump(CSARset2_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.54it/s]


In [14]:
# Astex
Astex_Dataset = PocketTestDataset(PID = Astex_ID, Pseqs = Astex_seqs, Pfeatures = Astex_features)
Astex_Loader = DataLoader(Astex_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
Astex_predictions = trainer.test(Astex_Loader)

Astex_predictions = extract_results(Astex_ID, Astex_predictions, Astex_seq_lengths)
with open("../../input_data/PDB/BA/Astex_pockets.pkl", "wb") as f:
    pickle.dump(Astex_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  7.67it/s]


In [15]:
# COACH420
COACH420_Dataset = PocketTestDataset(PID = COACH420_ID, Pseqs = COACH420_seqs, Pfeatures = COACH420_features)
COACH420_Loader = DataLoader(COACH420_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
COACH420_predictions = trainer.test(COACH420_Loader)

COACH420_predictions = extract_results(COACH420_ID, COACH420_predictions, COACH420_seq_lengths)
with open("../../input_data/PDB/BA/COACH420_pockets.pkl", "wb") as f:
    pickle.dump(COACH420_predictions, f)

100%|██████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  6.34it/s]


In [16]:
# HOLO4K
HOLO4K_Dataset = PocketTestDataset(PID = HOLO4K_ID, Pseqs = HOLO4K_seqs, Pfeatures = HOLO4K_features)
HOLO4K_Loader = DataLoader(HOLO4K_Dataset, batch_size=config['Train']['batch_size'], shuffle=False)
HOLO4K_predictions = trainer.test(HOLO4K_Loader)
                                                     
HOLO4K_predictions = extract_results(HOLO4K_ID, HOLO4K_predictions, HOLO4K_seq_lengths)
with open("../../input_data/PDB/BA/HOLO4K_pockets.pkl", "wb") as f:
    pickle.dump(HOLO4K_predictions, f)   

100%|████████████████████████████████████████████████████████████████████| 34/34 [00:05<00:00,  5.90it/s]
