In [1]:
import os
import sys
import h5py
import numpy as np
import torch as pt
from glob import glob
from tqdm import tqdm
from multiprocessing import Pool

from src.structure_io import read_pdb, save_pdb
from src.structure import split_by_chain, encode_bfactor
from src.data_encoding import extract_all_contacts, encode_structure, encode_features, extract_topology
from src.dataset import StructuresDataset, collate_batch_features

In [2]:
# model parameters
# save_path = "model/save/i_v3_0_2021-05-27_14-27"  # 89
# save_path = "model/save/i_v3_1_2021-05-28_12-40"  # 90
# save_path = "model/save/i_v4_0_2021-09-07_11-20"  # 89
save_path = "model/save/i_v4_1_2021-09-07_11-21"  # 91
pdb_filepaths = glob("eukaryotic_protein_complexes/*.pdb")
pdb_filepaths = [fp for fp in pdb_filepaths if "_prot_interf" not in fp]

# select saved model
model_filepath = os.path.join(save_path, 'model_ckpt.pt')
#model_filepath = os.path.join(save_path, 'model.pt')

# add module to path
if save_path not in sys.path:
    sys.path.insert(0, save_path)

# load functions
from config import config_model
from model import Model

# define device
device = pt.device("cuda")

# create model
model = Model(config_model)

# reload model
model.load_state_dict(pt.load(model_filepath, map_location=pt.device("cpu")))

# set model to inference
model = model.eval().to(device)

# create dataset loader with preprocessing
dataset = StructuresDataset(pdb_filepaths, with_preprocessing=True)
dataloader = pt.utils.data.DataLoader(dataset, batch_size=None, shuffle=True, num_workers=8, prefetch_factor=2)

In [3]:
# parameters
r_thr = 5.0
output_filepath = 'datasets/eukaryotic_protein_complexes_predictions.h5'

# run model on all subunits
with h5py.File(output_filepath, 'w') as hf:
    with pt.no_grad():
        for subunits, key in tqdm(dataloader):
            #try:
            # compute contacts
            contacts = extract_all_contacts(subunits, r_thr, device=device)
            if len(contacts) == 0:
                continue
            ctc_ids = contacts['A:0']['B:0']['ids'].numpy()

            resids_ctc0 = np.unique(subunits['A:0']['resid'][ctc_ids[:,0]])
            resids_ctc1 = np.unique(subunits['B:0']['resid'][ctc_ids[:,1]])

            N0_0 = np.min(subunits['A:0']['resid'].cpu().numpy()) 
            N1_0 = np.min(subunits['B:0']['resid'].cpu().numpy()) 
            N0 = np.max(subunits['A:0']['resid'].cpu().numpy()) - N0_0 + 1
            N1 = np.max(subunits['B:0']['resid'].cpu().numpy()) - N1_0 + 1

            y0 = np.zeros((N0,))
            y1 = np.zeros((N1,))
            y0[resids_ctc0-N0_0] += 1.0
            y1[resids_ctc1-N1_0] += 1.0

            # store comparison
            hpath = key.split('/')[-1].split('-')[-1].split('.')[0]
            hf[hpath+'/y0'] = y0
            hf[hpath+'/y1'] = y1

            # predictions for each subunits
            p_l = []
            for k, cid in enumerate(['A:0', 'B:0']):
                # concatenate all chains together
                structure = subunits[cid]

                # encode structure and features
                X, M = encode_structure(structure)

                # q = pt.cat(encode_features(structure), dim=1)
                q = encode_features(structure)[0]

                # extract topology
                ids_topk, D_topk, R_topk, D, R = extract_topology(X.to(device), 64)

                # pack data and setup sink (IMPORTANT)
                X, ids_topk, q, M = collate_batch_features([[X, ids_topk, q, M]])

                # run model
                z = model(X.to(device), ids_topk.to(device), q.to(device), M.float().to(device))

                # prediction
                p = pt.sigmoid(z[:,0])

                # save results
                hf[hpath+'/p{}'.format(k)] = p.cpu().numpy()
                p_l.append(p.cpu().numpy())

            # save pdb with predicted interfaces
            output_filepath = key[:-4]+'_prot_interf.pdb'
            save_pdb({'A:0':encode_bfactor(subunits['A:0'], p_l[0]), 'B:0':encode_bfactor(subunits['B:0'], p_l[1])}, output_filepath)

            #except Exception as e:
                #print("error with {}: {}".format(key, e))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1106/1106 [08:29<00:00,  2.17it/s]
