Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New high-level APIs with OOP wrappers #59

Merged
merged 62 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
f1fbe68
Revert "Perform a deep copy when copying pytrees"
diegoferigo Jun 30, 2023
825ff5a
Always restore mutability before exiting the context manager
diegoferigo Jul 27, 2023
719eb40
Convert utils.py module to jaxsim.utils package
diegoferigo Jul 27, 2023
68ac4b5
Create abstract class that can be automatically vectorized with vmap
diegoferigo Aug 3, 2023
b005b8a
Add decorators to apply jit and vmap on jax dataclasses methods
diegoferigo Oct 6, 2023
d922ccf
Add tests for applying jax transformations in OOP on jax dataclasses
diegoferigo Aug 4, 2023
ce844c7
Add typing_extension dependency for Python < 3.11
diegoferigo Jul 31, 2023
d93b4ca
Use typing.Self in jaxsim_dataclass
diegoferigo Jul 31, 2023
a3bc673
Initialize __mutability__ in jaxsim_dataclass
diegoferigo Jul 31, 2023
84f5ef5
Restore the original pytree content if mutable_context raises
diegoferigo Aug 3, 2023
6b424bf
Extend the tracing detection checking also the Tracer type
diegoferigo Aug 3, 2023
75b8d27
Remove resources deprecated in upstream
diegoferigo Oct 6, 2023
c1f5b63
Run tests in subprocess and move pytest configuration to pyproject.toml
diegoferigo Aug 3, 2023
cf3575c
Merge pull request #44 from ami-iit/fix/trace_leak
diegoferigo Oct 9, 2023
fd18c0d
Enforce physics model data to be jax numpy arrays
diegoferigo Aug 4, 2023
5e322e4
Make JaxSim class inherit from Vmappable with decorated methods
diegoferigo Oct 6, 2023
e9ee4a7
New JaxSim.step_size method
diegoferigo Aug 4, 2023
ca22cc8
Enforce dtype of data returned by JaxSim objects
diegoferigo Aug 4, 2023
3f61287
Make high-level Link inherit from Vmappable with decorated methods
diegoferigo Aug 4, 2023
2ae562e
Make high-level Joint inherit from Vmappable with decorated methods
diegoferigo Aug 4, 2023
4e1e790
Mitigate circular import in high-level Model
diegoferigo Aug 4, 2023
342ccad
Make high-level Model inherit from Vmappable with decorated methods
diegoferigo Oct 9, 2023
31f903d
Ensure that quaternion has unary norm in high-level model
diegoferigo Aug 4, 2023
8be1fef
Update existing tests
diegoferigo Aug 4, 2023
d28574d
Update top-level init
diegoferigo Aug 4, 2023
e6edbfc
Fix representation of Link.external_force
diegoferigo Aug 4, 2023
a3e5220
Fix mixed adjoint in Link.add_external force
diegoferigo Aug 4, 2023
5b95a6b
Update jnp.block usage to prevent IDE warnings
diegoferigo Aug 4, 2023
91b5282
Invert transform without jnp.linalg.inv
diegoferigo Aug 4, 2023
acc4681
Fix computation of Model.com_position
diegoferigo Aug 4, 2023
79879b7
Prefer using tuple[str, ...] instead of list[str]
diegoferigo Aug 4, 2023
32445e4
Rename Link.add_{,com_}external_force o Link.apply_{,com_}external_force
diegoferigo Aug 4, 2023
d38be87
Rename Model sections
diegoferigo Oct 4, 2023
5a975d4
Move r/w methods of Link to Model
diegoferigo Oct 4, 2023
a539b44
Clarify that Link and Joint are r/o helpers
diegoferigo Oct 4, 2023
b6043b6
Disable jit when gathering links and joint objects
diegoferigo Oct 4, 2023
d63b181
Set mutability when joints and links are extracted
diegoferigo Oct 4, 2023
91723d7
Minor typing update in Joint
diegoferigo Oct 4, 2023
9b4d65d
Import VelRepr in high_level package
diegoferigo Oct 4, 2023
106c3dc
Make JaxsimDataclass a dataclass
diegoferigo Oct 4, 2023
7384d41
Fix existing tests after API changes
diegoferigo Oct 5, 2023
309582f
Merge pull request #48 from ami-iit/feature/oop_with_jax
diegoferigo Oct 9, 2023
03af4f4
Allow pre-step callbacks to generate output
diegoferigo Oct 9, 2023
90a1eab
Update jaxsim.simulation.__init__.py
diegoferigo Aug 4, 2023
35b5159
Remove leftover methods no longer valid
diegoferigo Aug 4, 2023
380d20c
Rename Joint.{joint_,}force to Joint.{joint_,}force_target
diegoferigo Aug 4, 2023
c2ff048
Fix descriptions.{Link|Joint} behavior as jax dataclasses
diegoferigo Oct 9, 2023
207a4a1
Fix jnp.block usage
diegoferigo Oct 9, 2023
f3d6af8
Do not hash parent model in Link and Joint
diegoferigo Oct 9, 2023
0a1c38d
Fix height calculation of PlaneTerrain
diegoferigo Oct 9, 2023
ab51d02
Removed unused scipy dependency
diegoferigo Oct 9, 2023
5573944
Merge pull request #53 from ami-iit/minor_api_update
diegoferigo Oct 10, 2023
1b26141
Fix differentiating through Quaternion.derivative
diegoferigo Oct 9, 2023
f4db5dd
Allow differentiating through the SoftContacts algorithm
diegoferigo Oct 9, 2023
c5c799d
Add test for automatic differentiation of RBDAs
diegoferigo Oct 10, 2023
e9bb166
Merge pull request #54 from ami-iit/automatic_differentiation
diegoferigo Oct 11, 2023
0d59772
Make all fields of simulator.JaxSim static, excluding SimulatorData
diegoferigo Oct 20, 2023
804d525
Merge pull request #55 from ami-iit/fix/jaxsim_static_fields
diegoferigo Oct 24, 2023
e0c0bc7
Merge remote-tracking branch 'ami/main' into new_api
diegoferigo Dec 6, 2023
458eeac
Remove pinnings necessary on old jax versions
diegoferigo Dec 6, 2023
50aea06
Update bool typing
diegoferigo Dec 6, 2023
b18b930
Update typing of random key
diegoferigo Dec 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ line-length = 88
[tool.isort]
profile = "black"
multi_line_output = 3

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-rsxX -v --strict-markers --forked"
testpaths = [
"tests",
]
9 changes: 3 additions & 6 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ install_requires =
ml-dtypes < 0.3.0
pptree
rod
scipy
typing_extensions; python_version < "3.11"

[options.packages.find]
where = src
Expand All @@ -71,13 +71,10 @@ style =
isort
testing =
idyntree
pytest
pytest >= 6.0
pytest-forked
pytest-icdiff
robot-descriptions
all =
%(style)s
%(testing)s

[tool:pytest]
addopts = -rsxX -v --strict-markers
testpaths = tests
4 changes: 3 additions & 1 deletion src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ def _is_editable() -> bool:
del _np_options
del _is_editable

from . import high_level, logging, math, sixd
from . import high_level, logging, math, simulation, sixd
from .high_level.common import VelRepr
from .simulation.ode_integration import IntegratorType
from .simulation.simulator import JaxSim
1 change: 1 addition & 0 deletions src/jaxsim/high_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import common, joint, link, model
from .common import VelRepr
114 changes: 84 additions & 30 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,127 @@
import dataclasses
from typing import Any, Tuple
import functools
from typing import Any

import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static

import jaxsim.parsers
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass
from jaxsim.utils import Vmappable, not_tracing, oop


@jax_dataclasses.pytree_dataclass
class Joint(JaxsimDataclass):
class Joint(Vmappable):
"""
High-level class to operate on a single joint of a simulated model.
High-level class to operate in r/o on a single joint of a simulated model.
"""

joint_description: Static[jaxsim.parsers.descriptions.JointDescription]

_parent_model: Any = dataclasses.field(default=None, repr=False, compare=False)
_parent_model: Any = dataclasses.field(
default=None, repr=False, compare=False, hash=False
)

@property
def parent_model(self) -> "jaxsim.high_level.model.Model":
""""""

return self._parent_model

def valid(self) -> bool:
return self.parent_model is not None
@functools.partial(oop.jax_tf.method_ro, jit=False)
def valid(self) -> jtp.Bool:
""""""

return jnp.array(self.parent_model is not None, dtype=bool)

@functools.partial(oop.jax_tf.method_ro, jit=False)
def index(self) -> jtp.Int:
""""""

return jnp.array(self.joint_description.index, dtype=int)

def index(self) -> int:
return self.joint_description.index
@functools.partial(oop.jax_tf.method_ro)
def dofs(self) -> jtp.Int:
""""""

def dofs(self) -> int:
return 1
return jnp.array(1, dtype=int)

@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def name(self) -> str:
""""""

return self.joint_description.name

def position(self, dof: int = 0) -> float:
return self.parent_model.joint_positions(joint_names=[self.name()])[dof]
@functools.partial(oop.jax_tf.method_ro)
def position(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_positions(joint_names=(self.name(),))[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def velocity(self, dof: int = None) -> jtp.Float:
""""""

def velocity(self, dof: int = 0) -> float:
return self.parent_model.joint_velocities(joint_names=[self.name()])[dof]
dof = dof if dof is not None else 0

def acceleration(self, dof: int = 0) -> float:
return self.parent_model.joint_accelerations(joint_names=[self.name()])[dof]
return jnp.array(
self.parent_model.joint_velocities(joint_names=(self.name(),))[dof],
dtype=float,
)

def force(self, dof: int = 0) -> float:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])[
dof
]
@functools.partial(oop.jax_tf.method_ro)
def force_target(self, dof: int = None) -> jtp.Float:
""""""

def position_limit(self, dof: int = 0) -> Tuple[float, float]:
if dof != 0:
dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_generalized_forces_targets(
joint_names=(self.name(),)
)[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def position_limit(self, dof: int = None) -> tuple[jtp.Float, jtp.Float]:
""""""

dof = dof if dof is not None else 0

if not_tracing(dof) and dof != 0:
msg = "Only joints with 1 DoF are currently supported"
raise ValueError(msg)

return self.joint_description.position_limit
low, high = self.joint_description.position_limit

return jnp.array(low, dtype=float), jnp.array(high, dtype=float)

# =================
# Multi-DoF methods
# =================

@functools.partial(oop.jax_tf.method_ro)
def joint_position(self) -> jtp.Vector:
return self.parent_model.joint_positions(joint_names=[self.name()])
""""""

return self.parent_model.joint_positions(joint_names=(self.name(),))

@functools.partial(oop.jax_tf.method_ro)
def joint_velocity(self) -> jtp.Vector:
return self.parent_model.joint_velocities(joint_names=[self.name()])
""""""

return self.parent_model.joint_velocities(joint_names=(self.name(),))

def joint_acceleration(self) -> jtp.Vector:
return self.parent_model.joint_accelerations(joint_names=[self.name()])
@functools.partial(oop.jax_tf.method_ro)
def joint_force_target(self) -> jtp.Vector:
""""""

def joint_force(self) -> jtp.Vector:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])
return self.parent_model.joint_generalized_forces_targets(
joint_names=(self.name(),)
)
Loading
Loading