In [15]:
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
import pandas as pd
import polars as pl
import numpy as np
import os
import torch
import random

from kinpfn.priors import Batch
from kinpfn.model import KINPFN

In [16]:
def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    
def get_batch_kinfold_folding_times(paths, seq_len=1000, num_features=1, **kwargs):

    dataset_size = len(paths)

    x = torch.zeros(seq_len, dataset_size, num_features)
    y = torch.zeros(seq_len, dataset_size)

    batch_index = 0
    length_order = []
    for path in paths:
        data = pl.read_csv(
            path,
            has_header=False,
            columns=[2, 4],
            n_rows=seq_len,
        )

        folding_times = data["column_3"].to_numpy()
        sequence = data["column_5"][0]
        length_order.append(len(sequence))
        sorted_folding_times = np.sort(folding_times)

        # Filter out points where x > 10^15 and x < 10^-6
        valid_indices = (sorted_folding_times <= 10**15) & (
            sorted_folding_times >= 10**-6
        )
        sorted_folding_times = sorted_folding_times[valid_indices]

        # Adjust the sequence length by sampling
        current_seq_len = len(sorted_folding_times)
        if current_seq_len <= 0:
            continue

        if current_seq_len < seq_len:
            # Repeat the sorted_folding_times and cdf to match the sequence length (Oversampling)
            repeat_factor = seq_len // current_seq_len + 1
            sorted_folding_times = np.tile(sorted_folding_times, repeat_factor)[
                :seq_len
            ]
        else:
            sorted_folding_times = sorted_folding_times[:seq_len]

        x[:, batch_index, 0] = torch.tensor(np.zeros(seq_len))
        y[:, batch_index] = torch.tensor(sorted_folding_times)
        batch_index += 1

    y = torch.log10(y)
    return Batch(x=x, y=y, target_y=y), length_order

In [None]:
model_path = "../../../../models/final_kinpfn_model_1400_1000_1000_86_50_2.5588748050825984e-05_256_4_512_8_0.0_0.0.pt"

kinpfn = KINPFN(
    model_path=model_path,
)
trained_model = kinpfn.model

if trained_model is not None:
    print("Load trained model!")
else:
    print("No trained model found!")
    exit()

In [18]:
def combined_gt_and_pred_plot(trained_model, training_points : list[int], candidate_ids: list[str], seed: int = None,times=1000, ):
    set_seed(seed)
    
    seq_len = 1000

    fig, axes = plt.subplots(
        nrows=len(candidate_ids),
        ncols=len(training_points) + 1,
        figsize=(25, 11),
        layout="constrained",
    )
    fig.set_dpi(300)

    for i, candidate_id in enumerate(candidate_ids):
        name = candidate_id
        plot_row = i
        if len(candidate_ids) == 1:
            ax = axes[0]
        else:
            ax = axes[plot_row, 0]

        ax.set_title(rf"Ground Truth $F(t)$ for Sequences {name}: ", pad=60)

        candidate_dir = f"./data/analysis_{name}"
        paths = [
            os.path.join(candidate_dir, f)
            for f in os.listdir(candidate_dir)
            if f.endswith(".csv")
        ]

        for j, path in enumerate(paths):

            data = pd.read_csv(
                path,
                usecols=[0, 2, 4],
                names=["structure", "folding_time", "sequence"],
                header=None,
                nrows=times,
            )

            structure = data["structure"].values
            structure = structure[0]
            folding_times = data["folding_time"].values
            sequence = data["sequence"].values
            sequence = sequence[0]

            sorted_folding_times = np.sort(folding_times)

            cdf = np.arange(1, len(sorted_folding_times) + 1) / len(
                sorted_folding_times
            )

            ax.scatter(sorted_folding_times, cdf, label=rf"${name}_{j}$: {sequence}")

            if name == "A":
                ax.axhline(y=0.25, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.3,
                    "25% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                )
                ax.axhline(y=0.7, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.75,
                    "70% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                )

            if name == "B":
                ax.axhline(y=0.5, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.55,
                    "50% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                )
                ax.axhline(y=0.2, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.25,
                    "20% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                )

            textstr = "\n".join((f"MFE Structure: {structure}",))
            anchored_text = AnchoredText(
                textstr,
                loc="lower right",
                prop=dict(size=8, color="black"),
                frameon=True,
            )
            anchored_text.patch.set_boxstyle("round,pad=0.3,rounding_size=0.2")
            anchored_text.patch.set_edgecolor("black")
            anchored_text.patch.set_linewidth(1.5)
            anchored_text.set_bbox_to_anchor((1.0, 0.05, 0, 0), transform=ax.transAxes)
            ax.add_artist(anchored_text)

            ax.set_xlabel("Time")
            ax.set_xscale("log")
            ax.set_ylabel(
                "Cumulative Population Probability"
            )
            ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=1)
            ax.grid(True)
            ax.set_axisbelow(True)

    for i, candidate_id in enumerate(candidate_ids):
        name = candidate_id
        plot_row = i

        candidate_dir = f"./data/analysis_{name}"
        paths = [
            os.path.join(candidate_dir, f)
            for f in os.listdir(candidate_dir)
            if f.endswith(".csv")
        ]

        colors = plt.cm.tab10.colors

        batch, length_order = get_batch_kinfold_folding_times(
            paths=paths, seq_len=seq_len
        )
        dataset_size = len(paths)

        x = batch.x
        y_folding_times = batch.y
        target_y_folding_times = batch.target_y

        for j, training_point in enumerate(training_points):
            plot_column = j + 1
            if len(candidate_ids) == 1:
                ax = axes[plot_column]
            else:
                ax = axes[plot_row, plot_column]
            ax.set_title(
                f"KinPFN "
                + r"$\hat{F}(t)$"
                + f" Approximations using {training_point} Context Times: "
            )

            pred_handles = []
            context_handles = []

            for batch_index in range(dataset_size):
                set_seed(123)

                train_indices = torch.randperm(seq_len)[:training_point]
                test_indices = ~train_indices

                # Create the training and test data
                train_x = x[train_indices, batch_index]
                train_y_folding_times = y_folding_times[train_indices, batch_index]

                test_x = x[:, batch_index]
                test_y_folding_times = y_folding_times[:, batch_index]

                with torch.no_grad():
                    logits = trained_model(
                        train_x[:, None],
                        train_y_folding_times[:, None],
                        test_x[:, None],
                    )

                ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
                ground_truth_cdf = torch.arange(
                    1, len(ground_truth_sorted_folding_times) + 1
                ) / len(ground_truth_sorted_folding_times)

                test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
                train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

                pred_cdf_original = (
                    trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
                )[0][0]

                linspace_extended = torch.linspace(
                    train_y_folding_times_sorted[0] - 1,
                    train_y_folding_times_sorted[-1] + 0.75,
                    1000,
                )
                pred_cdf_linspace_extended = (
                    trained_model.criterion.cdf(logits, linspace_extended)
                )[0][0]

                chosen_color = colors[batch_index % len(colors)]

                pred_handle = ax.scatter(
                    10**linspace_extended,
                    pred_cdf_linspace_extended,
                    marker="o",
                    label=rf"${name}_{batch_index}$ KinPFN " + r"$\hat{F}(t)$",
                    color=chosen_color,
                )
                pred_handles.append(pred_handle)

                context_handle = ax.scatter(
                    10**train_y_folding_times_sorted,
                    torch.zeros_like(train_y_folding_times_sorted),
                    color=chosen_color,
                    marker="x",
                    label=rf"${name}_{batch_index}$ Context Times",
                )
                context_handles.append(context_handle)

                ax.set_xscale("log")
                ax.set_xlabel("Time")
                # ax.set_ylabel(
                #     "Fraction of Molecules " + r"$F(t)$" + " folded into MFE Structure"
                # )

                if name == "A":
                    ax.axhline(y=0.25, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.3,
                        "25% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                    )
                    ax.axhline(y=0.7, color="black", linestyle="--")
                    ax.text(
                        0.3,
                        0.75,
                        "70% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                    )

                if name == "B":
                    ax.axhline(y=0.5, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.55,
                        "50% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                    )
                    ax.axhline(y=0.2, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.25,
                        "20% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                    )

                ax.grid(True)
                ax.set_axisbelow(True)

            legend1 = ax.legend(handles=pred_handles, loc="upper left", fontsize=8)
            ax.add_artist(legend1)
            legend2 = ax.legend(
                handles=context_handles,
                loc="lower right",
                bbox_to_anchor=(1.0, 0.075),
                fontsize=8,
            )
            ax.add_artist(legend2)

    plt.tight_layout()
    plt.show()

In [19]:
def combined_gt_and_pred_plot_main_paper(trained_model, training_points : list[int], candidate_ids: list[str], seed: int = None,times=1000, ):
    set_seed(seed)
    
    seq_len = 1000
    width = 11 * (len(training_points) + 1)

    fig, axes = plt.subplots(
        nrows=len(candidate_ids),
        ncols=len(training_points) + 1,
        figsize=(22, 9.5),
        layout="constrained",
    )
    fig.set_dpi(300)

    fontsize = 19
    label_fontsize = 26

    for i, candidate_id in enumerate(candidate_ids):
        name = candidate_id
        plot_row = i
        if len(candidate_ids) == 1:
            ax = axes[0]
        else:
            ax = axes[plot_row, 0]

        candidate_dir = f"./data/analysis_{name}"
        paths = [
            os.path.join(candidate_dir, f)
            for f in os.listdir(candidate_dir)
            if f.endswith(".csv")
        ]

        for j, path in enumerate(paths):

            data = pd.read_csv(
                path,
                usecols=[0, 2, 4],
                names=["structure", "folding_time", "sequence"],
                header=None,
                nrows=times,
            )

            structure = data["structure"].values
            structure = structure[0]
            folding_times = data["folding_time"].values
            sequence = data["sequence"].values
            sequence = sequence[0]

            sorted_folding_times = np.sort(folding_times)

            cdf = np.arange(1, len(sorted_folding_times) + 1) / len(
                sorted_folding_times
            )

            ax.scatter(sorted_folding_times, cdf, label=rf"$\phi_{j}$: {sequence}")

            if name == "A":
                ax.axhline(y=0.25, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.3,
                    "25% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )
                ax.axhline(y=0.7, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.75,
                    "70% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )

            if name == "B":
                ax.axhline(y=0.5, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.55,
                    "50% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )
                ax.axhline(y=0.2, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.25,
                    "20% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )

            textstr = "\n".join((f"MFE Structure: {structure}",))
            anchored_text = AnchoredText(
                textstr,
                loc="lower right",
                prop=dict(size=fontsize, color="black"),
                frameon=True,
            )
            anchored_text.patch.set_boxstyle("round,pad=0.3,rounding_size=0.2")
            anchored_text.patch.set_edgecolor("black")
            anchored_text.patch.set_linewidth(1.5)
            anchored_text.set_bbox_to_anchor((1.0, 0.05, 0, 0), transform=ax.transAxes)
            ax.add_artist(anchored_text)

            ax.set_xlabel("Time", fontsize=label_fontsize)
            ax.set_xscale("log")
            ax.set_ylabel(
                "Cumulative Population Probability", fontsize=label_fontsize
            )
            ax.tick_params(axis="both", which="major", labelsize=fontsize)
            ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.21), ncol=1, fontsize=fontsize)
            ax.grid(True)
            ax.set_axisbelow(True)

    for i, candidate_id in enumerate(candidate_ids):
        name = candidate_id
        plot_row = i

        candidate_dir = f"./data/analysis_{name}"
        paths = [
            os.path.join(candidate_dir, f)
            for f in os.listdir(candidate_dir)
            if f.endswith(".csv")
        ]

        colors = plt.cm.tab10.colors

        batch, length_order = get_batch_kinfold_folding_times(
            paths=paths, seq_len=seq_len
        )
        dataset_size = len(paths)

        x = batch.x
        y_folding_times = batch.y
        target_y_folding_times = batch.target_y

        for j, training_point in enumerate(training_points):
            plot_column = j + 1
            if len(candidate_ids) == 1:
                ax = axes[plot_column]
            else:
                ax = axes[plot_row, plot_column]
            

            pred_handles = []
            context_handles = []

            for batch_index in range(dataset_size):
                set_seed(123)

                train_indices = torch.randperm(seq_len)[:training_point]
                test_indices = ~train_indices

                # Create the training and test data
                train_x = x[train_indices, batch_index]
                train_y_folding_times = y_folding_times[train_indices, batch_index]

                test_x = x[:, batch_index]
                test_y_folding_times = y_folding_times[:, batch_index]

                with torch.no_grad():
                    logits = trained_model(
                        train_x[:, None],
                        train_y_folding_times[:, None],
                        test_x[:, None],
                    )

                ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
                ground_truth_cdf = torch.arange(
                    1, len(ground_truth_sorted_folding_times) + 1
                ) / len(ground_truth_sorted_folding_times)

                test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
                train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

                pred_cdf_original = (
                    trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
                )[0][0]

                linspace_extended = torch.linspace(
                    train_y_folding_times_sorted[0] - 1,
                    train_y_folding_times_sorted[-1] + 0.75,
                    1000,
                )
                pred_cdf_linspace_extended = (
                    trained_model.criterion.cdf(logits, linspace_extended)
                )[0][0]

                chosen_color = colors[batch_index % len(colors)]

                pred_handle = ax.scatter(
                    10**linspace_extended,
                    pred_cdf_linspace_extended,
                    marker="o",
                    label=rf"$\phi_{batch_index}$ KinPFN",
                    color=chosen_color,
                )
                pred_handles.append(pred_handle)

                context_handle = ax.scatter(
                    10**train_y_folding_times_sorted,
                    torch.zeros_like(train_y_folding_times_sorted),
                    color=chosen_color,
                    marker="x",
                    label=rf"$\phi_{batch_index}$ Context Times",
                )
                context_handles.append(context_handle)

                ax.set_xscale("log")
                ax.set_xlabel("Time", fontsize=label_fontsize)

                if name == "A":
                    ax.axhline(y=0.25, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.3,
                        "25% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )
                    ax.axhline(y=0.7, color="black", linestyle="--")
                    ax.text(
                        0.3,
                        0.75,
                        "70% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )

                if name == "B":
                    ax.axhline(y=0.5, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.55,
                        "50% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )
                    ax.axhline(y=0.2, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.25,
                        "20% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )

                ax.tick_params(axis="both", which="major", labelsize=fontsize)
                ax.grid(True)
                ax.set_axisbelow(True)

            legend1 = ax.legend(handles=pred_handles, loc="upper left", fontsize=fontsize)
            ax.add_artist(legend1)
            legend2 = ax.legend(
                handles=context_handles,
                loc="lower right",
                bbox_to_anchor=(1.0, 0.075),
                fontsize=fontsize,
            )
            ax.add_artist(legend2)

    plt.tight_layout()
    plt.show()

In [20]:
def combined_gt_and_pred_plot_appendix(trained_model, training_points : list[int], candidate_ids: list[str], seed: int = None,times=1000, ):
    set_seed(seed)
    
    seq_len = 1000
    width = 10 * (len(training_points) + 1)

    fig, axes = plt.subplots(
        nrows=len(candidate_ids),
        ncols=len(training_points) + 1,
        figsize=(width, 9),
        layout="constrained",
    )
    fig.set_dpi(300)

    fontsize = 18
    label_fontsize = 22

    for i, candidate_id in enumerate(candidate_ids):
        name = candidate_id
        plot_row = i
        if len(candidate_ids) == 1:
            ax = axes[0]
        else:
            ax = axes[plot_row, 0]


        candidate_dir = f"./data/analysis_{name}"
        paths = [
            os.path.join(candidate_dir, f)
            for f in os.listdir(candidate_dir)
            if f.endswith(".csv")
        ]

        for j, path in enumerate(paths):

            data = pd.read_csv(
                path,
                usecols=[0, 2, 4],
                names=["structure", "folding_time", "sequence"],
                header=None,
                nrows=times,
            )

            structure = data["structure"].values
            structure = structure[0]
            folding_times = data["folding_time"].values
            sequence = data["sequence"].values
            sequence = sequence[0]

            sorted_folding_times = np.sort(folding_times)

            cdf = np.arange(1, len(sorted_folding_times) + 1) / len(
                sorted_folding_times
            )

            ax.scatter(sorted_folding_times, cdf, label=rf"$\phi_{j}$: {sequence}")

            if name == "A":
                ax.axhline(y=0.25, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.3,
                    "25% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )
                ax.axhline(y=0.7, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.75,
                    "70% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )

            if name == "B":
                ax.axhline(y=0.5, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.55,
                    "50% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )
                ax.axhline(y=0.2, color="black", linestyle="--")
                ax.text(
                    0.1,
                    0.25,
                    "20% folded",
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    size=fontsize
                )

            textstr = "\n".join((f"MFE Structure: {structure}",))
            anchored_text = AnchoredText(
                textstr,
                loc="lower right",
                prop=dict(size=fontsize, color="black"),
                frameon=True,
            )
            anchored_text.patch.set_boxstyle("round,pad=0.3,rounding_size=0.2")
            anchored_text.patch.set_edgecolor("black")
            anchored_text.patch.set_linewidth(1.5)
            anchored_text.set_bbox_to_anchor((1.0, 0.05, 0, 0), transform=ax.transAxes)
            ax.add_artist(anchored_text)

            ax.set_xlabel("Time", fontsize=label_fontsize)
            ax.set_xscale("log")
            ax.set_ylabel(
                "Cumulative Population Probability", fontsize=label_fontsize
            )
            ax.tick_params(axis="both", which="major", labelsize=fontsize)
            ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.21), ncol=1, fontsize=fontsize)
            ax.grid(True)
            ax.set_axisbelow(True)

    for i, candidate_id in enumerate(candidate_ids):
        name = candidate_id
        plot_row = i

        candidate_dir = f"./data/analysis_{name}"
        paths = [
            os.path.join(candidate_dir, f)
            for f in os.listdir(candidate_dir)
            if f.endswith(".csv")
        ]

        colors = plt.cm.tab10.colors

        batch, length_order = get_batch_kinfold_folding_times(
            paths=paths, seq_len=seq_len
        )
        dataset_size = len(paths)

        x = batch.x
        y_folding_times = batch.y
        target_y_folding_times = batch.target_y

        for j, training_point in enumerate(training_points):
            plot_column = j + 1
            if len(candidate_ids) == 1:
                ax = axes[plot_column]
            else:
                ax = axes[plot_row, plot_column]
            

            pred_handles = []
            context_handles = []

            for batch_index in range(dataset_size):
                set_seed(123)

                train_indices = torch.randperm(seq_len)[:training_point]
                test_indices = ~train_indices

                # Create the training and test data
                train_x = x[train_indices, batch_index]
                train_y_folding_times = y_folding_times[train_indices, batch_index]

                test_x = x[:, batch_index]
                test_y_folding_times = y_folding_times[:, batch_index]

                with torch.no_grad():
                    logits = trained_model(
                        train_x[:, None],
                        train_y_folding_times[:, None],
                        test_x[:, None],
                    )

                ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
                ground_truth_cdf = torch.arange(
                    1, len(ground_truth_sorted_folding_times) + 1
                ) / len(ground_truth_sorted_folding_times)

                test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
                train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

                pred_cdf_original = (
                    trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
                )[0][0]

                linspace_extended = torch.linspace(
                    train_y_folding_times_sorted[0] - 1,
                    train_y_folding_times_sorted[-1] + 0.75,
                    1000,
                )
                pred_cdf_linspace_extended = (
                    trained_model.criterion.cdf(logits, linspace_extended)
                )[0][0]

                # Metrics
                # MAE
                single_absolute_error = np.abs(pred_cdf_original - ground_truth_cdf)
                mae = single_absolute_error.mean()

                # NLL
                nll_loss = trained_model.criterion.forward(
                    logits=logits, y=test_y_folding_times_sorted
                )
                mean_nll_loss = nll_loss.mean()

                print(f"#KinPFN Context Times: {len(train_y_folding_times)}")
                print(f"MAE: {mae:.4f}")
                print(f"NLL: {mean_nll_loss:.4f}")

                chosen_color = colors[batch_index % len(colors)]

                pred_handle = ax.scatter(
                    10**linspace_extended,
                    pred_cdf_linspace_extended,
                    marker="o",
                    label=rf"$\phi_{batch_index}$ KinPFN",
                    color=chosen_color,
                )
                pred_handles.append(pred_handle)

                context_handle = ax.scatter(
                    10**train_y_folding_times_sorted,
                    torch.zeros_like(train_y_folding_times_sorted),
                    color=chosen_color,
                    marker="x",
                    label=rf"$\phi_{batch_index}$ Context Times",
                )
                context_handles.append(context_handle)

                ax.set_xscale("log")
                ax.set_xlabel("Time", fontsize=label_fontsize)

                if name == "A":
                    ax.axhline(y=0.25, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.3,
                        "25% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )
                    ax.axhline(y=0.7, color="black", linestyle="--")
                    ax.text(
                        0.3,
                        0.75,
                        "70% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )

                if name == "B":
                    ax.axhline(y=0.5, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.55,
                        "50% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )
                    ax.axhline(y=0.2, color="black", linestyle="--")
                    ax.text(
                        0.1,
                        0.25,
                        "20% folded",
                        color="black",
                        transform=ax.get_yaxis_transform(),
                        size=fontsize
                    )

                ax.tick_params(axis="both", which="major", labelsize=fontsize)
                ax.grid(True)
                ax.set_axisbelow(True)

            legend1 = ax.legend(handles=pred_handles, loc="upper left", fontsize=fontsize)
            ax.add_artist(legend1)
            legend2 = ax.legend(
                handles=context_handles,
                loc="lower right",
                bbox_to_anchor=(1.0, 0.075),
                fontsize=fontsize,
            )
            ax.add_artist(legend2)

    plt.tight_layout()
    plt.show()

In [None]:

seed = 45665
combined_gt_and_pred_plot_main_paper(trained_model,[10], ["A"], seed=seed)
combined_gt_and_pred_plot_appendix(trained_model,[10, 25, 50], ["A"], seed=seed)