In [1]:
from __future__ import annotations

from typing import Callable, Mapping, Tuple, Any, Optional, Dict

import jax
import jax.numpy as jnp

from scaler_jax import inverse_jax, forward_jax, apply_boundary_conditions_x_jax


Array = jax.Array


def _flow_u_to_theta(flow, u: Array, condition: Optional[Array] = None) -> Tuple[Array, Array]:
    """
        * mapping u into theta using FlowJAX bijection
        * return log|det du/dtheta|.

    FlowJAX:
        * transform_and_log_det(u) returns (theta, log|det dtheta/du|)
        * get log|det du/dtheta|.
    """
    theta, fwd_logdet = flow.bijection.transform_and_log_det(u, condition)
    return theta, -fwd_logdet


def _flow_theta_to_u(flow, theta: Array, condition: Optional[Array] = None) -> Tuple[Array, Array]:
    """
        * mapping theta into u using FlowJAX bijection
        * return log|det du/dtheta| directly.
    """
    u, inv_logdet = flow.bijection.inverse_and_log_det(theta, condition)
    return u, inv_logdet


def preconditioned_pcn_jax(
    key: Array,
    *,
    # --- current state (all arrays; no None) ---
    u: Array,                 # (N, D)
    x: Array,                 # (N, D)
    logdetj: Array,           # (N,)
    logl: Array,              # (N,)
    logp: Array,              # (N,)
    logdetj_flow: Array,      # (N,)
    blobs: Array,             # (N, B...) ; use shape (N, 0) if no blobs

    beta: Array,              # scalar

    # --- functions ---
    loglike_fn: Callable[[Array], Tuple[Array, Array]],
    logprior_fn: Callable[[Array], Array],
    flow: Any,                # FlowJAX Transformed-like object with .bijection
    scaler_cfg: Mapping[str, Array],
    scaler_masks: Mapping[str, Array],

    # --- geometry (Student-t) ---
    geom_mu: Array,           # (D,)
    geom_cov: Array,          # (D, D)
    geom_nu: Array,           # scalar

    # --- options ---
    n_max: int,
    n_steps: int,
    proposal_scale: Array,    # scalar
    condition: Optional[Array] = None,
) -> Dict[str, Array]:
    """
    Doubly Preconditioned Crankâ€“Nicolson (PCN), JAX version.

    Requirements:
      - logprior_fn(x_i): x_i has shape (D,), returns scalar
      - loglike_fn(x_i): x_i has shape (D,), returns (scalar_loglike, blob_i)
        where blob_i has fixed shape (B...) matching blobs[0].

    Info:
      - All randomness uses `key` and returns updated `key` (make sure it is pure and check randomness).
      - FlowJAX bijections are vmapped across walkers.
    """
    u = jnp.asarray(u)
    x = jnp.asarray(x)
    logdetj = jnp.asarray(logdetj)
    logl = jnp.asarray(logl)
    logp = jnp.asarray(logp)
    logdetj_flow = jnp.asarray(logdetj_flow)
    blobs = jnp.asarray(blobs)
    beta = jnp.asarray(beta)
    proposal_scale = jnp.asarray(proposal_scale)

    geom_mu = jnp.asarray(geom_mu)
    geom_cov = jnp.asarray(geom_cov)
    geom_nu = jnp.asarray(geom_nu)

    n_walkers, n_dim = u.shape

    inv_cov = jnp.linalg.inv(geom_cov)
    chol_cov = jnp.linalg.cholesky(geom_cov)

    # --- Flow: u -> theta (batched via vmap) ---
    def _u2t_single(ui: Array) -> Tuple[Array, Array]:
        return _flow_u_to_theta(flow, ui, condition)

    theta, logdetj_flow0 = jax.vmap(_u2t_single, in_axes=0, out_axes=(0, 0))(u)

    # initial mean and counter and objective
    mu = geom_mu
    sigma0 = jnp.minimum(proposal_scale, jnp.asarray(0.99, dtype=u.dtype))
    logp2_best = jnp.mean(logl + logp)
    cnt0 = jnp.asarray(0, dtype=jnp.int32)
    i0 = jnp.asarray(0, dtype=jnp.int32)
    calls0 = jnp.asarray(0, dtype=jnp.int32)
    accept0 = jnp.asarray(0.0, dtype=u.dtype)
    done0 = jnp.asarray(False)

    # update initial flow logdet with computed one 
    logdetj_flow = logdetj_flow0

    blob_template = jnp.zeros_like(blobs[0])

    # helpers: Student-t form
    def _quad(diff_: Array) -> Array:
        tmp = diff_ @ inv_cov
        return jnp.sum(tmp * diff_, axis=1)

    # --- skip invalid walkers ---
    def _prior_or_neginf(xi: Array, ok: Array) -> Array:
        return jax.lax.cond(
            ok,
            lambda z: logprior_fn(z),
            lambda z: jnp.asarray(-jnp.inf, dtype=xi.dtype),
            xi,
        )

    def _like_or_neginf(xi: Array, ok: Array) -> Tuple[Array, Array]:
        def _do(z: Array) -> Tuple[Array, Array]:
            ll, bb = loglike_fn(z)
            return ll, bb

        def _skip(z: Array) -> Tuple[Array, Array]:
            return jnp.asarray(-jnp.inf, dtype=xi.dtype), blob_template

        return jax.lax.cond(ok, _do, _skip, xi)

    
    # (key, u, x, theta, logdetj, logdetj_flow, logl, logp, blobs, mu, sigma, logp2_best, cnt, i, calls, accept, done)
    carry0 = (
        key, u, x, theta, logdetj, logdetj_flow, logl, logp, blobs,
        mu, sigma0, logp2_best, cnt0, i0, calls0, accept0, done0
    )

    max_sigma_cap = jnp.minimum(jnp.asarray(2.38, dtype=u.dtype) / jnp.sqrt(jnp.asarray(n_dim, dtype=u.dtype)),
                                jnp.asarray(0.99, dtype=u.dtype))

    def cond_fn(carry):
        (_, _, _, _, _, _, _, _, _, _, _, _, _, i, _, _, done) = carry
        return (i < jnp.asarray(n_max, dtype=i.dtype)) & (~done)

    def body_fn(carry):
        (key, u, x, theta, logdetj, logdetj_flow, logl, logp, blobs,
         mu, sigma, logp2_best, cnt, i, calls, accept, done) = carry

        i1 = i + jnp.asarray(1, dtype=i.dtype)

        key, k_gamma, k_norm, k_unif = jax.random.split(key, 4)

        diff = theta - mu
        quad = _quad(diff)

        a = (jnp.asarray(n_dim, dtype=u.dtype) + geom_nu) / jnp.asarray(2.0, dtype=u.dtype)
        z = jax.random.gamma(k_gamma, a, shape=(n_walkers,))  # unit scale
        s = (geom_nu + quad) / (jnp.asarray(2.0, dtype=u.dtype) * z)

        eps = jax.random.normal(k_norm, shape=(n_walkers, n_dim), dtype=u.dtype)
        noise = eps @ chol_cov.T

        theta_prime = (
            mu
            + jnp.sqrt(jnp.asarray(1.0, dtype=u.dtype) - sigma * sigma) * diff
            + sigma * jnp.sqrt(s)[:, None] * noise
        )

        # --- Flow: theta into u (batched via vmap) ---
        def _t2u_single(ti: Array) -> Tuple[Array, Array]:
            return _flow_theta_to_u(flow, ti, condition)

        u_prime, logdetj_flow_prime = jax.vmap(_t2u_single, in_axes=0, out_axes=(0, 0))(theta_prime)

        # --- Scaler inverse: u into x, ---
        #TODO check boundary handling here 
        x_prime, logdetj_prime = inverse_jax(u_prime, scaler_cfg, scaler_masks)

        x_prime_bc = apply_boundary_conditions_x_jax(x_prime, dict(scaler_cfg))
        u_prime_bc = forward_jax(x_prime_bc, scaler_cfg, scaler_masks)
        x_prime, logdetj_prime = inverse_jax(u_prime_bc, scaler_cfg, scaler_masks)

        u_prime = u_prime_bc

        finite0 = jnp.isfinite(logdetj_prime) & jnp.all(jnp.isfinite(x_prime), axis=1)

        logp_prime = jax.vmap(_prior_or_neginf, in_axes=(0, 0), out_axes=0)(x_prime, finite0)
        finite1 = finite0 & jnp.isfinite(logp_prime)

        logl_prime, blobs_prime = jax.vmap(_like_or_neginf, in_axes=(0, 0), out_axes=(0, 0))(x_prime, finite1)

        calls = calls + jnp.sum(finite1.astype(jnp.int32))

        diff_prime = theta_prime - mu
        quad_prime = _quad(diff_prime)

        coef = -(jnp.asarray(n_dim, dtype=u.dtype) + geom_nu) / jnp.asarray(2.0, dtype=u.dtype)
        A = coef * jnp.log1p(quad_prime / geom_nu)
        B = coef * jnp.log1p(quad / geom_nu)

        log_alpha = (
            (logl_prime - logl) * beta
            + (logp_prime - logp)
            + (logdetj_prime - logdetj)
            + (logdetj_flow_prime - logdetj_flow)
            - A + B
        )

        alpha = jnp.exp(jnp.minimum(jnp.asarray(0.0, dtype=u.dtype), log_alpha))
        alpha = jnp.where(jnp.isnan(alpha), jnp.asarray(0.0, dtype=u.dtype), alpha)

        u_rand = jax.random.uniform(k_unif, shape=(n_walkers,), dtype=u.dtype)
        accept_mask = u_rand < alpha

        # accept / reject
        # TODO check this
        theta = jnp.where(accept_mask[:, None], theta_prime, theta)
        u = jnp.where(accept_mask[:, None], u_prime, u)
        x = jnp.where(accept_mask[:, None], x_prime, x)

        logdetj = jnp.where(accept_mask, logdetj_prime, logdetj)
        logdetj_flow = jnp.where(accept_mask, logdetj_flow_prime, logdetj_flow)
        logl = jnp.where(accept_mask, logl_prime, logl)
        logp = jnp.where(accept_mask, logp_prime, logp)
        blobs = jnp.where(accept_mask.reshape((n_walkers,) + (1,) * (blobs.ndim - 1)), blobs_prime, blobs)

        accept = jnp.mean(alpha)

        # TODO check
        step = jnp.asarray(1.0, dtype=u.dtype) / jnp.power(jnp.asarray(i1 + 1, dtype=u.dtype), jnp.asarray(0.75, dtype=u.dtype))
        sigma = sigma + step * (accept - jnp.asarray(0.234, dtype=u.dtype))
        sigma = jnp.abs(jnp.minimum(sigma, max_sigma_cap))

        mu_step = jnp.asarray(1.0, dtype=u.dtype) / jnp.asarray(i1 + 1, dtype=u.dtype)
        mu = mu + mu_step * (jnp.mean(theta, axis=0) - mu)

        logp2_new = jnp.mean(logl + logp)
        improved = logp2_new > logp2_best
        cnt = jnp.where(improved, jnp.asarray(0, dtype=cnt.dtype), cnt + jnp.asarray(1, dtype=cnt.dtype))
        logp2_best = jnp.where(improved, logp2_new, logp2_best)

        thresh = jnp.asarray(n_steps, dtype=u.dtype) * jnp.power(
            (jnp.asarray(2.38, dtype=u.dtype) / jnp.sqrt(jnp.asarray(n_dim, dtype=u.dtype))) / sigma,
            jnp.asarray(2.0, dtype=u.dtype),
        )
        done = cnt.astype(u.dtype) >= thresh

        return (
            key, u, x, theta, logdetj, logdetj_flow, logl, logp, blobs,
            mu, sigma, logp2_best, cnt, i1, calls, accept, done
        )

    carry_f = jax.lax.while_loop(cond_fn, body_fn, carry0)

    (key, u, x, theta, logdetj, logdetj_flow, logl, logp, blobs,
     mu, sigma, logp2_best, cnt, i, calls, accept, done) = carry_f

    return {
        "key": key,
        "u": u,
        "x": x,
        "logdetj": logdetj,
        "logdetj_flow": logdetj_flow,
        "logl": logl,
        "logp": logp,
        "blobs": blobs,
        "efficiency": sigma,
        "accept": accept,
        "steps": i,
        "calls": calls,
        "proposal_scale": sigma,
    }

---
# test
---

In [2]:
##################################################################################
# EXPERIMENT RUNNER (UPDATED FOR preconditioned_pcn_jax)
##################################################################################
import json
import matplotlib.pyplot as plt
import corner
import os
import re
import numpy as np

import jax
import jax.numpy as jnp

from likelihood import *
from gaussian_mixture import *
import logging
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)



SUPPORTED_EXPERIMENTS = {"gaussian"}


# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def mixture_mean_cov(means: jnp.ndarray, covs: jnp.ndarray, weights: jnp.ndarray, jitter: float = 1e-6):
    """
    Compute mixture mean/cov for a Gaussian mixture:
      means:   (K, D)
      covs:    (K, D, D)
      weights: (K,)
    Returns:
      mu:  (D,)
      cov: (D, D)
    """
    w = weights / jnp.sum(weights)
    mu = jnp.sum(w[:, None] * means, axis=0)  # (D,)

    diff = means - mu[None, :]  # (K, D)
    outer = diff[:, :, None] * diff[:, None, :]  # (K, D, D)

    cov = jnp.sum(w[:, None, None] * (covs + outer), axis=0)
    cov = cov + jitter * jnp.eye(cov.shape[0], dtype=cov.dtype)
    return mu, cov


def make_uniform_box_logprior(low: jnp.ndarray, high: jnp.ndarray):
    """
    Uniform prior on a hyper-rectangle:
      log p(x) = 0 inside [low, high], -inf outside
    """
    low = jnp.asarray(low)
    high = jnp.asarray(high)

    def logprior_fn(x: jnp.ndarray) -> jnp.ndarray:
        inside = jnp.all((x >= low) & (x <= high))
        return jax.lax.select(
            inside,
            jnp.asarray(0.0, dtype=x.dtype),
            jnp.asarray(-jnp.inf, dtype=x.dtype),
        )

    return logprior_fn


class IdentityBijection:
    def transform_and_log_det(self, u, condition=None):
        return u, jnp.asarray(0.0, dtype=u.dtype)

    def inverse_and_log_det(self, theta, condition=None):
        return theta, jnp.asarray(0.0, dtype=theta.dtype)


class IdentityFlow:
    """Use this only to validate wiring when you don't have a trained FlowJAX flow yet."""
    def __init__(self):
        self.bijection = IdentityBijection()


##################################################################################
# Runner
##################################################################################
class pcn_ExperimentRunner:
    """
    Runner that can generate samples using preconditioned_pcn_jax.

    Key design decisions:
      - We run the kernel multiple times (n_outer), each time updating the full ensemble state.
      - We store x across outer iterations -> samples shape (n_outer, N_walkers, D).
        Your diagnostics can flatten via reshape(-1, D).
    """

    def __init__(self, args, *, flow=None, scaler_cfg=None, scaler_masks=None):
        self.params = vars(args)

        # --- unique outdir ---
        base_results_dir = self.params["outdir"]
        unique_outdir = self.get_next_available_outdir(base_results_dir)
        print(f"Using output directory: {unique_outdir}")
        os.makedirs(unique_outdir, exist_ok=False)
        self.params["outdir"] = unique_outdir

        # --- validate experiment ---
        if self.params["experiment_type"] not in SUPPORTED_EXPERIMENTS:
            raise ValueError(
                f"Experiment type {self.params['experiment_type']} is not supported. "
                f"Supported types are: {SUPPORTED_EXPERIMENTS}"
            )

        print("Passed parameters:")
        for k, v in self.params.items():
            print(f"{k}: {v}")

        # Attach (or later set) flow + scaler
        self.flow = flow
        self.scaler_cfg = scaler_cfg
        self.scaler_masks = scaler_masks

        # Setup experiment
        if self.params["experiment_type"] == "gaussian":
            self._setup_gaussian_experiment(args)

        # Placeholder for results
        self.samples = None
        self.accept_history = None
        self.sigma_history = None
        self.calls_history = None

    # -------------------------------------------------------------------------
    # Setup
    # -------------------------------------------------------------------------
    def _setup_gaussian_experiment(self, args):
        print("Setting the target function to a Gaussian mixture distribution.")

        np.random.seed(900)

        D = int(self.params["n_dims"])

        # Generate "true" samples and mixture parameters (your existing generator)
        true_samples, means, covariances, weights = GaussianMixtureGenerator.generate_gaussian_mixture(
            n_dim=D,
            n_gaussians=args.nr_of_components,
            n_samples=args.nr_of_samples,
            width_mean=args.width_mean,
            width_cov=args.width_cov,
            weights=args.weights_of_components,
        )

        self.true_samples = true_samples

        # Convert mixture params to JAX arrays
        self.mcmc_means = jnp.stack(means, axis=0)        # (K, D)
        self.mcmc_covs = jnp.stack(covariances, axis=0)   # (K, D, D)
        self.mcmc_weights = jnp.asarray(weights)          # (K,)

        # Likelihood object you already have
        self.likelihood = GaussianMixtureLikelihood(
            means=self.mcmc_means,
            covs=self.mcmc_covs,
            weights=self.mcmc_weights,
        )

        # Prior bounds (uniform box)
        low_np, high_np = self.make_auto_bounds_inflated(
            means=means,
            covs=covariances,
            inflate=float(self.params.get("prior_inflate", 9.0)),
            nsig=float(self.params.get("prior_nsig", 12.0)),
            pad=float(self.params.get("prior_pad", 1e-6)),
        )
        self.prior_low = jnp.asarray(low_np)
        self.prior_high = jnp.asarray(high_np)

        # Student-t geometry in theta-space (use mixture moments as default)
        self.geom_mu, self.geom_cov = mixture_mean_cov(
            self.mcmc_means, self.mcmc_covs, self.mcmc_weights, jitter=float(self.params.get("geom_jitter", 1e-6))
        )
        self.geom_nu = jnp.asarray(self.params.get("geom_nu", 5.0), dtype=self.geom_mu.dtype)

        # Convenience target fn (optional)
        self.target_fn = self.target_normal

    # -------------------------------------------------------------------------
    # Public API
    # -------------------------------------------------------------------------
    def attach_flow_and_scaler(self, *, flow, scaler_cfg, scaler_masks):
        """Call this if you cannot provide these objects in __init__."""
        self.flow = flow
        self.scaler_cfg = scaler_cfg
        self.scaler_masks = scaler_masks

    def run_experiment(self):
        sampler = self.params.get("sampler", "precond_pcn")

        if self.params["experiment_type"] == "gaussian" and sampler == "precond_pcn":
            self._run_preconditioned_pcn_gaussian()
            return

        raise ValueError(
            f"Unsupported combination experiment_type={self.params['experiment_type']} sampler={sampler}"
        )

    # -------------------------------------------------------------------------
    # Core run method (your algorithm)
    # -------------------------------------------------------------------------
    def _run_preconditioned_pcn_gaussian(self):
        # --- required objects ---
        if self.flow is None:
            # If you want a hard error instead, replace with raise ValueError(...)
            print("Warning: self.flow is None; using IdentityFlow() for wiring test.")
            self.flow = IdentityFlow()

        if self.scaler_cfg is None or self.scaler_masks is None:
            raise ValueError(
                "scaler_cfg / scaler_masks are required for inverse_jax/forward_jax. "
                "Attach them via attach_flow_and_scaler(...) or pass into __init__."
            )

        D = int(self.params["n_dims"])
        N = int(self.params.get("n_walkers", 2048))

        # Outer iterations: each call to preconditioned_pcn_jax adapts sigma/mu internally up to n_max
        n_outer = int(self.params.get("n_outer", 50))

        # Kernel parameters
        beta = jnp.asarray(self.params.get("beta", 1.0), dtype=jnp.float32)
        n_max = int(self.params.get("n_max", 2000))
        n_steps = int(self.params.get("n_steps", 100))
        proposal_scale = jnp.asarray(self.params.get("proposal_scale", 0.2), dtype=jnp.float32)

        seed = int(self.params.get("seed", 0))
        key = jax.random.PRNGKey(seed)

        # Prior / likelihood functions in required signatures
        logprior_fn = make_uniform_box_logprior(self.prior_low, self.prior_high)
        blob0 = jnp.zeros((0,), dtype=jnp.float32)

        def loglike_fn(xi):
            ll = self.likelihood.log_prob(xi)  # scalar
            return ll, blob0

        # -----------------------------
        # Initialize ensemble state
        # -----------------------------
        key, k_init = jax.random.split(key, 2)
        u = jax.random.normal(k_init, shape=(N, D), dtype=jnp.float32)

        x, logdetj = inverse_jax(u, self.scaler_cfg, self.scaler_masks)

        # keep boundary-condition roundtrip consistent with your kernel
        x_bc = apply_boundary_conditions_x_jax(x, dict(self.scaler_cfg))
        u_bc = forward_jax(x_bc, self.scaler_cfg, self.scaler_masks)
        x, logdetj = inverse_jax(u_bc, self.scaler_cfg, self.scaler_masks)
        u = u_bc

        finite0 = jnp.isfinite(logdetj) & jnp.all(jnp.isfinite(x), axis=1)

        def _prior_or_neginf(xi, ok):
            return jax.lax.cond(
                ok,
                lambda z: logprior_fn(z),
                lambda z: jnp.asarray(-jnp.inf, dtype=xi.dtype),
                xi,
            )

        logp = jax.vmap(_prior_or_neginf, in_axes=(0, 0), out_axes=0)(x, finite0)
        finite1 = finite0 & jnp.isfinite(logp)

        def _like_or_neginf(xi, ok):
            def _do(z):
                return loglike_fn(z)
            def _skip(z):
                return jnp.asarray(-jnp.inf, dtype=xi.dtype), blob0
            return jax.lax.cond(ok, _do, _skip, xi)

        logl, _ = jax.vmap(_like_or_neginf, in_axes=(0, 0), out_axes=(0, 0))(x, finite1)

        # Required by kernel interface; it recomputes logdetj_flow internally anyway
        logdetj_flow = jnp.zeros((N,), dtype=jnp.float32)
        blobs = jnp.zeros((N, 0), dtype=jnp.float32)

        # storage
        xs = []
        accept_hist = []
        sigma_hist = []
        calls_hist = []

        # -----------------------------
        # Outer loop: accumulate samples
        # -----------------------------
        for t in range(n_outer):
            out = preconditioned_pcn_jax(
                key,
                u=u,
                x=x,
                logdetj=logdetj,
                logl=logl,
                logp=logp,
                logdetj_flow=logdetj_flow,
                blobs=blobs,
                beta=beta,
                loglike_fn=loglike_fn,
                logprior_fn=logprior_fn,
                flow=self.flow,
                scaler_cfg=self.scaler_cfg,
                scaler_masks=self.scaler_masks,
                geom_mu=self.geom_mu,
                geom_cov=self.geom_cov,
                geom_nu=self.geom_nu,
                n_max=n_max,
                n_steps=n_steps,
                proposal_scale=proposal_scale,
                condition=None,
            )

            # update state
            key = out["key"]
            u = out["u"]
            x = out["x"]
            logdetj = out["logdetj"]
            logdetj_flow = out["logdetj_flow"]
            logl = out["logl"]
            logp = out["logp"]
            blobs = out["blobs"]

            xs.append(x)
            accept_hist.append(out["accept"])
            sigma_hist.append(out["proposal_scale"])
            calls_hist.append(out["calls"])

            if (t + 1) % int(self.params.get("print_every", 10)) == 0:
                acc = float(np.asarray(out["accept"]))
                sig = float(np.asarray(out["proposal_scale"]))
                calls = int(np.asarray(out["calls"]))
                steps = int(np.asarray(out["steps"]))
                print(f"[outer {t+1:>4d}/{n_outer}] accept={acc:.4f} sigma={sig:.4f} calls={calls} steps={steps}")

        # Store results
        self.samples = np.asarray(jnp.stack(xs, axis=0))  # (n_outer, N, D)
        self.accept_history = np.asarray(jnp.stack(accept_hist))
        self.sigma_history = np.asarray(jnp.stack(sigma_hist))
        self.calls_history = np.asarray(jnp.stack(calls_hist))

        # Convenience summary
        print(
            f"Done. samples shape={self.samples.shape} "
            f"mean_accept={self.accept_history.mean():.4f} "
            f"last_sigma={self.sigma_history[-1]:.4f}"
        )

    # -------------------------------------------------------------------------
    # Existing utilities / diagnostics
    # -------------------------------------------------------------------------
    def target_normal(self, x, data=None):
        return self.likelihood.log_prob(x)

    def get_next_available_outdir(self, base_dir: str, prefix: str = "results") -> str:
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        existing = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
        matches = [re.match(rf"{prefix}_(\d+)", name) for name in existing]
        numbers = [int(m.group(1)) for m in matches if m]
        next_number = max(numbers, default=0) + 1
        return os.path.join(base_dir, f"{prefix}_{next_number}")

    @staticmethod
    def make_auto_bounds_inflated(means, covs, inflate=9.0, nsig=12.0, pad=1e-6,
                                 prior_low=None, prior_high=None):
        means = np.asarray(means, dtype=float)                 # (K, D)
        covs = np.asarray(covs, dtype=float) * float(inflate)  # inflate variance

        mu_min = means.min(axis=0)                             # (D,)
        mu_max = means.max(axis=0)                             # (D,)

        std_max = np.sqrt(np.stack([np.diag(C) for C in covs], axis=0)).max(axis=0)  # (D,)

        low = mu_min - nsig * std_max - pad
        high = mu_max + nsig * std_max + pad

        if prior_low is not None:
            low = np.minimum(low, float(prior_low))
        if prior_high is not None:
            high = np.maximum(high, float(prior_high))
        return low, high




    def get_true_and_mcmc_samples(self, discard=0, thin=1):
        dim = int(self.params["n_dims"])

        if not hasattr(self, "true_samples") or self.true_samples is None:
            raise ValueError("No true samples found. Ensure self.true_samples is set (gaussian experiment).")

        true_np = np.asarray(self.true_samples).reshape(-1, dim)

        if hasattr(self, "samples") and self.samples is not None:
            samp = np.asarray(self.samples).reshape(-1, dim)  # works for (n_outer, N, D) too
            samp = samp[int(discard)::int(thin), :]
            mcmc_np = samp
        else:
            raise ValueError("No sampler samples found. Run run_experiment() first.")

        return true_np, mcmc_np
    


    def plot_true_vs_mcmc_corner(self, seed=2046):
        """
        Overlay corner plot:
        - MCMC production samples (black)
        - true samples (red)
        Saves: true_vs_mcmc_corner_plot.pdf
        """
        # Get samples 
        true_np, mcmc_np = self.get_true_and_mcmc_samples()

        dim = int(self.params["n_dims"])
        labels = [f"x{i}" for i in range(dim)]

        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        # Plot MCMC first 
        fig = corner.corner(
            mcmc_np,
            color="black",
            hist_kwargs={"color": "black", "density": True},
            show_titles=True,
            labels=labels,
        )

        # Overlay true samples 
        corner.corner(
            true_np,
            fig=fig,
            color="red",
            hist_kwargs={"color": "red", "density": True},
            show_titles=True,
            labels=labels,
        )

        # Legend
        handles = [
            plt.Line2D([], [], color="black", label="pocomc"),
            plt.Line2D([], [], color="red", label="True Normal"),
        ]
        fig.legend(handles=handles, loc="upper right")

        save_name = os.path.join(outdir, "true_vs_mcmc_corner_plot.pdf")
        fig.savefig(save_name, bbox_inches="tight")
        plt.close(fig)

        print(f"Saved overlay corner plot to {save_name}")



    def plot_acceptance_rate(self):
        print("Plotting acceptance-rate diagnostic curve...")

        if self.accept_history is None:
            raise ValueError("No accept_history found. Run run_experiment() first.")

        accept = np.asarray(self.accept_history).reshape(-1)

        plt.figure(figsize=(6, 4))
        plt.plot(accept)
        plt.xlabel("Outer iteration")
        plt.ylabel("Mean acceptance (alpha)")
        plt.title("Preconditioned pCN Acceptance")
        save_name = os.path.join(self.params["outdir"], "acceptance_rate_curve.pdf")
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()
        print(f"Saved to {save_name}")



    def plot_sigma(self):
        if self.sigma_history is None:
            raise ValueError("No sigma_history found. Run run_experiment() first.")

        sig = np.asarray(self.sigma_history).reshape(-1)
        plt.figure(figsize=(6, 4))
        plt.plot(sig)
        plt.xlabel("Outer iteration")
        plt.ylabel("proposal_scale (sigma)")
        plt.title("Preconditioned pCN Sigma Adaptation")
        save_name = os.path.join(self.params["outdir"], "sigma_curve.pdf")
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()
        print(f"Saved to {save_name}")





    #-----------------------------------------------------------------------------
    # 3.3. SAMPLE STATISTICS
    #-----------------------------------------------------------------------------
    def save_samples_json(self):
        # output directory 
        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        # get samples once
        true_np, mcmc_np = self.get_true_and_mcmc_samples()

        # save generated samples
        mcmc_path = os.path.join(outdir, "mcmc_samples.json")
        with open(mcmc_path, "w", encoding="utf-8") as f:
            json.dump(mcmc_np.tolist(), f)
        print(f"MCMC samples saved to {mcmc_path}")

        # save true samples
        true_path = os.path.join(outdir, "true_samples.json")
        with open(true_path, "w", encoding="utf-8") as f:
            json.dump(true_np.tolist(), f)
        print(f"True samples saved to {true_path}")




    def compute_and_save_sample_statistics(self):
        """
        Computes and saves per-dimension statistics for:
        - MCMC production samples
        - true samples
        Saves: sample_statistics.txt in self.params["outdir"]
        """

        # get samples 
        true_samples, mcmc_samples = self.get_true_and_mcmc_samples()

        # MCMC stats
        self.pm = mcmc_samples.mean(axis=0)
        self.pv = mcmc_samples.var(axis=0)
        self.ps = mcmc_samples.std(axis=0)

        # True stats
        self.qm = true_samples.mean(axis=0)
        self.qv = true_samples.var(axis=0)
        self.qs = true_samples.std(axis=0)

        # store arrays 
        self.mcmc_samples = mcmc_samples
        self.true_samples_np = true_samples

        np.set_printoptions(precision=4, suppress=True)

        stats_str = (
            "pm (mean of MCMC samples):\n" + str(self.pm) +
            "\n\npv (variance of MCMC samples):\n" + str(self.pv) +
            "\n\nps (std dev of MCMC samples):\n" + str(self.ps) +
            "\n\nqm (mean of true samples):\n" + str(self.qm) +
            "\n\nqv (variance of true samples):\n" + str(self.qv) +
            "\n\nqs (std dev of true samples):\n" + str(self.qs) + "\n"
        )

        outdir = self.params["outdir"]
        os.makedirs(outdir, exist_ok=True)

        stats_path = os.path.join(outdir, "sample_statistics.txt")
        with open(stats_path, "w", encoding="utf-8") as f:
            f.write(stats_str)

        print(f"Sample statistics saved to {stats_path}")



    #-----------------------------------------------------------------------------
    # 3.4. KL DIVERGENCE
    #-----------------------------------------------------------------------------
    import numpy as np, warnings, os
    from typing import Tuple


    @staticmethod
    def gau_kl(pm: np.ndarray, pv: np.ndarray,
               qm: np.ndarray, qv: np.ndarray) -> float:
        """
        Kullback-Liebler divergence from Gaussian pm,pv to Gaussian qm,qv.
        Also computes KL divergence from a single Gaussian pm,pv to a set
         of Gaussians qm,qv.
        Diagonal covariances are assumed.  Divergence is expressed in nats.
        """
        if (len(qm.shape) == 2):
            axis = 1
        else:
            axis = 0
        # Determinants of diagonal covariances pv, qv
        dpv = pv.prod()
        dqv = qv.prod(axis)
        # Inverse of diagonal covariance qv
        iqv = 1. / qv
        # Difference between means pm, qm
        diff = qm - pm
        return (0.5 * (
            np.log(dqv / dpv)                 # log |\Sigma_q| / |\Sigma_p|
            + (iqv * pv).sum(axis)            # + tr(\Sigma_q^{-1} * \Sigma_p)
            + (diff * iqv * diff).sum(axis)   # + (\mu_q-\mu_p)^T\Sigma_q^{-1}(\mu_q-\mu_p)
            - len(pm)                         # - N
        ))
    

    def kl_metrics(
        self,
        outdir: str | None = None,
        filename: str = "kl_metrics.txt",
    ) -> None:
        import os
        import numpy as np

        # define outdir
        outdir = (
            outdir
            or (getattr(self, "params", {}) or {}).get("outdir", None)
            or getattr(self, "outdir", None)
        )
        if outdir is None:
            raise ValueError("No output directory specified (pass outdir=... or set params['outdir']).")
        os.makedirs(outdir, exist_ok=True)

        
        true_np, mcmc_np = self.get_true_and_mcmc_samples() 

        # Parametric Gaussian stats (diagonal covariance assumed)
        pm = mcmc_np.mean(axis=0)
        pv = mcmc_np.var(axis=0)
        qm = true_np.mean(axis=0)
        qv = true_np.var(axis=0)

        kl_val = self.gau_kl(pm, pv, qm, qv)  # scalar for 1D qm/qv

        out_path = os.path.join(outdir, filename)
        with open(out_path, "w", encoding="utf-8") as f:
            if np.isscalar(kl_val):
                f.write(f"Parametric KL (Gaussian): {float(kl_val):.8f}\n")
            else:
                kl_arr = np.asarray(kl_val).ravel()
                f.write("Parametric KL (Gaussian):\n")
                for i, v in enumerate(kl_arr):
                    f.write(f"  [{i}] {float(v):.8f}\n")

        print(f"KL metrics saved to {out_path}")

In [3]:
import jax.numpy as jnp

# Identity "scaler": u <-> x
def inverse_jax(u, scaler_cfg=None, scaler_masks=None):
    u = jnp.asarray(u)
    logdet = jnp.zeros((u.shape[0],), dtype=u.dtype)  # (N,)
    return u, logdet

def forward_jax(x, scaler_cfg=None, scaler_masks=None):
    return jnp.asarray(x)

def apply_boundary_conditions_x_jax(x, cfg_dict=None):
    return jnp.asarray(x)

flow = IdentityFlow()
scaler_cfg = {}      # empty mapping is OK because identity funcs ignore it
scaler_masks = {}

---
# weak
---

In [None]:
from types import SimpleNamespace

args = SimpleNamespace(
    outdir="./results",
    experiment_type="gaussian",

    n_dims=5,
    nr_of_components=4,
    nr_of_samples=10000,
    width_mean=10.0,
    width_cov=1.0,
    weights_of_components=None,

    sampler="precond_pcn",
    n_walkers=4000,        # start smaller for a quick test
    n_outer=100,
    n_max=9000,
    n_steps=300,
    proposal_scale=0.2,
    beta=1.0,
    seed=55,
    print_every=10,

    geom_nu=1,
    prior_inflate=16.0,
    prior_nsig=18.0,
    prior_pad=1e-6,
    geom_jitter=1e-6,
)

In [5]:
def main():
    # Get the arguments passed over from the command line, and create the experiment runner
    
    runner = pcn_ExperimentRunner(args, flow=flow, scaler_cfg=scaler_cfg, scaler_masks=scaler_masks)
    runner.run_experiment()
    runner.plot_true_vs_mcmc_corner()
    runner.plot_acceptance_rate()
    runner.plot_sigma()
    runner.save_samples_json()
    runner.compute_and_save_sample_statistics()
    runner.kl_metrics()

if __name__ == "__main__":
    main()

Using output directory: ./results\results_10
Passed parameters:
outdir: ./results\results_10
experiment_type: gaussian
n_dims: 5
nr_of_components: 4
nr_of_samples: 10000
width_mean: 10.0
width_cov: 1.0
weights_of_components: [0.25, 0.25, 0.25, 0.25]
sampler: precond_pcn
n_walkers: 4000
n_outer: 100
n_max: 9000
n_steps: 300
proposal_scale: 0.2
beta: 1.0
seed: 55
print_every: 10
geom_nu: 1
prior_inflate: 16.0
prior_nsig: 18.0
prior_pad: 1e-06
geom_jitter: 1e-06
Setting the target function to a Gaussian mixture distribution.
[-1.9527316 -5.8704066  7.5672913 -5.8052754 -0.1505208]
[-1.7254925   1.308074    2.4973822  -0.90360403 -6.615944  ]
[-1.6282320e+00 -1.3995171e-03  3.6579323e+00 -2.1062398e+00
 -2.6911259e-01]
[ 1.3372827 -3.7496185  6.6618586  6.768627   7.777319 ]
[outer   10/100] accept=0.2454 sigma=0.1493 calls=36000000 steps=9000
[outer   20/100] accept=0.2360 sigma=0.1487 calls=36000000 steps=9000
[outer   30/100] accept=0.2429 sigma=0.1471 calls=36000000 steps=9000
[outer  

---
# strong
---

In [None]:
from types import SimpleNamespace

args = SimpleNamespace(
    outdir="./results",
    experiment_type="gaussian",

    # gaussian-mixture generator params (match your parser names)
    n_dims=5,
    nr_of_components=4,
    nr_of_samples=10000,
    width_mean=10.0,
    width_cov=1.0,
    weights_of_components=None,     # or a list like [0.25,0.25,0.25,0.25]

    sampler="precond_pcn",

    # exploration 
    n_walkers=4096,     
    n_outer=300,        
    n_max=15000,        
    n_steps=500,        

    # proposal / acceptance
    proposal_scale=0.2,  
    beta=1.0,

    # heavier tails in Student-t mixture proposal
    geom_nu=2.5,          

    # prior box: avoid clipping the target
    prior_inflate=16.0,   
    prior_nsig=18.0,      
    prior_pad=1e-6,

    # geometry jitter
    geom_jitter=1e-6,

    seed=0,
    print_every=10,
)

In [5]:
def main():
    # Get the arguments passed over from the command line, and create the experiment runner
    
    runner = pcn_ExperimentRunner(args, flow=flow, scaler_cfg=scaler_cfg, scaler_masks=scaler_masks)
    runner.run_experiment()
    runner.plot_true_vs_mcmc_corner()
    runner.plot_acceptance_rate()
    runner.plot_sigma()
    runner.save_samples_json()
    runner.compute_and_save_sample_statistics()
    runner.kl_metrics()

if __name__ == "__main__":
    main()

Using output directory: ./results\results_8
Passed parameters:
outdir: ./results\results_8
experiment_type: gaussian
n_dims: 5
nr_of_components: 4
nr_of_samples: 10000
width_mean: 10.0
width_cov: 1.0
weights_of_components: None
sampler: precond_pcn
n_walkers: 4096
n_outer: 300
n_max: 15000
n_steps: 500
proposal_scale: 0.45
beta: 1.0
geom_nu: 2.5
prior_inflate: 16.0
prior_nsig: 18.0
prior_pad: 1e-06
geom_jitter: 1e-06
seed: 0
print_every: 10
Setting the target function to a Gaussian mixture distribution.
[-1.9527316 -5.8704066  7.5672913 -5.8052754 -0.1505208]
[-1.7254925   1.308074    2.4973822  -0.90360403 -6.615944  ]
[-1.6282320e+00 -1.3995171e-03  3.6579323e+00 -2.1062398e+00
 -2.6911259e-01]
[ 1.3372827 -3.7496185  6.6618586  6.768627   7.777319 ]
[outer   10/300] accept=0.2232 sigma=0.1429 calls=61440000 steps=15000
[outer   20/300] accept=0.2282 sigma=0.1388 calls=61440000 steps=15000
[outer   30/300] accept=0.2274 sigma=0.1419 calls=61440000 steps=15000
[outer   40/300] accept=

KeyboardInterrupt: 

# weak thing

from types import SimpleNamespace

args = SimpleNamespace(
    outdir="./results",
    experiment_type="gaussian",

    n_dims=2,
    nr_of_components=4,
    nr_of_samples=20000,
    width_mean=5.0,
    width_cov=1.0,
    weights_of_components=None,

    sampler="precond_pcn",
    n_walkers=512,        # start smaller for a quick test
    n_outer=20,
    n_max=500,
    n_steps=50,
    proposal_scale=0.2,
    beta=1.0,
    seed=0,
    print_every=5,

    geom_nu=5.0,
    prior_inflate=9.0,
    prior_nsig=12.0,
    prior_pad=1e-6,
    geom_jitter=1e-6,
)