Skip to content

Commit

Permalink
Step the simulator over horizon
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Sep 20, 2022
1 parent 1a89f43 commit e14490a
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 2 deletions.
82 changes: 80 additions & 2 deletions src/jaxsim/simulation/simulator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import functools
import pathlib
from typing import Dict, List, Union
from typing import Dict, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jax_dataclasses

Expand All @@ -14,7 +16,7 @@
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.physics.model.physics_model import PhysicsModel
from jaxsim.simulation import ode_integration
from jaxsim.simulation import ode_integration, simulator_callbacks
from jaxsim.utils import JaxsimDataclass


Expand Down Expand Up @@ -220,3 +222,79 @@ def step(self, clear_inputs: bool = False) -> Dict[str, StepData]:

self._set_mutability(self._mutability())
return step_data

@functools.partial(jax.jit, static_argnames=["horizon_steps"])
def step_over_horizon(
self,
horizon_steps: jtp.Int,
callback_handler: "simulator_callbacks.CallbackHandler" = None,
clear_inputs: jtp.Bool = False,
) -> Union["JaxSim", Tuple["JaxSim", jtp.PyTree]]:

# Process a mutable copy of the simulator
sim = self.copy().mutable(validate=True)

# Helper to get callbacks from the handler
get_cb = (
lambda h, cb_name: getattr(h, cb_name)
if h is not None and hasattr(h, cb_name)
else None
)

# Get the callbacks
configure_cb: Optional[simulator_callbacks.ConfigureCallbackSignature] = get_cb(
h=callback_handler, cb_name="configure_cb"
)
pre_step_cb: Optional[simulator_callbacks.PreStepCallbackSignature] = get_cb(
h=callback_handler, cb_name="pre_step_cb"
)
post_step_cb: Optional[simulator_callbacks.PostStepCallbackSignature] = get_cb(
h=callback_handler, cb_name="post_step_cb"
)

# Callback: configuration
sim = configure_cb(sim) if configure_cb is not None else sim

# Initialize the carry
Carry = Tuple[JaxSim, simulator_callbacks.CallbackHandler]
carry_init: Carry = (sim, callback_handler)

def body_fun(carry: Carry, xs: None) -> Tuple[Carry, jtp.PyTree]:

sim, callback_handler = carry

# Make sure to pass a mutable version of the simulator to the callbacks
sim = sim.mutable(validate=True)

# Callback: pre-step
# TODO: should we allow also producing a pre-step output?
sim = pre_step_cb(sim) if pre_step_cb is not None else sim

# Integrate all models
step_data = sim.step(clear_inputs=clear_inputs)

# Callback: post-step
sim, out_post_step = (
post_step_cb(sim, step_data)
if post_step_cb is not None
else (sim, None)
)

# Pack the carry
carry = (sim, callback_handler)

return carry, out_post_step

# Integrate over the given horizon
(sim, callback_handler), out_cb_horizon = jax.lax.scan(
f=body_fun, init=carry_init, xs=None, length=horizon_steps
)

# Enforce original mutability of the entire simulator
sim._set_mutability(self._mutability())

return (
sim
if callback_handler is None
else (sim, (callback_handler, out_cb_horizon))
)
56 changes: 56 additions & 0 deletions src/jaxsim/simulation/simulator_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import abc
from typing import Callable, Dict, Tuple

import jaxsim.typing as jtp
from jaxsim.high_level.model import StepData
from jaxsim.simulation.simulator import JaxSim


ConfigureCallbackSignature = Callable[[JaxSim], JaxSim]
PreStepCallbackSignature = Callable[[JaxSim], JaxSim]
PostStepCallbackSignature = Callable[
[JaxSim, Dict[str, StepData]], Tuple[JaxSim, jtp.PyTree]
]


class SimulatorCallback(abc.ABC):
pass


class ConfigureCallback(SimulatorCallback):
@property
def configure_cb(self) -> ConfigureCallbackSignature:

return lambda sim: self.configure(sim=sim)

@abc.abstractmethod
def configure(self, sim: JaxSim) -> JaxSim:
pass


class PreStepCallback(SimulatorCallback):
@property
def pre_step_cb(self) -> PreStepCallbackSignature:

return lambda sim: self.pre_step(sim=sim)

@abc.abstractmethod
def pre_step(self, sim: JaxSim) -> JaxSim:
pass


class PostStepCallback(SimulatorCallback):
@property
def post_step_cb(self) -> PostStepCallbackSignature:

return lambda sim, step_data: self.post_step(sim=sim, step_data=step_data)

@abc.abstractmethod
def post_step(
self, sim: JaxSim, step_data: Dict[str, StepData]
) -> Tuple[JaxSim, jtp.PyTree]:
pass


class CallbackHandler(ConfigureCallback, PreStepCallback, PostStepCallback):
pass

0 comments on commit e14490a

Please sign in to comment.