Skip to content

Commit

Permalink
Fixes following phasespaceposition
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 5, 2023
1 parent 625988b commit cd7a736
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions src/galdynamix/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
return self.potential_energy(q, t)

@partial_jit()
def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: phasespaceposition
def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
"""Compute the gradient of the potential at the given position(s).
Parameters
Expand Down Expand Up @@ -127,14 +127,6 @@ def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
The potential energy or value of the potential.
"""
lap = xp.trace(jax.hessian(self.potential_energy)(q, t))
"""Compute the gradient."""
return jax.grad(self.potential_energy, argnums=0)(q, t)

@partial_jit()
def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
# Note: trace(jacobian(gradient)) is faster than trace(hessian(energy))
lap = xp.trace(jax.jacfwd(self.gradient)(q, t))

return lap / (4 * xp.pi * self._G)

@partial_jit()
Expand Down

0 comments on commit cd7a736

Please sign in to comment.