# Run MCMC for JIMWLK

In [None]:
from src import workdir, parse_model_parameter_file
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde
from scipy.stats import entropy
import jax
import jax.numpy as jnp
import time
import sys
import os

# Get the path three levels up and add it to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))

from src.realnvp.utils import load

plt.rcParams.update({'font.size': 24})

In [None]:
param_labels = [
    r"$m\;[\mathrm{GeV}]$", 
    r"$B_G\;[\mathrm{GeV}^{-2}]$", 
    r"$B_{q}\;[\mathrm{GeV}^{-2}]$", 
    r"$\sigma$", 
    r"$Q_s/(g^2\mu)$", 
    r"$m_{\mathrm{JIMWLK}}\;[\mathrm{GeV}]$", 
    r"$\Lambda_{\mathrm{QCD}}\;[\mathrm{GeV}]$"
]

bounds = np.array([[0.02, 1.2],  # m
                [1, 10],    # BG
                [0.05, 3],  # BGq
                [0, 1.5],   # smearQsWidth
                [0.05, 1.50], # QsmuRatio
                [0.02, 1.2], # m_jimwlk
                [0.0001, 0.28] # Lambda_QCD_jimwlk
                ])

def read_pkl_file_chain_pocoMC(PATH_pklfile_chains):
    """
    data is a dictionary containing:
    - 'chain'
    - 'weights'
    - 'logl'
    - 'logp'
    - 'logz'
    - 'logz_err'
    """
    with open(PATH_pklfile_chains, 'rb') as pf:
        data = pickle.load(pf)

    # delete columns 
    data['chain'] = np.delete(data['chain'], [3, 8, 9, 10], axis=1)

    return data

In [None]:
def compute_1d_kl(p_samples, q_samples, bounds, bins=100, epsilon=1e-10):
    """
    Computes the KL divergence for each 1D marginal (diagonal of corner plot).
    Arguments:
        p_samples, q_samples: Arrays of shape (n_samples, n_dims)
        bounds: np.ndarray of shape (n_dims, 2) specifying min and max for each dimension
    Returns:
        kls: List of KL divergences for each dimension
        mean_kl: Mean KL divergence across all dimensions
    """
    kls = []
    dim = p_samples.shape[1]
    for i in range(dim):
        range_i = bounds[i]
        p_hist, bin_edges = np.histogram(p_samples[:, i], bins=bins, range=range_i, density=True)
        q_hist, _ = np.histogram(q_samples[:, i], bins=bin_edges, density=True)

        # Avoid log(0)
        p_hist += epsilon
        q_hist += epsilon

        kl = entropy(p_hist, q_hist)
        kls.append(kl)
    return kls, np.mean(kls)


def compute_2d_kl_lower(p_samples, q_samples, bounds, bins=40, epsilon=1e-10):
    """
    Computes KL divergence for each 2D correlation (lower triangle of corner plot).
    Arguments:
        p_samples, q_samples: Arrays of shape (n_samples, n_dims)
        bounds: np.ndarray of shape (n_dims, 2) specifying min and max for each dimension
    Returns:
        kls_2d: List of tuples ((i, j), kl) for each lower-triangle pair
        avg_kl_2d: Mean KL divergence across all pairs
    """
    kls_2d = []
    dim = p_samples.shape[1]
    for i in range(dim):
        for j in range(i):
            range_2d = [bounds[i], bounds[j]]
            p_hist, xedges, yedges = np.histogram2d(
                p_samples[:, i], p_samples[:, j], bins=bins, range=range_2d, density=True
            )
            q_hist, _, _ = np.histogram2d(
                q_samples[:, i], q_samples[:, j], bins=[xedges, yedges], density=True
            )

            # Flatten and regularize
            p_flat = p_hist.flatten() + epsilon
            q_flat = q_hist.flatten() + epsilon

            kl_2d = entropy(p_flat, q_flat)
            kls_2d.append(((i, j), kl_2d))
    
    avg_kl_2d = np.mean([k[1] for k in kls_2d]) if kls_2d else 0.0
    return kls_2d, avg_kl_2d

In [None]:
class NFSampler:
    def __init__(self, path_NF_model, seed=None):
        self.dim = 2
        self.bounds = np.array([[0.02, 1.2],  # m
                                [1, 10],    # BG
                                [0.05, 3],  # BGq
                                [0, 1.5],   # smearQsWidth
                                [0.05, 1.50], # QsmuRatio
                                [0.02, 1.2], # m_jimwlk
                                [0.0001, 0.28] # Lambda_QCD_jimwlk
                                ])

        # Specify to use CPU, not GPU.
        jax.config.update('jax_platform_name', 'cpu')

        if seed is None:
            seed = time.time_ns()
        self.sample_key, self.init_key = jax.random.split(jax.random.PRNGKey(seed), 2)

        # Load the normalizing flow and its hyperparameters
        self.flow, self.hyperparams = load(path_NF_model, self.init_key)
        self.dimension = self.hyperparams['dimension']

    def _in_bounds(self, x):
        lower = jnp.array(self.bounds[:, 0])
        upper = jnp.array(self.bounds[:, 1])
        return jnp.all((x >= lower) & (x <= upper), axis=1)

    def sample(self, size=1):
        accepted = []
        while len(accepted) < size:
            self.sample_key, pkey = jax.random.split(self.sample_key)
            z = jax.random.normal(pkey, (size, self.dimension))
            x, _ = self.flow(z)

            mask = np.array(self._in_bounds(x))
            x_np = np.array(x)
            accepted.extend(x_np[mask])

        return np.array(accepted[:size])

In [None]:
def make_corner_comparison_plot_paper(data1, data2, NF_model_path, bounds, 
                                      labels=None, 
                                      density_labels=None,
                                      bins=50, figsize=(12, 12),
                                      save_path=None,
                                      kl_1d=None, 
                                      kl_2d=None,
                                      mean_kl=None,
                                      title=None):
    """
    Create a corner plot comparing training_data and model_samples with fixed bounds.
    
    Parameters:
    - data1: (N, D) numpy array
    - data2: (M, D) numpy array
    - NF_model_path: path to the NF model file
    - bounds: (D, 2) array of [min, max] for each dimension
    - labels: list of length D with parameter names
    - density_labels: list of length 2 with density plot labels
    - bins: number of histogram bins
    - figsize: figure size
    - save_path: path to save the plot (optional)
    - kl_1d: optional array of shape (D,) with 1D KL divergences
    - kl_2d: optional array of shape (D, D) with 2D KL divergences
    - mean_kl: optional mean KL divergence across all dimensions
    - title: optional title for the plot
    """
    sampler = NFSampler(NF_model_path)
    model_samples = jax.device_get(sampler.sample(size=data1.shape[0]))

    def get_contour_levels(Z, levels):
        Z_flat = Z.flatten()
        Z_sorted = np.sort(Z_flat)[::-1]
        cumsum = np.cumsum(Z_sorted)
        cumsum /= cumsum[-1]  # normalize to [0,1]
        levels_out = [Z_sorted[np.searchsorted(cumsum, level)] for level in levels]
        return sorted(levels_out)

    D = data1.shape[1]
    fig, axes = plt.subplots(D, D, figsize=figsize, sharex='col', sharey='row')

    for i in range(D):
        for j in range(D):
            ax = axes[i, j]

            if i == j:
                # Diagonal: 1D histograms
                ax = fig.add_subplot(D, D, i * D + j + 1)
                ax.hist(data1[:, i], bins=bins, density=True, alpha=0.5,
                        color='tab:green', label=density_labels[0], range=bounds[i])
                ax.hist(model_samples[:, i], bins=bins, density=True,
                        color='tab:orange', histtype='step', linewidth=2,
                        label=density_labels[1], range=bounds[i], linestyle=':')
                ax.hist(data2[:, i], bins=bins, density=True,
                        color='tab:purple', histtype='step', linewidth=2,
                        label=density_labels[2], range=bounds[i], linestyle='--')
                ax.set_yticks([])
                ax.set_yticklabels([])
                if i != D-1:
                    ax.set_xticks([])
                    ax.set_xticklabels([])
                if i == D-1:
                    ax.tick_params(axis='x', rotation=45, labelsize=10)

                ax.set_xlim(bounds[i])
                ax.set_ylim(0, None)

                if kl_1d is not None:
                    ax.annotate(rf"$D_{{KL}}={kl_1d[i]:.3f}$", xy=(0.95, 0.95), 
                                xycoords='axes fraction',
                                ha='right', va='top', fontsize=10)
            elif j < i:
                # Lower triangle: 2D histograms
                ax.hist2d(data1[:, j], data1[:, i], bins=bins,
                          range=[bounds[j], bounds[i]], cmap='Greens', alpha=0.5)

                x = np.linspace(bounds[j][0], bounds[j][1], bins)
                y = np.linspace(bounds[i][0], bounds[i][1], bins)
                X, Y = np.meshgrid(x, y)
                # Add contour for training_data (solid)
                try:
                    data_train = np.vstack([data1[:, j], data1[:, i]])
                    kde_train = gaussian_kde(data_train)
                    Z_train = kde_train(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
                    contour_levels = get_contour_levels(Z_train, levels=[0.685, 0.955, 0.997])
                    ax.contour(X, Y, Z_train, levels=contour_levels, colors='green', linewidths=1.2)
                except Exception as e:
                    print(f"Contour failed for training data at ({i},{j}): {e}")
                
                # Add contour for model_samples (dashed)
                try:
                    data_model = np.vstack([data2[:, j], data2[:, i]])
                    kde_model = gaussian_kde(data_model)
                    Z_model = kde_model(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
                    contour_levels_model = get_contour_levels(Z_model, levels=[0.685, 0.955, 0.997])
                    ax.contour(X, Y, Z_model, levels=contour_levels_model, colors='purple',
                            linewidths=1.2, linestyles='dashed')
                except Exception as e:
                    print(f"Contour failed for model data at ({i},{j}): {e}")

                ax.set_xlim(bounds[j])
                ax.set_ylim(bounds[i])

                if kl_2d is not None and not np.isnan(kl_2d[i, j]):
                    ax.annotate(rf"$D_{{KL}}={kl_2d[i, j]:.3f}$", xy=(0.95, 0.95), 
                                xycoords='axes fraction',
                                ha='right', va='top', fontsize=10)
            else:
                #ax.axis('off')
                ax.hist2d(model_samples[:, j], model_samples[:, i], bins=bins,
                          range=[bounds[j], bounds[i]], cmap='Oranges', alpha=0.5)
                
                x = np.linspace(bounds[j][0], bounds[j][1], bins)
                y = np.linspace(bounds[i][0], bounds[i][1], bins)
                X, Y = np.meshgrid(x, y)
                # Add contour for training_data (solid)
                try:
                    data_train = np.vstack([model_samples[:, j], model_samples[:, i]])
                    kde_train = gaussian_kde(data_train)
                    Z_train = kde_train(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
                    contour_levels = get_contour_levels(Z_train, levels=[0.685, 0.955, 0.997])
                    ax.contour(X, Y, Z_train, levels=contour_levels, colors='orange',
                               linewidths=1.5, linestyles='dotted')
                except Exception as e:
                    print(f"Contour failed for training data at ({i},{j}): {e}")

                # Add contour for model_samples (dashed)
                try:
                    data_model = np.vstack([data2[:, j], data2[:, i]])
                    kde_model = gaussian_kde(data_model)
                    Z_model = kde_model(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
                    contour_levels_model = get_contour_levels(Z_model, levels=[0.685, 0.955, 0.997])
                    ax.contour(X, Y, Z_model, levels=contour_levels_model, colors='purple',
                            linewidths=1.2, linestyles='dashed')
                except Exception as e:
                    print(f"Contour failed for model data at ({i},{j}): {e}")

            # Labeling
            if i == D - 1:
                ax.set_xlabel(labels[j], fontsize=14)
            if j == 0:
                ax.set_ylabel(labels[i], fontsize=14)

            #if i < j:
            #    ax.axis('off')

    # Remove space between subplots
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.subplots_adjust(left=0.07, right=0.95, top=0.95, bottom=0.07)

    # Remove ticks and labels for the first and last plots
    axes[0, 0].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
    axes[D - 1, D - 1].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)

    # Rotate x-axis tick labels
    for ax in axes[-1]:
        ax.tick_params(axis='x', rotation=45, labelsize=14)

    # Rotate y-axis tick labels
    for ax in axes[:, 0]:
        ax.tick_params(axis='y', rotation=45, labelsize=14)

    fig.legend(
        labels=density_labels, 
        loc='center left', 
        fontsize=12,
        frameon=False,
        bbox_to_anchor=(0.4, 0.96),
        ncol=3
    )
    if mean_kl is not None:
        # annotate mean KL divergence with bbox_to_anchor annotate
        fig.text(0.37, 0.955, rf"$\langle D_{{KL}} \rangle = {mean_kl:.3f}$", 
                 ha='right', va='bottom', fontsize=12)
    if title is not None:
        # annotate title with bbox_to_anchor annotate
        fig.text(0.15, 0.955, title, 
                 ha='center', va='bottom', fontsize=12)

    plt.savefig(save_path)
    plt.close(fig)
    del fig, axes

In [None]:
def wrapper_corner_plot(path1, path2, path3, dens_labels, title, filename):
    """
    Wrapper function to read data and create corner plot.
    
    Parameters:
    - path1: path to first data file
    - path2: path to second data file
    - dens_labels: labels for the density plots
    - filename: name of the output file
    """
    data1 = read_pkl_file_chain_pocoMC(path1)
    data2 = read_pkl_file_chain_pocoMC(path2)

    kl_1d, mean_kl_1d = compute_1d_kl(
        p_samples=data1['chain'],
        q_samples=data2['chain'],
        bounds=bounds,
        bins=75
    )
    kl_2d, mean_kl_2d = compute_2d_kl_lower(
        p_samples=data1['chain'],
        q_samples=data2['chain'],
        bounds=bounds,
        bins=75
    )
    kl_2d_dict = dict(kl_2d)
    # Overall mean KL divergence
    overall_mean_kl = (mean_kl_1d + mean_kl_2d) / 2.
    make_corner_comparison_plot_paper(
        data1=data1['chain'],
        data2=data2['chain'],
        NF_model_path=path3,
        bounds=bounds,
        labels=param_labels,
        density_labels=dens_labels,
        bins=75,
        save_path=filename,
        kl_1d=None,#kl_1d,
        kl_2d=None,#kl_2d_dict,
        mean_kl=overall_mean_kl,
        title=title
    )

In [None]:
path1 = "./mcmc_PCGP_full_vanilla_flat_prior/chain.pkl"
path2 = "./CLUSTER/start_eP/mcmc_PCGP_eP_prior_vanilla/chain.pkl"
path3 = "./eP_nf_bs5000_L6_lr1e-03.pkl"
dens_labels = ["Posterior: joint", "Posterior: stage 1", "Posterior: stage 2"]
title = r"Start: $\gamma+\mathrm{p}$"
filename = "./corner_plot_comparison_flat_vs_eP_prior.pdf"
wrapper_corner_plot(path1, path2, path3, dens_labels, title, filename)