In [2]:
import os
import numpy as np
import torch as pt
from tqdm import tqdm
from glob import glob

import src as sp
import runtime as rt

In [3]:
# parameters
device = pt.device("cpu")
min_id = True 
min_sim = False 
r_noise = 0.0
n_sample = 1

# locate filepaths
pdb_filepaths = glob("examples/tem1/*.pdb")

# output path
output_path = "examples/tem1/"

# model parameters
# r6
save_path = "model/save/s_v6_4_2022-09-16_11-51"  # virtual Cb & partial
#save_path = "model/save/s_v6_5_2022-09-16_11-52"  # virtual Cb, partial & noise

# r7
#save_path = "model/save/s_v7_0_2023-04-25"  # partial chain
#save_path = "model/save/s_v7_1_2023-04-25"  # partial secondary structure
#save_path = "model/save/s_v7_2_2023-04-25"  # partial chain high coverage

# create runtime model 
model = rt.SequenceModel(save_path, "model.pt", device=device)

# create confidence mapping
conf = rt.ConfidenceMap("results/{}_cdf.csv".format(os.path.basename(save_path)))

In [4]:
for pdb_filepath in tqdm(pdb_filepaths):
    # load structure
    structure = rt.load_structure(pdb_filepath)
    
    # find all chains
    cids = [cid.split(':')[0] for cid in np.unique(structure['chain_name'])]
    
    # design one chain at a time
    for cid in cids:
        # known chains
        m_known = ~np.isin([cn.split(':')[0] for cn in structure['chain_name']], [cid])

        # apply noise
        r = pt.randn((structure['xyz'].shape[0], n_sample, structure['xyz'].shape[1])).numpy()
        structure['xyz'] = np.expand_dims(structure['xyz'],1) + r_noise * r
        
        # apply model
        structure, P, y = model(structure, m_known=m_known)

        # get highest average confidence prediction
        if len(P.shape) > 2:
            C = np.stack([conf(P[i].numpy()) for i in range(P.shape[0])])
            iopt = np.argmax(np.mean(np.max(C, axis=2), axis=1))
            p = P[iopt]
        else:
            p = P

        # amino acid only
        p, y = rt.aa_only(p, y)
        
        # find residues per chain and known residues
        mr_chains, chain_names = rt.chain_masks(structure)
        mr_known = pt.any(mr_chains[:,pt.from_numpy(~np.isin(chain_names, [cid]))], dim=1)
        
        # overwrite prediction
        p[mr_known] = y[mr_known]
        
        if min_id:
            # minimize sequence identity
            m_choice = (pt.sum(pt.round(p), dim=1) > 1.0)
            p[m_choice, pt.argmax(y, dim=1)[m_choice]] = 0.0
    
        # get sequence
        if min_sim:
            seqs = [rt.minimize_sequence_similarity(p[mr_chains[:,i]], y[mr_chains[:,i]])[0] for i in range(mr_chains.shape[1])]
        else:
            seqs = [rt.max_pred_to_seq(p[mr_chains[:,i]]) for i in range(mr_chains.shape[1])]
    
        # subunits ids
        sid = os.path.basename(pdb_filepath).split('.')[0]
        if "_{}".format(cid) in sid:
            tag = sid
        else:
            tag = "{}_{}".format(sid, cid)
        
        # write fasta
        rt.write_fasta(os.path.join(output_path, "{}.fasta".format(tag)), ':'.join(seqs), info=os.path.basename(save_path))
        
        # debug
        print("{}, {:.3f}".format(tag, rt.recovery_rate(y, p)))

100%|██████████| 1/1 [00:04<00:00,  4.99s/it]

1JTG_A, 0.553





In [5]:
seq_ref = rt.max_pred_to_seq(y)
rt.sequence_identity(seq_ref, ''.join(seqs)), rt.sequence_similarity(seq_ref, ''.join(seqs))

(0.5534351145038168, 0.7480916030534351)