In [1]:
import os
import numpy as np
import torch as pt
import pandas as pd
from tqdm import tqdm
from glob import glob
from time import time

import src as sp
import runtime as rt

In [2]:
# parameters
device = pt.device("cpu")
min_id = False 
min_sim = False 
r_noise = 0.0
n_sample = 1

# locate filepaths
pdb_filepaths = glob("benchmark_data/wt/monomers/*.pdb")
#pdb_filepaths = glob("benchmark_data/wt/dimers/*.pdb")

# output path
output_path = "benchmark_data/carbonara/monomers/maxseqid"
#output_path = "benchmark_data/carbonara/monomers/minseqid"
#output_path = "benchmark_data/carbonara/monomers/minseqsim"
#output_path = "benchmark_data/carbonara/dimers/maxseqid"
#output_path = "benchmark_data/carbonara/dimers/minseqid"
#output_path = "benchmark_data/carbonara/dimers/minseqsim"
#output_path = "benchmark_data/carbonara7/dimers/maxseqid"
#output_path = "benchmark_data/carbonara7d/dimers/maxseqid"

# model parameters
# r6
save_path = "model/save/s_v6_4_2022-09-16_11-51"  # virtual Cb & partial
#save_path = "model/save/s_v6_5_2022-09-16_11-52"  # virtual Cb, partial & noise

# r7
#save_path = "model/save/s_v7_0_2023-04-25"  # partial chain
#save_path = "model/save/s_v7_1_2023-04-25"  # partial chain and noise
#save_path = "model/save/s_v7_2_2023-04-25"  # partial chain high coverage
#save_path = "model/save/s_v7_3_2023-04-25"  # partial chain and noise and high coverage

# create runtime model 
model = rt.SequenceModel(save_path, "model.pt", device=device)

# create confidence mapping
#conf = rt.ConfidenceMap("results/{}_cdf.csv".format(os.path.basename(save_path)))

In [3]:
profiling = []
for pdb_filepath in tqdm(pdb_filepaths):
    t0 = time()
    # load structure
    structure = rt.load_structure(pdb_filepath)
    
    # max size
    if structure['xyz'].shape[0] > 1024*8:
        continue
    
    # find all chains
    cids = [cid.split(':')[0] for cid in np.unique(structure['chain_name'])]
    
    # design one chain at a time
    for cid in cids:
        # load structure
        structure = rt.load_structure(pdb_filepath)

        # known chains
        m_known = ~np.isin([cn.split(':')[0] for cn in structure['chain_name']], [cid])

        # apply noise
        r = pt.randn((structure['xyz'].shape[0], n_sample, structure['xyz'].shape[1])).numpy()
        structure['xyz'] = np.expand_dims(structure['xyz'],1) + r_noise * r
        
        # apply model
        structure, P, y = model(structure, m_known=m_known)

        # get highest average confidence prediction
        if len(P.shape) > 2:
            C = np.stack([conf(P[i].numpy()) for i in range(P.shape[0])])
            iopt = np.argmax(np.mean(np.max(C, axis=2), axis=1))
            p = P[iopt]
        else:
            p = P

        # amino acid only
        p, y = rt.aa_only(p, y)
        
        # find residues per chain and known residues
        mr_chains, chain_names = rt.chain_masks(structure)
        mr_known = pt.any(mr_chains[:,pt.from_numpy(~np.isin(chain_names, [cid]))], dim=1)
        
        # overwrite prediction
        p[mr_known] = y[mr_known]
        
        if min_id:
            # minimize sequence identity
            m_choice = (pt.sum(pt.round(p), dim=1) > 1.0)
            p[m_choice, pt.argmax(y, dim=1)[m_choice]] = 0.0
    
        # get sequence
        if min_sim:
            seqs = [rt.minimize_sequence_similarity(p[mr_chains[:,i]], y[mr_chains[:,i]])[0] for i in range(mr_chains.shape[1])]
        else:
            seqs = [rt.max_pred_to_seq(p[mr_chains[:,i]]) for i in range(mr_chains.shape[1])]
    
        # subunits ids
        sid = os.path.basename(pdb_filepath).split('.')[0]
        if "_{}".format(cid) in sid:
            tag = sid
        else:
            tag = "{}_{}".format(sid, cid)
        
        # write fasta
        rt.write_fasta(os.path.join(output_path, "{}.fasta".format(tag)), ':'.join(seqs), info=os.path.basename(save_path))
        
        # debug
        #print("{}, {:.3f}".format(tag, rt.recovery_rate(y, p)))
            
        # profiling
        t1 = time()
        profiling.append({'sid': tag, 'dt': t1-t0})
    
# save profiling
dfp = pd.DataFrame(profiling)
dfp.to_csv("results/profiling_carbonara.csv", index=False)

100%|██████████| 142/142 [07:17<00:00,  3.08s/it]
