In [1]:
import os
import json
import numpy as np
import torch as pt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob

import src as sp
import runtime as rt

In [49]:
# parameters
pdb_filepath = "examples/tem1/1JTG.pdb"
#pdb_filepath = "examples/the_triangle/2TRG.pdb"
device = pt.device("cuda")

# model parameters
# r6
save_path = "model/save/s_v6_4_2022-09-16_11-51"  # virtual Cb & partial
#save_path = "model/save/s_v6_5_2022-09-16_11-52"  # virtual Cb, partial & noise

# r7
#save_path = "model/save/s_v7_0_2023-04-25"  # partial chain
#save_path = "model/save/s_v7_1_2023-04-25"  # partial secondary structure
#save_path = "model/save/s_v7_2_2023-04-25"  # partial chain high coverage

# create runtime model 
model = rt.SequenceModel(save_path, "model.pt", device=device)

# create confidence mapping
#conf = rt.ConfidenceMap("results/{}_cdf.csv".format(os.path.basename(save_path)))

In [86]:
# parameters
known_chains = [""]
r_noise = 0.0

# load structure
structure = rt.load_structure(pdb_filepath)

# known chains
m_known = np.isin([cn.split(':')[0] for cn in structure['chain_name']], known_chains)

# apply noise
r = pt.randn(structure['xyz'].shape).numpy()
structure['xyz'] = structure['xyz'] + r_noise * r

# apply model
structure, p, y = model(structure, m_known=m_known)
p, y = rt.aa_only(p, y)

# find residues per chain and known residues
mr_chains, chain_names = rt.chain_masks(structure)
mr_known = pt.any(mr_chains[:,pt.from_numpy(np.isin(chain_names, known_chains))], dim=1)

# overwrite prediction
p[mr_known] = y[mr_known]

# compute confidence probability
#c = pt.from_numpy(conf(p.numpy()))
#h = c / pt.sum(c, dim=1).reshape(-1,1)
#kstar = rt.kstar(c)

# get sequence
seqs_ref = [rt.max_pred_to_seq(y[mr_chains[:,i]]) for i in range(mr_chains.shape[1])]
seqs = [rt.max_pred_to_seq(p[mr_chains[:,i]]) for i in range(mr_chains.shape[1])]

# tag for files
if len(known_chains) > 0:
    tag = "_given_{}".format(''.join(known_chains))
else:
    tag = ""

# write fasta
rt.write_fasta(pdb_filepath[:-4]+"{}_p.fasta".format(tag), ':'.join(seqs), info=os.path.basename(save_path))

# store processed input structure
sp.save_pdb(sp.split_by_chain(structure), pdb_filepath[:-4]+"{}_bb.pdb".format(tag))

# assess predictions
results = []
for i in range(mr_chains.shape[1]):
    # get chain predictions
    yi = y[mr_chains[:,i]]
    pi = p[mr_chains[:,i]]
    
    # store results
    results.append({
        "pdb_filepath": pdb_filepath,
        "chain_name": chain_names[i],
        "recovery_rate": rt.recovery_rate(yi, pi).numpy().item(),
        "maximum_recovery_rate": rt.maximum_recovery_rate(yi, pi).numpy().item(),
        "average_multiplicity": rt.average_multiplicity(pi).numpy().item(),
        "sequence_similarity": rt.sequence_similarity(seqs_ref[i], seqs[i]),
    })
    
# pack results
df = pd.DataFrame(results)
df

Unnamed: 0,pdb_filepath,chain_name,recovery_rate,maximum_recovery_rate,average_multiplicity,sequence_similarity
0,examples/tem1/1JTG.pdb,A,0.48855,0.885496,3.961832,0.706107
1,examples/tem1/1JTG.pdb,B,0.581818,0.860606,3.581818,0.690909
