In [None]:
##################################################################################
# EXPERIMENT RUNNER (UPDATED FOR preconditioned_pcn_jax)
##################################################################################

import os
import re
import numpy as np

import jax
import jax.numpy as jnp


# -----------------------------------------------------------------------------
# You likely already have these in your project; keep your existing imports.
# -----------------------------------------------------------------------------
import GaussianMixtureGenerator 
import GaussianMixtureLikelihood
# from your_project.transforms import inverse_jax, forward_jax, apply_boundary_conditions_x_jax
# from your_project.samplers import preconditioned_pcn_jax


SUPPORTED_EXPERIMENTS = {"gaussian"}


# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def mixture_mean_cov(means: jnp.ndarray, covs: jnp.ndarray, weights: jnp.ndarray, jitter: float = 1e-6):
    """
    Compute mixture mean/cov for a Gaussian mixture:
      means:   (K, D)
      covs:    (K, D, D)
      weights: (K,)
    Returns:
      mu:  (D,)
      cov: (D, D)
    """
    w = weights / jnp.sum(weights)
    mu = jnp.sum(w[:, None] * means, axis=0)  # (D,)

    diff = means - mu[None, :]  # (K, D)
    outer = diff[:, :, None] * diff[:, None, :]  # (K, D, D)

    cov = jnp.sum(w[:, None, None] * (covs + outer), axis=0)
    cov = cov + jitter * jnp.eye(cov.shape[0], dtype=cov.dtype)
    return mu, cov


def make_uniform_box_logprior(low: jnp.ndarray, high: jnp.ndarray):
    """
    Uniform prior on a hyper-rectangle:
      log p(x) = 0 inside [low, high], -inf outside
    """
    low = jnp.asarray(low)
    high = jnp.asarray(high)

    def logprior_fn(x: jnp.ndarray) -> jnp.ndarray:
        inside = jnp.all((x >= low) & (x <= high))
        return jax.lax.select(
            inside,
            jnp.asarray(0.0, dtype=x.dtype),
            jnp.asarray(-jnp.inf, dtype=x.dtype),
        )

    return logprior_fn


class IdentityBijection:
    def transform_and_log_det(self, u, condition=None):
        return u, jnp.asarray(0.0, dtype=u.dtype)

    def inverse_and_log_det(self, theta, condition=None):
        return theta, jnp.asarray(0.0, dtype=theta.dtype)


class IdentityFlow:
    """Use this only to validate wiring when you don't have a trained FlowJAX flow yet."""
    def __init__(self):
        self.bijection = IdentityBijection()


##################################################################################
# Runner
##################################################################################
class pcn_ExperimentRunner:
    """
    Runner that can generate samples using preconditioned_pcn_jax.

    Key design decisions:
      - We run the kernel multiple times (n_outer), each time updating the full ensemble state.
      - We store x across outer iterations -> samples shape (n_outer, N_walkers, D).
        Your diagnostics can flatten via reshape(-1, D).
    """

    def __init__(self, args, *, flow=None, scaler_cfg=None, scaler_masks=None):
        self.params = vars(args)

        # --- unique outdir ---
        base_results_dir = self.params["outdir"]
        unique_outdir = self.get_next_available_outdir(base_results_dir)
        print(f"Using output directory: {unique_outdir}")
        os.makedirs(unique_outdir, exist_ok=False)
        self.params["outdir"] = unique_outdir

        # --- validate experiment ---
        if self.params["experiment_type"] not in SUPPORTED_EXPERIMENTS:
            raise ValueError(
                f"Experiment type {self.params['experiment_type']} is not supported. "
                f"Supported types are: {SUPPORTED_EXPERIMENTS}"
            )

        print("Passed parameters:")
        for k, v in self.params.items():
            print(f"{k}: {v}")

        # Attach (or later set) flow + scaler
        self.flow = flow
        self.scaler_cfg = scaler_cfg
        self.scaler_masks = scaler_masks

        # Setup experiment
        if self.params["experiment_type"] == "gaussian":
            self._setup_gaussian_experiment(args)

        # Placeholder for results
        self.samples = None
        self.accept_history = None
        self.sigma_history = None
        self.calls_history = None

    # -------------------------------------------------------------------------
    # Setup
    # -------------------------------------------------------------------------
    def _setup_gaussian_experiment(self, args):
        print("Setting the target function to a Gaussian mixture distribution.")

        np.random.seed(900)

        D = int(self.params["n_dims"])

        # Generate "true" samples and mixture parameters (your existing generator)
        true_samples, means, covariances, weights = GaussianMixtureGenerator.generate_gaussian_mixture(
            n_dim=D,
            n_gaussians=args.nr_of_components,
            n_samples=args.nr_of_samples,
            width_mean=args.width_mean,
            width_cov=args.width_cov,
            weights=args.weights_of_components,
        )

        self.true_samples = true_samples

        # Convert mixture params to JAX arrays
        self.mcmc_means = jnp.stack(means, axis=0)        # (K, D)
        self.mcmc_covs = jnp.stack(covariances, axis=0)   # (K, D, D)
        self.mcmc_weights = jnp.asarray(weights)          # (K,)

        # Likelihood object you already have
        self.likelihood = GaussianMixtureLikelihood(
            means=self.mcmc_means,
            covs=self.mcmc_covs,
            weights=self.mcmc_weights,
        )

        # Prior bounds (uniform box)
        low_np, high_np = self.make_auto_bounds_inflated(
            means=means,
            covs=covariances,
            inflate=float(self.params.get("prior_inflate", 9.0)),
            nsig=float(self.params.get("prior_nsig", 12.0)),
            pad=float(self.params.get("prior_pad", 1e-6)),
        )
        self.prior_low = jnp.asarray(low_np)
        self.prior_high = jnp.asarray(high_np)

        # Student-t geometry in theta-space (use mixture moments as default)
        self.geom_mu, self.geom_cov = mixture_mean_cov(
            self.mcmc_means, self.mcmc_covs, self.mcmc_weights, jitter=float(self.params.get("geom_jitter", 1e-6))
        )
        self.geom_nu = jnp.asarray(self.params.get("geom_nu", 5.0), dtype=self.geom_mu.dtype)

        # Convenience target fn (optional)
        self.target_fn = self.target_normal

    # -------------------------------------------------------------------------
    # Public API
    # -------------------------------------------------------------------------
    def attach_flow_and_scaler(self, *, flow, scaler_cfg, scaler_masks):
        """Call this if you cannot provide these objects in __init__."""
        self.flow = flow
        self.scaler_cfg = scaler_cfg
        self.scaler_masks = scaler_masks

    def run_experiment(self):
        sampler = self.params.get("sampler", "precond_pcn")

        if self.params["experiment_type"] == "gaussian" and sampler == "precond_pcn":
            self._run_preconditioned_pcn_gaussian()
            return

        raise ValueError(
            f"Unsupported combination experiment_type={self.params['experiment_type']} sampler={sampler}"
        )

    # -------------------------------------------------------------------------
    # Core run method (your algorithm)
    # -------------------------------------------------------------------------
    def _run_preconditioned_pcn_gaussian(self):
        # --- required objects ---
        if self.flow is None:
            # If you want a hard error instead, replace with raise ValueError(...)
            print("Warning: self.flow is None; using IdentityFlow() for wiring test.")
            self.flow = IdentityFlow()

        if self.scaler_cfg is None or self.scaler_masks is None:
            raise ValueError(
                "scaler_cfg / scaler_masks are required for inverse_jax/forward_jax. "
                "Attach them via attach_flow_and_scaler(...) or pass into __init__."
            )

        D = int(self.params["n_dims"])
        N = int(self.params.get("n_walkers", 2048))

        # Outer iterations: each call to preconditioned_pcn_jax adapts sigma/mu internally up to n_max
        n_outer = int(self.params.get("n_outer", 50))

        # Kernel parameters
        beta = jnp.asarray(self.params.get("beta", 1.0), dtype=jnp.float32)
        n_max = int(self.params.get("n_max", 2000))
        n_steps = int(self.params.get("n_steps", 100))
        proposal_scale = jnp.asarray(self.params.get("proposal_scale", 0.2), dtype=jnp.float32)

        seed = int(self.params.get("seed", 0))
        key = jax.random.PRNGKey(seed)

        # Prior / likelihood functions in required signatures
        logprior_fn = make_uniform_box_logprior(self.prior_low, self.prior_high)
        blob0 = jnp.zeros((0,), dtype=jnp.float32)

        def loglike_fn(xi):
            ll = self.likelihood.log_prob(xi)  # scalar
            return ll, blob0

        # -----------------------------
        # Initialize ensemble state
        # -----------------------------
        key, k_init = jax.random.split(key, 2)
        u = jax.random.normal(k_init, shape=(N, D), dtype=jnp.float32)

        x, logdetj = inverse_jax(u, self.scaler_cfg, self.scaler_masks)

        # keep boundary-condition roundtrip consistent with your kernel
        x_bc = apply_boundary_conditions_x_jax(x, dict(self.scaler_cfg))
        u_bc = forward_jax(x_bc, self.scaler_cfg, self.scaler_masks)
        x, logdetj = inverse_jax(u_bc, self.scaler_cfg, self.scaler_masks)
        u = u_bc

        finite0 = jnp.isfinite(logdetj) & jnp.all(jnp.isfinite(x), axis=1)

        def _prior_or_neginf(xi, ok):
            return jax.lax.cond(
                ok,
                lambda z: logprior_fn(z),
                lambda z: jnp.asarray(-jnp.inf, dtype=xi.dtype),
                xi,
            )

        logp = jax.vmap(_prior_or_neginf, in_axes=(0, 0), out_axes=0)(x, finite0)
        finite1 = finite0 & jnp.isfinite(logp)

        def _like_or_neginf(xi, ok):
            def _do(z):
                return loglike_fn(z)
            def _skip(z):
                return jnp.asarray(-jnp.inf, dtype=xi.dtype), blob0
            return jax.lax.cond(ok, _do, _skip, xi)

        logl, _ = jax.vmap(_like_or_neginf, in_axes=(0, 0), out_axes=(0, 0))(x, finite1)

        # Required by kernel interface; it recomputes logdetj_flow internally anyway
        logdetj_flow = jnp.zeros((N,), dtype=jnp.float32)
        blobs = jnp.zeros((N, 0), dtype=jnp.float32)

        # storage
        xs = []
        accept_hist = []
        sigma_hist = []
        calls_hist = []

        # -----------------------------
        # Outer loop: accumulate samples
        # -----------------------------
        for t in range(n_outer):
            out = preconditioned_pcn_jax(
                key,
                u=u,
                x=x,
                logdetj=logdetj,
                logl=logl,
                logp=logp,
                logdetj_flow=logdetj_flow,
                blobs=blobs,
                beta=beta,
                loglike_fn=loglike_fn,
                logprior_fn=logprior_fn,
                flow=self.flow,
                scaler_cfg=self.scaler_cfg,
                scaler_masks=self.scaler_masks,
                geom_mu=self.geom_mu,
                geom_cov=self.geom_cov,
                geom_nu=self.geom_nu,
                n_max=n_max,
                n_steps=n_steps,
                proposal_scale=proposal_scale,
                condition=None,
            )

            # update state
            key = out["key"]
            u = out["u"]
            x = out["x"]
            logdetj = out["logdetj"]
            logdetj_flow = out["logdetj_flow"]
            logl = out["logl"]
            logp = out["logp"]
            blobs = out["blobs"]

            xs.append(x)
            accept_hist.append(out["accept"])
            sigma_hist.append(out["proposal_scale"])
            calls_hist.append(out["calls"])

            if (t + 1) % int(self.params.get("print_every", 10)) == 0:
                acc = float(np.asarray(out["accept"]))
                sig = float(np.asarray(out["proposal_scale"]))
                calls = int(np.asarray(out["calls"]))
                steps = int(np.asarray(out["steps"]))
                print(f"[outer {t+1:>4d}/{n_outer}] accept={acc:.4f} sigma={sig:.4f} calls={calls} steps={steps}")

        # Store results
        self.samples = np.asarray(jnp.stack(xs, axis=0))  # (n_outer, N, D)
        self.accept_history = np.asarray(jnp.stack(accept_hist))
        self.sigma_history = np.asarray(jnp.stack(sigma_hist))
        self.calls_history = np.asarray(jnp.stack(calls_hist))

        # Convenience summary
        print(
            f"Done. samples shape={self.samples.shape} "
            f"mean_accept={self.accept_history.mean():.4f} "
            f"last_sigma={self.sigma_history[-1]:.4f}"
        )

    # -------------------------------------------------------------------------
    # Existing utilities / diagnostics
    # -------------------------------------------------------------------------
    def target_normal(self, x, data=None):
        return self.likelihood.log_prob(x)

    def get_next_available_outdir(self, base_dir: str, prefix: str = "results") -> str:
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        existing = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
        matches = [re.match(rf"{prefix}_(\d+)", name) for name in existing]
        numbers = [int(m.group(1)) for m in matches if m]
        next_number = max(numbers, default=0) + 1
        return os.path.join(base_dir, f"{prefix}_{next_number}")

    @staticmethod
    def make_auto_bounds_inflated(means, covs, inflate=9.0, nsig=12.0, pad=1e-6,
                                 prior_low=None, prior_high=None):
        means = np.asarray(means, dtype=float)                 # (K, D)
        covs = np.asarray(covs, dtype=float) * float(inflate)  # inflate variance

        mu_min = means.min(axis=0)                             # (D,)
        mu_max = means.max(axis=0)                             # (D,)

        std_max = np.sqrt(np.stack([np.diag(C) for C in covs], axis=0)).max(axis=0)  # (D,)

        low = mu_min - nsig * std_max - pad
        high = mu_max + nsig * std_max + pad

        if prior_low is not None:
            low = np.minimum(low, float(prior_low))
        if prior_high is not None:
            high = np.maximum(high, float(prior_high))
        return low, high

    def get_true_and_mcmc_samples(self, discard=0, thin=1):
        dim = int(self.params["n_dims"])

        if not hasattr(self, "true_samples") or self.true_samples is None:
            raise ValueError("No true samples found. Ensure self.true_samples is set (gaussian experiment).")

        true_np = np.asarray(self.true_samples).reshape(-1, dim)

        if hasattr(self, "samples") and self.samples is not None:
            samp = np.asarray(self.samples).reshape(-1, dim)  # works for (n_outer, N, D) too
            samp = samp[int(discard)::int(thin), :]
            mcmc_np = samp
        else:
            raise ValueError("No sampler samples found. Run run_experiment() first.")

        return true_np, mcmc_np