In [1]:
%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp
from sparse_wf.model.wave_function import MoonLikeWaveFunction
from sparse_wf.system import get_molecule
from sparse_wf.mcmc import init_electrons

In [68]:
molecule_args = {
    'method': 'chain',
    'chain_args': {
        'element': 'H',
        'distance': 1.8,
        'n': 200
    },
    'basis': 'sto-3g'
}
model_args = data = {
    "n_determinants": 16,
    "n_envelopes": 8,
    "embedding": {
        "cutoff": 10.0,
        "feature_dim": 256,
        "nuc_mlp_depth": 4,
        "pair_mlp_widths": [16, 8],
        "pair_n_envelopes": 32
    },
    "jastrow": {
        "e_e_cusps": "psiformer",
        "use_log_jastrow": True,
        "use_mlp_jastrow": True,
        "mlp_depth": 2,
        "mlp_width": 64
    }
}
mol = get_molecule(molecule_args)
wf = MoonLikeWaveFunction.create(mol, **model_args)

In [69]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
electrons = init_electrons(key, mol, 1)[0]
params = wf.init(subkey, electrons)
static = wf.get_static_input(electrons)

In [70]:
logpsi, state = wf.logpsi_with_state(params, electrons, static)

In [73]:
delta = jnp.concatenate([jnp.array([[0.1, 0, 0]]), jnp.zeros((199, 3))], axis=0)
changed_idx = jnp.array([0])

In [74]:
new_logpsi, new_state = wf.low_rank_update(params, electrons + delta, changed_idx, static, state)

In [75]:
real_logpsi, real_state = wf.logpsi_with_state(params, electrons + delta, static)

In [76]:
jax.tree_map(lambda x, y: jnp.linalg.norm(x - y), new_state, real_state)

LowRankState(embedding=MoonState(h_init=Array(4.3259074e-06, dtype=float32), h_init_same=Array(2.863675e-06, dtype=float32), h_init_diff=Array(2.865884e-06, dtype=float32), h0=Array(3.7399934e-06, dtype=float32), h1=Array(4.7234503e-06, dtype=float32), HL_up=Array(1.1350817e-05, dtype=float32), HL_dn=Array(1.1127907e-05, dtype=float32), h_out=Array(8.334951e-06, dtype=float32)), orbitals=Array(1.8817977e-05, dtype=float32), determinant=LogPsiState(matrices=[Array(1.8817977e-05, dtype=float32)], inverses=[Array(18.013235, dtype=float32)], slogdets=[(Array(0., dtype=float32), Array(0.00099631, dtype=float32))]), jastrow=Array(2.3823843e-07, dtype=float32))

In [84]:
(new_logpsi - real_logpsi) / real_logpsi, real_logpsi, new_logpsi, logpsi

(Array(-0., dtype=float32),
 Array(-135.00717, dtype=float32),
 Array(-135.00717, dtype=float32),
 Array(-135.126, dtype=float32))

# Benchmarking

In [78]:
electrons = init_electrons(key, mol, 32)

In [79]:
import functools
@functools.partial(jax.jit, static_argnums=1)
@functools.partial(jax.vmap, in_axes=(0, None))
def normal_fwd(electrons, static):
    return wf.logpsi_with_state(params, electrons, static)[0]

In [80]:
static = wf.get_static_input(electrons)
%time jax.block_until_ready(normal_fwd(electrons, static));
%timeit jax.block_until_ready(normal_fwd(electrons, static));

CPU times: user 9.85 s, sys: 369 ms, total: 10.2 s
Wall time: 11.4 s
98.8 ms ± 170 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [81]:
import functools
@functools.partial(jax.jit, static_argnums=1)
@functools.partial(jax.vmap, in_axes=(0, None, 0))
def fast_fwd(electrons, static, state):
    return wf.low_rank_update(params, electrons, jnp.array([0]), static, state)

In [82]:
static = wf.get_static_input(electrons)
_, state = jax.vmap(wf.logpsi_with_state, in_axes=(None, 0, None))(params, electrons, static)
%time jax.block_until_ready(fast_fwd(electrons, static, state));
%timeit jax.block_until_ready(fast_fwd(electrons, static, state));

CPU times: user 8.18 s, sys: 194 ms, total: 8.37 s
Wall time: 10.1 s
12.3 ms ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [83]:
normal_fwd(electrons, static) - fast_fwd(electrons, static, state)[0]

Array([-1.5258789e-05,  0.0000000e+00,  0.0000000e+00, -3.0517578e-05,
       -1.5258789e-05, -3.0517578e-05,  0.0000000e+00,  0.0000000e+00,
        6.1035156e-05, -3.0517578e-05,  2.1972656e-03, -1.5258789e-05,
        1.5258789e-05, -1.5258789e-05,  6.1035156e-05,  0.0000000e+00,
       -6.1035156e-05, -1.5258789e-05, -4.5776367e-05,  3.0517578e-05,
        4.5776367e-05,  1.5258789e-05, -1.5258789e-05,  0.0000000e+00,
        2.4414062e-04,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        1.1901855e-03,  1.5258789e-05,  1.5258789e-05, -9.9182129e-05],      dtype=float32)