In [1]:
import os
import random
import torch
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
from matplotlib.lines import Line2D
import pandas as pd
import polars as pl
import numpy as np
from pathlib import Path
from collections import defaultdict

from kinpfn.priors import Batch
from kinpfn.model import KINPFN

from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.neighbors import KernelDensity
from scipy.integrate import cumulative_trapezoid

In [2]:
import numpy as np
import torch
from scipy.interpolate import interp1d

def multi_modal_distribution(
    x, num_peaks, log_start, log_end, rng=np.random.default_rng(seed=np.random.seed())
):
    # Generate parameters for uniform distributions
    centers = rng.uniform(log_start + 1, log_end - 1, num_peaks).astype(np.float32)
    widths = rng.uniform(0.1, (log_end - log_start) / 5, num_peaks).astype(np.float32)
    
    distribution = np.zeros_like(x, dtype=np.float32)

    for center, width in zip(centers, widths):
        # Define the bounds of the uniform distribution
        lower_bound = center - width / 2
        upper_bound = center + width / 2

        # Add the uniform distribution within its range
        mask = (np.log(x) >= lower_bound) & (np.log(x) <= upper_bound)
        distribution[mask] += 1.0 / width  # Constant value for uniform distribution

    return distribution

def sample_from_multi_modal_distribution(batch_size, seq_len, num_features):

    xs = torch.zeros(batch_size, seq_len, num_features, dtype=torch.float32)
    ys = torch.empty(batch_size, seq_len, dtype=torch.float32)

    rng = np.random.default_rng(seed=np.random.seed())
    log_start = -6
    log_end = 15
    x = np.logspace(log_start, log_end, seq_len, dtype=np.float32)

    for i in range(batch_size):
        num_peaks = rng.integers(2, 6)
        own_pdf = multi_modal_distribution(x, num_peaks, log_start, log_end, rng=rng)
        own_pdf /= np.trapz(own_pdf, x)
        own_cdf = np.cumsum(own_pdf).astype(np.float32)
        own_cdf /= own_cdf[-1]
        inverse_cdf = interp1d(
            own_cdf, x, bounds_error=False, fill_value=(x[0], x[-1]), kind="linear"
        )
        uniform_samples = rng.uniform(0, 1, seq_len).astype(np.float32)
        samples = inverse_cdf(uniform_samples).astype(np.float32)
        ys[i] = torch.tensor(samples, dtype=torch.float32)

    return xs, ys

def get_batch_new_distribution_prior(
    batch_size, seq_len, num_features=1, hyperparameters=None, **kwargs
):
    xs, ys = sample_from_multi_modal_distribution(
        batch_size=batch_size, seq_len=seq_len, num_features=num_features
    )
    # Log encoding y
    ys = torch.log10(ys)

    return Batch(
        x=xs.transpose(0, 1),
        y=ys.transpose(0, 1),
        target_y=ys.transpose(0, 1),
    )


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.interpolate import interp1d

# Parameters
batch_size = 1
seq_len = 1000
num_features = 1
log_start = -6
log_end = 15

# Generate data
_, ys = sample_from_multi_modal_distribution(batch_size, seq_len, num_features)
x = np.logspace(log_start, log_end, seq_len, dtype=np.float32)

# Compute the multi-modal PDF and CDF
rng = np.random.default_rng(seed=None)
num_peaks = 4  # Adjust number of peaks as desired
pdf = multi_modal_distribution(x, num_peaks, log_start, log_end, rng=rng)
pdf /= np.trapz(pdf, x)  # Normalize PDF
cdf = np.cumsum(pdf)
cdf /= cdf[-1]  # Normalize CDF

# Plot PDF, CDF, and sampled distribution
plt.figure(figsize=(18, 6))

# PDF plot
plt.subplot(1, 3, 1)
plt.plot(x, pdf, label="PDF", color="blue")
plt.xscale("log")
plt.title("Multi-Modal Uniform PDF")
plt.xlabel("x (log scale)")
plt.ylabel("Density")
plt.legend()

# CDF plot
plt.subplot(1, 3, 2)
plt.plot(x, cdf, label="CDF", color="green")
plt.xscale("log")
plt.title("Cumulative Distribution Function (CDF)")
plt.xlabel("x (log scale)")
plt.ylabel("Cumulative Probability")
plt.legend()

# Histogram of samples
plt.subplot(1, 3, 3)
plt.hist(ys[0].numpy(), bins=50, density=True, alpha=0.7, color="orange", label="Samples")
plt.xscale("log")
plt.title("Sampled Distribution")
plt.xlabel("x (log scale)")
plt.ylabel("Density")
plt.legend()

plt.tight_layout()
plt.show()



In [4]:
def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [None]:
model_path = "../../../../models/final_kinpfn_model_1400_1000_1000_86_50_2.5588748050825984e-05_256_4_512_8_0.0_0.0.pt"

kinpfn = KINPFN(
    model_path=model_path,
)
trained_model = kinpfn.model

if trained_model is not None:
    print("Load trained model!")
else:
    print("No trained model found!")
    exit()

In [6]:
def evaluate_model_on_new_distr_prior(trained_model, n_components_list):
    set_seed(seed=123)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    training_points = [10, 25, 50, 75, 100]

    batch = get_batch_new_distribution_prior(100, seq_len)
    dataset_size = 100
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    
    #n_components_list = [2, 3, 4, 5]

    for training_point in training_points:
        print(f"Training Point: {training_point}")
        mae_losses = []
        mean_nll_losses = []
        ensemble_mae_losses = []
        ensemble_nll_losses = []

        for i in indices:
            batch_index = i

            train_indices = torch.randperm(seq_len)[:training_point]

            train_x = x[train_indices, batch_index]
            train_y_folding_times = y_folding_times[train_indices, batch_index]

            test_x = x[:, batch_index]
            test_y_folding_times = y_folding_times[:, batch_index]

            train_x = train_x.to(device)
            train_y_folding_times = train_y_folding_times.to(device)
            test_x = test_x.to(device)
            test_y_folding_times = test_y_folding_times.to(device)

            with torch.no_grad():
                # Add a batch dimension as required by the transformer model
                logits = trained_model(
                    train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
                )

            ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
            ground_truth_cdf = torch.arange(
                1, len(ground_truth_sorted_folding_times) + 1
            ) / len(ground_truth_sorted_folding_times)

            test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
            pred_cdf_original = (
                trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
            )[0][0]

            single_absolute_error = np.abs(pred_cdf_original - ground_truth_cdf)
            mae = single_absolute_error.mean()
            mae_losses.append(mae)

            nll_loss = trained_model.criterion.forward(
                logits=logits, y=test_y_folding_times_sorted
            )
            mean_nll_loss = nll_loss.mean()
            mean_nll_losses.append(mean_nll_loss)

            train_y_folding_times_np = train_y_folding_times.reshape(-1, 1).cpu().numpy()
            test_y_folding_times_np = test_y_folding_times_sorted.reshape(-1, 1).cpu().numpy()
            combined_density = np.zeros_like(test_y_folding_times_np.flatten())
            marginal_weights = []
            gmm_pdfs = []

            for n_components in n_components_list:
                gmm = GaussianMixture(n_components=n_components, max_iter=100000)
                gmm.fit(train_y_folding_times_np)
                
                log_marginal_likelihood = gmm.score(train_y_folding_times_np) * len(train_y_folding_times_np)
                marginal_weights.append(log_marginal_likelihood)
                
                gmm_pdf = np.exp(gmm.score_samples(test_y_folding_times_np))
                gmm_pdfs.append(gmm_pdf)

            marginal_weights = np.exp(marginal_weights - np.max(marginal_weights))
            marginal_weights /= np.sum(marginal_weights)

            # Weighted combined density
            for gmm_pdf, weight in zip(gmm_pdfs, marginal_weights):
                combined_density += weight * gmm_pdf

            combined_cdf = np.cumsum(combined_density) / np.sum(combined_density)

            ensemble_mae = np.abs(combined_cdf - ground_truth_cdf.cpu().numpy()).mean()
            ensemble_mae_losses.append(ensemble_mae)

            ensemble_nll = -np.mean(np.log(combined_density + 1e-9))
            ensemble_nll_losses.append(ensemble_nll)

        mae_losses = torch.tensor(mae_losses).mean()
        print(f"KinPFN Mean MAE Loss: {mae_losses}")

        mean_nll_losses = torch.tensor(mean_nll_losses).mean()
        print(f"KinPFN Mean NLL Loss: {mean_nll_losses}")

        ensemble_mae_losses = torch.tensor(ensemble_mae_losses).mean()
        print(f"GMM Ensemble Mean MAE Loss: {ensemble_mae_losses}")

        ensemble_nll_losses = torch.tensor(ensemble_nll_losses).mean()
        print(f"GMM Ensemble Mean NLL Loss: {ensemble_nll_losses}")

In [7]:
#evaluate_model_on_new_distr_prior(trained_model,n_components_list=[2, 3, 4, 5])

In [8]:
import matplotlib.pyplot as plt

def evaluate_model_on_new_distr_prior_plotting(trained_model, n_components_list):
    set_seed(seed=123)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_len = 1000
    training_points = [10, 25, 50, 75, 100]

    batch = get_batch_new_distribution_prior(100, seq_len)
    dataset_size = 100
    print(f"Dataset Size: {dataset_size}")

    x = batch.x
    y_folding_times = batch.y
    target_y_folding_times = batch.target_y

    indices = list(range(dataset_size))
    
    for training_point in training_points:
        print(f"Training Point: {training_point}")
        mae_losses = []
        mean_nll_losses = []
        ensemble_mae_losses = []
        ensemble_nll_losses = []

        # Process the first example only for plotting
        batch_index = indices[0]

        train_indices = torch.randperm(seq_len)[:training_point]

        train_x = x[train_indices, batch_index]
        train_y_folding_times = y_folding_times[train_indices, batch_index]

        test_x = x[:, batch_index]
        test_y_folding_times = y_folding_times[:, batch_index]

        train_x = train_x.to(device)
        train_y_folding_times = train_y_folding_times.to(device)
        test_x = test_x.to(device)
        test_y_folding_times = test_y_folding_times.to(device)

        with torch.no_grad():
            # Add a batch dimension as required by the transformer model
            logits = trained_model(
                train_x[:, None], train_y_folding_times[:, None], test_x[:, None]
            )

        ground_truth_sorted_folding_times, _ = torch.sort(test_y_folding_times)
        ground_truth_cdf = torch.arange(
            1, len(ground_truth_sorted_folding_times) + 1
        ) / len(ground_truth_sorted_folding_times)

        test_y_folding_times_sorted, _ = torch.sort(test_y_folding_times)
        pred_cdf_original = (
            trained_model.criterion.cdf(logits, test_y_folding_times_sorted)
        )[0][0]

        # GMM ensemble
        train_y_folding_times_np = train_y_folding_times.reshape(-1, 1).cpu().numpy()
        test_y_folding_times_np = test_y_folding_times_sorted.reshape(-1, 1).cpu().numpy()
        combined_density = np.zeros_like(test_y_folding_times_np.flatten())
        marginal_weights = []
        gmm_pdfs = []

        for n_components in n_components_list:
            gmm = GaussianMixture(n_components=n_components, max_iter=100000)
            gmm.fit(train_y_folding_times_np)
            
            log_marginal_likelihood = gmm.score(train_y_folding_times_np) * len(train_y_folding_times_np)
            marginal_weights.append(log_marginal_likelihood)
            
            gmm_pdf = np.exp(gmm.score_samples(test_y_folding_times_np))
            gmm_pdfs.append(gmm_pdf)

        marginal_weights = np.exp(marginal_weights - np.max(marginal_weights))
        marginal_weights /= np.sum(marginal_weights)

        # Weighted combined density
        for gmm_pdf, weight in zip(gmm_pdfs, marginal_weights):
            combined_density += weight * gmm_pdf

        combined_cdf = np.cumsum(combined_density) / np.sum(combined_density)

        # Plot the CDF approximation for both models
        plt.figure(figsize=(10, 6))
        plt.plot(
            ground_truth_sorted_folding_times.cpu().numpy(),
            ground_truth_cdf.cpu().numpy(),
            label="Ground Truth CDF",
            color="green",
            linestyle="--",
        )
        plt.plot(
            test_y_folding_times_sorted.cpu().numpy(),
            pred_cdf_original.cpu().numpy(),
            label="Model Predicted CDF",
            color="blue",
        )
        plt.plot(
            test_y_folding_times_sorted.cpu().numpy(),
            combined_cdf,
            label="GMM Ensemble CDF",
            color="orange",
        )
        plt.title(f"CDF Approximation (Training Points: {training_point})")
        plt.xlabel("Folding Times")
        plt.ylabel("Cumulative Probability")
        plt.legend()
        plt.grid()
        plt.show()

        # Break after the first example is processed for plotting
        break


In [None]:
evaluate_model_on_new_distr_prior_plotting(trained_model,n_components_list=[2, 3, 4, 5])