In [1]:
import os
import random
import torch
import matplotlib.pyplot as plt
import polars as pl
import numpy as np
from kinpfn.priors import Batch
from kinpfn.model import KINPFN

In [2]:
def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [3]:
def get_dataset_size(val_set_dir):
    dataset_size = 0
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                dataset_size += 1
    return dataset_size


def get_batch_testing_folding_times(val_set_dir, seq_len=100, num_features=1, **kwargs):

    dataset_size = get_dataset_size(val_set_dir)

    x = torch.zeros(seq_len, dataset_size, num_features)
    y = torch.zeros(seq_len, dataset_size)

    batch_index = 0
    length_order = []
    file_names_in_order = []
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                path = os.path.join(subdir, file)
                data = pl.read_csv(
                    path,
                    has_header=False,
                    columns=[2, 4],
                    n_rows=seq_len,
                )

                folding_times = data["column_3"].to_numpy()
                sequence = data["column_5"][0]
                length_order.append(len(sequence))
                file_names_in_order.append(file)
                sorted_folding_times = np.sort(folding_times)

                # Filter out points where x > 10^15 and x < 10^-6
                valid_indices = (sorted_folding_times <= 10**15) & (
                    sorted_folding_times >= 10**-6
                )
                sorted_folding_times = sorted_folding_times[valid_indices]

                # Adjust the sequence length by sampling
                current_seq_len = len(sorted_folding_times)
                if current_seq_len <= 0:
                    continue

                if current_seq_len < seq_len:
                    # Repeat the sorted_folding_times and cdf to match the sequence length (Oversampling)
                    repeat_factor = seq_len // current_seq_len + 1
                    sorted_folding_times = np.tile(sorted_folding_times, repeat_factor)[
                        :seq_len
                    ]
                else:
                    sorted_folding_times = sorted_folding_times[:seq_len]

                x[:, batch_index, 0] = torch.tensor(np.zeros(seq_len))
                y[:, batch_index] = torch.tensor(sorted_folding_times)
                batch_index += 1

    y = torch.log10(y)
    return Batch(x=x, y=y, target_y=y), length_order, file_names_in_order

In [None]:
model_path = "../../../../models/final_kinpfn_model_1400_1000_1000_86_50_2.5588748050825984e-05_256_4_512_8_0.0_0.0.pt"

kinpfn = KINPFN(
    model_path=model_path,
)
trained_model = kinpfn.model

if trained_model is not None:
    print("Load trained model!")
else:
    print("No trained model found!")
    exit()

In [5]:
def evaluate_model_on_yeast_tRNA_rRNA(trained_model, training_points, seed=None):
    if seed is None:
        seed = random.randint(0, 100000)
    #seed = 72934
    #seed = 38924
    set_seed(seed=seed)
    print(f"Seed: {seed}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000

    num_evaluations = 2

    val_set_dir = "./data/simulations"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    evaluations = 0

    if len(training_points) == 1:
        width = 7.25 * num_evaluations
        fig, axes = plt.subplots(
            nrows=1,
            ncols=num_evaluations,
            figsize=(width, 6.5),
            layout="constrained",
        )
    else:
        width = 7.75 * len(training_points)
        fig, axes = plt.subplots(
            nrows=2,
            ncols=len(training_points),
            figsize=(width, 12.5),
            layout="constrained",
        )
    fig.set_dpi(300)

    for i in range(num_evaluations):
        if len(training_points) == 1:
            # First tRNA then rRNA
            if i == 0:
                batch_index = 1
            else:
                batch_index = 0
        else:
            row = i
            if i == 0:
                batch_index = 1
            else:
                batch_index = 0

        for j, training_point in enumerate(training_points):
            if len(training_points) == 1:
                col = i
                ax = axes[col]
            else:
                col = j
                ax = axes[row, col]

            all_pred_cdf_list = []
            all_mae_list = []
            for _ in range(20):

                train_indices = torch.randperm(seq_len)[:training_point]

                # Create the training and test data
                train_x = x[train_indices, batch_index]
                train_y_folding_times = y_folding_times[train_indices, batch_index]

                test_x = x[:, batch_index]
                test_y_folding_times = y_folding_times[:, batch_index]

                train_x = train_x.to(device)
                train_y_folding_times = train_y_folding_times.to(device)
                test_x = test_x.to(device)
                test_y_folding_times = test_y_folding_times.to(device)

                with torch.no_grad():
                    logits = trained_model(
                        train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                    )

                ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
                ground_truth_cdf = torch.arange(
                    1, len(ground_truth_sorted_folding_times) + 1
                ) / len(ground_truth_sorted_folding_times)

                test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
                train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

                pred_cdf_original = (
                    trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
                )[0][0]

                linspace_extended = torch.linspace(
                    train_y_folding_times_sorted[0] - 1,
                    train_y_folding_times_sorted[-1] + 1,
                    1000,
                )
                pred_cdf_linspace_extended = (
                    trained_model.criterion.cdf(logits, linspace_extended)
                )[0][0]

                all_pred_cdf_list.append(pred_cdf_linspace_extended.cpu().numpy())
                single_absolute_error = np.abs(pred_cdf_original.cpu().numpy() - ground_truth_cdf.cpu().numpy())
                mae = single_absolute_error.mean()
                all_mae_list.append(mae)

            mean_pred_cdf = np.mean(all_pred_cdf_list, axis=0)
            std_pred_cdf = np.std(all_pred_cdf_list, axis=0)
            mean_mae = np.mean(all_mae_list)

            ax.scatter(
                10**ground_truth_sorted_folding_times.cpu().numpy(),
                ground_truth_cdf.cpu().numpy(),
                color="#000000",
                marker="x",
                label="Target",
            )


            ax.plot(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf,
                color="#cc101fc7",
                marker=".",
                label="KinPFN " + r"(mean $\pm$ std)",
            )

            ax.fill_between(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf - std_pred_cdf,
                mean_pred_cdf + std_pred_cdf,
                color="#cc101fc7",
                alpha=0.5,
            )

            ax.scatter(
                10**train_y_folding_times_sorted.cpu().numpy(),
                np.zeros_like(train_y_folding_times_sorted.cpu().numpy()),
                color="blue",
                marker="o",
                label="Context First Passage Times",
            )

            ax.set_ylim(0, 1)
            ax.set_xscale("log")
            ax.set_xlabel("Time", fontsize=18)
            ax.tick_params(axis='both', which='major', labelsize=16)

            if col == 0:
                ax.set_ylabel(
                    "Cumulative Population Probability", fontsize=18
                )

            print(f"#KinPFN Context Times: {len(train_y_folding_times)}")
            print(f"MAE: {mean_mae:.4f}")

            ax.legend(fontsize=16)

        evaluations += 1

    plt.tight_layout()
    plt.show()

In [6]:
def evaluate_model_on_yeast_tRNA(trained_model, training_points, seed=None):
    if seed is None:
        seed = random.randint(0, 100000)
    set_seed(seed=seed)
    print(f"Seed: {seed}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000

    num_evaluations = 1

    val_set_dir = "./data/simulations"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    evaluations = 0

    if len(training_points) == 1:
        width = 7.5 * num_evaluations
        fig, axes = plt.subplots(
            nrows=1,
            ncols=num_evaluations,
            figsize=(width, 6),
            layout="constrained",
        )
    else:
        width = 8.75 * len(training_points)
        fig, axes = plt.subplots(
            nrows=1,
            ncols=len(training_points),
            figsize=(width, 6.5),
            layout="constrained",
        )
    fig.set_dpi(300)

    for i in range(num_evaluations):
        if len(training_points) == 1:
            # First tRNA then rRNA
            if i == 0:
                batch_index = 1
            else:
                batch_index = 0
        else:
            row = i
            if i == 0:
                batch_index = 1
            else:
                batch_index = 0

        for j, training_point in enumerate(training_points):
            if len(training_points) == 1:
                col = i
                ax = axes
            else:
                col = j
                ax = axes[col]

            all_pred_cdf_list = []
            all_mae_list = []
            for _ in range(20):

                train_indices = torch.randperm(seq_len)[:training_point]

                train_x = x[train_indices, batch_index]
                train_y_folding_times = y_folding_times[train_indices, batch_index]

                test_x = x[:, batch_index]
                test_y_folding_times = y_folding_times[:, batch_index]

                train_x = train_x.to(device)
                train_y_folding_times = train_y_folding_times.to(device)
                test_x = test_x.to(device)
                test_y_folding_times = test_y_folding_times.to(device)

                with torch.no_grad():
                    logits = trained_model(
                        train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                    )

                ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
                ground_truth_cdf = torch.arange(
                    1, len(ground_truth_sorted_folding_times) + 1
                ) / len(ground_truth_sorted_folding_times)

                test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
                train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

                pred_cdf_original = (
                    trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
                )[0][0]

                linspace_extended = torch.linspace(
                    train_y_folding_times_sorted[0] - 1,
                    train_y_folding_times_sorted[-1] + 1,
                    1000,
                )
                pred_cdf_linspace_extended = (
                    trained_model.criterion.cdf(logits, linspace_extended)
                )[0][0]

                all_pred_cdf_list.append(pred_cdf_linspace_extended.cpu().numpy())
                single_absolute_error = np.abs(pred_cdf_original.cpu().numpy() - ground_truth_cdf.cpu().numpy())
                mae = single_absolute_error.mean()
                all_mae_list.append(mae)

            mean_pred_cdf = np.mean(all_pred_cdf_list, axis=0)
            std_pred_cdf = np.std(all_pred_cdf_list, axis=0)
            mean_mae = np.mean(all_mae_list)

            ax.scatter(
                10**ground_truth_sorted_folding_times.cpu().numpy(),
                ground_truth_cdf.cpu().numpy(),
                color="#000000",
                marker="x",
                label="Target",
            )


            ax.plot(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf,
                color="#cc101fc7",
                marker=".",
                label="KinPFN " + r"(mean $\pm$ std)",
            )

            ax.fill_between(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf - std_pred_cdf,
                mean_pred_cdf + std_pred_cdf,
                color="#cc101fc7",
                alpha=0.5,
            )

            ax.scatter(
                10**train_y_folding_times_sorted.cpu().numpy(),
                np.zeros_like(train_y_folding_times_sorted.cpu().numpy()),
                color="blue",
                marker="o",
                label="Context First Passage Times",
            )

            ax.set_ylim(0, 1)
            ax.set_xscale("log")
            ax.set_xlabel("Time", fontsize=18)
            ax.tick_params(axis='both', which='major', labelsize=16)

            if col == 0:
                ax.set_ylabel(
                    "Cumulative Population Probability", fontsize=18
                )

            print(f"#KinPFN Context Times: {len(train_y_folding_times)}")
            print(f"MAE: {mean_mae:.4f}")

            ax.legend(fontsize=16)

        evaluations += 1

    plt.tight_layout()
    plt.show()
    plt.close("all")

In [None]:

evaluate_model_on_yeast_tRNA_rRNA(trained_model, training_points=[10, 25, 50, 75], seed=38924)
evaluate_model_on_yeast_tRNA(trained_model, training_points=[25, 50], seed=12043)

In [8]:
def evaluate_model_on_yeast_tRNA_single_plot(trained_model, training_points, seed=None):
    if seed is None:
        seed = random.randint(0, 100000)
    set_seed(seed=seed)
    print(f"Seed: {seed}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    num_evaluations = 1
    val_set_dir = "./data/simulations"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    
    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    evaluations = 0

    for j, training_point in enumerate(training_points):
        for i in range(num_evaluations):
            width = 8.75
            fig, ax = plt.subplots(figsize=(width, 6.5))
            fig.set_dpi(300)
            
            # First tRNA, then rRNA
            if i == 0:
                batch_index = 1
            else:
                batch_index = 0

            all_pred_cdf_list = []
            all_mae_list = []
            for _ in range(20):

                train_indices = torch.randperm(seq_len)[:training_point]

                train_x = x[train_indices, batch_index]
                train_y_folding_times = y_folding_times[train_indices, batch_index]

                test_x = x[:, batch_index]
                test_y_folding_times = y_folding_times[:, batch_index]

                train_x = train_x.to(device)
                train_y_folding_times = train_y_folding_times.to(device)
                test_x = test_x.to(device)
                test_y_folding_times = test_y_folding_times.to(device)

                with torch.no_grad():
                    logits = trained_model(
                        train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                    )

                ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
                ground_truth_cdf = torch.arange(
                    1, len(ground_truth_sorted_folding_times) + 1
                ) / len(ground_truth_sorted_folding_times)

                test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
                train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

                pred_cdf_original = (
                    trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
                )[0][0]

                linspace_extended = torch.linspace(
                    train_y_folding_times_sorted[0] - 1,
                    train_y_folding_times_sorted[-1] + 1,
                    1000,
                )
                pred_cdf_linspace_extended = (
                    trained_model.criterion.cdf(logits, linspace_extended)
                )[0][0]

                all_pred_cdf_list.append(pred_cdf_linspace_extended.cpu().numpy())
                single_absolute_error = np.abs(pred_cdf_original.cpu().numpy() - ground_truth_cdf.cpu().numpy())
                mae = single_absolute_error.mean()
                all_mae_list.append(mae)

            mean_pred_cdf = np.mean(all_pred_cdf_list, axis=0)
            std_pred_cdf = np.std(all_pred_cdf_list, axis=0)
            mean_mae = np.mean(all_mae_list)

            ax.scatter(
                10**ground_truth_sorted_folding_times.cpu().numpy(),
                ground_truth_cdf.cpu().numpy(),
                color="#000000",
                marker="x",
                label="Target",
            )

            ax.plot(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf,
                color="#cc101fc7",
                marker=".",
                label="KinPFN " + r"(mean $\pm$ std)",
            )

            ax.fill_between(
                10**linspace_extended.cpu().numpy(),
                mean_pred_cdf - std_pred_cdf,
                mean_pred_cdf + std_pred_cdf,
                color="#cc101fc7",
                alpha=0.5,
            )

            ax.scatter(
                10**train_y_folding_times_sorted.cpu().numpy(),
                np.zeros_like(train_y_folding_times_sorted.cpu().numpy()),
                color="blue",
                marker="o",
                label="Context First Passage Times",
            )

            ax.set_ylim(0, 1)
            ax.set_xscale("log")
            ax.set_xlabel("Time", fontsize=20)
            ax.tick_params(axis='both', which='major', labelsize=18)

            if i == 0:
                ax.set_ylabel(
                    "Cumulative Population Probability", fontsize=20
                )

            print(f"#KinPFN Context Times: {len(train_y_folding_times)}")
            print(f"MAE: {mean_mae:.4f}")

            ax.legend(fontsize=18)

            evaluations += 1

            plt.tight_layout()
            plt.show()
            plt.close("all")


In [None]:
evaluate_model_on_yeast_tRNA_single_plot(trained_model, training_points=[25, 50], seed=12043)