Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Step the simulator over horizon #10

Merged
merged 1 commit into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions src/jaxsim/simulation/simulator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import functools
import pathlib
from typing import Dict, List, Union
from typing import Dict, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jax_dataclasses

Expand All @@ -14,7 +16,7 @@
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.physics.model.physics_model import PhysicsModel
from jaxsim.simulation import ode_integration
from jaxsim.simulation import ode_integration, simulator_callbacks
from jaxsim.utils import JaxsimDataclass


Expand Down Expand Up @@ -220,3 +222,86 @@ def step(self, clear_inputs: bool = False) -> Dict[str, StepData]:

self._set_mutability(self._mutability())
return step_data

@functools.partial(jax.jit, static_argnames=["horizon_steps"])
def step_over_horizon(
self,
horizon_steps: jtp.Int,
callback_handler: Union[
"simulator_callbacks.SimulatorCallback",
"simulator_callbacks.CallbackHandler",
] = None,
clear_inputs: jtp.Bool = False,
) -> Union[
"JaxSim",
Tuple["JaxSim", Tuple["simulator_callbacks.SimulatorCallback", jtp.PyTree]],
]:

# Process a mutable copy of the simulator
original_mutability = self._mutability()
sim = self.copy().mutable(validate=True)

# Helper to get callbacks from the handler
get_cb = (
lambda h, cb_name: getattr(h, cb_name)
if h is not None and hasattr(h, cb_name)
else None
)

# Get the callbacks
configure_cb: Optional[simulator_callbacks.ConfigureCallbackSignature] = get_cb(
h=callback_handler, cb_name="configure_cb"
)
pre_step_cb: Optional[simulator_callbacks.PreStepCallbackSignature] = get_cb(
h=callback_handler, cb_name="pre_step_cb"
)
post_step_cb: Optional[simulator_callbacks.PostStepCallbackSignature] = get_cb(
h=callback_handler, cb_name="post_step_cb"
)

# Callback: configuration
sim = configure_cb(sim) if configure_cb is not None else sim

# Initialize the carry
Carry = Tuple[JaxSim, simulator_callbacks.CallbackHandler]
carry_init: Carry = (sim, callback_handler)

def body_fun(carry: Carry, xs: None) -> Tuple[Carry, jtp.PyTree]:

sim, callback_handler = carry

# Make sure to pass a mutable version of the simulator to the callbacks
sim = sim.mutable(validate=True)

# Callback: pre-step
# TODO: should we allow also producing a pre-step output?
sim = pre_step_cb(sim) if pre_step_cb is not None else sim

# Integrate all models
step_data = sim.step(clear_inputs=clear_inputs)

# Callback: post-step
sim, out_post_step = (
post_step_cb(sim, step_data)
if post_step_cb is not None
else (sim, None)
)

# Pack the carry
carry = (sim, callback_handler)

return carry, out_post_step

# Integrate over the given horizon
(sim, callback_handler), out_cb_horizon = jax.lax.scan(
f=body_fun, init=carry_init, xs=None, length=horizon_steps
)

# Enforce original mutability of the entire simulator
sim._set_mutability(original_mutability)

return (
sim
if callback_handler is None
else (sim, (callback_handler, out_cb_horizon))
)
55 changes: 55 additions & 0 deletions src/jaxsim/simulation/simulator_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import abc
from typing import Callable, Dict, Tuple

import jaxsim.typing as jtp
from jaxsim.high_level.model import StepData
from jaxsim.simulation.simulator import JaxSim

ConfigureCallbackSignature = Callable[[JaxSim], JaxSim]
PreStepCallbackSignature = Callable[[JaxSim], JaxSim]
PostStepCallbackSignature = Callable[
[JaxSim, Dict[str, StepData]], Tuple[JaxSim, jtp.PyTree]
]


class SimulatorCallback(abc.ABC):
pass


class ConfigureCallback(SimulatorCallback):
@property
def configure_cb(self) -> ConfigureCallbackSignature:

return lambda sim: self.configure(sim=sim)

@abc.abstractmethod
def configure(self, sim: JaxSim) -> JaxSim:
pass


class PreStepCallback(SimulatorCallback):
@property
def pre_step_cb(self) -> PreStepCallbackSignature:

return lambda sim: self.pre_step(sim=sim)

@abc.abstractmethod
def pre_step(self, sim: JaxSim) -> JaxSim:
pass


class PostStepCallback(SimulatorCallback):
@property
def post_step_cb(self) -> PostStepCallbackSignature:

return lambda sim, step_data: self.post_step(sim=sim, step_data=step_data)

@abc.abstractmethod
def post_step(
self, sim: JaxSim, step_data: Dict[str, StepData]
) -> Tuple[JaxSim, jtp.PyTree]:
pass


class CallbackHandler(ConfigureCallback, PreStepCallback, PostStepCallback):
pass