Skip to content

Commit

Permalink
Merge pull request #38 from ami-iit/feature/no_experimental
Browse files Browse the repository at this point in the history
ABA and RNEA with `jax.lax.scan`, Python 3.11, fixed bugs in other representations
  • Loading branch information
diegoferigo committed Jun 30, 2023
2 parents 2e50bf1 + 457572e commit 0eab286
Show file tree
Hide file tree
Showing 22 changed files with 1,160 additions and 284 deletions.
34 changes: 31 additions & 3 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
- "3.8"
- "3.9"
- "3.10"
# - "3.11"
- "3.11"

steps:

Expand All @@ -76,13 +76,41 @@ jobs:
path: dist
name: dist

- name: Install wheel
# Workaround: install iDynTree for Python 3.11
- name: iDynTree on Python 3.11
if: contains(matrix.os, 'ubuntu') && matrix.python == '3.11'
shell: bash
run: pip install dist/*.whl
run: pip install --pre idyntree

- name: Install wheel (ubuntu)
if: contains(matrix.os, 'ubuntu')
shell: bash
run: pip install "$(find dist/ -type f -name '*.whl')[all]"

- name: Install wheel (macos)
if: contains(matrix.os, 'macos')
shell: bash
run: pip install "$(find dist/ -type f -name '*.whl')"

- name: Import the package
run: python -c "import jaxsim"

- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Install Gazebo Classic
if: contains(matrix.os, 'ubuntu') && (matrix.python == '3.10' || matrix.python == '3.11')
run: |
sudo apt-get update
sudo apt-get install gazebo
- name: Run the Python tests
if: contains(matrix.os, 'ubuntu') && (matrix.python == '3.10' || matrix.python == '3.11')
run: pytest
env:
JAX_PLATFORM_NAME: cpu

publish:
name: Publish to PyPI
needs: test
Expand Down
12 changes: 6 additions & 6 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ long_description_content_type = text/markdown
author = Diego Ferigo
author_email = diego.ferigo@iit.it
license = BSD
license_file = LICENSE
license_files = LICENSE
platforms = any
url = https://github.com/ami-iit/jaxsim

Expand Down Expand Up @@ -55,12 +55,10 @@ package_dir =
python_requires = >=3.8
install_requires =
coloredlogs
distrax
flax
jax >=0.3.14, <0.3.16
jaxlib == 0.3.15
jax >= 0.4.1
jaxlib
jaxlie
jax_dataclasses >= 1.2.2, < 1.4.0
jax_dataclasses >= 1.4.0
pptree
rod
scipy
Expand All @@ -73,8 +71,10 @@ style =
black
isort
testing =
idyntree
pytest
pytest-icdiff
robot-descriptions
all =
%(style)s
%(testing)s
Expand Down
18 changes: 11 additions & 7 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Tuple
import dataclasses
from typing import Any, Tuple

import jax_dataclasses
from jax_dataclasses import Static

import jaxsim.high_level
import jaxsim.parsers.descriptions as descriptions
import jaxsim.parsers
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass

Expand All @@ -14,10 +15,13 @@ class Joint(JaxsimDataclass):
High-level class to operate on a single joint of a simulated model.
"""

joint_description: descriptions.JointDescription = jax_dataclasses.static_field()
parent_model: "jaxsim.high_level.model.Model" = jax_dataclasses.field(
default=None, repr=False, compare=False
)
joint_description: Static[jaxsim.parsers.descriptions.JointDescription]

_parent_model: Any = dataclasses.field(default=None, repr=False, compare=False)

@property
def parent_model(self) -> "jaxsim.high_level.model.Model":
return self._parent_model

def valid(self) -> bool:
return self.parent_model is not None
Expand Down
42 changes: 24 additions & 18 deletions src/jaxsim/high_level/link.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import dataclasses
from typing import Any

import jax.numpy as jnp
import jax_dataclasses
import numpy as np
from jax_dataclasses import Static

import jaxsim.high_level
import jaxsim.parsers.descriptions as descriptions
import jaxsim.parsers
import jaxsim.sixd as sixd
import jaxsim.typing as jtp
from jaxsim.physics.algos.jacobian import jacobian
Expand All @@ -18,10 +21,13 @@ class Link(JaxsimDataclass):
High-level class to operate on a single link of a simulated model.
"""

link_description: descriptions.LinkDescription = jax_dataclasses.static_field()
parent_model: "jaxsim.high_level.model.Model" = jax_dataclasses.field(
default=None, repr=False, compare=False
)
link_description: Static[jaxsim.parsers.descriptions.LinkDescription]

_parent_model: Any = dataclasses.field(default=None, repr=False, compare=False)

@property
def parent_model(self) -> "jaxsim.high_level.model.Model":
return self._parent_model

def valid(self) -> bool:
return self.parent_model is not None
Expand Down Expand Up @@ -143,22 +149,22 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:
raise ValueError(output_vel_repr)

def external_force(self) -> jtp.Vector:
W_f_ext = self.parent_model.data.model_input.f_ext[self.index()]
"""
Return the active external force acting on the link.
if self.parent_model.velocity_representation is VelRepr.Inertial:
return W_f_ext
This external force is a user input and is not computed by the physics engine.
During the simulation, this external force is summed to other terms like those
related to enforce contact constraints.
elif self.parent_model.velocity_representation is VelRepr.Body:
W_H_B = self.parent_model.base_transform()
W_X_B = sixd.se3.SE3.from_matrix(W_H_B).adjoint()
Returns:
The active external 6D force acting on the link in the active representation.
"""

return W_X_B.transpose() @ W_f_ext

elif self.parent_model.velocity_representation is VelRepr.Mixed:
raise NotImplementedError
W_f_ext = self.parent_model.data.model_input.f_ext[self.index()]

else:
raise ValueError(self.parent_model.velocity_representation)
return self.parent_model.inertial_to_active_representation(
array=W_f_ext, is_force=True
)

def add_external_force(
self, force: jtp.Array = None, torque: jtp.Array = None
Expand Down
Loading

0 comments on commit 0eab286

Please sign in to comment.