Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use explicit integration of quaternions also in schemes on SO(3) #205

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 24 additions & 33 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability

try:
Expand Down Expand Up @@ -539,48 +540,38 @@ class ExplicitRungeKuttaSO3Mixin:
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
"""

@classmethod
def integrate_rk_stage(
cls, x0: js.ode_data.ODEState, t0: Time, dt: TimeStep, k: js.ode_data.ODEState
) -> js.ode_data.ODEState:

op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)

W_Q_B_tf = xf.physics_model.base_quaternion

return xf.replace(
physics_model=xf.physics_model.replace(
base_quaternion=W_Q_B_tf / jnp.linalg.norm(W_Q_B_tf)
)
)

@classmethod
def post_process_state(
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
) -> js.ode_data.ODEState:

# Indices to convert quaternions between serializations.
to_xyzw = jnp.array([1, 2, 3, 0])
# Extract the initial base quaternion.
W_Q_B_t0 = x0.physics_model.base_quaternion

# Get the initial rotation.
W_R_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
xyzw=x0.physics_model.base_quaternion[to_xyzw]
# We assume that the initial quaternion is already unary.
exceptions.raise_runtime_error_if(
condition=jnp.logical_not(jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0)),
msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
)

# Get the final angular velocity.
# This is already computed by averaging the kᵢ in RK-based schemes.
# Therefore, by using the ω at tf, we obtain a RK scheme operating
# on the SO(3) manifold.
W_ω_WB_tf = xf.physics_model.base_angular_velocity

# Integrate the orientation on SO(3).
# Note that we left-multiply with the exponential map since the angular
# velocity is expressed in the inertial frame.
W_R_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_R_B_t0
# Get the angular velocity ω to integrate the quaternion.
# This velocity ω[t0] is computed in the previous timestep by averaging the kᵢ
# corresponding to the active RK-based scheme. Therefore, by using the ω[t0],
traversaro marked this conversation as resolved.
Show resolved Hide resolved
# we obtain an explicit RK scheme operating on the SO(3) manifold.
# Note that the current integrator is not a semi-implicit scheme, therefore
# using the final ω[tf] would be not correct.
W_ω_WB_t0 = x0.physics_model.base_angular_velocity

# Integrate the quaternion on SO(3).
W_Q_B_tf = jaxsim.math.Quaternion.integration(
quaternion=W_Q_B_t0,
dt=dt,
omega=W_ω_WB_t0,
omega_in_body_fixed=False,
)
traversaro marked this conversation as resolved.
Show resolved Hide resolved

# Replace the quaternion in the final state.
return xf.replace(
physics_model=xf.physics_model.replace(base_quaternion=W_R_B_tf.wxyz),
physics_model=xf.physics_model.replace(base_quaternion=W_Q_B_tf),
validate=True,
)
8 changes: 8 additions & 0 deletions src/jaxsim/rbda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.math import StandardGravity


Expand Down Expand Up @@ -131,6 +132,13 @@ def process_inputs(
if W_Q_B.shape != (4,):
raise ValueError(W_Q_B.shape, (4,))

# Check that the quaternion is unary since our RBDAs make this assumption in order
# to prevent introducing additional normalizations that would affect AD.
exceptions.raise_value_error_if(
condition=jnp.logical_not(jnp.allclose(W_Q_B.dot(W_Q_B), 1.0)),
msg="A RBDA received a quaternion that is not normalized.",
)
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved

# Pack the 6D base velocity and acceleration.
W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
Expand Down
6 changes: 3 additions & 3 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_ad_aba(
aba = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g: jaxsim.rbda.aba(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_ad_rnea(
rnea = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, W_f_L, g: jaxsim.rbda.rnea(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_ad_fk(
fk = lambda W_p_B, W_Q_B, s: jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
)

Expand Down