In [None]:
#@title

from IPython.display import HTML, Image

try:
  import brax
except ImportError:
  from IPython.display import clear_output
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax.io import html
import torch
v = torch.ones(1, device='cuda')  # init torch cuda before jax

In [None]:
from typing import Sequence, Union

import jax
from brax.jumpy import _in_jit, X, Optional, Tuple, Callable, onp, Any, _which_np, jnp, ndarray


def while_loop(cond_fun: Callable[[X], Any],
               body_fun: Callable[[X], X],
               init_val: X) -> X:
  """Call body function while conditional function is true, starting with state"""
  if _in_jit():
    return jax.lax.while_loop(cond_fun, body_fun, init_val)
  else:
    val = init_val
    while cond_fun(val):
      val = body_fun(val)
    return val


def fori_loop(lower: int, upper: int,
               body_fun: Callable[[X], X],
               init_val: X) -> X:
  """Call body function while conditional function is true, starting with state"""
  if _in_jit():
    return jax.lax.fori_loop(lower, upper, body_fun, init_val)
  else:
    val = init_val
    for i in range(lower, upper):
      val = body_fun(val)
    return val


def index_add(x: ndarray, idx: ndarray, y: ndarray) -> ndarray:
    """Pure equivalent of x[idx] += y."""
    if _which_np(x) is jnp:
        return x.at[idx].add(y)
    x = onp.copy(x)
    x[idx] += y
    return x


def index_update(x: ndarray, idx: ndarray, y: ndarray) -> ndarray:
  """Pure equivalent of x[idx] += y."""
  if _which_np(x) is jnp:
    return x.at[idx].set(y, mode='drop')
  x = onp.copy(x)
  x[idx] = y
  return x


def meshgrid(*xi, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> ndarray:
    if _which_np(xi[0]) is jnp:
        return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
    return onp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)


def randint(rng: ndarray, shape: Tuple[int, ...] = (),
            low: Optional[int] = 0, high: Optional[int] = 1) -> ndarray:
  """Sample integers in [low, high) with given shape."""
  if _which_np(rng) is jnp:
    return jax.random.randint(rng, shape=shape, minval=low, maxval=high)
  else:
    return onp.random.default_rng(rng).integers(low=low, high=high, size=shape)


def maximum(x1: ndarray, x2: ndarray) -> ndarray:
  """Element-wise maximum of array elements."""
  return _which_np(x1, x2).maximum(x1, x2)


def choice(rng: ndarray, a: Union[int, Any], shape: Tuple[int, ...] = (),
           replace: bool = True, p: Optional[Any] = None, axis: int = 0) -> ndarray:
  """Generate sample(s) from given array"""
  if _which_np(rng) is jnp:
    return jax.random.choice(rng, a, shape=shape, replace=replace, p=p, axis=axis)
  else:
    return onp.random.default_rng(rng).choice(a, size=shape, replace=replace, p=p, axis=axis)


def atleast_1d(*arys) -> ndarray:
  """Ensure arrays are all at least 1d (dimensions added to beginning)"""
  return _which_np(*arys).atleast_1d(*arys)


def atleast_2d(*arys) -> ndarray:
  """Ensure arrays are all at least 2d (dimensions added to beginning)"""
  return _which_np(*arys).atleast_2d(*arys)


def atleast_3d(*arys) -> ndarray:
  """Ensure arrays are all at least 3d (dimensions added to beginning)"""
  return _which_np(*arys).atleast_3d(*arys)


def cond(pred, true_fun: Callable, false_fun: Callable, *operands: Any):
  """Conditionally apply true_fun or false_fun to operands"""
  if _in_jit():
    return jax.lax.cond(pred, true_fun, false_fun, *operands)
  else:
    if pred:
      return true_fun(operands)
    else:
      return false_fun(operands)

In [None]:
import brax
import brax.jumpy as jp
from brax.physics.config_pb2 import Body


def add_box_wall_to_body(body: Body, from_xy: jp.ndarray, to_xy: jp.ndarray, half_height: float = 0.5, wall_width: float = 0.25) -> None:
    """Add a box wall collider to a body

    Args:
        body: Body contained from a config_pb2 object. Assume body handles z-axis
        from_xy: xy coordinates of start of capsule (relative to body)
        to_xy: xy coordinates of end of capsule (relative to body)
        half_height: Half height of box

    Returns:
        Nothing
    """
    unit_vector = jp.zeros_like(from_xy); unit_vector = jp.index_update(unit_vector, jp.arange(0, 1), jp.ones(1))  # x unit vector
    vector = to_xy - from_xy  # Used for angle and length
    length = jp.norm(vector)
    midpoint = (from_xy + to_xy) / 2  # xy midpoint (position)
    # Dot product for vector rotation (a dot b = |a| * |b| * cos theta), convert to degrees
    z_rotation = (jp.arccos(jp.dot(unit_vector, vector) / length) * 180 / jp.pi)
    coll = body.colliders.add()  # Add collider (for position and rotation)
    coll.position.x, coll.position.y = midpoint
    coll.rotation.z = z_rotation  # W.r.t. unit x-vector
    box = coll.box  # Actual box object
    box.halfsize.x, box.halfsize.y, box.halfsize.z = length / 2, wall_width, half_height


def add_capsule_wall_to_body(body: Body, from_xy: jp.ndarray, to_xy: jp.ndarray, radius: float = 0.5, include_radius: bool = False) -> None:
    """Add a capsule wall collider to a body

    Note: currently only support horizontal and vertical

    Args:
        body: Body contained from a config_pb2 object
        from_xy: xy coordinates of start of capsule (relative to body)
        to_xy: xy coordinates of end of capsule (relative to body)
        radius: Radius of capsule
        include_radius: If true, include capsule radius in capsule length calculation (shrink by radius * 2)
    Returns:
        Nothing. Body is modified in place anyway
    """
    length = jp.norm(from_xy - to_xy) - (include_radius * radius)  # Capsule length
    cap_xy = (from_xy + to_xy) / 2  # Capsule xy position
    # Need capsule rotation
    # A vertical (same x) capsule has rotation.x = 90. Goes from up in z to up in x
    # A horizontal (same y) capsule has rotation.y = 90. Goes from up in z to up in y
    assert (from_xy[0] == to_xy[0]) or (from_xy[1] == to_xy[1])
    coll = body.colliders.add()  # Add collider (for position and rotation)
    vertical = (from_xy[0] == to_xy[0])  # Vertical walls (y to y), otherwise horizontal (x to x)
    coll.position.x = cap_xy[0]; coll.position.y = cap_xy[1]
    if vertical: coll.rotation.x = 90
    else: coll.rotation.y = 90
    cap = coll.capsule  # Actual capsule object
    cap.radius = radius; cap.length = length


def draw_arena(cfg: brax.Config, cage_x: float, cage_y: float, capsule_radius_or_box_half_height: float = 0.5,
               arena_name: str = "Arena", use_boxes: bool = True) -> None:
    """Add frozen 4-sided arena using capsule walls to enforce bounds of play

    Arranged such that cage_x and cage_y are the bounds of the inner area (i.e., at radius edge of capsule facing inward)
    Defines a rectangle from [-cage_x - rad, -cage_y - rad] to [cage_x + rad, cage_y + rad]. Determine additional space needs elsewhere!
    Args:
        cfg: brax Config object
        cage_x: Max x size
        cage_y: Max y size
        capsule_radius_or_box_half_height: thickness of wall (capsule) or half height of wall (box, is 2x thickness). >= 0.5 recommended
        arena_name: Name given to arena (used to include collide pairs later)
        use_boxes: Use box walls instead of capsules
    Returns:
        Nothing, in-place
    """
    x, y, r = cage_x, cage_y, capsule_radius_or_box_half_height
    arena = cfg.bodies.add(name=arena_name, mass=1.)  # 1 frozen body, many colliders
    arena.frozen.all = True
    aqp = cfg.defaults.add().qps.add(name=arena_name)  # Default height such that walls just touch the ground
    aqp.pos.z = capsule_radius_or_box_half_height
    if use_boxes: r /= 2  # Wall halfsize, expand coordinates so that we *enclose* this space
    xy_positions = jp.array([[x + r, y + r], [x + r, -y - r], [-x - r, -y - r], [-x - r, y + r]])
    for i in range(len(xy_positions)):
        add_capsule_wall_to_body(arena, xy_positions[i], xy_positions[int((i+1) % 4)], r, True) if not use_boxes else add_box_wall_to_body(arena, xy_positions[i], xy_positions[int((i+1) % 4)], capsule_radius_or_box_half_height, r)

In [None]:
from brax import math as math
import jax
from brax import jumpy as jp
from brax.envs import env
from google.protobuf import text_format

def extend_ant_cfg(cfg: str = brax.envs.ant._SYSTEM_CONFIG,
                   cage_max_xy: jp.ndarray = jp.array([4.5, 4.5]),
                   offset: float = 1,
                   n_apples: int = 8,
                   n_bombs: int = 8) -> brax.Config:
    cfg = text_format.Parse(cfg, brax.Config())  # Get ant config
    ant_body_names = [b.name for b in cfg.bodies if b.name != 'Ground']  # Find ant components
    # Add arena
    draw_arena(cfg, cage_max_xy[0] + offset, cage_max_xy[1] + offset, 0.5)
    for b in ant_body_names:
        cfg.collide_include.add(first=b, second='Arena')
    # Add apples and bombs. All frozen, non-collidable objects, all starting in same spot (actual spot determined on reset)
    for i in range(n_apples):
        apple = cfg.bodies.add(name=f'Target_{i+1}', mass=1.)
        apple.frozen.all = True
        sph = apple.colliders.add().sphere
        sph.radius = 0.25
    for i in range(n_bombs):
        bomb = cfg.bodies.add(name=f'Bomb_{i+1}', mass=1.)
        bomb.frozen.all = True
        sph = bomb.colliders.add().sphere
        sph.radius = 0.25
    return cfg


class AntGatherEnv(env.Env):
    """
    Args:
        n_apples: Number of apples in environment (+1 reward each)
        n_bombs: Number of bombs in environment  (-1 reward each)
        cage_xy: Max x and y values of arena (box from (-x,-y) to (x,y))
        robot_object_spacing: Minimum spawn distance of objects from ant initial position
        catch_range: Distance at which robot "catches" apple or bomb
        n_bins: Resolution of ant sensor. If multiple objects are in same bin span, only closest is seen
        sensor_range: Range of ant sensors
        sensor_span: Arc (in radians) of ant sensors
        dying_cost: Cost for death (undoable locomotion error)

    Apples and bombs spawn at any integer grid location within cage_xy, except those too close to origin
    Ant gets its standard observations, plus:
      n_bins apple readings and n_bins bomb readings
    """
    def __init__(self,
                 n_apples: int = 8,
                 n_bombs: int = 8,
                 cage_xy: Sequence[float] = (6, 6),
                 robot_object_spacing: float = 2.,
                 catch_range: float = 1.,
                 n_bins: int = 10,
                 sensor_range: float = 6.,
                 sensor_span: float = jp.pi,
                 dying_cost: float = -10.,
                 **kwargs
                 ):
        self.cage_xy = jp.array(cage_xy)
        cfg = extend_ant_cfg(cage_max_xy=self.cage_xy, offset=1., n_apples=n_apples, n_bombs=n_bombs)  # Add walls, apples, and bombs
        self.sys = brax.System(cfg)
        # super().__init__(_SYSTEM_CONFIG)
        # Ant and target indexes
        self.torso_idx = self.sys.body.index['$ Torso']  # Ant always starting in small jitter range at 0
        self.n_apples = n_apples
        self.n_bombs = n_bombs
        self.n_objects = n_apples + n_bombs
        self.n_bins = n_bins
        self.dying_cost = dying_cost
        self.sensor_range = sensor_range
        self.half_span = sensor_span / 2
        self.catch_range = catch_range
        last_ind = self.sys.num_bodies; first_ind = last_ind - (self.n_objects)
        self.object_indices = jp.arange(first_ind, last_ind)  # Indices for apples and bombs
        # Find all integer locations at least robot_object_spacing away from ant spawn position
        possible_grid_positions = jp.stack([g.ravel() for g in meshgrid(jp.arange(-self.cage_xy[0], self.cage_xy[0]+1), jp.arange(-self.cage_xy[1], self.cage_xy[1]+1))], axis=1)
        self.possible_grid_positions = jp.stack([g for g in possible_grid_positions if jp.norm(g) > robot_object_spacing], axis=0)
        self.possible_grid_positions = jp.concatenate([self.possible_grid_positions, jp.zeros((self.possible_grid_positions.shape[0], 1))], axis=1)
        self.waiting_area = self.possible_grid_positions[-1] + self.sensor_range * 2  # Stick captured objects somewhere else

    def reset(self, rng: jp.ndarray) -> env.State:
        qp = self.sample_init_qp(rng)
        info = self.sys.info(qp)
        distances = jp.norm(qp.pos[self.torso_idx][:2] - qp.pos[self.object_indices][..., :2],
                            axis=1)  # Distances to all objects
        obs = self._get_obs(qp, info, distances)
        reward, done, zero = jp.zeros(3)
        # Use metrics to track apples and bombs, determine termination
        metrics = {
            'apples': zero,
            'bombs': zero,
            'objects': zero,
        }
        info = {'rng': rng}  # Save rng
        return env.State(qp, obs, reward, done, metrics, info)

    def sample_init_qp(self, rng: jp.ndarray) -> brax.QP:
        rng, rng1, rng2, rng3 = jp.random_split(rng, 4)
        # Initial joint and velocity positions
        qpos = self.sys.default_angle() + jp.random_uniform(
            rng1, (self.sys.num_joint_dof,), -.1, .1)
        qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -.1, .1)
        # Set default qp with the sampled joints
        qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
        # Sample object positions
        object_pos = choice(rng3, self.possible_grid_positions, (self.n_objects,), replace=False)
        # apple_pos, bomb_pos = object_pos[:self.n_apples], object_pos[self.n_apples:]
        # Update object positions
        pos = jp.index_update(qp.pos, self.object_indices, object_pos)
        return qp.replace(pos=pos)

    def step(self, state: env.State, action: jp.ndarray) -> env.State:
        """Run one timestep of the environment's dynamics."""
        qp, info = self.sys.step(state.qp, action)
        distances = jp.norm(qp.pos[self.torso_idx][:2] - qp.pos[self.object_indices][..., :2],
                            axis=1)  # Distances to all objects
        # Get observation
        obs = self._get_obs(qp, info, distances)
        # "Death" and associated rewards
        done = jp.where(qp.pos[self.torso_idx, 2] < 0.2, x=jp.float32(1), y=jp.float32(0))
        done = jp.where(qp.pos[self.torso_idx, 2] > 1.0, x=jp.float32(1), y=done)
        reward = jp.where(done > 0, jp.float32(self.dying_cost), jp.float32(0))
        # Rewards for apples and bombs
        in_range = distances <= self.catch_range
        # Move objects we hit to the waiting area
        tgt_pos = jp.where(in_range[:, None], self.waiting_area, qp.pos[self.object_indices])
        qp = qp.replace(pos=jp.index_update(qp.pos, self.object_indices, tgt_pos))

        in_range_apple, in_range_bomb = in_range[:self.n_apples], in_range[self.n_apples:]
        reward = jp.where(in_range_apple.any() & (done == 0), jp.float32(1), reward)
        reward = jp.where(in_range_bomb.any() & (done == 0), jp.float32(-1), reward)
        # Done if we hit all objects
        done = jp.where((qp.pos[self.object_indices] == self.waiting_area).all(), jp.float32(1),done)
        apples_hit, bombs_hit = in_range_apple.sum(), in_range_bomb.sum()
        state.metrics.update(apples=apples_hit, bombs=bombs_hit)

        return state.replace(qp=qp, obs=obs, reward=reward, done=done)

    def _get_readings(self, qp: brax.QP, distances: jp.ndarray) -> jp.ndarray:
        """Get sensor readings for ant

        Get ant
          ori = [0, 1, 0, 0]
          rot = ant_quat
          ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3]
          ori = atan2(ori[1], ori[0])
        Split ant sensor span into n_bins
        For each bin, get only closest of each object type (apple or bomb)
        """
        readings = jnp.zeros(self.n_bins * 2)
        bin_res = (2 * self.half_span) / self.n_bins  # FOV for each bin
        ant_orientation = qp.rot[self.torso_idx]  # Quaternion orientation
        ori = jp.array([0,1,0,0])
        ori = math.quat_mul(math.quat_mul(ant_orientation, ori), math.quat_inv(ant_orientation))[1:3]
        ori = jp.arctan2(ori[1], ori[0])  # Projected into x-y plane
        object_xy = qp.pos[self.object_indices][..., :2]
        angles = jp.arctan2(object_xy[...,0], object_xy[...,1]) - ori  # Angle from ant face to all objects (-pi to pi)
        in_range = distances <= self.sensor_range
        # Sensor bin for each object (apples then bombs) (-1 of out of range/span) (nobjects,)
        object_bins = jp.where(jp.logical_and(jp.abs(angles) <= self.half_span, in_range)
                               , ((angles + self.half_span) / bin_res).astype(int), jp.int32(-1))
        bomb_bins = jp.where(object_bins[self.n_apples:] >= 0, object_bins[self.n_apples:] + self.n_apples, -1)
        object_bins = jp.index_update(object_bins, jp.arange(self.n_apples, self.n_objects), bomb_bins)
        object_intensities = jp.where(object_bins >= 0, 1. - (distances / self.sensor_range), jp.float32(0))
        readings = jp.index_update(readings, object_bins, object_intensities)
        # sorted_indices = object_bins.argsort()  # Sort so that -1 is all at beginning
        # TODO: Not quite right. This doesn't guarantee the closest reading, it just guarantees *a* reading
        return readings

    def _get_obs(self, qp: brax.QP, info: brax, distances: jp.ndarray) -> jp.ndarray:
        """Observe ant body position and velocities."""
        # some pre-processing to pull joint angles and velocities
        (joint_angle,), (joint_vel,) = self.sys.joints[0].angle_vel(qp)

        # qpos:
        # XYZ of the torso (3,)
        # orientation of the torso as quaternion (4,)
        # joint angles (8,)
        qpos = [qp.pos[0], qp.rot[0], joint_angle]

        # qvel:
        # velcotiy of the torso (3,)
        # angular velocity of the torso (3,)
        # joint angle velocities (8,)
        qvel = [qp.vel[0], qp.ang[0], joint_vel]

        # external contact forces:
        # delta velocity (3,), delta ang (3,) * 10 bodies in the system
        # Note that mujoco has 4 extra bodies tucked inside the Torso that Brax
        # ignores
        cfrc = [
            jp.clip(info.contact.vel, -1, 1),
            jp.clip(info.contact.ang, -1, 1)
        ]
        # flatten bottom dimension
        cfrc = [jp.reshape(x, x.shape[:-2] + (-1,)) for x in cfrc]
        # Sensor readings
        readings = [self._get_readings(qp, distances)]

        return jp.concatenate(qpos + qvel + cfrc + readings)

In [None]:
e = AntGatherEnv()
rng = jp.random_prngkey(0)
state = e.reset(rng=rng)
HTML(html.render(e.sys, [state.qp]))
states = [state]
for t in range(200):
    rng, rng1 = jp.random_split(rng, 2)
    states.append(jax.jit(e.step)(states[-1], jp.random_uniform(rng1, (8,), -1, 1)))
HTML(html.render(e.sys, [s.qp for s in states]))