Skip to content

Commit

Permalink
Merge pull request #167 from ami-iit/fix_reference_frame_of_model_jac…
Browse files Browse the repository at this point in the history
…obian

Always use the link frame in body-fixed and mixed output representations of model Jacobian
  • Loading branch information
diegoferigo committed Jun 4, 2024
2 parents 2fa8eaa + 9beb8c4 commit 44b9aee
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 42 deletions.
74 changes: 56 additions & 18 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,27 +451,36 @@ def generalized_free_floating_jacobian(
)

# Compute the doubly-left free-floating full jacobian.
B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
)

# Update the input velocity representation such that `J_WL_I @ I_ν`.
# ======================================================================
# Update the input velocity representation such that v_WL = J_WL_I @ I_ν
# ======================================================================

match data.velocity_representation:

case VelRepr.Inertial:

W_H_B = data.base_transform()
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()

B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
B_X_W, jnp.eye(model.dofs())
)

case VelRepr.Body:

B_J_full_WX_I = B_J_full_WX_B

case VelRepr.Mixed:

W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()

B_J_full_WX_I = B_J_full_WX_BW = (
B_J_full_WX_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
Expand All @@ -480,32 +489,61 @@ def generalized_free_floating_jacobian(
case _:
raise ValueError(data.velocity_representation)

# Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
# ====================================================================
# Create stacked Jacobian for each link by filtering the full Jacobian
# ====================================================================

κ_bool = model.kin_dyn_parameters.support_body_array_bool

# Keep only the columns of the full Jacobian corresponding to the support
# body array of each link.
B_J_WL_I = jax.vmap(
lambda κ: jnp.where(
jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I)
)
)(κ_bool)

# =======================================================================
# Update the output velocity representation such that O_v_WL = O_J_WL @ ν
# =======================================================================

match output_vel_repr:

case VelRepr.Inertial:

W_H_B = data.base_transform()
W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I
W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)

O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I)

case VelRepr.Body:
O_J_full_WX_I = B_J_full_WX_I

O_J_WL_I = L_J_WL_I = jax.vmap(
lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
B_H_L, inverse=True
)
@ B_J_WL_I
)(B_H_L, B_J_WL_I)

case VelRepr.Mixed:
W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint()
O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I

case _:
raise ValueError(output_vel_repr)
W_H_B = data.base_transform()

κ_bool = model.kin_dyn_parameters.support_body_array_bool
LW_H_L = jax.vmap(
lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3))
)(B_H_L)

O_J_WL_I = jax.vmap(
lambda κ: jnp.where(
jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I)
)
)(κ_bool)
LW_H_B = jax.vmap(
lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
)(LW_H_L, B_H_L)

O_J_WL_I = LW_J_WL_I = jax.vmap(
lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
@ B_J_WL_I
)(LW_H_B, B_J_WL_I)

case _:
raise ValueError(output_vel_repr)

return O_J_WL_I

Expand Down
64 changes: 43 additions & 21 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,22 @@ def link_forces(
if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# Helper function to convert a single 6D force to the active representation.
def convert(f_L: jtp.Vector) -> jtp.Vector:
return JaxSimModelReferences.inertial_to_other_representation(
array=f_L,
other_representation=self.velocity_representation,
transform=data.base_transform(),
is_force=True,
)
# Helper function to convert a single 6D force to the active representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:

return jax.vmap(
lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(
array=W_f_L,
other_representation=self.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(W_f_L, W_H_L)

# Convert to the desired representation.
f_L = jax.vmap(convert)(W_f_L[link_idxs, :])
# The f_L output is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])

return f_L

Expand Down Expand Up @@ -319,7 +324,7 @@ def apply_link_forces(
forces: jtp.MatrixLike,
model: js.model.JaxSimModel | None = None,
data: js.data.JaxSimModelData | None = None,
link_names: tuple[str, ...] | None = None,
link_names: tuple[str, ...] | str | None = None,
additive: bool = False,
) -> Self:
"""
Expand All @@ -345,7 +350,7 @@ def apply_link_forces(
Then, we always convert and store forces in inertial-fixed representation.
"""

f_L = jnp.array(forces)
f_L = jnp.atleast_2d(forces).astype(float)

# Helper function to replace the link forces.
def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
Expand Down Expand Up @@ -380,6 +385,15 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:

# If we have the model, we can extract the link names if not provided.
link_names = link_names if link_names is not None else model.link_names()

# Make sure that the link names are a tuple if they are provided by the user.
link_names = (link_names,) if isinstance(link_names, str) else link_names

if len(link_names) != f_L.shape[0]:
msg = "The number of link names ({}) must match the number of forces ({})"
raise ValueError(msg.format(len(link_names), f_L.shape[0]))

# Extract the link indices.
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)

# Compute the bias depending on whether we either set or add the link forces.
Expand All @@ -405,16 +419,24 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
if not_tracing(forces) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# Helper function to convert a single 6D force to the inertial representation.
def convert(f_L: jtp.Vector) -> jtp.Vector:
return JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
transform=data.base_transform(),
is_force=True,
)
# Helper function to convert a single 6D force to the inertial representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert_using_link_frame(
f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
) -> jtp.Matrix:

return jax.vmap(
lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(f_L, W_H_L)

W_f_L = jax.vmap(convert)(f_L)
# The f_L input is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])

return replace(
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
Expand Down
27 changes: 24 additions & 3 deletions tests/test_api_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,8 @@ def test_link_jacobians(
), link_name

# The following is true only in inertial-fixed representation.
if data.velocity_representation is VelRepr.Inertial:
J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data)
assert J_WL_model == pytest.approx(J_WL_links)
J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data)
assert J_WL_model == pytest.approx(J_WL_links)

for link_name, link_idx in zip(
model.link_names(),
Expand All @@ -159,6 +158,28 @@ def test_link_jacobians(
v_WL_js = js.link.velocity(model=model, data=data, link_index=link_idx)
assert v_WL_js == pytest.approx(v_WL_idt), link_name

# Test conversion to a different output velocity representation.
for other_repr in {VelRepr.Inertial, VelRepr.Body, VelRepr.Mixed}.difference(
{data.velocity_representation}
):

with data.switch_velocity_representation(other_repr):
kin_dyn_other_repr = (
utils_idyntree.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)
)

for link_name, link_idx in zip(
model.link_names(),
js.link.names_to_idxs(model=model, link_names=model.link_names()),
):
v_WL_idt = kin_dyn_other_repr.frame_velocity(frame_name=link_name)
v_WL_js = js.link.velocity(
model=model, data=data, link_index=link_idx, output_vel_repr=other_repr
)
assert v_WL_js == pytest.approx(v_WL_idt), link_name


def test_link_bias_acceleration(
jaxsim_models_types: js.model.JaxSimModel,
Expand Down

0 comments on commit 44b9aee

Please sign in to comment.