In [1]:
import jax.numpy as jnp
from jax.config import config

from jaxsce import densities

config.update("jax_enable_x64", True)

name = "sqrt_r"
kwargs = {"Nel": 10}



In [2]:
def test_Nel_density(name: str, kwargs: dict):
    # Intitialize density
    density = densities.get_density(name, **kwargs)

    # Check that density has all the required attributes
    assert density.Nel
    assert jnp.all(density.a >= 0.0)
    assert density.a.shape == (density.Nel,)
    assert density.LDA_int
    assert density.GEA_int

    # Test positivity of density
    rho = density.rho(density.a)
    assert jnp.all(rho >= 0.0)

    # Test that cumulant integrates to the correct number of electrons
    Ne = density.Ne(density.a)
    assert jnp.allclose(Ne, jnp.arange(density.Nel))

    # Test that the derivative of the cumulant is correct
    Ne_deriv = density.Ne_deriv(density.a)
    assert jnp.allclose(Ne_deriv, 4 * jnp.pi * density.a**2 * rho)

    # Test if the second derivative works
    Ne_deriv2 = density.Ne_deriv2(density.a)

    # Test that the co-Cumulant is positive
    coNe = density.coNe(density.a)
    assert jnp.all(coNe >= 0.0)

    # Test that the derivative of the co-Cumulant works
    density.coNe_deriv(density.a)

    # Test that the inverse of the cumulant is correct
    invNe = density.invNe(Ne)
    assert jnp.allclose(invNe, density.a)

    # Test that the derivative of the inverse of the cumulant is correct
    invNe_deriv = density.invNe_deriv(Ne, invNe)
    assert jnp.allclose(invNe_deriv, 1.0 / Ne_deriv)

    # Test that the co-motion functions are positive
    co_motion_function = density.co_motion_function(density.a[1])
    assert jnp.all(co_motion_function >= 0.0)
    co_motion_functions = density.co_motion_functions(density.a)
    assert jnp.all(co_motion_functions >= 0.0)
    assert jnp.allclose(co_motion_function, co_motion_functions[1])

    # Test that the co-motion functions and its derivatives computed from Ne
    # are the same as those computed from a
    co_motion_function_Ne = density.co_motion_function_Ne(Ne[1])
    assert jnp.allclose(co_motion_function_Ne, co_motion_function)
    co_motion_functions_Ne = density.co_motion_functions_Ne(Ne)
    assert jnp.allclose(co_motion_functions_Ne, co_motion_functions)
    co_motion_function_deriv = density.co_motion_function_deriv(density.a[1])
    co_motion_function_deriv_Ne = density.co_motion_function_deriv_Ne(Ne[1])
    assert jnp.allclose(co_motion_function_deriv, co_motion_function_deriv_Ne*Ne_deriv[1])
    co_motion_functions_deriv = density.co_motion_functions_deriv(density.a)
    co_motion_functions_deriv_Ne = density.co_motion_functions_deriv_Ne(Ne)
    # only test derivatives for a[1:] because co-motion functions are ill behaved at zero
    assert jnp.allclose(co_motion_functions_deriv[1:], co_motion_functions_deriv_Ne[1:]*Ne_deriv[1:, None])
    co_motion_function_deriv2 = density.co_motion_function_deriv2(density.a[1])
    co_motion_function_deriv2_Ne = density.co_motion_function_deriv2_Ne(Ne[1])
    assert jnp.allclose(co_motion_function_deriv2, co_motion_function_deriv2_Ne*Ne_deriv[1]**2+co_motion_function_deriv_Ne*Ne_deriv2[1])
    co_motion_functions_deriv2 = density.co_motion_functions_deriv2(density.a)
    co_motion_functions_deriv2_Ne = density.co_motion_functions_deriv2_Ne(Ne)
    assert jnp.allclose(co_motion_functions_deriv2[1:], co_motion_functions_deriv2_Ne[1:]*Ne_deriv[1:, None]**2+co_motion_functions_deriv_Ne[1:, :]*Ne_deriv2[1:, None])

    # Test that vH is positive
    vH = density.vH(density.a)
    assert jnp.all(vH >= 0.0)

    # Test that the derivative of vH works
    density.vH_deriv(density.a)


In [5]:
test_Nel_density(name, kwargs)

AssertionError: 

In [22]:
# Intitialize density
density = densities.get_density(name, **kwargs)

# Check that density has all the required attributes
assert density.Nel
assert jnp.all(density.a >= 0.0)
assert density.a.shape == (density.Nel,)
assert density.LDA_int
assert density.GEA_int

# Test positivity of density
rho = density.rho(density.a)
assert jnp.all(rho >= 0.0)

# Test that cumulant integrates to the correct number of electrons
Ne = density.Ne(density.a)
assert jnp.allclose(Ne, jnp.arange(density.Nel))

# Test that the derivative of the cumulant is correct
Ne_deriv = density.Ne_deriv(density.a)
assert jnp.allclose(Ne_deriv, 4 * jnp.pi * density.a**2 * rho)

# Test if the second derivative works
Ne_deriv2 = density.Ne_deriv2(density.a)

# Test that the co-Cumulant is positive
coNe = density.coNe(density.a)
assert jnp.all(coNe >= 0.0)

# Test that the derivative of the co-Cumulant works
density.coNe_deriv(density.a)

# Test that the inverse of the cumulant is correct
invNe = density.invNe(Ne)
assert jnp.allclose(invNe, density.a)

# Test that the derivative of the inverse of the cumulant is correct
invNe_deriv = density.invNe_deriv(Ne, invNe)
assert jnp.allclose(invNe_deriv, 1.0 / Ne_deriv)

# Test that the co-motion functions are positive
co_motion_function = density.co_motion_function(density.a[1])
assert jnp.all(co_motion_function >= 0.0)
co_motion_functions = density.co_motion_functions(density.a)
assert jnp.all(co_motion_functions >= 0.0)
assert jnp.allclose(co_motion_function, co_motion_functions[1])

# Test that the co-motion functions and its derivatives computed from Ne
# are equal to the ones computed from r=a
co_motion_function_Ne = density.co_motion_function_Ne(Ne[1])
assert jnp.allclose(co_motion_function_Ne, co_motion_function)
co_motion_functions_Ne = density.co_motion_functions_Ne(Ne)
assert jnp.allclose(co_motion_functions_Ne, co_motion_functions)

# We don't want to compute the derivatives at a, because the co-motions
# there are singular
r = jnp.array([0.5, 1., 1.5, 2.])
Ne = density.Ne(r)
Ne_deriv = density.Ne_deriv(r)
density.Ne_deriv2(r)
co_motion_function_deriv = density.co_motion_function_deriv(r[1])
co_motion_function_deriv_Ne = density.co_motion_function_deriv_Ne(Ne[1])
assert jnp.allclose(co_motion_function_deriv, co_motion_function_deriv_Ne*Ne_deriv[1])
co_motion_functions_deriv = density.co_motion_functions_deriv(r)
co_motion_functions_deriv_Ne = density.co_motion_functions_deriv_Ne(Ne)
# only test derivatives for a[1:] because co-motion functions are ill behaved at zero
assert jnp.allclose(co_motion_functions_deriv, co_motion_functions_deriv_Ne*Ne_deriv[:, None])

# TODO: Second derivatives need a sign fix!
# co_motion_function_deriv2 = density.co_motion_function_deriv2(r[1])
# co_motion_function_deriv2_Ne = density.co_motion_function_deriv2_Ne(Ne[1])
# assert jnp.allclose(co_motion_function_deriv2, co_motion_function_deriv2_Ne*Ne_deriv[1]**2+co_motion_function_deriv_Ne*Ne_deriv2[1])
# co_motion_functions_deriv2 = density.co_motion_functions_deriv2(density.a)
# co_motion_functions_deriv2_Ne = density.co_motion_functions_deriv2_Ne(Ne)
# assert jnp.allclose(co_motion_functions_deriv2[:], co_motion_functions_deriv2_Ne[:]*Ne_deriv[:, None]**2+co_motion_functions_deriv_Ne*Ne_deriv2[:, None])

# Test that vH is positive
vH = density.vH(density.a)
assert jnp.all(vH >= 0.0)

# Test that the derivative of vH works
density.vH_deriv(density.a)

AssertionError: 

In [21]:
# Test that the co-motion functions and its derivatives computed from Ne
r = jnp.array([0.5, 1., 1.5, 2.])
Ne = density.Ne(r)
Ne_deriv = density.Ne_deriv(r)
Ne_deriv2 = density.Ne_deriv2(r)
co_motion_function_deriv = density.co_motion_function_deriv(r[0])
co_motion_function_deriv_Ne = density.co_motion_function_deriv_Ne(Ne[0])
assert jnp.allclose(co_motion_function_deriv, co_motion_function_deriv_Ne*Ne_deriv[0])
co_motion_functions_deriv = density.co_motion_functions_deriv(r)
co_motion_functions_deriv_Ne = density.co_motion_functions_deriv_Ne(Ne)
# only test derivatives for a[1:] because co-motion functions are ill behaved at zero
assert jnp.allclose(co_motion_functions_deriv, co_motion_functions_deriv_Ne*Ne_deriv[:, None])
# co_motion_function_deriv2 = density.co_motion_function_deriv2(r[0])
# co_motion_function_deriv2_Ne = density.co_motion_function_deriv2_Ne(Ne[0])
# assert jnp.allclose(co_motion_function_deriv2, co_motion_function_deriv2_Ne*Ne_deriv[0]**2+co_motion_function_deriv_Ne*Ne_deriv2[0])
co_motion_functions_deriv2 = density.co_motion_functions_deriv2(r)
co_motion_functions_deriv2_Ne = density.co_motion_functions_deriv2_Ne(Ne)
assert jnp.allclose(co_motion_functions_deriv2, co_motion_functions_deriv2_Ne*Ne_deriv[:, None]**2+co_motion_functions_deriv_Ne*Ne_deriv2[:, None])

# Test that vH is positive
vH = density.vH(density.a)
assert jnp.all(vH >= 0.0)

# Test that the derivative of vH works
density.vH_deriv(density.a)

AssertionError: 

In [14]:
co_motion_function_deriv2

Array([ 0.        , -0.66705519,  0.66295086, -0.68763517,  0.73280774,
       -0.86073055,  1.01538529, -1.49647233,  2.34653081,  4.17741074],      dtype=float64)

In [19]:
co_motion_function_deriv2_Ne*Ne_deriv[1]**2+co_motion_function_deriv_Ne*Ne_deriv2[1]

Array([-1.77635684e-15,  9.16752340e-01,  6.62950865e-01,  6.74666021e-01,
        7.32807737e-01,  7.10161210e-01,  1.01538529e+00,  9.10457494e-01,
        2.34653081e+00,  4.17741074e+00], dtype=float64)

In [13]:
co_motion_function_deriv2-co_motion_function_deriv2_Ne*Ne_deriv[1]**2-co_motion_function_deriv_Ne*Ne_deriv2[1]

Array([ 1.77635684e-15, -1.58380753e+00,  2.22044605e-16, -1.36230119e+00,
        0.00000000e+00, -1.57089176e+00,  0.00000000e+00, -2.40692983e+00,
        2.22044605e-16, -1.77635684e-15], dtype=float64)

In [29]:
density.co_motion_function_deriv2(1.)

TypeError: 'DynamicJaxprTracer' object is not callable

In [12]:
print(jnp.where(jnp.isnan(co_motion_functions_deriv - co_motion_functions_deriv_Ne*Ne_deriv)))

(Array([0, 0, 2, 2, 8, 8], dtype=int64), Array([0, 9, 1, 8, 2, 7], dtype=int64))


In [20]:
jnp.allclose(co_motion_functions_deriv-co_motion_functions_deriv_Ne*Ne_deriv, jnp.zeros_like(co_motion_functions_deriv), equal_nan=True)

Array(False, dtype=bool)

In [22]:
co_motion_functions_deriv-co_motion_functions_deriv_Ne*Ne_deriv

Array([[            nan,  7.75612086e-01, -1.00000000e+00,
         1.00578541e+00, -1.00000000e+00,  1.13230699e+00,
        -1.00000000e+00,  1.37543184e+00, -1.00000000e+00,
                    nan],
       [ 1.00000000e+00,  0.00000000e+00, -2.07766726e-01,
         2.81840179e-01, -2.96510506e-01,  2.28717305e-01,
        -1.54468748e-01, -6.49775961e-02,  8.45121880e-01,
        -1.66335432e+00],
       [ 1.00000000e+00,             nan,  1.11022302e-16,
         7.99992761e-02, -8.30831434e-02,  5.22751362e-03,
         2.11312892e-01, -3.05942281e-01,             nan,
        -1.33859353e+00],
       [ 1.00000000e+00,  3.92447713e-01,  7.95527279e-02,
         0.00000000e+00,  8.52998632e-03, -6.88760092e-02,
         6.58923404e-01, -3.49746486e-01, -1.89034919e+00,
        -1.08306234e+00],
       [ 1.00000000e+00,  2.98174890e-01,  8.30831434e-02,
        -2.12135200e+10,  0.00000000e+00, -6.81737399e-02,
         1.72809428e+15, -3.21755994e-01, -1.02793070e+00,
        -8.