In [None]:
import glob
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.nonparametric.smoothers_lowess import lowess
from scipy.signal import savgol_filter
from headers import headers

In [None]:
tools_full_names = {'fs_pref': 'Foldseek', 
                    'fs_exh': 'Foldseek, without prefilter', 
                    'hmmscan_pref': "HMMER", 
                    'hmmscan_exh': "HMMER, without prefilter", 
                    'mm_pref': "MMseqs", 
                    'mm_exh': "MMseqs, without prefilter",
                    'reseek_fast': "Reseek (fast)", 
                    'reseek_10_fast': "Reseek (fast) evalue 10", 
                    'reseek_sens': "Reseek, sensitive",  
                    'reseek_exh': "Reseek, without prefilter",
                    'tm_exh': "TM-align"}

In [None]:
colors = {"fs": '#1f77b4', 'mm': '#ff7f0e', "hmmscan": '#2ca02c', "reseek": '#d62728', "tm": '#9467bd'}
line_style = {"exh": "-", "pref": "--"}

In [None]:
fig_dir = "../figures/"
data_dir = "../data/"

with open(f"{data_dir}/processed/f1_data.pkl", 'rb') as file:
    f1_data = pickle.load(file)

In [None]:
from math import log2
plt.figure(figsize=(10, 6), dpi=300)

summary_file = open("../data/processed/precision_recall_summary.tsv", 'w')
summary_file.write("\t".join(["search_type", "max_f1", "max f1 lower confidence interval", 
                              "max f1 upper confidence interval", 
                              "precision at max f1", "recall at max f1", "evalue(or TM score) bin of highest f1"
                             ]) + "\n")

sig_thresh_max_f1 = {}
for search_type in f1_data.keys():
    if "reseek_fast" in search_type:
        continue  ## Reseek fast data looked odd !!
    # Calculate F1 scores
    if "_exh" in search_type:
        linestyle = "-"
    else:
        linestyle = "--"
    
    precision_recall_df = f1_data[search_type]["precision_vs_recall"]
    if "tm_" not in search_type:
        precision_recall_df = precision_recall_df[precision_recall_df["evalue_bin"] >= (-9 * log2(10))]  # Remove hits with an e-value below 10^^-9
    else:
        precision_recall_df = precision_recall_df[precision_recall_df["evalue_bin"] >= -90] 
        temp = precision_recall_df
    
    tool_name = search_type.split("_")[0]
    
    x_axis = precision_recall_df["precision"]
    y_axis = precision_recall_df["recall"]
    

    
    # Plot the precision-recall curve
    max_f1 = f1_data[search_type]["max_f1"]
    plt.plot(x_axis, y_axis, label=f'{tools_full_names[search_type]} (Max F1={max_f1:.3f})', linestyle=linestyle, color=colors[tool_name])
    
    if "tm" in search_type:
        evalue_bin = -f1_data[search_type]['max_f1_evalue_bin']/100
    else:
        evalue_bin = 2 ** f1_data[search_type]['max_f1_evalue_bin']
    summary_file.write(f"{tools_full_names[search_type]}\t{f1_data[search_type]['max_f1']}\t{f1_data[search_type]['f1_ci_lower']}\t{f1_data[search_type]['f1_ci_upper']}\t{f1_data[search_type]['max_f1_precision']}\t{f1_data[search_type]['max_f1_recall']}\t{evalue_bin}\n")

    sig_thresh_max_f1[search_type] = evalue_bin
    
plt.legend()
plt.xlabel('Precision')
plt.ylabel('Recall')
plt.grid(True, alpha=0.3)
plt.savefig(f"{fig_dir}/precision_recall_split_vs_split.png")
plt.show()
summary_file.close()

In [None]:
plt.figure(dpi=300)

# Extract keys and values
sorted_x = ['hmmscan_pref', 'hmmscan_exh', 'mm_pref', 
                    'mm_exh','fs_pref', 'fs_exh', 'reseek_sens',  
                    'reseek_exh','tm_exh']
x = [tools_full_names[x] for x  in sorted_x]
y = [sig_thresh_max_f1[i] for i in sorted_x]
#list(sig_thresh_max_f1.values())

bars = plt.bar(x, y)

# Write y-values on top of each bar
for bar in bars:
    plt.text(
        bar.get_x() + bar.get_width() / 2,  # X position (middle of bar)
        bar.get_height(),                   # Y position (top of bar)
        str(bar.get_height()),              # The actual Y value
        ha='center', va='bottom'            # Align center and above bar
    )
plt.xlabel("Search method")
plt.xticks(rotation=30, ha='right')  # rotate and align to right
plt.ylabel("E-value bin or TM-score")
plt.savefig(f"{fig_dir}/ali_sig_with_highest_f1.png", bbox_inches="tight")
plt.show()