In [1]:
import numpy as np
import scipy as sp

from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
import jax
from jax import (
    vmap,
    jit,
    random,
    tree_util,
    numpy as jnp,
    grad,
    jvp,
    lax,
    hessian,
    scipy as jsp,
)
import matplotlib.pyplot as plt
import pickle
from functools import reduce

from flax import linen as nn
from functools import partial

print = partial(print, flush=True)

import math
import itertools
from dataclasses import dataclass, field
from typing import Callable, Optional, Sequence, Union, Any, Tuple

from abc import ABC, abstractmethod
from enum import Enum
import pytest

np.set_printoptions(suppress=True, precision=5)

<div style="background-color: white; color: black; padding: 14px; font-size: 16px;">

When you take log JAX grad will likely complain about complex values functions. So we need to keep track of sign and magnitude separately. We use the fact that away from the nodes gradient of the sign should be zero, and we will never be exactly on the node, so we can just use log of absolute value. Given $\psi = s e^f$ where $f = \log|\psi|$ and $s = \text{sign}(\psi)$.

$$\nabla \psi = s e^f \nabla f$$

$$\nabla^2 \psi = s e^f [\nabla^2 f + (\nabla f)^2]$$

$$\frac{\nabla^2 \psi}{\psi} = \nabla^2 f + (\nabla f)^2$$

</div>


In [2]:
@dataclass
class ueg:
    r_s: float
    n_elec: Tuple[int, int]
    box_length: float = 0.0
    rec_lattice: Tuple = ()
    dim: int = 3
    volume: float = 0.0
    n_particles: int = 0
    density: float = 0.0
    seed: int = 0

    def __post_init__(self):
        assert self.dim == 3, "Only 3D systems are supported."
        assert (
            self.n_elec[0] == self.n_elec[1]
        ), "Only unpolarized systems are supported."
        self.box_length = (
            4 / 3 * jnp.pi * self.r_s**3 * (self.n_elec[0] + self.n_elec[1])
        ) ** (1 / 3)
        self.rec_lattice = (2 * jnp.pi / self.box_length,) * 3
        self.volume = self.box_length**3
        self.n_particles = self.n_elec[0] + self.n_elec[1]
        self.density = self.n_particles / self.volume

    def get_occ_k_points(self) -> jax.Array:
        """Get the occupied k-points for the system."""
        dk = 1 + 1e-5
        max_k = int(jnp.ceil(self.n_elec[0] * dk) ** (1 / 3.0))
        ordinals = sorted(range(-max_k, max_k + 1), key=abs)
        ordinals = jnp.asarray(list(itertools.product(ordinals, repeat=3)))
        kpoints = ordinals @ (jnp.array(self.rec_lattice) * jnp.eye(3)).T
        kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm))
        k_norms = jnp.linalg.norm(kpoints, axis=1)
        kpoints = kpoints[k_norms <= k_norms[self.n_elec[0] - 1] * dk]

        kpoints_list = kpoints.tolist()
        result = []
        result.append(kpoints_list[0])
        # remove gamma from consideration
        kpoints_list = [k for i, k in enumerate(kpoints_list) if i != 0]

        pairs = {}
        processed = set()
        for k in kpoints_list:
            k_tuple = tuple(k)
            if k_tuple in processed:
                continue

            neg_k = tuple(-x for x in k)
            processed.add(k_tuple)
            if neg_k in map(tuple, kpoints_list):
                processed.add(neg_k)

            canonical = None
            for i, val in enumerate(k):
                if abs(val) > 1e-10:
                    if val > 0:
                        canonical = k_tuple
                        partner = neg_k
                    else:
                        canonical = neg_k
                        partner = k_tuple
                    break

            if canonical is not None:
                pairs[canonical] = partner

        sorted_canonicals = sorted(pairs.keys(), key=lambda k: sum(x * x for x in k))
        for canonical in sorted_canonicals:
            result.append(canonical)
            result.append(pairs[canonical])
        return jnp.array(result)

    @partial(jit, static_argnums=(0,))
    def _calc_dis(self, pos: jax.Array) -> Tuple:
        box_length = jnp.array([self.box_length, self.box_length, self.box_length])
        pos_up = pos[0]
        pos_dn = pos[1]
        pos_flat = jnp.concatenate([pos_up, pos_dn], axis=0)
        n_particles = pos_flat.shape[0]

        def get_disp(i, j):
            dr = pos_flat[i] - pos_flat[j]
            dr = dr - box_length * jnp.round(dr / box_length)
            return dr

        disp = vmap(
            lambda i: vmap(get_disp, in_axes=(None, 0))(i, jnp.arange(n_particles))
        )(jnp.arange(n_particles))
        dist = jnp.sqrt(jnp.sum(disp**2, axis=-1) + 1e-10)
        mask = ~jnp.eye(n_particles, dtype=bool)
        dist = jnp.where(mask, dist, 0.0)
        return dist, disp

    def init_walker_data(self, n_walkers: int) -> dict:
        def walker_init(subkey):
            subkey, subkey_up = random.split(subkey)
            pos_up = random.uniform(subkey_up, (self.n_elec[0], 3)) * self.box_length
            subkey, subkey_dn = random.split(subkey)
            pos_dn = random.uniform(subkey_dn, (self.n_elec[1], 3)) * self.box_length
            pos = jnp.array([pos_up, pos_dn])
            dist, disp = self._calc_dis(pos)
            return pos, dist, disp

        random_key = random.PRNGKey(self.seed)
        random_key, *subkeys = random.split(random_key, n_walkers + 1)
        pos, dist, disp = vmap(walker_init)(jnp.array(subkeys))
        return {
            "pos": pos,
            "dist": dist,
            "disp": disp,
            "random_key": random_key,
        }

    @partial(jit, static_argnums=(0,))
    def update_walker_data(self, new_pos_batch: jax.Array, walker_data: dict) -> dict:
        assert new_pos_batch.shape == walker_data["pos"].shape

        def update_single_walker(carry, new_pos_i):
            dist, disp = self._calc_dis(new_pos_i)
            return carry, (dist, disp)

        _, (dist, disp) = lax.scan(update_single_walker, None, new_pos_batch)
        walker_data["dist"] = dist
        walker_data["disp"] = disp
        walker_data["pos"] = new_pos_batch
        return walker_data

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))

In [3]:
@dataclass
class wave_function(ABC):
    """Abstract base class for wave functions."""

    n_elec: Tuple[int, int]
    use_hessian: bool = False

    @partial(jit, static_argnums=(0, 2))
    @abstractmethod
    def _evaluate_log(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate both the logarithm of the absolute value and the sign of the wave function.

        Args:
            pos: The positions of the electrons. shape: (2, n_elec_s, 3)
            sys: The system object containing system parameters.
            wave_data: Additional data needed for the evaluation.

        Returns:
            (log(|psi|), sign)
        """
        pass

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log_abs(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Get only the log abs part for gradient calculations."""
        log_abs, _ = self._evaluate_log(pos, ueg_sys, wave_data)
        return log_abs

    @partial(jit, static_argnums=(0, 2))
    def evaluate(
        self, pos_batch: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> Union[jax.Array, Tuple[jax.Array, jax.Array]]:
        """Evaluate the wave function log for a batch of positions.

        Args:
            pos_batch: Batch of positions
            ueg_sys: The system object containing system parameters
            wave_data: Additional data needed for the evaluation

        Returns:
            (log(|psi|), signs)
        """

        return vmap(self._evaluate_log, in_axes=(0, None, None))(
            pos_batch, ueg_sys, wave_data
        )

    @partial(jit, static_argnums=(0, 2))
    def _gradient(self, pos: jnp.array, ueg_sys: ueg, wave_data: dict) -> jax.Array:
        """Calculate the gradient of log(|psi|) wrt positions.

        Args:
            pos: The positions of the electrons. shape: (2, n_elec_s, 3)
            sys: The system object containing system parameters.
            wave_data: Additional data needed for the evaluation.
        Returns:
            The gradient of log(|psi|).
        """
        return grad(self._evaluate_log_abs)(pos, ueg_sys, wave_data)

    @partial(jit, static_argnums=(0, 2))
    def gradient(
        self, pos_batch: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Calculate the gradient of log(|psi|) for a batch of positions."""
        return vmap(self._gradient, in_axes=(0, None, None))(
            pos_batch, ueg_sys, wave_data
        )

    @partial(jit, static_argnums=(0, 2))
    def _laplacian(self, pos: jnp.array, ueg_sys: ueg, wave_data: dict) -> jax.Array:
        """Calculate the Laplacian of log(|psi|).

        Args:
            pos: The positions of the electrons. shape: (2, n_elec_s, 3)
            sys: The system object containing system parameters.
            wave_data: Additional data needed for the evaluation.
        Returns:
            The Laplacian of log(|psi|).
        """
        orig_shape = pos.shape
        flat_pos = pos.reshape(-1)
        n_coords = flat_pos.size

        def log_abs_fn(p):
            p_reshaped = p.reshape(orig_shape)
            return self._evaluate_log_abs(p_reshaped, ueg_sys, wave_data)

        def flat_grad_fn(flat_p):
            return grad(log_abs_fn)(flat_p).reshape(-1)

        def hess_diag_element(i):
            unit_vec = jnp.zeros(n_coords).at[i].set(1.0)
            _, hess_i = jvp(flat_grad_fn, (flat_pos,), (unit_vec,))
            return hess_i[i]

        diagonal = vmap(hess_diag_element)(jnp.arange(n_coords))

        return jnp.sum(diagonal)

    @partial(jit, static_argnums=(0, 2))
    def _laplacian_hessian(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Laplacian of log(|psi|) using jax Hessian function.

        Constructs the full Hessian matrix of log(|psi|).
        Probably faster for small systems, but uses more memory.
        """
        orig_shape = pos.shape
        flat_pos = pos.reshape(-1)

        def flat_eval_fn(flat_p):
            p_reshaped = flat_p.reshape(orig_shape)
            return self._evaluate_log_abs(p_reshaped, ueg_sys, wave_data)

        return jnp.trace(hessian(flat_eval_fn)(flat_pos))

    @partial(jit, static_argnums=(0, 2))
    def laplacian(
        self, pos_batch: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> jax.Array:
        """Calculate the Laplacian of log(|psi|) for a batch of positions."""
        if self.use_hessian:
            _lap_fun = self._laplacian_hessian
        else:
            _lap_fun = self._laplacian

        return vmap(_lap_fun, in_axes=(0, None, None))(pos_batch, ueg_sys, wave_data)

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))


@dataclass
class product_state(wave_function):
    """Product of wave functions in the states."""

    n_elec: Tuple[int, int]
    states: Tuple[wave_function, ...]
    use_hessian: bool = field(default=False, init=False)

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate both log(|psi|) and sign of the product wave function."""
        results = [
            state._evaluate_log(pos, ueg_sys, wave_data) for state in self.states
        ]

        log_abs_values = [res[0] for res in results]
        log_abs_psi = jnp.sum(jnp.array(log_abs_values))

        signs = [res[1] for res in results]
        combined_sign = jnp.prod(jnp.array(signs))

        return log_abs_psi, combined_sign

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))


@dataclass
class slater(wave_function):
    """Slater determinant wave function."""

    n_elec: Tuple[int, int]

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log(
        self, pos: jax.Array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate both log(|psi|) and sign of the Slater determinant."""
        k_points = wave_data["k_points"]
        cos_cols_up = vmap(
            lambda k: vmap(lambda x: jnp.cos(jnp.dot(k, x)), in_axes=(0,))(pos[0]),
            in_axes=(0,),
        )(k_points[::2])
        sin_cols_up = vmap(
            lambda k: vmap(lambda x: jnp.sin(jnp.dot(k, x)), in_axes=(0,))(pos[0]),
            in_axes=(0,),
        )(k_points[1::2])
        cols_up = jnp.empty((k_points.shape[0], self.n_elec[0]))
        cols_up = cols_up.at[::2, :].set(cos_cols_up)
        cols_up = cols_up.at[1::2, :].set(sin_cols_up)
        cos_cols_dn = vmap(
            lambda k: vmap(lambda x: jnp.cos(jnp.dot(k, x)), in_axes=(0,))(pos[1]),
            in_axes=(0,),
        )(k_points[::2])
        sin_cols_dn = vmap(
            lambda k: vmap(lambda x: jnp.sin(jnp.dot(k, x)), in_axes=(0,))(pos[1]),
            in_axes=(0,),
        )(k_points[1::2])
        cols_dn = jnp.empty((k_points.shape[0], self.n_elec[1]))
        cols_dn = cols_dn.at[::2, :].set(cos_cols_dn)
        cols_dn = cols_dn.at[1::2, :].set(sin_cols_dn)

        sign_up, logdet_up = jnp.linalg.slogdet(cols_up)
        sign_dn, logdet_dn = jnp.linalg.slogdet(cols_dn)
        log_abs_psi = logdet_up + logdet_dn
        combined_sign = sign_up * sign_dn
        return log_abs_psi, combined_sign

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))


@dataclass
class jastrow(wave_function):
    """Jastrow factor wave function.

    Coulomb-Yukawa form of the Jastrow factor with
    u(r_ij) = a / r_ij * (1 - exp(-r_ij / f))
    """

    n_elec: Tuple[int, int]

    @partial(jit, static_argnums=(0, 2))
    def _evaluate_log(
        self, pos: jnp.array, ueg_sys: ueg, wave_data: dict
    ) -> Tuple[jax.Array, jax.Array]:
        """Evaluate the log of the Jastrow factor."""
        dist, _ = ueg_sys._calc_dis(pos)
        i_indices, j_indices = jnp.triu_indices(self.n_elec[0], k=1)
        pairwise_dist_up = dist[i_indices, j_indices]
        pairwise_dist_dn = dist[i_indices + self.n_elec[0], j_indices + self.n_elec[0]]
        pairwise_dist_updn = dist[self.n_elec[0] :, : self.n_elec[0]]
        omega_p = (4 * jnp.pi * ueg_sys.density) ** 0.5
        a = 1 / omega_p
        f_ss = (2 * a) ** 0.5
        f_os = a**0.5

        def u_fun(r_ij, f):
            return a / r_ij * (1 - jnp.exp(-r_ij / f))

        u_up = jnp.sum(vmap(u_fun, in_axes=(0, None))(pairwise_dist_up, f_ss))
        u_dn = jnp.sum(vmap(u_fun, in_axes=(0, None))(pairwise_dist_dn, f_ss))
        u_updn = jnp.sum(vmap(u_fun, in_axes=(0, None))(pairwise_dist_updn, f_os))

        return -u_up - u_dn - u_updn, jnp.array(1.0)

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))

In [4]:
class hamiltonian:
    ewald_truncation_limit: int = 5
    """Hamiltonian class for the UEG."""

    @partial(jit, static_argnums=(0, 2))
    def _coulomb_ewald(self, ee_disp: jax.Array, ueg_sys: ueg) -> jax.Array:
        """Compute the Ewald sum for the Coulomb potential. Stolen from ferminet."""
        lattice = jnp.array(ueg_sys.box_length) * jnp.eye(3)
        rec = jnp.array(ueg_sys.rec_lattice) * jnp.eye(3)
        volume = ueg_sys.volume
        gamma = 2.8 / volume ** (1 / 3) ** 2
        ordinals = sorted(
            range(-self.ewald_truncation_limit, self.ewald_truncation_limit + 1),
            key=abs,
        )
        ordinals = jnp.array(list(itertools.product(ordinals, repeat=3)))
        lat_vectors = jnp.einsum("kj,ij->ik", lattice, ordinals)
        rec_vectors = jnp.einsum("jk,ij->ik", rec, ordinals[1:])
        rec_vec_square = jnp.einsum("ij,ij->i", rec_vectors, rec_vectors)
        lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1)

        def real_space_ewald(separation: jax.Array):
            displacements = jnp.linalg.norm(separation - lat_vectors, axis=-1)
            return jnp.sum(
                jax.scipy.special.erfc(gamma**0.5 * displacements) / displacements
            )

        def recp_space_ewald(separation: jax.Array):
            return (4 * jnp.pi / volume) * jnp.sum(
                jnp.exp(1.0j * jnp.dot(rec_vectors, separation))
                * jnp.exp(-rec_vec_square / (4 * gamma))
                / rec_vec_square
            )

        def ewald_sum(separation: jax.Array):
            return (
                real_space_ewald(separation)
                + recp_space_ewald(separation)
                - jnp.pi / (volume * gamma)
            )

        madelung_const = (
            jnp.sum(jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / lat_vec_norm)
            - 2 * gamma**0.5 / jnp.pi**0.5
        )
        madelung_const += (4 * jnp.pi / volume) * jnp.sum(
            jnp.exp(-rec_vec_square / (4 * gamma)) / rec_vec_square
        ) - jnp.pi / (volume * gamma)

        batch_ewald_sum = jax.vmap(ewald_sum, in_axes=(0,))

        def electron_electron_potential(ee: jax.Array):
            """Evaluates periodic electron-electron potential."""
            nelec = ee.shape[0]
            ee = jnp.reshape(ee, [-1, 3])
            ewald = batch_ewald_sum(ee)
            ewald = jnp.reshape(ewald, [nelec, nelec])
            ewald = ewald.at[jnp.diag_indices(nelec)].set(0.0)
            return 0.5 * jnp.sum(ewald) + 0.5 * nelec * madelung_const

        return jnp.real(electron_electron_potential(ee_disp))

    @partial(jit, static_argnums=(0, 2))
    def coulomb_ewald(self, walker_data: dict, ueg_sys: ueg) -> jax.Array:
        """Compute the Ewald sum for the Coulomb potential."""
        ee_disp = walker_data["disp"]

        def walker_scan(carry, ee_disp_i):
            return None, self._coulomb_ewald(ee_disp_i, ueg_sys)

        _, ene = jax.lax.scan(walker_scan, None, ee_disp)
        return ene

    @partial(jit, static_argnums=(0, 3, 4))
    def _local_kinetic_energy(
        self, pos: jax.Array, wave_data: dict, wave: wave_function, ueg_sys: ueg
    ) -> jax.Array:
        """Calculate the local kinetic energy using log-domain computations.

        The local kinetic energy is -0.5 * (∇²ψ/ψ) = -0.5 * (∇²(log |ψ|) + |∇(log |ψ|)|²)
        """
        grad_log_abs_psi = wave._gradient(pos, ueg_sys, wave_data)
        lap_log_abs_psi = wave._laplacian(pos, ueg_sys, wave_data)
        grad_squared = jnp.sum(grad_log_abs_psi**2)
        return -0.5 * (lap_log_abs_psi + grad_squared)

    @partial(jit, static_argnums=(0, 3, 4))
    def local_kinetic_energy(
        self, pos_batch: jax.Array, wave_data: dict, wave: wave_function, ueg_sys: ueg
    ) -> jax.Array:
        """Calculate the local kinetic energy for a batch of positions."""
        return vmap(self._local_kinetic_energy, in_axes=(0, None, None, None))(
            pos_batch, wave_data, wave, ueg_sys
        )

    @partial(jit, static_argnums=(0, 3, 4))
    def local_energy(
        self,
        walker_data: dict,
        wave_data: dict,
        wave: wave_function,
        ueg_sys: ueg,
    ) -> jax.Array:
        """Compute the local energy of the system."""
        pos_batch = walker_data["pos"]
        pot_ene = self.coulomb_ewald(walker_data, ueg_sys)
        kin_ene = self.local_kinetic_energy(pos_batch, wave_data, wave, ueg_sys)
        return kin_ene + pot_ene

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))

In [7]:
n_elec = (1, 1)
r_s = 1.0
slater_state = slater(n_elec)
system = ueg(r_s, n_elec, seed=10)
walker_data = system.init_walker_data(20)
k_points = system.get_occ_k_points()
wave_data = {"k_points": k_points}
pos = walker_data["pos"]
jastrow_state = jastrow(n_elec)
jastrow_slater = product_state(n_elec, (jastrow_state, slater_state))
# jastrow_slater = slater_state
ham = hamiltonian()
lke = ham.local_kinetic_energy(pos, wave_data, jastrow_slater, system)
print(f"Local kinetic energy: {lke}")
coulomb_ewald = ham.coulomb_ewald(walker_data, system)
print(f"Coulomb Ewald: {coulomb_ewald}")
le = (
    ham.local_energy(walker_data, wave_data, jastrow_slater, system)
    / system.n_particles
)
print(f"Local energy per electron: {le}")
print(f"average: {jnp.mean(le)}, std: {jnp.std(le)}")
# direc = jnp.array(np.random.randn(3))
# direc /= jnp.linalg.norm(direc)
# pos_coinc = pos.at[0, 0, 0].set(pos[0, 0, 1])
# for dr in [10 ** (-n) for n in range(1, 5)]:
#     pos_coinc_dr = pos_coinc.at[0, 0, 0].add(dr * direc)
#     new_walker_data = system.update_walker_data(pos_coinc_dr, walker_data)
#     j_kin_ene = ham.local_kinetic_energy(
#         pos_coinc_dr, wave_data, jastrow_slater, system
#     )
#     j_pot_ene = ham.coulomb_ewald(new_walker_data["disp"], system)
#     j_local_ene = j_kin_ene + j_pot_ene
#     print(f"j_kin_ene: {j_kin_ene}, j_pot_ene: {j_pot_ene}, j_local_ene: {j_local_ene}")

Local kinetic energy: [-0.2706  -0.88778 -0.11603 -0.48488 -0.16432 -0.22079 -0.30832 -0.4177
 -0.42316 -1.04345 -0.82659 -0.5633  -0.52504 -0.23368 -0.24009 -2.26593
 -0.89956 -0.38441 -0.29832 -0.16373]
Coulomb Ewald: [-1.56715 -0.99577 -1.77185 -1.39715 -1.73192 -1.6524  -1.51077 -1.47608
 -1.43902 -0.83398 -1.05545 -1.34049 -1.36145 -1.6373  -1.62385  0.42933
 -0.97779 -1.45481 -1.54629 -1.69888]
Local energy per electron: [-0.91887 -0.94178 -0.94394 -0.94101 -0.94812 -0.93659 -0.90955 -0.94689
 -0.93109 -0.93872 -0.94102 -0.9519  -0.94325 -0.93549 -0.93197 -0.9183
 -0.93868 -0.91961 -0.9223  -0.9313 ]
average: -0.9345186718101499, std: 0.011210859076980812


In [None]:
@dataclass
class metropolis:
    n_elec: Tuple[int, int]
    step_size: float = 0.1
    seed: int = 0

    @partial(jit, static_argnums=(0, 1, 2))
    def _metropolis_step(
        self,
        ueg_sys: ueg,
        wave: wave_function,
        walker_data: dict,
        wave_data: dict,
        sampling_data: dict,
    ):
        random_key = walker_data["random_key"]
        random_key, subkey = random.split(random_key)
        pos = walker_data["pos"]

converged SCF energy = -1.39700704811961
escf: -1.3970070481196089
CCSD energy -1.4148953453196933
CCSD(T) energy -1.4148953453196933
CCSD(T) energy per electron -0.7074476726598466




The Hamiltonian for the uniform electron gas is given by

$$\hat{H} = \hat{T} + \hat{V}_{ee} + \hat{V}_{eb} + \hat{V}_{bb}$$

where:

-   $\hat{T}$ is the kinetic energy
-   $\hat{V}_{ee}$ is the electron-electron interaction
-   $\hat{V}_{eb}$ is the electron-background interaction
-   $\hat{V}_{bb}$ is the background-background interaction

We will work with per simulation cell quantities. With periodic boundary conditions, we use plane waves as basis functions

$$\phi_{\mathbf{k}}(\mathbf{r}) = \frac{1}{\sqrt{\Omega}}e^{i\mathbf{k}\cdot\mathbf{r}}$$

where $\Omega$ is the simulation cell volume. The field operators are expanded as

$$\hat{\psi}(\mathbf{r}) = \frac{1}{\sqrt{\Omega}}\sum_{\mathbf{k}}e^{i\mathbf{k}\cdot\mathbf{r}}\hat{c}_{\mathbf{k}}$$

$$\hat{\psi}^{\dagger}(\mathbf{r}) = \frac{1}{\sqrt{\Omega}}\sum_{\mathbf{k}}e^{-i\mathbf{k}\cdot\mathbf{r}}\hat{c}^{\dagger}_{\mathbf{k}}$$

where $\hat{c}_{\mathbf{k}}$ and $\hat{c}^{\dagger}_{\mathbf{k}}$ are the annihilation and creation operators for an electron with wave vector $\mathbf{k}$.

The electron-electron interaction is given by

$$\hat{V}_{ee} = \frac{1}{2}\int\int d\mathbf{r}_1 d\mathbf{r}_2 \, \hat{\psi}^{\dagger}(\mathbf{r}_1)\hat{\psi}^{\dagger}(\mathbf{r}_2)v(\mathbf{r}_1-\mathbf{r}_2)\hat{\psi}(\mathbf{r}_2)\hat{\psi}(\mathbf{r}_1)$$

where $v(\mathbf{r}_1-\mathbf{r}_2)$ is the Coulomb potential:

$$v(\mathbf{r}_1-\mathbf{r}_2) = \sum_{\mathbf{R}}\frac{1}{|\mathbf{r}_1-\mathbf{r}_2-\mathbf{R}|}$$

Substituting the field operators into the interaction term we obtain the integral

$$I = \int\int d\mathbf{r}_1 d\mathbf{r}_2 \, e^{-i\mathbf{k}_1\cdot\mathbf{r}_1}e^{-i\mathbf{k}_2\cdot\mathbf{r}_2}v(\mathbf{r}_1-\mathbf{r}_2)e^{i\mathbf{k}_3\cdot\mathbf{r}_2}e^{i\mathbf{k}_4\cdot\mathbf{r}_1}$$

$$I = \int\int d\mathbf{r}_1 d\mathbf{r}_2 \, e^{i(\mathbf{k}_4-\mathbf{k}_1)\cdot\mathbf{r}_1}e^{i(\mathbf{k}_3-\mathbf{k}_2)\cdot\mathbf{r}_2}v(\mathbf{r}_1-\mathbf{r}_2)$$

We make a change of variables:

-   $\mathbf{r} = \mathbf{r}_1 - \mathbf{r}_2$
-   $\mathbf{R} = \mathbf{r}_1$

$$I = \int\int d\mathbf{R} d\mathbf{r} \, e^{i(\mathbf{k}_4-\mathbf{k}_1)\cdot\mathbf{R}}e^{i(\mathbf{k}_3-\mathbf{k}_2)\cdot(\mathbf{R}-\mathbf{r})}v(\mathbf{r})$$

$$I = \int\int d\mathbf{R} d\mathbf{r} \, e^{i[(\mathbf{k}_4-\mathbf{k}_1)+(\mathbf{k}_3-\mathbf{k}_2)]\cdot\mathbf{R}}e^{-i(\mathbf{k}_3-\mathbf{k}_2)\cdot\mathbf{r}}v(\mathbf{r})$$

The integral over $\mathbf{R}$ gives a delta function enforcing momentum conservation

$$\int d\mathbf{R} \, e^{i[(\mathbf{k}_4-\mathbf{k}_1)+(\mathbf{k}_3-\mathbf{k}_2)]\cdot\mathbf{R}} = \Omega \, \delta_{\mathbf{k}_1+\mathbf{k}_2,\mathbf{k}_3+\mathbf{k}_4}$$

This leaves us with

$$I = \Omega \, \delta_{\mathbf{k}_1+\mathbf{k}_2,\mathbf{k}_3+\mathbf{k}_4} \int d\mathbf{r} \, e^{-i(\mathbf{k}_3-\mathbf{k}_2)\cdot\mathbf{r}}v(\mathbf{r})$$

Let's define $\mathbf{G} = \mathbf{k}_3 - \mathbf{k}_2 = \mathbf{k}_1 - \mathbf{k}_4$. The remaining integral is the Fourier transform of the Coulomb potential, given as

$$v(\mathbf{G}) = \int d\mathbf{r} \, e^{-i\mathbf{G}\cdot\mathbf{r}}v(\mathbf{r})$$

For $\mathbf{G} \neq 0$, we have

$$v(\mathbf{G}) = \sum_{\mathbf{L}} \int_{\Omega} \frac{e^{-i\mathbf{G}\cdot\mathbf{r}}}{|\mathbf{r} + \mathbf{L}|} d\mathbf{r}$$

Using the change of variables $\mathbf{r'} = \mathbf{r} + \mathbf{L}$, and noting that $e^{-i\mathbf{G}\cdot\mathbf{L}} = 1$ for any reciprocal lattice vector $\mathbf{G}$ and lattice vector $\mathbf{L}$, we obtain

$$v(\mathbf{G}) = \int_{\text{all space}} \frac{e^{-i\mathbf{G}\cdot\mathbf{r'}}}{|\mathbf{r'}|} d\mathbf{r'}$$

This is the Fourier transform of the Coulomb potential over all space, which gives

$$v(\mathbf{G}) = \frac{4\pi}{|\mathbf{G}|^2} \quad \text{for } \mathbf{G} \neq 0$$

Putting it all together, we get

$$\hat{V}_{ee} = \frac{1}{2\Omega}\sum_{\substack{\mathbf{k}_1,\mathbf{k}_2,\mathbf{k}_3 \\ \sigma_1,\sigma_2}}{}^{'} \frac{4\pi}{|\mathbf{k}_3-\mathbf{k}_2|^2} \, \hat{c}^{\dagger}_{\mathbf{k}_1\sigma_1}\hat{c}^{\dagger}_{\mathbf{k}_2\sigma_2}\hat{c}_{\mathbf{k}_3\sigma_2}\hat{c}_{\mathbf{k}_1+\mathbf{k}_2-\mathbf{k}_3\sigma_1}$$

Where we exclude the terms with $\mathbf{k}_3 = \mathbf{k}_2$ (which would give $\mathbf{G} = 0$).

For $\mathbf{G} = 0$, the integral diverges. Physically, this divergence is canceled by the interactions with and within a neutralizing background charge. Performing this calculation involves summing a conditionally convergent series and is usually accomplished using Ewald's method. The remaining term after the divergences are canceled out is given by the Madelung constant

$$E_M = N\frac{\xi}{2}$$

where $N$ is the number of electrons and $\xi$ is given by

$$\xi = \frac{1}{\Omega}\sum_{\mathbf{G}\neq 0}\frac{\exp(-\pi^2 G^2/\kappa^2)}{\pi G^2} - \frac{\pi}{\kappa^2\Omega} + \sum_{\mathbf{R}\neq 0}\frac{\text{erfc}(\kappa R)}{R} - \frac{2\kappa}{\sqrt{\pi}}$$

where $\kappa$ is a parameter that controls the convergence of the sum. A simplified expression for the Madelung constant is

$$E_M \approx −2.837297 \times \left(\frac{3}{4\pi}\right)^{1/3}N^{2/3}r_s^{-1}$$


In [3]:
@dataclass
class ueg_qc(ueg):
    """Quantum chemistry class for the UEG."""

    e_cut: float = 5.0

    def get_k_points(self) -> jax.Array:
        """Get the k-point basis for the system based on e_cut."""
        max_k = int(jnp.ceil(jnp.sqrt(self.e_cut * 2)))
        ordinals = sorted(range(-max_k, max_k + 1), key=abs)
        ordinals = jnp.asarray(list(itertools.product(ordinals, repeat=3)))
        kpoints = ordinals @ (jnp.array(self.rec_lattice) * jnp.eye(3)).T
        kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm))
        k_norms = jnp.linalg.norm(kpoints, axis=1) ** 2 / 2
        kpoints = kpoints[k_norms <= self.e_cut]

        kpoints_list = kpoints.tolist()
        result = []
        result.append(kpoints_list[0])
        # remove gamma from consideration
        kpoints_list = [k for i, k in enumerate(kpoints_list) if i != 0]

        pairs = {}
        processed = set()
        for k in kpoints_list:
            k_tuple = tuple(k)
            if k_tuple in processed:
                continue

            neg_k = tuple(-x for x in k)
            processed.add(k_tuple)
            if neg_k in map(tuple, kpoints_list):
                processed.add(neg_k)

            canonical = None
            for i, val in enumerate(k):
                if abs(val) > 1e-10:
                    if val > 0:
                        canonical = k_tuple
                        partner = neg_k
                    else:
                        canonical = neg_k
                        partner = k_tuple
                    break

            if canonical is not None:
                pairs[canonical] = partner

        sorted_canonicals = sorted(pairs.keys(), key=lambda k: sum(x * x for x in k))
        for canonical in sorted_canonicals:
            result.append(canonical)
            result.append(pairs[canonical])
        return jnp.array(result)

    def madelung(self):
        return (
            -2.837297
            * (3.0 / 4.0 / jnp.pi) ** (1.0 / 3.0)
            * self.n_particles ** (2.0 / 3.0)
            / self.r_s
        )

    @partial(jax.jit, static_argnums=(0,))
    def get_h1(self, k_points: jax.Array) -> jax.Array:
        """Get the one-body Hamiltonian in plane wave basis.
        Includes the Madelung constant."""
        h1 = jnp.diag(jnp.sum(k_points**2, axis=-1) / 2)
        madelung = 0.5 * self.madelung() / self.n_particles
        return h1 + madelung * jnp.eye(k_points.shape[0])

    @partial(jax.jit, static_argnums=(0,))
    def get_h1_real(self, k_points: jax.Array) -> jax.Array:
        """Get the one-body Hamiltonian in real basis."""
        h1_pw = self.get_h1(k_points)
        unitary = self.unitary_pw_to_real(k_points)
        h1 = unitary.conj() @ h1_pw @ unitary.T
        return h1.real

    @partial(jax.jit, static_argnums=(0,))
    def eri_element(
        self,
        k_points: jax.Array,
        p: jax.Array,
        q: jax.Array,
        r: jax.Array,
        s: jax.Array,
    ) -> jax.Array:
        """Periodic Coulomb interaction integral ( pq | rs )."""
        g1 = k_points[q] - k_points[p]
        g2 = k_points[r] - k_points[s]
        momentum_conserved = jnp.all(jnp.isclose(g1, g2), axis=-1)
        g1_squared = jnp.sum(g1 * g1, axis=-1)
        non_zero = g1_squared > 1e-10
        element = 4 * jnp.pi / g1_squared / self.volume
        element = jnp.where(jnp.isinf(element) | jnp.isnan(element), 0.0, element)
        return momentum_conserved * non_zero * element

    @partial(jax.jit, static_argnums=(0,))
    def get_eri_tensor(self, k_points: jax.Array) -> jax.Array:
        """Get the ERI tensor in plane wave basis."""
        n_kpts = k_points.shape[0]
        idx = jnp.arange(n_kpts)
        p_idx, q_idx, r_idx, s_idx = jnp.meshgrid(idx, idx, idx, idx, indexing="ij")
        p_flat = p_idx.flatten()
        q_flat = q_idx.flatten()
        r_flat = r_idx.flatten()
        s_flat = s_idx.flatten()
        eri_flat = self.eri_element(k_points, p_flat, q_flat, r_flat, s_flat)
        eri = eri_flat.reshape(n_kpts, n_kpts, n_kpts, n_kpts)
        return eri

    # @partial(jax.jit, static_argnums=(0,))
    def unitary_pw_to_real(self, k_points: jax.Array) -> jax.Array:
        """Unitary transformation from plane wave basis to real cos, sin basis.
        Assumes k_points arranged so that +k, -k pairs are adjacent.
        """
        n_kpts = k_points.shape[0]
        unitary = jnp.zeros((n_kpts, n_kpts), dtype=jnp.complex128)
        unitary_block = jnp.array([[1.0, 1.0], [-1.0j, 1.0j]]) / jnp.sqrt(2.0)
        n_blocks = (n_kpts - 1) // 2
        unitary = unitary.at[0, 0].set(1.0)
        unitary = unitary.at[1:, 1:].set(
            jsp.linalg.block_diag(*([unitary_block] * n_blocks))
        )
        return unitary

    @partial(jax.jit, static_argnums=(0,))
    def get_eri_tensor_real(self, k_points: jax.Array) -> jax.Array:
        """Calculate the ERI tensor in real basis using the unitary transformation."""
        eri = self.get_eri_tensor(k_points)
        unitary = self.unitary_pw_to_real(k_points)
        eri = jnp.einsum("ip,pqrs->iqrs", unitary.conj(), eri, optimize=True)
        eri = jnp.einsum("jq,iqrs->ijrs", unitary, eri, optimize=True)
        eri = jnp.einsum("kr,ijrs->ijks", unitary.conj(), eri, optimize=True)
        eri = jnp.einsum("ls,ijks->ijkl", unitary, eri, optimize=True).real
        return eri

    def __hash__(self):
        return hash(tuple(self.__dict__.values()))

In [4]:
system = ueg_qc(1.0, (1, 1), e_cut=10.0)
k_points = system.get_k_points()
n_kpts = k_points.shape[0]
print(f"Number of k-points: {n_kpts}")
h1 = system.get_h1_real(k_points)
eri = system.get_eri_tensor_real(k_points)

from pyscf import gto, scf, ao2mo, cc

mol = gto.M(verbose=0)
mol.nelectron = system.n_particles
mol.incore_anyway = True
mol.energy_nuc = lambda *args: 0.0
mol.verbose = 3
mf = scf.RHF(mol)
mf.get_hcore = lambda *args: h1
mf.get_ovlp = lambda *args: np.eye(n_kpts)
mf._eri = ao2mo.restore(8, eri, n_kpts)
mf.init_guess = "1e"
escf = mf.kernel()
print(f"escf: {escf}")

mycc = cc.RCCSD(mf)
mycc.kernel()
print("CCSD energy", mycc.e_tot)
et_correction = mycc.ccsd_t()
print("CCSD(T) energy", mycc.e_tot + et_correction)
print("CCSD(T) energy per electron", (mycc.e_tot + et_correction) / system.n_particles)

Number of k-points: 19
converged SCF energy = -1.39700704811961
escf: -1.3970070481196089
E(CCSD) = -1.414895345319693  E_corr = -0.01788829720008443
CCSD energy -1.4148953453196933
CCSD(T) correction = -3.39138626991892e-20
CCSD(T) energy -1.4148953453196933
CCSD(T) energy per electron -0.7074476726598466


Overwritten attributes  get_hcore get_ovlp  of <class 'pyscf.scf.hf.RHF'>
