In [122]:
import jax
import jax.numpy as jnp
from jax import vmap
from jax import random
import numpy as np
class testClass:
    def __init__(self, N, dim):
        self._N = N
        self._dim = dim
        self.la = jax.numpy.linalg

    def log_wfi_dumb(self, r, params): # dumb because JW does not need to be (N, N), it can be (N*(N-1)//2, )
        epsilon = 1e-6  # Small epsilon value
        r_cpy = r.reshape(-1, self._N, self._dim)
        r_diff = r_cpy[:, None, :, :] - r_cpy[:, :, None, :]
        r_dist = self.la.norm(r_diff + epsilon, axis=-1)  # Add epsilon to avoid nan

        rij = jnp.triu(r_dist, k=1)

        x = jnp.einsum('nij,ij->n', rij, params['JW_dumb'])
        
        return x.squeeze(-1)
    
    def log_wfi_smart(self, r, params):
        epsilon = 1e-6  # Small epsilon value for stability
        r_cpy = r.reshape(-1, self._N, self._dim)
        r_diff = r_cpy[:, None, :, :] - r_cpy[:, :, None, :]
        r_dist = self.la.norm(r_diff + epsilon, axis=-1)  # Add epsilon to avoid nan
        
        rij = jnp.triu(r_dist, k=1)
        
        # Generate indices for the upper triangular matrix excluding diagonal
        triu_indices = jnp.triu_indices(self._N, k=1)
        k = triu_indices[0] * (self._N - 1) - triu_indices[0] * (triu_indices[0] + 1) // 2 + triu_indices[1] - triu_indices[0] - 1
        
        # Use these indices to select the corresponding elements from the JW_smart vector
        JW_smart = params['JW_smart'][k]
        
        # Perform the element-wise multiplication and sum
        x = jnp.einsum('nij,i->n', rij, JW_smart)
        
        return x.squeeze(-1)

    
    # jax grad
    def grad_log_wfi_dumb(self, r, params):
        return vmap(jax.grad(self.log_wfi_dumb, argnums=0), in_axes=(0, None))(r, params)


    def grad_log_wfi_smart(self, r, params):
        return vmap(jax.grad(self.log_wfi_smart, argnums=0), in_axes=(0, None))(r, params)
        
        
    
# initialize r with shape (nbatch, particles * dim)
nbatch = 1
N = 2
dim = 1
key = random.PRNGKey(1)
r = random.normal(key, (nbatch, N * dim))

print("r", r.reshape(-1, N, dim))
# initialize params

params_dumb = {'JW_dumb': random.normal(key, (N, N))}
params_smart = {'JW_smart': random.normal(key, (N * (N-1) // 2, ))}

# grad
# instantiate the class
your_class = testClass(N, dim)
your_class.log_wfi_dumb(r, params_dumb)

#print the forward pass
log_wfi = your_class.log_wfi_dumb(r, params_dumb)
print("jax log_wfi dumb", log_wfi, "with shape", log_wfi.shape)
log_wfi = your_class.log_wfi_smart(r, params_smart)
print("jax log_wfi smart", log_wfi, "with shape", log_wfi.shape)

grad_log_wfi = your_class.grad_log_wfi_dumb(r, params_dumb)
print("jax grad", grad_log_wfi, "with shape", grad_log_wfi.shape)



r [[[-0.11617039]
  [ 2.2125063 ]]]
jax log_wfi dumb -1.0731523 with shape ()
jax log_wfi smart -2.7578166 with shape ()
jax grad [[ 0.4608419 -0.4608419]] with shape (1, 2)


In [13]:
r = np.random.randn(nbatch, N * dim)
# mode 1
r_reshaped = r.reshape(-1, N, dim)
for i in range(N):
    # access particle i
    print("particle", i, r_reshaped[:, i, :])

# mode 2
# instead of reshaping, use strided access
for i in range(0, N*dim, dim):
    print("particle", i // dim, r[i:i+dim])

Array([[ 0.18784384, -1.2833426 , -0.2710917 ,  1.2490594 ,  0.24447003]],      dtype=float32)