First imports/installs for this to run. As of Aug 1st, 2023, this will run in a high-RAM Google Colab instance

In [None]:
!pip install jaxopt

In [None]:
!pip install --upgrade scipy

In [None]:
!pip install optax

In [None]:
# The sampling is better behaved in 64 bit mode
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
import graphviz
import os
import arviz as az
import xarray as xr
from tqdm.autonotebook import tqdm
# import tqdm

import time

from IPython.display import set_matplotlib_formats
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp, logit, expit
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from google.colab import files
import jax
from jax.scipy import stats
from jax.scipy.linalg import cholesky
import sympy
import sympy as sp
from functools import partial
from typing import Callable
import scipy
import matplotlib.tri as tri

In [None]:
from tensorflow_probability.substrates import jax as tfp

In [None]:
import jaxopt
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_box
import optax

In [None]:
seed = 42
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)

In [None]:
#written by Enzo Michelangeli, style changes by josef-pktd
# Student's T random variable
@partial(jax.jit, static_argnums=(4,))
def multivariate_t_rvs(key, m, S, df=10000.0, n=1):
    '''generate random variables of multivariate t distribution
    Parameters
    ----------
    m : array_like
        mean of random variable, length determines dimension of random variable
    S : array_like
        square array of covariance  matrix
    df : int or float
        degrees of freedom
    n : int
        number of observations, return random array will be (n, len(m))
    Returns
    -------
    rvs : ndarray, (n, len(m))
        each row is an independent draw of a multivariate t distributed
        random variable
    '''
    m = jnp.asarray(m)
    d = m.shape[0]
    x = jax.random.gamma(key, df/2, shape=(n,))/(df)
    x = x.reshape(-1,1)
    z = jax.random.multivariate_normal(key, jnp.zeros(d),S,(n,))
    return m + z/jnp.sqrt(x)

@partial(jax.jit)
def mvStudent_logpdf(x, mean, shape, df):
    dim = mean.shape[0]

    vals, vecs = jnp.linalg.eigh(shape)
    vals += 1e-7
    logdet     = jnp.log(vals).sum()
    valsinv    = 1./vals #jnp.array([1./v for v in vals])
    U          = vecs * jnp.sqrt(valsinv)
    dev        = x - mean
    maha       = jnp.square(jnp.dot(dev, U)).sum(axis=-1)
    #print(maha)

    t = 0.5 * (df + dim)
    A = jax.scipy.special.gammaln(t)
    B = jax.scipy.special.gammaln(0.5 * df)
    C = dim/2. * jnp.log(df * np.pi)
    D = 0.5 * logdet
    E = -t * jnp.log(1 + (1./df) * maha)

    return A - B - C - D + E

In [None]:
def fill_diagonal(a, val):
  # https://github.com/google/jax/issues/2680
  assert a.ndim >= 2
  i, j = jnp.diag_indices(min(a.shape[-2:]))
  return a.at[..., i, j].set(val)



# K-Means++ Seed Selection

Used for the GMM fitting process.

In [None]:
@jax.jit
def assign(x, means):
  """
  x: shape (n, d)
  means: shape (n, d)
  """
  assignment = jax.vmap(
      lambda point: jnp.argmin(jax.vmap(jnp.linalg.norm)(means - point))
  )(x)
  dists = jax.vmap(jnp.linalg.norm)(means[assignment,:] - x)
  return assignment, dists

In [None]:
@partial(jax.jit, static_argnums=(1,))
def kpp_seeds(X, K, key, w=None):
  """
  X: has shape (n, d) for the dataset
  K: the number of seeds we want
  key: is a PRNG key
  w: the weights for the seed selection
  return int32 tensor of shape (k) with the indicies of the seeds in X
  """
  key, subkey = jax.random.split(key) #we are going to need k PRNG steps

  n, d = X.shape
  if w is None:
    w = jnp.ones(n)/n

  seeds = jnp.zeros(K, dtype=jnp.int32)


  seeds = seeds.at[0].set(jax.random.randint(subkey, minval=0, maxval=n-1, shape=(1,))[0])
  dists = jnp.linalg.norm(X-X[seeds[0],:], axis=-1)
  assignment = jnp.zeros(n, dtype=jnp.int32)

  def kpp_update(k_, val):
    assignment, dists, seeds, key = val

    #Lets select a new seed by weighted probability
    key, subkey = jax.random.split(key)
    d2 = w*dists*dists
    prob_select = d2/d2.sum()
    to_add = jax.random.choice(subkey, n, shape=(1,), p=prob_select)[0]
    seeds = seeds.at[k_].set(to_add)

    #Whats everyones distance to this?
    new_mean = X[seeds[k_],:]
    new_mean_dists = jnp.linalg.norm(X-new_mean, axis=-1)

    #Lets update the stats
    changed_ownership = new_mean_dists < dists
    dists = jnp.where(changed_ownership, new_mean_dists, dists)
    assignment = jnp.where(changed_ownership, k_, assignment)

    return assignment, dists, seeds, key

  init_values = assignment, dists, seeds, key

  assignment, dists, seeds, key = jax.lax.fori_loop(1, K, kpp_update, init_values)
  return assignment, dists, seeds, key

# GMM

In [None]:
def fit_gmm(key, data, K, w=None, student=False, tol=0.001, max_iters=400, fullCov=True):
  d = data.shape[1]
  n = data.shape[0]

  if w is None:
    W_ = jnp.ones(n)
  else:
    W_ = w

  S_all = jnp.cov(data, rowvar=False, aweights=W_).reshape(d,d)
  # print(S_all.shape)
  S_all = fill_diagonal(S_all, jnp.maximum(jnp.diag(S_all), 1e-3))

  def e_step(args):
    m, s = args
    # if student:
    #   log_pdfs = mvStudent_logpdf(data, m, s, df=1)
    # else:
    #   log_pdfs = jax.scipy.stats.multivariate_normal.logpdf(data, m, s)
    log_pdfs = jax.lax.cond(fullCov,
      lambda z: jax.scipy.stats.multivariate_normal.logpdf(data, m, s), #True
      lambda z: jax.scipy.stats.norm.logpdf(data, loc=m, scale=jnp.diag(s)**0.5).sum(axis=-1),
      m
      )
    # log_pdfs = data[:,0]*m[0]
    return log_pdfs

  def m_step(args):
    w = args*W_
    cov = jnp.cov(data, rowvar=False, aweights=w).reshape(d,d)
    cov = jnp.nan_to_num(cov, nan=S_all.max())
    cov = fill_diagonal(cov, jnp.maximum(jnp.diag(cov), 1e-3))
    return jnp.average(data, axis=0, weights=w), cov.reshape((d,d))

  def m_step_diag(args):
    w = args*W_
    avg = jnp.average(data, axis=0, weights=w)
    # cov = jnp.cov(data, rowvar=False, aweights=w).reshape(d,d)
    cov = jnp.average((data-avg.reshape(1, -1))**2, axis=0, weights=w)
    cov = jnp.nan_to_num(cov, nan=S_all.max())
    cov = jnp.maximum(cov, 1e-5)
    return avg, jnp.diag(cov)

  @partial(jax.jit, static_argnames=())
  def whole_step(𝜇, Σ, assignments):
    log_probs = jax.vmap(e_step)((𝜇, Σ))
    weights = jax.nn.softmax(log_probs, axis=0)
    # 𝜇, Σ = jax.vmap(m_step)(weights)
    𝜇, Σ = jax.lax.cond(fullCov,
      lambda weights: jax.vmap(m_step)(weights), #True
      lambda weights: jax.vmap(m_step_diag)(weights),
      weights
      )

    new_assignments = jnp.argmax(log_probs, axis=0)
    changes = (assignments != new_assignments).sum()
    # assignments = new_assignments
    return 𝜇, Σ, new_assignments, changes

  assignments = jax.random.randint(key, shape=(n,), minval=0, maxval=K, dtype=jnp.int32)
  #𝜇 = jax.random.choice(key, data, shape=(K,), replace=False)# init random centers jnp.zeros((K,d))
  assignments, dists, seeds, *_ = kpp_seeds(data, K, key, w=W_)# init random centers jnp.zeros((K,d))
  weights = jax.nn.one_hot(assignments, K).T
  # print(weights.shape, data.shape)
  𝜇, Σ = jax.vmap(m_step)(weights)


  changes = n
  iters = 0
  # while changes > n*tol and iters < max_iters:
  with tqdm(total=max_iters, leave=False) as pbar:
    while changes >= n*tol and iters < max_iters:
      𝜇, Σ, assignments, changes = whole_step(𝜇, Σ, assignments)
      iters += 1
      pbar.set_description(f"changes {changes/float(n)}%")
    # print(𝜇.T)
      pbar.update(1)
    pbar.update(max_iters-iters)

    # print("\t", changes)
  return 𝜇, Σ, jax.nn.softmax(jax.vmap(e_step)((𝜇, Σ)), axis=0).sum(axis=1)

In [None]:
@partial(jax.jit)
def mm_log_pdf(data, 𝜇, Σ, weights, student=False, df=1e10, fullCov=True):
  """
  𝜇 must be shaped as (clusters, d)
  Σ must be shaped as (clusters, d, d)
  """
  d = 𝜇.shape[1]

  data = data.reshape(-1, d)
  def LL(args):
    m, s = args
    log_pdfs = jax.lax.cond(student,
      lambda x: mvStudent_logpdf(x, m, s, df=df), #True
      lambda x: jax.lax.cond(fullCov,
        lambda z: jax.scipy.stats.multivariate_normal.logpdf(x, m, s),  #True
        lambda z: jax.scipy.stats.norm.logpdf(x, loc=m, scale=jnp.diag(s)**0.5).sum(axis=-1),
        m ),
      data
      ) #False
    # log_pdfs = data[:,0]*m[0]
    return log_pdfs
  weights = weights / weights.sum()
  weighted_log_pdfs = jax.vmap(LL)((𝜇, Σ)) + jnp.log(weights+1e-7).reshape(-1, 1)
  return jax.scipy.special.logsumexp(weighted_log_pdfs, axis=0)

Function to draw samples from our GMM using the bounds provided. Also has code support for using the full covariance and a naive rejection approach, but it is extremly slow if using bounded support. So we advice against it.

In [None]:

def getSamples(key, n, 𝜇, Σ, weights=None, low_bounds=None, hi_bounds=None, student=False, df=None, fullCov=True):

  if low_bounds is None:
    low_bounds = jax.lax.full_like(𝜇[0,:], -jnp.inf)
  if hi_bounds is None:
    hi_bounds = jax.lax.full_like(𝜇[0,:], jnp.inf)
  if weights is None:
    weights = jax.lax.full_like(𝜇[:,0], 1.0)

  samples_per_cluster = tfp.distributions.Multinomial(n, probs=weights/weights.sum()).sample(sample_shape=(1), seed=key)[0,:].astype(jnp.int32)

  key, subkey = jax.random.split(key)

  d = 𝜇.shape[1]
  if df is None:
    df = 1
  all_samples = []
  for i in range(𝜇.shape[0]):
    cur_samples = jnp.zeros((0,d))

    key, subkey = jax.random.split(key)
    sample_efficency = 1.0 #How many extra samples do we need to get where we want to be?
    while cur_samples.shape[0] < samples_per_cluster[i]:
      samples_needed = samples_per_cluster[i]-cur_samples.shape[0]
      n_ = jnp.round(sample_efficency*samples_needed).astype(jnp.int32)
      if student:
        new_samples = multivariate_t_rvs(subkey, 𝜇[i,:], Σ[i,:], df=df, n=n_)
      else:
        if fullCov:
          new_samples = jax.random.multivariate_normal(subkey, 𝜇[i,:], Σ[i,:], (n_,))
        else:
          m = 𝜇[i,:]
          s = jnp.diag(Σ[i,:])**0.5
          new_samples = jax.random.truncated_normal(subkey, (low_bounds-m)/s, (hi_bounds-m)/s, shape=(n_,d))*s+m

      #find out-of-bound samples
      bad_samples = jnp.logical_or(jnp.logical_or(new_samples < low_bounds, new_samples > hi_bounds), jnp.isnan(new_samples)).sum(axis=1)
      new_samples = new_samples[bad_samples == 0, :]
      # print("Bad Samples: ", bad_samples.sum())
      sample_efficency = jnp.maximum(samples_needed/(new_samples.shape[0]+1), sample_efficency)
      #add to the total set of samples
      cur_samples = jnp.vstack([cur_samples, new_samples])

      key, subkey = jax.random.split(key)

    all_samples.append(cur_samples)
  return jnp.vstack(all_samples)

In [None]:
@jax.jit
def norm_invcdf(p):
  p = jnp.clip(p, 1e-16, 1-1e-16)
  return jnp.sqrt(2)*jax.scipy.special.erfinv(2*p-1)

def rnd_trunk_normal(key, l, h, m, s, n):
  a, b = (l-m)/s, (h-m)/s
  phi_a =  jax.scipy.stats.norm.cdf(a)
  phi_b =  jax.scipy.stats.norm.cdf(b)
  U = jax.random.uniform(key, (n, m.shape[-1]))
  return norm_invcdf(phi_a + U*(phi_b-phi_a))*s+m



# Target Functions

These are the functions from the paper that we use as benchmarks

In [None]:
@jax.jit
def log_f_Erraqabi(x):
  # f_ = lambda z: jnp.log(jnp.sin(4*jnp.pi*z-jnp.pi/2)+1)
  # return jax.vmap(f_)(x).sum(axis=-1)
  z = jnp.sin(4*jnp.pi*x-jnp.pi/2)+1
  z = jnp.nan_to_num(jnp.log(z), nan=-1e38)
  return jnp.nan_to_num(z.sum(axis=-1))

In [None]:
@jax.jit
def log_f_Maddison(x, a):
  z = (-x - a*jnp.log(1+jnp.maximum(x, 0.0))).sum(axis=-1, keepdims=True)
  return jnp.where(x >= 0, z, -1e38).reshape(-1)

In [None]:
@partial(jax.jit, static_argnums=())
def log_f_clutter(x, data, sigma=2, pi=0.5):
  dim = data.shape[1]
  n = x.shape[0]
  x = x.reshape(n, -1)
  log_prior = -0.5*(x ** 2/sigma**2).sum(axis=-1) - dim*0.5*jnp.log(2*jnp.pi) - jnp.log(sigma).sum()

  model = (-0.5*((data[jnp.newaxis,:,:]-x[:,jnp.newaxis,:]) ** 2).sum(axis=-1) - dim * 0.5 * jnp.log(2*jnp.pi))
  noise = -0.5*(data ** 2 / 100. ** 2).sum(axis=-1) - dim * 0.5 * jnp.log(2*jnp.pi) - dim * jnp.log(100)
  noise = jnp.repeat(noise.reshape(-1, data.shape[0]), model.shape[0], axis=0)
  log_likelihood = logsumexp(jnp.stack([model + jnp.log(pi), noise + jnp.log(1-pi)], axis=-1), axis=-1).sum(axis=-1)
  negative_energy = log_likelihood + log_prior
  return negative_energy


# Easy Rejection Sampling

In [None]:
def getSubsetIDs(key, s, to_grab): #Not effectivly used in current code, we always pick all of them
  if to_grab >= s.shape[0]:
    return jnp.arange(s.shape[0])
  to_grab = to_grab//3
  # jnp.argsort(-s)[0:to_grab]
  keya, subkey = jax.random.split(key)
  stochastic_highs = jax.random.choice(subkey, int(s.shape[0]), shape=(int(to_grab),), p=jax.nn.softmax(s))
  keya, subkey = jax.random.split(key)
  stochastic_others = jax.random.choice(subkey, int(s.shape[0]), shape=(int(to_grab),))
  lowest = jnp.arange(s.shape[0])[jnp.argsort(-s)[0:to_grab]]
  ids = jnp.unique(jnp.hstack((stochastic_highs, stochastic_others, lowest)))
  return ids

In [None]:
def ers(key, log_func, d, target_samples=10000, low_val=-jnp.inf, hi_val=jnp.inf, samples_at_a_time = 500):
  """
  log_f: the log PDF of the function to draw samples from
  d: the number of dimensions to be sampling
  """
  key, subkey = jax.random.split(key)
  use_heavy_tail = False

  history = {}
  history["logC"] = []
  history["rate"] = []
  history["samples"] = []
  history["fits"] = []

  fullCov = False # d < 4

  f_eval_total = 0 #Keep track of how many times f has been evaluated

  low_box = jnp.ones(d)*low_val
  hi_box = jnp.ones(d)*hi_val

  gmm_time = 0.0
  refine_time = 0.0
  init_time = 0.0
  sampling_time = 0.0

  start = time.time()
  w_init = jax.random.uniform(subkey, (d,), minval=low_box, maxval=hi_box)
  w_init = jnp.nan_to_num(w_init, neginf=0.5, posinf=0.5) #if a box was +- inf, this will set it to zero
  step_size_search = 0.0 #By default we will use a line search indicate by 0

  if jnp.isfinite(low_box).all() and jnp.isfinite(hi_box).all():
    # print("Using box center as initial mean")
    f_eval_total += 1
    w_init = hi_box-low_box
    f_mode = w_init/2
    Sigma_est = jnp.diag(w_init/3) #Should cover the whole box
    𝜇 = f_mode.reshape(1, d) #shape with (1, *) b/c we will treat it as a GMM for code simplicity, later one we add more means
    Σ = Sigma_est.reshape(1, d, d)
    weights = jnp.ones((1))
  else:
    while log_func(w_init).max() < -1e36: #If discontinous we could by chance start in a bad spod
      # print("Attempting to sample a good point, b/c last one was ", log_func(w_init).max())
      key, subkey = jax.random.split(key)
      f_eval_total += 1
      step_size_search = 1e-4 # Line searches are a bad idea with discontinuities, lets use normal SGD
      w_init = jax.random.uniform(subkey, (d,), minval=low_box, maxval=hi_box)
      w_init = jnp.nan_to_num(w_init, neginf=0.5, posinf=0.5) #if a box was +- inf, this will set it to zero

    #First we find the mode of the target distribution
    if  jnp.isfinite(low_box).all() and jnp.isfinite(hi_box).all():
      searchers = jnp.concatenate([w_init.reshape(1, d), jax.random.uniform(subkey, (2**d+1,d), minval=low_val, maxval=hi_val)], axis=0)
    else:
      searchers = jnp.concatenate([w_init.reshape(1, d), w_init+jax.random.normal(subkey, (2**d+d+1,d))], axis=0)
    searchers = jnp.clip(searchers, low_box+1e-2, hi_box-1e-2)
    key, subkey = jax.random.split(key)
    pg = ProjectedGradient(fun=lambda x: -log_func(x).sum(), projection=projection_box, stepsize=step_size_search, jit=True, maxiter=50, maxls=15, tol=0.01)
    pg_sol_mode = pg.run(searchers, hyperparams_proj=(low_box+1e-2, hi_box-1e-2))
    f_mode = pg_sol_mode.params
    f_eval_total += pg_sol_mode.state.iter_num*15*searchers.shape[0]
    f_mode = jnp.clip(f_mode, low_box+1e-2, hi_box-1e-2)
    f_high_logpdf = log_func(f_mode)

    if jnp.cov(f_mode, rowvar=False).max() < 0.01: #Looks unimodal, we need to figure out a decent stnd. dev.
      # print("Unimodal case")
      f_mode = f_mode[0,:]

      #Now we find a set of "seed" points that are use to estimate a covariance of the
      #  distribution. The seeds are started from randomly perturbed mode, and
      # optimized to have a lower PDF to encourage them to go aware from the mode
      initial_spread = f_mode + jax.random.normal(key, (d*2+10, d))
      initial_spread = jnp.clip(initial_spread, low_box+1e-2, hi_box-1e-2)

      def spreadLoss(centers, spread, shift=-5):
        f_main = log_func(centers)
        f_spread = log_func(spread)
        closest =  (f_main.flatten() - f_spread.reshape(-1, 1)).min(axis=-1) #what was closest
        return jnp.mean(jnp.power((closest+shift), 2))

      pg = ProjectedGradient(fun=lambda z: spreadLoss(f_mode, z, -5), projection=projection_box, jit=True, stepsize=0.0, maxiter=10, maxls=5, tol=0.01)
      pg_sol = pg.run(initial_spread, hyperparams_proj=(low_box+1e-2, hi_box-1e-2))
      f_eval_total += pg_sol.state.iter_num*10*initial_spread.shape[0]

      seed_points = jnp.vstack((f_mode, pg_sol.params))
      logPDF_seeds = log_func(seed_points)
      weights = jax.nn.softmax(logPDF_seeds.flatten())
      Sigma_est = jnp.cov(seed_points, rowvar=False, aweights=weights) + jnp.diag(jnp.ones(d))*0.1
      #We have now chosen an initial starting point w_init, as well as an initial covariance Sigma_est, to use as our initial functoin g(x)
      𝜇 = f_mode.reshape(1, d) #shape with (1, *) b/c we will treat it as a GMM for code simplicity, later one we add more means
      Σ = Sigma_est.reshape(1, d, d)
      weights = jnp.ones((1))
    else: #We found more than one mode, lets just use those.
      # print("Multi-modal case")
      modes = f_mode
      pair_dist = jnp.sqrt(jnp.sum((modes[:, None, :] - modes[None, :, :])**2, axis=-1))
      i, j = jnp.unravel_index(jnp.argmax(pair_dist, axis=None), pair_dist.shape)
      max_dist = pair_dist[i,j]

      selected_modes = jnp.concatenate([modes[i:i+1,:], modes[j:j+1,:]], axis=0)
      pair_dist = jnp.sqrt(jnp.sum((selected_modes[:, None, :] - modes[None, :, :])**2, axis=-1)).min(axis=0)
      next_fathest = jnp.argmax(pair_dist)
      while pair_dist[next_fathest] > 0.01: #We have another valid mode
        selected_modes = jnp.concatenate([selected_modes, modes[next_fathest:next_fathest+1,:]], axis=0)
        pair_dist = jnp.sqrt(jnp.sum((selected_modes[:, None, :] - modes[None, :, :])**2, axis=-1)).min(axis=0)
        next_fathest = jnp.argmax(pair_dist)
      #We place a gaussian mode at each mode and use a shrunk Σ based on pairwise mode cov
      𝜇 = selected_modes #shape with (1, *) b/c we will treat it as a GMM for code simplicity, later one we add more means
      # Sigma_est = jnp.cov(selected_modes, rowvar=False)*0.5+0.5*jnp.eye(d)
      if selected_modes.shape == 2:
        Sigma_est = jnp.eye(d)*max_dist
      else:

        pair_dist = jnp.sqrt(jnp.sum((selected_modes[:, None, :] - selected_modes[None, :, :])**2, axis=-1))
        v = fill_diagonal(pair_dist, 1e20).min(axis=0).max()
        v = v/(selected_modes.shape[0]*3)
        Sigma_est = jnp.eye(d)*v
      Σ = jnp.repeat(Sigma_est.reshape((1, d, d)), selected_modes.shape[0], axis=0)
      weights = jnp.ones((selected_modes.shape[0]))/selected_modes.shape[0]

  init_time += time.time() - start
  history["init_g"] = (𝜇, Σ)
  #Now we have initial sigma and mu
  # print(𝜇.shape)
  # print((𝜇, Σ))
  total_sampled = 0
  total_accepted =0
  log_C = -1e30

  cur_samples = []
  cur_samples_f = []

  rej_samples = []
  rej_samples_f = []

  max_dev = 20#jnp.minimum((hi_box - low_box).max()*3, 20)-1e-2
  refit_cycle = 2 #How many X times more data before we refit/refine the model?

  keya = key
  keya, subkey = jax.random.split(keya)
  df = 1000.0#np.inf
  ignore_threshold = -1e36

  𝜇, Σ, weights = 𝜇.astype(jnp.float64), Σ.astype(jnp.float64), weights.astype(jnp.float64)
  K_orig = 𝜇.shape[0]
  last_gmm_fit_size = 10
  refit = True
  last_fit_sample_size = 5
  #used for 'undoing' a bad GMM attempt
  logC_old = 1e20
  𝜇_old, Σ_old = 𝜇, Σ
  gmm_ban = 0
  gmm_ban_strength = 2
  gmm_size_factor = 1

  log_C_best = 1000.0

  regression = False
  prev_rate = 0.0
  # print("Starting " , 𝜇, Σ)
  first_iter = True
  with tqdm(total=target_samples, leave=False) as pbar:
    while total_accepted < target_samples:
      pbar.set_description("Sampling at "+str(prev_rate))
      pbar.refresh()
      start = time.time()
      to_grab = int(samples_at_a_time*𝜇.shape[0])
      # to_grab = int(samples_at_a_time*jnp.log(𝜇.shape[0]+1))
      candidates = getSamples(subkey, to_grab, 𝜇, Σ, weights, low_bounds=low_box, hi_bounds=hi_box, fullCov=fullCov)
      keya, subkey = jax.random.split(keya)
      total_sampled += candidates.shape[0]
      gmm_ban -= 1

      f_eval_total += candidates.shape[0]
      log_f = log_func(candidates)
      valid = log_f > ignore_threshold #Candidates that occured in a zero area of the target function
      # candidates = candidates[valid,:]
      # log_f = log_f[valid]
      #now sample
      log_g = mm_log_pdf(candidates, 𝜇, Σ, weights, fullCov=fullCov)
      #We now add a constant to g(x) so that g(x) >= f(x), but done in log-space

      diff = log_f-log_g
      log_C_hat = diff[valid].max()
      log_C_best = min(log_C_best*(1+jnp.sign(log_C_best)*0.05), log_C_hat)
      if log_C_hat - log_C_best > 0.05:
        refit = True
      if log_C_hat - log_C > 0.1 and total_accepted > 10: #Big change, lets adjust the model
        refit = True
        regression = log_C_hat - log_C > 0.2
      else:
        regression = False
      log_C = logC_old = jnp.maximum(log_C_hat, log_C)
      U = jnp.log(jax.random.uniform(subkey, (candidates.shape[0],)))
      keya, subkey = jax.random.split(keya)

      to_accept = U <= diff-log_C
      cur_samples.append(candidates[to_accept,:])
      cur_samples_f.append(log_f[to_accept].ravel())
      rej_samples.append(candidates[~to_accept,:])
      rej_samples_f.append(log_f[~to_accept].ravel())
      total_accepted += to_accept.sum()

      sampling_time += time.time() - start
      prev_rate = float(to_accept.sum()/float(valid.shape[0]))
      history["logC"] += [float(log_C)]
      history["rate"] += [prev_rate]
      history["samples"] += [int(valid.shape[0])]

      sizeToK = lambda x : jnp.minimum(jnp.log(x+1)/jnp.log(2), x/(d*15))

      gmm_try = False
      refine_try = False
      refine_accpt = False

      pbar.update(int(to_accept.sum()))
      𝜇_old, Σ_old, weights_old, log_C_old = 𝜇, Σ, weights, log_C
      𝜇_pre, Σ_pre, weights_pre = 𝜇, Σ, weights

      if ((last_gmm_fit_size*1.5 < total_accepted) and total_accepted >= d**2 and gmm_ban < 0) or regression:
        gmm_try = True
        start = time.time()
        pbar.set_description("Fitting new GMM")
        pbar.refresh()
        #Compress everything into one tensor
        last_gmm_fit_size = total_accepted
        cur_samples = [jnp.vstack(cur_samples)]
        cur_samples_f =  [jnp.hstack(cur_samples_f)]
        rej_samples = [jnp.vstack(rej_samples)]
        rej_samples_f =  [jnp.hstack(rej_samples_f)]
        refit = True

        K = jnp.round(sizeToK(total_accepted)).astype(jnp.int32)

        samples = jnp.concatenate([cur_samples[0], rej_samples[0] ], axis=0)
        w = jax.nn.softmax(jnp.hstack([cur_samples_f[0], rej_samples_f[0] ]))
        w = w * jnp.hstack([jnp.ones_like(cur_samples_f[0])*10, jnp.ones_like(rej_samples_f[0])])
        gmm_size_factor = 1
        𝜇, Σ, weights = fit_gmm(subkey, samples, int(K+K_orig), w=w, tol=1e-5, fullCov=fullCov)
        keya, subkey = jax.random.split(keya)
        # break

        fs = cur_samples_f[0]
        gs = mm_log_pdf(cur_samples[0], 𝜇, Σ, weights, fullCov=fullCov)
        log_C = -1e30
        gmm_time += time.time() - start

      if total_accepted > max(last_fit_sample_size*refit_cycle, d*5):
        refit = True

      start = time.time()
      if refit:
        refine_try = True

        cur_samples = [jnp.vstack(cur_samples)]
        cur_samples_f =  [jnp.hstack(cur_samples_f)]
        rej_samples = [jnp.vstack(rej_samples)]
        rej_samples_f =  [jnp.hstack(rej_samples_f)]

        # print("Adjusting the model to rejected samples")
        pbar.set_description("Adjusting the model to rejected samples")
        pbar.refresh() # to show immediately the update

        if fullCov:
          def safeSigma(s):
            s = jnp.nan_to_num(s)
            S = s @ s.T
            diag = jnp.diag(S)
            S = fill_diagonal(S, jnp.maximum(diag, 1e-5))
            return S
        else:
           def safeSigma(s):
             return jnp.diag(jnp.exp(s))

        def optGMM(𝜇, Σ, weights, log_C_old):
          if fullCov:
            Σ_opt = jax.vmap(lambda X : jnp.linalg.cholesky(X))(Σ)
          else:
            Σ_opt = jax.vmap(lambda X : jnp.log(jnp.diag(X)+1e-16))(Σ)


          g_c = mm_log_pdf(cur_samples[0].reshape(-1, d), 𝜇, Σ, weights, fullCov=fullCov)
          g_r = mm_log_pdf(rej_samples[0].reshape(-1, d), 𝜇, Σ, weights, fullCov=fullCov)

          to_select_size = 10000000
          cur_IDs = getSubsetIDs(subkey, cur_samples_f[0]-g_c, to_select_size)
          rej_IDs = getSubsetIDs(subkey, rej_samples_f[0]-g_r, to_select_size)


          cur_s_f = cur_samples_f[0][cur_IDs]
          cur_s   = cur_samples[0][cur_IDs]
          rej_s_f = rej_samples_f[0][rej_IDs]
          rej_s   = rej_samples[0][rej_IDs]


          #Refinmenet based on accepted & rejected samples
          @jax.jit
          def lossDiff(args):
            m, S, w = args
            S = jax.vmap(safeSigma)(S)
            m = jnp.clip(m, low_val, hi_val)
            w = jnp.nan_to_num(jax.nn.softmax(w), nan=0, neginf=0)
            rej_g = mm_log_pdf(rej_s.reshape(-1, m.shape[-1]), m, S, w, fullCov=fullCov)
            cur_g = mm_log_pdf(cur_s.reshape(-1, m.shape[-1]), m, S, w, fullCov=fullCov)

            diff_rej =  rej_s_f-rej_g
            diff_cur =  cur_s_f-cur_g
            diff_a = jnp.hstack([diff_rej, diff_cur ])
            #Weights of each item
            diff_a_w = jax.nn.softmax(diff_a, axis=0)
            loss = jnp.sum(diff_a_w * diff_a)
            return jnp.nan_to_num(loss, nan=1000.0)

          𝜇_opt = 𝜇
          w_opt = jnp.log(weights)
          logC_new = 1e30
          opt = optax.adabelief(0.1)
          pgLD = jaxopt.OptaxSolver(opt=opt, fun=lossDiff, maxiter=800)
          params = (𝜇_opt, Σ_opt, w_opt)
          state = pgLD.init_state(params)

          best_sol = 𝜇, Σ, weights, log_C_old+1e-16 #No improvment

          for steps in [100, 100, 200, 400]:
            pbar.set_description("Adjusting the model to rejected samples, trying " + str(steps))
            pbar.refresh() # to show immediately the update

            for _ in tqdm(range(steps), desc='Opt', leave=False):
              params, state = pgLD.update(params, state)

            𝜇2, Σ2, w2 = params
            𝜇_opt, Σ_opt, w_opt = 𝜇2, Σ2, w2

            if jnp.isnan(𝜇2).any() or jnp.isnan(w2).any():
              # print("\tOpt NaN")
              break

            Σ2 = jax.vmap(safeSigma)(Σ2)
            𝜇2 = jnp.clip(𝜇2, low_val, hi_val)
            w2 = jax.nn.softmax(w2)


            rej_g = mm_log_pdf(rej_samples[0].reshape(-1, d), 𝜇2, Σ2, w2, fullCov=fullCov)
            cur_g = mm_log_pdf(cur_samples[0].reshape(-1, d), 𝜇2, Σ2, w2, fullCov=fullCov)
            diff_rej =  rej_samples_f[0]-rej_g
            diff_cur =  cur_samples_f[0]-cur_g
            logC_new = jnp.maximum(diff_rej.max(), diff_cur.max())
            if jnp.isnan(logC_new) :
              break
            #evaluate g(x) on the values we already have f(x) from this batch to see if we have improved the bound
            if logC_new < best_sol[3]: #log_C_old - jnp.sign(log_C_old)/100:
              log_C_old = logC_new

              best_sol = 𝜇2, Σ2, w2, logC_new #We did it, we improved the model
          return best_sol #No better solution found

        if gmm_try:
          𝜇_, Σ_, weights_, newC = optGMM(𝜇, Σ, weights, log_C_old)
          if newC < log_C_old:
            refine_accpt = True
            log_C_old = newC
            𝜇, Σ, weights, log_C_old = 𝜇_, Σ_, weights_, newC
        if refit: #gmm failed or no GMM round
          𝜇_, Σ_, weights_, newC = optGMM(𝜇_pre, Σ_pre, weights_pre, log_C_old)
        if newC < log_C_old:
          𝜇, Σ, weights, log_C_old = 𝜇_, Σ_, weights_, newC
          refine_accpt = True
        if not refine_accpt:
          𝜇, Σ, weights, log_C_old = 𝜇_old, Σ_old, weights_old, log_C_old
        log_C = log_C_old

        refit = False
        last_fit_sample_size = total_accepted
      refine_time += time.time() - start

      keya, subkey = jax.random.split(keya)
      history["fits"] += [(gmm_try, refine_try, refine_accpt)]
      # break
  cur_samples = [jnp.vstack(cur_samples)]
  cur_samples_f =  [jnp.hstack(cur_samples_f)]
  rej_samples = [jnp.vstack(rej_samples)]
  rej_samples_f =  [jnp.hstack(rej_samples_f)]
  history['time'] = {"gmm": gmm_time, "refine": refine_time, "init":init_time, "sampling":sampling_time}
  return total_accepted/f_eval_total, cur_samples, cur_samples_f, rej_samples, rej_samples_f, (𝜇, Σ, weights), history

# Testing

In [None]:
results_Maddison = {}
for a in tqdm([1, 5, 10, 15, 20], desc="Dimension Size"):
  results_Maddison[a] = []
  for seed in tqdm(range(10), desc="Seed", leave=False):
    accpt_rate, samples, samples_f, r, r_f, g_, history = ers(jax.random.PRNGKey(seed), lambda x: log_f_Maddison(x, a), d=1,
      target_samples=100000, samples_at_a_time=500, low_val=0, hi_val=jnp.inf)
    results_Maddison[a].append(float(accpt_rate))

  print("a: ", a)
  print("\t", np.mean(results_Maddison[a]), np.std(results_Maddison[a]))
  print("\t", results_Maddison[a])

In [None]:
results_Erraqabi = {}
for d in tqdm([1, 2, 3, 4, 5, 6, 7], desc="Dimension Size"): #Testing 1-7
  results_Erraqabi[d] = []
  for seed in tqdm(range(10), desc="Seed", leave=False):
    accpt_rate, samples, samples_f, r, r_f, g_, history = ers(jax.random.PRNGKey(seed), lambda x: log_f_Erraqabi(x), d=d,
      target_samples=100000, samples_at_a_time=500, low_val=0, hi_val=1)
    results_Erraqabi[d].append(float(accpt_rate))

  print("d: ", d)
  print("\t", np.mean(results_Erraqabi[d]), np.std(results_Erraqabi[d]))
  print("\t", results_Erraqabi[d])

In [None]:
results_clutter = {}
for d in tqdm([1, 2], desc="Dimension Size"): #Testing 1-2
  results_clutter[d] = []
  for seed in tqdm(range(10), desc="Seed", leave=False):

    sigma=2
    pi = 0.5 # https://github.com/cmaddis/astar-sampling/tree/c65c5ff4cd779d5b528500f14428c4f216e0482c/examples
    num_points = 10
    points = np.concatenate((np.linspace(-5, -3, num_points), np.linspace(2, 4, num_points)))
    data = np.zeros((len(points), d))
    for q in range(d):
      data[:,q] = points

    data = jnp.array(data)
    clutter = lambda x: log_f_clutter(x,data)
    break
    accpt_rate, samples, samples_f, r, r_f, g_, h = ers(jax.random.PRNGKey(seed), lambda x: clutter(x), d=d,
      target_samples=100000, samples_at_a_time=500)
    results_clutter[d].append(float(accpt_rate))

  print("d:")
  print("\t", np.mean(results_clutter[d]), np.std(results_clutter[d]))
  print("\t", results_clutter[d])

