From 9f6fa21128c5edc86579c31e9382baa406f53b6a Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 3 Jun 2024 11:57:56 +0200 Subject: [PATCH 1/6] Allow passing a single link name to apply_link_forces --- src/jaxsim/api/references.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index a0e067c7c..8d2888ee2 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -319,7 +319,7 @@ def apply_link_forces( forces: jtp.MatrixLike, model: js.model.JaxSimModel | None = None, data: js.data.JaxSimModelData | None = None, - link_names: tuple[str, ...] | None = None, + link_names: tuple[str, ...] | str | None = None, additive: bool = False, ) -> Self: """ @@ -345,7 +345,7 @@ def apply_link_forces( Then, we always convert and store forces in inertial-fixed representation. """ - f_L = jnp.array(forces) + f_L = jnp.atleast_2d(forces).astype(float) # Helper function to replace the link forces. def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: @@ -380,6 +380,15 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # If we have the model, we can extract the link names if not provided. link_names = link_names if link_names is not None else model.link_names() + + # Make sure that the link names are a tuple if they are provided by the user. + link_names = (link_names,) if isinstance(link_names, str) else link_names + + if len(link_names) != f_L.shape[0]: + msg = "The number of link names ({}) must match the number of forces ({})" + raise ValueError(msg.format(len(link_names), f_L.shape[0])) + + # Extract the link indices. link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) # Compute the bias depending on whether we either set or add the link forces. From 479fe3e506d344a8279614f43ce76bbbe41feb71 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 4 Jun 2024 11:48:20 +0200 Subject: [PATCH 2/6] Allow to apply link forces using the link frame in body-fixed and mixed --- src/jaxsim/api/references.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 8d2888ee2..37e4bc26b 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -414,16 +414,24 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: if not_tracing(forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") - # Helper function to convert a single 6D force to the inertial representation. - def convert(f_L: jtp.Vector) -> jtp.Vector: - return JaxSimModelReferences.other_representation_to_inertial( - array=f_L, - other_representation=self.velocity_representation, - transform=data.base_transform(), - is_force=True, - ) + # Helper function to convert a single 6D force to the inertial representation + # considering as body the link (i.e. L_f_L and LW_f_L). + def convert_using_link_frame( + f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike + ) -> jtp.Matrix: + + return jax.vmap( + lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial( + array=f_L, + other_representation=self.velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(f_L, W_H_L) - W_f_L = jax.vmap(convert)(f_L) + # The f_L input is either L_f_L or LW_f_L, depending on the representation. + W_H_L = js.model.forward_kinematics(model=model, data=data) + W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace( forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L) From 2681109b94c5dadec0347db846bd6b863b92d557 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 4 Jun 2024 11:47:55 +0200 Subject: [PATCH 3/6] Allow to get link forces using the link frame in body-fixed and mixed --- src/jaxsim/api/references.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 37e4bc26b..de55daab7 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -202,17 +202,22 @@ def link_forces( if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") - # Helper function to convert a single 6D force to the active representation. - def convert(f_L: jtp.Vector) -> jtp.Vector: - return JaxSimModelReferences.inertial_to_other_representation( - array=f_L, - other_representation=self.velocity_representation, - transform=data.base_transform(), - is_force=True, - ) + # Helper function to convert a single 6D force to the active representation + # considering as body the link (i.e. L_f_L and LW_f_L). + def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: - # Convert to the desired representation. - f_L = jax.vmap(convert)(W_f_L[link_idxs, :]) + return jax.vmap( + lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( + array=W_f_L, + other_representation=self.velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(W_f_L, W_H_L) + + # The f_L output is either L_f_L or LW_f_L, depending on the representation. + W_H_L = js.model.forward_kinematics(model=model, data=data) + f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) return f_L From 791d4740e8c24724f2cf1bb06f8f3705a518c883 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 4 Jun 2024 11:29:17 +0200 Subject: [PATCH 4/6] Use the link frame in the stacked model Jacobian for repr body and mixed --- src/jaxsim/api/model.py | 74 +++++++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a3a154268..06b1027bd 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -450,27 +450,36 @@ def generalized_free_floating_jacobian( ) # Compute the doubly-left free-floating full jacobian. - B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left( + B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), ) - # Update the input velocity representation such that `J_WL_I @ I_ν`. + # ====================================================================== + # Update the input velocity representation such that v_WL = J_WL_I @ I_ν + # ====================================================================== + match data.velocity_representation: + case VelRepr.Inertial: + W_H_B = data.base_transform() B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint() + B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag( B_X_W, jnp.eye(model.dofs()) ) case VelRepr.Body: + B_J_full_WX_I = B_J_full_WX_B case VelRepr.Mixed: + W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_J_full_WX_I = B_J_full_WX_BW = ( B_J_full_WX_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) @@ -479,32 +488,61 @@ def generalized_free_floating_jacobian( case _: raise ValueError(data.velocity_representation) - # Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. + # ==================================================================== + # Create stacked Jacobian for each link by filtering the full Jacobian + # ==================================================================== + + κ_bool = model.kin_dyn_parameters.support_body_array_bool + + # Keep only the columns of the full Jacobian corresponding to the support + # body array of each link. + B_J_WL_I = jax.vmap( + lambda κ: jnp.where( + jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I) + ) + )(κ_bool) + + # ======================================================================= + # Update the output velocity representation such that O_v_WL = O_J_WL @ ν + # ======================================================================= + match output_vel_repr: + case VelRepr.Inertial: + W_H_B = data.base_transform() - W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint() - O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I + W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) + + O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I) case VelRepr.Body: - O_J_full_WX_I = B_J_full_WX_I + + O_J_WL_I = L_J_WL_I = jax.vmap( + lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform( + B_H_L, inverse=True + ) + @ B_J_WL_I + )(B_H_L, B_J_WL_I) case VelRepr.Mixed: - W_R_B = data.base_orientation(dcm=True) - BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint() - O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I - case _: - raise ValueError(output_vel_repr) + W_H_B = data.base_transform() - κ_bool = model.kin_dyn_parameters.support_body_array_bool + LW_H_L = jax.vmap( + lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3)) + )(B_H_L) - O_J_WL_I = jax.vmap( - lambda κ: jnp.where( - jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I) - ) - )(κ_bool) + LW_H_B = jax.vmap( + lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) + )(LW_H_L, B_H_L) + + O_J_WL_I = LW_J_WL_I = jax.vmap( + lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B) + @ B_J_WL_I + )(LW_H_B, B_J_WL_I) + + case _: + raise ValueError(output_vel_repr) return O_J_WL_I From 761b3311a5611c9f7b45c055b7e61109b69ad7d0 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 4 Jun 2024 11:31:03 +0200 Subject: [PATCH 5/6] Update test of link APIs --- tests/test_api_link.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_api_link.py b/tests/test_api_link.py index bd80ea027..84c842246 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -147,9 +147,8 @@ def test_link_jacobians( ), link_name # The following is true only in inertial-fixed representation. - if data.velocity_representation is VelRepr.Inertial: - J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data) - assert J_WL_model == pytest.approx(J_WL_links) + J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data) + assert J_WL_model == pytest.approx(J_WL_links) for link_name, link_idx in zip( model.link_names(), From 9beb8c4c4648b86f992eb5343a3616ed97afb6b1 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 4 Jun 2024 12:10:00 +0200 Subject: [PATCH 6/6] Test output representation of link jacobian --- tests/test_api_link.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 84c842246..dc619fd80 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -158,6 +158,28 @@ def test_link_jacobians( v_WL_js = js.link.velocity(model=model, data=data, link_index=link_idx) assert v_WL_js == pytest.approx(v_WL_idt), link_name + # Test conversion to a different output velocity representation. + for other_repr in {VelRepr.Inertial, VelRepr.Body, VelRepr.Mixed}.difference( + {data.velocity_representation} + ): + + with data.switch_velocity_representation(other_repr): + kin_dyn_other_repr = ( + utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + ) + + for link_name, link_idx in zip( + model.link_names(), + js.link.names_to_idxs(model=model, link_names=model.link_names()), + ): + v_WL_idt = kin_dyn_other_repr.frame_velocity(frame_name=link_name) + v_WL_js = js.link.velocity( + model=model, data=data, link_index=link_idx, output_vel_repr=other_repr + ) + assert v_WL_js == pytest.approx(v_WL_idt), link_name + def test_link_bias_acceleration( jaxsim_models_types: js.model.JaxSimModel,