In [1]:
%load_ext autoreload
%autoreload 2
# %%
import functools
import os

# ruff: noqa: E402 # Allow setting environment variables before importing jax
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"


import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pytest
from jax import config as jax_config
from sparse_wf.mcmc import init_electrons
from sparse_wf.model.utils import get_relative_tolerance

from utils import build_atom_chain, build_model, change_float_dtype

jax_config.update("jax_enable_x64", True)
jax_config.update("jax_default_matmul_precision", "highest")


@functools.lru_cache()
def setup_inputs(dtype):
    rng = jax.random.PRNGKey(0)
    rng_r, rng_params = jax.random.split(rng)
    mol = build_atom_chain(10, 2)
    model = build_model(mol)
    # model = jtu.tree_map(lambda x: change_float_dtype(x, dtype), model)
    electrons = init_electrons(rng_r, mol, batch_size=1)[0]
    params = model.init(rng_params, electrons)
    params, electrons = jtu.tree_map(lambda x: change_float_dtype(x, dtype), (params, electrons))
    static_args = model.get_static_input(electrons)
    return model, electrons, params, static_args


# TODO: add separate testcases for embedding, jastrow, determinant, total_logpsi
@pytest.mark.parametrize("dtype", [jnp.float64])
def test_low_rank_update_logpsi(dtype):
    model, electrons, params, static_args = setup_inputs(dtype)
    (sign_old, logpsi_old), state = model.log_psi_with_state(params, electrons, static_args)
    assert logpsi_old.dtype == dtype

    ind_move = np.array(len(electrons) // 2)
    idx_changed = ind_move[None]
    dr = np.array([2, 0, 0]).astype(dtype)
    electrons_new = electrons.at[ind_move].add(dr)

    logpsi_new = model(params, electrons_new, static_args)
    (sign_new, logpsi_new_update), state_new = model.log_psi_low_rank_update(
        params, electrons_new, idx_changed, static_args, state
    )

    assert logpsi_new.dtype == dtype
    assert logpsi_new_update.dtype == dtype
    assert jnp.allclose(logpsi_new, logpsi_new_update, rtol=get_relative_tolerance(dtype))

    electrons_new_new = electrons_new.at[ind_move].add(dr)
    (sign_new_new, logpsi_new_new_update), state_new_new = model.log_psi_low_rank_update(
        params, electrons_new_new, idx_changed, static_args, state_new
    )
    assert jnp.allclose(logpsi_new_new_update, logpsi_new_update, rtol=get_relative_tolerance(dtype))


In [2]:
dtype = jnp.float64
model, electrons, params, static_args = setup_inputs(dtype)
(sign_old, logpsi_old), state = model.log_psi_with_state(params, electrons, static_args)
assert logpsi_old.dtype == dtype

2024-07-07 18:31:52.149127: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [None]:
ind_move = np.array(len(electrons) // 2)
idx_changed = ind_move[None]
dr = np.array([2, 0, 0]).astype(dtype)
electrons_new = electrons.at[ind_move].add(dr)

In [10]:
_, new_state = model.log_psi_with_state(params, electrons_new, static_args)

In [11]:
logpsi_new = model(params, electrons_new, static_args)
(sign_new, logpsi_new_update), state_new = model.log_psi_low_rank_update(
    params, electrons_new, idx_changed, static_args, state
)
assert logpsi_new.dtype == dtype
assert logpsi_new_update.dtype == dtype
assert jnp.allclose(logpsi_new, logpsi_new_update, rtol=get_relative_tolerance(dtype))

In [12]:
electrons_new_new = electrons_new.at[ind_move].add(dr)
logpsi_new_new = model(params, electrons_new_new, static_args)
(sign_new_new, logpsi_new_new_update), state_new_new = model.log_psi_low_rank_update(
    params, electrons_new_new, idx_changed, static_args, state_new
)
assert jnp.allclose(logpsi_new_new_update, logpsi_new_new, rtol=get_relative_tolerance(dtype))

In [52]:
e = electrons
(sign, log_psi), state = model.log_psi_with_state(params, e, static_args)
for _ in range(10):
    ind_move = np.random.randint(0, len(electrons), 1)
    dr = np.random.normal(size=3)
    e = e.at[ind_move].add(dr)
    prev_state = state
    (new_sign, new_log), new_state = model.log_psi_low_rank_update(params, e, ind_move, static_args, state)
    (sign, log_psi), state = model.log_psi_with_state(params, e, static_args)
    assert jnp.allclose(log_psi, new_log, rtol=get_relative_tolerance(dtype)), (log_psi, new_log, _)

[      1       7       9      10      12      19 1000000 1000000 1000000
 1000000 1000000 1000000 1000000 1000000 1000000 1000000]
[      7       9      16      17      18      19 1000000 1000000 1000000
 1000000 1000000 1000000 1000000 1000000 1000000 1000000]
[      3       4       7       9      13      14      15      17      19
 1000000 1000000 1000000 1000000 1000000 1000000 1000000]
[      1       2       3       4       7       9      11      12      14
      17      19 1000000 1000000 1000000 1000000 1000000]
[      1       7       9      10      17      19 1000000 1000000 1000000
 1000000 1000000 1000000 1000000 1000000 1000000 1000000]
[      7       8       9      16      17      18      19 1000000 1000000
 1000000 1000000 1000000 1000000 1000000 1000000 1000000]
[      4       6       7       9      15      16      17      19 1000000
 1000000 1000000 1000000 1000000 1000000 1000000 1000000]
[      2       3       4       7       9      11      12      14      15
      17  

In [26]:
jnp.linalg.norm(new_state.embedding.h_init - state.embedding.h_init)

Array(1.71278835e-14, dtype=float64)

In [27]:
jnp.linalg.norm(new_state.embedding.h0 - state.embedding.h0)

Array(5.61455003e-15, dtype=float64)

In [28]:
jnp.linalg.norm(new_state.embedding.h1 - state.embedding.h1)

Array(1.03040626e-14, dtype=float64)

In [29]:
jnp.linalg.norm(new_state.embedding.HL_dn - state.embedding.HL_dn)

Array(1.90916342e-15, dtype=float64)

In [30]:
jnp.linalg.norm(new_state.embedding.HL_up - state.embedding.HL_up)

Array(2.4490009e-15, dtype=float64)

In [31]:
jnp.linalg.norm(new_state.embedding.h_out - state.embedding.h_out)

Array(1.88909266e-14, dtype=float64)

In [None]:
jnp.linalg.norm(state.orbitals.envelopes - new_state.orbitals.envelopes)

Array(7.25748187e-16, dtype=float64)

In [41]:
jnp.linalg.norm(state.orbitals.orbitals - new_state.orbitals.orbitals)

Array(6.11356715e-15, dtype=float64)

In [44]:
jnp.linalg.norm(state.determinant.inverses[0] - new_state.determinant.inverses[0])

Array(81.27725255, dtype=float64)

In [45]:
jnp.linalg.norm(state.determinant.matrices[0] - new_state.determinant.matrices[0])

Array(6.11356715e-15, dtype=float64)

In [47]:
jnp.linalg.norm(state.determinant.slogdets[0][1] - new_state.determinant.slogdets[0][1])

Array(0.08233283, dtype=float64)

In [46]:
jnp.linalg.norm(state.jastrow - new_state.jastrow )

Array(0., dtype=float64)

In [32]:
idx = np.asarray(new_state.embedding.extras)
idx = idx[idx < 1000]
np_a = np.array(prev_state.embedding.h_out)
np_a[idx] = np.array(new_state.embedding.h_out[idx])
np.linalg.norm(np_a - state.embedding.h_out,axis=-1)

array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.07556884e-14,
       1.84365445e-15, 1.31614063e-14, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       7.97748611e-15, 0.00000000e+00, 0.00000000e+00, 9.59327177e-16])

In [33]:
np.linalg.norm(state.embedding.h_out - prev_state.embedding.h_out, axis=-1)

array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.58056314e-01,
       1.36269349e-04, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       4.08076676e-03, 0.00000000e+00, 0.00000000e+00, 3.18222863e-03])

In [34]:
np.linalg.norm(new_state.embedding.h_out - prev_state.embedding.h_out, axis=-1)

array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.58056314e-01,
       1.36269349e-04, 1.31614063e-14, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       4.08076676e-03, 0.00000000e+00, 0.00000000e+00, 3.18222863e-03])

In [28]:
ind_move

array([5])

In [23]:
idx

array([ 9, 19,  7], dtype=int32)

In [22]:
(new_state.embedding.h_out - state.embedding.h_out).sum(-1)

Array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, -2.59762239e-01,  0.00000000e+00, -6.93889390e-15,
        0.00000000e+00, -3.89965837e-15,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -2.04523898e-15],      dtype=float64)

In [189]:
(new_state.embedding.h_out - prev_state.embedding.h_out).sum(-1)

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.], dtype=float64)

In [190]:
(state.embedding.h_out - prev_state.embedding.h_out).sum(-1)

Array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  1.24866167e+01,
       -3.66057865e-06,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        1.93650218e-03,  0.00000000e+00,  0.00000000e+00,  1.93818585e-03],      dtype=float64)

Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[-2.08213926e-02,  3.87600502e-02,  2.00341720e-02, ...,
         -4.01084727e-04,  6.15049317e-02, -7.62647006e-02],
        [ 1.81537854e-06,  1.82269497e-05,  1.29904522e-05, ...,
         -9.27508612e-06,  1.09319851e-05, -1.33177839e-05]],

       [[-4.00713198e-02,  1.61914462e-01,  1.00967009e-01, ...,
         -4.09011428e-01, -2.02309909e-01,  2.88780137e-01],
        [ 3.50825101e-06,  1.63049997e-05,  1.52910623e-05, ...,
          3.46505555e-07,  1.76055921e-05, -1.90350749e-05]],

       ...,

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000