Skip to content

Commit

Permalink
Merge pull request #162 from ami-iit/fix/representations
Browse files Browse the repository at this point in the history
Ensure consistent link forces and fix static joint friction
  • Loading branch information
flferretti committed Jun 4, 2024
2 parents fbd1f08 + 6a7effb commit 2fa8eaa
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 5 deletions.
6 changes: 4 additions & 2 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def inertial_to_other_representation(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
is_force: bool = False,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from inertial-fixed to another representation.
Expand Down Expand Up @@ -153,7 +154,8 @@ def other_representation_to_inertial(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
is_force: bool = False,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from another representation to inertial-fixed.
Expand Down
17 changes: 14 additions & 3 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def system_velocity_dynamics(
).astype(float)

# Build link forces if not provided
W_f_L = (
O_f_L = (
jnp.atleast_2d(link_forces.squeeze())
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
Expand All @@ -125,7 +125,7 @@ def system_velocity_dynamics(

# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)

# Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
# expressed in the world frame.
Expand Down Expand Up @@ -183,7 +183,7 @@ def system_velocity_dynamics(

# Compute the joint friction torque
τ_friction = -(
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
+ jnp.diag(kv) @ data.state.physics_model.joint_velocities
)

Expand All @@ -194,6 +194,17 @@ def system_velocity_dynamics(
# Compute the total joint forces
τ_total = τ + τ_friction + τ_position_limit

references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=τ_total,
link_forces=O_f_L,
data=data,
velocity_representation=data.velocity_representation,
)

with references.switch_velocity_representation(VelRepr.Inertial):
W_f_L = references.link_forces(model=model, data=data)

# Compute the total external 6D forces applied to the links
W_f_L_total = W_f_L + W_f_Li_terrain

Expand Down
81 changes: 81 additions & 0 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
import pytest

Expand Down Expand Up @@ -90,3 +91,83 @@ def test_box_with_external_forces(
assert data.time() == t_ns / 1e9 + dt
assert data.base_position() == pytest.approx(data0.base_position())
assert data.base_orientation() == pytest.approx(data0.base_orientation())


def test_box_with_zero_gravity(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jnp.ndarray,
):

model = jaxsim_model_box

# Split the PRNG key.
key, subkey, subkey2 = jax.random.split(prng_key, num=3)

# Build the data of the model.
data0 = js.data.JaxSimModelData.build(
model=model,
base_position=jax.random.uniform(subkey2, shape=(3,)),
velocity_representation=velocity_representation,
standard_gravity=0.0,
soft_contacts_params=jaxsim.rbda.SoftContactsParams.build(K=0.0, D=0.0, mu=0.0),
)

# Generate a random linear force.
L_f = (
jax.random.uniform(subkey, shape=(model.number_of_links(), 6))
.at[:, 3:]
.set(jnp.zeros(3))
)

# Initialize a references object that simplifies handling external forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data0,
velocity_representation=velocity_representation,
)

# Apply a link forces to the base link.
references = references.apply_link_forces(
forces=jnp.atleast_2d(L_f),
link_names=model.link_names(),
model=model,
data=data0,
additive=False,
)

# Create the integrator.
integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model, data=data0, system_dynamics=js.ode.system_dynamics
)
)

# Initialize the integrator.
tf = 1.0
dt = 0.010
T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)

# Copy the initial data...
data = data0.copy()

# ... and step the simulation.
for t_ns in T:

data, integrator_state = js.model.step(
model=model,
data=data,
dt=dt,
integrator=integrator,
integrator_state=integrator_state,
link_forces=references.link_forces(model=model, data=data),
)

# Check that the box moved as expected.
assert data.time() == t_ns / 1e9 + dt
assert data.base_position() == pytest.approx(
data0.base_position()
+ 0.5 * L_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
rel=1e-4,
)

0 comments on commit 2fa8eaa

Please sign in to comment.