In [70]:
import numpy as np
import scipy as sp

from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
import jax
from jax import vmap, jit, random, tree_util, numpy as jnp, grad, jvp, lax, hessian
import matplotlib.pyplot as plt
import pickle
from functools import reduce

from flax import linen as nn
from functools import partial

print = partial(print, flush=True)

import math
import itertools
from dataclasses import dataclass, field
from typing import Callable, Optional, Sequence, Union, Any, Tuple

from abc import ABC, abstractmethod
from enum import Enum
import pytest

np.set_printoptions(suppress=True, precision=5)

<div style="background-color: white; color: black; padding: 14px; font-size: 16px;">

When you take log JAX grad will likely complain about complex values functions. So we need to keep track of sign and magnitude separately. We use the fact that away from the nodes gradient of the sign should be zero, and we will never be exactly on the node, so we can just use log of absolute value. Given $\psi = s e^f$ where $f = \log|\psi|$ and $s = \text{sign}(\psi)$.

$$\nabla \psi = s e^f \nabla f$$

$$\nabla^2 \psi = s e^f [\nabla^2 f + (\nabla f)^2]$$

$$\frac{\nabla^2 \psi}{\psi} = \nabla^2 f + (\nabla f)^2$$

</div>


In [179]:
@dataclass
class ueg:
    r_s: float
    n_elec: Tuple[int, int]
    box_length: float = 0.0
    rec_lattice: Tuple = ()
    dim: int = 3
    volume: float = 0.0
    n_particles: int = 0
    density: float = 0.0
    seed: int = 0

    def __post_init__(self):
        assert self.dim == 3, "Only 3D systems are supported."
        assert (
            self.n_elec[0] == self.n_elec[1]
        ), "Only unpolarized systems are supported."
        self.box_length = (
            4 / 3 * jnp.pi * self.r_s**3 * (self.n_elec[0] + self.n_elec[1])
        ) ** (1 / 3)
        self.rec_lattice = (2 * jnp.pi / self.box_length,) * 3
        self.volume = self.box_length**3
        self.n_particles = self.n_elec[0] + self.n_elec[1]
        self.density = self.n_particles / self.volume

    def get_occ_k_points(self) -> jax.Array:
        """Get the occupied k-points for the system."""
        dk = 1 + 1e-5
        max_k = int(jnp.ceil(self.n_elec[0] * dk) ** (1 / 3.0))
        ordinals = sorted(range(-max_k, max_k + 1), key=abs)
        ordinals = jnp.asarray(list(itertools.product(ordinals, repeat=3)))
        kpoints = ordinals @ (jnp.array(self.rec_lattice) * jnp.eye(3)).T
        kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm))
        k_norms = jnp.linalg.norm(kpoints, axis=1)
        return kpoints[k_norms <= k_norms[self.n_elec[0] - 1] * dk]

    @partial(jit, static_argnums=(0,))
    def _calc_dis(self, pos: jax.Array) -> Tuple:
        box_length = jnp.array([self.box_length, self.box_length, self.box_length])
        pos_up = pos[0]
        pos_dn = pos[1]
        pos_flat = jnp.concatenate([pos_up, pos_dn], axis=0)
        n_particles = pos_flat.shape[0]

        def get_disp(i, j):
            dr = pos_flat[i] - pos_flat[j]
            dr = dr - box_length * jnp.round(dr / box_length)
            return dr

        disp = vmap(
            lambda i: vmap(get_disp, in_axes=(None, 0))(i, jnp.arange(n_particles))
        )(jnp.arange(n_particles))
        dist = jnp.sqrt(jnp.sum(disp**2, axis=-1) + 1e-10)
        mask = ~jnp.eye(n_particles, dtype=bool)
        dist = jnp.where(mask, dist, 0.0)
        return dist, disp

    def init_walker_data(self, n_walkers: int) -> dict:
        def walker_init(subkey):
            subkey, subkey_up = random.split(subkey)
            pos_up = random.uniform(subkey_up, (self.n_elec[0], 3)) * self.box_length
            subkey, subkey_dn = random.split(subkey)
            pos_dn = random.uniform(subkey_dn, (self.n_elec[1], 3)) * self.box_length
            pos = jnp.array([pos_up, pos_dn])
            dist, disp = self._calc_dis(pos)
            return pos, dist, disp

        random_key = random.PRNGKey(self.seed)
        random_key, *subkeys = random.split(random_key, n_walkers + 1)
        pos, dist, disp = vmap(walker_init)(jnp.array(subkeys))
        return {
            "pos": pos,
            "dist": dist,
            "disp": disp,
            "random_key": random_key,
        }

    @partial(jit, static_argnums=(0,))
    def update_walker_data(self, new_pos_batch: jax.Array, walker_data: dict) -> dict:
        assert new_pos_batch.shape == walker_data["pos"].shape

        def update_single_walker(carry, new_pos_i):
            dist, disp = self._calc_dis(new_pos_i)
            return carry, (dist, disp)

        _, (dist, disp) = lax.scan(update_single_walker, None, new_pos_batch)
        walker_data["dist"] = dist
        walker_data["disp"] = disp
        walker_data["pos"] = new_pos_batch
        return walker_data

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))

In [128]:
system = ueg(r_s=2.0, n_elec=(2, 2), seed=42)


def test_calc_dis_diagonal():
    """Test that diagonal elements of distance matrix are zero."""
    pos_up = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
    pos_dn = jnp.array([[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]])
    pos = jnp.array([pos_up, pos_dn])
    dist, _ = system._calc_dis(pos)
    for i in range(system.n_particles):
        assert dist[i, i] == 0.0


def test_calc_dis_symmetric():
    """Test that distance matrix is symmetric."""
    pos_up = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
    pos_dn = jnp.array([[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]])
    pos = jnp.array([pos_up, pos_dn])
    dist, disp = system._calc_dis(pos)
    for i in range(system.n_particles):
        for j in range(system.n_particles):
            assert jnp.isclose(dist[i, j], dist[j, i])
            if i != j:
                assert jnp.allclose(disp[i, j], -disp[j, i])


def test_calc_dis_periodic():
    """Test periodic boundary conditions in distance calculation."""
    box_length = system.box_length
    pos_up = jnp.array(
        [[0.1, 0.2, 0.3], [box_length - 0.1, box_length - 0.2, box_length - 0.3]]
    )
    pos_dn = jnp.array([[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]])
    pos = jnp.array([pos_up, pos_dn])
    dist, _ = system._calc_dis(pos)
    assert dist[0, 1] < 0.9


def test_calc_dis_known_distance():
    """Test distance calculation with known values."""
    pos_up = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]])
    pos_dn = jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
    pos = jnp.array([pos_up, pos_dn])
    dist, disp = system._calc_dis(pos)
    assert jnp.isclose(dist[0, 1], 1.0)
    assert jnp.isclose(dist[0, 2], 1.0)
    assert jnp.isclose(dist[0, 3], 1.0)
    assert jnp.allclose(disp[0, 1], -jnp.array([1.0, 0.0, 0.0]))
    assert jnp.allclose(disp[0, 2], -jnp.array([0.0, 1.0, 0.0]))
    assert jnp.allclose(disp[0, 3], -jnp.array([0.0, 0.0, 1.0]))


def test_init_walker_data():
    """Test initialization of walker data."""
    n_walkers = 3
    walker_data = system.init_walker_data(n_walkers)
    assert "pos" in walker_data
    assert "dist" in walker_data
    assert "disp" in walker_data
    assert "random_key" in walker_data
    assert walker_data["pos"].shape == (n_walkers, 2, system.n_elec[0], 3)
    assert walker_data["dist"].shape == (
        n_walkers,
        system.n_particles,
        system.n_particles,
    )
    assert walker_data["disp"].shape == (
        n_walkers,
        system.n_particles,
        system.n_particles,
        3,
    )
    assert jnp.all(walker_data["pos"] >= 0.0)
    assert jnp.all(walker_data["pos"] <= system.box_length)


def test_update_walker_data():
    """Test updating walker data with new positions."""
    n_walkers = 2
    walker_data = system.init_walker_data(n_walkers)
    new_pos_up = jnp.array(
        [[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]
    )
    new_pos_dn = jnp.array(
        [[[1.3, 1.4, 1.5], [1.6, 1.7, 1.8]], [[1.9, 2.0, 2.1], [2.2, 2.3, 2.4]]]
    )
    new_pos_batch = jnp.array(
        [[new_pos_up[0], new_pos_dn[0]], [new_pos_up[1], new_pos_dn[1]]]
    )
    updated_data = system.update_walker_data(new_pos_batch, walker_data)
    assert jnp.allclose(updated_data["pos"], new_pos_batch)
    manual_dist, manual_disp = system._calc_dis(new_pos_batch[0])
    assert jnp.allclose(updated_data["dist"][0], manual_dist)
    assert jnp.allclose(updated_data["disp"][0], manual_disp)


test_calc_dis_diagonal()
test_calc_dis_symmetric()
test_calc_dis_periodic()
test_calc_dis_known_distance()
test_init_walker_data()
test_update_walker_data()

In [151]:
@dataclass
class wave_function(ABC):
    """Abstract base class for wave functions."""

    n_elec: Tuple[int, int]
    use_hessian: bool = False

    @partial(jit, static_argnums=(0, 2))
    @abstractmethod
    def _evaluate_log(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate both the logarithm of the absolute value and the sign of the wave function.

        Args:
            pos: The positions of the electrons. shape: (2, n_elec_s, 3)
            sys: The system object containing system parameters.
            wave_data: Additional data needed for the evaluation.

        Returns:
            (log(|psi|), sign)
        """
        pass

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log_abs(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Get only the log abs part for gradient calculations."""
        log_abs, _ = self._evaluate_log(pos, ueg_sys, wave_data)
        return log_abs

    @partial(jit, static_argnums=(0, 2))
    def evaluate(
        self, pos_batch: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> Union[jax.Array, Tuple[jax.Array, jax.Array]]:
        """Evaluate the wave function log for a batch of positions.

        Args:
            pos_batch: Batch of positions
            ueg_sys: The system object containing system parameters
            wave_data: Additional data needed for the evaluation

        Returns:
            (log(|psi|), signs)
        """

        return vmap(self._evaluate_log, in_axes=(0, None, None))(
            pos_batch, ueg_sys, wave_data
        )

    @partial(jit, static_argnums=(0, 2))
    def _gradient(self, pos: jnp.array, ueg_sys: ueg, wave_data: dict) -> jax.Array:
        """Calculate the gradient of log(|psi|) wrt positions.

        Args:
            pos: The positions of the electrons. shape: (2, n_elec_s, 3)
            sys: The system object containing system parameters.
            wave_data: Additional data needed for the evaluation.
        Returns:
            The gradient of log(|psi|).
        """
        return grad(self._evaluate_log_abs)(pos, ueg_sys, wave_data)

    @partial(jit, static_argnums=(0, 2))
    def gradient(
        self, pos_batch: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Calculate the gradient of log(|psi|) for a batch of positions."""
        return vmap(self._gradient, in_axes=(0, None, None))(
            pos_batch, ueg_sys, wave_data
        )

    @partial(jit, static_argnums=(0, 2))
    def _laplacian(self, pos: jnp.array, ueg_sys: ueg, wave_data: dict) -> jax.Array:
        """Calculate the Laplacian of log(|psi|).

        Args:
            pos: The positions of the electrons. shape: (2, n_elec_s, 3)
            sys: The system object containing system parameters.
            wave_data: Additional data needed for the evaluation.
        Returns:
            The Laplacian of log(|psi|).
        """
        orig_shape = pos.shape
        flat_pos = pos.reshape(-1)
        n_coords = flat_pos.size

        def log_abs_fn(p):
            p_reshaped = p.reshape(orig_shape)
            return self._evaluate_log_abs(p_reshaped, ueg_sys, wave_data)

        def flat_grad_fn(flat_p):
            return grad(log_abs_fn)(flat_p).reshape(-1)

        def hess_diag_element(i):
            unit_vec = jnp.zeros(n_coords).at[i].set(1.0)
            _, hess_i = jvp(flat_grad_fn, (flat_pos,), (unit_vec,))
            return hess_i[i]

        diagonal = vmap(hess_diag_element)(jnp.arange(n_coords))

        return jnp.sum(diagonal)

    @partial(jit, static_argnums=(0, 2))
    def _laplacian_hessian(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Laplacian of log(|psi|) using jax Hessian function.

        Constructs the full Hessian matrix of log(|psi|).
        Probably faster for small systems, but uses more memory.
        """
        orig_shape = pos.shape
        flat_pos = pos.reshape(-1)

        def flat_eval_fn(flat_p):
            p_reshaped = flat_p.reshape(orig_shape)
            return self._evaluate_log_abs(p_reshaped, ueg_sys, wave_data)

        return jnp.trace(hessian(flat_eval_fn)(flat_pos))

    @partial(jit, static_argnums=(0, 2))
    def laplacian(
        self, pos_batch: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Calculate the Laplacian of log(|psi|) for a batch of positions."""
        if self.use_hessian:
            _lap_fun = self._laplacian_hessian
        else:
            _lap_fun = self._laplacian

        return vmap(_lap_fun, in_axes=(0, None, None))(pos_batch, ueg_sys, wave_data)

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))


@dataclass
class product_state(wave_function):
    """Product of wave functions in the states."""

    n_elec: Tuple[int, int]
    states: Tuple[wave_function, ...]
    use_hessian: bool = field(default=False, init=False)

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate both log(|psi|) and sign of the product wave function."""
        results = [
            state._evaluate_log(pos, ueg_sys, wave_data) for state in self.states
        ]

        log_abs_values = [res[0] for res in results]
        log_abs_psi = jnp.sum(jnp.array(log_abs_values))

        signs = [res[1] for res in results]
        combined_sign = jnp.prod(jnp.array(signs))

        return log_abs_psi, combined_sign

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))


@dataclass
class slater(wave_function):
    """Slater determinant wave function."""

    n_elec: Tuple[int, int]

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log(
        self, pos: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate both log(|psi|) and sign of the Slater determinant."""
        k_points = wave_data["k_points"]
        cos_cols_up = vmap(
            lambda k: vmap(lambda x: jnp.cos(jnp.dot(k, x)), in_axes=(0,))(pos[0]),
            in_axes=(0,),
        )(k_points[::2])
        sin_cols_up = vmap(
            lambda k: vmap(lambda x: jnp.sin(jnp.dot(k, x)), in_axes=(0,))(pos[0]),
            in_axes=(0,),
        )(k_points[1::2])
        cols_up = jnp.empty((k_points.shape[0], self.n_elec[0]))
        cols_up = cols_up.at[::2, :].set(cos_cols_up)
        cols_up = cols_up.at[1::2, :].set(sin_cols_up)
        cos_cols_dn = vmap(
            lambda k: vmap(lambda x: jnp.cos(jnp.dot(k, x)), in_axes=(0,))(pos[1]),
            in_axes=(0,),
        )(k_points[::2])
        sin_cols_dn = vmap(
            lambda k: vmap(lambda x: jnp.sin(jnp.dot(k, x)), in_axes=(0,))(pos[1]),
            in_axes=(0,),
        )(k_points[1::2])
        cols_dn = jnp.empty((k_points.shape[0], self.n_elec[1]))
        cols_dn = cols_dn.at[::2, :].set(cos_cols_dn)
        cols_dn = cols_dn.at[1::2, :].set(sin_cols_dn)

        sign_up, logdet_up = jnp.linalg.slogdet(cols_up)
        sign_dn, logdet_dn = jnp.linalg.slogdet(cols_dn)
        log_abs_psi = logdet_up + logdet_dn
        combined_sign = sign_up * sign_dn
        return log_abs_psi, combined_sign

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))


@dataclass
class jastrow(wave_function):
    """Jastrow factor wave function.

    Coulomb-Yukawa form of the Jastrow factor with
    u(r_ij) = a / r_ij * (1 - exp(-r_ij / f))
    """

    n_elec: Tuple[int, int]

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate the log of the Jastrow factor."""
        dist, _ = ueg_sys._calc_dis(pos)
        i_indices, j_indices = jnp.triu_indices(self.n_elec[0], k=1)
        pairwise_dist_up = dist[i_indices, j_indices]
        pairwise_dist_dn = dist[i_indices + self.n_elec[0], j_indices + self.n_elec[0]]
        pairwise_dist_updn = dist[self.n_elec[0] :, : self.n_elec[0]]
        omega_p = (4 * jnp.pi * ueg_sys.density) ** 0.5
        a = 1 / omega_p
        f_ss = (2 * a) ** 0.5
        f_os = a**0.5

        def u_fun(r_ij, f):
            return a / r_ij * (1 - jnp.exp(-r_ij / f))

        u_up = jnp.sum(vmap(u_fun, in_axes=(0, None))(pairwise_dist_up, f_ss))
        u_dn = jnp.sum(vmap(u_fun, in_axes=(0, None))(pairwise_dist_dn, f_ss))
        u_updn = jnp.sum(vmap(u_fun, in_axes=(0, None))(pairwise_dist_updn, f_os))

        return -u_up - u_dn - u_updn, jnp.array(1.0)

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))

In [130]:
slater_state = slater((7, 7))
system = ueg(0.5, (7, 7), seed=0)
walker_data = system.init_walker_data(1)
k_points = system.get_occ_k_points()
walker_data["k_points"] = k_points
pos = walker_data["pos"]
# print(pos)
print(jnp.exp(slater_state.evaluate(pos, system, walker_data)[0]))
# print(slater_state.gradient(pos, system, walker_data))
print(-slater_state.laplacian(pos, system, walker_data) / 2)
# jastrow_state = jastrow((7, 7))
# print(jastrow_state.evaluate(pos, system, walker_data))
# print(jastrow_state.gradient(pos, system, walker_data))
# print(jastrow_state.laplacian(pos, system, walker_data))

[37.59767]
[650.32528]


In [180]:
class hamiltonian:
    ewald_truncation_limit: int = 10
    """Hamiltonian class for the UEG."""

    @partial(jit, static_argnums=(0, 2))
    def _coulomb_ewald(self, ee_disp: jax.Array, ueg_sys: ueg) -> jax.Array:
        """Compute the Ewald sum for the Coulomb potential. Stolen from ferminet."""
        lattice = jnp.array(ueg_sys.box_length) * jnp.eye(3)
        rec = jnp.array(ueg_sys.rec_lattice) * jnp.eye(3)
        volume = ueg_sys.volume
        gamma = 2.8 / volume ** (1 / 3) ** 2
        ordinals = sorted(
            range(-self.ewald_truncation_limit, self.ewald_truncation_limit + 1),
            key=abs,
        )
        ordinals = jnp.array(list(itertools.product(ordinals, repeat=3)))
        lat_vectors = jnp.einsum("kj,ij->ik", lattice, ordinals)
        rec_vectors = jnp.einsum("jk,ij->ik", rec, ordinals[1:])
        rec_vec_square = jnp.einsum("ij,ij->i", rec_vectors, rec_vectors)
        lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1)

        def real_space_ewald(separation: jax.Array):
            displacements = jnp.linalg.norm(separation - lat_vectors, axis=-1)
            return jnp.sum(
                jax.scipy.special.erfc(gamma**0.5 * displacements) / displacements
            )

        def recp_space_ewald(separation: jax.Array):
            return (4 * jnp.pi / volume) * jnp.sum(
                jnp.exp(1.0j * jnp.dot(rec_vectors, separation))
                * jnp.exp(-rec_vec_square / (4 * gamma))
                / rec_vec_square
            )

        def ewald_sum(separation: jax.Array):
            return (
                real_space_ewald(separation)
                + recp_space_ewald(separation)
                - jnp.pi / (volume * gamma)
            )

        madelung_const = (
            jnp.sum(jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / lat_vec_norm)
            - 2 * gamma**0.5 / jnp.pi**0.5
        )
        madelung_const += (4 * jnp.pi / volume) * jnp.sum(
            jnp.exp(-rec_vec_square / (4 * gamma)) / rec_vec_square
        ) - jnp.pi / (volume * gamma)

        batch_ewald_sum = jax.vmap(ewald_sum, in_axes=(0,))

        def electron_electron_potential(ee: jax.Array):
            """Evaluates periodic electron-electron potential."""
            nelec = ee.shape[0]
            ee = jnp.reshape(ee, [-1, 3])
            ewald = batch_ewald_sum(ee)
            ewald = jnp.reshape(ewald, [nelec, nelec])
            ewald = ewald.at[jnp.diag_indices(nelec)].set(0.0)
            return 0.5 * jnp.sum(ewald) + 0.5 * nelec * madelung_const

        return jnp.real(electron_electron_potential(ee_disp))

    @partial(jit, static_argnums=(0, 2))
    def coulomb_ewald(self, ee_disp: jax.Array, ueg_sys: ueg) -> jax.Array:
        """Compute the Ewald sum for the Coulomb potential."""

        def walker_scan(carry, ee_disp_i):
            return None, self._coulomb_ewald(ee_disp_i, ueg_sys)

        _, ene = jax.lax.scan(walker_scan, None, ee_disp)
        return ene

    @partial(jit, static_argnums=(0, 3, 4))
    def _local_kinetic_energy(
        self, pos: jax.Array, wave_data: dict, wave: wave_function, ueg_sys: ueg
    ) -> jax.Array:
        """Calculate the local kinetic energy using log-domain computations.

        The local kinetic energy is -0.5 * (∇²ψ/ψ) = -0.5 * (∇²(log |ψ|) + |∇(log |ψ|)|²)
        """
        grad_log_abs_psi = wave._gradient(pos, ueg_sys, wave_data)
        lap_log_abs_psi = wave._laplacian(pos, ueg_sys, wave_data)
        grad_squared = jnp.sum(grad_log_abs_psi**2)
        return -0.5 * (lap_log_abs_psi + grad_squared)

    @partial(jit, static_argnums=(0, 3, 4))
    def local_kinetic_energy(
        self, pos_batch: jax.Array, wave_data: dict, wave: wave_function, ueg_sys: ueg
    ) -> jax.Array:
        """Calculate the local kinetic energy for a batch of positions."""
        return vmap(self._local_kinetic_energy, in_axes=(0, None, None, None))(
            pos_batch, wave_data, wave, ueg_sys
        )

    @partial(jit, static_argnums=(0, 4, 5))
    def local_energy(
        self,
        pos_batch: jax.Array,
        walker_data: dict,
        wave_data: dict,
        wave: wave_function,
        ueg_sys: ueg,
    ) -> jax.Array:
        """Compute the local energy of the system."""
        ee_disp = walker_data["disp"]
        pot_ene = self.coulomb_ewald(ee_disp, ueg_sys)
        kin_ene = self.local_kinetic_energy(pos_batch, wave_data, wave, ueg_sys)
        return kin_ene + pot_ene

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))

In [185]:
n_elec = (7, 7)
r_s = 1.0
slater_state = slater(n_elec)
system = ueg(r_s, n_elec, seed=0)
walker_data = system.init_walker_data(20)
k_points = system.get_occ_k_points()
wave_data = {"k_points": k_points}
pos = walker_data["pos"]
jastrow_state = jastrow(n_elec)
jastrow_slater = product_state(n_elec, (jastrow_state, slater_state))
ham = hamiltonian()
print(
    f"\nslater local kinetic energy:\n{ham.local_kinetic_energy(pos, wave_data, slater_state, system)}"
)
print(
    f"\njastrow slater local kinetic energy:\n{ham.local_kinetic_energy(pos, wave_data, jastrow_slater, system)}"
)
print(f"\ncoulomb ewald:\n{ham.coulomb_ewald(walker_data['disp'], system)}")
print(
    f"\njastrow slater local energy per electron:\n{ham.local_energy(pos, walker_data, wave_data, jastrow_slater, system) / system.n_particles}"
)
# direc = jnp.array(np.random.randn(3))
# direc /= jnp.linalg.norm(direc)
# pos_coinc = pos.at[0, 0, 0].set(pos[0, 0, 1])
# for dr in [10 ** (-n) for n in range(1, 5)]:
#     pos_coinc_dr = pos_coinc.at[0, 0, 0].add(dr * direc)
#     new_walker_data = system.update_walker_data(pos_coinc_dr, walker_data)
#     j_kin_ene = ham.local_kinetic_energy(
#         pos_coinc_dr, wave_data, jastrow_slater, system
#     )
#     j_pot_ene = ham.coulomb_ewald(new_walker_data["disp"], system)
#     j_local_ene = j_kin_ene + j_pot_ene
#     print(f"j_kin_ene: {j_kin_ene}, j_pot_ene: {j_pot_ene}, j_local_ene: {j_local_ene}")


slater local kinetic energy:
[15.69278 15.69278 15.69278 15.69278 15.69278 15.69278 15.69278 15.69278
 15.69278 15.69278 15.69278 15.69278 15.69278 15.69278 15.69278 15.69278
 15.69278 15.69278 15.69278 15.69278]

jastrow slater local kinetic energy:
[ -0.709     3.29035  -4.50766  10.40404   5.38758  -0.83172  -3.91401
   8.4454    4.80808  10.10691   7.69061   3.22545  -1.0655  -20.6299
   5.41867   9.58622   9.65809   6.35987   5.07636   8.67459]

coulomb ewald:
[-1.49381 -2.39022 -2.57552 -8.36747 -4.0468  -1.22493 -3.76854 -5.93575
 -6.89416 -8.2186  -6.56479 -3.4446  -3.46815 -3.13693 -9.40796 -7.60908
 -9.68321 -5.52666 -3.95609 -7.5279 ]

jastrow slater local energy per electron:
[-0.15734  0.0643  -0.50594  0.14547  0.09577 -0.1469  -0.54875  0.17926
 -0.14901  0.13488  0.08042 -0.01565 -0.32383 -1.69763 -0.28495  0.14122
 -0.00179  0.05951  0.08002  0.08191]


In [None]:
@dataclass
class metropolis:
    n_elec: Tuple[int, int]
    step_size: float = 0.1
    seed: int = 0

    @partial(jit, static_argnums=(0, 1, 2))
    def _metropolis_step(
        self,
        ueg_sys: ueg,
        wave: wave_function,
        walker_data: dict,
        wave_data: dict,
        sampling_data: dict,
    ):
        random_key = walker_data["random_key"]
        random_key, subkey = random.split(random_key)
        pos = walker_data["pos"]