In [28]:
import matplotlib.pyplot as plt
import polars as pl
import numpy as np
import os
import torch
import random

from kinpfn.priors import Batch
from scipy.interpolate import interp1d

def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [29]:
def get_dataset_size(val_set_dir):
    dataset_size = 0
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                dataset_size += 1
    return dataset_size


def get_batch_testing_folding_times(val_set_dir, seq_len=100, num_features=1, **kwargs):

    dataset_size = get_dataset_size(val_set_dir)

    x = torch.zeros(seq_len, dataset_size, num_features)
    y = torch.zeros(seq_len, dataset_size)

    batch_index = 0
    length_order = []
    file_names_in_order = []
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                path = os.path.join(subdir, file)
                data = pl.read_csv(
                    path,
                    has_header=False,
                    columns=[2, 4],
                    n_rows=seq_len,
                )

                folding_times = data["column_3"].to_numpy()
                sequence = data["column_5"][0]
                length_order.append(len(sequence))
                file_names_in_order.append(file)
                sorted_folding_times = np.sort(folding_times)

                # Filter out points where x > 10^15 and x < 10^-6
                valid_indices = (sorted_folding_times <= 10**15) & (
                    sorted_folding_times >= 10**-6
                )
                sorted_folding_times = sorted_folding_times[valid_indices]

                # Adjust the sequence length by sampling if bounds are exceeded
                current_seq_len = len(sorted_folding_times)
                if current_seq_len <= 0:
                    continue

                if current_seq_len < seq_len:
                    # Repeat the sorted_folding_times and cdf to match the sequence length (Oversampling) (only if bounds are exceeded)
                    repeat_factor = seq_len // current_seq_len + 1
                    sorted_folding_times = np.tile(sorted_folding_times, repeat_factor)[
                        :seq_len
                    ]
                else:
                    sorted_folding_times = sorted_folding_times[:seq_len]

                x[:, batch_index, 0] = torch.tensor(np.zeros(seq_len))
                y[:, batch_index] = torch.tensor(sorted_folding_times)
                batch_index += 1

    y = torch.log10(y)
    return Batch(x=x, y=y, target_y=y), length_order, file_names_in_order

In [30]:
def plot_kinpfn_on_random_testing_set_seq():
    seq_len = 1000

    num_sequences = 20

    val_set_dir = "../../../../kinpfn_testing_set"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)

    y_folding_times = batch.y

    indices = list(range(dataset_size))

    fig, ax = plt.subplots(
        nrows=1, ncols=1, figsize=(8, 5)
    )
    fig.subplots_adjust(wspace=3)
    fig.set_dpi(300)

    sequences = 0

    random.shuffle(indices)

    for batch_index in indices:
        if sequences >= num_sequences:
            break

        print("Lengths of Sequences: ", length_order[batch_index])

        test_y_folding_times = y_folding_times[:, batch_index]
            

        ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
        ground_truth_cdf = torch.arange(
            1, len(ground_truth_sorted_folding_times) + 1
        ) / len(ground_truth_sorted_folding_times)

        ax.scatter(
            10**ground_truth_sorted_folding_times,
            ground_truth_cdf,
            marker="x",
            label="Target " + r"$F(t)$",
        )

        sequences += 1

    plt.ylim(0, 1)
    plt.xscale("log")
    plt.xlabel("Time t")
    plt.ylabel(
        "Fraction of Molecules " + r"$F(t)$" + " folded into Stop Structure"
    )
    plt.tight_layout()
    plt.show()

In [None]:
plot_kinpfn_on_random_testing_set_seq()

In [32]:
def multi_modal_distribution(
    x, num_peaks, log_start, log_end, rng=np.random.default_rng(seed=np.random.seed())
):
    theta = rng.uniform(log_start, log_end)
    means = rng.uniform(log_start + 1, log_end + 1, num_peaks) + theta
    stds = rng.uniform(0.1, (log_end - log_start) / 5, num_peaks)
    distribution = np.zeros_like(x, dtype=np.float32)
    print("k: ", num_peaks)
    print("Parameters: Means:", means, "Stds:", stds)
    for mean, std in zip(means, stds):
        distribution += np.exp(-((np.log(x) - mean) ** 2) / (2 * std**2))

    return distribution


def sample_from_multi_modal_distribution(batch_size, seq_len, num_features):

    xs = torch.zeros(batch_size, seq_len, num_features, dtype=torch.float32)
    ys = torch.empty(batch_size, seq_len, dtype=torch.float32)

    rng = np.random.default_rng(seed=np.random.seed())
    log_start = -6
    log_end = 15
    x = np.logspace(log_start, log_end, seq_len, dtype=np.float32)

    for i in range(batch_size):
        num_peaks = rng.integers(2, 6)
        own_pdf = multi_modal_distribution(x, num_peaks, log_start, log_end, rng=rng)
        own_pdf /= np.trapz(own_pdf, x)
        own_cdf = np.cumsum(own_pdf).astype(np.float32)
        own_cdf /= own_cdf[-1]
        inverse_cdf = interp1d(
            own_cdf, x, bounds_error=False, fill_value=(x[0], x[-1]), kind="linear"
        )
        uniform_samples = rng.uniform(0, 1, seq_len).astype(np.float32)
        samples = inverse_cdf(uniform_samples).astype(np.float32)
        ys[i] = torch.tensor(samples, dtype=torch.float32)

    return xs, ys


def get_batch_multi_modal_distribution_prior(
    batch_size, seq_len, num_features=1, hyperparameters=None, **kwargs
):
    xs, ys = sample_from_multi_modal_distribution(
        batch_size=batch_size, seq_len=seq_len, num_features=num_features
    )
    # Log encoding y
    ys = torch.log10(ys)

    return Batch(
        x=xs.transpose(0, 1),
        y=ys.transpose(0, 1),
        target_y=ys.transpose(0, 1),
    )

In [33]:

def plot_real_and_prior_fpt_cdfs():
    seed = 970
    set_seed(seed)
    
    seq_len = 1000
    num_sequences = 10
    val_set_dir = "../../../../kinpfn_testing_set"
    
    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)
    y_folding_times = batch.y

    indices = list(range(dataset_size))
    random.shuffle(indices)
    rng = np.random.default_rng(seed=seed)

    colors = plt.get_cmap("tab10")

    fig_real, ax_real = plt.subplots(nrows=1, ncols=1, figsize=(8, 5))
    fig_synthetic, ax_synthetic = plt.subplots(nrows=1, ncols=1, figsize=(7.295, 5))

    sequences = 0

    fontsize = 15
    labelsize = 12

    for batch_index in indices:
        if sequences >= num_sequences:
            break
        
        color = colors(sequences % colors.N)

        test_y_folding_times = y_folding_times[:, batch_index]
        ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
        ground_truth_cdf = torch.arange(
            1, len(ground_truth_sorted_folding_times) + 1
        ) / len(ground_truth_sorted_folding_times)
        
        ax_real.scatter(
            10**ground_truth_sorted_folding_times,
            ground_truth_cdf,
            color=color,
            marker="x",
            label=f"Target F(t) Seq {sequences + 1}",
        )
        
        log_start = -6
        log_end = 15
        x = np.logspace(log_start, log_end, seq_len, dtype=np.float32)
        num_peaks = rng.integers(2, 6)
        y = multi_modal_distribution(x, num_peaks, log_start, log_end, rng=rng)
        y /= np.trapz(y, x)
        
        cdf = np.cumsum(y).astype(np.float32)
        cdf /= cdf[-1]
        
        ax_synthetic.scatter(
            x,
            cdf,
            color=color,
            marker="x",
            label=f"Synthetic F(t) Seq {sequences + 1}",
        )
        
        sequences += 1

    ax_real.set_xscale("log")
    ax_real.set_xlabel("Time t", fontsize=fontsize)
    ax_real.set_ylabel("Cumulative Population Probability", fontsize=fontsize)
    ax_real.tick_params(axis="both", labelsize=labelsize)
    ax_real.grid(True)

    ax_synthetic.set_xscale("log")
    ax_synthetic.set_xlim(10**log_start, 10**log_end)
    ax_synthetic.set_xlabel("Synthetic Time t", fontsize=fontsize)
    ax_synthetic.set_ylabel("Cumulative Population Probability", fontsize=fontsize)
    ax_synthetic.tick_params(axis="both", labelsize=labelsize)
    ax_synthetic.grid(True)

    plt.tight_layout()
    plt.show()


In [None]:
plot_real_and_prior_fpt_cdfs()

In [35]:
def plot_single_cdf_and_sampled_inversed(seq_len=100):
    fig, ax1 = plt.subplots(figsize=(10, 6))
    fig.set_dpi(300)
    fontsize = 15
    seed = 39
    set_seed(seed)
    print("Seed: ", seed)
    rng = np.random.default_rng(seed=seed)
    log_start = -6
    log_end = 15
    x = np.logspace(log_start, log_end, seq_len, dtype=np.float32)

    num_peaks = rng.integers(2, 3)
    y = multi_modal_distribution(x, num_peaks, log_start, log_end, rng=rng)
    y /= np.trapz(y, x)

    cdf = np.cumsum(y).astype(np.float32)
    cdf /= cdf[-1]

    color_cdf = "darkred"
    color_pdf = "darkblue"

    ax1.set_xscale("log")
    ax1.set_xlim(10**log_start, 10**log_end)
    ax1.plot(x, cdf, label="CDF", color=color_cdf)
    ax1.set_xlabel("Synthetic Time t", fontsize = fontsize + 2)
    ax1.set_ylabel(
        "Cumulative Population Probability",
        color=color_cdf,
        fontsize = fontsize + 2,
    )
    ax1.grid(True)
    ax1.tick_params(axis="y", labelcolor=color_cdf, labelsize = fontsize)
    ax1.tick_params(axis="x", labelsize = fontsize)

    inverse_cdf = interp1d(
        cdf, x, bounds_error=False, fill_value=(x[0], x[-1]), kind="linear"
    )
    uniform_samples = rng.uniform(0, 1, seq_len).astype(np.float32)
    samples = inverse_cdf(uniform_samples).astype(np.float32)

    ax1.scatter(
        samples,
        np.zeros_like(samples),
        marker="x",
        label="Samples",
        color="black",
        alpha=0.7,
    )

    ax2 = ax1.twinx()
    
    ax2.plot(x, y, label="PDF", color=color_pdf, linestyle="--")
    ax2.set_ylabel("Probability Density Function", color=color_pdf, fontsize = fontsize)
    ax2.tick_params(axis="y", labelcolor=color_pdf, labelsize = fontsize)
    ax1.legend(loc="upper left", fontsize = fontsize)
    ax2.legend(loc="upper right", fontsize = fontsize)

    plt.show()

In [None]:
plot_single_cdf_and_sampled_inversed(seq_len=1000)