Skip to content

Commit

Permalink
Merge pull request #34 from ami-iit/refactor/soft_contacts
Browse files Browse the repository at this point in the history
Rewrite soft-contacts algorithm
  • Loading branch information
diegoferigo committed Jun 9, 2023
2 parents 06c4e26 + 706f5d2 commit 2e50bf1
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 201 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/high_level/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,4 @@ def add_com_external_force(
)

def in_contact(self) -> jtp.Bool:
return not jnp.allclose(self.external_force(), 0)
return self.parent_model.in_contact()[self.index()]
43 changes: 43 additions & 0 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,41 @@ def joints(self, joint_names: List[str] = None) -> List[high_level.joint.Joint]:

return [self._joints[name] for name in joint_names]

def in_contact(
self,
link_names: Optional[List[str]] = None,
terrain: Terrain = FlatTerrain(),
) -> jtp.Vector:
""""""

link_names = link_names if link_names is not None else self.link_names()

if set(link_names) - set(self._links.keys()) != set():
raise ValueError("One or more link names are not part of the model")

from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel

W_p_Ci, _ = collidable_points_pos_vel(
model=self.physics_model,
q=self.data.model_state.joint_positions,
qd=self.data.model_state.joint_velocities,
xfb=self.data.model_state.xfb(),
)

terrain_height = jax.vmap(terrain.height)(W_p_Ci[0, :], W_p_Ci[1, :])

below_terrain = W_p_Ci[2, :] <= terrain_height

links_in_contact = jax.vmap(
lambda link_index: jnp.where(
self.physics_model.gc.body == link_index,
below_terrain,
jnp.zeros_like(below_terrain, dtype=bool),
).any()
)(jnp.array([link.index() for link in self.links(link_names=link_names)]))

return links_in_contact

# ==================
# Vectorized methods
# ==================
Expand Down Expand Up @@ -441,6 +476,14 @@ def joint_velocities(self, joint_names: List[str] = None) -> jtp.Vector:
self._joint_indices(joint_names=joint_names)
]

def joint_generalized_forces_targets(
self, joint_names: List[str] = None
) -> jtp.Vector:
if self.dofs() == 0 and (joint_names is None or len(joint_names) == 0):
return jnp.array([])

return self.data.model_input.tau[self._joint_indices(joint_names=joint_names)]

def joint_limits(
self, joint_names: List[str] = None
) -> Tuple[jtp.Vector, jtp.Vector]:
Expand Down
Loading

0 comments on commit 2e50bf1

Please sign in to comment.