In [23]:
import itertools

import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import absltest, parameterized
import os
import sys

module_path = os.path.abspath(os.path.join(os.getcwd() + "/../.."))
if module_path not in sys.path:
    sys.path.append(module_path)

from ferminet import base_config, hamiltonian, networks

In [24]:
def h_atom_log_psi(param, xs, spins, atoms=None, charges=None):
    del param, spins, atoms, charges
    # log of exact hydrogen wavefunction.
    return -jnp.abs(jnp.linalg.norm(xs))


def h_atom_log_psi_signed(param, xs, spins, atoms=None, charges=None):
    log_psi = h_atom_log_psi(param, xs, spins, atoms, charges)
    return jnp.ones_like(log_psi), log_psi


def kinetic_from_hessian(log_f):

    def kinetic_operator(params, pos, spins, atoms, charges):
        f = lambda x: jnp.exp(log_f(params, x, spins, atoms, charges))
        ys = f(pos)
        hess = jax.hessian(f)(pos)
        return -0.5 * jnp.trace(hess) / ys

    return kinetic_operator


def kinetic_from_hessian_log(log_f):

    def kinetic_operator(params, pos, spins, atoms, charges):
        f = lambda x: log_f(params, x, spins, atoms, charges)
        grad_f = jax.grad(f)(pos)
        hess = jax.hessian(f)(pos)
        return -0.5 * (jnp.trace(hess) + jnp.sum(grad_f**2))

    return kinetic_operator


class HamiltonianTest(parameterized.TestCase):

    @parameterized.parameters(["default", "folx"])
    def test_local_kinetic_energy(self, laplacian):

        dummy_params = {}
        xs = np.random.normal(size=(3,))
        spins = np.ones(shape=(1,))
        atoms = np.random.normal(size=(1, 3))
        charges = 2 * np.ones(shape=(1,))
        expected_kinetic_energy = -(1 - 2 / np.abs(np.linalg.norm(xs))) / 2

        kinetic = hamiltonian.local_kinetic_energy(
            h_atom_log_psi_signed, laplacian_method=laplacian
        )
        kinetic_energy = kinetic(
            dummy_params,
            networks.FermiNetData(
                positions=xs, spins=spins, atoms=atoms, charges=charges
            ),
        )
        np.testing.assert_allclose(kinetic_energy, expected_kinetic_energy, rtol=1.0e-5)

    def test_potential_energy_null(self):

        # with one electron and a nuclear charge of zero, the potential energy is
        # zero.
        xs = np.random.normal(size=(1, 3))
        r_ae = jnp.linalg.norm(xs, axis=-1)
        r_ee = jnp.zeros(shape=(1, 1, 1))
        atoms = jnp.zeros(shape=(1, 3))
        charges = jnp.zeros(shape=(1,))
        v = hamiltonian.potential_energy(r_ae, r_ee, atoms, charges)
        np.testing.assert_allclose(v, 0.0, rtol=1e-5)

    def test_potential_energy_ee(self):

        xs = np.random.normal(size=(5, 3))
        r_ae = jnp.linalg.norm(xs, axis=-1)
        r_ee = jnp.linalg.norm(xs[None, ...] - xs[:, None, :], axis=-1)
        atoms = jnp.zeros(shape=(1, 3))
        charges = jnp.zeros(shape=(1,))
        mask = ~jnp.eye(r_ee.shape[0], dtype=bool)
        expected_v_ee = 0.5 * np.sum(1.0 / r_ee[mask])
        v = hamiltonian.potential_energy(r_ae, r_ee[..., None], atoms, charges)
        np.testing.assert_allclose(v, expected_v_ee, rtol=1e-5)

    def test_potential_energy_he2_ion(self):

        xs = np.random.normal(size=(1, 3))
        atoms = jnp.array([[0, 0, -1], [0, 0, 1]])
        r_ae = jnp.linalg.norm(xs - atoms, axis=-1)
        r_ee = jnp.zeros(shape=(1, 1, 1))
        charges = jnp.array([2, 2])
        v_ee = -jnp.sum(charges / r_ae)
        v_ae = jnp.prod(charges) / jnp.linalg.norm(jnp.diff(atoms, axis=0))
        expected_v = v_ee + v_ae
        v = hamiltonian.potential_energy(r_ae[..., None], r_ee, atoms, charges)
        np.testing.assert_allclose(v, expected_v, rtol=1e-5)

    def test_local_energy(self):

        spins = np.ones(shape=(1,))
        atoms = np.zeros(shape=(1, 3))
        charges = np.ones(shape=(1,))
        dummy_params = {}
        local_energy = hamiltonian.local_energy(
            h_atom_log_psi_signed, charges, nspins=(1, 0), use_scan=False
        )

        xs = np.random.normal(size=(100, 3))
        key = jax.random.PRNGKey(4)
        keys = jax.random.split(key, num=xs.shape[0])
        batch_local_energy = jax.vmap(
            local_energy,
            in_axes=(
                None,
                0,
                networks.FermiNetData(
                    positions=0, spins=None, atoms=None, charges=None
                ),
            ),
        )
        energies, _ = batch_local_energy(
            dummy_params,
            keys,
            networks.FermiNetData(
                positions=xs, spins=spins, atoms=atoms, charges=charges
            ),
        )

        np.testing.assert_allclose(energies, -0.5 * np.ones_like(energies), rtol=1e-5)


class LaplacianTest(parameterized.TestCase):

    @parameterized.parameters(["default", "folx"])
    def test_laplacian(self, laplacian):

        xs = np.random.uniform(size=(100, 3))
        spins = np.ones(shape=(1,))
        atoms = np.random.normal(size=(1, 3))
        charges = 3 * np.ones(shape=(1,))
        data = networks.FermiNetData(
            positions=xs, spins=spins, atoms=atoms, charges=charges
        )
        dummy_params = {}
        t_l_fn = jax.vmap(
            hamiltonian.local_kinetic_energy(
                h_atom_log_psi_signed, laplacian_method=laplacian
            ),
            in_axes=(
                None,
                networks.FermiNetData(
                    positions=0, spins=None, atoms=None, charges=None
                ),
            ),
        )
        t_l = t_l_fn(dummy_params, data)
        hess_t = jax.vmap(
            kinetic_from_hessian(h_atom_log_psi),
            in_axes=(None, 0, None, None, None),
        )(dummy_params, xs, spins, atoms, charges)
        np.testing.assert_allclose(t_l, hess_t, rtol=1e-5)

    @parameterized.parameters(itertools.product([True, False], ["default", "folx"]))
    def test_fermi_net_laplacian(self, full_det, laplacian):
        natoms = 2
        np.random.seed(12)
        atoms = np.random.uniform(low=-5.0, high=5.0, size=(natoms, 3))
        nspins = (2, 3)
        charges = 2 * np.ones(shape=(natoms,))
        batch = 4
        cfg = base_config.default()
        cfg.network.full_det = full_det
        cfg.network.ferminet.hidden_dims = ((8, 4),) * 2
        cfg.network.determinants = 2
        feature_layer = networks.make_ferminet_features(
            natoms,
            cfg.system.electrons,
            cfg.system.ndim,
        )
        network = networks.make_fermi_net(
            nspins,
            charges,
            full_det=full_det,
            feature_layer=feature_layer,
            **cfg.network.ferminet
        )
        log_network = lambda *args, **kwargs: network.apply(*args, **kwargs)[1]
        key = jax.random.PRNGKey(47)
        params = network.init(key)
        xs = np.random.normal(scale=5, size=(batch, sum(nspins) * 3))
        spins = np.sign(np.random.normal(scale=1, size=(batch, sum(nspins))))
        t_l_fn = jax.jit(
            jax.vmap(
                hamiltonian.local_kinetic_energy(
                    network.apply, laplacian_method=laplacian
                ),
                in_axes=(
                    None,
                    networks.FermiNetData(
                        positions=0, spins=0, atoms=None, charges=None
                    ),
                ),
            )
        )
        t_l = t_l_fn(
            params,
            networks.FermiNetData(
                positions=xs, spins=spins, atoms=atoms, charges=charges
            ),
        )
        hess_t_fn = jax.jit(
            jax.vmap(
                kinetic_from_hessian_log(log_network),
                in_axes=(None, 0, 0, None, None),
            )
        )
        hess_t = hess_t_fn(params, xs, spins, atoms, charges)
        if hess_t.dtype == jnp.float64:
            atol, rtol = 1.0e-10, 1.0e-10
        else:
            # This needs a low tolerance because on fast math optimization in CPU can
            # substantially affect floating point expressions. See
            # https://github.com/google/jax/issues/6566.
            atol, rtol = 4.0e-3, 4.0e-3
        np.testing.assert_allclose(t_l, hess_t, atol=atol, rtol=rtol)

In [3]:
from ferminet.pbc import envelopes
from ferminet.utils import system


def _sc_lattice_vecs(rs: float, nelec: int) -> np.ndarray:
    """Returns simple cubic lattice vectors with Wigner-Seitz radius rs."""
    volume = (4 / 3) * np.pi * (rs**3) * nelec
    length = volume ** (1 / 3)
    return length * np.eye(3)


cfg = base_config.default()
cfg.system.electrons = (7, 7)
# A ghost atom at the origin defines one-electron coordinate system.
# Element 'X' is a dummy nucleus with zero charge
cfg.system.molecule = [system.Atom("X", (0.0, 0.0, 0.0))]
# Pretraining is not currently implemented for systems in PBC
cfg.pretrain.method = None

lattice = _sc_lattice_vecs(1.0, sum(cfg.system.electrons))
kpoints = envelopes.make_kpoints(lattice, cfg.system.electrons)

cfg.system.make_local_energy_fn = "ferminet.pbc.hamiltonian.local_energy"
cfg.system.make_local_energy_kwargs = {"lattice": lattice, "heg": True}
cfg.network.make_feature_layer_fn = "ferminet.pbc.feature_layer.make_pbc_feature_layer"
cfg.network.make_feature_layer_kwargs = {"lattice": lattice, "include_r_ae": False}
cfg.network.make_envelope_fn = "ferminet.pbc.envelopes.make_multiwave_envelope"
cfg.network.make_envelope_kwargs = {"kpoints": kpoints}
cfg.network.full_det = True

In [15]:
import attr
import chex
import jax
import jax.numpy as jnp
from typing import Tuple, Optional, Sequence


@attr.s(auto_attribs=True, kw_only=True)
class SlaterOptions:
    """Options for the k-space Slater determinant.

    Attributes:
        ndim: dimension of system
        full_det: If true, evaluate determinants over all electrons
    """

    ndim: int = 3
    full_det: bool = True


def construct_orbital_matrix(
    pos: jnp.ndarray,
    k_vectors: jnp.ndarray,
    nspins: Tuple[int, int],
    options: SlaterOptions,
) -> jnp.ndarray:
    n_electrons = pos.shape[0] // options.ndim
    pos_reshaped = pos.reshape(n_electrons, options.ndim)

    k_dot_r = jnp.einsum("ik,jk->ij", pos_reshaped, k_vectors)
    orbital_matrix = jnp.exp(1j * k_dot_r)

    return orbital_matrix


@attr.s(auto_attribs=True)
class Network:
    options: SlaterOptions
    init: callable
    apply: callable
    orbitals: callable


def make_kspace_slater(
    nspins: Tuple[int, int],
    k_vectors: jnp.ndarray,
    *,
    ndim: int = 3,
    full_det: bool = True,
) -> Network:
    if sum([nspin for nspin in nspins if nspin > 0]) == 0:
        raise ValueError("No electrons present!")

    if k_vectors.shape[0] != sum(nspins):
        raise ValueError(
            f"Number of k-vectors ({k_vectors.shape[0]}) must match "
            f"number of electrons ({sum(nspins)})"
        )

    options = SlaterOptions(ndim=ndim, full_det=full_det)

    def init(key: chex.PRNGKey) -> dict:
        """Initialize network parameters (just stores k-vectors)."""
        return {"k_vectors": k_vectors}

    def apply(
        params,
        pos: jnp.ndarray,
        spins: jnp.ndarray,
        atoms: jnp.ndarray,
        charges: jnp.ndarray,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        orbital_matrix = construct_orbital_matrix(
            pos, params["k_vectors"], nspins, options
        )

        if not options.full_det:
            # Split into spin-up and spin-down blocks
            n_up = nspins[0]
            up_orbitals = orbital_matrix[:n_up, :n_up]
            down_orbitals = orbital_matrix[n_up:, n_up:]

            # Compute determinants for each spin block
            sign_up, logdet_up = jnp.linalg.slogdet(up_orbitals)
            sign_down, logdet_down = jnp.linalg.slogdet(down_orbitals)

            # For complex matrices, sign is actually a phase
            phase = jnp.angle(sign_up * sign_down)
            return phase, (logdet_up + logdet_down).real

        # Compute full determinant
        sign, logdet = jnp.linalg.slogdet(orbital_matrix)
        # For complex matrices, sign is actually a phase
        phase = jnp.angle(sign)
        return phase, logdet.real

    def orbitals(
        params,
        pos: jnp.ndarray,
        spins: jnp.ndarray,
        atoms: jnp.ndarray,
        charges: jnp.ndarray,
    ) -> jnp.ndarray:
        """Returns orbital matrix without computing determinant."""
        return construct_orbital_matrix(pos, params["k_vectors"], nspins, options)

    return Network(options=options, init=init, apply=apply, orbitals=orbitals)

In [25]:
from ferminet.pbc import feature_layer as pbc_feature_layer
from ferminet.pbc import hamiltonian

cfg = base_config.default()

nspins = (7, 0)
atoms = jnp.asarray([[0.0, 0.0, 0.2], [1.2, 1.0, -0.2], [2.5, -0.8, 0.6]])
natom = atoms.shape[0]
charges = jnp.asarray([2, 5, 7])
spins = np.ones(shape=(1,))
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
xs = jax.random.uniform(subkey, shape=(sum(nspins), 3))

feature_layer = pbc_feature_layer.make_pbc_feature_layer(
    natom, nspins, ndim=3, lattice=jnp.eye(3), include_r_ae=False
)

kpoints = envelopes.make_kpoints(jnp.eye(3), (7, 0), min_kpoints=7)

# network = networks.make_fermi_net(
#     nspins,
#     charges,
#     envelope=envelopes.make_multiwave_envelope(kpoints),
#     feature_layer=feature_layer,
#     bias_orbitals=cfg.network.bias_orbitals,
#     full_det=cfg.network.full_det,
#     **cfg.network.ferminet,
# )
network = make_kspace_slater(nspins, kpoints)

key, subkey = jax.random.split(key)
params = network.init(subkey)

local_energy = hamiltonian.local_energy(
    f=network.apply,
    charges=charges,
    nspins=nspins,
    use_scan=False,
    lattice=jnp.eye(3),
    heg=False,
)

data = networks.FermiNetData(
    positions=xs.flatten(), spins=spins, atoms=atoms, charges=charges
)

key, subkey = jax.random.split(key)
e1, _ = local_energy(params, subkey, data)

# Select random electron coordinate to displace by a random lattice vec
key, subkey = jax.random.split(key)
e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0])
key, subkey = jax.random.split(key)
randvec = jax.random.randint(subkey, (3,), 0, 100).astype(jnp.float32)
xs = xs.at[e_idx].add(randvec)

data2 = networks.FermiNetData(
    positions=xs.flatten(), spins=spins, atoms=atoms, charges=charges
)

key, subkey = jax.random.split(key)
e2, _ = local_energy(params, subkey, data2)

In [26]:
e1, e2

(Array(65.6517, dtype=float32), Array(65.65158, dtype=float32))