In [30]:
import os
import random
import torch
import matplotlib.pyplot as plt
import polars as pl
import numpy as np

from kinpfn.priors import Batch
from kinpfn.model import KINPFN


In [31]:
def plot_model_riemann_distribution(logits, trained_model, order_test_x, ax=None):
    probs_of_first_pred = logits[order_test_x][0, 0].softmax(1)
    borders = trained_model.criterion.borders
    widths = borders[1:] - borders[:-1]
    density_of_first_pred = probs_of_first_pred / widths
    print(borders[:-1].tolist())
    print(density_of_first_pred.tolist())
    print(len(borders[:-1].tolist()))
    print(len(density_of_first_pred.tolist()))
    # if len(borders[:-1].tolist()) != len(density_of_first_pred.tolist()):
    # raise ValueError("The length of x_values and y_values must be the same.")

    decoded_borders = 10**borders
    decoded_widths = decoded_borders[1:] - decoded_borders[:-1]

    plt.bar(
        decoded_borders[:-1].tolist(),
        density_of_first_pred.tolist()[0],
        width=decoded_widths,
        align="edge",
        edgecolor="black",
    )
    plt.xscale("log")
    plt.show()

In [32]:
def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [33]:
def get_dataset_size(val_set_dir):
    dataset_size = 0
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                dataset_size += 1
    return dataset_size


def get_batch_testing_folding_times(val_set_dir, seq_len=100, num_features=1, **kwargs):

    dataset_size = get_dataset_size(val_set_dir)

    x = torch.zeros(seq_len, dataset_size, num_features)
    y = torch.zeros(seq_len, dataset_size)

    batch_index = 0
    length_order = []
    file_names_in_order = []
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                path = os.path.join(subdir, file)
                data = pl.read_csv(
                    path,
                    has_header=False,
                    columns=[2, 4],
                    n_rows=seq_len,
                )

                folding_times = data["column_3"].to_numpy()
                sequence = data["column_5"][0]
                length_order.append(len(sequence))
                file_names_in_order.append(file)
                sorted_folding_times = np.sort(folding_times)

                # Filter out points where x > 10^15 and x < 10^-6
                valid_indices = (sorted_folding_times <= 10**15) & (
                    sorted_folding_times >= 10**-6
                )
                sorted_folding_times = sorted_folding_times[valid_indices]

                # Adjust the sequence length by sampling
                current_seq_len = len(sorted_folding_times)
                if current_seq_len <= 0:
                    continue

                if current_seq_len < seq_len:
                    # Repeat the sorted_folding_times and cdf to match the sequence length (Oversampling)
                    repeat_factor = seq_len // current_seq_len + 1
                    sorted_folding_times = np.tile(sorted_folding_times, repeat_factor)[
                        :seq_len
                    ]
                else:
                    sorted_folding_times = sorted_folding_times[:seq_len]

                x[:, batch_index, 0] = torch.tensor(np.zeros(seq_len))
                y[:, batch_index] = torch.tensor(sorted_folding_times)
                batch_index += 1

    y = torch.log10(y)
    return Batch(x=x, y=y, target_y=y), length_order, file_names_in_order

In [None]:
model_path = "../../../../models/final_kinpfn_model_1400_1000_1000_86_50_2.5588748050825984e-05_256_4_512_8_0.0_0.0.pt"

kinpfn = KINPFN(
    model_path=model_path,
)
trained_model = kinpfn.model

if trained_model is not None:
    print("Load trained model!")
else:
    print("No trained model found!")
    exit()

In [35]:
def plot_model_prediction_with_riemann(trained_model):
    seq_len = 1000
    num_evaluations = 50
    val_set_dir = "./data"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    training_point = 75

    indices = list(range(dataset_size))
    evaluations = 0

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(16.5, 15.5))
    fig.set_dpi(300)
    seeds = [3105, 8374]
    for i in indices:
        seed = seeds[i]
        print("Seed: ", seed)
        set_seed(seed)
        if evaluations >= num_evaluations:
            break
        if length_order[i] < 60:
            continue
        batch_index = i
        print("Sequence Length: ", length_order[batch_index])
        print("File Name: ", file_names_in_order[batch_index])

        row = evaluations


        ax = axes[row, 0]

        train_indices = torch.randperm(seq_len)[:training_point]

        train_x = x[train_indices, batch_index]
        train_y_folding_times = y_folding_times[train_indices, batch_index]

        test_x = x[:, batch_index]
        test_y_folding_times = y_folding_times[:, batch_index]

        with torch.no_grad():
            logits = trained_model(
                train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
            )

        ### CALCULATE GROUND TRUTH CDF WHICH WE WANT TO PREDICT
        ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
        ground_truth_cdf = torch.arange(
            1, len(ground_truth_sorted_folding_times) + 1
        ) / len(ground_truth_sorted_folding_times)

        test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
        train_y_folding_times_sorted, _ = torch.sort(train_y_folding_times)

        ### CDF FUNCTION
        pred_cdf_original = (
            trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
        )[0][0]

        linspace_extended = torch.linspace(
            train_y_folding_times_sorted[0] - 1,
            train_y_folding_times_sorted[-1] + 1,
            1000,
        )
        pred_cdf_linspace_extended = (
            trained_model.criterion.cdf(logits, linspace_extended)
        )[0][0]

        ax.scatter(
            10**ground_truth_sorted_folding_times,
            ground_truth_cdf,
            color="#000000",
            marker="x",
            label="Target",
        )
        ax.scatter(
            10**linspace_extended,
            pred_cdf_linspace_extended,
            color="#cc101fc7",
            marker=".",
            label="KinPFN",
        )


        ax.scatter(
            10**train_y_folding_times_sorted,
            torch.zeros_like(train_y_folding_times_sorted),
            color="blue",
            marker="o",
            label="Context First Passage Times",
        )

        ax.set_ylim(0, 1)
        ax.set_xscale("log")
        ax.set_xlabel("Time", fontsize=18)
        ax.set_ylabel(
            "Cumulative Population Probability", fontsize=18
        )
        ax.tick_params(axis="both", which="major", labelsize=16)

        # MAE
        single_absolute_error = np.abs(pred_cdf_original - ground_truth_cdf)
        mae = single_absolute_error.mean()

        # NLL
        nll_loss = trained_model.criterion.forward(
            logits=logits, y=test_y_folding_times_sorted
        )
        mean_nll_loss = nll_loss.mean()

        ax.legend(fontsize=16)

        test_x = test_x.cpu().numpy()
        test_y_folding_times = test_y_folding_times.cpu().numpy()
        train_x = train_x.cpu().numpy()
        train_y_folding_times = train_y_folding_times.cpu().numpy()

        ## RIEMANN DISTRIBUTION PLOT
        ax = axes[row, 1]

        probs_of_first_pred = logits[test_x][0, 0].softmax(1)
        borders = trained_model.criterion.borders
        widths = borders[1:] - borders[:-1]
        density_of_first_pred = probs_of_first_pred / widths

        decoded_borders = 10**borders
        decoded_widths = decoded_borders[1:] - decoded_borders[:-1]

        ax.bar(
            decoded_borders[:-1].tolist(),
            density_of_first_pred.tolist()[0],
            width=decoded_widths,
            align="edge",
            edgecolor="black",
            label="KinPFN PPD PDF",
        )
        ax.set_xscale("log")
        ax.set_xlabel("Time (Riemann Distribution Buckets)", fontsize=18)
        ax.set_ylabel("Riemann Density", fontsize=18)
        ax.tick_params(axis="both", which="major", labelsize=16)
        ax.legend(fontsize=16)

        evaluations += 1

    plt.show()

In [None]:
plot_model_prediction_with_riemann(trained_model)