In [16]:
import pickle
from matplotlib import pyplot as plt

In [19]:
with open("data/mmseqs_results_3k.pkl", "rb") as f:
    results_mmseqs = pickle.load(f)
with open("data/glm2_results_3k.pkl", "rb") as f:
    results_glm2 = pickle.load(f)
with open("data/esm2_results_3k.pkl", "rb") as f:
    results_esm2 = pickle.load(f)
with open("data/blastp_results_3k.pkl", "rb") as f:
    results_blastp = pickle.load(f)


In [None]:
k_values = range(1,11)
plot_df = []
match_pct_threshold = 0.7

def has_context_matches(result, k, threshold):
    if not result:
        return 0
    return int(max(result[:k]) >= threshold)

def get_recall_at_k(results, k, threshold):
    return sum([has_context_matches(result, k, threshold) for result in results]) / len(results)


glm2_recalls = []
esm2_recalls = []
mmseqs_recalls = []
blastp_recalls = []
for k in k_values:
    mmseqs_recalls.append(get_recall_at_k(results_mmseqs, k, match_pct_threshold))
    glm2_recalls.append(get_recall_at_k(results_glm2, k, match_pct_threshold))
    esm2_recalls.append(get_recall_at_k(results_esm2, k, match_pct_threshold))
    blastp_recalls.append(get_recall_at_k(results_blastp, k, match_pct_threshold))

plt.figure(figsize=(6, 6), dpi=300)
plt.plot(k_values, mmseqs_recalls, label='MMseqs2', color='#FFA500')
plt.plot(k_values, glm2_recalls, label='Gaia', color='#2edd97')
plt.plot(k_values, esm2_recalls, label='ESM2', color='#00A5FF')
plt.plot(k_values, blastp_recalls, label='BLASTp', color='gray')


plt.xlabel('K', fontsize=14)
plt.ylabel('Recall', fontsize=14)
plt.legend(loc='lower right', fontsize=12)
plt.title('Genomic context sensitivity', fontsize=16)
plt.grid(False)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlim(1, 10)

