Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify lin/ang serialization of 6D quantities #18

Merged
merged 5 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 15 additions & 42 deletions src/jaxsim/high_level/link.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
import numpy.typing as npt

import jaxsim.high_level
import jaxsim.parsers.descriptions as descriptions
import jaxsim.sixd as sixd
import jaxsim.typing as jtp
from jaxsim.parsers.sdf.utils import flip_velocity_serialization
from jaxsim.physics.algos.forward_kinematics import forward_kinematics
from jaxsim.physics.algos.jacobian import jacobian
from jaxsim.utils import JaxsimDataclass

Expand Down Expand Up @@ -47,31 +44,23 @@ def mass(self) -> jtp.Float:

return self.link_description.mass

def spatial_inertia(self) -> npt.NDArray:
def spatial_inertia(self) -> jtp.Matrix:

return flip_velocity_serialization(self.link_description.inertia)

def com(self) -> jtp.VectorJax:

from jaxsim.math.skew import Skew

skew_mc1 = self.spatial_inertia()[3:6, 0:3]

com_wrt_link_frame = (Skew.vee(skew_mc1) / self.mass()).squeeze()
return com_wrt_link_frame
return self.link_description.inertia

def com_position(self, in_link_frame: bool = True) -> jtp.VectorJax:

from jaxsim.math.skew import Skew
from jaxsim.math.inertia import Inertia

skew_mc1 = self.spatial_inertia()[3:6, 0:3]
L_p_CoM = (Skew.vee(skew_mc1) / self.mass()).squeeze()
_, L_p_CoM, _ = Inertia.to_params(M=self.spatial_inertia())

if in_link_frame:
return L_p_CoM
return L_p_CoM.squeeze()

W_H_L = self.transform()
return W_H_L @ L_p_CoM
W_ph_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])

return W_ph_CoM[0:3].squeeze()

# ==========
# Kinematics
Expand All @@ -90,12 +79,7 @@ def orientation(self, dcm: bool = False) -> jtp.Vector:

def transform(self) -> jtp.Matrix:

return forward_kinematics(
model=self.parent_model.physics_model,
body_index=self.index(),
q=self.parent_model.data.model_state.joint_positions,
xfb=self.parent_model.data.model_state.xfb(),
)
return self.parent_model.forward_kinematics()[self.index()]

def velocity(self, vel_repr: VelRepr = None) -> jtp.Vector:

Expand All @@ -119,20 +103,12 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:
output_vel_repr = self.parent_model.velocity_representation

# Return the doubly left-trivialized free-floating jacobian
J_body_anglin = jacobian(
L_J_WL_B = jacobian(
model=self.parent_model.physics_model,
body_index=self.index(),
q=self.parent_model.data.model_state.joint_positions,
)

# Convert ang-lin to lin-ang serialization
L_J_WL_B = jnp.zeros_like(J_body_anglin)
L_J_WL_B = L_J_WL_B.at[0:6, 0:6].set(
flip_velocity_serialization(J_body_anglin[0:6, 0:6])
)
L_J_WL_B = L_J_WL_B.at[0:3, 6:].set(J_body_anglin[3:6, 6:])
L_J_WL_B = L_J_WL_B.at[3:6, 6:].set(J_body_anglin[0:3, 6:])

if self.parent_model.velocity_representation is VelRepr.Body:

L_J_WL_target = L_J_WL_B
Expand Down Expand Up @@ -183,8 +159,7 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:

def external_force(self) -> jtp.Vector:

W_f_ext_anglin = self.parent_model.data.model_input.f_ext[self.index()]
W_f_ext = jnp.hstack([W_f_ext_anglin[3:6], W_f_ext_anglin[0:3]])
W_f_ext = self.parent_model.data.model_input.f_ext[self.index()]

if self.parent_model.velocity_representation is VelRepr.Inertial:
return W_f_ext
Expand Down Expand Up @@ -235,12 +210,11 @@ def add_external_force(
else:
raise ValueError(self.parent_model.velocity_representation)

W_f_ext_anglin = jnp.hstack([W_f_ext[3:6], W_f_ext[0:3]])
W_f_ext_current = self.parent_model.data.model_input.f_ext[self.index(), :]

self.parent_model.data.model_input.f_ext = (
self.parent_model.data.model_input.f_ext.at[self.index(), :].set(
W_f_ext_current + W_f_ext_anglin
W_f_ext_current + W_f_ext
)
)

Expand All @@ -267,7 +241,7 @@ def add_com_external_force(
W_H_GL = W_H_L @ L_H_GL
GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint()

W_f_ext = GL_X_W.T @ GL_f_ext
W_f_ext = GL_X_W.transpose() @ GL_f_ext

elif self.parent_model.velocity_representation is VelRepr.Mixed:

Expand All @@ -277,17 +251,16 @@ def add_com_external_force(
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint()

W_f_ext = GW_X_W.T @ GW_f_ext
W_f_ext = GW_X_W.transpose() @ GW_f_ext

else:
raise ValueError(self.parent_model.velocity_representation)

W_f_ext_anglin = jnp.hstack([W_f_ext[3:6], W_f_ext[0:3]])
W_f_ext_current = self.parent_model.data.model_input.f_ext[self.index(), :]

self.parent_model.data.model_input.f_ext = (
self.parent_model.data.model_input.f_ext.at[self.index(), :].set(
W_f_ext_current + W_f_ext_anglin
W_f_ext_current + W_f_ext
)
)

Expand Down
55 changes: 11 additions & 44 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import jaxsim.physics.model.physics_model_state
import jaxsim.typing as jtp
from jaxsim import high_level, physics, sixd
from jaxsim.parsers.sdf.utils import flip_velocity_serialization
from jaxsim.physics.algos import soft_contacts
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.simulation import ode_data, ode_integration
Expand Down Expand Up @@ -205,7 +204,7 @@ def reduce(self, considered_joints: List[str]) -> None:

physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
model_description=reduced_model_description,
gravity=self.physics_model.gravity[3:6],
gravity=self.physics_model.gravity[0:3],
)

reduced_model = Model.build(
Expand Down Expand Up @@ -281,7 +280,7 @@ def nr_of_joints(self) -> int:

return len(self._joints)

def total_mass(self) -> float:
def total_mass(self) -> jtp.Float:

return jnp.sum(jnp.array([l.mass() for l in self.links()]))

Expand Down Expand Up @@ -419,7 +418,6 @@ def base_transform(self) -> jtp.MatrixJax:

def base_velocity(self) -> jtp.Vector:

# Get the base 6D velocity expressed in inertial representation
W_v_WB = jnp.hstack(
[
self.data.model_state.base_linear_velocity,
Expand All @@ -431,8 +429,7 @@ def base_velocity(self) -> jtp.Vector:

def external_forces(self) -> jtp.Matrix:

f_ext_anglin = self.data.model_input.f_ext
W_f_ext = jnp.hstack([f_ext_anglin[:, 3:6], f_ext_anglin[:, 0:3]])
W_f_ext = self.data.model_input.f_ext

inertial_to_active = lambda f: self.inertial_to_active_representation(
f, is_force=True
Expand Down Expand Up @@ -465,26 +462,11 @@ def generalized_jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:

def free_floating_mass_matrix(self) -> jtp.Matrix:

M_body_anglin = jaxsim.physics.algos.crba.crba(
M_body = jaxsim.physics.algos.crba.crba(
model=self.physics_model,
q=self.data.model_state.joint_positions,
)

Mbb = flip_velocity_serialization(M_body_anglin[0:6, 0:6])
Mbs_ang = M_body_anglin[0:3, 6:]
Mbs_lin = M_body_anglin[3:6, 6:]
Msb_ang = M_body_anglin[6:, 0:3]
Msb_lin = M_body_anglin[6:, 3:6]

M_body_linang = jnp.zeros_like(M_body_anglin)
M_body_linang = M_body_linang.at[0:6, 0:6].set(Mbb)
M_body_linang = M_body_linang.at[0:6, 6:].set(jnp.vstack([Mbs_lin, Mbs_ang]))
M_body_linang = M_body_linang.at[6:, 0:6].set(jnp.hstack([Msb_lin, Msb_ang]))
M_body_linang = M_body_linang.at[6:, 6:].set(M_body_anglin[6:, 6:])

# This is M in body-fixed velocity representation
M_body = M_body_linang

if self.velocity_representation is VelRepr.Body:
return M_body

Expand Down Expand Up @@ -570,9 +552,9 @@ def com_position(self) -> jtp.Vector:
for l in self.links()
]

B_c_homo = (1 / m) * jnp.sum(jnp.array(com_links), axis=0)
B_ph_CoM = (1 / m) * jnp.sum(jnp.array(com_links), axis=0)

return (W_H_B @ B_c_homo)[0:3]
return (W_H_B @ B_ph_CoM)[0:3]

# ==========
# Algorithms
Expand Down Expand Up @@ -618,23 +600,19 @@ def inverse_dynamics(
# Express base_acceleration in inertial representation
W_a_WB = self.active_to_inertial_representation(array=base_acceleration)

# Convert to ang-lin serialization
W_a_WB_anglin = self.flip_lin_ang_6D(array=W_a_WB)

# Compute RNEA
f_B_anglin, tau = jaxsim.physics.algos.rnea.rnea(
W_f_B, tau = jaxsim.physics.algos.rnea.rnea(
model=self.physics_model,
xfb=self.data.model_state.xfb(),
q=self.data.model_state.joint_positions,
qd=self.data.model_state.joint_velocities,
qdd=joint_accelerations,
a0fb=W_a_WB_anglin,
a0fb=W_a_WB,
f_ext=self.data.model_input.f_ext,
)

# Adjust shape and convert to lin-ang serialization
# Adjust shape
tau = jnp.atleast_1d(tau.squeeze())
W_f_B = self.flip_lin_ang_6D(array=f_B_anglin)

# Express W_f_B in the active representation
f_B = self.inertial_to_active_representation(array=W_f_B, is_force=True)
Expand Down Expand Up @@ -667,7 +645,7 @@ def forward_dynamics_aba(
tau = tau if tau is not None else jnp.zeros_like(self.joint_positions())

# Compute ABA
W_a_WB_anglin, sdd = jaxsim.physics.algos.aba.aba(
W_a_WB, sdd = jaxsim.physics.algos.aba.aba(
model=self.physics_model,
xfb=self.data.model_state.xfb(),
q=self.data.model_state.joint_positions,
Expand All @@ -676,9 +654,8 @@ def forward_dynamics_aba(
f_ext=self.data.model_input.f_ext,
)

# Adjust shape and convert to lin-ang serialization
# Adjust shape
sdd = jnp.atleast_1d(sdd.squeeze())
W_a_WB = self.flip_lin_ang_6D(array=W_a_WB_anglin)

# Express W_a_WB in the active representation
a_WB = self.inertial_to_active_representation(array=W_a_WB)
Expand Down Expand Up @@ -969,16 +946,6 @@ def integrate(
# Private methods
# ===============

@staticmethod
def flip_lin_ang_6D(array: jtp.Array) -> jtp.Array:

array = jnp.array(array).squeeze()

if array.size != 6:
raise ValueError(array.size)

return jnp.flip(jnp.array(jnp.split(array, 2)), axis=0).flatten()

def inertial_to_active_representation(
self, array: jtp.Array, is_force: bool = False
) -> jtp.Array:
Expand Down
12 changes: 6 additions & 6 deletions src/jaxsim/math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def translate(direction: jtp.Vector) -> jtp.Matrix:

return jnp.array(
[
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, z, -y, 1, 0, 0],
[-z, 0, x, 0, 1, 0],
[y, -x, 0, 0, 0, 1],
[1, 0, 0, 0, z, -y],
[0, 1, 0, -z, 0, x],
[0, 0, 1, y, -x, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1],
]
)
6 changes: 3 additions & 3 deletions src/jaxsim/math/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def coordinates_tf(X: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
raise ValueError(p.shape)

R = X[0:3, 0:3]
r = -Skew.vee(R.T @ X[3:6, 0:3])
r = -Skew.vee(R.T @ X[0:3, 3:6])

if cols_p > 1:
r = jnp.tile(r, (1, cols_p))
Expand All @@ -47,7 +47,7 @@ def velocities_threed(v_6d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
v = jnp.repeat(v, cols_p, axis=1)

if rows_v == 6:
vp = v[3:6, :] + jnp.cross(v[0:3, :], p, axis=0)
vp = v[0:3, :] + jnp.cross(v[3:6, :], p, axis=0)
else:
raise ValueError(v.shape)

Expand All @@ -71,7 +71,7 @@ def forces_sixd(f_3d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
raise ValueError(cols_p, cols_fp)

if rows_fp == 3:
f = jnp.vstack([jnp.cross(p, fp, axis=0), fp])
f = jnp.vstack([fp, jnp.cross(p, fp, axis=0)])
else:
raise ValueError(fp.shape)

Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/math/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ class Cross:
@staticmethod
def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:

ω, v = jnp.split(velocity_sixd.squeeze(), 2)
v, ω = jnp.split(velocity_sixd.squeeze(), 2)

v_cross = jnp.block(
[
[Skew.wedge(vector=ω), jnp.zeros(shape=(3, 3))],
[Skew.wedge(vector=v), Skew.wedge(vector=ω)],
[Skew.wedge(vector=ω), Skew.wedge(vector=v)],
[jnp.zeros(shape=(3, 3)), Skew.wedge(vector=ω)],
]
)

Expand Down
18 changes: 9 additions & 9 deletions src/jaxsim/math/inertia.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@

class Inertia:
@staticmethod
def to_sixd(mass: float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:

if I.shape != (3, 3):
raise ValueError(I, I.shape)

C = Skew.wedge(vector=com)
c = Skew.wedge(vector=com)

M = jnp.vstack(
M = jnp.block(
[
jnp.hstack([I + mass * C @ C.T, mass * C]),
jnp.hstack([mass * C.T, mass * jnp.eye(3)]),
[mass * jnp.eye(3), mass * c.T],
[mass * c, I + mass * c @ c.T],
]
)

return M

@staticmethod
def to_params(M: jtp.Matrix) -> Tuple[float, jtp.Vector, jtp.Matrix]:
def to_params(M: jtp.Matrix) -> Tuple[jtp.Float, jtp.Vector, jtp.Matrix]:

m = M[5, 5]
m = jnp.diag(M[0:3, 0:3]).sum() / 3

mC = M[0:3, 3:6]
mC = M[3:6, 0:3]
c = Skew.vee(mC) / m
I = M[0:3, 0:3] - mC @ mC.T / m
I = M[3:6, 3:6] - (mC @ mC.T / m)

return m, c, I
Loading