From e69ff12276b6c626cd6020386e9873df1cdd815a Mon Sep 17 00:00:00 2001 From: nstarman Date: Tue, 5 Dec 2023 15:19:33 -0500 Subject: [PATCH] Fixes following phasespaceposition Signed-off-by: nstarman --- src/galdynamix/potential/_potential/base.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/galdynamix/potential/_potential/base.py b/src/galdynamix/potential/_potential/base.py index 7ffdae1f..253edde4 100644 --- a/src/galdynamix/potential/_potential/base.py +++ b/src/galdynamix/potential/_potential/base.py @@ -89,7 +89,7 @@ def __call__(self, q: jt.Array, /, t: jt.Array) -> jt.Array: return self.potential_energy(q, t) @partial_jit() - def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: phasespaceposition + def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: """Compute the gradient of the potential at the given position(s). Parameters @@ -126,15 +126,8 @@ def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array: :class:`~jax.Array` The potential energy or value of the potential. """ - lap = xp.trace(jax.hessian(self.potential_energy)(q, t)) - """Compute the gradient.""" - return jax.grad(self.potential_energy, argnums=0)(q, t) - - @partial_jit() - def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array: # Note: trace(jacobian(gradient)) is faster than trace(hessian(energy)) lap = xp.trace(jax.jacfwd(self.gradient)(q, t)) - return lap / (4 * xp.pi * self._G) @partial_jit()