Skip to content

Commit

Permalink
Move r/w methods of Link to Model
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 9, 2023
1 parent d38be87 commit 5a975d4
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 88 deletions.
88 changes: 0 additions & 88 deletions src/jaxsim/high_level/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,94 +250,6 @@ def external_force(self) -> jtp.Vector:

return f_ext

@functools.partial(oop.jax_tf.method_rw, static_argnames=["additive"])
def apply_external_force(
self, force: jtp.Array = None, torque: jtp.Array = None, additive: bool = True
) -> None:
""""""

force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)

f_ext = jnp.hstack([force, torque])

if self.parent_model.velocity_representation is VelRepr.Inertial:
W_f_ext = f_ext

elif self.parent_model.velocity_representation is VelRepr.Body:
L_f_ext = f_ext
W_H_L = self.transform()
L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint()

W_f_ext = L_X_W.transpose() @ L_f_ext

elif self.parent_model.velocity_representation is VelRepr.Mixed:
LW_f_ext = f_ext

W_p_L = self.transform()[0:3, 3]
W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint()

W_f_ext = LW_X_W.transpose() @ LW_f_ext

else:
raise ValueError(self.parent_model.velocity_representation)

# Compute the new 6D force
W_f_ext_current = self.parent_model.data.model_input.f_ext[self.index(), :]
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext

# Apply the new 6D force to the link frame
self.parent_model.data.model_input.f_ext = (
self.parent_model.data.model_input.f_ext.at[self.index(), :].set(new_force)
)

@functools.partial(oop.jax_tf.method_rw, static_argnames=["additive"])
def apply_com_external_force(
self, force: jtp.Array = None, torque: jtp.Array = None, additive: bool = True
) -> None:
""""""

force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)

f_ext = jnp.hstack([force, torque])

if self.parent_model.velocity_representation is VelRepr.Inertial:
W_f_ext = f_ext

elif self.parent_model.velocity_representation is VelRepr.Body:
GL_f_ext = f_ext

W_H_L = self.transform()
L_p_CoM = self.com_position(in_link_frame=True)
L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM)
W_H_GL = W_H_L @ L_H_GL
GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint()

W_f_ext = GL_X_W.transpose() @ GL_f_ext

elif self.parent_model.velocity_representation is VelRepr.Mixed:
GW_f_ext = f_ext

W_p_CoM = self.com_position(in_link_frame=False)
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint()

W_f_ext = GW_X_W.transpose() @ GW_f_ext

else:
raise ValueError(self.parent_model.velocity_representation)

# Compute the new 6D force
W_f_ext_current = self.parent_model.data.model_input.f_ext[self.index(), :]
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext

# Apply the new 6D force to the link frame
self.parent_model.data.model_input.f_ext = (
self.parent_model.data.model_input.f_ext.at[self.index(), :].set(new_force)
)

@functools.partial(oop.jax_tf.method_ro)
def in_contact(self) -> jtp.Bool:
""""""
Expand Down
118 changes: 118 additions & 0 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,124 @@ def external_forces(self) -> jtp.Matrix:

return jax.vmap(inertial_to_active, in_axes=0)(W_f_ext)

# =======================
# Single link r/w methods
# =======================

@functools.partial(
oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
)
def apply_external_force_to_link(
self,
link_name: str,
force: jtp.Array = None,
torque: jtp.Array = None,
additive: bool = True,
) -> None:
""""""

# Get the target link with the correct mutability
link = self.get_link(link_name=link_name)
link._set_mutability(mutability=self._mutability())

# Initialize zero force components if not set
force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)

# Build the target 6D force in the active representation
f_ext = jnp.hstack([force, torque])

# Convert the 6D force to the inertial representation
if self.velocity_representation is VelRepr.Inertial:
W_f_ext = f_ext

elif self.velocity_representation is VelRepr.Body:
L_f_ext = f_ext
W_H_L = link.transform()
L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint()

W_f_ext = L_X_W.transpose() @ L_f_ext

elif self.velocity_representation is VelRepr.Mixed:
LW_f_ext = f_ext

W_p_L = link.transform()[0:3, 3]
W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint()

W_f_ext = LW_X_W.transpose() @ LW_f_ext

else:
raise ValueError(self.velocity_representation)

# Obtain the new 6D force considering the 'additive' flag
W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext

# Update the model data
self.data.model_input.f_ext = self.data.model_input.f_ext.at[
link.index(), :
].set(new_force)

@functools.partial(
oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
)
def apply_external_force_to_link_com(
self,
link_name: str,
force: jtp.Array = None,
torque: jtp.Array = None,
additive: bool = True,
) -> None:
""""""

# Get the target link with the correct mutability
link = self.get_link(link_name=link_name)
link._set_mutability(mutability=self._mutability())

# Initialize zero force components if not set
force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)

# Build the target 6D force in the active representation
f_ext = jnp.hstack([force, torque])

# Convert the 6D force to the inertial representation
if self.velocity_representation is VelRepr.Inertial:
W_f_ext = f_ext

elif self.velocity_representation is VelRepr.Body:
GL_f_ext = f_ext

W_H_L = link.transform()
L_p_CoM = link.com_position(in_link_frame=True)
L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM)
W_H_GL = W_H_L @ L_H_GL
GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint()

W_f_ext = GL_X_W.transpose() @ GL_f_ext

elif self.velocity_representation is VelRepr.Mixed:
GW_f_ext = f_ext

W_p_CoM = link.com_position(in_link_frame=False)
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint()

W_f_ext = GW_X_W.transpose() @ GW_f_ext

else:
raise ValueError(self.velocity_representation)

# Obtain the new 6D force considering the 'additive' flag
W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext

# Update the model data
self.data.model_input.f_ext = self.data.model_input.f_ext.at[
link.index(), :
].set(new_force)

# ================================================
# Generalized methods and free-floating quantities
# ================================================
Expand Down

0 comments on commit 5a975d4

Please sign in to comment.