In [1]:
import os
import sys
import h5py
import json
import numpy as np
import torch as pt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob

from src.dataset import StructuresDataset, collate_batch_features, select_by_sid, select_by_interface_types
from src.data_encoding import encode_structure, encode_features, extract_topology, categ_to_resnames, resname_to_categ
from src.structure import data_to_structure, encode_bfactor, concatenate_chains, split_by_chain
from src.structure_io import save_pdb, read_pdb
from src.scoring import bc_scoring, bc_score_names
from processing.build_dataset import config_dataset, pack_dataset_items
from src.data_encoding import extract_all_contacts
from src.config import config_data
from src.data_handler import load_interface_labels
from src.scoring import bc_scoring

In [2]:
# model parameters
# R3
#save_path = "model/save/i_v3_0_2021-05-27_14-27"  # 89
#save_path = "model/save/i_v3_1_2021-05-28_12-40"  # 90
# R4
#save_path = "model/save/i_v4_0_2021-09-07_11-20"  # 89
save_path = "model/save/i_v4_1_2021-09-07_11-21"  # 91

# select saved model
model_filepath = os.path.join(save_path, 'model_ckpt.pt')
#model_filepath = os.path.join(save_path, 'model.pt')

In [3]:
# add module to path
if save_path not in sys.path:
    sys.path.insert(0, save_path)
    
# load functions
from config import config_model, config_data
from data_handler import Dataset
from model import Model

In [7]:
# define device
device = pt.device("cpu")

# create model
model = Model(config_model)

# reload model
model.load_state_dict(pt.load(model_filepath, map_location=pt.device(device)))

# set model to inference
model = model.eval().to(device)

In [None]:
# data parameters
data_path = "examples/53"
# find pdb files and ignore already predicted oins
pdb_filepaths = glob(os.path.join(data_path, "*.pdb"), recursive=True)
pdb_filepaths = [fp for fp in pdb_filepaths if "_i" not in fp]

# create dataset loader with preprocessing
dataset = StructuresDataset(pdb_filepaths, with_preprocessing=True)
testing_transient=np.genfromtxt("/home/omokhtar/Desktop/final_atom/data/benchmarks/MaSIF/testing_transient.txt", dtype=np.dtype('U'))
# debug print
print(len(dataset))

results= {}
mids = np.array(config_dataset['molecule_ids'].astype(np.string_)).astype(np.dtype('U'))
t0 = pt.from_numpy(np.where(np.isin(mids, config_data['l_types']))[0])
t1_l = [pt.from_numpy(np.where(np.isin(mids, r_types))[0]) for r_types in config_data['r_types']]

# run model on all subunits
with pt.no_grad():
    for subunits, filepath in tqdm(dataset):
        contacts = extract_all_contacts(subunits, config_dataset['r_thr'], device=device)
        _, contacts_data = pack_dataset_items(subunits, contacts,config_dataset['molecule_ids'],config_dataset['max_num_nn'], device=device)
       
        # concatenate all chains together
        structure = concatenate_chains(subunits)
        for subunit in subunits:
            id_chain = '_'.join([filepath.split('/')[-1][:-4],subunit.split(':')[0]]) 
            if subunit.count(":") >1: continue
            if id_chain not in testing_transient: continue
                
            contact0 = contacts_data[subunit]
            # encode structure and features
            X, M = encode_structure(subunits[subunit])
            #q = pt.cat(encode_features(structure), dim=1)
            q = encode_features(subunits[subunit])[0]

            # extract topology
            ids_topk, _, _, _, _ = extract_topology(X, 64)

            # pack data and setup sink (IMPORTANT)
            X, ids_topk, q, M = collate_batch_features([[X, ids_topk, q, M]])
            y = pt.zeros((M.shape[1], len(t1_l)), dtype=pt.bool)
            for ckey in contact0:
                y |= load_interface_labels(contact0[ckey], t0, t1_l)
            
            # run model
            z = model(X.to(device), ids_topk.to(device), q.to(device), M.float().to(device))
            
            y = y[:,0].float().view(-1,1)
            p = pt.sigmoid(z[:,0]).float().view(-1,1)
            results[id_chain]=bc_scoring(y.to('cuda'),p.to('cuda'))
''' 
            # for all predictions
            for i in range(z.shape[1]):
                # prediction
                p = pt.sigmoid(z[:,i])

                # encode result
                #structure = encode_bfactor(structure, p.cpu().numpy())

                # save results
                #output_filepath = filepath[:-4]+'_i{}.pdb'.format(i)
                #save_pdb(split_by_chain(structure), output_filepath)
'''

In [8]:
for id_chain in results:
    print (f'roc auc for {id_chain} is {results[id_chain][-2][0]}')

roc auc for 2f4m_A is 0.9820747375488281
roc auc for 1wdw_H is 0.8655789494514465
roc auc for 2ayo_A is 0.9088626503944397
roc auc for 1jtd_B is 0.988070547580719
roc auc for 4lvn_A is 0.8595605492591858
roc auc for 4zrj_A is 0.729529619216919
roc auc for 3vv2_A is 0.8555220365524292
roc auc for 1f6m_A is 0.932812511920929
roc auc for 4hdo_A is 0.8906896710395813
roc auc for 3sja_I is 0.7162952423095703
roc auc for 1w1w_B is 0.8545309901237488
roc auc for 2i3t_A is 0.81409752368927
roc auc for 4fzv_A is 0.8393750190734863
roc auc for 1xg2_A is 0.5910776257514954
roc auc for 3qml_D is 0.8383738994598389
roc auc for 4yc7_B is 0.9247331619262695
roc auc for 4x33_B is 0.972842812538147
roc auc for 3bh6_B is 0.9766187071800232
roc auc for 3zwl_B is 0.7412217259407043
roc auc for 1xqs_A is 0.9342710971832275
roc auc for 3wn7_A is 0.9805610179901123
roc auc for 1ewy_A is 0.7579848170280457
roc auc for 2v9t_B is 0.9479190707206726
roc auc for 3h6g_B is 0.9959572553634644
roc auc for 4dvg_B is 

In [10]:
auc = pt.stack([pt.tensor(v) for v in results.values()])[:, -2]
pt.median(auc)

53


  auc = pt.stack([pt.tensor(v) for v in results.values()])[:, -2]


tensor(0.8555, device='cuda:0')