Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow keeping the base pose when a model is reduced #42

Merged
merged 5 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,20 @@ def __post_init__(self):

self._set_mutability(original_mutability)

def reduce(self, considered_joints: List[str]) -> None:
def reduce(
self, considered_joints: List[str], keep_base_pose: bool = False
) -> None:
"""
Reduce the model by lumping together the links connected by removed joints.

Args:
considered_joints: The list of joints to consider.
keep_base_pose: A flag indicating whether to keep the base pose or not.
"""

# Reduce the model description
# Reduce the model description.
# If considered_joints contains joints not existing in the model, the method
# will raise an exception.
reduced_model_description = self.physics_model.description.reduce(
considered_joints=considered_joints
)
Expand All @@ -311,14 +316,22 @@ def reduce(self, considered_joints: List[str]) -> None:
vel_repr=self.velocity_representation,
)

# Replace the current model with the reduced one
original_mutability = self._mutability()
self._set_mutability(mutability=self._mutability().MUTABLE_NO_VALIDATION)
self.physics_model = reduced_model.physics_model
self.data = reduced_model.data
self._links = reduced_model._links
self._joints = reduced_model._joints
self._set_mutability(original_mutability)
# Extract the base pose
W_p_B = self.base_position()
W_Q_B = self.base_orientation(dcm=False)

# Replace the current model with the reduced model.
# Since the structure of the PyTree changes, we disable validation.
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
self.physics_model = reduced_model.physics_model
self.data = reduced_model.data
self._links = reduced_model._links
self._joints = reduced_model._joints

if keep_base_pose:
with self.mutable_context(mutability=Mutability.MUTABLE):
self.reset_base_position(position=W_p_B)
self.reset_base_orientation(orientation=W_Q_B, dcm=False)

def zero(self) -> None:
self.data = ModelData.zero(physics_model=self.physics_model)
Expand Down Expand Up @@ -993,10 +1006,13 @@ def reset_joint_positions(

# TODO: joint position limits

self.data.model_state.joint_positions = (
self.data.model_state.joint_positions.at[
self._joint_indices(joint_names=joint_names)
].set(positions)
self.data.model_state.joint_positions = jnp.atleast_1d(
jnp.array(
self.data.model_state.joint_positions.at[
self._joint_indices(joint_names=joint_names)
].set(positions),
dtype=float,
)
)

def reset_joint_velocities(
Expand All @@ -1013,14 +1029,17 @@ def reset_joint_velocities(

# TODO: joint velocity limits

self.data.model_state.joint_velocities = (
self.data.model_state.joint_velocities.at[
self._joint_indices(joint_names=joint_names)
].set(velocities)
self.data.model_state.joint_velocities = jnp.atleast_1d(
jnp.array(
self.data.model_state.joint_velocities.at[
self._joint_indices(joint_names=joint_names)
].set(velocities),
dtype=float,
)
)

def reset_base_position(self, position: jtp.Vector) -> None:
self.data.model_state.base_position = position
self.data.model_state.base_position = jnp.array(position, dtype=float)

def reset_base_orientation(self, orientation: jtp.Array, dcm: bool = False) -> None:
if dcm:
Expand All @@ -1030,7 +1049,7 @@ def reset_base_orientation(self, orientation: jtp.Array, dcm: bool = False) -> N
).as_quaternion_xyzw()
orientation = orientation_xyzw[to_wxyz]

self.data.model_state.base_quaternion = orientation
self.data.model_state.base_quaternion = jnp.array(orientation, dtype=float)

def reset_base_transform(self, transform: jtp.Matrix) -> None:
if transform.shape != (4, 4):
Expand Down Expand Up @@ -1074,8 +1093,13 @@ def reset_base_velocity(self, base_velocity: jtp.VectorJax) -> None:
else:
raise ValueError(self.velocity_representation)

self.data.model_state.base_linear_velocity = base_velocity_inertial[0:3]
self.data.model_state.base_angular_velocity = base_velocity_inertial[3:6]
self.data.model_state.base_linear_velocity = jnp.array(
base_velocity_inertial[0:3], dtype=float
)

self.data.model_state.base_angular_velocity = jnp.array(
base_velocity_inertial[3:6], dtype=float
)

# ===========
# Integration
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

# Propagate link velocity
vJ = S[i] * qd[ii]
vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0

v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)
Expand Down
4 changes: 3 additions & 1 deletion src/jaxsim/simulation/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ class JaxSim(JaxsimDataclass):
steps_per_run: Static[jtp.Int] = dataclasses.field(default=1)

# Default velocity representation (could be overridden for individual models)
velocity_representation: VelRepr = dataclasses.field(default=VelRepr.Inertial)
velocity_representation: Static[VelRepr] = dataclasses.field(
default=VelRepr.Inertial
)

# Integrator type
integrator_type: Static[ode_integration.IntegratorType] = dataclasses.field(
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def editable(self: T, validate: bool = True) -> ContextManager[T]:
def mutable_context(self: T, mutability: Mutability) -> ContextManager[T]:
""""""

original_mutability = self._mutability
original_mutability = self._mutability()

self._set_mutability(mutability)
yield self
Expand Down