Skip to content

Commit

Permalink
Merge pull request #17 from ami-iit/feature/optimize_jit_jacobian
Browse files Browse the repository at this point in the history
Make JIT compilation of Jacobian algorithm independent from body index
  • Loading branch information
diegoferigo committed Sep 21, 2022
2 parents 67b10fe + 5a1affc commit 119c228
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 35 deletions.
30 changes: 8 additions & 22 deletions src/jaxsim/physics/algos/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,24 @@ def propagate_kinematics(
Jb = i_X_0[body_index]
J = J.at[0:6, 0:6].set(Jb)

ComputeJacobianCarry = jtp.MatrixJax
compute_jacobian_carry = J
# To make JIT happy, we operate on a boolean version of κ(i).
# Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
κ_bool = model.support_body_array_bool(body_index=body_index)

def compute_jacobian(
carry: ComputeJacobianCarry, i: jtp.Int
) -> Tuple[ComputeJacobianCarry, None]:
def update_jacobian(
carry: Tuple[ComputeJacobianCarry, jtp.Int]
) -> ComputeJacobianCarry:

J, i = carry
def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> Tuple[jtp.MatrixJax, None]:
def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:

ii = i - 1

Js_i = i_X_0[body_index] @ jnp.linalg.inv(i_X_0[i]) @ S[i]
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())

return J

carry = jax.lax.cond(
pred=(jnp.any(i == model.support_body_array(body_index=body_index))),
true_fun=update_jacobian,
false_fun=lambda carry_i: carry_i[0],
operand=(carry, i),
)

return carry, None
J = jax.lax.select(pred=κ_bool[i], on_true=update_jacobian(J, i), on_false=J)
return J, None

J, _ = jax.lax.scan(
f=compute_jacobian,
init=compute_jacobian_carry,
xs=np.arange(start=1, stop=model.NB),
f=compute_jacobian, init=J, xs=np.arange(start=1, stop=model.NB)
)

return J
31 changes: 18 additions & 13 deletions src/jaxsim/physics/model/physics_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
from typing import Dict, List, Union
from typing import Dict, Union

import jax.lax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
Expand Down Expand Up @@ -185,25 +186,29 @@ def parent_array(self) -> jtp.Vector:

return jnp.array([-1] + list(self._parent_array_dict.values()))

def support_body_array(self, body_index: int) -> jtp.Vector:
def support_body_array(self, body_index: jtp.Int) -> jtp.Vector:
"""Returns κ(i)"""

kappa: List[int] = [body_index]
κ_bool = self.support_body_array_bool(body_index=body_index)
return jnp.array(jnp.where(κ_bool)[0], dtype=int)

if body_index == 0:
return np.array(kappa)
def support_body_array_bool(self, body_index: jtp.Int) -> jtp.Vector:

while True:
active_link = body_index
κ_bool = jnp.zeros(self.NB, dtype=bool)

i = self._parent_array_dict[kappa[-1]]
for i in np.flip(np.arange(start=0, stop=self.NB)):

if i == 0:
break
κ_bool, active_link = jax.lax.cond(
pred=(i == active_link),
false_fun=lambda: (κ_bool, active_link),
true_fun=lambda: (
κ_bool.at[active_link].set(True),
self.parent[active_link],
),
)

kappa.append(i)

kappa.append(0)
return np.array(list(reversed(kappa)), dtype=int)
return κ_bool

@property
def tree_transforms(self) -> jtp.Array:
Expand Down

0 comments on commit 119c228

Please sign in to comment.