In [1]:
import os
import pandas as pd
from glob import glob
from tqdm import tqdm
from time import time

import src as sp

In [2]:
# parameters
pdb_filepaths = glob("benchmark_data/wt/monomers/*.pdb")
#pdb_filepaths = glob("benchmark_data/wt/dimers/*.pdb")
output_dir = "benchmark_data/mpnn/monomers"
#output_dir = "benchmark_data/mpnn/dimers"
#output_dir = "/tmp"

In [3]:
# load structure
profiling = []
for pdb_filepath in tqdm(pdb_filepaths):
    t0 = time()
    pdbid = os.path.basename(pdb_filepath)[:-4]
    
    # find subunits
    subunits = sp.split_by_chain(sp.read_pdb(pdb_filepath))
    cids = [cid.split(':')[0] for cid in subunits]
    
    # predict sequence on one chain at a time
    for cid in cids:
        # define key
        if "_" in pdbid:
            key = pdbid
        else:
            key = "{}_{}".format(pdbid, cid)
    
        # get reference sequences
        seqs = {cid.split(':')[0]:sp.subunit_to_sequence(subunits[cid]) for cid in subunits}
    
        # run ProteinMPNN
        command = ' '.join([
            "python", "ProteinMPNN/protein_mpnn_run.py",
            "--pdb_path", pdb_filepath,
            "--pdb_path_chains", cid,
            "--out_folder", output_dir,
            "--num_seq_per_target", "1",
            "--sampling_temp", "0.000001",
            "--seed", "37",
            "--batch_size", "1",
        ])
        #os.system(command)
        output = os.popen(command).read()
    
        # locate and rename output
        fa_filepath = os.path.join(output_dir, "seqs", pdbid+".fa")
        new_fa_filepath = os.path.join(output_dir, "seqs", "{}.fa".format(key))
        os.rename(fa_filepath, new_fa_filepath)
        fa_filepath = new_fa_filepath
        
        # read file
        with open(fa_filepath, 'r') as fs:
            fa_str = fs.read()
            
        # parse file and update sequence
        sid = os.path.basename(fa_filepath).split('.')[0]
        seqs[cid] = fa_str.split('\n')[3]
        
        # get multimer sequences
        seq = ':'.join([seqs[cid] for cid in cids])
            
        # write corresponding file
        with open(os.path.join(output_dir, "{}.fasta".format(key)), 'w') as fs:
            fs.write('>{}\n{}'.format(sid,seq))
    
    # profiling
    t1 = time()
    #profiling.append({'sid': sid, 'dt': t1-t0})
    profiling.append({'sid': sid, 'dt': float(output.split()[-2])})

# save profiling
dfp = pd.DataFrame(profiling)
dfp.to_csv("results/profiling_mpnn.csv", index=False)

100%|██████████| 142/142 [07:02<00:00,  2.97s/it]
