Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement joint position limits #22

Merged
merged 4 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class JointDescription:
friction_static: float = 0.0
friction_viscous: float = 0.0

position_limit_damper: float = 0.0
position_limit_spring: float = 0.0

position_limit: Tuple[float, float] = (0.0, 0.0)
initial_position: Union[float, npt.NDArray] = 0.0

Expand Down
10 changes: 10 additions & 0 deletions src/jaxsim/parsers/sdf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ def extract_data_from_sdf(
and j.axis.dynamics is not None
and j.axis.dynamics.friction is not None
else 0.0,
position_limit_damper=j.axis.limit.dissipation
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.dissipation is not None
else 0.0,
position_limit_spring=j.axis.limit.stiffness
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.stiffness is not None
else 0.0,
)
for j in sdf_tree.model.joints
if j.type in {"revolute", "prismatic", "fixed"}
Expand Down
16 changes: 16 additions & 0 deletions src/jaxsim/physics/model/physics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class PhysicsModel(JaxsimDataclass):
_joint_friction_static: Dict[int, float] = dataclasses.field(default_factory=dict)
_joint_friction_viscous: Dict[int, float] = dataclasses.field(default_factory=dict)

_joint_limit_spring: Dict[int, float] = dataclasses.field(default_factory=dict)
_joint_limit_damper: Dict[int, float] = dataclasses.field(default_factory=dict)

def __post_init__(self):

if self.initial_state is None:
Expand Down Expand Up @@ -102,6 +105,17 @@ def build_from(
joint.index: joint.friction_viscous for joint in model_description.joints
}

# Dicts from the joint index to the spring and damper joint limits parameters.
# Note: the joint index is equal to its child link index.
joint_limit_spring = {
joint.index: joint.position_limit_spring
for joint in model_description.joints
}
joint_limit_damper = {
joint.index: joint.position_limit_damper
for joint in model_description.joints
}

# Transform between model's root and model's base link
# (this is just the pose of the base link in the SDF description)
base_link = model_description.links_dict[model_description.link_names()[0]]
Expand Down Expand Up @@ -160,6 +174,8 @@ def build_from(
_link_inertias_dict=link_spatial_inertias_dict,
_joint_friction_static=joint_friction_static,
_joint_friction_viscous=joint_friction_viscous,
_joint_limit_spring=joint_limit_spring,
_joint_limit_damper=joint_limit_damper,
gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]),
is_floating_base=True,
gc=GroundContact.build_from(model_description=model_description),
Expand Down
22 changes: 21 additions & 1 deletion src/jaxsim/simulation/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,26 @@ def dx_dt(
terrain=terrain,
)

# =====================
# Joint position limits
# =====================

# Get the joint position limits
s_min, s_max = jnp.array(
[j.position_limit for j in physics_model.description.joints_dict.values()]
).T

# Get the spring/damper parameters of joint limits enforcement
# k_spring = jnp.array(list(physics_model._joint_limit_spring.values()))
k_damper = jnp.array(list(physics_model._joint_limit_damper.values()))

# Compute the joint torques that enforce joint limits
s = ode_state.physics_model.joint_positions
# sd = ode_state.physics_model.joint_velocities
tau_min = jnp.where(s <= s_min, k_damper * (s_min - s), 0)
tau_max = jnp.where(s >= s_max, k_damper * (s_max - s), 0)
tau_limit = tau_max + tau_min

# ==============
# Joint friction
# ==============
Expand All @@ -127,7 +147,7 @@ def dx_dt(
total_forces = ode_input.physics_model.f_ext + contact_forces_links

# Compute the joint torques to actuate
tau = ode_input.physics_model.tau + tau_friction
tau = ode_input.physics_model.tau + tau_friction + tau_limit

W_a_WB, qdd = algos.aba.aba(
model=physics_model,
Expand Down