Skip to content

Commit

Permalink
Merge pull request #48 from ami-iit/feature/oop_with_jax
Browse files Browse the repository at this point in the history
Apply new OOP pattern compatible with jax transformations
  • Loading branch information
diegoferigo committed Oct 9, 2023
2 parents cf3575c + 7384d41 commit 309582f
Show file tree
Hide file tree
Showing 13 changed files with 799 additions and 398 deletions.
4 changes: 3 additions & 1 deletion src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ def _is_editable() -> bool:
del _np_options
del _is_editable

from . import high_level, logging, math, sixd
from . import high_level, logging, math, simulation, sixd
from .high_level.common import VelRepr
from .simulation.ode_integration import IntegratorType
from .simulation.simulator import JaxSim
1 change: 1 addition & 0 deletions src/jaxsim/high_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import common, joint, link, model
from .common import VelRepr
117 changes: 91 additions & 26 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import dataclasses
from typing import Any, Tuple
import functools
from typing import Any

import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static

import jaxsim.parsers
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass
from jaxsim.utils import Vmappable, not_tracing, oop


@jax_dataclasses.pytree_dataclass
class Joint(JaxsimDataclass):
class Joint(Vmappable):
"""
High-level class to operate on a single joint of a simulated model.
High-level class to operate in r/o on a single joint of a simulated model.
"""

joint_description: Static[jaxsim.parsers.descriptions.JointDescription]
Expand All @@ -21,53 +23,116 @@ class Joint(JaxsimDataclass):

@property
def parent_model(self) -> "jaxsim.high_level.model.Model":
""""""

return self._parent_model

def valid(self) -> bool:
return self.parent_model is not None
@functools.partial(oop.jax_tf.method_ro, jit=False)
def valid(self) -> jtp.Bool:
""""""

return jnp.array(self.parent_model is not None, dtype=bool)

@functools.partial(oop.jax_tf.method_ro, jit=False)
def index(self) -> jtp.Int:
""""""

def index(self) -> int:
return self.joint_description.index
return jnp.array(self.joint_description.index, dtype=int)

def dofs(self) -> int:
return 1
@functools.partial(oop.jax_tf.method_ro)
def dofs(self) -> jtp.Int:
""""""

return jnp.array(1, dtype=int)

@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def name(self) -> str:
""""""

return self.joint_description.name

def position(self, dof: int = 0) -> float:
return self.parent_model.joint_positions(joint_names=[self.name()])[dof]
@functools.partial(oop.jax_tf.method_ro)
def position(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_positions(joint_names=(self.name(),))[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def velocity(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_velocities(joint_names=(self.name(),))[dof],
dtype=float,
)

def velocity(self, dof: int = 0) -> float:
return self.parent_model.joint_velocities(joint_names=[self.name()])[dof]
@functools.partial(oop.jax_tf.method_ro)
def acceleration(self, dof: int = None) -> jtp.Float:
""""""

def acceleration(self, dof: int = 0) -> float:
return self.parent_model.joint_accelerations(joint_names=[self.name()])[dof]
dof = dof if dof is not None else 0

def force(self, dof: int = 0) -> float:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])[
dof
]
return jnp.array(
self.parent_model.joint_accelerations(joint_names=[self.name()])[dof],
dtype=float,
)

def position_limit(self, dof: int = 0) -> Tuple[float, float]:
if dof != 0:
@functools.partial(oop.jax_tf.method_ro)
def force(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_generalized_forces(joint_names=(self.name(),))[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def position_limit(self, dof: int = None) -> tuple[jtp.Float, jtp.Float]:
""""""

dof = dof if dof is not None else 0

if not_tracing(dof) and dof != 0:
msg = "Only joints with 1 DoF are currently supported"
raise ValueError(msg)

return self.joint_description.position_limit
low, high = self.joint_description.position_limit

return jnp.array(low, dtype=float), jnp.array(high, dtype=float)

# =================
# Multi-DoF methods
# =================

@functools.partial(oop.jax_tf.method_ro)
def joint_position(self) -> jtp.Vector:
return self.parent_model.joint_positions(joint_names=[self.name()])
""""""

return self.parent_model.joint_positions(joint_names=(self.name(),))

@functools.partial(oop.jax_tf.method_ro)
def joint_velocity(self) -> jtp.Vector:
return self.parent_model.joint_velocities(joint_names=[self.name()])
""""""

return self.parent_model.joint_velocities(joint_names=(self.name(),))

@functools.partial(oop.jax_tf.method_ro)
def joint_acceleration(self) -> jtp.Vector:
""""""

return self.parent_model.joint_accelerations(joint_names=[self.name()])

@functools.partial(oop.jax_tf.method_ro)
def joint_force(self) -> jtp.Vector:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])
""""""

return self.parent_model.joint_generalized_forces(joint_names=(self.name(),))
Loading

0 comments on commit 309582f

Please sign in to comment.