In [None]:
from IPython.display import HTML, Image

try:
  import brax
except ImportError:
  from IPython.display import clear_output
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax.io import html
# import timeit

In [None]:
import jax
from brax.jumpy import _in_jit, Callable,Any


def cond(pred, true_fun: Callable, false_fun: Callable, *operands: Any):
    if _in_jit():
        return jax.lax.cond(pred, true_fun, false_fun, *operands)
    else:
        if pred:return true_fun(operands)
        else: return false_fun(operands)

In [None]:
from brax.envs import State, Wrapper
import brax.jumpy as jp
class RandomizedAutoResetWrapperNaive(Wrapper):
    """Automatically resets Brax envs that are done.

    Force resample every step. Inefficient"""
    def step(self, state: State, action: jp.ndarray) -> State:
        if 'steps' in state.info:
            steps = state.info['steps']
            steps = jp.where(state.done, jp.zeros_like(steps), steps)
            state.info.update(steps=steps)
        state = state.replace(done=jp.zeros_like(state.done))
        state = self.env.step(state, action)
        maybe_reset = self.reset(state.info['rng'])

        def where_done(x, y):
            done = state.done
            if done.shape:
                done = jp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))  # type: ignore
            return jp.where(done, x, y)

        qp = jp.tree_map(where_done, maybe_reset.qp, state.qp)
        obs = where_done(maybe_reset.obs, state.obs)
        return state.replace(qp=qp, obs=obs)


class RandomizedAutoResetWrapperOnTerminal(Wrapper):
    """Automatically reset Brax envs that are done.

    Resample only when >=1 environment is actually done. Still resamples for all
    """
    def step(self, state: State, action: jp.ndarray) -> State:
        if 'steps' in state.info:
            steps = state.info['steps']
            steps = jp.where(state.done, jp.zeros_like(steps), steps)
            state.info.update(steps=steps)
        state = state.replace(done=jp.zeros_like(state.done))
        state = self.env.step(state, action)
        maybe_reset = cond(state.done.any(), self.reset, lambda rng: state, state.info['rng'])

        def where_done(x, y):
            done = state.done
            if done.shape:
                done = jp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))  # type: ignore
            return jp.where(done, x, y)

        qp = jp.tree_map(where_done, maybe_reset.qp, state.qp)
        obs = where_done(maybe_reset.obs, state.obs)
        return state.replace(qp=qp, obs=obs)

In [None]:
from brax.envs.wrappers import EpisodeWrapper, AutoResetWrapper, VmapWrapper
import time

ENV_NAME = 'fetch' # Create basic fetch environment (one of the few with 'rng' in its state.info)
NUM_ENVS = 2048
EPISODE_LENGTH = 40
T = 1000

# Vmap wrapper requires us to pass key of batch size (vs VectorWrapper)
BASE_KEY = jax.random.PRNGKey(0)
MULTI_KEY = jax.random.split(BASE_KEY, NUM_ENVS)

base_env = VmapWrapper(EpisodeWrapper(envs._envs[ENV_NAME](), EPISODE_LENGTH, 1))
action = jax.numpy.ones((NUM_ENVS, 10))  # Action on GPU/TPU to save transfer

In [None]:
for reset_wrapper_class in [AutoResetWrapper, RandomizedAutoResetWrapperNaive, RandomizedAutoResetWrapperOnTerminal]:
    name = reset_wrapper_class.__name__
    print(f'testing runtime for {name}')
    e = reset_wrapper_class(base_env)
    times = [time.time()]
    state = jax.jit(e.reset)(MULTI_KEY)
    for _ in range(T):
        state = jax.jit(e.step)(state, action)
        times.append(time.time())
    print(f'jit times (rough){times[1] - times[0]}')
    print(f'avg step time{jp.mean(times[11:] - times[10:-1])}')
