# Benchmark Brax Rendering Example using JaxRenderer backend

Self-link: https://colab.research.google.com/drive/1A7PzhG3vn6oNzrWTxE5E3dmu8xQTcNnH

Benchmarking using A100 on smaller canvas: https://colab.research.google.com/drive/1Mr2pRetdobZdgdtNT_DJuS4YZq3-C4oY

Benchmarking using T4, and baseline (CPU PyTinyrenderer): https://colab.research.google.com/drive/1xhkYNz5WjvUCjQWpp72CLf9SIy3i5PnN

Generate data: https://colab.research.google.com/drive/1c_83TLtb_pOt4OSlWFQgIKzA9DxdBDTp

<details>
  <summary>Other related profiling/benchmarking Colabs</summary>

Profile with pre-generated data: https://colab.research.google.com/drive/12yNBVOdwUqUOBRgmQHF0gl8eMzzIi_BH

Previous profile version (profile branch head): https://colab.research.google.com/drive/1V7gdTY6ZYz7YhJI_LCWLZM035zSDTYoQ

Previous prfiling version: https://colab.research.google.com/drive/1BJQG38IHPaUuMKMnNeq0hjrLtqwx4OfU

All-inlines + minibatch loops: https://colab.research.google.com/drive/1NiCTOCxfU0Mvr818Zqt4y63-S20Gn8yb

All-inlines: https://colab.research.google.com/drive/1Wrt3a0yoVPhYUJYCZltFIYoELWD-ciE0

Baseline link: https://colab.research.google.com/drive/17NSjyJL_Ov9D32Mnrs7miZmjiPNf9U6l

</details>

The sample output is using A100 backend, with High RAM (since this is the only combination available).

In [None]:
#@title Reinstall jaxrenderer from tag `v0.3.0`; Install pytinyrenderer
!pip uninstall jaxrenderer -y -qqq
!pip install git+https://github.com/JoeyTeng/jaxrenderer.git@v0.3.0 -qqq

!pip install pytinyrenderer --upgrade -qqq

[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for jaxrenderer (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#@title ## GPU Spec
!nvidia-smi

Mon Jun 12 09:56:56 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    45W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
#@title ## Imports
#@markdown **⚠️ PLEASE NOTE:**

#@markdown This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

from functools import partial
from typing import Optional
import os
import pickle

import jax
from jax import numpy as jp
import numpy as onp
import pandas as pd
from scipy.spatial.transform import Rotation as R

import pytinyrenderer

from renderer import CameraParameters as Camera
from renderer import LightParameters as Light
from renderer import ModelObject as Instance
from renderer import ShadowParameters as Shadow
from renderer import Renderer, transpose_for_display

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title ## Load data; Constants
!rm inputs-30.zip inputs.pickle
!wget https://github.com/JoeyTeng/jaxrenderer/raw/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip
!unzip inputs-30.zip

_batched_instances, _batched_camera, _batched_target = pickle.load(open("inputs.pickle", "rb"))

canvas_width: int = 960 #@param {type:"integer"}
canvas_height: int = 540 #@param {type:"integer"}
number_of_frames_to_render: int = 2 #@param {type:"slider", min: 1, max: 30, step: 1}

batched_instances = jax.tree_map(lambda a: jp.asarray(a[:number_of_frames_to_render]).block_until_ready(), _batched_instances)
batched_camera = jax.tree_map(lambda a: jp.asarray(a[:number_of_frames_to_render]).block_until_ready(), _batched_camera)
batched_target = jax.tree_map(lambda a: jp.asarray(a[:number_of_frames_to_render]).block_until_ready(), _batched_target)

rm: cannot remove 'inputs-30.zip': No such file or directory
rm: cannot remove 'inputs.pickle': No such file or directory
--2023-06-12 09:57:04--  https://github.com/JoeyTeng/jaxrenderer/raw/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/JoeyTeng/jaxrenderer/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip [following]
--2023-06-12 09:57:04--  https://raw.githubusercontent.com/JoeyTeng/jaxrenderer/92904b74f4d760cd66b8940ed41c2d854cfbebe9/test_resources/pre-gen-brax/inputs-30.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP req

In [None]:
#@title ## Utility Function
import gc
import itertools
import time
from typing import Any, Callable


def benchmark(
    f: Callable[[], Any],
    repeat: int = 2,
    number: int = 1,
    sort: bool = True,
) -> list[int]:
  """Execute given pure function `f` `number` times and take sum, repeat for
    `repeat` times.

  In total, f is executed repeat * number times.

  This function is a very similar re-implementation of `timeit.repeat`.

  Parameters:
    - f: Pure function to be tested, taking no parameter and returns nothing.
      Return value is ignored. This function must be safe to reentrancy.
    - repeat: int, default 2.
    - number: int, default 1.
    - sort: bool, default True. Whether to sort result timings (in ascending
      order)

  Return:
    a list of total execution times, in nanoseconds. The size of the results
    equals to `repeat`.
  """
  result: list[int] = []

  for _ in range(repeat):
    it = itertools.repeat(None, number)
    gcold = gc.isenabled()
    gc.disable()

    try:
      _t_before: int = time.perf_counter_ns()
      for _ in it:
        f()
      _t_after: int = time.perf_counter_ns()
    finally:
      if gcold:
        gc.enable()

    result.append(_t_after - _t_before)

  if sort:
    result.sort()

  return result

## Experiment - JaxRenderer

In [None]:
#@title ### Render function
def profile_compile(batched_instances, batched_camera, batched_target, loop_unroll: int = 1):
  """Return cost analysis. states is batched."""
  @jax.default_matmul_precision("float32")
  def render_instances(
    instances: list[Instance],
    width: int,
    height: int,
    camera: Camera,
    light: Optional[Light] = None,
    shadow: Optional[Shadow] = None,
    camera_target: Optional[jp.ndarray] = None,
    enable_shadow: bool = True,
  ) -> jp.ndarray:
    """Renders an RGB array of sequence of instances.

    Rendered result is not transposed with `transpose_for_display`; it is in
    floating numbers in [0, 1], not `uint8` in [0, 255].
    """
    if light is None:
      direction = jp.array([0.57735, -0.57735, 0.57735])
      light = Light(
          direction=direction,
          ambient=0.8,
          diffuse=0.8,
          specular=0.6,
      )
    if shadow is None and enable_shadow:
      assert camera_target is not None, 'camera_target is None'
      shadow = Shadow(centre=camera_target)
    elif not enable_shadow:
      shadow = None

    img = Renderer.get_camera_image(
      objects=instances,
      light=light,
      camera=camera,
      width=width,
      height=height,
      shadow_param=shadow,
      loop_unroll=loop_unroll,
    )
    arr = jax.lax.clamp(0., img, 1.)

    return arr

  def _render(instances, camera, target) -> jp.ndarray:
    _render = jax.jit(
      render_instances,
      static_argnames=("width", "height", "enable_shadow"),
      inline=True,
    )
    img = _render(instances=instances, width=canvas_width, height=canvas_height, camera=camera, camera_target=target)
    arr = transpose_for_display((img * 255).astype(jp.uint8))

    return arr

  _render_batch = jax.jit(jax.vmap(jax.jit(_render, inline=True)))
  _render_batch_lowered = _render_batch.lower(batched_instances, batched_camera, batched_target)
  _render_batch_compiled = _render_batch_lowered.compile()

  return _render_batch_compiled

### Branchless + loop unroll

In [None]:
#@title loop unroll options
loop_unroll_cases: list[int] = [1, 2, 3, 4, 16, 64, 192]  # 192 = 960 / 5

In [None]:
#@title benchmarking
timings = onp.zeros((len(loop_unroll_cases), 4), dtype=int)
column_indices = ["compilation (min)", "compilation (max)", "execution (min)", "execution (max)"]
row_indices = loop_unroll_cases.copy()

for i, loop_unroll in enumerate(loop_unroll_cases):
  ts = benchmark(lambda: profile_compile(batched_instances, batched_camera, batched_target, loop_unroll=loop_unroll))
  # only take the minimum/maximum time
  timings[i, 0] = ts[0]
  timings[i, 1] = ts[-1]

  _render_batch_compiled = profile_compile(batched_instances, batched_camera, batched_target, loop_unroll=loop_unroll)

  ts = benchmark(lambda: _render_batch_compiled(batched_instances, batched_camera, batched_target).block_until_ready())
  # only take the minimum/maximum time
  timings[i, 2] = ts[0]
  timings[i, 3] = ts[-1]

In [None]:
pd.DataFrame(timings, index=row_indices, columns=column_indices)

Unnamed: 0,compilation (min),compilation (max),execution (min),execution (max)
1,4655743276,7128431012,258989703,339372853
2,5355998136,6352320156,260569160,280682752
3,5969125779,7177999421,256503592,284426666
4,6569332576,7875871619,237489787,258422602
16,13811213933,15514422280,229073328,261196325
64,45607638476,47535025952,181243144,268247449
192,153797714268,156491627664,197630345,480351913


### Unroll = 64, number of frames

In [None]:
#@title Benchmarking
from tqdm.auto import tqdm

loop_unroll: int = 64
num_frames = list(range(1, 31))

num_frames_timings = onp.zeros((len(num_frames), 2), dtype=int)
column_indices = ["execution (min)", "execution (max)"]
row_indices = num_frames.copy()

for i, num_frame in tqdm(enumerate(num_frames), total=len(num_frames)):
  # prepare data
  batched_instances = jax.tree_map(lambda field: field[:num_frame].block_until_ready(), _batched_instances)
  batched_camera = jax.tree_map(lambda field: field[:num_frame].block_until_ready(), _batched_camera)
  batched_target = jax.tree_map(lambda field: field[:num_frame].block_until_ready(), _batched_target)

  # benchmark
  _render_batch_compiled = profile_compile(batched_instances, batched_camera, batched_target, loop_unroll=loop_unroll)

  ts = benchmark(lambda: _render_batch_compiled(batched_instances, batched_camera, batched_target).block_until_ready())
  # only take the minimum/maximum time
  num_frames_timings[i, 0] = ts[0]
  num_frames_timings[i, 1] = ts[-1]


pd.DataFrame(num_frames_timings, index=row_indices, columns=column_indices)

  0%|          | 0/30 [00:00<?, ?it/s]

Unnamed: 0,execution (min),execution (max)
1,134303425,428374016
2,215165081,535381431
3,280743384,575425603
4,352555107,640407796
5,415671180,710903950
6,483401680,797335161
7,549624319,834071602
8,615964585,916684731
9,682909404,980797175
10,749187381,1036435361


### Unroll = 1, number of frames

In [None]:
#@title Benchmarking
from tqdm.auto import tqdm

loop_unroll: int = 1
num_frames = list(range(1, 16))

num_frames_timings = onp.zeros((len(num_frames), 2), dtype=int)
column_indices = ["execution (min)", "execution (max)"]
row_indices = num_frames.copy()

for i, num_frame in tqdm(enumerate(num_frames), total=len(num_frames)):
  # prepare data
  batched_instances = jax.tree_map(lambda field: field[:num_frame].block_until_ready(), _batched_instances)
  batched_camera = jax.tree_map(lambda field: field[:num_frame].block_until_ready(), _batched_camera)
  batched_target = jax.tree_map(lambda field: field[:num_frame].block_until_ready(), _batched_target)

  # benchmark
  _render_batch_compiled = profile_compile(batched_instances, batched_camera, batched_target, loop_unroll=loop_unroll)

  ts = benchmark(lambda: _render_batch_compiled(batched_instances, batched_camera, batched_target).block_until_ready(), repeat=5)
  # only take the minimum/maximum time
  num_frames_timings[i, 0] = ts[0]
  num_frames_timings[i, 1] = ts[-1]


pd.DataFrame(num_frames_timings, index=row_indices, columns=column_indices)

  0%|          | 0/15 [00:00<?, ?it/s]

Unnamed: 0,execution (min),execution (max)
1,161219556,521410760
2,258736647,333003456
3,340829376,387884737
4,412454606,460350032
5,481284209,507172379
6,555310795,594454967
7,624655384,655084040
8,693651470,730945688
9,762756986,798568660
10,832026866,873455576


In [None]:
#@title # Terminating Colab Automatically
try:
  from google.colab import runtime
  print("Terminating Colab")
  runtime.unassign()
except ModuleNotFoundError:
  print("Not in Colab, skip termination")