In [1]:
import os
import numpy as np
import torch as pt
import pandas as pd
from glob import glob
from tqdm import tqdm
from time import time

import esm.inverse_folding

In [2]:
# parameters
device = pt.device("cuda")

# load model
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval().to(device)



### Monomers

In [3]:
# locate benchark structures
pdb_filepaths = glob("benchmark_data/wt/monomers/*.pdb")

profiling = []
for pdb_filepath in tqdm(pdb_filepaths):
    t0 = time()
    # load structure
    structure = esm.inverse_folding.util.load_structure(pdb_filepath)
    coords, seq = esm.inverse_folding.util.extract_coords_from_structure(structure)
    
    # sample sequence
    sampled_seq = model.sample(coords, temperature=1e-6, device=device)
    
    # save sequence
    sid = os.path.basename(pdb_filepath).split('.')[0]
    with open("benchmark_data/esm/monomers/{}.fasta".format(sid), 'w') as fs:
        fs.write(">{}\n{}".format(sid, sampled_seq))
        
    # profiling
    t1 = time()
    profiling.append({'sid': sid, 'dt': t1-t0})

# save profiling
dfp = pd.DataFrame(profiling)
dfp.to_csv("results/profiling_esm.csv", index=False)

100%|██████████| 142/142 [03:14<00:00,  1.37s/it]


### Dimers

In [None]:
# locate benchark structures
pdb_filepaths = glob("benchmark_data/wt/dimers/*.pdb")

for pdb_filepath in tqdm(pdb_filepaths):
    # load structure
    structure = esm.inverse_folding.util.load_structure(pdb_filepath)
    coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure)

    # get chain ids
    cid0, cid1 = list(coords)
    
    # sample sequences
    seq0 = ':'.join([
        esm.inverse_folding.multichain_util.sample_sequence_in_complex(model, coords, cid0, temperature=1e-6),
        native_seqs[cid1],
    ])
    seq1 = ':'.join([
        native_seqs[cid0],
        esm.inverse_folding.multichain_util.sample_sequence_in_complex(model, coords, cid1, temperature=1e-6),
    ])
        
    # save sequences
    sid = os.path.basename(pdb_filepath).split('.')[0]
    with open("benchmark_data/esm/dimers/{}_{}.fasta".format(sid, cid0), 'w') as fs:
        fs.write(">{}\n{}".format(sid, seq0))
    with open("benchmark_data/esm/dimers/{}_{}.fasta".format(sid, cid1), 'w') as fs:
        fs.write(">{}\n{}".format(sid, seq1))