In [6]:
import pandas as pd
from matplotlib import pyplot as plt

In [None]:
def get_rank(row, level='family'):
    """Get rank at family/superfamily/fold level by parsing SCOP IDs like 'a.1.2.3'"""
    target = row['family']
    hits = row['hits_family']

    if level == 'superfamily':
        target = '.'.join(target.split('.')[:3])
        hits = ['.'.join(h.split('.')[:3]) for h in hits]
    elif level == 'fold':
        target = '.'.join(target.split('.')[:2])
        hits = ['.'.join(h.split('.')[:2]) for h in hits]

    return next((i + 1 for i, h in enumerate(hits) if h == target), float('inf'))
    
scop_lookup = dict(map(str.split, open('scop_lookup.tsv')))

def parse_search_results(file):
    df = pd.read_csv(file, sep='\t', header=None, names=['id', 'hits', 'score'])
    df = df[df['id'] != df['hits']]  # Remove self-hits
    # Group hits and process SCOP classifications
    results = df.groupby('id')['hits'].apply(list).reset_index()
    results['family'] = results['id'].apply(lambda x: scop_lookup[x])
    results['hits_family'] = results['hits'].apply(lambda x: [scop_lookup[item] for item in x])
    results['rank_family'] = results.apply(lambda x: get_rank(x, level='family'), axis=1)
    results['rank_superfamily'] = results.apply(lambda x: get_rank(x, level='superfamily'), axis=1)
    results['rank_fold'] = results.apply(lambda x: get_rank(x, level='fold'), axis=1)
    return results

blastp_results = parse_search_results('search_result/blastp.tsv')
mmseqs_results = parse_search_results('search_result/mmseqs2.tsv')
foldseek_results = parse_search_results('search_result/foldseek.tsv')
esm2_results = parse_search_results('search_result/esm2_t33_650M_UR50D.tsv')
glm2_results = parse_search_results('search_result/gLM2_650M_embed.tsv')

In [None]:
for level in ['family', 'superfamily', 'fold']:

    # Function to calculate recall scores for given k values
    def calculate_recall_scores(df_split, k_values):
        recall_scores = []
        for k in k_values:
            recall_at_k = (df_split[f'rank_{level}'] <= k).mean()
            recall_scores.append(recall_at_k)
        return recall_scores

    # Define k values
    k_values = range(1, 101)

    # Calculate recall scores
    esm2_recall_scores = calculate_recall_scores(esm2_results, k_values)
    glm2_recall_scores = calculate_recall_scores(glm2_results, k_values)
    foldseek_recall_scores = calculate_recall_scores(foldseek_results, k_values)
    mmseqs_recall_scores = calculate_recall_scores(mmseqs_results, k_values)
    blastp_recall_scores = calculate_recall_scores(blastp_results, k_values)

    plt.figure(figsize=(6, 6), dpi=300)
    plt.plot(k_values, glm2_recall_scores, label='Gaia', color='#2EDD97')
    plt.plot(k_values, esm2_recall_scores, label='ESM2', color='#00A5FF')
    plt.plot(k_values, mmseqs_recall_scores, label='MMseqs2', color='#FFA500')
    plt.plot(k_values, blastp_recall_scores, label='BLASTp', color='gray')
    plt.plot(k_values, foldseek_recall_scores, label='Foldseek', color='purple')

    plt.xlabel('K', fontsize=14)
    plt.ylabel('Recall@K', fontsize=14)
    plt.legend(loc='lower right', fontsize=12)
    plt.title(f'Protein structure {level} sensitivity', fontsize=16)
    plt.grid(False)

    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlim(1, 30)