In [None]:
import numpy as np
import matplotlib.pyplot as plt
import itertools
import jax
import jax.numpy as jnp
import time
import pickle
from scipy.stats import entropy
import gc
import sys
import os
import psutil

# Get the path three levels up and add it to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))

from src.realnvp.utils import load

def print_memory_usage():
    process = psutil.Process(os.getpid())
    print(f"Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")

In [None]:
# Load training data
def load_training_dataset(path_comparison_data, supervised=True):
    try:
        with open(path_comparison_data, 'rb') as pf:
            data_raw = pickle.load(pf)
        if supervised:
            return np.array(data_raw[:, :-1])
        else:
            return np.array(data_raw)
    except FileNotFoundError:
        print(f"File {path_comparison_data} not found. Please check the path.")
        return None

In [None]:
# Grid search parameters
batch_sizes = [500, 1000, 2000, 5000]
layer_counts = [4, 6, 8, 10, 12]
learning_rates = [1e-3, 5e-4, 1e-4, 5e-5]

combinations = list(itertools.product(batch_sizes, layer_counts, learning_rates))

base_paths = [
    "./eP/",
]

corresponding_model_dir_names = [
    "models_eP/",
]

filename_prefixes = [
    "eP_nf_bs",
]

supervised = [True]

model_names = []
for base_path, model_dir, filename_prefix in zip(base_paths, corresponding_model_dir_names, filename_prefixes):
    setup_model_names = []
    for batch_size, layers, lr in combinations:
        model_name = f"{base_path}{model_dir}{filename_prefix}{batch_size}_L{layers}_lr{lr:.0e}.pkl"
        setup_model_names.append(model_name)
    model_names.append(setup_model_names)

# Check results
print(f"{len(combinations)} combinations per setup.")
print(f"{len(model_names)} total setups.")
for i, (setup, names) in enumerate(zip(base_paths, model_names)):
    print(f"Setup {i+1}: {setup} → {len(names)} model names")


training_data_directories = [
    "data_eP",
]

training_datasets = []
for base_path, data_dir, is_supervised in zip(base_paths, training_data_directories, supervised):
    training_data = load_training_dataset(f"{base_path}{data_dir}/training_data.pkl", supervised=is_supervised)
    training_datasets.append(training_data)

# Print the training dataset shapes with their paths
print("Training datasets loaded:")
for base_path, data_dir, training_data in zip(base_paths, training_data_directories, training_datasets):
    if training_data is not None:
        print(f"Path: {base_path}{data_dir}/training_data.pkl, Shape: {training_data.shape}")
    else:
        print(f"Path: {base_path}{data_dir}/training_data.pkl, Shape: Not Loaded")


In [None]:
class NFSampler:
    def __init__(self, path_NF_model, seed=None):
        self.dim = 2
        self.bounds = np.array([[0.02, 1.2],  # m
                                [1, 10],    # BG
                                [0.05, 3],  # BGq
                                [0, 1.5],   # smearQsWidth
                                [0.05, 1.50], # QsmuRatio
                                [0.02, 1.2], # m_jimwlk
                                [0.0001, 0.28] # Lambda_QCD_jimwlk
                                ])

        # Specify to use CPU, not GPU.
        jax.config.update('jax_platform_name', 'cpu')

        if seed is None:
            seed = time.time_ns()
        self.sample_key, self.init_key = jax.random.split(jax.random.PRNGKey(seed), 2)

        # Load the normalizing flow and its hyperparameters
        self.flow, self.hyperparams = load(path_NF_model, self.init_key)
        self.dimension = self.hyperparams['dimension']

    def _in_bounds(self, x):
        lower = jnp.array(self.bounds[:, 0])
        upper = jnp.array(self.bounds[:, 1])
        return jnp.all((x >= lower) & (x <= upper), axis=1)

    def sample(self, size=1):
        accepted = []
        while len(accepted) < size:
            self.sample_key, pkey = jax.random.split(self.sample_key)
            z = jax.random.normal(pkey, (size, self.dimension))
            x, _ = self.flow(z)

            mask = np.array(self._in_bounds(x))
            x_np = np.array(x)
            accepted.extend(x_np[mask])

        return np.array(accepted[:size])

In [None]:
def make_corner_comparison_plot(training_data, model_samples, bounds, labels=None, bins=50, figsize=(12, 12),
                                save_path=None,
                                kl_1d=None, kl_2d=None):
    """
    Create a corner plot comparing training_data and model_samples with fixed bounds.
    
    Parameters:
    - training_data: (N, D) numpy array
    - model_samples: (M, D) numpy array
    - bounds: (D, 2) array of [min, max] for each dimension
    - labels: list of length D with parameter names
    - bins: number of histogram bins
    - figsize: figure size
    - save_path: path to save the plot (optional)
    - kl_1d: optional array of shape (D,) with 1D KL divergences
    - kl_2d: optional array of shape (D, D) with 2D KL divergences
    """
    D = training_data.shape[1]
    fig, axes = plt.subplots(D, D, figsize=figsize, sharex='col', sharey='row')

    for i in range(D):
        for j in range(D):
            ax = axes[i, j]

            if i == j:
                # Diagonal: 1D histograms
                ax = fig.add_subplot(D, D, i * D + j + 1)
                ax.hist(training_data[:, i], bins=bins, density=True, alpha=0.6,
                        color='tab:green', label='Training', range=bounds[i])
                ax.hist(model_samples[:, i], bins=bins, density=True, alpha=0.6,
                        color='tab:purple', label='NF', range=bounds[i])
                ax.set_yticks([])
                ax.set_yticklabels([])
                if i != D-1:
                    ax.set_xticks([])
                    ax.set_xticklabels([])
                if i == D-1:
                    ax.tick_params(axis='x', rotation=45, labelsize=10)

                ax.set_xlim(bounds[i])
                ax.set_ylim(0, None)

                #if i == D-1:
                #    ax.legend(fontsize=10)

                if kl_1d is not None:
                    ax.annotate(rf"$D_{{KL}}={kl_1d[i]:.3f}$", xy=(0.95, 0.95), 
                                xycoords='axes fraction',
                                ha='right', va='top', fontsize=10)
            elif j < i:
                # Lower triangle: 2D histograms
                ax.hist2d(training_data[:, j], training_data[:, i], bins=bins,
                          range=[bounds[j], bounds[i]], cmap='Greens', alpha=0.4)
                ax.hist2d(model_samples[:, j], model_samples[:, i], bins=bins,
                          range=[bounds[j], bounds[i]], cmap='RdPu', alpha=0.4)
                ax.set_xlim(bounds[j])
                ax.set_ylim(bounds[i])

                if kl_2d is not None and not np.isnan(kl_2d[i, j]):
                    ax.annotate(rf"$D_{{KL}}={kl_2d[i, j]:.3f}$", xy=(0.95, 0.95), 
                                xycoords='axes fraction',
                                ha='right', va='top', fontsize=10)
            else:
                ax.axis('off')

            # Labeling
            if i == D - 1:
                ax.set_xlabel(labels[j])
            if j == 0:
                ax.set_ylabel(labels[i])

            if i < j:
                ax.axis('off')

    # Remove space between subplots
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.subplots_adjust(left=0.07, right=0.95, top=0.95, bottom=0.07)

    # Remove ticks and labels for the first and last plots
    axes[0, 0].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
    axes[D - 1, D - 1].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)

    # Rotate x-axis tick labels
    for ax in axes[-1]:
        ax.tick_params(axis='x', rotation=45, labelsize=14)

    # Rotate y-axis tick labels
    for ax in axes[:, 0]:
        ax.tick_params(axis='y', rotation=45, labelsize=14)

    fig.legend(
        labels=["Training", "NF"], 
        loc="upper right", 
        fontsize=12,
        frameon=False,
        bbox_to_anchor=(0.35, 0.95)
    )

    plt.savefig(save_path)
    plt.close(fig)
    del fig, axes
    gc.collect()


In [None]:
def compute_1d_kl(p_samples, q_samples, bounds, bins=100, epsilon=1e-10):
    """
    Computes the KL divergence for each 1D marginal (diagonal of corner plot).
    Arguments:
        p_samples, q_samples: Arrays of shape (n_samples, n_dims)
        bounds: np.ndarray of shape (n_dims, 2) specifying min and max for each dimension
    Returns:
        kls: List of KL divergences for each dimension
        mean_kl: Mean KL divergence across all dimensions
    """
    kls = []
    dim = p_samples.shape[1]
    for i in range(dim):
        range_i = bounds[i]
        p_hist, bin_edges = np.histogram(p_samples[:, i], bins=bins, range=range_i, density=True)
        q_hist, _ = np.histogram(q_samples[:, i], bins=bin_edges, density=True)

        # Avoid log(0)
        p_hist += epsilon
        q_hist += epsilon

        kl = entropy(p_hist, q_hist)
        kls.append(kl)
    return kls, np.mean(kls)


def compute_2d_kl_lower(p_samples, q_samples, bounds, bins=40, epsilon=1e-10):
    """
    Computes KL divergence for each 2D correlation (lower triangle of corner plot).
    Arguments:
        p_samples, q_samples: Arrays of shape (n_samples, n_dims)
        bounds: np.ndarray of shape (n_dims, 2) specifying min and max for each dimension
    Returns:
        kls_2d: List of tuples ((i, j), kl) for each lower-triangle pair
        avg_kl_2d: Mean KL divergence across all pairs
    """
    kls_2d = []
    dim = p_samples.shape[1]
    for i in range(dim):
        for j in range(i):
            range_2d = [bounds[i], bounds[j]]
            p_hist, xedges, yedges = np.histogram2d(
                p_samples[:, i], p_samples[:, j], bins=bins, range=range_2d, density=True
            )
            q_hist, _, _ = np.histogram2d(
                q_samples[:, i], q_samples[:, j], bins=[xedges, yedges], density=True
            )

            # Flatten and regularize
            p_flat = p_hist.flatten() + epsilon
            q_flat = q_hist.flatten() + epsilon

            kl_2d = entropy(p_flat, q_flat)
            kls_2d.append(((i, j), kl_2d))
    
    avg_kl_2d = np.mean([k[1] for k in kls_2d]) if kls_2d else 0.0
    return kls_2d, avg_kl_2d

In [None]:
param_labels = [
    r"$m\;[\mathrm{GeV}]$", 
    r"$B_G\;[\mathrm{GeV}^{-2}]$", 
    r"$B_{q}\;[\mathrm{GeV}^{-2}]$", 
    r"$\sigma$", 
    r"$Q_s/(g^2\mu)$", 
    r"$m_{\mathrm{JIMWLK}}\;[\mathrm{GeV}]$", 
    r"$\Lambda_{\mathrm{QCD}}\;[\mathrm{GeV}]$"
]

best_models = []
for model_setup, training_data in zip(model_names, training_datasets):
    print(f"Processing model setup with {len(model_setup)} models...")
    model_setup_found = []
    avg_kl_values = []
    for model_name in model_setup:
        # If model does not exist, skip
        try:
            print(f"Evaluating model: {model_name}")
            sampler = NFSampler(model_name)
            #model_samples = np.array(sampler.sample(size=training_data.shape[0]))
            model_samples = jax.device_get(sampler.sample(size=training_data.shape[0]))
        except FileNotFoundError:
            print(f"Model file not found: {model_name}")
            continue

        # Compute KL divergences
        kl_1d, mean_kl_1d = compute_1d_kl(
            p_samples=training_data,
            q_samples=model_samples,
            bounds=sampler.bounds,
            bins=75
        )
        kl_2d, mean_kl_2d = compute_2d_kl_lower(
            p_samples=training_data,
            q_samples=model_samples,
            bounds=sampler.bounds,
            bins=75
        )

        # Overall mean KL divergence
        overall_mean_kl = (mean_kl_1d + mean_kl_2d) / 2.
        avg_kl_values.append(overall_mean_kl)
        model_setup_found.append(model_name)

        kl_2d_dict = dict(kl_2d)
        make_corner_comparison_plot(
            training_data=training_data,
            model_samples=model_samples,
            bounds=sampler.bounds,
            labels=param_labels,
            bins=75,
            save_path=f"{model_name}_corner_plot.png",
            kl_1d=kl_1d,
            kl_2d=kl_2d_dict
        )

        # Explicitly delete all objects to free memory
        del sampler, model_samples, kl_1d, kl_2d, kl_2d_dict
        gc.collect()
        print_memory_usage()

    # Identify the model with the lowest average KL divergence
    best_model_index = np.argmin(avg_kl_values)
    best_model_name = model_setup_found[best_model_index]
    best_models.append(f"Best model: {best_model_name} with KL divergence: {avg_kl_values[best_model_index]:.4f}")

# Print the best models for each setup
print("\nBest models for each setup:")
for i, best_model in enumerate(best_models):
    print(f"Setup {i+1}: {best_model}")

## Plot for the paper to compare the NF flow model to the training distribution

In [None]:
from scipy.stats import gaussian_kde

def make_corner_comparison_plot_paper(training_data, model_samples, bounds, labels=None, bins=50, figsize=(12, 12),
                                save_path=None,
                                kl_1d=None, kl_2d=None, mean_kl=None):
    """
    Create a corner plot comparing training_data and model_samples with fixed bounds.
    
    Parameters:
    - training_data: (N, D) numpy array
    - model_samples: (M, D) numpy array
    - bounds: (D, 2) array of [min, max] for each dimension
    - labels: list of length D with parameter names
    - bins: number of histogram bins
    - figsize: figure size
    - save_path: path to save the plot (optional)
    - kl_1d: optional array of shape (D,) with 1D KL divergences
    - kl_2d: optional array of shape (D, D) with 2D KL divergences
    - mean_kl: overall mean KL divergence (optional)
    """
    D = training_data.shape[1]
    fig, axes = plt.subplots(D, D, figsize=figsize, sharex='col', sharey='row')

    for i in range(D):
        for j in range(D):
            ax = axes[i, j]

            if i == j:
                # Diagonal: 1D histograms
                ax = fig.add_subplot(D, D, i * D + j + 1)
                ax.hist(training_data[:, i], bins=bins, density=True, alpha=0.5,
                        color='tab:green', label='Training', range=bounds[i])
                ax.hist(model_samples[:, i], bins=bins, density=True,
                        color='tab:orange', histtype='step', linewidth=2, 
                        label='NF', range=bounds[i], linestyle=':')
                ax.set_yticks([])
                ax.set_yticklabels([])
                if i != D-1:
                    ax.set_xticks([])
                    ax.set_xticklabels([])
                if i == D-1:
                    ax.tick_params(axis='x', rotation=45, labelsize=10)

                ax.set_xlim(bounds[i])
                ax.set_ylim(0, None)

                if kl_1d is not None:
                    ax.annotate(rf"$D_{{KL}}={kl_1d[i]:.3f}$", xy=(0.95, 0.95), 
                                xycoords='axes fraction',
                                ha='right', va='top', fontsize=10)
            elif j < i:
                # Lower triangle: 2D histograms
                ax.hist2d(training_data[:, j], training_data[:, i], bins=bins,
                          range=[bounds[j], bounds[i]], cmap='Greens', alpha=0.5)

                def get_contour_levels(Z, levels):
                    Z_flat = Z.flatten()
                    Z_sorted = np.sort(Z_flat)[::-1]
                    cumsum = np.cumsum(Z_sorted)
                    cumsum /= cumsum[-1]  # normalize to [0,1]
                    levels_out = [Z_sorted[np.searchsorted(cumsum, level)] for level in levels]
                    return sorted(levels_out)

                x = np.linspace(bounds[j][0], bounds[j][1], bins)
                y = np.linspace(bounds[i][0], bounds[i][1], bins)
                X, Y = np.meshgrid(x, y)
                # Add contour for training_data (solid)
                try:
                    data_train = np.vstack([training_data[:, j], training_data[:, i]])
                    kde_train = gaussian_kde(data_train)
                    Z_train = kde_train(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
                    contour_levels = get_contour_levels(Z_train, levels=[0.685, 0.955, 0.997])
                    ax.contour(X, Y, Z_train, levels=contour_levels, colors='green', linewidths=1.2)
                except Exception as e:
                    print(f"Contour failed for training data at ({i},{j}): {e}")
                
                # Add contour for model_samples (dashed)
                try:
                    data_model = np.vstack([model_samples[:, j], model_samples[:, i]])
                    kde_model = gaussian_kde(data_model)
                    Z_model = kde_model(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
                    contour_levels_model = get_contour_levels(Z_model, levels=[0.685, 0.955, 0.997])
                    ax.contour(X, Y, Z_model, levels=contour_levels_model, colors='orange',
                            linewidths=1.5, linestyles='dotted')
                except Exception as e:
                    print(f"Contour failed for model data at ({i},{j}): {e}")

                ax.set_xlim(bounds[j])
                ax.set_ylim(bounds[i])

                if kl_2d is not None and not np.isnan(kl_2d[i, j]):
                    ax.annotate(rf"$D_{{KL}}={kl_2d[i, j]:.3f}$", xy=(0.95, 0.95), 
                                xycoords='axes fraction',
                                ha='right', va='top', fontsize=10)
            else:
                ax.axis('off')

            # Labeling
            if i == D - 1:
                ax.set_xlabel(labels[j], fontsize=14)
            if j == 0:
                ax.set_ylabel(labels[i], fontsize=14)

            if i < j:
                ax.axis('off')

    # Remove space between subplots
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.subplots_adjust(left=0.07, right=0.95, top=0.95, bottom=0.07)

    # Remove ticks and labels for the first and last plots
    axes[0, 0].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
    axes[D - 1, D - 1].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)

    # Rotate x-axis tick labels
    for ax in axes[-1]:
        ax.tick_params(axis='x', rotation=45, labelsize=14)

    # Rotate y-axis tick labels
    for ax in axes[:, 0]:
        ax.tick_params(axis='y', rotation=45, labelsize=14)

    fig.legend(
        labels=["Training", "NF"], 
        loc="upper right", 
        fontsize=12,
        frameon=False,
        bbox_to_anchor=(0.35, 0.95)
    )

    if mean_kl is not None:
        # annotate mean KL divergence with bbox_to_anchor annotate
        fig.text(0.33, 0.87, rf"$\langle D_{{KL}} \rangle = {mean_kl:.3f}$", 
                 ha='right', va='bottom', fontsize=12)

    plt.savefig(save_path)
    plt.close(fig)
    del fig, axes
    gc.collect()

param_labels = [
    r"$m\;[\mathrm{GeV}]$", 
    r"$B_G\;[\mathrm{GeV}^{-2}]$", 
    r"$B_{q}\;[\mathrm{GeV}^{-2}]$", 
    r"$\sigma$", 
    r"$Q_s/(g^2\mu)$", 
    r"$m_{\mathrm{JIMWLK}}\;[\mathrm{GeV}]$", 
    r"$\Lambda_{\mathrm{QCD}}\;[\mathrm{GeV}]$"
]

base_paths = [
    "./eP/",
]

corresponding_model_dir_names = [
    "models_eP/",
]

filename_prefixes = [
    "eP_nf_bs",
]

supervised = [True, True]

model_names = [
    ["./CLUSTER/eP/models_eP/eP_nf_bs5000_L6_lr1e-03.pkl"],
]

# Check results
print(f"{len(model_names)} total setups.")
for i, (setup, names) in enumerate(zip(base_paths, model_names)):
    print(f"Setup {i+1}: {setup} → {len(names)} model names")


training_data_directories = [
    "data_eP",
]

training_datasets = []
for base_path, data_dir, is_supervised in zip(base_paths, training_data_directories, supervised):
    training_data = load_training_dataset(f"{base_path}{data_dir}/training_data.pkl", supervised=is_supervised)
    training_datasets.append(training_data)

# Print the training dataset shapes with their paths
print("Training datasets loaded:")
for base_path, data_dir, training_data in zip(base_paths, training_data_directories, training_datasets):
    if training_data is not None:
        print(f"Path: {base_path}{data_dir}/training_data.pkl, Shape: {training_data.shape}")
    else:
        print(f"Path: {base_path}{data_dir}/training_data.pkl, Shape: Not Loaded")

i = 0
for model_setup, training_data in zip(model_names, training_datasets):
    print(f"Processing model setup with {len(model_setup)} models...")
    for model_name in model_setup:
        # If model does not exist, skip
        try:
            print(f"Evaluating model: {model_name}")
            sampler = NFSampler(model_name)
            model_samples = jax.device_get(sampler.sample(size=training_data.shape[0]))
        except FileNotFoundError:
            print(f"Model file not found: {model_name}")
            continue

        # Compute KL divergences
        kl_1d, mean_kl_1d = compute_1d_kl(
            p_samples=training_data,
            q_samples=model_samples,
            bounds=sampler.bounds,
            bins=75
        )
        kl_2d, mean_kl_2d = compute_2d_kl_lower(
            p_samples=training_data,
            q_samples=model_samples,
            bounds=sampler.bounds,
            bins=75
        )

        kl_2d_dict = dict(kl_2d)
        # Overall mean KL divergence
        overall_mean_kl = (mean_kl_1d + mean_kl_2d) / 2.
        make_corner_comparison_plot_paper(
            training_data=training_data,
            model_samples=model_samples,
            bounds=sampler.bounds,
            labels=param_labels,
            bins=75,
            save_path=f"./NF_trained_corner_plot_{training_data_directories[i]}.pdf",
            kl_1d=None,
            kl_2d=None,
            mean_kl=overall_mean_kl
        )

        # Explicitly delete all objects to free memory
        del sampler, model_samples, kl_1d, kl_2d, kl_2d_dict
        gc.collect()
        print_memory_usage()
        i += 1
