## JAX VMC
Code taken from - https://teddykoker.com/2024/11/neural-vmc-jax/

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import math

### Local energy

In [None]:
def local_energy(wavefunction, atoms, charges, pos):
    return kinetic_energy(wavefunction, pos) + potential_energy(atoms, charges, pos)

def kinetic_energy(wavefunction, pos):
    """Kinetic energy term of Hamiltonian"""
    laplacian = jnp.trace(jax.hessian(wavefunction)(pos))
    return -0.5 * laplacian / wavefunction(pos)

def potential_energy(pos):

    r1 = jnp.linalg.norm(pos[:, :, :3], axis=-1)
    r2 = jnp.linalg.norm(pos[:, :, 3:], axis=-1)
    r12 = jnp.linalg.norm(pos[:, :, :3] - pos[:, :, 3:], axis=-1)
    eps = 1e-8
    return -2 / (r1 + eps) - 2 / (r2 + eps) + 1 / (r12 + eps)

### Metropolis sampling

In [3]:
import equinox as eqx
from functools import partial
from collections.abc import Callable

@eqx.filter_jit
@partial(jax.vmap, in_axes=(None, 0, None, None, 0))
def metropolis(
    wavefunction: Callable,
    pos: jax.Array,
    step_size: float,
    mcmc_steps: int,
    key: jax.Array,
):
    """MCMC step

    Args:
        wavefunction: neural wavefunction
        pos: [3N] current electron positions flattened
        step_size: std of proposal for metropolis sampling
        mcmc_steps: number of steps to perform
        key: random key
    """

    def step(_, carry):
        pos, prob, num_accepts, key = carry
        key, subkey = jax.random.split(key)
        pos_proposal = pos + step_size * jax.random.normal(subkey, shape=pos.shape)
        prob_proposal = wavefunction(pos_proposal) ** 2

        key, subkey = jax.random.split(key)
        accept = jax.random.uniform(subkey) < prob_proposal / prob
        prob = jnp.where(accept, prob_proposal, prob)
        pos = jnp.where(accept, pos_proposal, pos)
        num_accepts = num_accepts + jnp.sum(accept)

        return pos, prob, num_accepts, key

    prob = wavefunction(pos) ** 2
    carry = (pos, prob, 0, key)
    pos, prob, num_accepts, key = jax.lax.fori_loop(0, mcmc_steps, step, carry)
    return pos, num_accepts / mcmc_steps

In [6]:
def wavefunction_h(pos):
    return jnp.exp(-jnp.linalg.norm(pos))

### Neural network, gradients

In [7]:
def make_loss(atoms, charges):
    # Based on implementation in https://github.com/google-deepmind/ferminet/

    @eqx.filter_custom_jvp
    def total_energy(wavefunction, pos):
        """Define L()"""
        batch_local_energy = jax.vmap(local_energy, (None, None, None, 0))
        e_l = batch_local_energy(wavefunction, atoms, charges, pos)
        loss = jnp.mean(e_l)
        return loss, e_l

    @total_energy.def_jvp
    def total_energy_jvp(primals, tangents):
        """Define the gradient of L()"""
        wavefunction, pos = primals
        log_wavefunction = lambda psi, pos: jnp.log(psi(pos))
        batch_wavefunction = jax.vmap(log_wavefunction, (None, 0))
        psi_primal, psi_tangent = eqx.filter_jvp(batch_wavefunction, primals, tangents)
        loss, local_energy = total_energy(wavefunction, pos)
        primals_out = loss, local_energy
        batch_size = jnp.shape(local_energy)[0]
        tangents_out = (jnp.dot(psi_tangent, local_energy - loss) / batch_size, local_energy)
        return primals_out, tangents_out

    return total_energy

In [8]:
class Linear(eqx.Module):
    """Linear layer"""

    weights: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        lim = math.sqrt(1 / (in_size + out_size))
        self.weights = jax.random.uniform(key, (in_size, out_size), minval=-lim, maxval=lim)
        self.bias = jnp.zeros(out_size)

    def __call__(self, x):
        return jnp.dot(x, self.weights) + self.bias


class PsiMLP(eqx.Module):
    """Simple MLP-based model using Slater determinant"""

    spins: tuple[int, int]
    linears: list[Linear]
    orbitals: Linear
    sigma: jax.Array 
    pi: jax.Array

    def __init__(
        self,
        hidden_sizes: list[int],
        spins: tuple[int, int],
        determinants: int,
        key: jax.Array,
    ):
        num_atoms = 1  # assume one atom
        sizes = [5] + hidden_sizes  # 5 input features
        key, *keys = jax.random.split(key, len(sizes))
        self.linears = []
        for i in range(len(sizes) - 1):
            self.linears.append(Linear(sizes[i], sizes[i + 1], keys[i]))
        self.orbitals = Linear(sizes[-1], sum(spins) * determinants, key)
        self.sigma = jnp.ones((num_atoms, sum(spins) * determinants))
        self.pi = jnp.ones((num_atoms, sum(spins) * determinants))
        self.spins = spins

    def __call__(self, pos):
        # atom electron displacement [electron, atom, 3]
        ae = pos.reshape(-1, 1, 3)
        # atom electron distance [electron, atom, 1]
        r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True)
        # feature for spins; 1 for up, -1 for down [atom, 1]
        spins = jnp.concatenate([jnp.ones(self.spins[0]), jnp.ones(self.spins[1]) * -1])

        # combine into features
        h = jnp.concatenate([ae, r_ae], axis=2)
        h = h.reshape([h.shape[0], -1])
        h = jnp.concatenate([h, spins[:, None]], axis=1)

        # multi-layer perceptron with tanh activations
        for linear in self.linears:
            h = jnp.tanh(linear(h))

        phi = self.orbitals(h) * jnp.sum(self.pi * jnp.exp(-self.sigma * r_ae), axis=1)

        # [electron, electron * determinants] -> [determinants, electron, electron]
        phi = phi.reshape(phi.shape[0], -1, phi.shape[0]).transpose(1, 0, 2)
        det = jnp.linalg.det(phi)
        return jnp.sum(det)

### Main function

In [9]:
import optax
from tqdm import tqdm

def vmc(
    wavefunction: Callable,
    atoms: jax.Array,
    charges: jax.Array,
    spins: tuple[int, int],
    *,
    batch_size: int = 4096,
    mcmc_steps: int = 50,
    warmup_steps: int = 200,
    init_width: float = 0.4,
    step_size: float = 0.2,
    learning_rate: float = 3e-3,
    iterations: int = 2_000,
    key: jax.Array,
):
    """Perform variational Monte Carlo

    Args:
        wavefunction: neural wavefunction
        atoms: [M, 3] atomic positions
        charges: [M] atomic charges
        spins: number spin-up, spin-down electrons
        batch_size: number of electron configurations to sample
        mcmc_steps: number of mcmc steps to perform between neural network
            updates (lessens autocorrelation)
        warmup_steps: number of mcmc steps to perform before starting training
        step_size: std of proposal for metropolis sampling
        learning_rate: learning rate
        iterations: number of neural network updates
        key: random key
    """
    total_energy = make_loss(atoms, charges)

    # initialize electron positions and perform warmup mcmc steps
    key, subkey = jax.random.split(key)
    pos = init_width * jax.random.normal(subkey, shape=(batch_size, sum(spins) * 3))
    key, *subkeys = jax.random.split(key, batch_size + 1)
    pos, _ = metropolis(wavefunction, pos, step_size, warmup_steps, jnp.array(subkeys))

    # Adam optimizer with gradient clipping
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(learning_rate))
    opt_state = optimizer.init(eqx.filter(wavefunction, eqx.is_array))

    @eqx.filter_jit
    def train_step(wavefunction, pos, key, opt_state):
        key, *subkeys = jax.random.split(key, batch_size + 1)
        pos, accept = metropolis(wavefunction, pos, step_size, mcmc_steps, jnp.array(subkeys))
        (loss, _), grads = eqx.filter_value_and_grad(total_energy, has_aux=True)(wavefunction, pos)
        updates, opt_state = optimizer.update(grads, opt_state, wavefunction)
        wavefunction = eqx.apply_updates(wavefunction, updates)
        return wavefunction, pos, key, opt_state, loss, accept

    losses, pmoves = [], []
    pbar = tqdm(range(iterations))
    for _ in pbar:
        wavefunction, pos, key, opt_state, loss, pmove = train_step(wavefunction, pos, key, opt_state)
        pmove = pmove.mean()
        losses.append(loss)
        pmoves.append(pmove)
        pbar.set_description(f"Energy: {loss:.4f}, P(move): {pmove:.2f}")

    return losses, pmoves

In [10]:
# Lithium at origin
atoms = jnp.zeros((1, 2))
charges = jnp.array([3.0])
spins = (2, 1) # 2 spin-up, 1 spin-down electrons

In [13]:
pos = 0.4 * jax.random.normal(key, shape=(4096, sum(spins) * 3))

In [15]:
pos.shape

(4096, 9)

In [None]:
key = jax.random.key(0)
key, subkey = jax.random.split(key)
model = PsiMLP(hidden_sizes=[64, 64, 64], determinants=4, spins=spins, key=key)
losses, _ = vmc(wavefunction_h, atoms, charges, spins, key=subkey)

Energy: -6.3426, P(move): 0.61:   2%|▏         | 46/2000 [00:38<27:15,  1.20it/s] 


KeyboardInterrupt: 