## Jaxifying QVMC

In [1]:
import jax
import jax.numpy as jnp

In [3]:
key = jax.random.key(seed=2)

In [8]:
n_walkers = 4000
n_dim = 3

In [10]:
pos = jax.random.normal(key, (n_walkers, n_dim))

In [13]:
def f(x):
    """Assuming ndim input (x)"""
    r2 = jnp.sqrt(jnp.sum(x**2))
    return jnp.mean(x) * jnp.exp(1/r2) 

In [16]:
f_b = jax.vmap(f)

In [None]:
# Metropolis
from jax import lax

def mh_accept(x1, x2, lp_1, lp_2, ratio, key, num_accepts):
  """Given state, proposal, and probabilities, execute MH accept/reject step."""
  key, subkey = jax.random.split(key)
  rnd = jnp.log(jax.random.uniform(subkey, shape=ratio.shape))
  cond = ratio > rnd
  x_new = jnp.where(cond[..., None], x2, x1)
  lp_new = jnp.where(cond, lp_2, lp_1)
  num_accepts += jnp.sum(cond)
  return x_new, key, lp_new, num_accepts

def metropolis_step(params, f, data, key, mcmc_width=0.12, num_accepts=0):
    
    n_steps = 10

    def step_fn(x1, key):
      key, subkey = jax.random.split(key)

      x1 = data
      mcmc_width = 0.02
      x2 = x1 + mcmc_width * jax.random.normal(
        subkey, shape=x1.shape)

      lp_1 = 2.0 * f(pos)
      lp_2 = 2.0 * f(pos)
      ratio = lp_2 - lp_1
      
      x_new, key, lp_new, num_accepts = mh_accept(
        x1, x2, lp_1, lp_2, ratio, key, num_accepts)
      
    new_data, key, _, num_accepts = lax.fori_loop(
        0, n_steps, step_fn, (data, key)
    )
    pmove = jnp.sum(num_accepts) / (nsteps)

    return new_data, pmove

In [27]:
burn_in = 10000
for i in range(burn_in):
   pos, key, lp_new, num_accepts =  metropolis_step(pos, f_b, key, mcmc_width=0.12, num_accepts=0)

In [23]:
metropolis_step(pos, f_b, key)

(Array([[ 0.41309845,  1.2854261 , -0.78039163],
        [ 1.2078618 , -0.2172334 , -0.79383266],
        [ 0.6312194 , -0.89359784,  0.71911615],
        ...,
        [-1.6597576 , -0.46819875, -0.65987235],
        [-0.5324503 , -0.01884888, -1.8873758 ],
        [-0.3821253 ,  0.17635551, -0.74504066]], dtype=float32),
 Array((), dtype=key<fry>) overlaying:
 [1674946514 3241262372],
 Array([ 1.1644934 ,  0.18934496,  0.6212964 , ..., -3.1774983 ,
        -2.6770978 , -2.130944  ], dtype=float32),
 Array(4000, dtype=int32))