# Benchmark Brax Rendering Example using JaxRenderer backend

Self-link: https://colab.research.google.com/drive/1X_IbG5SGdQ_GnsjjlBaF9pJQR7vu8Aml

Benchmarking using A100 on smaller canvas: https://colab.research.google.com/drive/1Mr2pRetdobZdgdtNT_DJuS4YZq3-C4oY

Benchmarking using A100 on larger canvas: https://colab.research.google.com/drive/1A7PzhG3vn6oNzrWTxE5E3dmu8xQTcNnH

Benchmarking using T4, and baseline (CPU PyTinyrenderer): https://colab.research.google.com/drive/1xhkYNz5WjvUCjQWpp72CLf9SIy3i5PnN

Generate data: https://colab.research.google.com/drive/1c_83TLtb_pOt4OSlWFQgIKzA9DxdBDTp

<details>
  <summary>Other related profiling/benchmarking Colabs</summary>

Profile with pre-generated data: https://colab.research.google.com/drive/12yNBVOdwUqUOBRgmQHF0gl8eMzzIi_BH

Previous profile version (profile branch head): https://colab.research.google.com/drive/1V7gdTY6ZYz7YhJI_LCWLZM035zSDTYoQ

Previous prfiling version: https://colab.research.google.com/drive/1BJQG38IHPaUuMKMnNeq0hjrLtqwx4OfU

All-inlines + minibatch loops: https://colab.research.google.com/drive/1NiCTOCxfU0Mvr818Zqt4y63-S20Gn8yb

All-inlines: https://colab.research.google.com/drive/1Wrt3a0yoVPhYUJYCZltFIYoELWD-ciE0

Baseline link: https://colab.research.google.com/drive/17NSjyJL_Ov9D32Mnrs7miZmjiPNf9U6l

</details>

The sample output is using T4 backend, with Standard RAM.

In [None]:
#@title Reinstall jaxrenderer from tag `v0.3.0`; Install pytinyrenderer
!pip uninstall jaxrenderer -y -qqq
!pip install git+https://github.com/JoeyTeng/jaxrenderer.git@v0.3.0 -qqq

!pip install pytinyrenderer --upgrade -qqq

[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for jaxrenderer (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m35.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#@title ## GPU Spec
!nvidia-smi

Wed Jun 21 02:04:28 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   64C    P8    11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
#@title ## Imports
#@markdown **⚠️ PLEASE NOTE:**

#@markdown This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

from functools import partial
from typing import Optional
import os
import pickle

import jax
from jax import numpy as jp
import numpy as onp
import pandas as pd
from scipy.spatial.transform import Rotation as R

from tqdm.auto import tqdm

import pytinyrenderer

from renderer import CameraParameters as Camera
from renderer import LightParameters as Light
from renderer import ModelObject as Instance
from renderer import ShadowParameters as Shadow
from renderer import Renderer, transpose_for_display

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title ## Load data; Constants
!rm inputs-30.zip inputs.pickle
!wget https://github.com/JoeyTeng/jaxrenderer/raw/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip
!unzip inputs-30.zip

FRAMES_AVAILABLE: int = 30

_batched_instances, _batched_camera, _batched_target = pickle.load(open("inputs.pickle", "rb"))

canvas_width: int = 480 #@param {type:"integer"}
canvas_height: int = 270 #@param {type:"integer"}

_batched_instances = jax.tree_map(lambda a: jp.asarray(a).block_until_ready(), _batched_instances)
_batched_camera = jax.tree_map(lambda a: jp.asarray(a).block_until_ready(), _batched_camera)
_batched_target = jax.tree_map(lambda a: jp.asarray(a).block_until_ready(), _batched_target)

rm: cannot remove 'inputs-30.zip': No such file or directory
rm: cannot remove 'inputs.pickle': No such file or directory
--2023-06-21 02:04:38--  https://github.com/JoeyTeng/jaxrenderer/raw/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip
Resolving github.com (github.com)... 192.30.255.112
Connecting to github.com (github.com)|192.30.255.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/JoeyTeng/jaxrenderer/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip [following]
--2023-06-21 02:04:38--  https://raw.githubusercontent.com/JoeyTeng/jaxrenderer/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP

In [None]:
#@title ## Utility Function
import gc
import itertools
import time
from typing import Any, Callable


def benchmark(
    f: Callable[[], Any],
    repeat: int = 2,
    number: int = 1,
    sort: bool = True,
) -> list[int]:
  """Execute given pure function `f` `number` times and take sum, repeat for
    `repeat` times.

  In total, f is executed repeat * number times.

  This function is a very similar re-implementation of `timeit.repeat`.

  Parameters:
    - f: Pure function to be tested, taking no parameter and returns nothing.
      Return value is ignored. This function must be safe to reentrancy.
    - repeat: int, default 2.
    - number: int, default 1.
    - sort: bool, default True. Whether to sort result timings (in ascending
      order)

  Return:
    a list of total execution times, in nanoseconds. The size of the results
    equals to `repeat`.
  """
  result: list[int] = []

  for _ in range(repeat):
    it = itertools.repeat(None, number)
    gcold = gc.isenabled()
    gc.disable()

    try:
      _t_before: int = time.perf_counter_ns()
      for _ in it:
        f()
      _t_after: int = time.perf_counter_ns()
    finally:
      if gcold:
        gc.enable()

    result.append(_t_after - _t_before)

  if sort:
    result.sort()

  return result

## Baseline - Pytinyrenderer (CPU only)

In [None]:
#@title Prepare data
def prepare_data(
  number_of_frames_to_render: int,
  batched_instances,
  batched_camera,
  batched_target,
):
  _py_scenes = [pytinyrenderer.TinySceneRenderer() for _ in range(number_of_frames_to_render)]
  _py_instances = []
  _py_cameras = []
  _py_targets = onp.asarray(batched_target).tolist()

  # for each batch
  for i in range(number_of_frames_to_render):
    # process camera
    _camera = jax.tree_map(lambda field: field[i], batched_camera)
    camera = pytinyrenderer.TinyRenderCamera(
        viewWidth=int(_camera.viewWidth),
        viewHeight=int(_camera.viewHeight),
        near=float(_camera.near),
        far=float(_camera.far),
        hfov=float(_camera.hfov),
        vfov=float(_camera.vfov),
        position=onp.asarray(_camera.position).tolist(),
        target=onp.asarray(_camera.target).tolist(),
        up=onp.asarray(_camera.up).tolist(),
    )
    _py_cameras.append(camera)
    # process instances
    instances = jax.tree_map(lambda field: field[i], batched_instances)
    instances_ids = []
    for instance in instances:
      _texture = onp.asarray(instance.model.diffuse_map)[:, ::-1, :].swapaxes(0, 1)

      _model_id = _py_scenes[i].create_mesh(
          onp.asarray(instance.model.verts).flatten().tolist(),  # vertices
          onp.asarray(instance.model.norms).flatten().tolist(),  # normals
          onp.asarray(instance.model.uvs, dtype=float).flatten().tolist(),  # uvs
          onp.asarray(instance.model.faces, dtype=int).flatten().tolist(),  # indices
          (onp.asarray(_texture) * 255).astype(int).flatten().tolist(),  # texture
          _texture.shape[0],  # texture_width
          _texture.shape[1],  # texture_height
          1.,  # texture_scaling
      )
      _instance_id = _py_scenes[i].create_object_instance(_model_id)
      _py_scenes[i].set_object_position(_instance_id, onp.asarray(instance.transform[:3, 3]).tolist())
      _py_scenes[i].set_object_orientation(_instance_id, onp.asarray(R.from_matrix(instance.transform[:3, :3]).as_quat()).tolist())
      _py_scenes[i].set_object_local_scaling(_instance_id, onp.asarray(instance.local_scaling).tolist())

      instances_ids.append(_instance_id)

    _py_instances.append(instances_ids)

  return _py_scenes, _py_instances, _py_cameras, _py_targets


prepare = partial(
    prepare_data,
    batched_instances=_batched_instances,
    batched_camera=_batched_camera,
    batched_target=_batched_target,
)

In [None]:
#@title Benchmarking
num_frames = [30, 300, 3000]

def py_render(scenes, instances, cameras, targets):
  for scene, instances, camera, target in zip(scenes, instances, cameras, targets):
    light = pytinyrenderer.TinyRenderLight(shadowmap_center=target)
    _ = scene.get_camera_image(instances, light, camera)

timings = onp.zeros((len(num_frames), 2), dtype=int)
scenes, instances, cameras, targets = prepare(FRAMES_AVAILABLE)

for i, num_frame in tqdm(enumerate(num_frames), total=len(num_frames)):
  batches = num_frame // FRAMES_AVAILABLE

  def f():
    for _ in range(batches):
      py_render(scenes, instances, cameras, targets)

  ts = benchmark(f)
  timings[i][0] = ts[0]
  timings[i][1] = ts[-1]


pd.DataFrame(timings, index=num_frames, columns=("min", "max"))

  0%|          | 0/3 [00:00<?, ?it/s]

Unnamed: 0,min,max
30,9377981409,9800225626
300,96434341083,98666547654
3000,979216455444,985778668595


In [None]:
pd.DataFrame(1 / (timings / 10**9 / onp.asarray(num_frames)[:, None]), index=num_frames, columns=("max fps", "min fps"))

Unnamed: 0,max fps,min fps
30,3.198983,3.061154
300,3.110925,3.040544
3000,3.063674,3.043279


## Experiment - JaxRenderer

In [None]:
#@title ### Render function
def profile_compile(batched_instances, batched_camera, batched_target, loop_unroll: int = 1):
  """Return cost analysis. states is batched."""
  @jax.default_matmul_precision("float32")
  def render_instances(
    instances: list[Instance],
    width: int,
    height: int,
    camera: Camera,
    light: Optional[Light] = None,
    shadow: Optional[Shadow] = None,
    camera_target: Optional[jp.ndarray] = None,
    enable_shadow: bool = True,
  ) -> jp.ndarray:
    """Renders an RGB array of sequence of instances.

    Rendered result is not transposed with `transpose_for_display`; it is in
    floating numbers in [0, 1], not `uint8` in [0, 255].
    """
    if light is None:
      direction = jp.array([0.57735, -0.57735, 0.57735])
      light = Light(
          direction=direction,
          ambient=0.8,
          diffuse=0.8,
          specular=0.6,
      )
    if shadow is None and enable_shadow:
      assert camera_target is not None, 'camera_target is None'
      shadow = Shadow(centre=camera_target)
    elif not enable_shadow:
      shadow = None

    img = Renderer.get_camera_image(
      objects=instances,
      light=light,
      camera=camera,
      width=width,
      height=height,
      shadow_param=shadow,
      loop_unroll=loop_unroll,
    )
    arr = jax.lax.clamp(0., img, 1.)

    return arr

  def _render(instances, camera, target) -> jp.ndarray:
    _render = jax.jit(
      render_instances,
      static_argnames=("width", "height", "enable_shadow"),
      inline=True,
    )
    img = _render(instances=instances, width=canvas_width, height=canvas_height, camera=camera, camera_target=target)
    arr = transpose_for_display((img * 255).astype(jp.uint8))

    return arr

  _render_batch = jax.jit(jax.vmap(jax.jit(_render, inline=True)))
  _render_batch_lowered = _render_batch.lower(batched_instances, batched_camera, batched_target)
  _render_batch_compiled = _render_batch_lowered.compile()

  return _render_batch_compiled

### Branchless + loop unroll

In [None]:
#@title loop unroll options
loop_unroll_cases: list[int] = [1, 2, 4, 16, 32] # bigger cases like 40 will leads to OOM on VRAM

In [None]:
#@title benchmarking

# render 30 frames per batch
batched_instances = _batched_instances
batched_camera = _batched_camera
batched_target = _batched_target

timings = onp.zeros((len(loop_unroll_cases), 4), dtype=int)
column_indices = ["compilation (min)", "compilation (max)", "execution (min)", "execution (max)"]
row_indices = loop_unroll_cases.copy()

for i, loop_unroll in tqdm(enumerate(loop_unroll_cases), total=len(loop_unroll_cases)):
  _render_batch_compiled = profile_compile(batched_instances, batched_camera, batched_target, loop_unroll=loop_unroll)

  ts = benchmark(lambda: _render_batch_compiled(batched_instances, batched_camera, batched_target).block_until_ready(), repeat=7)
  # only take the minimum/maximum time
  timings[i, 2] = ts[0]
  timings[i, 3] = ts[-1]

  0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
pd.DataFrame(timings, index=row_indices, columns=column_indices)

Unnamed: 0,compilation (min),compilation (max),execution (min),execution (max)
1,0,0,3499834290,3564664037
2,0,0,4018799304,4091981912
4,0,0,4512406777,4541740212
16,0,0,4879007576,4906181689
32,0,0,4355216093,4397172162


In [None]:
pd.DataFrame(1 / (timings[:, 2:] / 10**9 / FRAMES_AVAILABLE), index=loop_unroll_cases, columns=("max fps", "min fps"))

Unnamed: 0,max fps,min fps
1,8.571834,8.41594
2,7.464916,7.331411
4,6.648337,6.605398
16,6.148791,6.114735
32,6.888292,6.822567


### Unroll = 1, number of frames

In [None]:
#@title Benchmarking
loop_unroll: int = 1
num_frames = [30, 300, 1500, 3000]
frames_per_batch = FRAMES_AVAILABLE

num_frames_timings = onp.zeros((len(num_frames), 2), dtype=int)
column_indices = ["execution (min)", "execution (max)"]
row_indices = num_frames.copy()

batched_instances = _batched_instances
batched_camera = _batched_camera
batched_target = _batched_target


for i, num_frame in tqdm(enumerate(num_frames), total=len(num_frames)):
  # benchmark
  _render_batch_compiled = profile_compile(batched_instances, batched_camera, batched_target, loop_unroll=loop_unroll)
  times = num_frame // frames_per_batch

  def _render_multiple_times():
    for _ in range(times):
      _render_batch_compiled(batched_instances, batched_camera, batched_target).block_until_ready()

  ts = benchmark(_render_multiple_times)
  # only take the minimum/maximum time
  num_frames_timings[i, 0] = ts[0]
  num_frames_timings[i, 1] = ts[-1]


pd.DataFrame(num_frames_timings, index=row_indices, columns=column_indices)

  0%|          | 0/4 [00:00<?, ?it/s]

Unnamed: 0,execution (min),execution (max)
30,3521039139,3599547950
300,35870824734,36252594237
1500,179091632579,179552647623
3000,358249628419,358742021206


In [None]:
pd.DataFrame(1 / (num_frames_timings / 10**9 / onp.array(num_frames)[:, None]), index=row_indices, columns=("max fps", "min fps"))

Unnamed: 0,max fps,min fps
30,8.520212,8.33438
300,8.363343,8.27527
1500,8.375601,8.354096
3000,8.374049,8.362555


In [None]:
#@title # Terminating Colab Automatically
try:
  from google.colab import runtime
  print("Terminating Colab")
  runtime.unassign()
except ModuleNotFoundError:
  print("Not in Colab, skip termination")