In [1]:
from collections import defaultdict
import pandas as pd
import ast
from matplotlib import pyplot as plt


In [None]:
# Load gLM2 and ESM2 results
glm2_results = pd.read_csv('data/gLM2_embed_k100_retrieval_results.csv')
esm2_results = pd.read_csv('data/ESM2_k100_retrieval_results.csv')
# Load MMseqs2 results
mmseqs = defaultdict(list)
with open('data/mmseqs_results_seq_bm.m8', "r") as f:
    for line in f:
        query, subject = line.strip().split("\t")[:2]
        #Remove self hit
        if query != subject:
            mmseqs[query].append(subject)
mmseqs_results = pd.DataFrame({'query': list(mmseqs.keys()), 'search_result (k=100)': list(mmseqs.values())})

def convert_to_list(column):
    try:
        # Remove extra quotes around the list and convert
        column = column.replace('""', '"')
        return ast.literal_eval(column)
    except Exception as e:
        print(f"Error converting to list: {e}")
        return []
    
def get_rank(row):
    try:
        return row['search_result (k=100)'].index(row['expected']) + 1
    except ValueError:
        return float('inf')  # Large number to signify not found
    

# Function to calculate recall scores for given k values
def calculate_recall_scores(df_split, k_values):
    recall_scores = []
    for k in k_values:
        recall_at_k = (df_split['rank'] <= k).mean()
        recall_scores.append(recall_at_k)
    return recall_scores

    
def process_results(df):
    df['search_result (k=100)'] = df['search_result (k=100)'].map(convert_to_list)
    df['expected'] = df['expected'].map(lambda x: x.strip('"'))
    # Remove from each row the self hit from search_result
    df['search_result (k=100)'] = df.apply(lambda row: [item for item in row['search_result (k=100)'] if item != row['query']], axis=1)
    df['rank'] = df.apply(get_rank, axis=1)
    return df

esm2_results = process_results(esm2_results)
glm2_results = process_results(glm2_results)

for _, row in glm2_results.iterrows():
    assert row['query'] not in row['search_result (k=100)'], f"Self-hit found for query {row['query']}"

# Reorder mmseqs_results and add expected
mmseqs_results = mmseqs_results.set_index('query').reindex(esm2_results['query']).reset_index()
mmseqs_results['expected'] = esm2_results['expected']
mmseqs_results['rank'] = mmseqs_results.apply(get_rank, axis=1)

# Define k values
k_values = range(1, 101)
glm2_recall_scores = calculate_recall_scores(glm2_results, k_values)
esm2_recall_scores = calculate_recall_scores(esm2_results, k_values)
mmseqs_recall_scores = calculate_recall_scores(mmseqs_results, k_values)
blastp_recall_scores = [1.0 for _ in k_values] # BLASTp is the ground truth


plt.figure(figsize=(6, 6), dpi=300)
plt.plot(k_values, glm2_recall_scores, label='Gaia', color='#2edd97')
plt.plot(k_values, esm2_recall_scores, label='ESM2', color='#00A5FF')
plt.plot(k_values, mmseqs_recall_scores, label='MMseqs2', color='#FFA500')
plt.plot(k_values, blastp_recall_scores, label='BLASTp', color='gray', linestyle='--')


plt.xlabel('K', fontsize=14)
plt.ylabel('Recall', fontsize=14)
plt.legend(loc='lower right', fontsize=12)
plt.title('Protein sequence sensitivity', fontsize=16)
plt.grid(False)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlim(1, 10)