In [1]:
import os
import sys

from biopandas.pdb import PandasPdb
import numpy as np
import torch
from transformers import EsmTokenizer

sys.path.append("../")
from saprot_utils.foldseek_util import get_struc_seq

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
def process_single_file(filename, folder, foldseek_path="../../saprot_utils/bin/foldseek"):
    """Generate SaProt structure sequence 
    for a given structure file `filename` in folder `folder`"""
    file = os.path.join(folder, filename)
    chain = PandasPdb().read_pdb(file).df["ATOM"].chain_id.unique()
    if len(chain) > 1:
        raise ValueError(f"Expected one chain in epitope structure {filename}, got {len(chain)}: {chain}")
    chain = chain[0]
    seq, foldseek_seq, combined_seq = get_struc_seq(foldseek_path, file, [chain])[chain]
    return {
        'pdb_path': filename,
        'seq': seq,
        'chain': chain,
        'saprot_seq': combined_seq,
    }

# Get SaProt structural sequences using `foldseek`

In [14]:
# replace if your have own path to foldseek binary file
foldseek_path = '../../saprot_utils/bin/foldseek'

In [15]:
pdb_pairs = [
    ['7LM9_A.pdb', '7RBY_C.pdb'],
    ['8FDW_A.pdb', '8U1G_A.pdb']
]
pdb_folder = "inference_examples/"

In [None]:
saprot_seqs = []
for pdb_pair in pdb_pairs:
    seqs = list(map(
        lambda x: process_single_file(x, pdb_folder, foldseek_path=foldseek_path)["saprot_seq"], 
        pdb_pair
    ))
    saprot_seqs.append(seqs)

# Run inference

## Prepare inputs

In [7]:
tokenizer = EsmTokenizer.from_pretrained("westlake-repl/SaProt_35M_AF2")

# Load trained model
trained_model_path = "best_model.pt"
trained_model = torch.load(trained_model_path, map_location=torch.device("cuda:0"))
trained_model.eval()
print("Loaded trained model")

Loaded trained model


In [8]:
predictions = []
for i, (saprot_seq_1, saprot_seq_2) in enumerate(saprot_seqs):
    inputs = list(map(lambda x: tokenizer(x, return_tensors="pt"), [saprot_seq_1, saprot_seq_2]))
    inputs = {"antigen_epitope": inputs}
    with torch.no_grad():
        match_prediction = trained_model(**inputs)
    prediction = match_prediction[0, :, :, 0].cpu().detach().numpy()
    predictions.append(prediction)

    # save predictions
    ag_file, epi_file = pdb_pairs[i]
    save_file = f"prediction_" \
                + ".".join(ag_file.split(".")[:-1]) \
                + "_" \
                + ".".join(epi_file.split(".")[:-1])
    np.save(save_file, prediction)