From 872febe7f29d90aea011ff782d01f6a6bc992552 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 May 2024 13:20:27 +0200 Subject: [PATCH 1/4] Ensure that link forces are in `VelRepr.Inertial` When passed to forward dynamics Co-authored-by: Alessandro Croci --- src/jaxsim/api/ode.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index a2a98043e..cda603de2 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -113,7 +113,7 @@ def system_velocity_dynamics( ).astype(float) # Build link forces if not provided - W_f_L = ( + O_f_L = ( jnp.atleast_2d(link_forces.squeeze()) if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) @@ -125,7 +125,7 @@ def system_velocity_dynamics( # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float) + W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float) # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points, # expressed in the world frame. @@ -194,6 +194,17 @@ def system_velocity_dynamics( # Compute the total joint forces τ_total = τ + τ_friction + τ_position_limit + references = js.references.JaxSimModelReferences.build( + model=model, + joint_force_references=τ_total, + link_forces=O_f_L, + data=data, + velocity_representation=data.velocity_representation, + ) + + with references.switch_velocity_representation(VelRepr.Inertial): + W_f_L = references.link_forces(model=model, data=data) + # Compute the total external 6D forces applied to the links W_f_L_total = W_f_L + W_f_Li_terrain From 51ed0cbc3876a1c9fab2241b0556028a46446a0c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 May 2024 18:11:33 +0200 Subject: [PATCH 2/4] Use joint velocities sign for static joint friction Co-authored-by: Alessandro Croci --- src/jaxsim/api/ode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index cda603de2..426934f16 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -183,7 +183,7 @@ def system_velocity_dynamics( # Compute the joint friction torque τ_friction = -( - jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions) + jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities) + jnp.diag(kv) @ data.state.physics_model.joint_velocities ) From 31766ba6549f58b0a2bc8bafa062c6674173a5e2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 27 May 2024 17:17:09 +0200 Subject: [PATCH 3/4] Add test to verify that forces are applied correctly --- tests/test_simulations.py | 81 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 9495c2737..f300a73a1 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp import pytest @@ -90,3 +91,83 @@ def test_box_with_external_forces( assert data.time() == t_ns / 1e9 + dt assert data.base_position() == pytest.approx(data0.base_position()) assert data.base_orientation() == pytest.approx(data0.base_orientation()) + + +def test_box_with_zero_gravity( + jaxsim_model_box: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jnp.ndarray, +): + + model = jaxsim_model_box + + # Split the PRNG key. + key, subkey, subkey2 = jax.random.split(prng_key, num=3) + + # Build the data of the model. + data0 = js.data.JaxSimModelData.build( + model=model, + base_position=jax.random.uniform(subkey2, shape=(3,)), + velocity_representation=velocity_representation, + standard_gravity=0.0, + soft_contacts_params=jaxsim.rbda.SoftContactsParams.build(K=0.0, D=0.0, mu=0.0), + ) + + # Generate a random linear force. + L_f = ( + jax.random.uniform(subkey, shape=(model.number_of_links(), 6)) + .at[:, 3:] + .set(jnp.zeros(3)) + ) + + # Initialize a references object that simplifies handling external forces. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data0, + velocity_representation=velocity_representation, + ) + + # Apply a link forces to the base link. + references = references.apply_link_forces( + forces=jnp.atleast_2d(L_f), + link_names=model.link_names(), + model=model, + data=data0, + additive=False, + ) + + # Create the integrator. + integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build( + dynamics=js.ode.wrap_system_dynamics_for_integration( + model=model, data=data0, system_dynamics=js.ode.system_dynamics + ) + ) + + # Initialize the integrator. + tf = 1.0 + dt = 0.010 + T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int) + integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt) + + # Copy the initial data... + data = data0.copy() + + # ... and step the simulation. + for t_ns in T: + + data, integrator_state = js.model.step( + model=model, + data=data, + dt=dt, + integrator=integrator, + integrator_state=integrator_state, + link_forces=references.link_forces(model=model, data=data), + ) + + # Check that the box moved as expected. + assert data.time() == t_ns / 1e9 + dt + assert data.base_position() == pytest.approx( + data0.base_position() + + 0.5 * L_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, + rel=1e-4, + ) From 6a7effbe2c233ea847cd714880a8ab8c470e57e5 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 28 May 2024 17:37:33 +0200 Subject: [PATCH 4/4] Make `is_force` kwarg-only and remove default value Co-authored-by: Alessandro Croci --- src/jaxsim/api/common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index fe958156a..0c9fcf4f5 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -87,7 +87,8 @@ def inertial_to_other_representation( array: jtp.Array, other_representation: VelRepr, transform: jtp.Matrix, - is_force: bool = False, + *, + is_force: bool, ) -> jtp.Array: r""" Convert a 6D quantity from inertial-fixed to another representation. @@ -153,7 +154,8 @@ def other_representation_to_inertial( array: jtp.Array, other_representation: VelRepr, transform: jtp.Matrix, - is_force: bool = False, + *, + is_force: bool, ) -> jtp.Array: r""" Convert a 6D quantity from another representation to inertial-fixed.