In [None]:
import glob
import os
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from sklearn.utils import resample
from scipy.interpolate import interp1d
from sens_up2_first_fp import auc_tps_bef_1st_fp
from find_optimal_bins import find_optimal_bins
from cve import cve

In [None]:
fig_dir = "../figures/"
data_dir = "../data/"

In [None]:
# Read the SFFP data from the pickle file
with open(f"{data_dir}/processed/sffp.pkl", 'rb') as file:
    loaded_data = pickle.load(file)

cum_sens_up2first_fp_dict_fam = loaded_data["sffp"]
cum_sens_up2first_fp_dict_clan = loaded_data["sffp_clan"]
auc_sffp = loaded_data["auc_sffp"]
ci_info = loaded_data["ci_info"]
frac_sens_up2first_fp_dict_fam = loaded_data["frac_info"]

In [None]:
db_name = {"mm": "MMseqs", "fs_cut": "Foldseek (fs_cut)", "cif_cut": "Foldseek (cif_cut)", "reseek": "Reseek", "tm": "TM-align"}
tools = cum_sens_up2first_fp_dict_fam.keys()

# Overall performance plots
Sensitivity up to the first False Positive (FP) at Pfam level

## Sensitivity up to first FP

In [None]:
def plot_sens_plot_from_cum_data(sens_dict, tool, column, auc_sffp=None, sffp_ci_lower=None, sffp_ci_upper=None):
    sens_df = sens_dict[tool]
    x_axis = sens_df["cum_sum_frac"]
    y_axis = sens_df[column]
    if auc_sffp is None:
        auc_sffp = np.trapz(y_axis, x_axis)
    plt.plot(x_axis, y_axis, label=f"{db_name[tool]}: AUC = {auc_sffp:.2f}")
    if not(sffp_ci_lower is None):
        plt.fill_between(x_axis, sffp_ci_lower, sffp_ci_upper, alpha=0.2)

### Family level plots

In [None]:
plt.figure(dpi=300)

outfile = open(f"{data_dir}/processed/ci_auc_sffp_fam.tsv", 'w')
print("AUC values with 95% CI:")
for tool in db_name.keys():
    
    plot_sens_plot_from_cum_data(cum_sens_up2first_fp_dict_fam, tool, "tp_bef_fp_frac_pfam", auc_sffp[tool], ci_info[tool]['sffp_ci_lower'], ci_info[tool]['sffp_ci_upper'])
    
    auc_lower = ci_info[tool]['auc_ci_lower']
    auc_upper = ci_info[tool]['auc_ci_upper']
    outfile.write(f"{db_name[tool]} : {auc_lower:.3f}-{auc_upper:.3f}\n")
    print(f"{db_name[tool]} : {auc_lower:.3f}-{auc_upper:.3f}")
plt.xlabel("Fraction of queries")
plt.ylabel("SFFP")
plt.legend()
plt.savefig(f"{fig_dir}/sens_up2_1stfp_family_level.png")
plt.show()
outfile.close()

### Clan-level plots

In [None]:
plt.figure(dpi=300)

for tool in db_name.keys():
    plot_sens_plot_from_cum_data(cum_sens_up2first_fp_dict_clan, tool, "tp_bef_fp_frac_clan")

plt.xlabel("Fraction of queries")
plt.ylabel("SFFP")
plt.legend()
plt.savefig(f"{fig_dir}/sens_up2_1stfp_clan_level.png")
plt.show()

## Coverage vs. Error plots

In [None]:
with open(f"{data_dir}/processed/cve_ci.pkl", 'rb') as file:
    cve_dict = pickle.load(file)

In [None]:
plt.figure(dpi=300)

def format_func(value, tick_number):  
    return f'$10^{{{int(value)}}}$'

for tool in db_name.keys():
    tool_data = cve_dict[tool]
    x, y_pre, y_lower_pre, y_upper_pre = tool_data
    y, y_lower, y_upper = np.log10(y_pre), np.log10(y_lower_pre), np.log10(y_upper_pre)
    
    # Plot with confidence intervals - matching your exact plotting commands
    plt.plot(x, y, label=f"{db_name[tool]}")
    plt.fill_between(x, y_lower, y_upper, alpha=0.2)

plt.xlabel('Sensitivity')  # Label for the x-axis  
plt.ylabel('FPEPQ')
plt.gca().yaxis.set_major_formatter(FuncFormatter(format_func))
plt.legend()
plt.savefig(f"{fig_dir}/cve.png")
plt.show()

# Stratified performance plots

### Different seeds characteristics 
For plddt, we know the proper bin sizes in advance. For other characteristics, we need to find it. We start by 20 bins and then reduce the number of bins by aggregating close bins together until 4-5 bins are left. The performance is stratified based on:
* size
* average plddt
* secondary structure state
* average contact number
* length normalized number of secondary structure state transitions
* Sequence identity of query with family members
Before binning, we average the output from reseek, cif_cut, tm, and reseek and bin their average.

In [None]:
query_ids = frac_sens_up2first_fp_dict_fam["mm"]["seed_id"]

In [None]:

tools4stratification = ["mm", "reseek", "cif_cut", "tm"]
markers = ["o", "s", "v", "*"]
mean_df = ((frac_sens_up2first_fp_dict_fam["mm"].set_index("seed_id") + 
            frac_sens_up2first_fp_dict_fam["tm"].set_index("seed_id") +
            frac_sens_up2first_fp_dict_fam["reseek"].set_index("seed_id") +
            frac_sens_up2first_fp_dict_fam["cif_cut"].set_index("seed_id"))/4).reset_index()

In [None]:
def bin_column_and_plot(df, stratifier_col, perf_func, bin_edges, plot_options):
    df_plot = df.copy()
    labels = [f"[{bin_edges[i]}, {bin_edges[i+1]})" for i in range(len(bin_edges)-1)]
    labels[0] = f"<={bin_edges[1]}"
    labels[-1] = f">{bin_edges[-2]}"
    df_plot["bin"] = pd.cut(df_plot[stratifier_col], bins=bin_edges, labels=labels, right=False)
    bins_and_performances = df_plot.groupby("bin").apply(perf_func).reset_index().rename(columns={0:"stratifier_col"})
    plt.plot(bins_and_performances["bin"].astype(str), bins_and_performances["stratifier_col"], **plot_options)

### Stratification by size

In [None]:
def add_size(df):
    df["size"] = df["seed_id"].str.split("-", expand=True)[2].astype(int) - df["seed_id"].str.split("-", expand=True)[1].astype(int) + 1
    return df

In [None]:
mean_df_size = mean_df.copy()
mean_df_size = add_size(mean_df_size)

In [None]:
auc_1st_fp_func = lambda df: auc_tps_bef_1st_fp(df, "tp_bef_fp_frac_pfam")
optimal_bins_size = find_optimal_bins(mean_df_size, "size", auc_1st_fp_func, initial_bins=20, final_bins=5)

In [None]:
optimal_bins_size["range"]

In [None]:
size_bin_edges = [12, 32, 47, 104, 175, 1728]

In [None]:
tools4stratification

In [None]:
plt.figure(dpi=300)

for i, tool in enumerate(tools4stratification):
    df_plot = frac_sens_up2first_fp_dict_fam[tool].copy()
    df_size = add_size(df_plot)
    plot_options = {"marker": markers[i], "linestyle":"", "label": db_name[tool]}
    bin_column_and_plot(df_size, "size", perf_func=auc_1st_fp_func, bin_edges=size_bin_edges, plot_options=plot_options)
plt.xlabel("Size range")
plt.ylabel("AUC_SFFP")
plt.legend()
plt.savefig(f"{fig_dir}/size_stratified_performance_sample_pf.png")
plt.show()

### Stratification by pLDDT

In [None]:
plddt_bin_edges = [0, 50, 70, 90, 100]
plddt_df_all = pd.read_csv(f"{data_dir}/processed/pfam_avg_plddt.tsv", sep="\t")

In [None]:
plt.figure(dpi=300)

for i, tool in enumerate(tools4stratification):
    df_plot = frac_sens_up2first_fp_dict_fam[tool].copy()
    df_plot = df_plot.merge(plddt_df_all, on="seed_id")
    plot_options = {"marker": markers[i], "linestyle":"", "label": db_name[tool]}
    bin_column_and_plot(df_plot, "avg_plddt", perf_func=auc_1st_fp_func, bin_edges=plddt_bin_edges, plot_options=plot_options)
plt.xlabel("avg_pLDDT")
plt.ylabel("AUC_SFFP")
plt.legend()

plt.savefig(f"{fig_dir}/plddt_stratified_performance_sample_pf.png")
plt.show()

### Stratification by contact number

In [None]:
cn_df = pd.read_csv(f"{data_dir}/processed/avg_contact_num.tsv", sep="\t")
mean_df_cn = mean_df.merge(cn_df, on="seed_id")

In [None]:
optimal_bins_cn = find_optimal_bins(mean_df_cn, "avg_contact_num", auc_1st_fp_func, initial_bins=20, final_bins=5)

In [None]:
optimal_bins_cn

In [None]:
cn_bin_edges = [3.64, 6.4, 7.58, 7.85, 8.55, 14.3]

In [None]:
plt.figure(dpi=300)

for i, tool in enumerate(tools4stratification):
    df_plot = frac_sens_up2first_fp_dict_fam[tool].copy()
    df_plot = df_plot.merge(cn_df, on="seed_id")
    plot_options = {"marker": markers[i], "linestyle":"", "label": db_name[tool]}
    bin_column_and_plot(df_plot, "avg_contact_num", perf_func=auc_1st_fp_func, bin_edges=cn_bin_edges, plot_options=plot_options)
plt.xlabel("Average contact number")
plt.ylabel("AUC_SFFP")
plt.legend()

plt.savefig(f"{fig_dir}/contact_number_stratified_performance_sample_pfam.png")
plt.show()

### Stratification by secondary structure

In [None]:
ss_info_df = pd.read_csv(f"{data_dir}/processed/ss_info_pfam.tsv", sep="\t")

In [None]:
ss_state_df = ss_info_df.copy()
ss_state_df["main_state"] = "NA"
ss_state_df.loc[ss_state_df["h_frac"]>=0.5, "main_state"] = "Helix"
ss_state_df.loc[ss_state_df["e_frac"]>=0.5, "main_state"] = "Sheet"
ss_state_df.loc[ss_state_df["c_frac"]>=0.5, "main_state"] = "Coil"
ss_state_df = ss_state_df[ss_state_df["main_state"]!="NA"][["seed_id", "main_state"]].reset_index(drop=True)

In [None]:
mean_df_ss = mean_df.merge(ss_state_df, on="seed_id")
plt.figure(dpi=300)

for i, tool in enumerate(tools4stratification):
    stratified_perf = frac_sens_up2first_fp_dict_fam[tool].merge(ss_state_df, on="seed_id").groupby("main_state").apply(auc_1st_fp_func).reset_index().rename(columns={0:"performance"})
    
    plot_options = {"marker": markers[i], "linestyle":"", "label": db_name[tool]}
    plt.plot(stratified_perf["main_state"], stratified_perf["performance"], marker=markers[i], linestyle="", label=db_name[tool])
plt.xlabel("Major secondary structure state")
plt.ylabel("AUC_SFFP")
plt.legend()

plt.savefig(f"{fig_dir}/ss_state_stratified_performance_sample_pf.png")
plt.show()

### Stratification by transition between secondary structure states

In [None]:
tr_df = ss_info_df[["seed_id", "len_norm_tr_count"]]

In [None]:
mean_df_tr = mean_df.merge(tr_df, on="seed_id")
optimal_bins_tr = find_optimal_bins(mean_df_tr, "len_norm_tr_count", auc_1st_fp_func, initial_bins=20, final_bins=5)

In [None]:
optimal_bins_tr

In [None]:
tr_bin_edges = [0, 0.056, 0.076, 0.1, 0.25, 0.5]

plt.figure(dpi=300)

for i, tool in enumerate(tools4stratification):
    df_plot = frac_sens_up2first_fp_dict_fam[tool].copy()
    df_plot = df_plot.merge(tr_df, on="seed_id")
    plot_options = {"marker": markers[i], "linestyle":"", "label": db_name[tool]}
    bin_column_and_plot(df_plot, "len_norm_tr_count", perf_func=auc_1st_fp_func, bin_edges=tr_bin_edges, plot_options=plot_options)
plt.xlabel("Length normalized secondary structure state transition frequency")
plt.ylabel("AUC_SFFP")
plt.legend()
plt.savefig(f"{fig_dir}/transition_stratified_performance_sample_pf.png")
plt.show()

### Stratification by intrafamily sequence identity

In [None]:
pi_df = pd.read_csv(f"{data_dir}/processed/avg_intra_fam_pident.tsv", sep="\t") #pi means percentage identity

In [None]:
mean_df_pi = mean_df.merge(pi_df, on="seed_id")
optimal_bins_pi = find_optimal_bins(mean_df_pi, "avg_intra_fam_pident", auc_1st_fp_func, initial_bins=20, final_bins=4)

In [None]:
optimal_bins_pi

In [None]:
pi_bin_edges = [5, 12, 13.5, 19.5, 100]

plt.figure(dpi=300)

for i, tool in enumerate(tools4stratification):
    df_plot = frac_sens_up2first_fp_dict_fam[tool].copy()
    df_plot = df_plot.merge(pi_df, on="seed_id")
    plot_options = {"marker": markers[i], "linestyle":"", "label": db_name[tool]}
    bin_column_and_plot(df_plot, "avg_intra_fam_pident", perf_func=auc_1st_fp_func, bin_edges=pi_bin_edges, plot_options=plot_options)
plt.xlabel("Average sequence identity with other members of family")
plt.ylabel("AUC_SFFP")
plt.legend()
plt.savefig(f"{fig_dir}/pident_stratified_performance_sample_pf.png")
plt.show()