Skip to content

Commit

Permalink
Merge pull request #29 from ami-iit/documentation
Browse files Browse the repository at this point in the history
Documentation enhancements
  • Loading branch information
diegoferigo committed Apr 21, 2023
2 parents 6561730 + d1849b3 commit 7568c2c
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/jaxsim/high_level/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@


class VelRepr(enum.IntEnum):
"""
Enumeration of all supported 6D velocity representations.
"""

Body = enum.auto()
Mixed = enum.auto()
Inertial = enum.auto()
4 changes: 4 additions & 0 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

@jax_dataclasses.pytree_dataclass
class Joint(JaxsimDataclass):
"""
High-level class to operate on a single joint of a simulated model.
"""

joint_description: descriptions.JointDescription = jax_dataclasses.static_field()
parent_model: "jaxsim.high_level.model.Model" = jax_dataclasses.field(
default=None, repr=False, compare=False
Expand Down
4 changes: 4 additions & 0 deletions src/jaxsim/high_level/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

@jax_dataclasses.pytree_dataclass
class Link(JaxsimDataclass):
"""
High-level class to operate on a single link of a simulated model.
"""

link_description: descriptions.LinkDescription = jax_dataclasses.static_field()
parent_model: "jaxsim.high_level.model.Model" = jax_dataclasses.field(
default=None, repr=False, compare=False
Expand Down
22 changes: 22 additions & 0 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,26 @@

@jax_dataclasses.pytree_dataclass
class ModelData(JaxsimDataclass):
"""
Class used to store the model state and input at a given time.
"""

model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState
model_input: jaxsim.physics.model.physics_model_state.PhysicsModelInput
contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState

@staticmethod
def zero(physics_model: physics.model.physics_model.PhysicsModel) -> "ModelData":
"""
Return a ModelData object with all fields set to zero and initialized with the right shape.
Args:
physics_model: The considered physics model.
Returns:
The zero ModelData object of the given physics model.
"""

return ModelData(
model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState.zero(
physics_model=physics_model
Expand All @@ -48,6 +62,10 @@ def zero(physics_model: physics.model.physics_model.PhysicsModel) -> "ModelData"

@jax_dataclasses.pytree_dataclass
class StepData(JaxsimDataclass):
"""
Class used to store the data computed at each step of the simulation.
"""

t0: float
tf: float
dt: float
Expand Down Expand Up @@ -77,6 +95,10 @@ class StepData(JaxsimDataclass):

@jax_dataclasses.pytree_dataclass
class Model(JaxsimDataclass):
"""
High-level class to operate on a simulated model.
"""

model_name: str = jax_dataclasses.static_field()
physics_model: physics.model.physics_model.PhysicsModel = dataclasses.field(
repr=False
Expand Down
102 changes: 102 additions & 0 deletions src/jaxsim/simulation/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ def odeint_euler_one_step(
tf: Time,
num_sub_steps: int = 1,
) -> Tuple[State, Dict[str, Any]]:
"""
Forward Euler integrator.
Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.
Returns:
The final state and a dictionary including auxiliary data at t0.
"""

# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
Expand Down Expand Up @@ -79,6 +93,20 @@ def odeint_rk4_one_step(
tf: Time,
num_sub_steps: int = 1,
) -> Tuple[State, Dict[str, Any]]:
"""
Runge-Kutta 4 integrator.
Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.
Returns:
The final state and a dictionary including auxiliary data at t0.
"""

# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
Expand Down Expand Up @@ -135,6 +163,20 @@ def odeint_euler_semi_implicit_one_step(
tf: Time,
num_sub_steps: int = 1,
) -> Tuple[ODEState, Dict[str, Any]]:
"""
Semi-implicit Euler integrator.
Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.
Returns:
The final state and a dictionary including auxiliary data at t0.
"""

# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
Expand Down Expand Up @@ -245,6 +287,18 @@ def integrate_single_step_over_horizon(
t: TimeHorizon,
x0: State,
) -> Tuple[State, Dict[str, Any]]:
"""
Integrate a single-step integrator over a given horizon.
Args:
integrator_single_step: A single-step integrator.
t: The vector of time instants of the integration horizon.
x0: The initial state of the integration horizon.
Returns:
The final state and auxiliary data produced by the integrator.
"""

# Initialize the carry
carry_init = (x0, t)

Expand Down Expand Up @@ -286,6 +340,22 @@ def odeint_euler(
num_sub_steps: int = 1,
return_aux: bool = False
) -> Union[State, Tuple[State, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Euler method.
Args:
func: A function that computes the time-derivative of the state.
y0: The initial state.
t: The vector of time instants of the integration horizon.
*args: Additional arguments to be passed to the function func.
num_sub_steps: The number of sub-steps to be performed within each integration step.
return_aux: Whether to return the auxiliary data produced by the integrator.
Returns:
The state of the system at the end of the integration horizon, and optionally
the auxiliary data produced by the integrator.
"""

# Close func over additional inputs and parameters
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)

Expand All @@ -310,6 +380,22 @@ def odeint_euler_semi_implicit(
num_sub_steps: int = 1,
return_aux: bool = False
) -> Union[State, Tuple[State, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Semi-Implicit Euler method.
Args:
func: A function that computes the time-derivative of the state.
y0: The initial state.
t: The vector of time instants of the integration horizon.
*args: Additional arguments to be passed to the function func.
num_sub_steps: The number of sub-steps to be performed within each integration step.
return_aux: Whether to return the auxiliary data produced by the integrator.
Returns:
The state of the system at the end of the integration horizon, and optionally
the auxiliary data produced by the integrator.
"""

# Close func over additional inputs and parameters
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)

Expand All @@ -334,6 +420,22 @@ def odeint_rk4(
num_sub_steps: int = 1,
return_aux: bool = False
) -> Union[State, Tuple[State, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Runge-Kutta 4 method.
Args:
func: A function that computes the time-derivative of the state.
y0: The initial state.
t: The vector of time instants of the integration horizon.
*args: Additional arguments to be passed to the function func.
num_sub_steps: The number of sub-steps to be performed within each integration step.
return_aux: Whether to return the auxiliary data produced by the integrator.
Returns:
The state of the system at the end of the integration horizon, and optionally
the auxiliary data produced by the integrator.
"""

# Close func over additional inputs and parameters
dx_dt_closure = lambda x, ts: func(x, ts, *args)

Expand Down
42 changes: 41 additions & 1 deletion src/jaxsim/simulation/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ def compute_contact_forces(
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
terrain: Terrain = FlatTerrain(),
) -> Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]:
"""
Compute the contact forces acting on the collidable points of the model.
Args:
physics_model: The physics model to consider.
ode_state: The state of the ODE corresponding to the physics model.
soft_contacts_params: The parameters of the soft contacts model.
terrain: The terrain model.
Returns:
A tuple containing:
- The contact forces expressed in the world frame acting on the model's links.
- The derivative of the tangential deformation of the terrain dynamics.
- The contact forces expressed in the world frame acting on the model's collidable points.
"""

# Compute position and linear mixed velocity of all model's collidable points
# collidable_points_kinematics
pos_cp, vel_cp = collidable_points_pos_vel(
Expand Down Expand Up @@ -64,6 +80,23 @@ def dx_dt(
ode_input: ode_data.ODEInput = None,
terrain: Terrain = FlatTerrain(),
) -> Tuple[ode_data.ODEState, Dict[str, Any]]:
"""
Compute the state derivative of the ODE corresponding to the physics model.
Args:
x: The state of the ODE.
t: The current time.
physics_model: The physics model to consider.
soft_contacts_params: The parameters of the soft contacts model.
ode_input: The input of the ODE.
terrain: The terrain model.
Returns:
A tuple containing:
- The state derivative of the ODE.
- A dictionary containing auxiliary information.
"""

if t is not None and isinstance(t, np.ndarray) and t.size != 1:
raise ValueError(t.size)

Expand Down Expand Up @@ -152,6 +185,7 @@ def dx_dt(
# Compute the joint torques to actuate
tau = ode_input.physics_model.tau + tau_friction + tau_limit

# Compute forward dynamics with the ABA algorithm
W_a_WB, qdd = algos.aba.aba(
model=physics_model,
xfb=ode_state.physics_model.xfb(),
Expand Down Expand Up @@ -198,11 +232,15 @@ def dx_dt(
# =====================================

def fix_one_dof(vector: jtp.Vector) -> Optional[jtp.Vector]:
"""Fix the shape of computed quantities for models with just 1 DoF."""

if vector is None:
return None

return jnp.array([vector]) if vector.shape == () else vector

# Fill the PhysicsModelState object included in the input ODEState to store the
# returned PhysicsModelState derivative
physics_model_state_derivative = ode_state.physics_model.replace(
joint_positions=fix_one_dof(ode_state.physics_model.joint_velocities.squeeze()),
joint_velocities=fix_one_dof(qdd.squeeze()),
Expand All @@ -212,12 +250,14 @@ def fix_one_dof(vector: jtp.Vector) -> Optional[jtp.Vector]:
base_linear_velocity=xd_fb.squeeze()[7:10],
)

# Fill the SoftContactsState object included in the input ODEState to store the
# returned SoftContactsState derivative
soft_contacts_state_derivative = ode_state.soft_contacts.replace(
tangential_deformation=tangential_deformation_dot.squeeze(),
)

# We store the state derivative using the ODEState class so that the pytree
# structure remains consistent, and it allows using our generic pytree integrators
# structure remains consistent, allowing to use our generic pytree integrators
state_derivative = ode_data.ODEState(
physics_model=physics_model_state_derivative,
soft_contacts=soft_contacts_state_derivative,
Expand Down
Loading

0 comments on commit 7568c2c

Please sign in to comment.