In [43]:
import jax.numpy as jnp
from jax import grad, jit, vmap

from jax import random

In [67]:
G = {
    "beta": 5.0,
    "lambda0": 0.2,
    "lambda1": 0.3,
    "lambda2": 0.5,
}

    
@jit
def quadruple(a1, a2, b1, b2) -> float: 
    return jnp.dot(jnp.cross(a1, b1), jnp.cross(a2, b2))
@jit
def V(r : float, S : float) -> float:
    return jnp.exp(-r) - S * jnp.exp(- r / G["beta"])

@jit
def S(r, p1, q1, p2, q2) -> float:
    S1 = quadruple(p1, p2, r, r)
    S2 = quadruple(p1, p2, q1, q2)
    S3 = quadruple(q1, q2, r, r)

    return G["lambda0"]*S1 + G["lambda1"]*S2 + G["lambda2"]*S3

@jit
def unpack_cellrow(cellrow : jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    pos = cellrow[0]
    p = cellrow[1]
    q = cellrow[2]
    return pos, p, q

@jit
def U(cellrow1 : jnp.ndarray, cellrow2 : jnp.ndarray) -> float:

    pos1, p1, q1 = unpack_cellrow(cellrow1)
    pos2, p2, q2 = unpack_cellrow(cellrow2)
    
   

    _dir = pos1 - pos2

    r = jnp.sqrt(jnp.dot(_dir, _dir))

    s = S(_dir, p1, q1, p2, q2)

    v = V(r, s)
    # take care of the case where the cells are the same
    print("v", v)
    return jnp.where(jnp.array_equal(pos1, pos2), 0.0, v)



def get_IC(N : int) -> jnp.ndarray:

    # N random points in 3D
    key = random.PRNGKey(0)
    pos = random.normal(key, (N, 3))

    # N vectors pointing up
    p = jnp.repeat(jnp.array([[0, 0, 1]]), N, axis=0)

    # N vectors pointing right
    q = jnp.repeat(jnp.array([[0, 1, 0]]), N, axis=0)

    # combine them
    cells = jnp.stack([pos, p, q], axis=1)

    return cells

# compute the energy of the system
def U_sum(cells : jnp.ndarray):
    print("computing energy")
    print(len(cells))

    # use the jax map function to compute the energy of each cell
    U_cells = vmap(U, (0, None), 0)(cells, cells)

    # sum the energies
    U_sum = jnp.sum(U_cells)


    return U_sum

# jaxify the energy function
# U_sum_jax = jit(U_sum)

cells = get_IC(2)


print(U_sum(cells))



computing energy
2
v Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=2/0)>
nan


In [48]:
cells[0][2]

Array([0., 1., 0.], dtype=float32)