Skip to content

Commit

Permalink
Update Simulator.step method
Browse files Browse the repository at this point in the history
- Return the StepData of all simulated models
- Introduce support of clearing model inputs after stepping
  • Loading branch information
diegoferigo committed Sep 19, 2022
1 parent 61af2ad commit d258afa
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/jaxsim/simulation/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jaxsim.physics
import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model
from jaxsim.high_level.model import Model, StepData
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.physics.model.physics_model import PhysicsModel
Expand Down Expand Up @@ -180,28 +180,33 @@ def remove_model(self, model_name: str) -> None:
self.data.models.pop(model_name)
self._set_mutability(self._mutability())

def step(self) -> None:
def step(self, clear_inputs: bool = False) -> Dict[str, StepData]:

t0_ns = jnp.array(self.data.time_ns, dtype=int)
dt_ns = jnp.array(self.step_size_ns * self.steps_per_run, dtype=int)

tf_ns = t0_ns + dt_ns

# We collect the StepData of all models
step_data = dict()

for model in self.models():

with model.editable(validate=True) as integrated_model:

integrated_model.integrate(
step_data[model.name()] = integrated_model.integrate(
t0=jnp.array(t0_ns, dtype=float) / 1e9,
tf=jnp.array(tf_ns, dtype=float) / 1e9,
sub_steps=self.steps_per_run,
integrator_type=self.integrator_type,
terrain=self.data.terrain,
contact_parameters=self.data.contact_parameters,
clear_inputs=clear_inputs,
)

self.data.models[model.name()].data = integrated_model.data

self.data.time_ns += dt_ns

self._set_mutability(self._mutability())
return step_data

0 comments on commit d258afa

Please sign in to comment.