# Generate Input Data for Profiling JaxRenderer using Brax Envs

Self-link: https://colab.research.google.com/drive/1c_83TLtb_pOt4OSlWFQgIKzA9DxdBDTp

Previous prfiling version (profile branch head): https://colab.research.google.com/drive/1V7gdTY6ZYz7YhJI_LCWLZM035zSDTYoQ

Previous prfiling version: https://colab.research.google.com/drive/1BJQG38IHPaUuMKMnNeq0hjrLtqwx4OfU

All-inlines + minibatch loops: https://colab.research.google.com/drive/1NiCTOCxfU0Mvr818Zqt4y63-S20Gn8yb

All-inlines: https://colab.research.google.com/drive/1Wrt3a0yoVPhYUJYCZltFIYoELWD-ciE0

Baseline link: https://colab.research.google.com/drive/17NSjyJL_Ov9D32Mnrs7miZmjiPNf9U6l

> Majority of the code is adopted from [Brax Teams's Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb), with minor modifications on the visualisation part.

The sample output is using CPU backend, with standard RAM.

In [None]:
#@title Install brax from pip
!pip install brax -qqq
!pip install jaxrenderer -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.2/151.2 kB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m71.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m44.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.6/101.6 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m639.3/639.3 kB[0m [31m44.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.8/207.8 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m304.5/304.5 kB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
#@title ## Imports
#@markdown **⚠️ PLEASE NOTE:**

#@markdown This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

import json
from typing import Iterable, NamedTuple, Optional
import os
import pickle

import jax
from jax import numpy as jp
import numpy as onp
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

import brax

from brax import base, envs, math

import trimesh

from renderer import CameraParameters as Camera
from renderer import LightParameters as Light
from renderer import Model as RendererMesh
from renderer import ModelObject as Instance
from renderer import ShadowParameters as Shadow
from renderer import Renderer, UpAxis, merge_objects, create_capsule, create_cube, transpose_for_display

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()



In [None]:
#@title ## Utility Code and Constants

canvas_width: int = 960 #@param {type: "integer"}
canvas_height: int = 540 #@param {type: "integer"}

def grid(grid_size: int, color) -> jp.ndarray:
  grid = onp.zeros((grid_size, grid_size, 3), dtype=onp.single)
  grid[:, :] = onp.array(color) / 255.0
  grid[0] = onp.zeros((grid_size, 3), dtype=onp.single)
  # to reverse texture along y direction
  grid[:, -1] = onp.zeros((grid_size, 3), dtype=onp.single)
  return jp.asarray(grid)

_GROUND: jp.ndarray = grid(100, [200, 200, 200])

class Obj(NamedTuple):
  """An object to be rendered in the scene.

  Assume the system is unchanged throughout the rendering.

  col is accessed from the batched geoms `sys.geoms`, representing one geom.
  """
  instance: Instance
  """An instance to be rendered in the scene, defined by jaxrenderer."""
  link_idx: int
  """col.link_idx if col.link_idx is not None else -1"""
  off: jp.ndarray
  """col.transform.rot"""
  rot: jp.ndarray
  """col.transform.rot"""

def _build_objects(sys: brax.System) -> list[Obj]:
  """Converts a brax System to a list of Obj."""
  objs: list[Obj] = []

  def take_i(obj, i):
    return jax.tree_map(lambda x: jp.take(x, i, axis=0), obj)

  link_names: list[str]
  link_names = [n or f'link {i}' for i, n in enumerate(sys.link_names)]
  link_names += ['world']
  link_geoms: dict[str, list[Any]] = {}
  for batch in sys.geoms:
    num_geoms = len(batch.friction)
    for i in range(num_geoms):
      link_idx = -1 if batch.link_idx is None else batch.link_idx[i]
      link_geoms.setdefault(link_names[link_idx], []).append(take_i(batch, i))

  for _, geom in link_geoms.items():
    for col in geom:
      tex = col.rgba[:3].reshape((1, 1, 3))
      # reference: https://github.com/erwincoumans/tinyrenderer/blob/89e8adafb35ecf5134e7b17b71b0f825939dc6d9/model.cpp#L215
      specular_map = jax.lax.full(tex.shape[:2], 2.0)

      if isinstance(col, base.Capsule):
        half_height = col.length / 2
        model = create_capsule(
          radius=col.radius,
          half_height=half_height,
          up_axis=UpAxis.Z,
          diffuse_map=tex,
          specular_map=specular_map,
        )
      elif isinstance(col, base.Box):
        model = create_cube(
          half_extents=col.halfsize,
          diffuse_map=tex,
          texture_scaling=jp.array(16.),
          specular_map=specular_map,
        )
      elif isinstance(col, base.Sphere):
        model = create_capsule(
          radius=col.radius,
          half_height=jp.array(0.),
          up_axis=UpAxis.Z,
          diffuse_map=tex,
          specular_map=specular_map,
        )
      elif isinstance(col, base.Plane):
        tex = _GROUND
        model = create_cube(
          half_extents=jp.array([1000.0, 1000.0, 0.0001]),
          diffuse_map=tex,
          texture_scaling=jp.array(8192.),
          specular_map=specular_map,
        )
      elif isinstance(col, base.Convex):
        # convex objects are not visual
        continue
      elif isinstance(col, base.Mesh):
        tm = trimesh.Trimesh(vertices=col.vert, faces=col.face)
        model = RendererMesh.create(
            verts=tm.vertices,
            norms=tm.vertex_normals,
            uvs=jp.zeros((tm.vertices.shape[0], 2), dtype=int),
            faces=tm.faces,
            diffuse_map=tex,
        )
      else:
        raise RuntimeError(f'unrecognized collider: {type(col)}')

      i: int = col.link_idx if col.link_idx is not None else -1
      instance = Instance(model=model)
      off = col.transform.pos
      rot = col.transform.rot
      obj = Obj(instance=instance, link_idx=i, off=off, rot=rot)

      objs.append(obj)

  return objs

def _with_state(objs: Iterable[Obj], x: brax.Transform) -> list[Instance]:
  """x must has at least 1 element. This can be ensured by calling
    `x.concatenate(base.Transform.zero((1,)))`. x is `state.x`.

    This function does not modify any inputs, rather, it produces a new list of
    `Instance`s.
  """
  if (len(x.pos.shape), len(x.rot.shape)) != (2, 2):
    raise RuntimeError('unexpected shape in state')

  instances: list[Instance] = []
  for obj in objs:
    i = obj.link_idx
    pos = x.pos[i] + math.rotate(obj.off, x.rot[i])
    rot = math.quat_mul(x.rot[i], obj.rot)
    instance = obj.instance
    instance = instance.replace_with_position(pos)
    instance = instance.replace_with_orientation(rot)
    instances.append(instance)

  return instances

def _eye(sys: brax.System, state: brax.State) -> jp.ndarray:
  """Determines the camera location for a Brax system."""
  xj = state.x.vmap().do(sys.link.joint)
  dist = jp.concatenate(xj.pos[None, ...] - xj.pos[:, None, ...])
  dist = jp.linalg.norm(dist, axis=1).max()
  off = jp.array([2 * dist, -2 * dist, dist])

  return state.x.pos[0, :] + off

def _up(unused_sys: brax.System) -> jp.ndarray:
  """Determines the up orientation of the camera."""
  return jp.array([0., 0., 1.])

def get_target(state: brax.State) -> jp.ndarray:
  """Gets target of camera."""
  return jp.array([state.x.pos[0, 0], state.x.pos[0, 1], 0])

def get_camera(
    sys: brax.System,
    state: brax.State,
    width: int = canvas_width,
    height: int = canvas_height,
) -> Camera:
  """Gets camera object."""
  eye, up = _eye(sys, state), _up(sys)
  hfov = 58.0
  vfov = hfov * height / width
  target = get_target(state)
  camera = Camera(
      viewWidth=width,
      viewHeight=height,
      position=eye,
      target=target,
      up=up,
      hfov=hfov,
      vfov=vfov,
  )

  return camera

In [None]:
#@title ## Generate States

env_name = 'ant' # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
number_of_frames_to_render: int = 30 #@param {type:"integer"}

def generate_state(env_name, number_of_frames_to_render, backend='positional'):
  env = envs.get_environment(env_name=env_name, backend=backend)
  jit_reset = jax.jit(env.reset)
  vmap_reset = jax.jit(jax.vmap(lambda i: jit_reset(rng=jax.random.PRNGKey(seed=i)).pipeline_state))
  seeds = jax.lax.iota(int, number_of_frames_to_render)
  states = jax.tree_map(lambda field: field.block_until_ready(), vmap_reset.lower(seeds).compile()(seeds))

  sys = env.sys

  with jax.profiler.TraceAnnotation("build inputs"):
    get_cameras = jax.jit(jax.vmap(lambda state: get_camera(sys, state))).lower(states).compile()
    batched_camera = jax.tree_map(lambda field: field.block_until_ready(), get_cameras(states))
    get_targets = jax.jit(jax.vmap(get_target)).lower(states).compile()
    batched_target = get_targets(states).block_until_ready()

    objs = _build_objects(sys)

    get_instances = jax.jit(jax.vmap(lambda state: _with_state(objs, state.x.concatenate(base.Transform.zero((1,)))))).lower(states).compile()
    batched_instances = jax.tree_map(lambda field: field.block_until_ready(), get_instances(states))

  return batched_instances, batched_camera, batched_target

In [None]:
batched_instances, batched_camera, batched_target = generate_state(env_name, number_of_frames_to_render)
pickle.dump((batched_instances, batched_camera, batched_target), open("inputs.pickle", "wb"))
!zip -r9 inputs.zip inputs.pickle

  adding: inputs.pickle (deflated 99%)


In [None]:
from google.colab import files
files.download("inputs.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
from typing import NamedTuple

class SceneStats(NamedTuple):
  number_of_objects: int
  number_of_triangles: int
  number_of_vertices: int


def get_stats(batched_instances):
  _one_instance = jax.tree_map(lambda a: jp.asarray(a[0]), batched_instances)
  _merged = merge_objects(_one_instance)

  return SceneStats._make([_merged.texture_shape.shape[0], _merged.faces.shape[0], _merged.verts.shape[0]])

envs_names = ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
df_buffer = onp.zeros((len(envs_names), 3), dtype=int)

for i, env in enumerate(envs_names):
  stats = get_stats(generate_state(env, 1)[0])
  df_buffer[i, :] = onp.asarray(list(stats._asdict().values()))

df = pd.DataFrame(df_buffer, index=envs_names, columns=["# of objects", "# of triangles", "# of vertices"])
df

Unnamed: 0,# of objects,# of triangles,# of vertices
ant,18,3276,9816
halfcheetah,9,1548,4632
hopper,5,780,2328
humanoid,18,3276,9816
humanoidstandup,18,3276,9816
inverted_pendulum,3,576,1728
inverted_double_pendulum,5,780,2328
pusher,21,3852,11544
reacher,10,1740,5208
walker2d,8,1356,4056


In [None]:
envs_names = ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
for env in envs_names:
  batched_instances, batched_camera, batched_target = generate_state(env, number_of_frames_to_render)
  pickle.dump((batched_instances, batched_camera, batched_target), open(f"inputs-{env}.pickle", "wb"))

!zip -r9 inputs.zip inputs-*.pickle

  adding: inputs-ant.pickle (deflated 99%)
  adding: inputs-halfcheetah.pickle (deflated 99%)
  adding: inputs-hopper.pickle (deflated 99%)
  adding: inputs-humanoid.pickle (deflated 99%)
  adding: inputs-humanoidstandup.pickle (deflated 99%)
  adding: inputs-inverted_double_pendulum.pickle (deflated 99%)
  adding: inputs-inverted_pendulum.pickle (deflated 99%)
  adding: inputs-pusher.pickle (deflated 99%)
  adding: inputs-reacher.pickle (deflated 99%)
  adding: inputs-walker2d.pickle (deflated 99%)


In [None]:
#@title Terminating Colab Automatically
try:
  from google.colab import runtime
  print("Terminating Colab")
  runtime.unassign()
except ModuleNotFoundError:
  print("Not in Colab, skip termination")