In [1]:
%load_ext autoreload
%autoreload 2
# %%
import functools
import os

from sparse_wf.model.utils import get_relative_tolerance
from utils import build_atom_chain, build_model, change_float_dtype

# ruff: noqa: E402 # Allow setting environment variables before importing jax
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pytest
from folx.api import FwdJacobian, FwdLaplArray
from jax import config as jax_config
from sparse_wf.jax_utils import fwd_lap
from sparse_wf.mcmc import init_electrons

jax_config.update("jax_enable_x64", True)
jax_config.update("jax_default_matmul_precision", "highest")


@functools.lru_cache()
def setup_inputs(dtype):
    rng = jax.random.PRNGKey(0)
    rng_r, rng_params = jax.random.split(rng)
    mol = build_atom_chain(10, 2)
    model = build_model(mol)
    model = jtu.tree_map(lambda x: change_float_dtype(x, dtype), model)
    electrons = init_electrons(rng_r, mol, batch_size=1)[0]
    params = model.init(rng_params, electrons)
    model, params, electrons = jtu.tree_map(lambda x: change_float_dtype(x, dtype), (model, params, electrons))
    static_args = model.get_static_input(electrons)
    return model, electrons, params, static_args


def to_zero_padded(x, dependencies):
    jac = x.jacobian.data
    n_el = x.shape[-2]
    n_centers = jac.shape[-2]
    jac = jac.reshape([-1, 3, *jac.shape[1:]])
    jac_out = jnp.zeros([n_el, 3, *jac.shape[2:]], jac.dtype)
    for i in range(n_centers):
        jac_out = jac_out.at[dependencies[i], ..., i, :].set(jac[:, ..., i, :], mode="drop")
    jac_out = jac_out.reshape([n_el * 3, *jac.shape[2:]])
    return FwdLaplArray(x.x, FwdJacobian(data=jac_out), x.laplacian)


def assert_close(x: FwdLaplArray, y: FwdLaplArray, rtol=None):
    rtol = get_relative_tolerance(x.x.dtype) if rtol is None else rtol

    def rel_error(a, b):
        return jnp.linalg.norm(a - b) / jnp.linalg.norm(b)

    error_val = rel_error(x.x, y.x)
    error_lap = rel_error(x.laplacian, y.laplacian)
    error_jac = rel_error(x.jacobian.dense_array, y.jacobian.dense_array)
    assert all(
        [e < rtol for e in [error_val, error_lap, error_jac]]
    ), f"Rel. errors: {error_val}, {error_lap}, {error_jac}"


@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
def test_embedding(dtype):
    model, electrons, params, static_args = setup_inputs(dtype)
    embedding_int, dependencies = model.embedding.apply_with_fwd_lap(params.embedding, electrons, static_args)
    embedding_int = to_zero_padded(embedding_int, dependencies)
    embedding_ext = fwd_lap(lambda r: model.embedding.apply(params.embedding, r, static_args))(electrons)
    assert embedding_ext.dtype == dtype
    assert embedding_int.dtype == dtype
    assert_close(embedding_int, embedding_ext)


@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
def test_orbitals(dtype):
    model, electrons, params, static_args = setup_inputs(dtype)
    orbitals_int, dependencies = model.orbitals_with_fwd_lap(params, electrons, static_args)
    orbitals_ext = fwd_lap(lambda r: model.orbitals(params, r, static_args)[0])(electrons)
    orbitals_int = to_zero_padded(orbitals_int, dependencies)
    assert orbitals_int.dtype == dtype
    assert orbitals_ext.dtype == dtype
    assert_close(orbitals_int, orbitals_ext)


@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
def test_energy(dtype):
    # Use higher tolerance for energy due to possibly ill-conditioned orbital matrx
    # TODO: find a way to get samples/parameters which don't lead to ill-conditioned matrices and allow better testing
    rtol = get_relative_tolerance(dtype) * 1e3
    model, electrons, params, static_args = setup_inputs(dtype)

    E_dense = model.local_energy_dense(params, electrons, static_args)
    E_sparse = model.local_energy(params, electrons, static_args)
    for E, label in zip([E_sparse, E_dense], ["sparse", "dense"]):
        assert E.dtype == dtype, f"energy {label}: {E.dtype} != {dtype}"
        assert np.isfinite(E), f"energy {label}: {E} != {dtype}"

    rel_error = jnp.abs(E_sparse - E_dense) / jnp.abs(E_dense)

    assert rel_error < rtol, f"Rel. error |E_sparse - E_dense| / |E_dense|: {rel_error}"

In [5]:
dtype = jnp.float32
model, electrons, params, static_args = setup_inputs(dtype)

In [14]:
embedding_ext = fwd_lap(lambda r: model.embedding.apply(params.embedding, r, static_args))(electrons)

In [15]:
embedding_int, dependencies = model.embedding.apply_with_fwd_lap(params.embedding, electrons, static_args)
embedding_int = to_zero_padded(embedding_int, dependencies)
assert embedding_ext.dtype == dtype
assert embedding_int.dtype == dtype
assert_close(embedding_int, embedding_ext)

14
[[      0       1      10 1000000 1000000 1000000 1000000 1000000 1000000
  1000000 1000000 1000000 1000000 1000000]
 [      1       0       2      10      11      12 1000000 1000000 1000000
  1000000 1000000 1000000 1000000 1000000]
 [      2       1       3       4      10      11      12      13      14
       15 1000000 1000000 1000000 1000000]
 [      3       2       4       5       6      11      12      13      14
       15      16 1000000 1000000 1000000]
 [      4       2       3       5       6       7      11      12      13
       14      15      16 1000000 1000000]
 [      5       3       4       6       7      13      14      15      16
       17      19 1000000 1000000 1000000]
 [      6       3       4       5       7       8      13      14      15
       16      17      18      19 1000000]
 [      7       4       5       6       8       9      15      16      17
       18      19 1000000 1000000 1000000]
 [      8       6       7       9      15      16      17    

AssertionError: Rel. errors: 2.930747200480255e-07, 0.006100487895309925, 0.19240273535251617

In [None]:
embedding_ext.x - embedding_int.x

Array([[ 0.0000000e+00,  1.4901161e-07, -1.1920929e-07, ...,
         2.3841858e-07,  4.1723251e-07,  1.4901161e-07],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00, -2.3841858e-07],
       [ 1.9073486e-06, -7.6293945e-06, -1.9073486e-06, ...,
        -4.7683716e-07,  2.5033951e-06, -1.1444092e-05],
       ...,
       [ 2.3841858e-07,  0.0000000e+00, -9.5367432e-07, ...,
        -9.5367432e-07,  2.3841858e-07, -9.5367432e-07],
       [ 2.3841858e-07, -2.9802322e-07,  0.0000000e+00, ...,
        -2.3841858e-06, -7.1525574e-07,  1.6689301e-06],
       [ 0.0000000e+00, -5.9604645e-07,  0.0000000e+00, ...,
         3.8146973e-06,  3.8146973e-06, -1.9073486e-06]], dtype=float32)

In [34]:
jnp.linalg.norm((embedding_ext.jacobian.data - embedding_int.jacobian.data), axis=(0, 2))

Array([ 9.148434,  6.346167, 20.676302, 24.918758, 29.507683, 31.473238,
       29.961163, 27.230335, 30.571577, 26.17066 , 16.336601, 16.525473,
       35.308834, 30.023981, 32.112034, 18.83667 , 27.498219, 27.408297,
       22.192966, 19.152874], dtype=float32)

In [36]:
jnp.linalg.norm((embedding_ext.laplacian - embedding_int.laplacian), axis=1)

Array([ 6.7464504,  7.561483 , 12.229936 , 17.249681 , 10.963287 ,
       16.064144 , 28.501211 , 76.87907  , 20.849226 , 42.943302 ,
        8.318078 ,  7.748808 , 13.701675 , 27.14995  , 18.695047 ,
       13.574827 , 18.355412 , 41.021065 , 21.978632 , 25.455654 ],      dtype=float32)