diff --git a/src/jaxsim/simulation/simulator.py b/src/jaxsim/simulation/simulator.py index 75def8a28..354153b41 100644 --- a/src/jaxsim/simulation/simulator.py +++ b/src/jaxsim/simulation/simulator.py @@ -21,15 +21,19 @@ @jax_dataclasses.pytree_dataclass class SimulatorData: + # Simulation time stored in ns in order to prevent floats approximation time_ns: jtp.Int = jnp.array(0, dtype=int) + # Terrain and contact parameters terrain: Terrain = jax_dataclasses.field(default_factory=lambda: FlatTerrain()) contact_parameters: SoftContactsParams = jax_dataclasses.field( default_factory=lambda: SoftContactsParams() ) + # Dictionary containing all handled models models: Dict[str, Model] = jax_dataclasses.field(default_factory=dict) + # Default gravity vector (could be overridden for individual models) gravity: jtp.Vector = jax_dataclasses.field( default_factory=lambda: jaxsim.physics.default_gravity() ) @@ -38,17 +42,23 @@ class SimulatorData: @jax_dataclasses.pytree_dataclass class JaxSim(JaxsimDataclass): + # Step size stored in ns in order to prevent floats approximation step_size_ns: jtp.Int = jax_dataclasses.field( default_factory=lambda: jnp.array(0.001, dtype=int) ) + + # Number of substeps performed at each integration step steps_per_run: jtp.Int = jax_dataclasses.static_field(default=1) + # Default velocity representation (could be overridden for individual models) velocity_representation: VelRepr = jax_dataclasses.field(default=VelRepr.Mixed) + # Integrator type integrator_type: ode_integration.IntegratorType = jax_dataclasses.static_field( default=ode_integration.IntegratorType.EulerForward ) + # Simulator data data: SimulatorData = dataclasses.field(default_factory=lambda: SimulatorData()) @staticmethod