In [None]:
import numpy as np
import torch as pt
from glob import glob
from tqdm import tqdm
import random

from data.utils.feature_extraction import extract_dynamic_features, encode_sequence, mean_coordinates, extract_topology
from data.utils.PDB_processing import split_nmr_pdb, make_pdb, read_pdb, get_sasa_unbound, fill_nan_with_neighbors
from data.utils.protein_chemistry import list_aa, std_elements

from model.utils.data_handler import collate_batch_features
from model.utils.model import Model
from model.utils.configs import config_model, config_data, config_runtime
from model.utils.for_visualization import p_to_bfactor

In [None]:
device = pt.device("cuda" if pt.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
pdb_files = glob('input_structures/*.pdb')
model_path = 'model_307.pt'

In [None]:
results = {}

for each_pdb in tqdm(pdb_files):
    chain_key = each_pdb.split(r"/")[-1].split('.')[0]
    print (chain_key)
    pdb_chains = split_nmr_pdb(each_pdb)
    models = list(pdb_chains.values())[0] # Single chain PDB structures
    pdb_file = make_pdb(models)
    aa_map, seq, atom_type, atoms_xyz = read_pdb(pdb_file)
    seq=''.join(seq)

    # Get features
    mean_xyz = mean_coordinates(atoms_xyz)
    R, D= extract_topology(mean_xyz)
    # Indices of nearest neighbors
    knn = min(64, D.shape[0])
    D_nn, nn_topk = pt.topk(pt.tensor(D), knn, dim=1, largest=False)
    R_nn = pt.gather(pt.tensor(R), 1, nn_topk.unsqueeze(2).repeat(1, 1, R.shape[2])).to(pt.float32)
    motion_v_nn, motion_s_nn, rmsf, de, CP_nn = extract_dynamic_features(atoms_xyz, nn_topk.numpy())
    sasa_dic_unbound, labeled_seqs2 = get_sasa_unbound(each_pdb)
    assert list(labeled_seqs2.values())[0]==seq
    rsa = np.array(list(sasa_dic_unbound.values())[0])
    rsa = fill_nan_with_neighbors(np.array(rsa))
    rsa = np.array(rsa)[np.array(aa_map) - 1]
    onehot_seq = pt.tensor(encode_sequence(atom_type, std_elements))
    
    rmsf, de, rsa = pt.tensor(rmsf).unsqueeze(1), pt.tensor(de).unsqueeze(1), pt.tensor(rsa).to(pt.float64).unsqueeze(1)
    D_nn, nn_topk, R_nn, motion_v_nn, motion_s_nn, CP_nn = D_nn.to(pt.float32).unsqueeze(2), nn_topk.to(pt.int64), R_nn.to(pt.float32), pt.tensor(motion_v_nn).to(pt.float32), pt.tensor(motion_s_nn).to(pt.float32).unsqueeze(2), pt.tensor(CP_nn).to(pt.float32).unsqueeze(2)
    
    features = collate_batch_features([[onehot_seq, rmsf, de, rsa, nn_topk, D_nn, R_nn, motion_v_nn, motion_s_nn, CP_nn, pt.tensor(aa_map)]])
    onehot_seq, rmsf, de, rsa, nn_topk, D_nn, R_nn, motion_v_nn, motion_s_nn, CP_nn, aa_map = features
    
    # Load and Apply
    model = Model(config_model)
    model.load_state_dict(pt.load(model_path, map_location=device, weights_only=True))
    model = model.eval().to(device)
    with pt.no_grad():
        z,_,_ = model(onehot_seq.to(device), rmsf.to(device), de.to(device), rsa.to(device), nn_topk.to(device), D_nn.to(device), R_nn.to(device), motion_v_nn.to(device), motion_s_nn.to(device), CP_nn.to(device), aa_map.to(device))
        results[each_pdb] = [pt.sigmoid(z).detach(), seq]