Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute link bias accelerations J̇ν #127

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,27 @@ def velocity(

# Compute the link velocity in the output velocity representation.
return O_J_WL_I @ I_ν


@jax.jit
def bias_acceleration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
) -> jtp.Vector:
"""
Compute the bias acceleration of the link.

Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.

Returns:
The 6D bias acceleration of the link.
"""

# Compute the bias acceleration of all links in the active representation.
O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index]
return O_v̇_WL
205 changes: 205 additions & 0 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,211 @@ def average_velocity_jacobian(
# ========================


@jax.jit
def link_bias_accelerations(
model: JaxSimModel,
data: js.data.JaxSimModelData,
) -> jtp.Vector:
r"""
Compute the bias accelerations of the links of the model.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The bias accelerations of the links of the model.

Note:
This function computes the component of the total 6D acceleration not due to
the joint or base acceleration.
It is often called :math:`\dot{J} \boldsymbol{\nu}`.
"""

# ================================================
# Compute the body-fixed zero base 6D acceleration
# ================================================

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(
xyzw=jaxsim.math.Quaternion.to_xyzw(wxyz=data.base_orientation())
),
translation=data.base_position(),
).as_matrix()

def other_representation_to_inertial(
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the active representation of the base acceleration C_v̇_WB
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()

# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
return W_X_C @ (C_v̇_WB + jaxsim.math.Cross.vx(C_X_W @ W_v_WC) @ C_v_WB)

# Here we initialize a zero 6D acceleration in the active representation, and
# convert it to inertial-fixed. This is a useful intermediate representation
# because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration
# W_a_WB, and intrinsic accelerations can be expressed in different frames through
# a simple C_X_W 6D transform.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_C = W_H_W = jnp.eye(4)
W_v_WC = W_v_WW = jnp.zeros(6)
with data.switch_velocity_representation(VelRepr.Inertial):
C_v_WB = W_v_WB = data.base_velocity()

case VelRepr.Body:
W_H_C = W_H_B
with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WC = W_v_WB = data.base_velocity()
with data.switch_velocity_representation(VelRepr.Body):
C_v_WB = B_v_WB = data.base_velocity()

case VelRepr.Mixed:
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_H_C = W_H_BW
with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity()[0:3]
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
with data.switch_velocity_representation(VelRepr.Mixed):
C_v_WB = BW_v_WB = data.base_velocity()
case _:
raise ValueError(data.velocity_representation)

# Convert a zero 6D acceleration from the active representation to inertial-fixed.
W_v̇_WB = other_representation_to_inertial(
C_v̇_WB=jnp.zeros(6), C_v_WB=C_v_WB, W_H_C=W_H_C, W_v_WC=W_v_WC
)

# ===================================
# Initialize buffers and prepare data
# ===================================

# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array

# Compute 6D transforms of the base velocity.
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)

# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=data.joint_positions(), base_transform=W_H_B
)

# Allocate the buffer to store the body-fixed link velocities.
L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))

# Store the base velocity.
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
L_v_WL = L_v_WL.at[0].set(B_v_WB)

# Get the joint velocities.
ṡ = data.joint_velocities(model=model, joint_names=model.joint_names())

# Allocate the buffer to store the body-fixed link accelerations,
# and initialize the base acceleration.
L_v̇_WL = jnp.zeros(shape=(model.number_of_links(), 6))
L_v̇_WL = L_v̇_WL.at[0].set(B_X_W @ W_v̇_WB)

# ======================================
# Propagate accelerations and velocities
# ======================================

# The computation of the bias forces is similar to the forward pass of RNEA,
# this time with zero base and joint accelerations. Furthermore, here we do
# not remove gravity during the propagation.

# Initialize the loop.
Carry = tuple[jtp.MatrixJax, jtp.MatrixJax]
carry0: Carry = (L_v_WL, L_v̇_WL)

def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
# Initialize index and unpack the carry.
ii = i - 1
v, a = carry

# Get the motion subspace of the joint.
Si = S[i].squeeze()

# Project the joint velocity into its motion subspace.
vJ = Si * ṡ[ii]

# Propagate the link body-fixed velocity.
v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)

# Propagate the link body-fixed acceleration considering zero joint acceleration.
s̈ = 0.0
a_i = i_X_λi[i] @ a[λ[i]] + Si * s̈ + jaxsim.math.Cross.vx(v[i]) @ vJ
a = a.at[i].set(a_i)

return (v, a), None

# Compute the body-fixed velocity and body-fixed apparent acceleration of the links.
(L_v_WL, L_v̇_WL), _ = (
jax.lax.scan(
f=propagate_accelerations,
init=carry0,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(L_v_WL, L_v̇_WL), None]
)

# ===================================================================
# Convert the body-fixed 6D acceleration to the active representation
# ===================================================================

def body_to_other_representation(
L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the body-fixed apparent acceleration L_v̇_WL to
another representation C_v̇_WL expressed in a generic frame C.
"""

# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L)
return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL)

match data.velocity_representation:
case VelRepr.Body:
C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links())
L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6))

case VelRepr.Inertial:
C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
L_v_CL = L_v_WL

case VelRepr.Mixed:
W_H_L = js.model.forward_kinematics(model=model, data=data)
LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
C_H_L = LW_H_L
L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL)

case _:
raise ValueError(data.velocity_representation)

# Convert from body-fixed to the active representation.
O_v̇_WL = jax.vmap(body_to_other_representation)(
L_v̇_WL=L_v̇_WL, L_v_WL=L_v_WL, C_H_L=C_H_L, L_v_CL=L_v_CL
)

return O_v̇_WL


@jax.jit
def link_contact_forces(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down
32 changes: 32 additions & 0 deletions tests/test_api_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,35 @@ def test_link_jacobians(
v_WL_idt = kin_dyn.frame_velocity(frame_name=link_name)
v_WL_js = js.link.velocity(model=model, data=data, link_index=link_idx)
assert v_WL_js == pytest.approx(v_WL_idt), link_name


def test_link_bias_acceleration(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=velocity_representation,
)

kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)

# =====
# Tests
# =====

for name, index in zip(
model.link_names(),
js.link.names_to_idxs(model=model, link_names=model.link_names()),
):
Jν_idt = kin_dyn.frame_bias_acc(frame_name=name)
Jν_js = js.link.bias_acceleration(model=model, data=data, link_index=index)
assert pytest.approx(Jν_idt) == Jν_js
7 changes: 7 additions & 0 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def test_model_rbda(
)
assert pytest.approx(HH_idt) == HH_js

# Bias accelerations
Jν_js = js.model.link_bias_accelerations(model=model, data=data)
Jν_idt = jnp.stack(
[kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()]
)
assert pytest.approx(Jν_idt) == Jν_js


def test_model_jacobian(
jaxsim_models_types: js.model.JaxSimModel,
Expand Down
9 changes: 9 additions & 0 deletions tests/utils_idyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,15 @@ def frame_velocity(self, frame_name: str) -> npt.NDArray:

return v_WF.toNumPy()

def frame_bias_acc(self, frame_name: str) -> npt.NDArray:

if self.kin_dyn.getFrameIndex(frame_name) < 0:
raise ValueError(f"Frame '{frame_name}' does not exist")

J̇ν = self.kin_dyn.getFrameBiasAcc(frame_name)

return J̇ν.toNumPy()

def com_position(self) -> npt.NDArray:

W_p_G = self.kin_dyn.getCenterOfMassPosition()
Expand Down
Loading