In [1]:
import datasets
import os
from tqdm import tqdm
import time 
import torch
import faiss
from transformers import AutoModel, AutoTokenizer, EsmForProteinFolding
datasets.disable_caching()
faiss.omp_set_num_threads(4)


MMSEQS_PATH = "/path/to/mmseqs2"  # Replace with the actual path to your mmseqs2 installation
FOLDSEEK_PATH = "/path/to/foldseek"  # Replace with the actual path to your Foldseek installation

In [2]:
# Download the dataset
dataset = datasets.load_dataset("tattabio/bac_arch_bigene", keep_in_memory=True)['train']

In [None]:
with open("bac_test.fasta", "w") as fasta_file:
    for example in dataset:
        seq = example["Seq1"]
        accession = example["Seq1_accession"]
        fasta_file.write(f">{accession}\n{seq}\n")

with open("arc_test.fasta", "w") as fasta_file:
    for example in dataset:
        seq = example["Seq2"]
        accession = example["Seq2_accession"]
        fasta_file.write(f">{accession}\n{seq}\n")

correct_matches = {}
for example in tqdm(dataset):
    seq = example["Seq1"]
    accession = example["Seq1_accession"]
    correct_match = example["Seq2_accession"]
    correct_matches[accession] = correct_match


In [None]:
# BlastP benchmark
blast_command = f"blastp -query bac_test.fasta -subject arc_test.fasta -outfmt 6 -out blast_results.txt"
start = time.time()
os.system(blast_command)
end = time.time()
time_taken = end - start
print(f"BlastP Time taken: {time_taken} seconds")

correct = 0
seen = set()  # to keep track of already seen queries   
with open(f"blast_results.txt", "r") as blast_file:
    # for k=1 to 10 check if the top hit is included

    lines = blast_file.readlines()
   
    for line in lines: 
        query = line.split()[0]
        if query not in seen: 
            hit = line.split()[1]
            correct_match = correct_matches.get(query, None)
            if correct_match == hit:
                correct += 1
        seen.add(query)
# calculate recall 
recall = correct / len(correct_matches)
print(f"BlastP recall: {recall:.4f}")

In [None]:
# Mmseqs benchmark
correct = 0
start = time.time()
os.system(f"{MMSEQS_PATH} createdb arc_test.fasta arc_test_db")
mmseqs_command = f"{MMSEQS_PATH} easy-search bac_test.fasta arc_test_db mmseqs_results.m8 tmp --threads 4 -v 0"
os.system(mmseqs_command)
end = time.time()
time_taken = end - start
print(f"Mmseqs Time taken: {time_taken} seconds")

correct = 0 
seen = set()

with open(f"mmseqs_results.m8", "r") as mmseqs_file:
    # if the file is empty, skip to the next iteration
    lines = mmseqs_file.readlines()
    for line in lines:
        query = line.split()[0]
        if query not in seen:
            hit = line.split()[1]
            correct_match = correct_matches.get(query, None)
            if correct_match == hit:
                correct += 1
        seen.add(query)
        
# calculate recall
recall = correct / len(correct_matches)
print(f"Mmseqs recall: {recall:.4f}")

In [None]:
# Protein structure prediction for Foldseek using ESMFold
def standardize_amino_acids(sequence: str) -> str:
    modified_amino_acid_map = {
        "O": "K",  # Pyrrolysine -> Lysine
        "U": "C",  # Selenocysteine -> Cysteine
        "J": "L",  # Isoleucine/Leucine -> Leucine
    }
    for modified_amino_acid, standard_amino_acid in modified_amino_acid_map.items():
        sequence = sequence.replace(modified_amino_acid, standard_amino_acid)
    return sequence

class ESMFoldModel:
    MODEL_NAME = "facebook/esmfold_v1"
    NUM_RECYCLES = 4

    def __init__(self, device = 'cuda'):
        self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
        self.model = EsmForProteinFolding.from_pretrained(self.MODEL_NAME).eval().to(device)
        self.device = device

    def inference(self, protein_str):
        protein_str = standardize_amino_acids(protein_str)
        input_ids = self.tokenizer([protein_str], return_tensors="pt", add_special_tokens=False)['input_ids'].to(self.device)
        with torch.inference_mode(), torch.autocast(
            device_type="cuda", dtype=torch.float16, 
        ):
            output = self.model(input_ids, num_recycles=self.NUM_RECYCLES)
        pdb = self.model.output_to_pdb(output)[0]
        return pdb

def write_pdb_str_to_file(pdb_str: str, file_path: str):
    with open(file_path, 'w') as f:
        # Replace the ESMFold header with the expected header.
        pdb_str = pdb_str.replace('PARENT N/A', '')
        f.write('CRYST1    1.000    1.000    1.000  90.00  90.00  90.00 P 1           1          ')
        f.write(pdb_str)
    

esmfold_model = ESMFoldModel(device='cuda')

os.makedirs('structures_seq1', exist_ok=True)
os.makedirs('structures_seq2', exist_ok=True)

def process_row(row):
    # Skip if PDB files already exist for this pair
    seq1_id, seq2_id = row['Seq1_accession'], row['Seq2_accession']
    if os.path.exists(f'structures_seq1/{seq1_id}.pdb') and os.path.exists(f'structures_seq2/{seq2_id}.pdb'):
        return
    seq1_pdb = esmfold_model.inference(row['Seq1'])
    seq2_pdb = esmfold_model.inference(row['Seq2'])
    write_pdb_str_to_file(seq1_pdb, f'structures_seq1/{seq1_id}.pdb')
    write_pdb_str_to_file(seq2_pdb, f'structures_seq2/{seq2_id}.pdb')

start_time = time.time()
_ = dataset.map(process_row)
end_time = time.time()
print(f"Structure prediction (ESMFold) for Foldseek Time taken: {end_time - start_time} seconds")


In [None]:
# Foldseek benchmark
total_time = 0
correct = 0
start = time.time()
createdb_command1 = f"/mnt/raid0/apps/foldseek/bin/foldseek createdb structures2 targetDB"
createdb_command2 = f"/mnt/raid0/apps/foldseek/bin/foldseek createdb structures1 queryDB"
os.system(createdb_command1)
os.system(createdb_command2)
foldseek_command = f"/mnt/raid0/apps/foldseek/bin/foldseek easy-search queryDB targetDB foldseek_results.m8 /tmp --threads 4 -v 0"
os.system(foldseek_command)
end = time.time()
time_taken = end - start
seen = set()

with open(f"foldseek_results.m8", "r") as foldseek_file:

    lines = foldseek_file.readlines()
    for line in lines:
        # for each line, check if the query has already been seen
        # if it has, skip to the next line
        # if it hasn't, check if the top hit is the correct match
        query = line.split()[0]
        if query not in seen:
            # get the top hit
            top_hit = line.split()[1]
            # get the correct match for the query
            val = correct_matches.get(query, None)
            if val is not None and top_hit == val:
                correct += 1
        seen.add(query)
    
os.remove(f"foldseek_results.m8")
recall = correct / len(correct_matches)
print(f"Foldseek Time taken: {total_time} seconds")
print(f"Foldseek recall: {recall:.4f}")

In [None]:
# gLM2 Benchmark
glm2_model_name = "tattabio/gLM2_650M_embed"
model = AutoModel.from_pretrained(glm2_model_name, trust_remote_code=True)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(glm2_model_name, trust_remote_code=True)

def infer_fn(examples, seq_col="Seq1"):
    sequences = examples[seq_col]
    inputs = tokenizer(
        sequences, return_tensors="pt", padding=True, 
    )
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    attention_mask = inputs["attention_mask"].bool()
    with torch.no_grad(), torch.cuda.amp.autocast():
        outputs = model(inputs['input_ids'], attention_mask=attention_mask)
        hidden = outputs.pooler_output.float()
        hidden = torch.nn.functional.normalize(hidden, dim=-1)
    return {f"{seq_col}_features": hidden.cpu().numpy()}

start_time = time.time()
glm2_ds = dataset.map(lambda x: infer_fn(x, seq_col="Seq1"), batched=True, batch_size=64)
glm2_ds = glm2_ds.map(lambda x: infer_fn(x, seq_col="Seq2"), batched=True, batch_size=64)
total_time = time.time() - start_time
print(f"GLM2 inference time: {total_time:.2f} seconds")
glm2_ds.set_format(type="numpy")

start_time = time.time()
glm2_ds = glm2_ds.add_faiss_index(column="Seq1_features", metric_type=faiss.METRIC_INNER_PRODUCT)
seq2_features = glm2_ds["Seq2_features"]
scores, hits = glm2_ds.get_nearest_examples_batch("Seq1_features", seq2_features, k=1)
hits = [hit['Seq1_accession'] for hit in hits]
gt = glm2_ds['Seq1_accession']
accuracy = sum([1 for hit, gt in zip(hits, gt) if hit == gt]) / len(gt)

print(f"GLM2 accuracy: {accuracy:.4f}")
total_time = time.time() - start_time
print(f"GLM2 search time: {total_time:.2f} seconds")


In [None]:
# ESM2 Benchmark
esm2_model_name = "facebook/esm2_t33_650M_UR50D"
model = AutoModel.from_pretrained(esm2_model_name)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(esm2_model_name)

def infer_fn(examples, seq_col="Seq1", layer='mid'):
    sequences = examples[seq_col]
    inputs = tokenizer(
        sequences, return_tensors="pt", padding=True,
    )
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    attention_mask = inputs["attention_mask"].bool()
    with torch.no_grad(), torch.cuda.amp.autocast():
        outputs = model(inputs['input_ids'], attention_mask=attention_mask, output_hidden_states=True)
        layer_idx = len(model.encoder.layer) // 2 - 1 if layer == 'mid' else -1
        hiddens = outputs.hidden_states[layer_idx].float()
        mask = attention_mask.unsqueeze(-1)
        # Mean pool
        hidden = torch.where(mask, hiddens, 0.0)
        hidden = torch.sum(hidden, 1) / torch.sum(mask, dim=1, dtype=hidden.dtype)
        hidden = torch.nn.functional.normalize(hidden, dim=-1)
    return {f"{seq_col}_features": hidden.cpu().numpy()}

start_time = time.time()
esm2_ds = dataset.map(lambda x: infer_fn(x, seq_col="Seq1"), batched=True, batch_size=64)
esm2_ds = esm2_ds.map(lambda x: infer_fn(x, seq_col="Seq2"), batched=True, batch_size=64)
total_time = time.time() - start_time
print(f"ESM2 inference time: {total_time:.2f} seconds")
esm2_ds.set_format(type="numpy")

start_time = time.time()
esm2_ds = esm2_ds.add_faiss_index(column="Seq1_features", metric_type=faiss.METRIC_INNER_PRODUCT)
seq2_features = esm2_ds["Seq2_features"]
scores, hits = esm2_ds.get_nearest_examples_batch("Seq1_features", seq2_features, k=1)
hits = [hit['Seq1_accession'] for hit in hits]
gt = esm2_ds['Seq1_accession']
accuracy = sum([1 for hit, gt in zip(hits, gt) if hit == gt]) / len(gt)

print(f"ESM2 accuracy: {accuracy:.4f}")
total_time = time.time() - start_time
print(f"ESM2 search time: {total_time:.2f} seconds")
