## Jaxifying QVMC

In [108]:
%load_ext autoreload
%autoreload 3

In [1]:
import jax
import jax.numpy as jnp

### Sampling

In [48]:
key = jax.random.key(seed=2)
n_walkers = 4200
n_dim = 3
pos = jax.random.normal(key, (n_walkers, n_dim))

In [3]:
def f(x):
    """Assuming ndim input (x)"""
    r2 = jnp.sqrt(jnp.sum(x**2))
    return jnp.mean(x) * jnp.exp(1/r2)

f_b = jax.vmap(f)  # batched network

In [225]:
# Metropolis
from jax import lax

def mh_accept(x1, x2, lp_1, lp_2, ratio, key, num_accepts):
  """Given state, proposal, and probabilities, execute MH accept/reject step."""
  key, subkey = jax.random.split(key)
  rnd = jnp.log(jax.random.uniform(subkey, shape=ratio.shape))
  cond = ratio > rnd
  cond = cond[:, None]
  x_new = jnp.where(cond, x2, x1)
  lp_new = jnp.where(cond, lp_2, lp_1)
  num_accepts += jnp.sum(cond)
  return x_new, key, lp_new, num_accepts

def metropolis_step(
      params, f, data, key, mcmc_width=0.12):
    
    n_steps = 10

    def step_fn(i, carry):
      data, key, num_accepts = carry
      key, subkey = jax.random.split(key)

      x1 = data
      x2 = x1 + mcmc_width * jax.random.normal(
        subkey, shape=x1.shape)

      lp_1 = 2.0 * f(params, x1)
      lp_2 = 2.0 * f(params, x2)
      ratio = lp_2 - lp_1
      
      x_new, key, lp_new, num_accepts = mh_accept(
        x1, x2, lp_1, lp_2, ratio, key, num_accepts)
      
      return x_new, key, num_accepts
      
    new_data, key, num_accepts = lax.fori_loop(
        0, n_steps, step_fn, (data, key, jnp.array(0))
    )
    pmove = num_accepts / (n_steps * data.shape[0])

    return new_data, pmove

In [226]:
burn_in = 3
for i in range(burn_in):
   x_new, pmove =  metropolis_step(params, f_b, pos, key, mcmc_width=0.42)

In [227]:
pmove

Array(0.9984286, dtype=float32)

In [50]:
x_new.shape

(4200, 3)

In [23]:
metropolis_step(pos, f_b, key)

(Array([[ 0.41309845,  1.2854261 , -0.78039163],
        [ 1.2078618 , -0.2172334 , -0.79383266],
        [ 0.6312194 , -0.89359784,  0.71911615],
        ...,
        [-1.6597576 , -0.46819875, -0.65987235],
        [-0.5324503 , -0.01884888, -1.8873758 ],
        [-0.3821253 ,  0.17635551, -0.74504066]], dtype=float32),
 Array((), dtype=key<fry>) overlaying:
 [1674946514 3241262372],
 Array([ 1.1644934 ,  0.18934496,  0.6212964 , ..., -3.1774983 ,
        -2.6770978 , -2.130944  ], dtype=float32),
 Array(4000, dtype=int32))

### Network

In [189]:
from networks import MLP, psi_nn

In [190]:
model = MLP(input_dim=3, n_hidden_layers=2, hidden_dim=3, output_size=1)

In [191]:
x = jax.random.normal(key, 3)

In [192]:
variables = model.init(key, jax.random.normal(key, 3))
params = variables["params"]

In [229]:
def f(params, x):
    nn_out = model.apply({"params": params}, x)
    r = jnp.linalg.norm(x)
    # Hydrogenic term (example: Z=2)
    hydrogenic = -1.0 * r
    return jnp.squeeze(nn_out) + hydrogenic

In [206]:
f_b = jax.vmap(f, in_axes=(None, 0), out_axes=0)
f_b(params, jax.random.normal(key, (200,3)))

Array([ 0.0258584 , -0.03850148,  0.00148375,  0.01763745,  0.02979456,
        0.0151001 , -0.07553628, -0.0882896 , -0.08894154, -0.08137566,
        0.06084321, -0.02863689, -0.00715178,  0.05185672, -0.02112361,
        0.04854413, -0.04653417,  0.07055479, -0.05292021,  0.07034159,
        0.01887535,  0.06965564,  0.05661403, -0.00995859,  0.05106074,
       -0.05164292, -0.0085656 ,  0.06881385,  0.04231528,  0.03084136,
       -0.05243913, -0.06109716, -0.00978027, -0.04573929,  0.07336047,
        0.06626271,  0.02963621,  0.06190407, -0.02899716, -0.08802418,
        0.01983056,  0.08222185, -0.00862097,  0.00100727,  0.08454274,
       -0.06329973, -0.07004205, -0.06267394,  0.03833655, -0.06802086,
       -0.09524848,  0.08415908,  0.03329285,  0.03897749,  0.04141916,
        0.08084417, -0.06532656, -0.05613819, -0.01659333,  0.04253322,
       -0.07285166, -0.0250865 , -0.04652925,  0.08676142, -0.04791211,
        0.01011629,  0.07735635,  0.01126209, -0.05565479, -0.07

### Local Energy

In [207]:
def _lapl_over_f(params, data):
    n = data.shape[0]
    eye = jnp.eye(n)

    grad_f = jax.grad(f, argnums=1)
    
    def grad_f_closure(x):
        return grad_f(params, x)

    primal, dgrad_f = jax.linearize(grad_f_closure, data)

    hessian_diagonal = lambda i: dgrad_f(eye[i])[i]

    result = -0.5 * lax.fori_loop(
        0, n, lambda i, val: val + hessian_diagonal(i), 0.0)
    return result - 0.5 * jnp.sum(primal ** 2)

In [208]:
from local_energy import get_local_energy_fn

In [209]:
local_energy = get_local_energy_fn(f)

In [210]:
e_loc = local_energy(params, pos)
mean_energy = jnp.mean(e_loc)

In [211]:
e_loc

Array([-0.02742211, -0.02419916, -0.02406534, ..., -0.08091354,
       -0.01965621, -0.02805981], dtype=float32)

In [212]:
mean_energy

Array(-0.04365415, dtype=float32)

In [213]:
jnp.mean((e_loc - mean_energy)**2)

Array(0.00097862, dtype=float32)

### Optimization

In [221]:
# not doing it via loss and grad

def energy_loss(params, data):
    eloc = local_energy(params, data)   # (B,)
    return jnp.mean(eloc)


def energy_grads(params, data):
    # ----- forward pass -----
    eloc = local_energy(params, data)          # (B,)
    loss = jnp.mean(eloc)
    diff = eloc - loss                         # (B,)

    # ----- log Ïˆ -----
    log_psi = lambda p: jnp.log(f_b(p, data))  # (B,)

    # ----- VJP -----
    _, vjp_fn = jax.vjp(log_psi, params)

    # Contract batch dimension first
    cotangent = diff / diff.shape[0]            # (B,)

    grads = vjp_fn(cotangent)[0]

    return loss, grads


In [222]:
pos.shape

(4200, 3)

In [223]:
flat_params, ravel_fn = ravel_pytree(params)

In [224]:
loss, grads = energy_grads(params, pos)

In [None]:
from jax.flatten_util import ravel_pytree

In [149]:
grad_flat, unravel_fn = ravel_pytree(grads)

In [141]:
energy_grads(params, pos)

(Array(-1.0010053, dtype=float32),
 {'Dense_0': {'bias': Array([ 0.03250854, -0.01686911,  0.02931045], dtype=float32),
   'kernel': Array([[ 0.0003809 , -0.00022431,  0.00037943],
          [ 0.00238728, -0.00078759,  0.00232838],
          [-0.00251525,  0.00076611, -0.00235071]], dtype=float32)},
  'Dense_1': {'bias': Array([-0.65068114,  0.6625924 , -0.21817963], dtype=float32),
   'kernel': Array([[-0.02279942,  0.02266682, -0.0076655 ],
          [ 0.00333157, -0.00304474,  0.00200873],
          [ 0.02362109, -0.02323238,  0.00827753]], dtype=float32)},
  'Dense_2': {'bias': Array([0.833125], dtype=float32),
   'kernel': Array([[0.02169407],
          [0.0214928 ],
          [0.00057703]], dtype=float32)}})

In [232]:
import optax
from jax.flatten_util import ravel_pytree

# ------------------
# Optimizer
# ------------------
learning_rate = 2e-2
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# ------------------
# Burn-in
# ------------------
burn_in = 10
for i in range(burn_in):
    pos, pmove = metropolis_step(
        params, f_b, pos, key, mcmc_width=0.12
    )

# ------------------
# Optimization loop
# ------------------
nsteps = 100
for step in range(nsteps):

    # --- MCMC move ---
    pos, pmove = metropolis_step(
        params, f_b, pos, key, mcmc_width=0.12
    )

    # --- Energy + gradients ---
    loss, grads = energy_grads(params, pos)

    # --- Optax update ---
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    print(
        f"Step {step:04d} | "
        f"Energy {loss:.6f} | "
        f"Acceptance {pmove:.3f}"
    )


Step 0000 | Energy -0.022312 | Acceptance 0.999
Step 0001 | Energy -0.023441 | Acceptance 0.999
Step 0002 | Energy -0.022859 | Acceptance 0.999
Step 0003 | Energy -0.022845 | Acceptance 0.998
Step 0004 | Energy -0.023420 | Acceptance 0.999
Step 0005 | Energy -0.023349 | Acceptance 0.999
Step 0006 | Energy -0.022783 | Acceptance 0.999
Step 0007 | Energy -0.023073 | Acceptance 0.999
Step 0008 | Energy -0.022771 | Acceptance 0.998
Step 0009 | Energy -0.023299 | Acceptance 0.998
Step 0010 | Energy -0.023575 | Acceptance 0.999
Step 0011 | Energy -0.023139 | Acceptance 0.998
Step 0012 | Energy -0.022602 | Acceptance 0.998
Step 0013 | Energy -0.023437 | Acceptance 0.998
Step 0014 | Energy -0.023166 | Acceptance 0.998
Step 0015 | Energy -0.022666 | Acceptance 0.998
Step 0016 | Energy -0.021841 | Acceptance 0.998
Step 0017 | Energy -0.021418 | Acceptance 0.998
Step 0018 | Energy -0.021625 | Acceptance 0.998
Step 0019 | Energy -0.021572 | Acceptance 0.998
Step 0020 | Energy -0.021181 | Acceptanc

KeyboardInterrupt: 