In [1]:
import os
import random
import torch
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
from matplotlib.lines import Line2D
import pandas as pd
import polars as pl
import numpy as np
from pathlib import Path
from collections import defaultdict

from kinpfn.priors import Batch
from kinpfn.model import KINPFN

from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.neighbors import KernelDensity
from scipy.integrate import cumulative_trapezoid

In [2]:
def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [3]:
def get_dataset_size(val_set_dir):
    dataset_size = 0
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                dataset_size += 1
    return dataset_size


def get_batch_testing_folding_times(val_set_dir, seq_len=100, num_features=1, **kwargs):

    dataset_size = get_dataset_size(val_set_dir)

    x = torch.zeros(seq_len, dataset_size, num_features)
    y = torch.zeros(seq_len, dataset_size)

    batch_index = 0
    length_order = []
    file_names_in_order = []
    for subdir, _, files in os.walk(val_set_dir):
        for file in files:
            if file.endswith(".csv"):
                path = os.path.join(subdir, file)
                data = pl.read_csv(
                    path,
                    has_header=False,
                    columns=[2, 4],
                    n_rows=seq_len,
                )

                folding_times = data["column_3"].to_numpy()
                sequence = data["column_5"][0]
                length_order.append(len(sequence))
                file_names_in_order.append(file)
                sorted_folding_times = np.sort(folding_times)

                # Filter out points where x > 10^15 and x < 10^-6
                valid_indices = (sorted_folding_times <= 10**15) & (
                    sorted_folding_times >= 10**-6
                )
                sorted_folding_times = sorted_folding_times[valid_indices]

                # Adjust the sequence length by sampling
                current_seq_len = len(sorted_folding_times)
                if current_seq_len <= 0:
                    continue

                if current_seq_len < seq_len:
                    # Repeat the sorted_folding_times and cdf to match the sequence length (Oversampling)
                    repeat_factor = seq_len // current_seq_len + 1
                    sorted_folding_times = np.tile(sorted_folding_times, repeat_factor)[
                        :seq_len
                    ]
                else:
                    sorted_folding_times = sorted_folding_times[:seq_len]

                x[:, batch_index, 0] = torch.tensor(np.zeros(seq_len))
                y[:, batch_index] = torch.tensor(sorted_folding_times)
                batch_index += 1

    y = torch.log10(y)
    return Batch(x=x, y=y, target_y=y), length_order, file_names_in_order

In [None]:
model_path = "../../../../models/final_kinpfn_model_1400_1000_1000_86_50_2.5588748050825984e-05_256_4_512_8_0.0_0.0.pt"

kinpfn = KINPFN(
    model_path=model_path,
)
trained_model = kinpfn.model

if trained_model is not None:
    print("Load trained model!")
else:
    print("No trained model found!")
    exit()

In [5]:
from sklearn.mixture import GaussianMixture
import numpy as np
import torch

def evaluate_model_on_testing_set_gmm_ensemble_weighted(trained_model, n_components_list):
    set_seed(seed=123)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    training_points = [10, 25, 50, 75, 100, 250, 500, 750, 1000]

    val_set_dir = "../../../../kinpfn_testing_set"

    batch, length_order, file_names_in_order = get_batch_testing_folding_times(
        val_set_dir=val_set_dir, seq_len=seq_len
    )
    dataset_size = get_dataset_size(val_set_dir)
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    
    #n_components_list = [2, 3, 4, 5]

    for training_point in training_points:
        print(f"Training Point: {training_point}")
        mae_losses = []
        mean_nll_losses = []
        ensemble_mae_losses = []
        ensemble_nll_losses = []

        for i in indices:
            batch_index = i

            train_indices = torch.randperm(seq_len)[:training_point]

            train_x = x[train_indices, batch_index]
            train_y_folding_times = y_folding_times[train_indices, batch_index]

            test_x = x[:, batch_index]
            test_y_folding_times = y_folding_times[:, batch_index]

            train_x = train_x.to(device)
            train_y_folding_times = train_y_folding_times.to(device)
            test_x = test_x.to(device)
            test_y_folding_times = test_y_folding_times.to(device)

            with torch.no_grad():
                # Add a batch dimension as required by the transformer model
                logits = trained_model(
                    train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                )

            ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
            ground_truth_cdf = torch.arange(
                1, len(ground_truth_sorted_folding_times) + 1
            ) / len(ground_truth_sorted_folding_times)

            test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
            pred_cdf_original = (
                trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
            )[0][0]

            single_absolute_error = np.abs(pred_cdf_original - ground_truth_cdf)
            mae = single_absolute_error.mean()
            mae_losses.append(mae)

            nll_loss = trained_model.criterion.forward(
                logits=logits, y=test_y_folding_times_sorted
            )
            mean_nll_loss = nll_loss.mean()
            mean_nll_losses.append(mean_nll_loss)

            train_y_folding_times_np = train_y_folding_times.reshape(-1, 1).cpu().numpy()
            test_y_folding_times_np = test_y_folding_times_sorted.reshape(-1, 1).cpu().numpy()
            combined_density = np.zeros_like(test_y_folding_times_np.flatten())
            marginal_weights = []
            gmm_pdfs = []

            for n_components in n_components_list:
                gmm = GaussianMixture(n_components=n_components, max_iter=100000)
                gmm.fit(train_y_folding_times_np)
                
                log_marginal_likelihood = gmm.score(train_y_folding_times_np) * len(train_y_folding_times_np)
                marginal_weights.append(log_marginal_likelihood)
                
                gmm_pdf = np.exp(gmm.score_samples(test_y_folding_times_np))
                gmm_pdfs.append(gmm_pdf)

            marginal_weights = np.exp(marginal_weights - np.max(marginal_weights))
            marginal_weights /= np.sum(marginal_weights)

            # Weighted combined density
            for gmm_pdf, weight in zip(gmm_pdfs, marginal_weights):
                combined_density += weight * gmm_pdf

            combined_cdf = np.cumsum(combined_density) / np.sum(combined_density)

            ensemble_mae = np.abs(combined_cdf - ground_truth_cdf.cpu().numpy()).mean()
            ensemble_mae_losses.append(ensemble_mae)

            ensemble_nll = -np.mean(np.log(combined_density + 1e-9))
            ensemble_nll_losses.append(ensemble_nll)

        mae_losses = torch.tensor(mae_losses).mean()
        print(f"KinPFN Mean MAE Loss: {mae_losses}")

        mean_nll_losses = torch.tensor(mean_nll_losses).mean()
        print(f"KinPFN Mean NLL Loss: {mean_nll_losses}")

        ensemble_mae_losses = torch.tensor(ensemble_mae_losses).mean()
        print(f"GMM Ensemble Mean MAE Loss: {ensemble_mae_losses}")

        ensemble_nll_losses = torch.tensor(ensemble_nll_losses).mean()
        print(f"GMM Ensemble Mean NLL Loss: {ensemble_nll_losses}")


In [6]:
#evaluate_model_on_testing_set_gmm_ensemble_weighted(trained_model, n_components_list=[2, 3, 4, 5])
#evaluate_model_on_testing_set_gmm_ensemble_weighted(trained_model, n_components_list=[2, 3, 4])

In [7]:
from scipy.interpolate import interp1d

## Family of multi-modal Gaussian distributions
def multi_modal_distribution(
    x, num_peaks, log_start, log_end, rng=np.random.default_rng(seed=np.random.seed())
):
    theta = rng.uniform(log_start, log_end)
    means = (rng.uniform(log_start + 1, log_end + 1, num_peaks) + theta).astype(
        np.float32
    )
    stds = rng.uniform(0.1, (log_end - log_start) / 5, num_peaks).astype(np.float32)
    distribution = np.zeros_like(x, dtype=np.float32)

    for mean, std in zip(means, stds):
        distribution += np.exp(-((np.log(x) - mean) ** 2) / (2 * std**2))

    return distribution


def sample_from_multi_modal_distribution(batch_size, seq_len, num_features):

    xs = torch.zeros(batch_size, seq_len, num_features, dtype=torch.float32)
    ys = torch.empty(batch_size, seq_len, dtype=torch.float32)

    rng = np.random.default_rng(seed=np.random.seed())
    log_start = -6
    log_end = 15
    x = np.logspace(log_start, log_end, seq_len, dtype=np.float32)

    for i in range(batch_size):
        num_peaks = rng.integers(2, 6)
        own_pdf = multi_modal_distribution(x, num_peaks, log_start, log_end, rng=rng)
        own_pdf /= np.trapz(own_pdf, x)
        own_cdf = np.cumsum(own_pdf).astype(np.float32)
        own_cdf /= own_cdf[-1]
        inverse_cdf = interp1d(
            own_cdf, x, bounds_error=False, fill_value=(x[0], x[-1]), kind="linear"
        )
        uniform_samples = rng.uniform(0, 1, seq_len).astype(np.float32)
        samples = inverse_cdf(uniform_samples).astype(np.float32)
        ys[i] = torch.tensor(samples, dtype=torch.float32)

    return xs, ys


def get_batch_multi_modal_distribution_prior(
    batch_size, seq_len, num_features=1, hyperparameters=None, **kwargs
):
    xs, ys = sample_from_multi_modal_distribution(
        batch_size=batch_size, seq_len=seq_len, num_features=num_features
    )
    # Log encoding y
    ys = torch.log10(ys)

    return Batch(
        x=xs.transpose(0, 1),
        y=ys.transpose(0, 1),
        target_y=ys.transpose(0, 1),
    )


In [8]:
def evaluate_model_gmm_ensemble_weighted_on_prior( n_components_list):
    set_seed(seed=123)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    #training_points = [10, 25, 50, 75, 100,250,500,750,1000]
    training_points = [250,500,750,1000]

    batch = get_batch_multi_modal_distribution_prior(10000, seq_len)
    dataset_size = 10000
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    
    #n_components_list = [2, 3, 4, 5]

    for training_point in training_points:
        print(f"Training Point: {training_point}")
        ensemble_mae_losses = []
        ensemble_nll_losses = []

        for i in indices:
            batch_index = i

            train_indices = torch.randperm(seq_len)[:training_point]

            train_x = x[train_indices, batch_index]
            train_y_folding_times = y_folding_times[train_indices, batch_index]

            test_x = x[:, batch_index]
            test_y_folding_times = y_folding_times[:, batch_index]

            train_x = train_x.to(device)
            train_y_folding_times = train_y_folding_times.to(device)
            test_x = test_x.to(device)
            test_y_folding_times = test_y_folding_times.to(device)

            ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
            ground_truth_cdf = torch.arange(
                1, len(ground_truth_sorted_folding_times) + 1
            ) / len(ground_truth_sorted_folding_times)

            test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)

            train_y_folding_times_np = train_y_folding_times.reshape(-1, 1).cpu().numpy()
            test_y_folding_times_np = test_y_folding_times_sorted.reshape(-1, 1).cpu().numpy()
            combined_density = np.zeros_like(test_y_folding_times_np.flatten())
            marginal_weights = []
            gmm_pdfs = []

            for n_components in n_components_list:
                gmm = GaussianMixture(n_components=n_components, max_iter=100000)
                gmm.fit(train_y_folding_times_np)
                
                log_marginal_likelihood = gmm.score(train_y_folding_times_np) * len(train_y_folding_times_np)
                marginal_weights.append(log_marginal_likelihood)
                
                gmm_pdf = np.exp(gmm.score_samples(test_y_folding_times_np))
                gmm_pdfs.append(gmm_pdf)

            marginal_weights = np.exp(marginal_weights - np.max(marginal_weights))
            marginal_weights /= np.sum(marginal_weights)

            # Weighted combined density
            for gmm_pdf, weight in zip(gmm_pdfs, marginal_weights):
                combined_density += weight * gmm_pdf

            combined_cdf = np.cumsum(combined_density) / np.sum(combined_density)

            ensemble_mae = np.abs(combined_cdf - ground_truth_cdf.cpu().numpy()).mean()
            ensemble_mae_losses.append(ensemble_mae)

            ensemble_nll = -np.mean(np.log(combined_density + 1e-9))
            ensemble_nll_losses.append(ensemble_nll)

        ensemble_mae_losses = torch.tensor(ensemble_mae_losses).mean()
        print(f"GMM Ensemble Mean MAE Loss: {ensemble_mae_losses}")

        ensemble_nll_losses = torch.tensor(ensemble_nll_losses).mean()
        print(f"GMM Ensemble Mean NLL Loss: {ensemble_nll_losses}")

In [None]:
evaluate_model_gmm_ensemble_weighted_on_prior(n_components_list=[2, 3, 4, 5])
#evaluate_model_gmm_ensemble_weighted_on_prior(n_components_list=[2, 3, 4])