Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply new OOP pattern compatible with jax transformations #48

Merged
merged 27 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fd18c0d
Enforce physics model data to be jax numpy arrays
diegoferigo Aug 4, 2023
5e322e4
Make JaxSim class inherit from Vmappable with decorated methods
diegoferigo Oct 6, 2023
e9ee4a7
New JaxSim.step_size method
diegoferigo Aug 4, 2023
ca22cc8
Enforce dtype of data returned by JaxSim objects
diegoferigo Aug 4, 2023
3f61287
Make high-level Link inherit from Vmappable with decorated methods
diegoferigo Aug 4, 2023
2ae562e
Make high-level Joint inherit from Vmappable with decorated methods
diegoferigo Aug 4, 2023
4e1e790
Mitigate circular import in high-level Model
diegoferigo Aug 4, 2023
342ccad
Make high-level Model inherit from Vmappable with decorated methods
diegoferigo Oct 9, 2023
31f903d
Ensure that quaternion has unary norm in high-level model
diegoferigo Aug 4, 2023
8be1fef
Update existing tests
diegoferigo Aug 4, 2023
d28574d
Update top-level init
diegoferigo Aug 4, 2023
e6edbfc
Fix representation of Link.external_force
diegoferigo Aug 4, 2023
a3e5220
Fix mixed adjoint in Link.add_external force
diegoferigo Aug 4, 2023
5b95a6b
Update jnp.block usage to prevent IDE warnings
diegoferigo Aug 4, 2023
91b5282
Invert transform without jnp.linalg.inv
diegoferigo Aug 4, 2023
acc4681
Fix computation of Model.com_position
diegoferigo Aug 4, 2023
79879b7
Prefer using tuple[str, ...] instead of list[str]
diegoferigo Aug 4, 2023
32445e4
Rename Link.add_{,com_}external_force o Link.apply_{,com_}external_force
diegoferigo Aug 4, 2023
d38be87
Rename Model sections
diegoferigo Oct 4, 2023
5a975d4
Move r/w methods of Link to Model
diegoferigo Oct 4, 2023
a539b44
Clarify that Link and Joint are r/o helpers
diegoferigo Oct 4, 2023
b6043b6
Disable jit when gathering links and joint objects
diegoferigo Oct 4, 2023
d63b181
Set mutability when joints and links are extracted
diegoferigo Oct 4, 2023
91723d7
Minor typing update in Joint
diegoferigo Oct 4, 2023
9b4d65d
Import VelRepr in high_level package
diegoferigo Oct 4, 2023
106c3dc
Make JaxsimDataclass a dataclass
diegoferigo Oct 4, 2023
7384d41
Fix existing tests after API changes
diegoferigo Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ def _is_editable() -> bool:
del _np_options
del _is_editable

from . import high_level, logging, math, sixd
from . import high_level, logging, math, simulation, sixd
from .high_level.common import VelRepr
from .simulation.ode_integration import IntegratorType
from .simulation.simulator import JaxSim
1 change: 1 addition & 0 deletions src/jaxsim/high_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import common, joint, link, model
from .common import VelRepr
117 changes: 91 additions & 26 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import dataclasses
from typing import Any, Tuple
import functools
from typing import Any

import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static

import jaxsim.parsers
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass
from jaxsim.utils import Vmappable, not_tracing, oop


@jax_dataclasses.pytree_dataclass
class Joint(JaxsimDataclass):
class Joint(Vmappable):
"""
High-level class to operate on a single joint of a simulated model.
High-level class to operate in r/o on a single joint of a simulated model.
"""

joint_description: Static[jaxsim.parsers.descriptions.JointDescription]
Expand All @@ -21,53 +23,116 @@ class Joint(JaxsimDataclass):

@property
def parent_model(self) -> "jaxsim.high_level.model.Model":
""""""

return self._parent_model

def valid(self) -> bool:
return self.parent_model is not None
@functools.partial(oop.jax_tf.method_ro, jit=False)
def valid(self) -> jtp.Bool:
""""""

return jnp.array(self.parent_model is not None, dtype=bool)

@functools.partial(oop.jax_tf.method_ro, jit=False)
def index(self) -> jtp.Int:
""""""

def index(self) -> int:
return self.joint_description.index
return jnp.array(self.joint_description.index, dtype=int)

def dofs(self) -> int:
return 1
@functools.partial(oop.jax_tf.method_ro)
def dofs(self) -> jtp.Int:
""""""

return jnp.array(1, dtype=int)

@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def name(self) -> str:
""""""

return self.joint_description.name

def position(self, dof: int = 0) -> float:
return self.parent_model.joint_positions(joint_names=[self.name()])[dof]
@functools.partial(oop.jax_tf.method_ro)
def position(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_positions(joint_names=(self.name(),))[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def velocity(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_velocities(joint_names=(self.name(),))[dof],
dtype=float,
)

def velocity(self, dof: int = 0) -> float:
return self.parent_model.joint_velocities(joint_names=[self.name()])[dof]
@functools.partial(oop.jax_tf.method_ro)
def acceleration(self, dof: int = None) -> jtp.Float:
""""""

def acceleration(self, dof: int = 0) -> float:
return self.parent_model.joint_accelerations(joint_names=[self.name()])[dof]
dof = dof if dof is not None else 0

def force(self, dof: int = 0) -> float:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])[
dof
]
return jnp.array(
self.parent_model.joint_accelerations(joint_names=[self.name()])[dof],
dtype=float,
)

def position_limit(self, dof: int = 0) -> Tuple[float, float]:
if dof != 0:
@functools.partial(oop.jax_tf.method_ro)
def force(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_generalized_forces(joint_names=(self.name(),))[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def position_limit(self, dof: int = None) -> tuple[jtp.Float, jtp.Float]:
""""""

dof = dof if dof is not None else 0

if not_tracing(dof) and dof != 0:
msg = "Only joints with 1 DoF are currently supported"
raise ValueError(msg)

return self.joint_description.position_limit
low, high = self.joint_description.position_limit

return jnp.array(low, dtype=float), jnp.array(high, dtype=float)

# =================
# Multi-DoF methods
# =================

@functools.partial(oop.jax_tf.method_ro)
def joint_position(self) -> jtp.Vector:
return self.parent_model.joint_positions(joint_names=[self.name()])
""""""

return self.parent_model.joint_positions(joint_names=(self.name(),))

@functools.partial(oop.jax_tf.method_ro)
def joint_velocity(self) -> jtp.Vector:
return self.parent_model.joint_velocities(joint_names=[self.name()])
""""""

return self.parent_model.joint_velocities(joint_names=(self.name(),))

@functools.partial(oop.jax_tf.method_ro)
def joint_acceleration(self) -> jtp.Vector:
""""""

return self.parent_model.joint_accelerations(joint_names=[self.name()])

@functools.partial(oop.jax_tf.method_ro)
def joint_force(self) -> jtp.Vector:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])
""""""

return self.parent_model.joint_generalized_forces(joint_names=(self.name(),))
Loading
Loading