In [1]:
import jax
import jax.numpy  as jnp
import jax.random as jrnd
import matplotlib.pyplot as plt

In [2]:
def elliptical_slice(x0, log_lh_func, chol, num_samples, rng_key):

  @jax.jit
  def ess_step_condfun(state):
    x, new_x, nu, thresh, lower, upper, rng_key = state
    llh = log_lh_func(new_x)
    return log_lh_func(new_x) < thresh

  @jax.jit
  def ess_step_bodyfun(state):
    x, new_x, nu, thresh, lower, upper, rng_key = state
    theta_rng, rng_key = jrnd.split(rng_key, 2)
    theta = jrnd.uniform(theta_rng, minval=lower, maxval=upper)
    new_x = x*jnp.cos(theta) + nu*jnp.sin(theta)
    lower, upper = jax.lax.cond(theta < 0, lambda : (theta, upper), lambda : (lower, theta))
    return x, new_x, nu, thresh, lower, upper, rng_key

  @jax.jit
  def ess_step(x, rng_key):
    nu_rng, u_rng, theta_rng, rng_key = jrnd.split(rng_key, 4)
    nu = jrnd.normal(nu_rng, shape=x.shape)
    u = jrnd.uniform(u_rng)
    thresh = log_lh_func(x) + jnp.log(u)
    theta = jrnd.uniform(theta_rng, minval=0, maxval=2*jnp.pi)
    upper = theta
    lower = theta - 2*jnp.pi
    new_x = x*jnp.cos(theta) + nu*jnp.sin(theta)
    _, new_x, _, _, _, _, _ = jax.lax.while_loop(
      ess_step_condfun,
      ess_step_bodyfun,
      (x, new_x, nu, thresh, lower, upper, rng_key)
    )
    return new_x

  @jax.jit
  def scanfunc(state, xs):
    x, rng_key = state
    step_key, rng_key = jrnd.split(rng_key, 2)
    x = ess_step(x, step_key)
    return (x, rng_key), x

  _, samples = jax.lax.scan(scanfunc, (x0, rng_key), None, num_samples)

  return samples

In [3]:
rng_key = jrnd.PRNGKey(1)

In [4]:
elliptical_slice(jnp.ones(3), lambda x: jnp.log(jnp.all(x>0)), jnp.eye(3), 50, rng_key)

DeviceArray([[1.035393  , 0.94472396, 0.9813875 ],
             [0.5938187 , 0.93434364, 1.3151302 ],
             [0.50829834, 0.74971354, 1.3903047 ],
             [1.0801375 , 1.1392491 , 0.66086346],
             [0.8132634 , 1.3561089 , 0.11626202],
             [0.74596715, 0.25006914, 0.22099783],
             [0.26355907, 0.19394132, 0.42499414],
             [0.40187183, 0.44570813, 0.37252545],
             [0.2937717 , 0.757076  , 0.00482692],
             [0.6341446 , 0.0391647 , 0.1550641 ],
             [0.135018  , 0.9054813 , 1.4817276 ],
             [0.16691567, 0.7839114 , 1.4554222 ],
             [0.45004037, 0.6983323 , 0.21856627],
             [0.32851306, 0.8990807 , 0.8106946 ],
             [0.46958178, 0.945887  , 0.8042547 ],
             [1.035112  , 1.0443119 , 0.37906447],
             [1.1944005 , 1.4324472 , 0.10578489],
             [1.7111918 , 1.58958   , 0.47994557],
             [0.50697035, 1.5274148 , 1.1078554 ],
             [0.7584119 , 1.171