In [1]:
import os
import json
import numpy as np
import torch as pt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob

import src as sp
import src.runtime as rt

In [2]:
# parameters
device = pt.device("cpu")

# locate filepaths
pdb_filepaths = glob("examples/tem1/*.pdb")
pdb_filepaths = [fp for fp in pdb_filepaths if "_bb.pdb" not in fp]

# model parameters
save_path = "model/save/s_v6_4_2022-09-16_11-51"  # virtual Cb & partial

# create dataset
dataset = rt.StructuresDataset(pdb_filepaths)

# create runtime model 
model = rt.SequenceModel(save_path, "model.pt", device=device)

# create confidence mapping
#conf = rt.ConfidenceMap("results/{}_cdf.csv".format(os.path.basename(save_path)))

In [3]:
# parameters
r_noise = 0.0

# apply model on all subunits  alone
results = []
for i in tqdm(range(len(dataset))):
    # load structure
    structure, pdb_filepath = dataset[i]
    
    # split chains
    subunits = sp.split_by_chain(structure)
    
    # for each subunits
    for cid in subunits:
        # get chain name
        chain_name = cid.split(':')[0]

        # apply model
        subunit = subunits[cid]
        subunit['chain_name'] = np.array([cid]*subunit['xyz'].shape[0])
        subunit, p, y = model(subunit)
        p, y = rt.aa_only(p, y)

        # compute confidence probability
        #c = pt.from_numpy(conf(p.numpy()))
        #h = c / pt.sum(c, dim=1).reshape(-1,1)
        #kstar = rt.kstar(c)
        
        # get sequence
        seq_ref = rt.max_pred_to_seq(y)
        seq = rt.max_pred_to_seq(p)
        
        # write fasta
        rt.write_fasta(pdb_filepath[:-4]+"_{}_p.fasta".format(chain_name), seq, info=os.path.basename(save_path))
        
        # store processed input structure
        sp.save_pdb({cid:subunit}, pdb_filepath[:-4]+"_{}_bb.pdb".format(chain_name))

        # assess predictions
        results.append({
            "pdb_filepath": pdb_filepath,
            "chain_name": chain_name,
            "recovery_rate": rt.recovery_rate(y, p).numpy().item(),
            "maximum_recovery_rate": rt.maximum_recovery_rate(y, p).numpy().item(),
            "average_multiplicity": rt.average_multiplicity(p).numpy().item(),
            "sequence_similarity": rt.sequence_similarity(seq_ref, seq),
        })
    
# pack results
df = pd.DataFrame(results)
df

100%|██████████| 1/1 [00:08<00:00,  8.34s/it]


Unnamed: 0,pdb_filepath,chain_name,recovery_rate,maximum_recovery_rate,average_multiplicity,sequence_similarity
0,examples/tem1/1JTG.pdb,A,0.557252,0.881679,3.39313,0.751908
1,examples/tem1/1JTG.pdb,B,0.557576,0.890909,3.448485,0.709091
