In [None]:
#@title

from IPython.display import HTML, Image

try:
  import brax
except ImportError:
  from IPython.display import clear_output
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax.io import html
import torch
v = torch.ones(1, device='cuda')  # init torch cuda before jax

In [None]:
from typing import Sequence, Union

import jax
from brax.jumpy import _in_jit, X, Optional, Tuple, Callable, onp, Any, _which_np, jnp, ndarray

def while_loop(cond_fun: Callable[[X], Any],
               body_fun: Callable[[X], X],
               init_val: X) -> X:
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

    The type signature in brief is

    .. code-block:: haskell

      while_loop :: (a -> Bool) -> (a -> a) -> a -> a

    The semantics of ``while_loop`` are given by this Python implementation::

      def while_loop(cond_fun, body_fun, init_val):
        val = init_val
        while cond_fun(val):
          val = body_fun(val)
        return val
    """
    if _in_jit():
        return jax.lax.while_loop(cond_fun, body_fun, init_val)
    else:
        val = init_val
        while cond_fun(val):
            val = body_fun(val)
        return val


def index_add(x: ndarray, idx: ndarray, y: ndarray) -> ndarray:
    """Pure equivalent of x[idx] += y."""
    if _which_np(x) is jnp:
        return x.at[idx].add(y)
    x = onp.copy(x)
    x[idx] += y
    return x


def meshgrid(*xi, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> ndarray:
    if _which_np(xi[0]) is jnp:
        return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
    return onp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)


def randint(rng: ndarray, shape: Tuple[int, ...] = (),
            low: Optional[int] = 0, high: Optional[int] = 1) -> ndarray:
    """Sample integers in [low, high) with given shape."""
    if _which_np(rng) is jnp:
        return jax.random.randint(rng, shape=shape, minval=low, maxval=high)
    else:
        return onp.random.default_rng(rng).integers(low=low, high=high, size=shape)


def choice(rng: ndarray, a: Union[int, Any], shape: Tuple[int, ...] = (),
           replace: bool = True, p: Optional[Any] = None, axis: int = 0) -> ndarray:
    """Pick from  in [low, high) with given shape."""
    if _which_np(rng) is jnp:
        return jax.random.choice(rng, a, shape=shape, replace=replace, p=p, axis=axis)
    else:
        return onp.random.default_rng(rng).choice(a, size=shape, replace=replace, p=p, axis=axis)


def atleast_1d(*arys) -> ndarray:
    return _which_np(*arys).atleast_1d(*arys)

import brax
import brax.jumpy as jp
from brax.physics.config_pb2 import Body


def add_box_wall_to_body(body: Body, from_xy: jp.ndarray, to_xy: jp.ndarray, half_height: float = 0.5, wall_width: float = 0.25) -> None:
    """Add a box wall collider to a body

    Args:
        body: Body contained from a config_pb2 object. Assume body handles z-axis
        from_xy: xy coordinates of start of capsule (relative to body)
        to_xy: xy coordinates of end of capsule (relative to body)
        half_height: Half height of box

    Returns:
        Nothing
    """
    unit_vector = jp.zeros_like(from_xy); unit_vector = jp.index_update(unit_vector, jp.arange(0, 1), jp.ones(1))  # x unit vector
    vector = to_xy - from_xy  # Used for angle and length
    length = jp.norm(vector)
    midpoint = (from_xy + to_xy) / 2  # xy midpoint (position)
    # Dot product for vector rotation (a dot b = |a| * |b| * cos theta), convert to degrees
    z_rotation = (jp.arccos(jp.dot(unit_vector, vector) / length) * 180 / jp.pi)
    coll = body.colliders.add()  # Add collider (for position and rotation)
    coll.position.x, coll.position.y = midpoint
    coll.rotation.z = z_rotation  # W.r.t. unit x-vector
    box = coll.box  # Actual box object
    box.halfsize.x, box.halfsize.y, box.halfsize.z = length / 2, wall_width, half_height


def add_capsule_wall_to_body(body: Body, from_xy: jp.ndarray, to_xy: jp.ndarray, radius: float = 0.5, include_radius: bool = False) -> None:
    """Add a capsule wall collider to a body

    Note: currently only support horizontal and vertical

    Args:
        body: Body contained from a config_pb2 object
        from_xy: xy coordinates of start of capsule (relative to body)
        to_xy: xy coordinates of end of capsule (relative to body)
        radius: Radius of capsule
        include_radius: If true, include capsule radius in capsule length calculation (shrink by radius * 2)
    Returns:
        Nothing. Body is modified in place anyway
    """
    length = jp.norm(from_xy - to_xy) - (include_radius * radius)  # Capsule length
    cap_xy = (from_xy + to_xy) / 2  # Capsule xy position
    # Need capsule rotation
    # A vertical (same x) capsule has rotation.x = 90. Goes from up in z to up in x
    # A horizontal (same y) capsule has rotation.y = 90. Goes from up in z to up in y
    assert (from_xy[0] == to_xy[0]) or (from_xy[1] == to_xy[1])
    coll = body.colliders.add()  # Add collider (for position and rotation)
    vertical = (from_xy[0] == to_xy[0])  # Vertical walls (y to y), otherwise horizontal (x to x)
    coll.position.x = cap_xy[0]; coll.position.y = cap_xy[1]
    if vertical: coll.rotation.x = 90
    else: coll.rotation.y = 90
    cap = coll.capsule  # Actual capsule object
    cap.radius = radius; cap.length = length


def draw_arena(cfg: brax.Config, cage_x: float, cage_y: float, capsule_radius_or_box_half_height: float = 0.5, arena_name: str = "Arena", use_boxes: bool = True) -> None:
    """Add frozen 4-sided arena using capsule walls to enforce bounds of play

    Arranged such that cage_x and cage_y are the bounds of the inner area (i.e., at radius edge of capsule facing inward)
    Defines a rectangle from [-cage_x - rad, -cage_y - rad] to [cage_x + rad, cage_y + rad]. Determine additional space needs elsewhere!
    Args:
        cfg: brax Config object
        cage_x: Max x size
        cage_y: Max y size
        capsule_radius_or_box_half_height: thickness of wall (capsule) or half height of wall (box, is 2x thickness). >= 0.5 recommended
        arena_name: Name given to arena (used to include collide pairs later)
        use_boxes:
    Returns:
        Nothing, in-place
    """
    x, y, r = cage_x, cage_y, capsule_radius_or_box_half_height
    arena = cfg.bodies.add(name=arena_name, mass=1.)  # 1 frozen body, many colliders
    arena.frozen.all = True
    aqp = cfg.defaults.add().qps.add(name=arena_name)  # Default height such that walls just touch the ground
    aqp.pos.z = capsule_radius_or_box_half_height
    if use_boxes: r /= 2  # Wall halfsize, expand coordinates so that we *enclose* this space
    xy_positions = jp.array([[x + r, y + r], [x + r, -y - r], [-x - r, -y - r], [-x - r, y + r]])
    for i in range(len(xy_positions)):
        add_capsule_wall_to_body(arena, xy_positions[i], xy_positions[int((i+1) % 4)], r, True) if not use_boxes else add_box_wall_to_body(arena, xy_positions[i], xy_positions[int((i+1) % 4)], capsule_radius_or_box_half_height, r)



def draw_t_maze(cfg: brax.Config, t_x: float, t_y: float, hallway_width: float = 2., capsule_radius_or_box_half_height: float = 0.5, arena_name: str = "Arena", use_boxes: bool = True) -> None:
    """Draw a T (like in TMaze or heaven hell)

    Arranged such that cage_x and cage_y are the bounds of the inner area (i.e., at radius edge of capsule facing inward)
    Defines a rectangle from [-cage_x - rad, -cage_y - rad] to [cage_x + rad, cage_y + rad]. Determine additional space needs elsewhere!
    Args:
        cfg: brax Config object
        t_x: Rightmost x coordinate of top of T
        t_y: Top of T y coordinate
        hallway_width: Uniform width within T
        capsule_radius_or_box_half_height: thickness of wall. >=0.5 recommended
        arena_name: Name given to arena (used to include collide pairs later
    Returns:
        Nothing, in-place
    """
    r = capsule_radius_or_box_half_height
    arena = cfg.bodies.add(name=arena_name, mass=1.)  # 1 frozen body, many colliders
    arena.frozen.all = True
    aqp = cfg.defaults.add().qps.add(name=arena_name)  # Default height such that walls just touch the ground
    aqp.pos.z = capsule_radius_or_box_half_height
    # Top-left point, clockwise around T
    xy_positions = jp.array([
        [-t_x - r, t_y + r],
        [t_x + r, t_y + r],
        [t_x + r, t_y - hallway_width - r],
        [hallway_width + r, t_y - hallway_width - r],
        [hallway_width + r, -r],
        [-hallway_width - r, -r],
        [-hallway_width - r, t_y - hallway_width - r],
        [-t_x - r, t_y - hallway_width - r]
    ])
    for i in range(len(xy_positions)):
        add_capsule_wall_to_body(arena, xy_positions[i], xy_positions[int((i+1) % xy_positions.shape[0])], r, True) if not use_boxes else add_box_wall_to_body(arena, xy_positions[i], xy_positions[int((i+1) % xy_positions.shape[0])], capsule_radius_or_box_half_height, r)


In [None]:
"""Trains an ant to go to heaven by following the advice of a priest"""
import brax
from brax import jumpy as jp
from brax.envs import env
from google.protobuf import text_format


def extend_ant_cfg(cfg: str = brax.envs.ant._SYSTEM_CONFIG, hhp: jp.ndarray = jp.array([[-5.25, 7.], [5.25, 7.], [0., 7.]]), hallway_width: float = 2) -> brax.Config:
    cfg = text_format.Parse(cfg, brax.Config())  # Get ant config
    ant_body_names = [b.name for b in cfg.bodies if b.name != 'Ground']
    # Add priest
    priest = cfg.bodies.add(name='Priest', mass=1.)
    priest.frozen.all = True
    sph = priest.colliders.add().sphere
    sph.radius = 0.5
    aqp = cfg.defaults.add().qps.add(name='Priest')  # Default priest position, never changes
    aqp.pos.x, aqp.pos.y, aqp.pos.z = hhp[-1, 0], hhp[-1, 1], 1.
    heaven = cfg.bodies.add(name='Target', mass=1.)
    heaven.frozen.all = True
    sph = heaven.colliders.add().sphere
    sph.radius = 0.5
    hell = cfg.bodies.add(name='Hell', mass=1.)
    hell.frozen.all = True
    sph = hell.colliders.add().sphere
    sph.radius = 0.5
    # Add walls
    draw_t_maze(cfg, t_x=hhp[:,0].max() + hallway_width / 2, t_y=hhp[:,1].max() + hallway_width / 2, hallway_width=hallway_width)
    for b in ant_body_names:
        cfg.collide_include.add(first=b, second='Arena')
    # Need to match control frequency with Hai's. He uses 15 frame skip, timestep = 0.02, so 0.3 seconds between actions
    # Default is timestep = 0.05, substeps = 10
    # self.unwrapped.sys.config.dt *= action_repeat
    # self.unwrapped.sys.config.substeps *= action_repeat
    return cfg


class AntHeavenHellEnv(env.Env):
    """AntHeavenHell. Basically TMaze with partial observability

    Args:
        heaven_hell: xy positions of heaven hell. By convention, both at same y, left + right
        priest_position: Position of priest. Typically at the top of T
        visible_radius: Radius within which ant can see priest
        dying_cost: Cost for death (undoable locomotion error)
    """
    def __init__(self,
                 heaven_hell: Sequence[Sequence[float]] = ((-5.25, 7.), (5.25, 7.)),
                 priest_position: Sequence[float] = (0, 7.),
                 visible_radius: float = 2.,
                 dying_cost: float = -2.,
                 **kwargs):
        # Preliminaries
        self.heaven_hell_xy = jp.array(heaven_hell)
        self.priest_pos = jp.array(priest_position)
        self._hhp = jp.concatenate((jp.concatenate((self.heaven_hell_xy, self.priest_pos[None, ...]), axis=0), jp.ones((3, 1))), axis=1)
        self.visible_radius = visible_radius
        self.dying_cost = dying_cost
        cfg = extend_ant_cfg(hhp=self._hhp, hallway_width=2.)
        self.sys = brax.System(cfg)
        # Ant and target indexes
        self.target_idx = self.sys.body.index['Target']
        self.hell_idx = self.sys.body.index['Hell']
        self.priest_idx = self.sys.body.index['Priest']
        self.torso_idx = self.sys.body.index['$ Torso']
        self.ant_indices = jp.arange(self.torso_idx, self.priest_idx)  # All parts of ant
        self.ant_l = self.ant_indices.shape[0]
        self.ant_mg = tuple(meshgrid(self.ant_indices, jp.arange(0, 2)))
        self._init_ant_pos = jp.array([[-0.5, 0.5], [0.5, 1.5]])  # Low and high xy for ant position

    def reset(self, rng: jp.ndarray) -> env.State:
        rng, qp = self.sample_init_qp(rng)
        info = self.sys.info(qp)
        obs = self._get_obs(qp, info, jp.float32(0))
        reward, done, zero = jp.zeros(3)
        metrics = {
            'heavens': zero,
            'hells': zero
        }
        info = {'rng': rng}
        return env.State(qp, obs, reward, done, metrics, info)

    def sample_init_qp(self, rng: jp.ndarray):
        rng, rng1, rng2, rng3, rng4 = jp.random_split(rng, 5)
        qpos = self.sys.default_angle() + jp.random_uniform(
            rng1, (self.sys.num_joint_dof,), -.1, .1)
        qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -.1, .1)
        # initial ant position
        ant_pos = jp.random_uniform(rng3, (2,), *self._init_ant_pos)  # Sample ant torso position
        # Set default qp with the sampled joints
        qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
        # Add ant xy to all ant part positions (otherwise they spring back hilariously)
        pos = index_add(qp.pos, self.ant_mg, ant_pos[..., None])
        # Pick heaven and hell positions
        target_pos, hell_pos = choice(rng3, self._hhp[:2], (2,), False)
        # Update heaven and hell positions
        pos = jp.index_update(pos, jp.stack([self.target_idx, self.hell_idx]), jp.stack([target_pos, hell_pos]))
        # Actually update qpos
        return rng, qp.replace(pos=pos)


    def step(self, state: env.State, action: jp.ndarray) -> env.State:
        """Run one timestep of the environment's dynamics."""
        qp, info = self.sys.step(state.qp, action)
        # "Death" and associated rewards
        dead = jp.where(qp.pos[self.torso_idx, 2] < 0.2, x=jp.float32(1), y=jp.float32(0))
        dead = jp.where(qp.pos[self.torso_idx, 2] > 1.0, x=jp.float32(1), y=dead)
        reward = jp.where(dead > 0, jp.float32(self.dying_cost), jp.float32(0))
        heaven_hell_priest = jp.stack([qp.pos[self.target_idx], qp.pos[self.hell_idx], qp.pos[self.priest_idx]])
        # Are we in range of heaven/hell (done+reward) or priest (extra observation)
        in_range = (jp.norm(heaven_hell_priest[:, :2] - qp.pos[self.torso_idx, :2], axis=-1) <= self.visible_radius)
        priest_in_range = in_range[-1]
        reward = jp.where(in_range[0], jp.float32(1), reward)  # +1 for heaven
        reward = jp.where(in_range[1], jp.float32(-1), reward)  # -1 for hell
        done = jp.where(reward != 0, jp.float32(1), jp.float32(0))  # Done if any reward
        # Get observation
        obs = self._get_obs(qp, info, priest_in_range)
        state.metrics.update(hits=done)
        return state.replace(qp=qp, obs=obs, reward=reward, done=done)

    def _get_obs(self, qp: brax.QP, info: brax.Info, priest_in_range: jp.float32) -> jp.ndarray:
        """Observe ant body position and velocities."""
        # some pre-processing to pull joint angles and velocities
        (joint_angle,), (joint_vel,) = self.sys.joints[0].angle_vel(qp)
        # 0 obs if not in range, else -1/1 for heaven in negative/positive x direction
        tgt_x = atleast_1d(qp.pos[self.target_idx][0])
        heaven_direction = jp.where(priest_in_range > 0, jp.sign(tgt_x), jp.zeros_like(tgt_x))

        # qpos:
        # XYZ of the torso (3,)
        # orientation of the torso as quaternion (4,)
        # joint angles (8,)
        qpos = [qp.pos[0], qp.rot[0], joint_angle]

        # qvel:
        # velcotiy of the torso (3,)
        # angular velocity of the torso (3,)
        # joint angle velocities (8,)
        qvel = [qp.vel[0], qp.ang[0], joint_vel]

        # external contact forces:
        # delta velocity (3,), delta ang (3,) * 10 bodies in the system
        # Note that mujoco has 4 extra bodies tucked inside the Torso that Brax
        # ignores
        cfrc = [
            jp.clip(info.contact.vel, -1, 1),
            jp.clip(info.contact.ang, -1, 1)
        ]
        # flatten bottom dimension
        cfrc = [jp.reshape(x, x.shape[:-2] + (-1,)) for x in cfrc]
        # Target xy (if in range)

        # heaven direction (1,)
        return jp.concatenate(qpos + qvel + cfrc + [heaven_direction])

In [None]:
e = AntHeavenHellEnv()
rng = jp.random_prngkey(0)
state = e.reset(rng=rng)
HTML(html.render(e.sys, [state.qp]))
states = [state]
for t in range(200):
    rng, rng1 = jp.random_split(rng, 2)
    states.append(jax.jit(e.step)(states[-1], jp.random_uniform(rng1, (8,), -1, 1)))
HTML(html.render(e.sys, [s.qp for s in states]))