In [None]:
"""
Code for generating and running toy problems with pocomc sampler
"""

##################################################################################
# 1. PACKAGES
##################################################################################

# diagnostics
import os
import json
import re
import argparse
import numpy as np
import matplotlib.pyplot as plt
import corner
import logging
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)

# Preconditioned Monte Carlo
import numpy as np
import pocomc as pc
from scipy.stats import uniform
import numpy as np
import jax
import jax.numpy as jnp
from scipy.stats import norm




import jax.numpy as jnp
import jax
from jax.scipy.special import logsumexp
from jax.scipy.linalg import solve_triangular


# my classes
from likelihood import *
from gaussian_mixture import *





SUPPORTED_EXPERIMENTS = ["gaussian"]
##################################################################################
# 2. ARGUMENTS
##################################################################################
### The argparse is used to store and process any user input we want to pass on
parser = argparse.ArgumentParser(description="Run experiment with specified parameters.")
parser.add_argument(
    "--experiment-type",
    choices=["gaussian", "dualmoon", "rosenbrock"],
    required=True,
    help="Which experiment to run."
)
parser.add_argument(
    "--n-dims",
    type=int,
    required=True,
    help="Number of dimensions."
)
parser.add_argument(
    "--outdir",
    type=str,
    required=True,
    help="The output directory, where things will be stored"
)

# Everything below here are hyperparameters for the Gaussian experiment. 
parser.add_argument(
    "--nr-of-samples",
    type=int,
    default=10000,
    help="Number of samples to be geerated"
)
parser.add_argument(
    "--nr-of-components",
    type=int,
    default=2,
    help="Number of components to be geerated"
)
parser.add_argument(
    "--width-mean",
    type=float,
    default=10.0,
    help="The width of mean"
)
parser.add_argument(
    "--width-cov",
    type=float,
    default=3.0,
    help="The width of cov"
)
parser.add_argument(
    "--weights-of-components",
    nargs="+",          
    type=float,
    default=None,
    help="Mixture weights (--weights-of-components 0.3 0.7). If omitted, uses equal weights."
)

# Everything below here are hyperparameters for sampler
parser.add_argument(
    "--prior-low",
    type=float,
    default=-20.0,
    help="Prior lower bound."
)
parser.add_argument(
    "--prior-high",
    type=float,
    default=20.0,
    help="Prior upper bound."
)
parser.add_argument(
    "--n-effective",
    type=int,
    default=4096,
    help="Effective number"
)
parser.add_argument(
    "--n-active",
    type=int,
    default=2048,
    help="Active number"
)
parser.add_argument(
    "--n-prior",
    type=int,
    default=50000,
    help="Active number"
)










##################################################################################
# 3. EXPERIMENT RUNNER
##################################################################################
class SequentialMCExperimentRunner:
    """
    Base class storing everything shared between different run experiments
    """
    def __init__(self, args):
        # Process the argparse args into params:
        self.params = vars(args)

        # Automatically create a unique output directory: results_1, results_2, ...
        base_results_dir = self.params["outdir"]
        unique_outdir = self.get_next_available_outdir(base_results_dir)
        print(f"Using output directory: {unique_outdir}")
        os.makedirs(unique_outdir, exist_ok=False)
        self.params["outdir"] = unique_outdir

        # Check if experiment type is allowed/supported:
        if self.params["experiment_type"] not in SUPPORTED_EXPERIMENTS:
            raise ValueError(
                f"Experiment type {self.params['experiment_type']} is not supported. "
                f"Supported types are: {SUPPORTED_EXPERIMENTS}"
            )

        # Show the parameters to the screen/log file
        print("Passed parameters:")
        for key, value in self.params.items():
            print(f"{key}: {value}")



        # Specify the desired target function based on the experiment type
        # ... your outdir logic etc ...
        if self.params["experiment_type"] == "gaussian":
            print("Setting the target function to a standard Gaussian distribution.")

            # defining parameters for mcmc sampler / ground-truth generator
            np.random.seed(900)

            D = self.params["n_dims"]
            
            true_samples, means, covariances, weights = GaussianMixtureGenerator.generate_gaussian_mixture(
                n_dim=D,
                n_gaussians=args.nr_of_components,
                n_samples = args.nr_of_samples,
                width_mean = args.width_mean,
                width_cov = args.width_cov,
                weights= args.weights_of_components,
            ) 

            # Store true samples for diagnostics later on
            self.true_samples = true_samples

            # Convert to JAX arrays 
            self.mcmc_means   = jnp.stack(means, axis=0)         # (K, D)
            self.mcmc_covs    = jnp.stack(covariances, axis=0)   # (K, D, D)
            self.mcmc_weights = jnp.asarray(weights)             # (K,)

            # Define Likelihood 
            self.likelihood = GaussianMixtureLikelihood(
                means=self.mcmc_means,
                covs=self.mcmc_covs,
                weights=self.mcmc_weights,
            )

            self.target_fn = self.target_normal

        

    def get_next_available_outdir(self, base_dir: str, prefix: str = "results") -> str:
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        existing = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
        matches = [re.match(rf"{prefix}_(\d+)", name) for name in existing]
        numbers = [int(m.group(1)) for m in matches if m]
        next_number = max(numbers, default=0) + 1
        return os.path.join(base_dir, f"{prefix}_{next_number}")



    def target_normal(self, x, data):
        # x can be shape (D,) or (..., D); GaussianMixtureLikelihood.log_prob supports both
        return self.likelihood.log_prob(x)
    

    @staticmethod
    def make_auto_bounds_inflated(means, covs, inflate=9.0, nsig=12.0, pad=1e-6,
                                    prior_low=None, prior_high=None):
        """
        Per-dimension uniform bounds [low[d], high[d]] that 
        contain all mass of a Gaussian mixture:
            - Bounds per component are wide enough to cover every single component;
            - Created by means and cov and not 
        It is achieved by:
            - find smallest and largest component means in dim k:
            - for each component k, take the marginal sd in dim k 
            - set bounds
        """
        means = np.asarray(means, dtype=float)           # (K, D)
        covs  = np.asarray(covs,  dtype=float) * float(inflate)   # make variance bigger

        # find smallest and largest component means in dim k
        mu_min = means.min(axis=0)                       # (D,)
        mu_max = means.max(axis=0)                       # (D,)

        # for each component k, take the marginal sd in dim k 
        std_max = np.sqrt(np.stack([np.diag(C) for C in covs], axis=0)).max(axis=0)  # (D,)

        # set bounds
        low  = mu_min - nsig * std_max - pad
        high = mu_max + nsig * std_max + pad

        # if bounds are provided, make bigger bounds, if necessary
        if prior_low  is not None: low  = np.minimum(low,  float(prior_low))
        if prior_high is not None: high = np.maximum(high, float(prior_high))
        return low, high
    

    #-----------------------------------------------------------------------------
    # 3.1. SAMPLER 
    #-----------------------------------------------------------------------------
    def run_experiment(self):
        """
        Run pocoMC sampler for the chosen experiment.
        Stores results on self (similar to your MALA version) and returns a dict.
        """

        dim = int(self.params["n_dims"])
        # "data" placeholder to preserve target_fn signature pattern
        data = {}

        # Prior: default to wide uniform [-30, 30] per dimension (override via params)
        #prior_low  = float(self.params["prior_low"])
        #prior_high = float(self.params["prior_high"])
        #self.prior = pc.Prior(dim * [uniform(loc=prior_low, scale=prior_high - prior_low)])

      
        # Define our normal/Gaussian prior.
        # prior = pc.Prior(dim*[norm(0.0, 3.0)]) # N(0,3)

        # After you have self.mcmc_means / self.mcmc_covs available
        means = np.asarray(self.mcmc_means)
        covs  = np.asarray(self.mcmc_covs)

        low, high = self.make_auto_bounds_inflated(
            means, covs,
            inflate=9.0,   # or smaller like 4.0 if you want tighter
            nsig=12.0,     # conservative
            prior_low=float(self.params["prior_low"]),
            prior_high=float(self.params["prior_high"]),
        )

        self.prior = pc.Prior([
            uniform(loc=low[d], scale=high[d] - low[d]) for d in range(dim)
        ])

        # Vectorized log-likelihood wrapper
        # - pocoMC passes x as np.ndarray with shape (n_active, dim) when vectorize=True
        # - must return np.ndarray shape (n_active,)
        if hasattr(self, "likelihood"):
            # GMM class: use JAX likelihood object's log_prob
            like_obj = self.likelihood
          
            def log_likelihood(x, data):
                out = self.target_normal(jnp.asarray(x), data)
                out = np.asarray(out)                 # convert to numpy
                out = np.atleast_1d(out).astype(np.float64)
                return out

        # -------------------------------------------------------------------------
        # pocoMC sampler configuration (read from params with sensible defaults)
        # -------------------------------------------------------------------------
        # Early stopping and max steps (which is optional)
        pc_n_steps = self.params.get("pc_n_steps", None)
        pc_n_max_steps = self.params.get("pc_n_max_steps", None)
        if pc_n_steps is not None:
            pc_n_steps = int(pc_n_steps)
        if pc_n_max_steps is not None:
            pc_n_max_steps = int(pc_n_max_steps)

        # Run control
        n_total    = int(self.params.get("n_total", 4096))
        n_evidence = int(self.params.get("n_evidence", 4096))
        progress   = bool(self.params.get("progress", True))
        save_every = self.params.get("save_every", None)
        if save_every is not None:
            save_every = int(save_every)

        random_state = int(self.params.get("random_state", 0))

        # Save states into your experiment outdir (optional, but nice)
        output_dir = self.params.get("outdir", None)

        sampler = pc.Sampler(
            prior=self.prior,
            likelihood=log_likelihood,
            likelihood_args=[data],
            vectorize=True,  # requires (n_active, n_dim)->(n_active,) :contentReference[oaicite:2]{index=2}
            n_dim=dim,
            n_effective=int(self.params.get("n_effective")),
            n_active=int(self.params.get("n_active")),
            #dynamic=True,
            n_prior=int(self.params.get("n_prior")),
            precondition=bool(self.params.get("precondition", True)),
            flow= self.params.get("flow", "nsf6"),
            sample=self.params.get("sample_kernel", "tpcn"),  # "tpcn" or "rwm" per docs :contentReference[oaicite:1]{index=1}
            n_steps=pc_n_steps,
            n_max_steps=pc_n_max_steps,
            output_dir=output_dir,
            output_label="pmc",
            random_state=random_state,
        )

        sampler.run(
            n_total=n_total,
            n_evidence=n_evidence,
            progress=progress,
            save_every=save_every,
        )

        # Posterior samples (resampled -> unweighted)
        samples, logl, logp = sampler.posterior(resample=True)

        # Evidence, it is optional
        logZ, logZerr = sampler.evidence()

        # Store results
        self.samples = samples
        self.logl = logl
        self.logp = logp
        self.logZ = logZ
        self.logZerr = logZerr

        self.results = {
            "samples": samples,
            "logl": logl,
            "logp": logp,
            "logZ": logZ,
            "logZerr": logZerr,
            "sampler_results": sampler.results,  # advanced per docs :contentReference[oaicite:3]{index=3}
            "params": self.params,
        }

        print("Sampling complete!")
        print(f"samples.shape = {samples.shape}")
        print(f"logZ = {logZ} ± {logZerr}")

        return self.results
    



    #-----------------------------------------------------------------------------
    # 3.2. PLOT DIAGNOSTICS 
    #-----------------------------------------------------------------------------
    def get_true_and_mcmc_samples(self, discard=0, thin=1):
        dim = int(self.params["n_dims"])

        # True samples
        if not hasattr(self, "true_samples") or self.true_samples is None:
            raise ValueError("No true samples found. Ensure self.true_samples is set (gaussian experiment).")

        true_np = np.asarray(self.true_samples).reshape(-1, dim)

        # Sampler samples
        if hasattr(self, "samples") and self.samples is not None:
            # pocoMC: self.samples is already (N, dim)
            samp = np.asarray(self.samples).reshape(-1, dim)
            samp = samp[int(discard)::int(thin), :]
            mcmc_np = samp
        else:
            raise ValueError(
                "No sampler samples found. Run run_experiment() first. "
                "Expected either self.samples (pocoMC) or self.chains (BlackJAX)."
            )

        return true_np, mcmc_np




    def plot_true_vs_mcmc_corner(self, seed=2046):
        """
        Overlay corner plot:
        - MCMC production samples (black)
        - true samples (red)
        Saves: true_vs_mcmc_corner_plot.pdf
        """
        # Get samples 
        true_np, mcmc_np = self.get_true_and_mcmc_samples()

        dim = int(self.params["n_dims"])
        labels = [f"x{i}" for i in range(dim)]

        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        # Plot MCMC first 
        fig = corner.corner(
            mcmc_np,
            color="black",
            hist_kwargs={"color": "black", "density": True},
            show_titles=True,
            labels=labels,
        )

        # Overlay true samples 
        corner.corner(
            true_np,
            fig=fig,
            color="red",
            hist_kwargs={"color": "red", "density": True},
            show_titles=True,
            labels=labels,
        )

        # Legend
        handles = [
            plt.Line2D([], [], color="black", label="pocomc"),
            plt.Line2D([], [], color="red", label="True Normal"),
        ]
        fig.legend(handles=handles, loc="upper right")

        save_name = os.path.join(outdir, "true_vs_mcmc_corner_plot.pdf")
        fig.savefig(save_name, bbox_inches="tight")
        plt.close(fig)

        print(f"Saved overlay corner plot to {save_name}")



    def plot_acceptance_rate(self):
        print("Plotting acceptance-rate diagnostic curve...")

        if hasattr(self, "sampler") and self.sampler is not None:
            results = self.sampler.results
        else:
            results = self.results["sampler_results"]

        accept = np.asarray(results["accept"]).reshape(-1)

        plt.figure(figsize=(6, 4))
        plt.plot(accept)
        plt.xlabel("PMC/SMC Iteration")
        plt.ylabel("Acceptance rate")
        plt.title("pocoMC Acceptance Rate")

        save_name = os.path.join(self.params["outdir"], "acceptance_rate_curve.pdf")
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()

        print(f"Saved to {save_name}")




    #-----------------------------------------------------------------------------
    # 3.3. SAMPLE STATISTICS
    #-----------------------------------------------------------------------------
    def save_samples_json(self):
        # output directory 
        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        # get samples once
        true_np, mcmc_np = self.get_true_and_mcmc_samples()

        # save generated samples
        mcmc_path = os.path.join(outdir, "mcmc_samples.json")
        with open(mcmc_path, "w", encoding="utf-8") as f:
            json.dump(mcmc_np.tolist(), f)
        print(f"MCMC samples saved to {mcmc_path}")

        # save true samples
        true_path = os.path.join(outdir, "true_samples.json")
        with open(true_path, "w", encoding="utf-8") as f:
            json.dump(true_np.tolist(), f)
        print(f"True samples saved to {true_path}")




    def compute_and_save_sample_statistics(self):
        """
        Computes and saves per-dimension statistics for:
        - MCMC production samples
        - true samples
        Saves: sample_statistics.txt in self.params["outdir"]
        """

        # get samples 
        true_samples, mcmc_samples = self.get_true_and_mcmc_samples()

        # MCMC stats
        self.pm = mcmc_samples.mean(axis=0)
        self.pv = mcmc_samples.var(axis=0)
        self.ps = mcmc_samples.std(axis=0)

        # True stats
        self.qm = true_samples.mean(axis=0)
        self.qv = true_samples.var(axis=0)
        self.qs = true_samples.std(axis=0)

        # store arrays 
        self.mcmc_samples = mcmc_samples
        self.true_samples_np = true_samples

        np.set_printoptions(precision=4, suppress=True)

        stats_str = (
            "pm (mean of MCMC samples):\n" + str(self.pm) +
            "\n\npv (variance of MCMC samples):\n" + str(self.pv) +
            "\n\nps (std dev of MCMC samples):\n" + str(self.ps) +
            "\n\nqm (mean of true samples):\n" + str(self.qm) +
            "\n\nqv (variance of true samples):\n" + str(self.qv) +
            "\n\nqs (std dev of true samples):\n" + str(self.qs) + "\n"
        )

        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        stats_path = os.path.join(outdir, "sample_statistics.txt")
        with open(stats_path, "w", encoding="utf-8") as f:
            f.write(stats_str)

        print(f"Sample statistics saved to {stats_path}")



    #-----------------------------------------------------------------------------
    # 3.4. KL DIVERGENCE
    #-----------------------------------------------------------------------------
    import numpy as np, warnings, os
    from typing import Tuple


    @staticmethod
    def gau_kl(pm: np.ndarray, pv: np.ndarray,
               qm: np.ndarray, qv: np.ndarray) -> float:
        """
        Kullback-Liebler divergence from Gaussian pm,pv to Gaussian qm,qv.
        Also computes KL divergence from a single Gaussian pm,pv to a set
         of Gaussians qm,qv.
        Diagonal covariances are assumed.  Divergence is expressed in nats.
        """
        if (len(qm.shape) == 2):
            axis = 1
        else:
            axis = 0
        # Determinants of diagonal covariances pv, qv
        dpv = pv.prod()
        dqv = qv.prod(axis)
        # Inverse of diagonal covariance qv
        iqv = 1. / qv
        # Difference between means pm, qm
        diff = qm - pm
        return (0.5 * (
            np.log(dqv / dpv)                 # log |\Sigma_q| / |\Sigma_p|
            + (iqv * pv).sum(axis)            # + tr(\Sigma_q^{-1} * \Sigma_p)
            + (diff * iqv * diff).sum(axis)   # + (\mu_q-\mu_p)^T\Sigma_q^{-1}(\mu_q-\mu_p)
            - len(pm)                         # - N
        ))
    

    def kl_metrics(
        self,
        outdir: str | None = None,
        filename: str = "kl_metrics.txt",
    ) -> None:
        import os
        import numpy as np

        # define outdir
        outdir = (
            outdir
            or (getattr(self, "params", {}) or {}).get("outdir", None)
            or getattr(self, "outdir", None)
        )
        if outdir is None:
            raise ValueError("No output directory specified (pass outdir=... or set params['outdir']).")
        os.makedirs(outdir, exist_ok=True)

        
        true_np, mcmc_np = self.get_true_and_mcmc_samples() 

        # Parametric Gaussian stats (diagonal covariance assumed)
        pm = mcmc_np.mean(axis=0)
        pv = mcmc_np.var(axis=0)
        qm = true_np.mean(axis=0)
        qv = true_np.var(axis=0)

        kl_val = self.gau_kl(pm, pv, qm, qv)  # scalar for 1D qm/qv

        out_path = os.path.join(outdir, filename)
        with open(out_path, "w", encoding="utf-8") as f:
            if np.isscalar(kl_val):
                f.write(f"Parametric KL (Gaussian): {float(kl_val):.8f}\n")
            else:
                kl_arr = np.asarray(kl_val).ravel()
                f.write("Parametric KL (Gaussian):\n")
                for i, v in enumerate(kl_arr):
                    f.write(f"  [{i}] {float(v):.8f}\n")

        print(f"KL metrics saved to {out_path}")


In [None]:
import sys

sys.argv = [
    ### The argparse is used to store and process any user input we want to pass on
    "notebook",
    "--experiment-type", "gaussian",
    "--outdir", "./runs/gaussian_2d",

    # Everything below here are hyperparameters for the gaussians.
    "--n-dims", "8",
    "--nr-of-samples", "10000",
    "--nr-of-components", "4",
    "--width-mean", "10.0",
    "--width-cov", "1.0",
    "--weights-of-components", "0.25", "0.25", "0.25", "0.25", 

    # Everything below here are hyperparameters for the samplers.
    "--prior-low", "-25.0",
    "--prior-high", "25.0",
    "--n-effective", "16000",
    "--n-active", "8000",
    "--n-prior", "150000"
]


def main():
    # Get the arguments passed over from the command line, and create the experiment runner
    args = parser.parse_args()
    runner = SequentialMCExperimentRunner(args)
    runner.run_experiment()
    runner.plot_true_vs_mcmc_corner()
    runner.plot_acceptance_rate()
    runner.save_samples_json()
    runner.compute_and_save_sample_statistics()
    runner.kl_metrics()

if __name__ == "__main__":
    main()

Using output directory: ./runs/gaussian_2d\results_13
Passed parameters:
experiment_type: gaussian
n_dims: 10
outdir: ./runs/gaussian_2d\results_13
nr_of_samples: 10000
nr_of_components: 4
width_mean: 10.0
width_cov: 1.0
weights_of_components: [0.25, 0.25, 0.25, 0.25]
prior_low: -25.0
prior_high: 25.0
n_effective: 16000
n_active: 8000
n_prior: 150000
Setting the target function to a standard Gaussian distribution.
[-1.9527 -5.8704  7.5673 -5.8053 -0.1505 -0.7373 -3.4366  5.8232  1.8345
 -1.4392]
[-1.7255  1.3081  2.4974 -0.9036 -6.6159  3.2771  6.3532 -2.7347  7.1564
  3.3161]
[-1.6282 -0.0014  3.6579 -2.1062 -0.2691  4.459   9.6636  8.9419 -2.9114
 -5.4834]
[ 1.3373 -3.7496  6.6619  6.7686  7.7773 -0.3607 -8.9235 -7.398   7.5108
 -6.5496]


Iter: 52it [1:07:27, 77.84s/it, beta=1, calls=1740096, ESS=18044, logZ=-51.4, logP=-68, acc=0.82, steps=5, eff=1]          


Sampling complete!
samples.shape = (44561, 10)
logZ = -51.59464312615643 ± 0.008917035592940308
Saved overlay corner plot to ./runs/gaussian_2d\results_13\true_vs_mcmc_corner_plot.pdf
Plotting acceptance-rate diagnostic curve...
Saved to ./runs/gaussian_2d\results_13\acceptance_rate_curve.pdf
MCMC samples saved to ./runs/gaussian_2d\results_13\mcmc_samples.json
True samples saved to ./runs/gaussian_2d\results_13\true_samples.json
Sample statistics saved to ./runs/gaussian_2d\results_13\sample_statistics.txt
KL metrics saved to ./runs/gaussian_2d\results_13\kl_metrics.txt
