In [None]:
"""
Code for generating and running toy problems with blackjax mala mcmc sampler
"""

##################################################################################
# 1. PACKAGES
##################################################################################

# diagnostics
import os
import re
import argparse
import numpy as np
import matplotlib.pyplot as plt
import corner
import logging
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)

# blackjax 
import blackjax
import jax
import jax.numpy as jnp

# my classes
from likelihood import *
from gaussian_mixture import *


SUPPORTED_EXPERIMENTS = ["gaussian"]
##################################################################################
# 2. ARGUMENTS
##################################################################################
### The argparse is used to store and process any user input we want to pass on
parser = argparse.ArgumentParser(description="Run experiment with specified parameters.")
parser.add_argument(
    "--experiment-type",
    choices=["gaussian", "dualmoon", "rosenbrock"],
    required=True,
    help="Which experiment to run."
)
parser.add_argument(
    "--n-dims",
    type=int,
    required=True,
    help="Number of dimensions."
)
parser.add_argument(
    "--outdir",
    type=str,
    required=True,
    help="The output directory, where things will be stored"
)

# Everything below here are hyperparameters for the flowMC algorithms. 
parser.add_argument(
    "--mala-step-size",
    type=float,
    default=1e-1,
    help="Step size for the MALA proposal (local sampler)."
)
parser.add_argument(
    "--n-chains",
    type=int,
    default=20,
    help="Number of Markov chains to process in parallel."
)
parser.add_argument(
    "--n-steps",
    type=int,
    default=5000,
    help="Number of Markov chains to process in parallel."
)

parser.add_argument(
    "--show-initial-positions",
    action="store_true",
    help="Show initial chain positions."
)

# Everything below here are hyperparameters for the Gaussian experiment. 
parser.add_argument(
    "--nr-of-samples",
    type=int,
    default=10000,
    help="Number of samples to be geerated"
)
parser.add_argument(
    "--nr-of-components",
    type=int,
    default=2,
    help="Number of components to be geerated"
)
parser.add_argument(
    "--width-mean",
    type=float,
    default=10.0,
    help="The width of mean"
)
parser.add_argument(
    "--width-cov",
    type=float,
    default=3.0,
    help="The width of cov"
)
parser.add_argument(
    "--weights-of-components",
    nargs="+",          
    type=float,
    default=None,
    help="Mixture weights (--weights-of-components 0.3 0.7). If omitted, uses equal weights."
)






##################################################################################
# 3. EXPERIMENT RUNNER
##################################################################################
class BlackjaxExperimentRunner:
    """
    Base class storing everything shared between different run experiments
    """
    def __init__(self, args):
        # Process the argparse args into params:
        self.params = vars(args)

        # Automatically create a unique output directory: results_1, results_2, ...
        base_results_dir = self.params["outdir"]
        unique_outdir = self.get_next_available_outdir(base_results_dir)
        print(f"Using output directory: {unique_outdir}")
        os.makedirs(unique_outdir, exist_ok=False)
        self.params["outdir"] = unique_outdir

        # Check if experiment type is allowed/supported:
        if self.params["experiment_type"] not in SUPPORTED_EXPERIMENTS:
            raise ValueError(
                f"Experiment type {self.params['experiment_type']} is not supported. "
                f"Supported types are: {SUPPORTED_EXPERIMENTS}"
            )

        # Show the parameters to the screen/log file
        print("Passed parameters:")
        for key, value in self.params.items():
            print(f"{key}: {value}")


        # Specify the desired target function based on the experiment type
        # ... your outdir logic etc ...
        if self.params["experiment_type"] == "gaussian":
            print("Setting the target function to a standard Gaussian distribution.")
            
            # define parameters for mcmc sampler           
            np.random.seed(900)

            D = self.params["n_dims"]
            
            true_samples, means, covariances, weights = GaussianMixtureGenerator.generate_gaussian_mixture(
                n_dim=D,
                n_gaussians=args.nr_of_components,
                n_samples = args.nr_of_samples,
                width_mean = args.width_mean,
                width_cov = args.width_cov,
                weights= args.weights_of_components,
            ) 

            # Store true samples for diagnostics 
            self.true_samples = true_samples

            # Convert to JAX arrays 
            mcmc_means   = jnp.stack(means, axis=0)         # (K, D)
            mcmc_covs    = jnp.stack(covariances, axis=0)   # (K, D, D)
            mcmc_weights = jnp.asarray(weights)             # (K,)

            # Define Likelihood 
            self.likelihood = GaussianMixtureLikelihood(
                means=mcmc_means,
                covs=mcmc_covs,
                weights=mcmc_weights,
            )

            self.target_fn = self.target_normal

        elif self.params["experiment_type"] == "dualmoon":
            self.target_fn = self.target_dual_moon
            print(f"Setting the target function to a dual moon distribution.")
        elif self.params["experiment_type"] == "rosenbrock":
            self.target_fn = self.target_rosenbrock
            print(f"Setting the target function to a Rosenbrock distribution.")
            



    def get_next_available_outdir(self, base_dir: str, prefix: str = "results") -> str:
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        existing = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
        matches = [re.match(rf"{prefix}_(\d+)", name) for name in existing]
        numbers = [int(m.group(1)) for m in matches if m]
        next_number = max(numbers, default=0) + 1
        return os.path.join(base_dir, f"{prefix}_{next_number}")



    def target_normal(self, x, data):
        # x can be shape (D,) or (..., D); GaussianMixtureLikelihood.log_prob supports both
        return self.likelihood.log_prob(x)



    # cover all components 
    def init_chains_gmm_cover_all(self, key, n_chains, means, covs, cov_scale=1.0):
        """
        means: (K, D)
        covs:  (K, D, D)
        returns: (n_chains, D)
        """
        K, D = means.shape

        # assign components so every component appears (if n_chains >= K)
        if n_chains >= K:
            reps = n_chains // K
            rem  = n_chains % K
            comp_ids = jnp.concatenate([
                jnp.repeat(jnp.arange(K), reps),
                jnp.arange(rem)
            ])[:n_chains]
            key, kperm = jax.random.split(key)
            comp_ids = comp_ids[jax.random.permutation(kperm, n_chains)]
        else:
            # if fewer chains than components, can't cover all; pick first n_chains after shuffle
            key, kperm = jax.random.split(key)
            comp_ids = jax.random.permutation(kperm, K)[:n_chains]

        # one key per chain
        key, kdraw = jax.random.split(key)
        keys = jax.random.split(kdraw, n_chains)

        def draw_one(k, c):
            return jax.random.multivariate_normal(k, means[c], covs[c] * cov_scale)

        return jax.vmap(draw_one)(keys, comp_ids)



    def nearest_mode_counts(self, x, means):
        # x: (N,D), means: (K,D)
        d2 = jnp.sum((x[:, None, :] - means[None, :, :])**2, axis=-1)  # (N,K)
        ids = jnp.argmin(d2, axis=1)  # (N,)
        K = means.shape[0]
        return jnp.bincount(ids, length=K), ids



    #-----------------------------------------------------------------------------
    # 3.1. SAMPLER 
    #-----------------------------------------------------------------------------
    def run_experiment(self):
        """
        Run the sampler for the chosen experiment
        """
        dim = int(self.params["n_dims"])
        n_chains = int(self.params["n_chains"])

        rng_key = jax.random.PRNGKey(42)
        # rng_key, subkey = jax.random.split(rng_key)
        rng_key, key_init, key_bundle = jax.random.split(rng_key, 3)

        # assign covering all components
        if hasattr(self, "likelihood"):  
            initial_position = self.init_chains_gmm_cover_all(
            key_init,
            self.params["n_chains"],
            self.likelihood.means,
            self.likelihood.covs,
            cov_scale=1.0,
            )


        #initial_position = jax.random.normal(subkey, shape=(n_chains, dim))

        if self.params["show_initial_positions"]:
            print("Initial chain positions were:")
            print(initial_position)

        data = {}

        def logdensity_fn(x):
            return self.target_fn(x, data)

        step_size = float(self.params["mala_step_size"])
        num_steps = int(self.params["n_steps"])

        mala = blackjax.mala(logdensity_fn, step_size=step_size)
        step = jax.jit(mala.step)

        def run_one_chain(chain_key, init_pos):
            state = mala.init(init_pos)

            def one_step(state, key):
                state, info = step(key, state)
                return state, (state.position, info.acceptance_rate)

            keys = jax.random.split(chain_key, num_steps)
            _, (positions, acc_rates) = jax.lax.scan(one_step, state, keys)
            return positions, acc_rates

        chain_keys = jax.random.split(rng_key, n_chains)
        positions, acc_rates = jax.vmap(run_one_chain)(chain_keys, initial_position)

        self.chains = positions          # (n_chains, n_steps, dim)
        self.acc_rates = acc_rates       # (n_chains, n_steps)

        self.results = {
            "initial_position": initial_position,
            "chains": positions,        # (n_chains, n_steps, dim)
            "acc_rates": acc_rates,     # (n_chains, n_steps)
            "params": self.params,
            }

        print("Sampling complete!")
        print("Mean acceptance rate per chain:", jnp.mean(acc_rates, axis=1))
        print("Overall mean acceptance rate:", jnp.mean(acc_rates))

        return self.results







    #-----------------------------------------------------------------------------
    # 3.2. PLOT DIAGNOSTICS 
    #-----------------------------------------------------------------------------

    def get_true_and_mcmc_samples(self, discard=0, thin=1):
        """
        Returns:
        true_np: (N_true, dim)
        mcmc_np: (N_mcmc, dim)

        Notes:
        - This runner uses BlackJAX directly, so samples live in self.chains
        with shape (n_chains, n_steps, dim).
        - discard: number of initial steps per chain to drop (burn-in)
        - thin: keep every `thin`-th sample
        """
        if not hasattr(self, "true_samples") or self.true_samples is None:
            raise ValueError("No true samples found. Ensure self.true_samples is set (gaussian experiment).")

        if not hasattr(self, "chains") or self.chains is None:
            raise ValueError("No MCMC chains found. Run run_experiment() before extracting samples.")

        dim = int(self.params["n_dims"])

        true_np = np.asarray(self.true_samples).reshape(-1, dim)

        chains = np.asarray(self.chains)  # (n_chains, n_steps, dim)
        chains = chains[:, int(discard)::int(thin), :]
        mcmc_np = chains.reshape(-1, dim)

        return true_np, mcmc_np
    


    def plot_true_vs_mcmc_corner(self, seed=2046):
        """
        Overlay corner plot:
        - MCMC production samples (green)
        - true samples (red)
        Saves: true_vs_mcmc_corner_plot.pdf
        """
        import os
        import numpy as np
        import matplotlib.pyplot as plt
        import corner

        # Get samples
        true_np, mcmc_np = self.get_true_and_mcmc_samples()

        dim = int(self.params["n_dims"])
        labels = [f"x{i}" for i in range(dim)]

        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        # Plot MCMC first 
        fig = corner.corner(
            mcmc_np,
            color="green",
            hist_kwargs={"color": "green", "density": True},
            show_titles=True,
            labels=labels,
        )

        # Overlay true samples 
        corner.corner(
            true_np,
            fig=fig,
            color="red",
            hist_kwargs={"color": "red", "density": True},
            show_titles=True,
            labels=labels,
        )

        # Legend
        handles = [
            plt.Line2D([], [], color="green", label="blackjax"),
            plt.Line2D([], [], color="red", label="True Normal"),
        ]
        fig.legend(handles=handles, loc="upper right")

        save_name = os.path.join(outdir, "true_vs_mcmc_corner_plot.pdf")
        fig.savefig(save_name, bbox_inches="tight")
        plt.close(fig)

        print(f"Saved overlay corner plot to {save_name}")



    def plot_acceptance_rate(self):
        print("Plotting BlackJAX diagnostic curve (acceptance rate)...")
        if not hasattr(self, "acc_rates") or self.acc_rates is None:
            raise ValueError("No acceptance-rate history found. Run run_experiment() first.")

        acc = np.asarray(self.acc_rates)  # (n_chains, n_steps)
        mean_acc = acc.mean(axis=0)       # (n_steps,)

        plt.figure(figsize=(6, 4))
        plt.plot(mean_acc)
        plt.xlabel("MCMC Step")
        plt.ylabel("Mean acceptance rate (across chains)")
        plt.title("BlackJAX MALA Acceptance Rate")

        save_name = os.path.join(self.params["outdir"], "acceptance_rate_curve.pdf")
        print(f"Saving diagnostic curve to {save_name}")
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()

        print("Acceptance rate curve plot saved.")


    #-----------------------------------------------------------------------------
    # 3.3. SAMPLE STATISTICS
    #-----------------------------------------------------------------------------
    def save_samples_json(self):
        import os, json

        # output directory 
        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        # get samples
        true_np, mcmc_np = self.get_true_and_mcmc_samples()

        # save MCMC
        mcmc_path = os.path.join(outdir, "mcmc_samples.json")
        with open(mcmc_path, "w", encoding="utf-8") as f:
            json.dump(mcmc_np.tolist(), f)
        print(f"MCMC samples saved to {mcmc_path}")

        # save TRUE
        true_path = os.path.join(outdir, "true_samples.json")
        with open(true_path, "w", encoding="utf-8") as f:
            json.dump(true_np.tolist(), f)
        print(f"True samples saved to {true_path}")


    def compute_and_save_sample_statistics(self):
        """
        Computes and saves per-dimension statistics for:
        - MCMC production samples
        - true samples
        Saves: sample_statistics.txt in self.params["outdir"]
        """
        import os
        import numpy as np

        # get samples
        true_samples, mcmc_samples = self.get_true_and_mcmc_samples()

        # MCMC stats
        self.pm = mcmc_samples.mean(axis=0)
        self.pv = mcmc_samples.var(axis=0)
        self.ps = mcmc_samples.std(axis=0)

        # True stats
        self.qm = true_samples.mean(axis=0)
        self.qv = true_samples.var(axis=0)
        self.qs = true_samples.std(axis=0)

        # store arrays 
        self.mcmc_samples = mcmc_samples
        self.true_samples_np = true_samples

        np.set_printoptions(precision=4, suppress=True)

        stats_str = (
            "pm (mean of MCMC samples):\n" + str(self.pm) +
            "\n\npv (variance of MCMC samples):\n" + str(self.pv) +
            "\n\nps (std dev of MCMC samples):\n" + str(self.ps) +
            "\n\nqm (mean of true samples):\n" + str(self.qm) +
            "\n\nqv (variance of true samples):\n" + str(self.qv) +
            "\n\nqs (std dev of true samples):\n" + str(self.qs) + "\n"
        )

        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        stats_path = os.path.join(outdir, "sample_statistics.txt")
        with open(stats_path, "w", encoding="utf-8") as f:
            f.write(stats_str)

        print(f"Sample statistics saved to {stats_path}")



    #-----------------------------------------------------------------------------
    # 3.4. KL DIVERGENCE
    #-----------------------------------------------------------------------------
    import numpy as np, warnings, os
    from typing import Tuple


    @staticmethod
    def gau_kl(pm: np.ndarray, pv: np.ndarray,
               qm: np.ndarray, qv: np.ndarray) -> float:
        """
        Kullback-Liebler divergence from Gaussian pm,pv to Gaussian qm,qv.
        Also computes KL divergence from a single Gaussian pm,pv to a set
         of Gaussians qm,qv.
        Diagonal covariances are assumed.  Divergence is expressed in nats.
        """
        if (len(qm.shape) == 2):
            axis = 1
        else:
            axis = 0
        # Determinants of diagonal covariances pv, qv
        dpv = pv.prod()
        dqv = qv.prod(axis)
        # Inverse of diagonal covariance qv
        iqv = 1. / qv
        # Difference between means pm, qm
        diff = qm - pm
        return (0.5 * (
            np.log(dqv / dpv)                 # log |\Sigma_q| / |\Sigma_p|
            + (iqv * pv).sum(axis)            # + tr(\Sigma_q^{-1} * \Sigma_p)
            + (diff * iqv * diff).sum(axis)   # + (\mu_q-\mu_p)^T\Sigma_q^{-1}(\mu_q-\mu_p)
            - len(pm)                         # - N
        ))
    

    def kl_metrics(
        self,
        outdir: str | None = None,
        filename: str = "kl_metrics.txt",
    ) -> None:
        import os
        import numpy as np

        # Define outdir
        outdir = (
            outdir
            or (getattr(self, "params", {}) or {}).get("outdir", None)
            or getattr(self, "outdir", None)
        )
        if outdir is None:
            raise ValueError("No output directory specified (pass outdir=... or set params['outdir']).")
        os.makedirs(outdir, exist_ok=True)

        # Get samples
        true_np, mcmc_np = self.get_true_and_mcmc_samples() 

        # Parametric Gaussian stats (diagonal covariance assumed)
        pm = mcmc_np.mean(axis=0)
        pv = mcmc_np.var(axis=0)
        qm = true_np.mean(axis=0)
        qv = true_np.var(axis=0)

        kl_val = self.gau_kl(pm, pv, qm, qv)  # scalar for 1D qm/qv

        out_path = os.path.join(outdir, filename)
        with open(out_path, "w", encoding="utf-8") as f:
            if np.isscalar(kl_val):
                f.write(f"Parametric KL (Gaussian): {float(kl_val):.8f}\n")
            else:
                kl_arr = np.asarray(kl_val).ravel()
                f.write("Parametric KL (Gaussian):\n")
                for i, v in enumerate(kl_arr):
                    f.write(f"  [{i}] {float(v):.8f}\n")

        print(f"KL metrics saved to {out_path}")








In [6]:
import sys

sys.argv = [
    ### The argparse is used to store and process any user input we want to pass on
    "notebook",
    "--experiment-type", "gaussian",
    "--outdir", "./runs/gaussian_5d",

    # Everything below here are hyperparameters for the flowMC algorithms.
    "--n-chains", "10",
    "--n-steps", "1000",                 
    "--mala-step-size", "1e-1",
    "--show-initial-positions",

    # Everything below here are hyperparameters for the gaussians.
    "--n-dims", "4",
    "--nr-of-samples", "10000",
    "--nr-of-components", "4",
    "--width-mean", "10.0",
    "--width-cov", "1.0",
    #"--weights-of-components", "0.4", "0.3", "0.3"
]


def main():
    # Get the arguments passed over from the command line, and create the experiment runner
    args = parser.parse_args()
    runner = BlackjaxExperimentRunner(args)
    runner.run_experiment()
    runner.plot_true_vs_mcmc_corner()
    runner.plot_acceptance_rate()
    runner.save_samples_json()
    runner.compute_and_save_sample_statistics()
    runner.kl_metrics()

if __name__ == "__main__":
    main()

Using output directory: ./runs/gaussian_5d\results_9
Passed parameters:
experiment_type: gaussian
n_dims: 4
outdir: ./runs/gaussian_5d\results_9
mala_step_size: 0.1
n_chains: 10
n_steps: 1000
show_initial_positions: True
nr_of_samples: 10000
nr_of_components: 4
width_mean: 10.0
width_cov: 1.0
weights_of_components: None
Setting the target function to a standard Gaussian distribution.
[-1.9527 -5.8704  7.5673 -5.8053]
[-1.7255  1.3081  2.4974 -0.9036]
[-1.6282 -0.0014  3.6579 -2.1062]
[ 1.3373 -3.7496  6.6619  6.7686]
Initial chain positions were:
[[-5.5477 -0.8543  1.7817 -1.3673]
 [ 0.7932 -3.0263  8.1518  5.2368]
 [-0.3866 -7.0235  6.7723 -5.8014]
 [ 0.4926 -4.5503  6.4392  7.4438]
 [-4.2156 -2.3204  8.3174 -7.5876]
 [-1.7456 -1.1393  2.6385 -0.5381]
 [-3.264   1.1644  4.5745 -0.838 ]
 [-1.6756 -0.8838  3.7719 -0.9709]
 [-1.8493  1.8194  3.3371 -0.956 ]
 [-1.4142 -7.7683  7.4747 -5.1494]]
Sampling complete!
Mean acceptance rate per chain: [0.586  0.7683 0.3144 0.7795 0.3057 0.7322 0.