Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor changes for the upcoming release #173

Merged
merged 26 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
acedffd
Update notation for support body array
diegoferigo Jun 11, 2024
8952d90
Minor fixes of jaxsim.api.model
diegoferigo Jun 11, 2024
439d6a9
Expose the Baumgarte regularization coefficient
diegoferigo Jun 11, 2024
9d98def
Simplify condition to check JAX_ENABLE_X64 env var
diegoferigo Jun 11, 2024
0e55a1a
Add jaxsim.typing.Scalar and jaxsim.typing.ScalarLike
diegoferigo Jun 11, 2024
19552f0
Update condition to exclude body-fixed representation in CoM test
diegoferigo Jun 11, 2024
5e2897b
Update link pose check of reduced models
diegoferigo Jun 11, 2024
c4e5d33
Add frame pose check of reduced models
diegoferigo Jun 11, 2024
f383d8b
Fix EOFs and trailing whitespaces
flferretti Jun 11, 2024
3564357
Allow overriding default logging verbosity
diegoferigo Jun 11, 2024
4d62b6f
Extend model reduction test
diegoferigo Jun 11, 2024
b201859
Remove unused variables
flferretti Jun 11, 2024
a19f18e
Avoid redefining existing functions
flferretti Jun 11, 2024
e25ec0a
Remove double import
flferretti Jun 11, 2024
741cf67
Use set comprehensions
flferretti Jun 11, 2024
1063e2e
Update deprecated `typing.Hashable` for Python 3.12
flferretti Jun 11, 2024
aefe446
Avoid non-assigned expressions
flferretti Jun 11, 2024
24e86ed
Update typing
flferretti Jun 11, 2024
15ab7c5
Add `KinematicGraph.joints_removed` property
flferretti Jun 11, 2024
baf07bc
Install only "gz sdf" command instead of full Gazebo Sim in CI
diegoferigo Jun 11, 2024
b7be8cc
Update `.gitignore`
flferretti Jun 11, 2024
e9a3f83
Make jaxsim.description classes hashable
diegoferigo Jun 12, 2024
3f41551
Include ModelDescription in JaxSimModel hash
diegoferigo Jun 12, 2024
bf9d7a2
Make sure that mutable attributes do not get cross-altered in reduction
diegoferigo Jun 12, 2024
8947957
Minor import changes
diegoferigo Jun 12, 2024
82502a3
Minor typing update
diegoferigo Jun 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# GitHub syntax highlighting
pixi.lock linguist-language=YAML

5 changes: 2 additions & 3 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ jobs:
- *src
- *tests

# https://gazebosim.org/docs/harmonic/install_ubuntu
- name: Install Gazebo Sim
- name: Install 'gz sdf' system command
if: |
contains(matrix.os, 'ubuntu') &&
(github.event_name != 'pull_request' ||
Expand All @@ -130,7 +129,7 @@ jobs:
sudo wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
sudo apt-get update
sudo apt-get install gz-harmonic
sudo apt-get install --no-install-recommends libsdformat13 gz-tools2

- name: Run the Python tests
if: |
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,10 @@ src/jaxsim/_version.py

# ruff
.ruff_cache/

# pixi environments
.pixi

# data
.mp4
.png
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX.

Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.

## Features

Expand All @@ -25,7 +25,7 @@ Its design facilitates research and accelerates prototyping in the intersection

### JaxSim as a multibody dynamics library

- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
- Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion.
- Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation].
- Exposes all the necessary quantities to develop controllers in centroidal coordinates.
Expand Down Expand Up @@ -132,10 +132,10 @@ The main differences between MJX/Brax and JaxSim are as follows:

- JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS].
- JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.
Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
function of the state $(\mathbf{q}, \boldsymbol{\nu})$, it doesn't require running any optimization algorithm
when stepping the simulation forward.
- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.
- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.

[brax]: https://github.com/google/brax
[mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
Expand Down
46 changes: 39 additions & 7 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ def _jnp_options() -> None:

import jax

# Enable by default
if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"):
# Enable by default 64bit precision in JAX.
if os.environ.get("JAX_ENABLE_X64", "1") != "0":

logging.info("Enabling JAX to use 64bit precision")
jax.config.update("jax_enable_x64", True)

Expand All @@ -27,6 +28,7 @@ def _np_options() -> None:


def _is_editable() -> bool:

import importlib.util
import pathlib
import site
Expand All @@ -45,11 +47,40 @@ def _is_editable() -> bool:
return jaxsim_package_dir not in site.getsitepackages()


# Initialize the logging verbosity
if _is_editable():
logging.configure(level=logging.LoggingLevel.DEBUG)
else:
logging.configure(level=logging.LoggingLevel.WARNING)
def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
"""
Get the default logging level.

Args:
env_var: The environment variable to check.

Returns:
The logging level to set.
"""

import os

# Define the default logging level depending on the installation mode.
default_logging_level = (
logging.LoggingLevel.DEBUG
if _is_editable() # noqa: F821
else logging.LoggingLevel.WARNING
)

# Allow to override the default logging level with an environment variable.
try:
return logging.LoggingLevel[
os.environ.get(env_var, default_logging_level.name).upper()
]

except KeyError as exc:
msg = f"Invalid logging level defined in {env_var}='{os.environ[env_var]}'"
raise RuntimeError(msg) from exc


# Configure the logger with the default logging level.
logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))


# Configure JAX
_jnp_options()
Expand All @@ -59,6 +90,7 @@ def _is_editable() -> bool:

del _jnp_options
del _np_options
del _get_default_logging_level
del _is_editable

from . import terrain # isort:skip
Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,20 @@ def jacobian(

W_H_C = transforms(model=model, data=data)

def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
C_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_C, inverse=True
)
C_J_WC = C_X_W @ W_J_WC
return C_J_WC

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)

case VelRepr.Mixed:

W_H_C = transforms(model=model, data=data)

def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:

W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))

Expand All @@ -389,7 +389,7 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
CW_J_WC = CW_X_W @ W_J_WC
return CW_J_WC

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)

case _:
raise ValueError(output_vel_repr)
Expand Down
5 changes: 4 additions & 1 deletion src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
import numpy as np
from jax_dataclasses import Static

import jaxsim.typing as jtp
Expand Down Expand Up @@ -220,7 +221,9 @@ def __hash__(self) -> int:
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())),
hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())),
hash(self._parent_array),
hash(self._support_body_array_bool),
)
)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def jacobian(
)

# Compute the actual doubly-left free-floating jacobian of the link.
κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B
κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B

# Adjust the input representation such that `J_WL_I @ I_ν`.
match data.velocity_representation:
Expand Down
39 changes: 16 additions & 23 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jaxsim.parsers.descriptions
import jaxsim.typing as jtp
from jaxsim.math import Cross
from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability
from jaxsim.utils import JaxsimDataclass, Mutability

from .common import VelRepr

Expand All @@ -33,6 +33,7 @@ class JaxSimModel(JaxsimDataclass):
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
)

kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)
Expand All @@ -41,13 +42,9 @@ class JaxSimModel(JaxsimDataclass):
default=None, repr=False, compare=False, hash=False
)

_description: Static[
HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
] = dataclasses.field(default=None, repr=False, compare=False, hash=False)

@property
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
return self._description.get()
description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)

def __eq__(self, other: JaxSimModel) -> bool:

Expand All @@ -61,6 +58,7 @@ def __hash__(self) -> int:
return hash(
(
hash(self.model_name),
hash(self.description),
hash(self.kin_dyn_parameters),
)
)
Expand Down Expand Up @@ -157,7 +155,7 @@ def build(
# Build the model
model = JaxSimModel(
model_name=model_name,
_description=HashlessObject(obj=model_description),
description=model_description,
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
Expand Down Expand Up @@ -302,7 +300,7 @@ def reduce(
locked_joint_positions:
A dictionary containing the positions of the joints to be considered
in the reduction process. The removed joints in the reduced model
will have their position locked to their value in this dictionary.
will have their position locked to their value of this dictionary.
If a joint is not part of the dictionary, its position is set to zero.
"""

Expand All @@ -315,10 +313,9 @@ def reduce(
new_joints = set(model.joint_names()) - set(locked_joint_positions)
raise ValueError(f"Passed joints not existing in the model: {new_joints}")

# Copy the model description with a deep copy of the joints.
intermediate_description = dataclasses.replace(
model.description, joints=copy.deepcopy(model.description.joints)
)
# Operate on a deep copy of the model description in order to prevent problems
# when mutable attributes are updated.
intermediate_description = copy.deepcopy(model.description)

# Update the initial position of the joints.
# This is necessary to compute the correct pose of the link pairs connected
Expand Down Expand Up @@ -686,8 +683,6 @@ def to_active(
another representation C_v̇_WB expressed in a generic frame C.
"""

from jaxsim.math import Cross

# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
Expand Down Expand Up @@ -1483,12 +1478,7 @@ def link_bias_accelerations(
# ================================================

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(
xyzw=jaxsim.math.Quaternion.to_xyzw(wxyz=data.base_orientation())
),
translation=data.base_position(),
).as_matrix()
W_H_B = data.base_transform()

def other_representation_to_inertial(
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
Expand Down Expand Up @@ -1529,9 +1519,12 @@ def other_representation_to_inertial(
W_H_C = W_H_BW
with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity()[0:3]
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
with data.switch_velocity_representation(VelRepr.Mixed):
C_v_WB = BW_v_WB = data.base_velocity()

case _:
raise ValueError(data.velocity_representation)

Expand Down
17 changes: 15 additions & 2 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,18 @@ def system_velocity_dynamics(

@jax.jit
def system_position_dynamics(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
"""
Compute the dynamics of the system position.

Args:
model: The model to consider.
data: The data of the considered model.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient for adjusting the quaternion norm.

Returns:
A tuple containing the derivative of the base position, the derivative of the
Expand All @@ -250,6 +254,7 @@ def system_position_dynamics(
quaternion=W_Q_B,
omega=W_ω_WB,
omega_in_body_fixed=False,
K=baumgarte_quaternion_regularization,
).squeeze()

return W_ṗ_B, W_Q̇_B, ṡ
Expand All @@ -262,6 +267,7 @@ def system_dynamics(
*,
joint_forces: jtp.Vector | None = None,
link_forces: jtp.Vector | None = None,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> tuple[ODEState, dict[str, Any]]:
"""
Compute the dynamics of the system.
Expand All @@ -271,6 +277,9 @@ def system_dynamics(
data: The data of the considered model.
joint_forces: The joint forces to apply.
link_forces: The 6D forces to apply to the links.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient used to adjust the norm of the
quaternion (only used in integrators not operating on the SO(3) manifold).

Returns:
A tuple with an `ODEState` object storing in each of its attributes the
Expand All @@ -287,7 +296,11 @@ def system_dynamics(
)

# Extract the velocities.
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data)
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
model=model,
data=data,
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
)

# Create an ODEState object populated with the derivative of each leaf.
# Our integrators, operating on generic pytrees, will be able to handle it
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class JointModel:
already in a vectorized form. In other words, it cannot be created using vmap.
"""

λ_H_pre: jax.Array
suc_H_i: jax.Array
λ_H_pre: jtp.Array
suc_H_i: jtp.Array

joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
Expand Down
Loading