Skip to content

Commit

Permalink
Merge pull request #173 from ami-iit/minor_changes_for_new_release
Browse files Browse the repository at this point in the history
Minor changes for the upcoming release
  • Loading branch information
diegoferigo authored Jun 12, 2024
2 parents 0422047 + 82502a3 commit e74e24b
Show file tree
Hide file tree
Showing 29 changed files with 369 additions and 156 deletions.
1 change: 0 additions & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# GitHub syntax highlighting
pixi.lock linguist-language=YAML

5 changes: 2 additions & 3 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ jobs:
- *src
- *tests
# https://gazebosim.org/docs/harmonic/install_ubuntu
- name: Install Gazebo Sim
- name: Install 'gz sdf' system command
if: |
contains(matrix.os, 'ubuntu') &&
(github.event_name != 'pull_request' ||
Expand All @@ -130,7 +129,7 @@ jobs:
sudo wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
sudo apt-get update
sudo apt-get install gz-harmonic
sudo apt-get install --no-install-recommends libsdformat13 gz-tools2
- name: Run the Python tests
if: |
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,10 @@ src/jaxsim/_version.py

# ruff
.ruff_cache/

# pixi environments
.pixi

# data
.mp4
.png
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX.

Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.

## Features

Expand All @@ -25,7 +25,7 @@ Its design facilitates research and accelerates prototyping in the intersection

### JaxSim as a multibody dynamics library

- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
- Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion.
- Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation].
- Exposes all the necessary quantities to develop controllers in centroidal coordinates.
Expand Down Expand Up @@ -132,10 +132,10 @@ The main differences between MJX/Brax and JaxSim are as follows:

- JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS].
- JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.
Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
function of the state $(\mathbf{q}, \boldsymbol{\nu})$, it doesn't require running any optimization algorithm
when stepping the simulation forward.
- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.
- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.

[brax]: https://github.com/google/brax
[mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
Expand Down
46 changes: 39 additions & 7 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ def _jnp_options() -> None:

import jax

# Enable by default
if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"):
# Enable by default 64bit precision in JAX.
if os.environ.get("JAX_ENABLE_X64", "1") != "0":

logging.info("Enabling JAX to use 64bit precision")
jax.config.update("jax_enable_x64", True)

Expand All @@ -27,6 +28,7 @@ def _np_options() -> None:


def _is_editable() -> bool:

import importlib.util
import pathlib
import site
Expand All @@ -45,11 +47,40 @@ def _is_editable() -> bool:
return jaxsim_package_dir not in site.getsitepackages()


# Initialize the logging verbosity
if _is_editable():
logging.configure(level=logging.LoggingLevel.DEBUG)
else:
logging.configure(level=logging.LoggingLevel.WARNING)
def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
"""
Get the default logging level.
Args:
env_var: The environment variable to check.
Returns:
The logging level to set.
"""

import os

# Define the default logging level depending on the installation mode.
default_logging_level = (
logging.LoggingLevel.DEBUG
if _is_editable() # noqa: F821
else logging.LoggingLevel.WARNING
)

# Allow to override the default logging level with an environment variable.
try:
return logging.LoggingLevel[
os.environ.get(env_var, default_logging_level.name).upper()
]

except KeyError as exc:
msg = f"Invalid logging level defined in {env_var}='{os.environ[env_var]}'"
raise RuntimeError(msg) from exc


# Configure the logger with the default logging level.
logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))


# Configure JAX
_jnp_options()
Expand All @@ -59,6 +90,7 @@ def _is_editable() -> bool:

del _jnp_options
del _np_options
del _get_default_logging_level
del _is_editable

from . import terrain # isort:skip
Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,20 @@ def jacobian(

W_H_C = transforms(model=model, data=data)

def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
C_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_C, inverse=True
)
C_J_WC = C_X_W @ W_J_WC
return C_J_WC

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)

case VelRepr.Mixed:

W_H_C = transforms(model=model, data=data)

def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:

W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))

Expand All @@ -389,7 +389,7 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
CW_J_WC = CW_X_W @ W_J_WC
return CW_J_WC

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)

case _:
raise ValueError(output_vel_repr)
Expand Down
5 changes: 4 additions & 1 deletion src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
import numpy as np
from jax_dataclasses import Static

import jaxsim.typing as jtp
Expand Down Expand Up @@ -220,7 +221,9 @@ def __hash__(self) -> int:
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())),
hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())),
hash(self._parent_array),
hash(self._support_body_array_bool),
)
)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def jacobian(
)

# Compute the actual doubly-left free-floating jacobian of the link.
κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B
κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B

# Adjust the input representation such that `J_WL_I @ I_ν`.
match data.velocity_representation:
Expand Down
39 changes: 16 additions & 23 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jaxsim.parsers.descriptions
import jaxsim.typing as jtp
from jaxsim.math import Cross
from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability
from jaxsim.utils import JaxsimDataclass, Mutability

from .common import VelRepr

Expand All @@ -33,6 +33,7 @@ class JaxSimModel(JaxsimDataclass):
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
)

kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)
Expand All @@ -41,13 +42,9 @@ class JaxSimModel(JaxsimDataclass):
default=None, repr=False, compare=False, hash=False
)

_description: Static[
HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
] = dataclasses.field(default=None, repr=False, compare=False, hash=False)

@property
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
return self._description.get()
description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)

def __eq__(self, other: JaxSimModel) -> bool:

Expand All @@ -61,6 +58,7 @@ def __hash__(self) -> int:
return hash(
(
hash(self.model_name),
hash(self.description),
hash(self.kin_dyn_parameters),
)
)
Expand Down Expand Up @@ -157,7 +155,7 @@ def build(
# Build the model
model = JaxSimModel(
model_name=model_name,
_description=HashlessObject(obj=model_description),
description=model_description,
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
Expand Down Expand Up @@ -302,7 +300,7 @@ def reduce(
locked_joint_positions:
A dictionary containing the positions of the joints to be considered
in the reduction process. The removed joints in the reduced model
will have their position locked to their value in this dictionary.
will have their position locked to their value of this dictionary.
If a joint is not part of the dictionary, its position is set to zero.
"""

Expand All @@ -315,10 +313,9 @@ def reduce(
new_joints = set(model.joint_names()) - set(locked_joint_positions)
raise ValueError(f"Passed joints not existing in the model: {new_joints}")

# Copy the model description with a deep copy of the joints.
intermediate_description = dataclasses.replace(
model.description, joints=copy.deepcopy(model.description.joints)
)
# Operate on a deep copy of the model description in order to prevent problems
# when mutable attributes are updated.
intermediate_description = copy.deepcopy(model.description)

# Update the initial position of the joints.
# This is necessary to compute the correct pose of the link pairs connected
Expand Down Expand Up @@ -686,8 +683,6 @@ def to_active(
another representation C_v̇_WB expressed in a generic frame C.
"""

from jaxsim.math import Cross

# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
Expand Down Expand Up @@ -1483,12 +1478,7 @@ def link_bias_accelerations(
# ================================================

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(
xyzw=jaxsim.math.Quaternion.to_xyzw(wxyz=data.base_orientation())
),
translation=data.base_position(),
).as_matrix()
W_H_B = data.base_transform()

def other_representation_to_inertial(
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
Expand Down Expand Up @@ -1529,9 +1519,12 @@ def other_representation_to_inertial(
W_H_C = W_H_BW
with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity()[0:3]
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
with data.switch_velocity_representation(VelRepr.Mixed):
C_v_WB = BW_v_WB = data.base_velocity()

case _:
raise ValueError(data.velocity_representation)

Expand Down
17 changes: 15 additions & 2 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,18 @@ def system_velocity_dynamics(

@jax.jit
def system_position_dynamics(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
"""
Compute the dynamics of the system position.
Args:
model: The model to consider.
data: The data of the considered model.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient for adjusting the quaternion norm.
Returns:
A tuple containing the derivative of the base position, the derivative of the
Expand All @@ -250,6 +254,7 @@ def system_position_dynamics(
quaternion=W_Q_B,
omega=W_ω_WB,
omega_in_body_fixed=False,
K=baumgarte_quaternion_regularization,
).squeeze()

return W_ṗ_B, W_Q̇_B,
Expand All @@ -262,6 +267,7 @@ def system_dynamics(
*,
joint_forces: jtp.Vector | None = None,
link_forces: jtp.Vector | None = None,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> tuple[ODEState, dict[str, Any]]:
"""
Compute the dynamics of the system.
Expand All @@ -271,6 +277,9 @@ def system_dynamics(
data: The data of the considered model.
joint_forces: The joint forces to apply.
link_forces: The 6D forces to apply to the links.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient used to adjust the norm of the
quaternion (only used in integrators not operating on the SO(3) manifold).
Returns:
A tuple with an `ODEState` object storing in each of its attributes the
Expand All @@ -287,7 +296,11 @@ def system_dynamics(
)

# Extract the velocities.
W_ṗ_B, W_Q̇_B, = system_position_dynamics(model=model, data=data)
W_ṗ_B, W_Q̇_B, = system_position_dynamics(
model=model,
data=data,
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
)

# Create an ODEState object populated with the derivative of each leaf.
# Our integrators, operating on generic pytrees, will be able to handle it
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class JointModel:
already in a vectorized form. In other words, it cannot be created using vmap.
"""

λ_H_pre: jax.Array
suc_H_i: jax.Array
λ_H_pre: jtp.Array
suc_H_i: jtp.Array

joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
Expand Down
Loading

0 comments on commit e74e24b

Please sign in to comment.