In [122]:
import jax
import jax.numpy as jnp
from jax import vmap
from jax import random
import numpy as np
class testClass:
    def __init__(self, N, dim):
        self._N = N
        self._dim = dim
        self.la = jax.numpy.linalg

    def log_wfi_dumb(self, r, params): # dumb because JW does not need to be (N, N), it can be (N*(N-1)//2, )
        epsilon = 1e-6  # Small epsilon value
        r_cpy = r.reshape(-1, self._N, self._dim)
        r_diff = r_cpy[:, None, :, :] - r_cpy[:, :, None, :]
        r_dist = self.la.norm(r_diff + epsilon, axis=-1)  # Add epsilon to avoid nan

        rij = jnp.triu(r_dist, k=1)

        x = jnp.einsum('nij,ij->n', rij, params['JW_dumb'])
        
        return x.squeeze(-1)
    
    def log_wfi_smart(self, r, params):
        epsilon = 1e-6  # Small epsilon value for stability
        r_cpy = r.reshape(-1, self._N, self._dim)
        r_diff = r_cpy[:, None, :, :] - r_cpy[:, :, None, :]
        r_dist = self.la.norm(r_diff + epsilon, axis=-1)  # Add epsilon to avoid nan
        
        rij = jnp.triu(r_dist, k=1)
        
        # Generate indices for the upper triangular matrix excluding diagonal
        triu_indices = jnp.triu_indices(self._N, k=1)
        k = triu_indices[0] * (self._N - 1) - triu_indices[0] * (triu_indices[0] + 1) // 2 + triu_indices[1] - triu_indices[0] - 1
        
        # Use these indices to select the corresponding elements from the JW_smart vector
        JW_smart = params['JW_smart'][k]
        
        # Perform the element-wise multiplication and sum
        x = jnp.einsum('nij,i->n', rij, JW_smart)
        
        return x.squeeze(-1)

    
    # jax grad
    def grad_log_wfi_dumb(self, r, params):
        return vmap(jax.grad(self.log_wfi_dumb, argnums=0), in_axes=(0, None))(r, params)


    def grad_log_wfi_smart(self, r, params):
        return vmap(jax.grad(self.log_wfi_smart, argnums=0), in_axes=(0, None))(r, params)
        
        
    
# initialize r with shape (nbatch, particles * dim)
nbatch = 1
N = 2
dim = 1
key = random.PRNGKey(1)
r = random.normal(key, (nbatch, N * dim))

print("r", r.reshape(-1, N, dim))
# initialize params

params_dumb = {'JW_dumb': random.normal(key, (N, N))}
params_smart = {'JW_smart': random.normal(key, (N * (N-1) // 2, ))}

# grad
# instantiate the class
your_class = testClass(N, dim)
your_class.log_wfi_dumb(r, params_dumb)

#print the forward pass
log_wfi = your_class.log_wfi_dumb(r, params_dumb)
print("jax log_wfi dumb", log_wfi, "with shape", log_wfi.shape)
log_wfi = your_class.log_wfi_smart(r, params_smart)
print("jax log_wfi smart", log_wfi, "with shape", log_wfi.shape)

grad_log_wfi = your_class.grad_log_wfi_dumb(r, params_dumb)
print("jax grad", grad_log_wfi, "with shape", grad_log_wfi.shape)



r [[[-0.11617039]
  [ 2.2125063 ]]]
jax log_wfi dumb -1.0731523 with shape ()
jax log_wfi smart -2.7578166 with shape ()
jax grad [[ 0.4608419 -0.4608419]] with shape (1, 2)


In [144]:
nbatch = 1
N = 2
dim = 1
r = np.random.randn(nbatch, N * dim)
#r = np.random.randn(N * dim,)
# mode 1
r_reshaped = r.reshape(-1, N, dim)
for i in range(N):
    # access particle i
    print("particle", i, r_reshaped[:, i, :])
print("=====")
# mode 2
# instead of reshaping, use strided access
for i in range(0, N*dim, dim):
    print("particle", i//dim, r[..., i:i+dim])

particle 0 [[0.25483576]]
particle 1 [[-0.71301476]]
=====
particle 0 [[0.25483576]]
particle 1 [[-0.71301476]]


In [46]:
import numpy.polynomial.hermite as P
def hermite(vals, deg, dim):
    """
    Compute the product of Hermite polynomials for the given values, degree, and dimension.

    Parameters:
    - vals: Array-like of values for which to compute the Hermite polynomial product.
    - deg: The degree of the Hermite polynomial.
    - dim: The dimension, indicating how many values and subsequent polynomials to consider.

    Returns:
    - The product of Hermite polynomials for the given inputs.
    """
    # Error handling for input parameters
    if not isinstance(vals, list) or not isinstance(deg, int) or not isinstance(dim, int):
        raise ValueError("Invalid input types for vals, deg, or dim.")
    if len(vals) != dim:
        raise ValueError("Dimension mismatch between 'vals' and 'dim'.")

    # Compute the product of Hermite polynomials across the given dimensions
    hermite_product = 1
    for val in vals:
        hermite_poly = P.Hermite([0] * deg + [1])(val)
        hermite_product *= hermite_poly

    return hermite_product

print(hermite([2, 2], 1, dim=2))


ValueError: Dimension mismatch between 'vals' and 'dim'.

In [4]:
def generate_degrees(nparticles, dim):
    max_comb = nparticles // 2
    combinations = [[0] * dim]
    seen = {tuple(combinations[0])}
    
    while len(combinations) < max_comb:
        new_combinations = []
        for comb in combinations:
            for i in range(dim):
                # Try incrementing each dimension by 1
                new_comb = comb.copy()
                new_comb[i] += 1
                new_comb_tuple = tuple(new_comb)
                if new_comb_tuple not in seen:
                    seen.add(new_comb_tuple)
                    new_combinations.append(new_comb)
                    if len(seen) == max_comb:
                        return np.array(combinations + new_combinations)
        combinations += new_combinations
    
    return np.array(combinations)
dim = 3
nparticles = 10
combinations_v2 = generate_degrees(nparticles, dim)
combinations_v2


array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0],
       [0, 0, 1],
       [2, 0, 0]])

In [33]:
import jax.numpy as jnp
import jax
from numpy import polynomial as P
import numpy as np
nparticles = 10
dim = 2



def hermite(r, degs):
    """
    Compute the product of Hermite polynomials for the given values, degree, and dimension.

    Parameters:
    - vals: Array-like of values for which to compute the Hermite polynomial product. It should be of shape (nbatch, dim)
    - degs: The degrees of the Hermite polynomials.
    - dim: The dimension, indicating how many values and subsequent polynomials to consider.

    Returns:
    - The product of Hermite polynomials for the given inputs.

    #TODO: move this to a helper function
    """
    # Error handling for input parameters
    #if not isinstance(vals, list) or not isinstance(deg, int) or not isinstance(self.dim, int):
    #    raise ValueError("Invalid input types for vals, deg, or dim.")
    #if len(vals) != dim:
    #    raise ValueError("Dimension mismatch between 'vals' and 'dim'.")


    # Compute the product of Hermite polynomials across the given dimensions
    hermite_product = 1
    for batch in range(r.shape[0]):
        #cartesian of r[batch] and degs
        
        for i in range(len(degs)):
            deg = degs[i]
            r_ = r[batch][i]
            #print(f"print(P.Hermite([0] * {deg} + [1])({r_}))")
            hermite_poly = P.Hermite([0] * int(deg) + [1])(r_)
            hermite_product *= hermite_poly


    return hermite_product


def log_slater(r):
    """
    Decomposed spin Slater determinant in log domain.
    ln psi = ln det (D(up)) + ln det (D(down))
    In our ground state, half of the particles are spin up and half are spin down.
    We will also add the 1/sqrt(N!) normalization factor here.

    D = |phi_1(r_1) phi_2(r_1) ... phi_n(r_1)| 
        |phi_1(r_2) phi_2(r_2) ... phi_n(r_2)| 
        |   ...         ...          ...     | 
        |phi_1(r_n) phi_2(r_n) ... phi_n(r_n)| 

    where phi_i is the i-th single particle wavefunction, in our case it is a hermite polynomial.
    """
    A = nparticles // 2
    r = r.reshape(-1, nparticles, dim)
    r_up = r[:, : A, :]
    r_down = r[:, A :, :]
    slater_fact = np.log(np.sqrt(np.math.factorial(nparticles)))

    # Compute the Slater determinant for the spin up particles
    D_up = jnp.zeros((r.shape[0], A, A))
    D_down = jnp.zeros((r.shape[0], A, A))

    degree_combs = generate_degrees(nparticles, dim)

    for part in range(A):
        for j in range(A):
            degrees = degree_combs[j] 
            
            #print("====hermite", hermite(r_up[:, part, :], degrees))
            
            D_up = D_up.at[:, part, j].set(hermite(r_up[:, part, :], degrees))
            D_down = D_down.at[:, part, j].set(hermite(r_down[:, part, :], degrees))

    # Compute the Slater determinant for the spin down particles
    log_slater_up = jnp.linalg.slogdet(D_up)[0]
    log_slater_down = jnp.linalg.slogdet(D_down)[0]

    return log_slater_up + log_slater_down - 0.5 * slater_fact # we can precompute this. The minus sign is because we are computing the log of the inverse of the factorial


r = jnp.array(np.random.randn(1, nparticles * dim))
log_slater(r)

  slater_fact = np.log(np.sqrt(np.math.factorial(nparticles)))


print(P.Hermite([0] * 0 + [1])(-0.4519834816455841))
print(P.Hermite([0] * 0 + [1])(-0.5464293956756592))
print(P.Hermite([0] * 0 + [1])(-0.35051533579826355))
print(P.Hermite([0] * 0 + [1])(0.44277459383010864))
print(P.Hermite([0] * 1 + [1])(-0.4519834816455841))
print(P.Hermite([0] * 0 + [1])(-0.5464293956756592))
print(P.Hermite([0] * 1 + [1])(-0.35051533579826355))
print(P.Hermite([0] * 0 + [1])(0.44277459383010864))
print(P.Hermite([0] * 0 + [1])(-0.4519834816455841))
print(P.Hermite([0] * 1 + [1])(-0.5464293956756592))
print(P.Hermite([0] * 0 + [1])(-0.35051533579826355))
print(P.Hermite([0] * 1 + [1])(0.44277459383010864))
print(P.Hermite([0] * 2 + [1])(-0.4519834816455841))
print(P.Hermite([0] * 0 + [1])(-0.5464293956756592))
print(P.Hermite([0] * 2 + [1])(-0.35051533579826355))
print(P.Hermite([0] * 0 + [1])(0.44277459383010864))
print(P.Hermite([0] * 1 + [1])(-0.4519834816455841))
print(P.Hermite([0] * 1 + [1])(-0.5464293956756592))
print(P.Hermite([0] * 1 + [1])(-0.35051533

Array([-3.7761033], dtype=float32)

In [93]:
dim = 3
def det(r):
    det_r = jnp.linalg.det(r)
    return det_r

def log_det(r):
    log_det_r = jnp.linalg.slogdet(r)[1]
    return log_det_r

def wf(r):
    return jnp.exp(-jnp.linalg.norm(r)**2)

def log_wf(r):
    return -jnp.linalg.norm(r)**2

r = jnp.array(np.random.randn(dim, dim)) 

grad_det = jax.grad(det)(r)
grad_wf = jax.grad(wf)(r)


print("1/wf * grad_wf", grad_wf / wf(r))

print("grad ln ", jax.grad(log_wf)(r))

print("\n grad ln wf == 1/wf * grad_wf", (abs(jax.grad(log_wf)(r) - grad_wf / wf(r)) < 0.001).all())

print("\n  1/det * grad det", grad_det / det(r))

print("==== det", det(r))
print("grad ln det", jax.grad(log_det)(r))



print("\n grad ln det == 1/det * grad_det", (abs(jax.grad(log_det)(r) - grad_det / det(r)) < 0.001).all())


1/wf * grad_wf [[-2.0443762e-04  3.5990626e-04 -7.4185060e-05]
 [-1.9279356e-05 -5.8702957e-05  4.9267084e-05]
 [ 1.7390294e-04  5.0164206e-05 -2.6005344e-04]]
grad ln  [[-2.0443763e-04  3.5990626e-04 -7.4185060e-05]
 [-1.9279356e-05 -5.8702961e-05  4.9267088e-05]
 [ 1.7390294e-04  5.0164210e-05 -2.6005344e-04]]

 grad ln wf == 1/wf * grad_wf True

  1/det * grad det [[12654.426   3515.1284  9140.335 ]
 [88889.74   65342.605  72046.914 ]
 [13230.24   11376.392  18732.555 ]]
==== det 2.5276656e-13
grad ln det [[12654.425  3515.128  9140.333]
 [88889.734 65342.6   72046.91 ]
 [13230.239 11376.391 18732.553]]

 grad ln det == 1/det * grad_det False


In [None]:
nparticles = 6

pade_aij = jnp.zeros((nparticles, nparticles))
for i in range(nparticles):
    for j in range(i+1, nparticles):
        # first N//2 particles are spin up, the rest are spin down
        # there is a more efficient way to do this for sure
        if i < nparticles // 2 and j < nparticles // 2:
            pade_aij = pade_aij.at[i, j].set(1 / (dim + 1))
        elif i >= nparticles // 2 and j >= nparticles // 2:
            pade_aij = pade_aij.at[i, j].set(1 / (dim + 1))
        else: 
            pade_aij = pade_aij.at[i, j].set(1 / (dim - 1))