Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make JIT compilation of Jacobian algorithm independent from body index #17

Merged
merged 1 commit into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions src/jaxsim/physics/algos/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,24 @@ def propagate_kinematics(
Jb = i_X_0[body_index]
J = J.at[0:6, 0:6].set(Jb)

ComputeJacobianCarry = jtp.MatrixJax
compute_jacobian_carry = J
# To make JIT happy, we operate on a boolean version of κ(i).
# Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
κ_bool = model.support_body_array_bool(body_index=body_index)

def compute_jacobian(
carry: ComputeJacobianCarry, i: jtp.Int
) -> Tuple[ComputeJacobianCarry, None]:
def update_jacobian(
carry: Tuple[ComputeJacobianCarry, jtp.Int]
) -> ComputeJacobianCarry:

J, i = carry
def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> Tuple[jtp.MatrixJax, None]:
def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:

ii = i - 1

Js_i = i_X_0[body_index] @ jnp.linalg.inv(i_X_0[i]) @ S[i]
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())

return J

carry = jax.lax.cond(
pred=(jnp.any(i == model.support_body_array(body_index=body_index))),
true_fun=update_jacobian,
false_fun=lambda carry_i: carry_i[0],
operand=(carry, i),
)

return carry, None
J = jax.lax.select(pred=κ_bool[i], on_true=update_jacobian(J, i), on_false=J)
return J, None

J, _ = jax.lax.scan(
f=compute_jacobian,
init=compute_jacobian_carry,
xs=np.arange(start=1, stop=model.NB),
f=compute_jacobian, init=J, xs=np.arange(start=1, stop=model.NB)
)

return J
31 changes: 18 additions & 13 deletions src/jaxsim/physics/model/physics_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
from typing import Dict, List, Union
from typing import Dict, Union

import jax.lax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
Expand Down Expand Up @@ -185,25 +186,29 @@ def parent_array(self) -> jtp.Vector:

return jnp.array([-1] + list(self._parent_array_dict.values()))

def support_body_array(self, body_index: int) -> jtp.Vector:
def support_body_array(self, body_index: jtp.Int) -> jtp.Vector:
"""Returns κ(i)"""

kappa: List[int] = [body_index]
κ_bool = self.support_body_array_bool(body_index=body_index)
return jnp.array(jnp.where(κ_bool)[0], dtype=int)

if body_index == 0:
return np.array(kappa)
def support_body_array_bool(self, body_index: jtp.Int) -> jtp.Vector:

while True:
active_link = body_index
κ_bool = jnp.zeros(self.NB, dtype=bool)

i = self._parent_array_dict[kappa[-1]]
for i in np.flip(np.arange(start=0, stop=self.NB)):

if i == 0:
break
κ_bool, active_link = jax.lax.cond(
pred=(i == active_link),
false_fun=lambda: (κ_bool, active_link),
true_fun=lambda: (
κ_bool.at[active_link].set(True),
self.parent[active_link],
),
)

kappa.append(i)

kappa.append(0)
return np.array(list(reversed(kappa)), dtype=int)
return κ_bool

@property
def tree_transforms(self) -> jtp.Array:
Expand Down