diff --git a/src/galdynamix/potential/_potential/base.py b/src/galdynamix/potential/_potential/base.py index 7ffdae1f..2ef52a0a 100644 --- a/src/galdynamix/potential/_potential/base.py +++ b/src/galdynamix/potential/_potential/base.py @@ -89,7 +89,7 @@ def __call__(self, q: jt.Array, /, t: jt.Array) -> jt.Array: return self.potential_energy(q, t) @partial_jit() - def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: phasespaceposition + def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: """Compute the gradient of the potential at the given position(s). Parameters @@ -127,14 +127,6 @@ def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array: The potential energy or value of the potential. """ lap = xp.trace(jax.hessian(self.potential_energy)(q, t)) - """Compute the gradient.""" - return jax.grad(self.potential_energy, argnums=0)(q, t) - - @partial_jit() - def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array: - # Note: trace(jacobian(gradient)) is faster than trace(hessian(energy)) - lap = xp.trace(jax.jacfwd(self.gradient)(q, t)) - return lap / (4 * xp.pi * self._G) @partial_jit()