Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Jacobian algorithm #121

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,62 +233,61 @@ def jacobian(
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# Compute the doubly left-trivialized free-floating jacobian.
L_J_WL_B = jaxsim.rbda.jacobian(
# Compute the doubly-left free-floating full jacobian.
B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
link_index=link_index,
joint_positions=data.joint_positions(),
)

# We want to return the Jacobian O_J_WL_I, where O is the representation of the
# output 6D velocity, and I is the representation of the input generalized velocity.
# Compute the actual doubly-left free-floating jacobian of the link.
κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B

# Change the input representation matching the one of data.
# Adjust the input representation such that `J_WL_I @ I_ν`.
match data.velocity_representation:

case VelRepr.Body:
L_J_WL_I = L_J_WL_B

case VelRepr.Inertial:

W_H_B = data.base_transform()

B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
B_X_W, jnp.eye(model.dofs())
)

L_J_WL_I = L_J_WL_B @ B_T_W
case VelRepr.Body:
B_J_WL_I = B_J_WL_B

case VelRepr.Mixed:

W_H_B = data.base_transform()
BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))

W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))

L_J_WL_I = L_J_WL_B @ B_T_BW
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
B_X_BW, jnp.eye(model.dofs())
)

case _:
raise ValueError(data.velocity_representation)

# Change the output representation matching specified one.
B_H_L = B_H_Li[link_index]

# Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
match output_vel_repr:
case VelRepr.Inertial:
W_H_B = data.base_transform()
W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I

case VelRepr.Body:
L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint()
L_J_WL_I = L_X_B @ B_J_WL_I
O_J_WL_I = L_J_WL_I

case VelRepr.Inertial:

W_H_L = transform(model=model, data=data, link_index=link_index)
W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
O_J_WL_I = W_X_L @ L_J_WL_I

case VelRepr.Mixed:

W_H_L = transform(model=model, data=data, link_index=link_index)
LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
LW_X_L = jaxlie.SE3.from_matrix(LW_H_L).adjoint()
O_J_WL_I = LW_X_L @ L_J_WL_I
W_H_B = data.base_transform()
W_H_L = W_H_B @ B_H_L
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint()
LW_J_WL_I = LW_X_B @ B_J_WL_I
O_J_WL_I = LW_J_WL_I

case _:
raise ValueError(output_vel_repr)
Expand Down
79 changes: 46 additions & 33 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,51 +382,64 @@ def generalized_free_floating_jacobian(
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# The body frame of the link.jacobian method is the link frame L.
# In this method, we want instead to use the base link B as body frame.
# Therefore, we always get the link jacobian having Inertial as output
# representation, and then we convert it to the desired output representation.
match output_vel_repr:
# Compute the doubly-left free-floating full jacobian.
B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
)

# Update the input velocity representation such that `J_WL_I @ I_ν`.
match data.velocity_representation:
case VelRepr.Inertial:
to_output = lambda W_J_WL: W_J_WL
W_H_B = data.base_transform()
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
B_X_W, jnp.eye(model.dofs())
)

case VelRepr.Body:

def to_output(W_J_WL: jtp.Matrix) -> jtp.Matrix:
W_H_B = data.base_transform()
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
return B_X_W @ W_J_WL
B_J_full_WX_I = B_J_full_WX_B

case VelRepr.Mixed:
W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
B_J_full_WX_I = B_J_full_WX_BW = (
B_J_full_WX_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
)

def to_output(W_J_WL: jtp.Matrix) -> jtp.Matrix:
W_H_B = data.base_transform()
W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
return BW_X_W @ W_J_WL
case _:
raise ValueError(data.velocity_representation)

# Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
match output_vel_repr:
case VelRepr.Inertial:
W_H_B = data.base_transform()
W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I

case VelRepr.Body:
O_J_full_WX_I = B_J_full_WX_I

case VelRepr.Mixed:
W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint()
O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I

case _:
raise ValueError(output_vel_repr)

# Compute first the link jacobians having the active representation of `data`
# as input representation (matching the one of ν), and inertial as output
# representation (i.e. W_J_WL_C where C is C_ν).
# Then, with to_output, we convert this jacobian to the desired output
# representation, that can either be W (inertial), B (body), or B[W] (mixed).
# This is necessary because for example the body-fixed free-floating jacobian
# of a link is L_J_WL, but here being inside model we need B_J_WL.
J_free_floating = jax.vmap(
lambda i: to_output(
W_J_WL=js.link.jacobian(
model=model,
data=data,
link_index=i,
output_vel_repr=VelRepr.Inertial,
)
κ_bool = model.kin_dyn_parameters.support_body_array_bool

O_J_WL_I = jax.vmap(
lambda κ: jnp.where(
jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I)
)
)(jnp.arange(model.number_of_links()))
)(κ_bool)

return J_free_floating
return O_J_WL_I


@functools.partial(jax.jit, static_argnames=["prefer_aba"])
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .collidable_points import collidable_points_pos_vel
from .crba import crba
from .forward_kinematics import forward_kinematics, forward_kinematics_model
from .jacobian import jacobian
from .jacobian import jacobian, jacobian_full_doubly_left
from .rnea import rnea
from .soft_contacts import SoftContacts, SoftContactsParams
93 changes: 90 additions & 3 deletions src/jaxsim/rbda/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def jacobian(
joint_positions: The positions of the joints.

Returns:
The doubly-left free-floating Jacobian of the link.
The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`.
"""

_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
Expand Down Expand Up @@ -105,10 +105,97 @@ def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:

return J, None

W_J_WL_W, _ = jax.lax.scan(
L_J_WL_B, _ = jax.lax.scan(
f=compute_jacobian,
init=J,
xs=np.arange(start=1, stop=model.number_of_links()),
)

return W_J_WL_W
return L_J_WL_B


@jax.jit
def jacobian_full_doubly_left(
model: js.model.JaxSimModel,
*,
joint_positions: jtp.VectorLike,
) -> tuple[jtp.Matrix, jtp.Array]:
r"""
Compute the doubly-left full free-floating Jacobian of a model.

The full Jacobian is a 6x(6+n) matrix with all the columns filled.
It is useful to run the algorithm once, and then extract the link Jacobian by
filtering the columns of the full Jacobian using the support parent array
:math:`\kappa(i)` of the link.

Args:
model: The model to consider.
joint_positions: The positions of the joints.

Returns:
The doubly-left full free-floating Jacobian of a model.
"""

_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions
)

# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array

# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=jnp.eye(4)
)

# Allocate the buffer of transforms base -> link.
B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
B_X_i = B_X_i.at[0].set(jnp.eye(6))

# =============================
# Compute doubly-left Jacobian
# =============================

# Allocate the Jacobian matrix.
# The Jbb section of the doubly-left Jacobian is an identity matrix.
J = jnp.zeros(shape=(6, 6 + model.dofs()))
J = J.at[0:6, 0:6].set(jnp.eye(6))

ComputeFullJacobianCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)

def compute_full_jacobian(
carry: ComputeFullJacobianCarry, i: jtp.Int
) -> tuple[ComputeFullJacobianCarry, None]:

ii = i - 1
B_X_i, J = carry

# Compute the base (0) to link (i) adjoint matrix.
B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
B_X_i = B_X_i.at[i].set(B_Xi_i)

# Compute the ii-th column of the B_S_BL(s) matrix.
B_Sii_BL = B_Xi_i @ S[i]
J = J.at[0:6, 6 + ii].set(B_Sii_BL.squeeze())

return (B_X_i, J), None

(B_X_i, J), _ = jax.lax.scan(
f=compute_full_jacobian,
init=compute_full_jacobian_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
)

# Convert adjoints to SE(3) transforms.
# Returning them here prevents calling FK in case the output representation
# of the Jacobian needs to be changed.
B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i)

# Adjust shape of doubly-left free-floating full Jacobian.
B_J_full_WL_B = J.squeeze().astype(float)

return B_J_full_WL_B, B_H_L
Loading