In [None]:
from IPython.display import HTML, Image

try:
  import brax
except ImportError:
  from IPython.display import clear_output
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax.io import html
# import timeit

In [None]:
from typing import Sequence, Union

import jax
from brax.jumpy import _in_jit, X, Optional, Tuple, Callable, onp, Any, _which_np, jnp, ndarray


def while_loop(cond_fun: Callable[[X], Any],
               body_fun: Callable[[X], X],
               init_val: X) -> X:
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

    The type signature in brief is

    .. code-block:: haskell

      while_loop :: (a -> Bool) -> (a -> a) -> a -> a

    The semantics of ``while_loop`` are given by this Python implementation::

      def while_loop(cond_fun, body_fun, init_val):
        val = init_val
        while cond_fun(val):
          val = body_fun(val)
        return val
    """
    if _in_jit():
        return jax.lax.while_loop(cond_fun, body_fun, init_val)
    else:
        val = init_val
        while cond_fun(val):
            val = body_fun(val)
        return val


def index_add(x: ndarray, idx: ndarray, y: ndarray) -> ndarray:
    """Pure equivalent of x[idx] += y."""
    if _which_np(x) is jnp:
        return x.at[idx].add(y)
    x = onp.copy(x)
    x[idx] += y
    return x


def meshgrid(*xi, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> ndarray:
    if _which_np(xi[0]) is jnp:
        return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
    return onp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)


def randint(rng: ndarray, shape: Tuple[int, ...] = (),
            low: Optional[int] = 0, high: Optional[int] = 1) -> ndarray:
    """Sample integers in [low, high) with given shape."""
    if _which_np(rng) is jnp:
        return jax.random.randint(rng, shape=shape, minval=low, maxval=high)
    else:
        return onp.random.default_rng(rng).integers(low=low, high=high, size=shape)


def choice(rng: ndarray, a: Union[int, Any], shape: Tuple[int, ...] = (),
           replace: bool = True, p: Optional[Any] = None, axis: int = 0) -> ndarray:
    """Pick from  in [low, high) with given shape."""
    if _which_np(rng) is jnp:
        return jax.random.choice(rng, a, shape=shape, replace=replace, p=p, axis=axis)
    else:
        return onp.random.default_rng(rng).choice(a, size=shape, replace=replace, p=p, axis=axis)


def atleast_1d(*arys) -> ndarray:
    return _which_np(*arys).atleast_1d(*arys)


def cond(pred, true_fun: Callable, false_fun: Callable, *operands: Any):
    if _in_jit():
        return jax.lax.cond(pred, true_fun, false_fun, *operands)
    else:
        if pred:return true_fun(operands)
        else: return false_fun(operands)

In [None]:
from brax.envs import State, Wrapper
import brax.jumpy as jp
class RandomizedAutoResetWrapperNaive(Wrapper):
    """Automatically resets Brax envs that are done.

    Force resample every step. Inefficient"""
    def step(self, state: State, action: jp.ndarray) -> State:
        if 'steps' in state.info:
            steps = state.info['steps']
            steps = jp.where(state.done, jp.zeros_like(steps), steps)
            state.info.update(steps=steps)
        state = state.replace(done=jp.zeros_like(state.done))
        state = self.env.step(state, action)
        maybe_reset = self.reset(state.info['rng'])

        def where_done(x, y):
            done = state.done
            if done.shape:
                done = jp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))  # type: ignore
            return jp.where(done, x, y)

        qp = jp.tree_map(where_done, maybe_reset.qp, state.qp)
        obs = where_done(maybe_reset.obs, state.obs)
        return state.replace(qp=qp, obs=obs)


class RandomizedAutoResetWrapperOnTerminal(Wrapper):
    """Automatically reset Brax envs that are done.

    Resample only when >=1 environment is actually done. Still resamples for all
    """
    def step(self, state: State, action: jp.ndarray) -> State:
        if 'steps' in state.info:
            steps = state.info['steps']
            steps = jp.where(state.done, jp.zeros_like(steps), steps)
            state.info.update(steps=steps)
        state = state.replace(done=jp.zeros_like(state.done))
        state = self.env.step(state, action)
        maybe_reset = cond(state.done.any(), self.reset, lambda rng: state, state.info['rng'])

        def where_done(x, y):
            done = state.done
            if done.shape:
                done = jp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))  # type: ignore
            return jp.where(done, x, y)

        qp = jp.tree_map(where_done, maybe_reset.qp, state.qp)
        obs = where_done(maybe_reset.obs, state.obs)
        return state.replace(qp=qp, obs=obs)

In [None]:
from brax.envs.wrappers import EpisodeWrapper, AutoResetWrapper, VmapWrapper
from time import time

ENV_NAME = 'fetch' # Create basic fetch environment (one of the few with 'rng' in its state.info)
NUM_ENVS = 2048
EPISODE_LENGTH = 1000
T = 10000

# Vmap wrapper requires us to pass key of batch size (vs VectorWrapper)
BASE_KEY = jax.random.PRNGKey(0)
MULTI_KEY = jax.random.split(BASE_KEY, NUM_ENVS)

base_env = VmapWrapper(EpisodeWrapper(envs._envs[ENV_NAME](), EPISODE_LENGTH, 1))
action = jax.numpy.ones(10)

In [None]:
for reset_wrapper_class in [AutoResetWrapper, RandomizedAutoResetWrapperNaive, RandomizedAutoResetWrapperOnTerminal]:
    name = reset_wrapper_class.__name__
    e = reset_wrapper_class(base_env)
    state = jax.jit(e.reset)(MULTI_KEY)
    %%timeit
    for _ in range(T):
        state = jax.jit(e.step)(state, action)
