EXPLAIN WHAT THIS NOTEBOOK IS FOR

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import Counter
from matplotlib.lines import Line2D
from scipy.stats import linregress
from sklearn.metrics import precision_recall_curve, auc
import dotenv

import sys
import os
sys.path.append(os.path.abspath('../..'))
from utils.plotting_utils import *


plt.rcParams['pdf.fonttype'] = 42  # For PDF: embed text as text, not paths
plt.rcParams['svg.fonttype'] = 'none'  # For SVG: embed text as text, not paths
#plt.rcParams['font.family'] = 'Arial'

In [None]:
'''
dir_overall = "/scratch/groups/emmalu/seq2loc/experiments/sweep_analysis"
dir_runs = "/scratch/groups/emmalu/seq2loc/sweep_experiments"

test_level1 = pd.read_csv(f"{dir_overall}/test/test_metrics_level1/overall_metrics.csv")
test_level2 = pd.read_csv(f"{dir_overall}/test/test_metrics_level2/overall_metrics.csv")
test_level3 = pd.read_csv(f"{dir_overall}/test/test_metrics_level3/overall_metrics.csv")


valid_level1 = pd.read_csv(f"{dir_overall}/test/test_metrics_level1/overall_metrics.csv")
valid_level2 = pd.read_csv(f"{dir_overall}/test/test_metrics_level2/overall_metrics.csv")
valid_level3 = pd.read_csv(f"{dir_overall}/test/test_metrics_level3/overall_metrics.csv")

avg_metrics = pd.concat([test_level1, test_level2, test_level3])
avg_metrics = avg_metrics[avg_metrics.agg_method != "TransformerPool"].reset_index(drop=True)

avg_metrics.to_csv("avg_metrics_ankit.csv", index=False)
'''

In [None]:
#GET NECESSARY .ENV VARIABLES
dotenv.load_dotenv('../../.env')
SWEEP_EXP_DIR=os.getenv("SWEEP_EXP_DIR")
SWEEP_ANALYSIS_DIR=os.getenv("SWEEP_ANALYSIS_DIR")
FIG_DIR = "../../figures/sweep_models"

mappings = load_config("../../datasets/final/hierarchical_label_set.yaml")
ordered_labels_level1 = mappings["level1"][:-1] #get rid of plastid
ordered_labels_level2 = mappings["level2"][:-1] #get rid of plastid
ordered_labels_level3 = mappings["level3"][:-1] #get rid of plastid
orders = [ordered_labels_level1, ordered_labels_level2, ordered_labels_level3]


models = ["ESM2", "ESM3", "ProtT5", "ProtBert"]
model_palette = sns.color_palette("plasma", len(models))
model_colors = {model: model_palette[i] for i, model in enumerate(models)}

avg_metrics = pd.read_csv(f"{SWEEP_ANALYSIS_DIR}/overall_metrics.csv")

hou_testset = pd.read_csv("../..datasets/final/hou_testset.csv")
hpa_uniprot_combined_trainset = pd.read_csv("../..datasets/final/hpa_uniprot_combined_trainset.csv")

In [5]:
#DROP DUPLICATE RUNS
idx = avg_metrics.groupby(
    ["exp_name",
    "category_level",
    "metadata_file",
    "clip_len",
    "agg_method",
    "mlp_dropout",
    "loss"
    ])["macro_ap"].idxmax()
avg_metrics = avg_metrics.loc[idx].reset_index(drop=True)

In [6]:
#CHECK IF ANY RUNS ARE MISSING
exp_names = ["ESM2", "ESM3", "ProtT5", "ProtBert"]
category_levels = ["level1", "level2", "level3"]
metadata_files = ["hpa_trainset", "uniprot_trainset", "hpa_uniprot_combined_trainset", "hpa_uniprot_combined_human_trainset"]
clip_lens = [512, 1024, 2048]
agg_methods = ["MeanPool", "MaxPool", "LightAttentionPool", "MultiHeadAttentionPool"]
losses = ["BCEWithLogitsLoss", "SigmoidFocalLoss"]
mlp_dropouts = [0, 0.25, 0.5]

combs = [
    (i,j,k,l,m,n,p) 
    for i in exp_names 
    for j in category_levels 
    for k in metadata_files 
    for l in clip_lens 
    for m in agg_methods 
    for n in losses
    for p in mlp_dropouts]

for comb in combs:
    i,j,k,l,m,n,p = comb
    temp = avg_metrics[
        (avg_metrics.exp_name == i) & 
        (avg_metrics.category_level == j) & 
        (avg_metrics.metadata_file == k) & 
        (avg_metrics.clip_len == l) & 
        (avg_metrics.agg_method == m) & 
        (avg_metrics.loss == n) &
        (avg_metrics.mlp_dropout == p)
    ]
    if len(temp)==0:
        print(f"Missing {comb}")

In [None]:
#COMPARING TRAININGSETS
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
legend_added = False

custom_palette = {
    'hpa_trainset': 'red',
    'uniprot_trainset': 'blue',
    'hpa_uniprot_combined_trainset': 'green',
    'hpa_uniprot_combined_human_trainset': 'orange'
}

for j in range(2):
    for i in [1, 2, 3]:
        if j == 0:
            g = sns.scatterplot(
                data=avg_metrics[avg_metrics.category_level==f"level{i}"],
                x="macro_ap",
                y="micro_ap",
                hue="metadata_file",
                style="agg_method",
                ax=axes[j][i-1],
                legend=not legend_added,  # Add legend only once
                alpha=0.6,
                palette=custom_palette
            )
            axes[j][i-1].set_xlabel("Macro AP")
            axes[j][i-1].set_ylabel("Micro AP")
            if not legend_added:
                handles, labels = axes[j][i-1].get_legend_handles_labels()
                legend_added = True  # Set to True after adding the legend
        elif j == 1:
            sns.scatterplot(
                data=avg_metrics[avg_metrics.category_level==f"level{i}"],
                x="f1_macro",
                y="f1_micro",
                hue="metadata_file",
                style="agg_method",
                ax=axes[j][i-1],
                legend=False,
                alpha=0.6,
                palette=custom_palette
            )
            axes[j][i-1].set_xlabel("Macro F1")
            axes[j][i-1].set_ylabel("Micro F1")
        if j == 0:
            axes[j][i-1].set_title(f"Level {i}")

# Remove individual legends from subplots
for ax in axes.flat:
    if ax.get_legend() is not None:
        ax.get_legend().remove()

# Create a single legend for the entire figure, placed outside the figure
labels = [
    "Training Set",
    "HPA trainset", 
    "HPA UniProt Combined (human) trainset", 
    "HPA Uniport Combined trainset", 
    "UniProt Trainset",
    "Aggregation Method", 
    "Light Attenion",
    "Max Pool",
    "Mean Pool", 
    "Multihead Attention"
    ]

fig.legend(handles, labels, loc='upper left', bbox_to_anchor=(1.01, 0.9), fontsize=9)

plt.tight_layout()
plt.savefig(f"{FIG_DIR}/trainset.pdf", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
#COMPARING CLIP LENGHTS
# --> No clear advantage of longer or shorter clip lengths
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
legend_added = False
custom_palette = {
    512: 'red',
    1024: 'blue',
    2048: 'green'
}

for j in range(2):
    for i in [1, 2, 3]:
        if j == 0:
            g = sns.scatterplot(
                data=avg_metrics[
                (avg_metrics.category_level==f"level{i}") &
                (avg_metrics.metadata_file=="hpa_uniprot_combined_trainset")
                ],
                x="macro_ap",
                y="micro_ap",
                hue="clip_len",
                style="agg_method",
                ax=axes[j][i-1],
                legend=not legend_added,  # Add legend only once
                alpha=0.8,
                palette=custom_palette
            )
            axes[j][i-1].set_xlabel("Macro AP")
            axes[j][i-1].set_ylabel("Micro AP")
            if not legend_added:
                handles, labels = axes[j][i-1].get_legend_handles_labels()
                legend_added = True  # Set to True after adding the legend
        elif j == 1:
            sns.scatterplot(
                data=avg_metrics[
                (avg_metrics.category_level==f"level{i}") &
                (avg_metrics.metadata_file=="uniprot_trainset")
                ],
                x="f1_macro",
                y="f1_micro",
                hue="clip_len",
                style="agg_method",
                ax=axes[j][i-1],
                legend=False,
                alpha=0.8,
                palette=custom_palette
            )
            axes[j][i-1].set_xlabel("Macro F1")
            axes[j][i-1].set_ylabel("Micro F1")
        if j == 0:
            axes[j][i-1].set_title(f"Level {i}")

# Remove individual legends from subplots
for ax in axes.flat:
    if ax.get_legend() is not None:
        ax.get_legend().remove()

# Create a single legend for the entire figure, placed outside the figure
labels = [
    "Clip Length",
    "512", 
    "1024", 
    "2048",
    "Aggregation Method", 
    "Light Attenion",
    "Max Pool",
    "Mean Pool", 
    "Multihead Attention"
    ]

fig.legend(handles, labels, loc='upper left', bbox_to_anchor=(1.01, 0.9), fontsize=9)

plt.savefig(f"{FIG_DIR}/clip_len_compare.pdf", bbox_inches='tight', dpi=300)
plt.tight_layout()
plt.show()

In [None]:
#COMPARING LOSS FUNCTIONS
#No clear advantage over using one loss than the other
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
legend_added = False
custom_palette = {
    'BCEWithLogitsLoss': 'red',
    'SigmoidFocalLoss': 'blue'
}

for j in range(2):
    for i in [1, 2, 3]:
        if j == 0:
            g = sns.scatterplot(
                data=avg_metrics[
                (avg_metrics.category_level==f"level{i}") &
                (avg_metrics.metadata_file=="hpa_uniprot_combined_trainset")],
                x="macro_ap",
                y="micro_ap",
                hue="loss",
                style="agg_method",
                ax=axes[j][i-1],
                legend=not legend_added,  # Add legend only once
                alpha=0.8,
                palette=custom_palette
            )
            if not legend_added:
                handles, labels = axes[j][i-1].get_legend_handles_labels()
                legend_added = True  # Set to True after adding the legend

            axes[j][i-1].set_xlabel("Macro AP")
            axes[j][i-1].set_ylabel("Micro AP")
        elif j == 1:
            sns.scatterplot(
                data=avg_metrics[
                (avg_metrics.category_level==f"level{i}") &
                (avg_metrics.metadata_file=="hpa_uniprot_combined_trainset")],
                x="f1_macro",
                y="f1_micro",
                hue="loss",
                style="agg_method",
                ax=axes[j][i-1],
                legend=False,
                alpha=0.8,
                palette=custom_palette
            )
            axes[j][i-1].set_xlabel("Macro F1")
            axes[j][i-1].set_ylabel("Micro F1")
        if j == 0:
            axes[j][i-1].set_title(f"Level {i}")

# Remove individual legends from subplots
for ax in axes.flat:
    if ax.get_legend() is not None:
        ax.get_legend().remove()

# Create a single legend for the entire figure, placed outside the figure

labels = [
    "Loss Function",
    "BCEWithLogitsLoss", 
    "SigmoidFocalLoss",
    "Aggregation Method", 
    "Light Attenion",
    "Max Pool",
    "Mean Pool", 
    "Multihead Attention"
    ]

fig.legend(handles, labels, loc='upper left', bbox_to_anchor=(1.01, 0.9), fontsize=9)
plt.savefig(f"{FIG_DIR}/loss_compare.pdf", bbox_inches='tight', dpi=300)
plt.tight_layout()
plt.show()

In [None]:
#COMPARING LOSS FUNCTIONS

# Filter data for hpa_uniprot_combined_trainset
data_combined = avg_metrics[avg_metrics.metadata_file == "hpa_uniprot_combined_trainset"]

# Pivot the data to have loss types as columns
pivoted = data_combined.pivot_table(
    index=["exp_name", "category_level", "clip_len", "agg_method"],
    columns="loss",
    values=["macro_ap", "micro_ap", "f1_macro", "f1_micro"]
)

diff = pivoted.xs("SigmoidFocalLoss", level=1, axis=1) - pivoted.xs("BCEWithLogitsLoss", level=1, axis=1)
diff = diff.reset_index()


fig, axes = plt.subplots(2, 3, figsize=(14, 8))
legend_added = False

for j in range(2):
    for i in [1, 2, 3]:
        if j == 0:
            g = sns.scatterplot(
                data=diff[
                (diff.category_level==f"level{i}")],
                x="macro_ap",
                y="micro_ap",
                hue="exp_name",
                style="agg_method",
                ax=axes[j][i-1],
                legend=not legend_added,  # Add legend only once
                alpha=0.8,
                s=100
            )
            if not legend_added:
                handles, labels = axes[j][i-1].get_legend_handles_labels()
                legend_added = True  # Set to True after adding the legend

            axes[j][i-1].set_xlabel("Macro AP difference")
            axes[j][i-1].set_ylabel("Micro AP difference")

        elif j == 1:
            sns.scatterplot(
                data=diff[
                (diff.category_level==f"level{i}")],
                x="f1_macro",
                y="f1_micro",
                hue="exp_name",
                style="agg_method",
                ax=axes[j][i-1],
                legend=False,
                alpha=0.8,
                s=100
            )
            axes[j][i-1].set_xlabel("Macro F1 difference")
            axes[j][i-1].set_ylabel("Micro F1 difference")
        if j == 0:
            axes[j][i-1].set_title(f"Level {i}")

        # Add axes at y=0 and x=0 lines
        axes[j][i-1].axhline(0, color='gray', linestyle='--', linewidth=0.7)
        axes[j][i-1].axvline(0, color='gray', linestyle='--', linewidth=0.7)

# Remove individual legends from subplots
for ax in axes.flat:
    if ax.get_legend() is not None:
        ax.get_legend().remove()

# Create a single legend for the entire figure, placed outside the figure

labels = [
    "Protein Language Model",
    "ESM2", 
    "ESM3", 
    "ProtBert", 
    "ProtT5",
    "Aggregation Method", 
    "Light Attenion",
    "Max Pool",
    "Mean Pool", 
    "Multihead Attention"
    ]

fig.legend(handles, labels, loc='upper left', bbox_to_anchor=(1.01, 0.9), fontsize=12)
fig.suptitle("Improvement using SigmoidFocalLoss over BCEWithLogitsLoss", fontsize=16)
plt.tight_layout()
plt.savefig(f"{FIG_DIR}/loss_diff_compare.pdf", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
#COMPARE ALL PLM x AGGREGATION STRATEGIES
# Only looking at models trained on uniprot_trainset
data_uniprot_combined_trainset = avg_metrics[
    (avg_metrics.metadata_file == "hpa_uniprot_combined_trainset")
]

# Get models with hyperparams that resulted in best macro_ap
idx = data_uniprot_combined_trainset.groupby(
    ["exp_name", "agg_method", "category_level"]
).macro_ap.transform(max) == data_uniprot_combined_trainset['macro_ap']
temp = data_uniprot_combined_trainset[idx]

fig, axes = plt.subplots(2, 3, figsize=(14, 8))
legend_added = False
for j in range(2):
    for i in [1, 2, 3]:
        if j == 0:
            g = sns.scatterplot(
                data=temp[temp.category_level == f"level{i}"],
                x="micro_ap",
                y="macro_ap",
                hue="exp_name",
                hue_order=["ProtBert", "ProtT5", "ESM2", "ESM3"],
                style="agg_method",
                style_order=["MaxPool", "MeanPool", "LightAttentionPool", "MultiHeadAttentionPool"],
                ax=axes[j][i-1],
                legend=not legend_added,  # Add legend only once
                alpha=0.8,
                s=200,
                palette=model_palette
            )
            axes[j][i-1].set_xlabel("Micro AP")
            axes[j][i-1].set_ylabel("Macro AP")
            axes[j][i-1].set_ylim()
            if not legend_added:
                handles, labels = axes[j][i-1].get_legend_handles_labels()
                legend_added = True  # Set to True after adding the legend
        elif j == 1:
            sns.scatterplot(
                data=temp[temp.category_level == f"level{i}"],
                x="f1_micro",
                y="f1_macro",
                hue="exp_name",
                hue_order=["ProtBert", "ProtT5", "ESM2", "ESM3"],
                style="agg_method",
                style_order=["MaxPool", "MeanPool", "LightAttentionPool", "MultiHeadAttentionPool"],
                ax=axes[j][i-1],
                legend=False,
                alpha=0.8,
                s=200,
                palette=model_palette
            )
            axes[j][i-1].set_xlabel("Micro F1")
            axes[j][i-1].set_ylabel("Macro F1")
        if j == 0:
            axes[j][i-1].set_title(f"Level {i}")

# Create a single legend for the entire figure, placed outside the figure
custom_labels = [
 "Protein Language Model",
 'ProtBert',
 'ProtT5',
 'ESM2',
 'ESM3',
 'Aggregation Method',
 'Max',
 'Mean',
 'LightAttention',
 'MultiHeadAttention']

legend = fig.legend(
    handles, 
    custom_labels, 
    loc='upper left', 
    bbox_to_anchor=(1.01, 0.9), 
    fontsize=9, 
    #facecolor='white', 
    #edgecolor='white',
    title=None)

# Remove individual legends from subplots
for ax in axes.flat:
    if ax.get_legend() is not None:
        ax.get_legend().remove()

plt.tight_layout()

plt.savefig(f"{FIG_DIR}/PLMxAGG_compare.pdf", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
#COMPARE ALL PLM x AGGREGATION STRATEGIES - Level 1 only
data_uniprot_combined_trainset = avg_metrics[
    (avg_metrics.metadata_file == "hpa_uniprot_combined_trainset")
]

# Get models with hyperparams that resulted in best macro_ap
idx = data_uniprot_combined_trainset.groupby(
    ["exp_name", "agg_method", "category_level"]
).macro_ap.transform(max) == data_uniprot_combined_trainset['macro_ap']
temp = data_uniprot_combined_trainset[idx]

fig, axes = plt.subplots(figsize=(4.5, 4))
legend_added = False
g = sns.scatterplot(
    data=temp[temp.category_level == f"level1"],
    x="micro_ap",
    y="macro_ap",
    hue="exp_name",
    hue_order=["ProtBert", "ProtT5", "ESM2", "ESM3"],
    style="agg_method",
    style_order=["MaxPool", "MeanPool", "LightAttentionPool", "MultiHeadAttentionPool"],
    ax=axes,
    legend=False,  # Add legend only once
    alpha=0.8,
    s=200,
    palette=model_palette
)
axes.set_xlabel("Micro AP")
axes.set_ylabel("Macro AP")
axes.set_ylim()
           
# Create a single legend for the entire figure, placed outside the figure
custom_labels = [
 "Protein Language Model",
 'ProtBert',
 'ProtT5',
 'ESM2',
 'ESM3',
 'Aggregation Method',
 'Max',
 'Mean',
 'LightAttention',
 'MultiHeadAttention']

legend = fig.legend(
    handles, 
    custom_labels, 
    loc='upper left', 
    bbox_to_anchor=(1.01, 0.9), 
    fontsize=9, 
    #facecolor='white', 
    #edgecolor='white',
    title=None)

plt.tight_layout()
plt.savefig(f"{FIG_DIR}/PLMxAGG_macroap_level1.pdf", bbox_inches='tight', dpi=300)
plt.show()

Save average and perclass metrics for best models

In [None]:
#Display best results for each PLM to add to table in manuscript

#Only look at models trains on combined trainset
data_uniprot_combined_trainset = avg_metrics[
    (avg_metrics.metadata_file == "hpa_uniprot_combined_trainset")
]

# Get models with hyperparams that resulted in best macro_ap
idx = data_uniprot_combined_trainset.groupby(
    ["exp_name", "category_level"]
).macro_ap.transform(max) == data_uniprot_combined_trainset['macro_ap']
temp = data_uniprot_combined_trainset[idx]

#----- SAVE PERCLASS METRICS ------
dfs = []
for i, row in temp.iterrows():
    run_id = row.run_id
    model = row.exp_name
    metadata = row.metadata_file
    path = f"{d}/{model}_{metadata}/{run_id}/all_folds_perclass_metrics.csv"

    df = pd.read_csv(path).round(3)
    df["Model"] = model
    df["Agg. Method"] = row.agg_method
    df["Level"] = row.category_level[-1]
    dfs.append(df)

pd.concat(dfs).to_csv(f"{FIG_DIR}/cutom_models_perclass_metrics.csv")


#----- SAVE AVG METRICS ------
values = ["acc_samples", "macro_ap", "micro_ap", 
            "f1_macro", "f1_micro", "num_labels",
            "jaccard_macro", "jaccard_micro", "rocauc_macro",
            "rocauc_micro",	"mlrap", "coverage_error"
        ]
temp_for_display = temp.pivot_table(index='exp_name', columns='category_level', values=values).round(3)
temp_for_display = temp_for_display.reindex(["ESM2", "ESM3", "ProtT5", "ProtBert"])
display(temp_for_display)
temp_for_display.to_csv(f"{FIG_DIR}/custom_models_avg_metrics.csv") #THIS FILE DOESNT ACTUALLY EXIST NEED TO MAKE IT FROM WHAT I SAVED

Next we look at which hyperparameters resulted in the best model performance for each PLM

In [None]:
#Only looking at models trained on hpa_uniprot_combined_trainset
data_uniprot_combined_trainset= avg_metrics[
    (avg_metrics.metadata_file=="hpa_uniprot_combined_trainset")
    ]

#Get models with hyperparams that resulted in best macro_ap
idx = data_uniprot_combined_trainset.groupby(
    ["exp_name", "agg_method", "category_level"]
    ).macro_ap.transform(max) == data_uniprot_combined_trainset['macro_ap']
temp = data_uniprot_combined_trainset[idx]

#Sort agg_methods
ordering = ["MaxPool", "MeanPool", "LightAttentionPool", "MultiHeadAttentionPool"]
agg_method_dtype = pd.CategoricalDtype(categories=ordering, ordered=True)
temp['agg_method'] = temp['agg_method'].astype(agg_method_dtype)
temp = temp.sort_values('agg_method')

temp = temp.pivot(
    index=["category_level", "agg_method"], 
    columns="exp_name", 
    values=["clip_len", "mlp_dropout", "acc_samples", "macro_ap", "micro_ap", "f1_macro", "f1_micro"])

display(temp)


Precision/recall bar plot per-class. One bar for every plm. Only use results for models trained on hpa_uniprot combined trainset and using best hyperparameters

In [None]:
idx = avg_metrics[
    (avg_metrics.metadata_file=="hpa_uniprot_combined_trainset")
].groupby(["exp_name", "category_level"])["macro_ap"].idxmax().values

perclass_metrics = []
for i, row in avg_metrics.iloc[idx].iterrows():
    path = f"{SWEEP_EXP_DIR}/{row.exp_name}_{row.metadata_file}/{row.run_id}/all_folds_perclass_metrics.csv"
    df = pd.read_csv(path)
    df["exp_name"] = row.exp_name
    df["category_level"] = row.category_level
    df["metadata_file"] = row.metadata_file
    df["clip_len"] = row.clip_len
    df["agg_method"] = row.agg_method
    df["loss"] = row.loss
    perclass_metrics.append(df)

perclass_metrics = pd.concat(perclass_metrics)

perclass_metrics = perclass_metrics.rename({"category": "label", "exp_name": "model"}, axis=1)
for level in [1,2,3]:
    temp = perclass_metrics[perclass_metrics.category_level == f"level{level}"]
    fig, ax = plot_perclas_double_bar(
        temp,
        {"Recall": "recall",
        "Precision": "precision"},
        models,
        model_colors,
        orders[level-1]
    )
    plt.savefig(f"{FIG_DIR}/custom_models_pr_barplot_level{level}.pdf", bbox_inches='tight', dpi=300)


In [None]:
r2_values_df = []

metrics = ["mcc", "acc", "recall", "precision", "f1", "jaccard", "rocauc"]
metric_title = {metrics:title for metrics, title in zip(metrics, ["MCC", "ACC", "Recall", "Precision", "F1", "Jaccard", "ROC-AUC"])}

for level in range(1, 4):
    fig, axes = plt.subplots(1, len(metrics), figsize=(27, 4), sharex=False, sharey=False)
    # Count labels for current level
    all_locs = sum(hpa_uniprot_combined_trainset[f"level{level}"].str.split(";").to_list(), [])
    counter = Counter(all_locs)
    labels = list(counter.keys())
    counts = list(counter.values())

    # Initialize an empty list to store custom legend handles and labels
    handles_list = []
    labels_list = []

    for model in models:
        temp = perclass_metrics[(perclass_metrics.category_level == f"level{level}") & 
                            (perclass_metrics.model == model)].set_index("label").reindex(labels)
        temp["counts"] = counts
        temp["model"] = model
        temp["level"] = f"Level {level}"

        r2_values= []
        for i, metric in enumerate(metrics):
            sns.scatterplot(
                data=temp, 
                x="counts", 
                y=metric, 
                color=model_colors[model],
                ax=axes[i],
                legend=False  # Disable legend for scatter plot
            )
            
            sns.regplot(
                data=temp,
                x="counts",
                y=metric,
                scatter=False,
                logx=True,
                line_kws={'color': model_colors[model], 'lw': 1.5},
                ax=axes[i],
            )
            axes[i].set_xscale("log")  # <-- Set x-axis to log scale

            x = np.log10(temp["counts"].to_numpy())
            y = temp[metric].to_numpy()
            slope, intercept, r_value, p_value, std_err = linregress(x, y)
            r_squared = r_value**2
            r2_values.append(r_squared)

            # Collect legend handles and labels manually for the last plot
            if i == len(metrics) - 1:  # Only add the legend to the last plot
                handle = Line2D([0], [0], marker='o', color='w', label=model,
                                markerfacecolor=model_colors[model], markersize=8)
                handles_list.append(handle)
                labels_list.append(model)

        r2_values_df.append([model, level] + r2_values)
    fig.suptitle(f"Level {level}", fontsize=16, y=1.02)  # Add figure-wide title indicating the level

    # Title and layout
    for ax, metric in zip(axes, metrics):
        ax.set_title(metric_title[metric])
        ax.set_xlabel("Count (log-scale)")
        ax.set_ylabel("Score")

    # Create a custom legend outside the figure
    fig.legend(
        handles_list, labels_list, title="Model",
        loc="upper right",
        bbox_to_anchor=(1.04, 0.85)
    )

    plt.tight_layout()
    fig.savefig(f"{FIG_DIR}/custom_CountsvsMetrics_level{level}.pdf", bbox_inches='tight', dpi=300)
    plt.show()

r2_values_df = pd.DataFrame(r2_values_df, columns = ["model", "level"] + metrics)
r2_values_df = r2_values_df.set_index("model").round(3)
display(r2_values_df)
r2_values_df.to_csv(f"{FIG_DIR}/r2_value_custom.csv")

In [22]:
#Extract laprott5 results
laprott5_avg_metrics = avg_metrics[(avg_metrics.metadata_file=="hpa_uniprot_combined_trainset") &
                                    (avg_metrics.clip_len==1024) &
                                    (avg_metrics.exp_name=="ProtT5") & 
                                    (avg_metrics.agg_method=="LightAttentionPool") &
                                    (avg_metrics.loss=="BCEWithLogitsLoss") &
                                    (avg_metrics.mlp_dropout == 0.25)]
laprott5_avg_metrics = laprott5_avg_metrics.drop("acc", axis=1).rename({"acc_samples": "acc"}, axis=1)

run_dir = f"{SWEEP_EXP_DIR}/ProtT5_hpa_uniprot_combined_trainset"
perclass_metrics =[]
for runid in laprott5_avg_metrics.sort_values("category_level").run_id.to_list():
    df = pd.read_csv(f"{run_dir}/{runid}/all_folds_perclass_metrics.csv")
    perclass_metrics.append(df)

laprott5_avg_metrics.to_csv("../../Benchmark-Models/LAProtT5/LAProtT5_avg_metrics.csv", index=None)
for i, df, in enumerate(perclass_metrics):
    df.to_csv(f"../../Benchmark-Models/LAProtT5/LAProtT5_perclass_metrics_level{i+1}.csv")

In [None]:
#Model to use for downstream analysis
idx = avg_metrics[avg_metrics.category_level=="level1"]["macro_ap"].idxmax()
avg_metrics.iloc[idx]

In [25]:
def make_matrix(d, run_id, name):

    rare_classes = [ #classes with <50 samples in HOU
        "actin-filaments",
        "intermediate-filaments", 
        "microtubules",
        "peroxisomes", 
        "lipid-droplets", 
        "nuclear-membrane", 
        "nuclear-speckles", 
        "nucleoli-fibrillar-center", 
        "plastid"
    ]
    print(name)
    thresholds = np.load(f"{d}/{run_id}/all_thresholds.npy")
    preds_bin = []
    preds_all = []
    for i in range(5):
        path = f"{d}/{run_id}/fold_{i}/fold_{i}_test_predictions.csv"
        preds_df = pd.read_csv(path)
        cols = preds_df.columns
        true_cols = [c for c in cols if "true" in c]
        pred_cols = [c for c in cols if "pred" in c]
        if i == 0:
            locations = np.array([c.split("_")[0] for c in true_cols])
            targets = np.array(preds_df[true_cols].to_numpy())
        else:
            assert (targets == preds_df[true_cols].to_numpy()).all()
            assert (locations == np.array([c.split("_")[0] for c in true_cols])).all()
            assert (locations == np.array([c.split("_")[0] for c in pred_cols])).all()
        preds = preds_df[pred_cols].to_numpy()
        preds = torch.sigmoid(torch.from_numpy(preds)).numpy()
        preds_all.append(preds)
        preds_bin.append((preds > thresholds[i]).astype(np.int16))
    preds_all = np.array(preds_all).mean(axis=0)
    preds_bin = (np.stack(preds_bin).mean(axis=0) > 0.5).astype(np.int16)


    #Make Precision Recall Line Plot:
    palette = sns.color_palette("tab20", n_colors=13)
    colors = {
        'cytoskeleton': palette[0], 
        'centrosome': palette[1],
        'plasma-membrane': palette[2], 
        'cytosol': palette[3],
        'endoplasmic-reticulum': palette[4], 
        'endomembrane-system': palette[4],
        'golgi-apparatus': palette[5], 
        'vesicles': palette[6],
        'endosomes': palette[7], 
        'lysosomes': palette[8], 
        'mitochondria': palette[9], 
        'nucleoplasm': palette[10], 
        'nucleus': palette[10],
        'nuclear-bodies': palette[11],
        'nucleoli': palette[12]
    }
    fig, ax = plt.subplots(figsize=(8, 4))  
    for i in range(targets.shape[1]):
        if locations[i] not in rare_classes:
            precision, recall, thresholds = precision_recall_curve(targets[:,i], preds_all[:,i])
            # Plot the precision-recall curve using Seaborn
            sns.lineplot(
                x=recall[:-1], 
                y=precision[:-1], 
                label=locations[i], 
                color=colors[locations[i]],
                errorbar=None
            )
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    # Move legend outside the plot (right side)
    ax.legend(
        title="Class",
        bbox_to_anchor=(1.02, 1),  # right outside
        loc='upper left',
        borderaxespad=0.
    )

    # Adjust layout so that legend doesn't overlap plot
    fig.tight_layout(rect=[0, 0, 0.85, 1])  # Leave space on right for legend

    plt.savefig(f"{FIG_DIR}/{name}_pr_curve.pdf", dpi=300)
    plt.show()

    uni_ids = preds_df.id.to_list()
    true_locs = [set(locations[row==1]) for row in targets]
    pred_locs = [set(locations[row==1]) for row in preds_bin]
    df = pd.DataFrame(np.array([uni_ids,true_locs,pred_locs]).T, columns=["id", "true", "pred"])

    def zip_zip(l1,l2):
        if len(l1) == 0 and len(l2) > 0:
            l = [("-", y) for y in l2]
        elif len(l1) > 0 and len(l2) == 0:  
            l = [(x, "-") for x in l1]
        elif len(l1) > 0 and len(l2) > 0:
            l = [(x,y) for x in l1 for y in l2]
        else: 
            l = []
        return l

    df.loc[:, "same"] = df.apply(lambda x: x.true.intersection(x.pred), axis=1)
    df.loc[:, "true_not_pred"] = df.true - df.pred
    df.loc[:, "pred_not_true"] = df.pred - df.true

    df.loc[:, "replace"] = df.apply(
            lambda x: zip_zip(
                x["true_not_pred"], 
                x["pred_not_true"]), 
                axis=1)

    locations = list(locations) + ["-"]
    replacement_counter = np.zeros((len(locations), len(locations)))
    for row in df[f"replace"].to_list():
        for replacement in row:
            true_loc, pred_loc = replacement
            j = locations.index(true_loc)
            k = locations.index(pred_loc)
            replacement_counter[j,k] += 1

    for row in df[f"same"].to_list():
        for loc in row:
            j = locations.index(loc)
            replacement_counter[j,j] += 1

    true_counts = targets.sum(axis=0)[:, None]
    true_counts[locations.index("plastid"), :] = 1
    true_replacement_counter = replacement_counter[:-1, :]/true_counts
    pred_counts = preds_bin.sum(axis=0)[None, :]
    pred_counts[:, locations.index("plastid")] = 1
    pred_replacement_counter = replacement_counter[:, :-1]/pred_counts

    print(true_replacement_counter.shape)
    #There are no true plastids, so delete this row
    true_replacement_counter = np.delete(true_replacement_counter, locations.index("plastid"), axis=0)
    pred_replacement_counter = np.delete(pred_replacement_counter, locations.index("plastid"), axis=0)
    print(true_replacement_counter.shape)


    fig, (ax1, ax2) = plt.subplots(
        1, 2, figsize=(16, 8), sharey=True, gridspec_kw={'width_ratios': [1, 0.93]})
    cbar_ax = fig.add_axes([.91, .5, .03, .4])  # Position for the color bar

    sns.heatmap(
        true_replacement_counter, 
        xticklabels=locations, 
        yticklabels=locations[:-2] + ["-"], 
        ax=ax1, 
        cmap="coolwarm", 
        cbar_ax=cbar_ax
        )
    ax1.set_xlabel("Predicted label")
    ax1.set_ylabel("True label")
    ax1.set_title("True Confusion Matrix", fontsize=16)
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=60, ha='right')

    sns.heatmap(
        pred_replacement_counter, 
        xticklabels=locations[:-1], 
        yticklabels=locations[:-2] + ["-"], 
        ax=ax2, 
        cmap="coolwarm", 
        cbar_ax=cbar_ax,
        )
    ax2.set_xlabel("Predicts label")
    ax1.set_ylabel("True label")
    ax2.set_title("Predicted Confusion Matrix", fontsize=16)
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=60, ha='right')

    plt.tight_layout(rect=[0, 0, .9, 1])  # Adjust layout to make space for the color bar
    plt.savefig(f"{FIG_DIR}/{name}.pdf", dpi=300)
    plt.show()

In [None]:
#ProtT5 is best in general, so get best ProtT5 models for each level
data_uniprot_combined_trainset = avg_metrics[
    (avg_metrics.metadata_file == "hpa_uniprot_combined_trainset")
]
idx = data_uniprot_combined_trainset.groupby(
    ["exp_name", "category_level"]
).macro_ap.transform(max) == data_uniprot_combined_trainset['macro_ap']
temp = data_uniprot_combined_trainset[idx]
temp = temp[temp.exp_name == "ProtT5"]

d = f"{SWEEP_EXP_DIR}/ProtT5_hpa_uniprot_combined_trainset"
for i, row in temp.iterrows():
    name = f"{row.exp_name}_{row.category_level}_confusionmatrices"
    run_id = row.run_id
    make_matrix(d, run_id, name)

In [None]:
#LAProtT5 is best benchmark model in general, so also make confusion matrices for this

temp = laprott5_avg_metrics

d = f"{SWEEP_EXP_DIR}/ProtT5_hpa_uniprot_combined_trainset"
for i, row in temp.iterrows():
    name = f"LAProtT5_{row.category_level}_confusionmatrices"
    run_id = row.run_id
    make_matrix(d, run_id, name)