## Jaxifying QVMC

In [108]:
%load_ext autoreload
%autoreload 3

In [1]:
import jax
import jax.numpy as jnp

### Sampling

In [48]:
key = jax.random.key(seed=2)
n_walkers = 4200
n_dim = 3
pos = jax.random.normal(key, (n_walkers, n_dim))

In [3]:
def f(x):
    """Assuming ndim input (x)"""
    r2 = jnp.sqrt(jnp.sum(x**2))
    return jnp.mean(x) * jnp.exp(1/r2)

f_b = jax.vmap(f)  # batched network

In [None]:
# Metropolis
from jax import lax

def mh_accept(x1, x2, lp_1, lp_2, ratio, key, num_accepts):
  """Given state, proposal, and probabilities, execute MH accept/reject step."""
  key, subkey = jax.random.split(key)
  rnd = jnp.log(jax.random.uniform(subkey, shape=ratio.shape))
  cond = ratio > rnd
  x_new = jnp.where(cond, x2, x1)
  lp_new = jnp.where(cond, lp_2, lp_1)
  num_accepts += jnp.sum(cond)
  return x_new, key, lp_new, num_accepts

def metropolis_step(
      params, f, data, key, mcmc_width=0.12, num_accepts=jnp.array(0)):
    
    n_steps = 10

    def step_fn(i, carry):
      data, key, num_accepts = carry
      key, subkey = jax.random.split(key)

      x1 = data
      x2 = x1 + mcmc_width * jax.random.normal(
        subkey, shape=x1.shape)

      lp_1 = 2.0 * f(params, x1)
      lp_2 = 2.0 * f(params, x2)
      ratio = lp_2 - lp_1
      
      x_new, key, lp_new, num_accepts = mh_accept(
        x1, x2, lp_1, lp_2, ratio, key, num_accepts)
      
      return x_new, key, num_accepts
      
    new_data, key, num_accepts = lax.fori_loop(
        0, n_steps, step_fn, (data, key, num_accepts)
    )
    pmove = num_accepts / (n_steps * data.shape[0])

    return new_data, pmove

In [None]:
burn_in = 3
for i in range(burn_in):
   x_new, pmove =  metropolis_step(params, f_b, pos, key, mcmc_width=0.42)

In [54]:
pmove

Array(0.9860238, dtype=float32)

In [50]:
x_new.shape

(4200, 3)

In [23]:
metropolis_step(pos, f_b, key)

(Array([[ 0.41309845,  1.2854261 , -0.78039163],
        [ 1.2078618 , -0.2172334 , -0.79383266],
        [ 0.6312194 , -0.89359784,  0.71911615],
        ...,
        [-1.6597576 , -0.46819875, -0.65987235],
        [-0.5324503 , -0.01884888, -1.8873758 ],
        [-0.3821253 ,  0.17635551, -0.74504066]], dtype=float32),
 Array((), dtype=key<fry>) overlaying:
 [1674946514 3241262372],
 Array([ 1.1644934 ,  0.18934496,  0.6212964 , ..., -3.1774983 ,
        -2.6770978 , -2.130944  ], dtype=float32),
 Array(4000, dtype=int32))

### Network

In [4]:
from networks import MLP, psi_nn

In [5]:
model = MLP(input_dim=3, n_hidden_layers=2, hidden_dim=3, output_size=1)

In [12]:
x = jax.random.normal(key, 3)

In [8]:
variables = model.init(key, jax.random.normal(key, 3))
params = variables["params"]

In [79]:
def f(params, x):
    nn_out = model.apply({"params": params}, x)
    r = jnp.linalg.norm(x)
    # Hydrogenic term (example: Z=2)
    hydrogenic = -2.0 * r
    return jnp.squeeze(nn_out) + hydrogenic

In [None]:
f_b = jax.vmap(f, in_axes=(None, 0), out_axes=0)
f_b(params, jax.random.normal(key, (200,3)))

### Local Energy

In [74]:
x = jax.random.normal(key, 3)

In [92]:
def _lapl_over_f(params, data):
    n = data.shape[0]
    eye = jnp.eye(n)

    grad_f = jax.grad(f, argnums=1)
    
    def grad_f_closure(x):
        return grad_f(params, x)

    primal, dgrad_f = jax.linearize(grad_f_closure, data)

    hessian_diagonal = lambda i: dgrad_f(eye[i])[i]

    result = -0.5 * lax.fori_loop(
        0, n, lambda i, val: val + hessian_diagonal(i), 0.0)
    return result - 0.5 * jnp.sum(primal ** 2)

In [100]:
vmapped_ke = jax.vmap(_lapl_over_f, in_axes=(None, 0), out_axes=0)

In [102]:
vmapped_ke(params, pos)

Array([-0.67479753, -0.7058146 , -0.45219326, ...,  0.9200618 ,
       -0.12960386, -1.2327738 ], dtype=float32)

In [103]:
def pe(x):
    r = jnp.linalg.norm(x)
    return -1/r

In [104]:
def te(params, x):
    return pe(x) + _lapl_over_f(params, x)

In [115]:
local_energy = jax.vmap(_lapl_over_f, in_axes=(None, 0), out_axes=0)

In [117]:
jnp.mean(local_energy(params, pos))

Array(-0.404899, dtype=float32)

In [121]:
from local_energy import get_local_energy_fn

In [124]:
local_energy = get_local_energy_fn(f)

In [125]:
jnp.mean(local_energy(params, pos))

Array(-1.0010053, dtype=float32)

### Optimization