Skip to content

Commit

Permalink
Use the link frame in the stacked model Jacobian for repr body and mixed
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 4, 2024
1 parent a3d10a8 commit c28ed2e
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,27 +450,36 @@ def generalized_free_floating_jacobian(
)

# Compute the doubly-left free-floating full jacobian.
B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
)

# Update the input velocity representation such that `J_WL_I @ I_ν`.
# ======================================================================
# Update the input velocity representation such that v_WL = J_WL_I @ I_ν
# ======================================================================

match data.velocity_representation:

case VelRepr.Inertial:

W_H_B = data.base_transform()
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()

B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
B_X_W, jnp.eye(model.dofs())
)

case VelRepr.Body:

B_J_full_WX_I = B_J_full_WX_B

case VelRepr.Mixed:

W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()

B_J_full_WX_I = B_J_full_WX_BW = (
B_J_full_WX_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
Expand All @@ -479,32 +488,61 @@ def generalized_free_floating_jacobian(
case _:
raise ValueError(data.velocity_representation)

# Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
# ====================================================================
# Create stacked Jacobian for each link by filtering the full Jacobian
# ====================================================================

κ_bool = model.kin_dyn_parameters.support_body_array_bool

# Keep only the columns of the full Jacobian corresponding to the support
# body array of each link.
B_J_WL_I = jax.vmap(
lambda κ: jnp.where(
jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I)
)
)(κ_bool)

# =======================================================================
# Update the output velocity representation such that O_v_WL = O_J_WL @ ν
# =======================================================================

match output_vel_repr:

case VelRepr.Inertial:

W_H_B = data.base_transform()
W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I
W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)

O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I)

case VelRepr.Body:
O_J_full_WX_I = B_J_full_WX_I

O_J_WL_I = L_J_WL_I = jax.vmap(
lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
B_H_L, inverse=True
)
@ B_J_WL_I
)(B_H_L, B_J_WL_I)

case VelRepr.Mixed:
W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint()
O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I

case _:
raise ValueError(output_vel_repr)
W_H_B = data.base_transform()

κ_bool = model.kin_dyn_parameters.support_body_array_bool
LW_H_L = jax.vmap(
lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3))
)(B_H_L)

O_J_WL_I = jax.vmap(
lambda κ: jnp.where(
jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I)
)
)(κ_bool)
LW_H_B = jax.vmap(
lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
)(LW_H_L, B_H_L)

O_J_WL_I = LW_J_WL_I = jax.vmap(
lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
@ B_J_WL_I
)(LW_H_B, B_J_WL_I)

case _:
raise ValueError(output_vel_repr)

return O_J_WL_I

Expand Down

0 comments on commit c28ed2e

Please sign in to comment.