In [74]:
import os
import random
import torch
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
from matplotlib.lines import Line2D
import pandas as pd
import polars as pl
import numpy as np
from pathlib import Path
from collections import defaultdict

from kinpfn.priors import Batch
from kinpfn.model import KINPFN

from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.neighbors import KernelDensity
from scipy.integrate import cumulative_trapezoid

In [75]:
def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [76]:
def get_dataset_size(val_set_dir):
    dataset_size = 0
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                dataset_size += 1
    return dataset_size


def get_batch_testing_folding_times(val_set_dir, seq_len=100, num_features=1, **kwargs):

    dataset_size = get_dataset_size(val_set_dir)

    x = torch.zeros(seq_len, dataset_size, num_features)
    y = torch.zeros(seq_len, dataset_size)

    batch_index = 0
    length_order = []
    file_names_in_order = []
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                path = os.path.join(subdir, file)
                data = pl.read_csv(
                    path,
                    has_header=False,
                    columns=[2, 4],
                    n_rows=seq_len,
                )

                folding_times = data["column_3"].to_numpy()
                sequence = data["column_5"][0]
                length_order.append(len(sequence))
                file_names_in_order.append(file)
                sorted_folding_times = np.sort(folding_times)

                # Filter out points where x > 10^15 and x < 10^-6
                valid_indices = (sorted_folding_times <= 10**15) & (
                    sorted_folding_times >= 10**-6
                )
                sorted_folding_times = sorted_folding_times[valid_indices]

                # Adjust the sequence length by sampling
                current_seq_len = len(sorted_folding_times)
                if current_seq_len <= 0:
                    continue

                if current_seq_len < seq_len:
                    # Repeat the sorted_folding_times and cdf to match the sequence length (Oversampling)
                    repeat_factor = seq_len // current_seq_len + 1
                    sorted_folding_times = np.tile(sorted_folding_times, repeat_factor)[
                        :seq_len
                    ]
                else:
                    sorted_folding_times = sorted_folding_times[:seq_len]

                x[:, batch_index, 0] = torch.tensor(np.zeros(seq_len))
                y[:, batch_index] = torch.tensor(sorted_folding_times)
                batch_index += 1

    y = torch.log10(y)
    return Batch(x=x, y=y, target_y=y), length_order, file_names_in_order

In [None]:
model_path = "../../../../models/final_kinpfn_model_1400_1000_1000_86_50_2.5588748050825984e-05_256_4_512_8_0.0_0.0.pt"

kinpfn = KINPFN(
    model_path=model_path,
)
trained_model = kinpfn.model

if trained_model is not None:
    print("Load trained model!")
else:
    print("No trained model found!")
    exit()

In [78]:
def plot_kinpfn_on_selected_testing_set_seq(trained_model, training_points, seed=None):
    if seed is None:
        seed = random.randint(0, 100000)
    set_seed(seed=seed)
    print(f"Seed: {seed}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000

    num_evaluations = 2

    val_set_dir = "./data"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    evaluations = 0
    width = 7.25 * len(training_points)
    fig, axes = plt.subplots(
        nrows=num_evaluations,
        ncols=len(training_points),
        figsize=(width, 12),
        layout="constrained",
    )
    fig.set_dpi(300)
    for i in range(num_evaluations):
        row = i
    
        batch_index = i

        for j, training_point in enumerate(training_points):
            col = j
            ax = axes[row, col]

            all_pred_cdf_list = []
            all_mae_list = []
            for _ in range(20):
                train_indices = torch.randperm(seq_len)[:training_point]

                # Create the training and test data
                train_x = x[train_indices, batch_index]
                train_y_folding_times = y_folding_times[train_indices, batch_index]

                test_x = x[:, batch_index]
                test_y_folding_times = y_folding_times[:, batch_index]

                train_x = train_x.to(device)
                train_y_folding_times = train_y_folding_times.to(device)
                test_x = test_x.to(device)
                test_y_folding_times = test_y_folding_times.to(device)

                with torch.no_grad():
                    logits = trained_model(
                        train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                    )

                ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
                ground_truth_cdf = torch.arange(
                    1, len(ground_truth_sorted_folding_times) + 1
                ) / len(ground_truth_sorted_folding_times)

                test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
                train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

                pred_cdf_original = (
                    trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
                )[0][0]

                linspace_extended = torch.linspace(
                    train_y_folding_times_sorted[0] - 3,
                    train_y_folding_times_sorted[-1] + 1,
                    1000,
                )
                pred_cdf_linspace_extended = (
                    trained_model.criterion.cdf(logits, linspace_extended)
                )[0][0]

                all_pred_cdf_list.append(pred_cdf_linspace_extended.cpu().numpy())
                single_absolute_error = np.abs(pred_cdf_original.cpu().numpy() - ground_truth_cdf.cpu().numpy())
                mae = single_absolute_error.mean()
                all_mae_list.append(mae)

            mean_pred_cdf = np.mean(all_pred_cdf_list, axis=0)
            std_pred_cdf = np.std(all_pred_cdf_list, axis=0)
            mean_mae = np.mean(all_mae_list)

            ax.scatter(
                10**ground_truth_sorted_folding_times.cpu().numpy(),
                ground_truth_cdf.cpu().numpy(),
                color="#000000",
                marker="x",
                label="Target",
            )


            ax.plot(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf,
                color="#cc101fc7",
                marker=".",
                label="KinPFN " + r"(mean $\pm$ std)",
            )

            ax.fill_between(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf - std_pred_cdf,
                mean_pred_cdf + std_pred_cdf,
                color="#cc101fc7",
                alpha=0.5,
            )

            ax.scatter(
                10**train_y_folding_times_sorted.cpu().numpy(),
                np.zeros_like(train_y_folding_times_sorted.cpu().numpy()),
                color="blue",
                marker="o",
                label="Context First Passage Times",
            )

            ax.set_ylim(0, 1)
            ax.set_xscale("log")
            ax.set_xlabel("Time", fontsize=18)
            ax.tick_params(axis='both', which='major', labelsize=16)

            if col == 0:
                ax.set_ylabel(
                    "Cumulative Population Probability", fontsize=18
                )

            # MAE
            single_absolute_error = np.abs(pred_cdf_original - ground_truth_cdf)
            mae = single_absolute_error.mean()

            # NLL
            nll_loss = trained_model.criterion.forward(
                logits=logits, y=test_y_folding_times_sorted
            )
            mean_nll_loss = nll_loss.mean()

            print(f"#KinPFN Context Times: {len(train_y_folding_times)}")
            print(f"MAE: {mean_mae:.4f}")

            ax.legend(fontsize=15)

        evaluations += 1

    
    plt.tight_layout()
    plt.show()

In [79]:
def evaluate_model_on_testing_set_gmm_bgmm_kde(trained_model):
    set_seed(seed=123)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    training_points = [10, 25, 50, 75, 100, 250, 500, 750, 1000]

    val_set_dir = "../../../../kinpfn_testing_set"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    
    # Iterate over different values of n_components for GMM and DP-GMM
    n_components_list = [1, 2, 3, 4, 5]

    for training_point in training_points:
        print(f"Training Point: {training_point}")
        mae_losses = []
        mean_nll_losses = []
        kde_nll_losses = []
        gmm_results = {n: [] for n in n_components_list}
        bgmm_results = {n: [] for n in n_components_list}

        for i in indices:
            batch_index = i

            train_indices = torch.randperm(seq_len)[:training_point]

            # Create the training and test data
            train_x = x[train_indices, batch_index]
            train_y_folding_times = y_folding_times[train_indices, batch_index]

            test_x = x[:, batch_index]
            test_y_folding_times = y_folding_times[:, batch_index]

            train_x = train_x.to(device)
            train_y_folding_times = train_y_folding_times.to(device)
            test_x = test_x.to(device)
            test_y_folding_times = test_y_folding_times.to(device)

            with torch.no_grad():
                # we add our batch dimension, as our transformer always expects that
                logits = trained_model(
                    train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                )

            ### CALCULATE GROUND TRUTH CDF WHICH WE WANT TO PREDICT
            ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
            ground_truth_cdf = torch.arange(
                1, len(ground_truth_sorted_folding_times) + 1
            ) / len(ground_truth_sorted_folding_times)

            test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)

            ### CDF FUNCTION
            pred_cdf_original = (
                trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
            )[0][0]

            # MAE
            single_absolute_error = np.abs(pred_cdf_original - ground_truth_cdf)
            mae = single_absolute_error.mean()
            mae_losses.append(mae)

            # NLL
            nll_loss = trained_model.criterion.forward(
                logits=logits, y=test_y_folding_times_sorted
            )
            mean_nll_loss = nll_loss.mean()
            mean_nll_losses.append(mean_nll_loss)

            # KDE
            kde_train_y_folding_times = train_y_folding_times.reshape(-1, 1)
            kde_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(-1, 1)
            kde = KernelDensity(kernel='gaussian', bandwidth=0.3520031472796679).fit(kde_train_y_folding_times)
            log_likelihood = kde.score_samples(kde_test_y_folding_times_sorted)  # Total log-likelihood
            kde_nll = -np.mean(log_likelihood)
            kde_nll_losses.append(kde_nll)


            # Competitor GMM and DP-GMM for multiple n_components
            for n_components in n_components_list:
                # GMM
                gmm = GaussianMixture(n_components=n_components, max_iter=100000)
                gmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1)
                gmm.fit(gmm_train_y_folding_times)
                gmm_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(-1, 1)
                gmm_nll = -gmm.score(gmm_test_y_folding_times_sorted)
                gmm_results[n_components].append(gmm_nll)

                # DP-GMM
                bgmm = BayesianGaussianMixture(n_components=n_components, weight_concentration_prior_type="dirichlet_process",weight_concentration_prior=0.0009794696670695395, max_iter=100000)
                bgmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1)
                bgmm.fit(bgmm_train_y_folding_times)
                bgmm_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(-1, 1)
                bgmm_nll = -bgmm.score(bgmm_test_y_folding_times_sorted)
                bgmm_results[n_components].append(bgmm_nll)

        # Calculate mean MAE and NLL losses for the model
        mae_losses = torch.stack(mae_losses)
        mae_losses = mae_losses.mean()
        print(f"KinPFN Mean MAE: {mae_losses}")
        mean_nll_losses = torch.stack(mean_nll_losses)
        mean_nll_losses = mean_nll_losses.mean()
        print(f"KinPFN Mean NLL Loss: {mean_nll_losses}")

        # Calculate mean KDE NLL loss for the model
        kde_nll_losses = torch.tensor(kde_nll_losses)
        kde_nll_losses = kde_nll_losses.mean()
        print(f"KDE Mean NLL Loss: {kde_nll_losses}")

        # Calculate and print mean NLL losses for GMM and DP-GMM for each n_component
        for n_components in n_components_list:
            gmm_mean_nll_losses = torch.tensor(gmm_results[n_components])
            gmm_mean_nll_losses = gmm_mean_nll_losses.mean()
            print(f"GMM Mean NLL Loss (n_components={n_components}): {gmm_mean_nll_losses}")

            bgmm_mean_nll_losses = torch.tensor(bgmm_results[n_components])
            bgmm_mean_nll_losses = bgmm_mean_nll_losses.mean()
            print(f"DP-GMM Mean NLL Loss (n_components={n_components}): {bgmm_mean_nll_losses}")


In [80]:
# KinPFN, KDE, GMM and DP-GMM on the test set
#evaluate_model_on_testing_set_gmm_bgmm_kde(trained_model)

In [81]:
def evaluate_model_on_testing_set_gmm_bgmm_nll_mae(trained_model):

    set_seed(seed=123)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    training_points = [10, 25, 50, 75, 100, 250, 500, 750, 1000]

    val_set_dir = "../../../../kinpfn_testing_set"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    
    n_components_list = [2, 3, 4, 5]

    for training_point in training_points:
        print(f"Training Point: {training_point}")
        mae_losses = []
        mean_nll_losses = []
        kde_mae_losses = []
        kde_nll_losses = []
        gmm_mae_losses = {n: [] for n in n_components_list}
        bgmm_mae_losses = {n: [] for n in n_components_list}
        gmm_nll_losses = {n: [] for n in n_components_list}
        bgmm_nll_losses = {n: [] for n in n_components_list}

        for i in indices:
            batch_index = i

            train_indices = torch.randperm(seq_len)[:training_point]

            train_x = x[train_indices, batch_index]
            train_y_folding_times = y_folding_times[train_indices, batch_index]

            test_x = x[:, batch_index]
            test_y_folding_times = y_folding_times[:, batch_index]

            train_x = train_x.to(device)
            train_y_folding_times = train_y_folding_times.to(device)
            test_x = test_x.to(device)
            test_y_folding_times = test_y_folding_times.to(device)

            with torch.no_grad():
                # we add our batch dimension, as our transformer always expects that
                logits = trained_model(
                    train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                )

            ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
            ground_truth_cdf = torch.arange(
                1, len(ground_truth_sorted_folding_times) + 1
            ) / len(ground_truth_sorted_folding_times)

            #test_y_folding_times_sorted = test_y_folding_times.sort().values.cpu().numpy()
            test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
            pred_cdf_original = (
                trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
            )[0][0]

            # MAE for KinPFN
            single_absolute_error = np.abs(pred_cdf_original - ground_truth_cdf)
            mae = single_absolute_error.mean()
            mae_losses.append(mae)

            # NLL for KinPFN
            nll_loss = trained_model.criterion.forward(
                logits=logits, y=test_y_folding_times_sorted
            )
            mean_nll_loss = nll_loss.mean()
            mean_nll_losses.append(mean_nll_loss)

            # KDE MAE and NLL
            kde_train_y_folding_times = train_y_folding_times.reshape(-1, 1).cpu().numpy()
            kde = KernelDensity(kernel='gaussian', bandwidth=0.3520031472796679).fit(kde_train_y_folding_times)
            kde_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(-1, 1)
            kde_pdf = np.exp(kde.score_samples(kde_test_y_folding_times_sorted))
            kde_cdf = np.cumsum(kde_pdf) / np.sum(kde_pdf)
            kde_mae = np.abs(kde_cdf - ground_truth_cdf.cpu().numpy()).mean()
            kde_mae_losses.append(kde_mae)

            kde_nll = -np.mean(kde.score_samples(kde_test_y_folding_times_sorted))
            kde_nll_losses.append(kde_nll)

            # GMM and DP-GMM MAE and NLL
            for n_components in n_components_list:
                gmm = GaussianMixture(n_components=n_components, max_iter=100000)
                gmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1).cpu().numpy()
                gmm.fit(gmm_train_y_folding_times)
                gmm_pdf = np.exp(gmm.score_samples(kde_test_y_folding_times_sorted))
                gmm_cdf = np.cumsum(gmm_pdf) / np.sum(gmm_pdf)
                gmm_mae = np.abs(gmm_cdf - ground_truth_cdf.cpu().numpy()).mean()
                gmm_mae_losses[n_components].append(gmm_mae)

                gmm_nll = -gmm.score(kde_test_y_folding_times_sorted)
                gmm_nll_losses[n_components].append(gmm_nll)

                bgmm = BayesianGaussianMixture(
                    n_components=n_components,
                    weight_concentration_prior_type="dirichlet_process",
                    weight_concentration_prior=0.0009794696670695395,
                    max_iter=100000
                )
                bgmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1).cpu().numpy()
                bgmm.fit(bgmm_train_y_folding_times)
                bgmm_pdf = np.exp(bgmm.score_samples(kde_test_y_folding_times_sorted))
                bgmm_cdf = np.cumsum(bgmm_pdf) / np.sum(bgmm_pdf)
                bgmm_mae = np.abs(bgmm_cdf - ground_truth_cdf.cpu().numpy()).mean()
                bgmm_mae_losses[n_components].append(bgmm_mae)

                bgmm_nll = -bgmm.score(kde_test_y_folding_times_sorted)
                bgmm_nll_losses[n_components].append(bgmm_nll)

        # Calculate and print losses
        mae_losses = torch.tensor(mae_losses).mean()
        print(f"KinPFN Mean MAE Loss: {mae_losses}")

        mean_nll_losses = torch.tensor(mean_nll_losses).mean()
        print(f"KinPFN Mean NLL Loss: {mean_nll_losses}")

        kde_mae_losses = torch.tensor(kde_mae_losses).mean()
        print(f"KDE Mean MAE: {kde_mae_losses}")

        kde_nll_losses = torch.tensor(kde_nll_losses).mean()
        print(f"KDE Mean NLL Loss: {kde_nll_losses}")

        for n_components in n_components_list:
            gmm_mae_losses_tensor = torch.tensor(gmm_mae_losses[n_components]).mean()
            print(f"GMM Mean MAE (n_components={n_components}): {gmm_mae_losses_tensor}")

            bgmm_mae_losses_tensor = torch.tensor(bgmm_mae_losses[n_components]).mean()
            print(f"DP-GMM Mean MAE (n_components={n_components}): {bgmm_mae_losses_tensor}")

            gmm_nll_losses_tensor = torch.tensor(gmm_nll_losses[n_components]).mean()
            print(f"GMM Mean NLL Loss (n_components={n_components}): {gmm_nll_losses_tensor}")

            bgmm_nll_losses_tensor = torch.tensor(bgmm_nll_losses[n_components]).mean()
            print(f"DP-GMM Mean NLL Loss (n_components={n_components}): {bgmm_nll_losses_tensor}")

In [82]:
#evaluate_model_on_testing_set_gmm_bgmm_nll_mae(trained_model)

In [83]:
def evaluate_model_on_testing_set_gmm_bgmm_rebuttal_ks(trained_model):
    set_seed(seed=123)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    training_points = [10, 25, 50, 75, 100, 250, 500, 750, 1000]

    val_set_dir = "../../../../kinpfn_testing_set"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    n_components_list = [2, 3, 4, 5]

    for training_point in training_points:
        print(f"Training Point: {training_point}")
        kinpfn_ks_losses = []
        kde_ks_losses = []
        gmm_ks_losses = {n: [] for n in n_components_list}
        bgmm_ks_losses = {n: [] for n in n_components_list}

        for i in indices:
            batch_index = i

            train_indices = torch.randperm(seq_len)[:training_point]

            train_x = x[train_indices, batch_index]
            train_y_folding_times = y_folding_times[train_indices, batch_index]

            test_x = x[:, batch_index]
            test_y_folding_times = y_folding_times[:, batch_index]

            train_x = train_x.to(device)
            train_y_folding_times = train_y_folding_times.to(device)
            test_x = test_x.to(device)
            test_y_folding_times = test_y_folding_times.to(device)

            with torch.no_grad():
                logits = trained_model(
                    train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                )

            ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
            ground_truth_cdf = torch.arange(
                1, len(ground_truth_sorted_folding_times) + 1
            ) / len(ground_truth_sorted_folding_times)

            test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)

            pred_cdf_original = (
                trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
            )[0][0]

            # KS for KinPFN
            ks_loss = np.max(np.abs(pred_cdf_original.cpu().numpy() - ground_truth_cdf.cpu().numpy()))
            kinpfn_ks_losses.append(ks_loss)


            # KDE
            kde_train_y_folding_times = train_y_folding_times.reshape(-1, 1)
            kde_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(-1, 1)
            kde = KernelDensity(kernel="gaussian", bandwidth=0.3520031472796679).fit(
                kde_train_y_folding_times
            )
            kde_pdf = np.exp(kde.score_samples(kde_test_y_folding_times_sorted))
            kde_cdf = np.cumsum(kde_pdf) / np.sum(kde_pdf)  # Normalize to form CDF
            kde_ks = np.max(np.abs(kde_cdf - ground_truth_cdf.cpu().numpy()))
            kde_ks_losses.append(kde_ks)

            # GMM and DP-GMM
            for n_components in n_components_list:
                # GMM
                gmm = GaussianMixture(n_components=n_components, max_iter=100000)
                gmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1)
                gmm.fit(gmm_train_y_folding_times)
                gmm_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(-1, 1)
                gmm_pdf = np.exp(gmm.score_samples(gmm_test_y_folding_times_sorted))
                gmm_cdf = np.cumsum(gmm_pdf) / np.sum(gmm_pdf)  # Normalize to form CDF
                gmm_ks = np.max(np.abs(gmm_cdf - ground_truth_cdf.cpu().numpy()))
                gmm_ks_losses[n_components].append(gmm_ks)

                # DP-GMM
                bgmm = BayesianGaussianMixture(
                    n_components=n_components,
                    weight_concentration_prior_type="dirichlet_process",
                    weight_concentration_prior=0.0009794696670695395,
                    max_iter=100000,
                )
                bgmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1)
                bgmm.fit(bgmm_train_y_folding_times)
                bgmm_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(
                    -1, 1
                )
                bgmm_pdf = np.exp(bgmm.score_samples(bgmm_test_y_folding_times_sorted))
                bgmm_cdf = np.cumsum(bgmm_pdf) / np.sum(bgmm_pdf)  # Normalize to form CDF
                bgmm_ks = np.max(np.abs(bgmm_cdf - ground_truth_cdf.cpu().numpy()))
                bgmm_ks_losses[n_components].append(bgmm_ks)

        print(f"KinPFN Mean KS Loss: {torch.tensor(kinpfn_ks_losses).mean()}")
        print(f"KDE Mean KS Loss: {torch.tensor(kde_ks_losses).mean()}")
        for n_components in n_components_list:
            gmm_mean_ks = torch.tensor(gmm_ks_losses[n_components]).mean()
            print(f"GMM Mean KS Loss (n_components={n_components}): {gmm_mean_ks}")
            bgmm_mean_ks = torch.tensor(bgmm_ks_losses[n_components]).mean()
            print(f"DP-GMM Mean KS Loss (n_components={n_components}): {bgmm_mean_ks}")


In [84]:
#evaluate_model_on_testing_set_gmm_bgmm_rebuttal_ks(trained_model)

In [None]:
def get_batch_testing_folding_times_with_seq_struc(val_set_dir, seq_len=100, num_features=1, **kwargs):

    dataset_size = get_dataset_size(val_set_dir)

    x = torch.zeros(seq_len, dataset_size, num_features)
    y = torch.zeros(seq_len, dataset_size)

    batch_index = 0
    length_order = []
    file_names_in_order = []
    sequence_order = []
    structure_order = []
    mfe_order = []
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                path = os.path.join(subdir, file)
                data = pl.read_csv(
                    path,
                    has_header=False,
                    columns=[0,1, 2, 4],
                    n_rows=seq_len,
                )

                folding_times = data["column_3"].to_numpy()
                sequence = data["column_5"][0]
                length_order.append(len(sequence))
                sequence_order.append(sequence)
                structure_order.append(data["column_1"][0])
                mfe_order.append(data["column_2"][0])
                file_names_in_order.append(file)
                sorted_folding_times = np.sort(folding_times)

                # Filter out points where x > 10^15 and x < 10^-6
                valid_indices = (sorted_folding_times <= 10**15) & (
                    sorted_folding_times >= 10**-6
                )
                sorted_folding_times = sorted_folding_times[valid_indices]

                # Adjust the sequence length by sampling
                current_seq_len = len(sorted_folding_times)
                if current_seq_len <= 0:
                    continue

                if current_seq_len < seq_len:
                    # Repeat the sorted_folding_times and cdf to match the sequence length (Oversampling)
                    repeat_factor = seq_len // current_seq_len + 1
                    sorted_folding_times = np.tile(sorted_folding_times, repeat_factor)[
                        :seq_len
                    ]
                else:
                    sorted_folding_times = sorted_folding_times[:seq_len]

                x[:, batch_index, 0] = torch.tensor(np.zeros(seq_len))
                y[:, batch_index] = torch.tensor(sorted_folding_times)
                batch_index += 1

    y = torch.log10(y)
    return Batch(x=x, y=y, target_y=y), length_order, file_names_in_order, sequence_order, structure_order, mfe_order



def evaluate_model_on_testing_set_individual_metrics(trained_model):
    set_seed(seed=123)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    training_points = [25]

    val_set_dir = "../../../../kinpfn_testing_set"
    batch, length_order, file_names_in_order, sequence_order, structure_order, mfe_order = get_batch_testing_folding_times_with_seq_struc(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    n_components_list = [2, 3, 4, 5]

    # Store individual results
    all_results = []

    for training_point in training_points:
        print(f"Training Point: {training_point}")

        for i in indices:
            batch_index = i
            result = {"example_index": batch_index, "training_point": training_point}

            train_indices = torch.randperm(seq_len)[:training_point]
            train_x = x[train_indices, batch_index]
            train_y_folding_times = y_folding_times[train_indices, batch_index]
            test_x = x[:, batch_index]
            test_y_folding_times = y_folding_times[:, batch_index]

            # Move data to device
            train_x = train_x.to(device)
            train_y_folding_times = train_y_folding_times.to(device)
            test_x = test_x.to(device)
            test_y_folding_times = test_y_folding_times.to(device)

            with torch.no_grad():
                logits = trained_model(
                    train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                )

            # Ground truth CDF
            ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
            ground_truth_cdf = torch.arange(
                1, len(ground_truth_sorted_folding_times) + 1
            ) / len(ground_truth_sorted_folding_times)
            test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)

            pred_cdf_original = (
                trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
            )[0][0]

            # Metrics for KinPFN
            kinpfn_ks = np.max(np.abs(pred_cdf_original.cpu().numpy() - ground_truth_cdf.cpu().numpy()))
            kinpfn_mae = np.abs(pred_cdf_original - ground_truth_cdf).mean()
            kinpfn_nll = trained_model.criterion.forward(
                logits=logits, y=test_y_folding_times_sorted
            ).mean().item()

            result["kinpfn"] = {"KS": kinpfn_ks, "MAE": kinpfn_mae.item(), "NLL": kinpfn_nll}

            # KDE Metrics
            kde_train_y_folding_times = train_y_folding_times.reshape(-1, 1).cpu().numpy()
            kde_test_y_folding_times_sorted = test_y_folding_times_sorted.reshape(-1, 1)
            kde = KernelDensity(kernel="gaussian", bandwidth=0.3520031472796679).fit(kde_train_y_folding_times)
            kde_pdf = np.exp(kde.score_samples(kde_test_y_folding_times_sorted))
            kde_cdf = np.cumsum(kde_pdf) / np.sum(kde_pdf)
            kde_ks = np.max(np.abs(kde_cdf - ground_truth_cdf.cpu().numpy()))
            kde_mae = np.abs(kde_cdf - ground_truth_cdf.cpu().numpy()).mean()
            kde_nll = -np.mean(kde.score_samples(kde_test_y_folding_times_sorted))

            result["kde"] = {"KS": kde_ks, "MAE": kde_mae, "NLL": kde_nll}

            # GMM and DP-GMM Metrics
            result["gmm"] = {}
            result["bgmm"] = {}

            for n_components in n_components_list:
                # GMM
                gmm = GaussianMixture(n_components=n_components, max_iter=100000)
                gmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1).cpu().numpy()
                gmm.fit(gmm_train_y_folding_times)
                gmm_pdf = np.exp(gmm.score_samples(kde_test_y_folding_times_sorted))
                gmm_cdf = np.cumsum(gmm_pdf) / np.sum(gmm_pdf)
                gmm_ks = np.max(np.abs(gmm_cdf - ground_truth_cdf.cpu().numpy()))
                gmm_mae = np.abs(gmm_cdf - ground_truth_cdf.cpu().numpy()).mean()
                gmm_nll = -gmm.score(kde_test_y_folding_times_sorted)

                result["gmm"][n_components] = {"KS": gmm_ks, "MAE": gmm_mae, "NLL": gmm_nll}

                # DP-GMM
                bgmm = BayesianGaussianMixture(
                    n_components=n_components,
                    weight_concentration_prior_type="dirichlet_process",
                    weight_concentration_prior=0.0009794696670695395,
                    max_iter=100000,
                )
                bgmm_train_y_folding_times = train_y_folding_times.reshape(-1, 1).cpu().numpy()
                bgmm.fit(bgmm_train_y_folding_times)
                bgmm_pdf = np.exp(bgmm.score_samples(kde_test_y_folding_times_sorted))
                bgmm_cdf = np.cumsum(bgmm_pdf) / np.sum(bgmm_pdf)
                bgmm_ks = np.max(np.abs(bgmm_cdf - ground_truth_cdf.cpu().numpy()))
                bgmm_mae = np.abs(bgmm_cdf - ground_truth_cdf.cpu().numpy()).mean()
                bgmm_nll = -bgmm.score(kde_test_y_folding_times_sorted)

                result["bgmm"][n_components] = {"KS": bgmm_ks, "MAE": bgmm_mae, "NLL": bgmm_nll}

            # Append sequence and stop structure and mfe
            result["sequence"] = sequence_order[batch_index]
            result["stop_structure"] = structure_order[batch_index]
            result["mfe"] = mfe_order[batch_index]

            # Add to all results
            all_results.append(result)

    return all_results


results = evaluate_model_on_testing_set_individual_metrics(trained_model)
print(results)

In [None]:
# KinPFN on 97 and 119 nucleotide long RNA in Appendix
plot_kinpfn_on_selected_testing_set_seq(trained_model, [10, 25, 50, 75], seed=79539)