# Render a HyperNeRF video!

**Author**: [Keunhong Park](https://keunhong.com)

[[Project Page](https://hypernerf.github.io)]
[[Paper](https://arxiv.org/abs/2106.13228)]
[[GitHub](https://github.com/google/hypernerf)]

This notebook renders a video using the test cameras generated in the capture processing notebook.

You can also load your own custom cameras by modifying the code slightly.

### Instructions

1. Convert a video into our dataset format using the [capture processing notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb).
2. Train a HyperNeRF model using the [training notebook](https://colab.sandbox.google.com/github/google/hypernerf/blob/main/notebooks/HyperNeRF_Training.ipynb)
3. Run this notebook!


### Notes
 * Please report issues on the [GitHub issue tracker](https://github.com/google/hypernerf/issues).

## Environment Setup

In [None]:
!pip install flax immutabledict mediapy
!pip install git+https://github.com/google/hypernerf

In [1]:
# @title Configure notebook runtime
# @markdown If you would like to use a GPU runtime instead, change the runtime type by going to `Runtime > Change runtime type`. 
# @markdown You will have to use a smaller batch size on GPU.

import jax

runtime_type = 'gpu'  # @param ['gpu', 'tpu']
if runtime_type == 'tpu':
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

print('Detected Devices:', jax.devices())

Detected Devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]


In [None]:
# @title Mount Google Drive
# @markdown Mount Google Drive onto `/content/gdrive`. You can skip this if running locally.

from google.colab import drive
drive.mount('/content/gdrive')

In [2]:
# @title Define imports and utility functions.

import jax
from jax.config import config as jax_config
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

import flax
import flax.linen as nn
from flax import jax_utils
from flax import optim
from flax.metrics import tensorboard
from flax.training import checkpoints

from absl import logging
from io import BytesIO
import random as pyrandom
import numpy as np
import PIL
import IPython
import tempfile
import imageio
import mediapy
from IPython.display import display, HTML
from base64 import b64encode


# Monkey patch logging.
def myprint(msg, *args, **kwargs):
 print(msg % args)

logging.info = myprint 
logging.warn = myprint
logging.error = myprint

In [3]:
# @title Model and dataset configuration
# @markdown Change the directories to where you saved your capture and experiment.

import os

import sys
if os.path.exists(os.path.expanduser('~/hypernerf-barf/')):
    project_root = os.path.expanduser('~/hypernerf-barf/')
elif os.path.exists(os.path.expanduser('~/3d_cv/repos/hypernerf_barf/')):
    project_root = os.path.expanduser('~/3d_cv/repos/hypernerf_barf/')
else:
    raise NotImplemented
sys.path.insert(0, project_root)

if os.path.exists('/hdd/zhiwen/data/hypernerf/raw/'):
    data_root = '/hdd/zhiwen/data/hypernerf/raw/'
elif os.path.exists('/home/zwyan/3d_cv/data/hypernerf/raw/'):
    data_root = '/home/zwyan/3d_cv/data/hypernerf/raw/'
else:
    raise NotImplemented

from pathlib import Path
from pprint import pprint
import gin
from IPython.display import display, Markdown

from hypernerf import models
from hypernerf import modules
from hypernerf import warping
from hypernerf import datasets
from hypernerf import configs


# @markdown The working directory where the trained model is.
train_dir = os.path.join(project_root, 'experiments/spec_exp01')  # @param {type: "string"}

# @markdown The directory to the dataset capture.
data_dir = os.path.join(data_root, 'americano/')  # @param {type: "string"}

checkpoint_dir = Path(train_dir, 'checkpoints')
checkpoint_dir.mkdir(exist_ok=True, parents=True)

config_path = Path(train_dir, 'config.gin')
with open(config_path, 'r') as f:
  logging.info('Loading config from %s', config_path)
  config_str = f.read()
gin.parse_config(config_str)

config_path = Path(train_dir, 'config.gin')
with open(config_path, 'w') as f:
  logging.info('Saving config to %s', config_path)
  f.write(config_str)
    
exp_config = configs.ExperimentConfig()
train_config = configs.TrainConfig()
eval_config = configs.EvalConfig()
spec_config = configs.SpecularConfig()

display(Markdown(
    gin.config.markdown(gin.config_str())))

Loading config from /home/zwyan/3d_cv/repos/hypernerf_barf/experiments/spec_exp01/config.gin
Saving config to /home/zwyan/3d_cv/repos/hypernerf_barf/experiments/spec_exp01/config.gin


#### Macros:

    batch_size = 512
    CONSTANT_ELASTIC_LOSS_SCHED = {'type': 'constant', 'value': %elastic_init_weight}
    data_dir = '/hdd/zhiwen/data/hypernerf/raw/americano/'
    DEFAULT_LR_SCHEDULE = \
        {'final_value': %final_lr,
         'initial_value': %init_lr,
         'num_steps': %max_steps,
         'type': 'exponential'}
    elastic_init_weight = 0.01
    final_lr = 1e-05
    hyper_num_dims = 4
    hyper_point_max_deg = 1
    hyper_point_min_deg = 0
    hyper_sheet_max_deg = 6
    hyper_sheet_min_deg = 0
    image_scale = 4
    init_lr = 0.001
    max_steps = 250000
    spatial_point_max_deg = 8
    spatial_point_min_deg = 0
    warp_max_deg = 4
    warp_min_deg = 0
    
#### Parameters for ExperimentConfig:

    ExperimentConfig.datasource_cls = @NerfiesDataSource
    ExperimentConfig.image_scale = %image_scale
    ExperimentConfig.random_seed = 0
    ExperimentConfig.subname = None
    
#### Parameters for warp/GLOEmbed:

    warp/GLOEmbed.num_dims = 8
    
#### Parameters for HyperSheetMLP:

    HyperSheetMLP.depth = 6
    HyperSheetMLP.max_deg = %hyper_sheet_max_deg
    HyperSheetMLP.min_deg = %hyper_sheet_min_deg
    HyperSheetMLP.output_channels = %hyper_num_dims
    HyperSheetMLP.skips = (4,)
    HyperSheetMLP.use_residual = False
    HyperSheetMLP.width = 64
    
#### Parameters for NerfiesDataSource:

    NerfiesDataSource.camera_type = 'json'
    NerfiesDataSource.data_dir = %data_dir
    NerfiesDataSource.shuffle_pixels = False
    NerfiesDataSource.test_camera_trajectory = 'orbit-mild'
    
#### Parameters for NerfModel:

    NerfModel.activation = @jax.nn.relu
    NerfModel.alpha_channels = 1
    NerfModel.hyper_embed_cls = @hyper/GLOEmbed
    NerfModel.hyper_embed_key = 'appearance'
    NerfModel.hyper_point_max_deg = %hyper_point_max_deg
    NerfModel.hyper_point_min_deg = %hyper_point_min_deg
    NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP
    NerfModel.hyper_sheet_use_input_points = True
    NerfModel.hyper_slice_method = 'bendy_sheet'
    NerfModel.hyper_use_warp_embed = True
    NerfModel.nerf_embed_cls = @nerf/GLOEmbed
    NerfModel.nerf_embed_key = 'appearance'
    NerfModel.nerf_rgb_branch_depth = 1
    NerfModel.nerf_rgb_branch_width = 128
    NerfModel.nerf_skips = (4,)
    NerfModel.nerf_trunk_depth = 8
    NerfModel.nerf_trunk_width = 256
    NerfModel.noise_std = None
    NerfModel.norm_type = 'none'
    NerfModel.num_coarse_samples = 64
    NerfModel.num_fine_samples = 64
    NerfModel.rgb_channels = 3
    NerfModel.spatial_point_max_deg = %spatial_point_max_deg
    NerfModel.spatial_point_min_deg = %spatial_point_min_deg
    NerfModel.use_alpha_condition = False
    NerfModel.use_linear_disparity = False
    NerfModel.use_nerf_embed = False
    NerfModel.use_posenc_identity = False
    NerfModel.use_rgb_condition = False
    NerfModel.use_sample_at_infinity = True
    NerfModel.use_stratified_sampling = True
    NerfModel.use_viewdirs = True
    NerfModel.use_warp = True
    NerfModel.use_white_background = False
    NerfModel.viewdir_max_deg = 4
    NerfModel.viewdir_min_deg = 0
    NerfModel.warp_embed_cls = @warp/GLOEmbed
    NerfModel.warp_embed_key = 'warp'
    NerfModel.warp_field_cls = @SE3Field
    
#### Parameters for SE3Field:

    SE3Field.hyper_depth = 0
    SE3Field.hyper_init = None
    SE3Field.hyper_width = 0
    SE3Field.max_deg = %warp_max_deg
    SE3Field.min_deg = %warp_min_deg
    SE3Field.norm = None
    SE3Field.num_hyper_dims = 0
    SE3Field.pivot_depth = 0
    SE3Field.pivot_width = 128
    SE3Field.rotation_depth = 0
    SE3Field.rotation_width = 128
    SE3Field.skips = (4,)
    SE3Field.translation_depth = 0
    SE3Field.translation_width = 128
    SE3Field.trunk_depth = 6
    SE3Field.trunk_width = 128
    SE3Field.use_posenc_identity = False
    
#### Parameters for SpecularConfig:

    SpecularConfig.screw_input_mode = 'rotation'
    
#### Parameters for TrainConfig:

    TrainConfig.background_loss_weight = 1.0
    TrainConfig.background_points_batch_size = 16384
    TrainConfig.batch_size = %batch_size
    TrainConfig.curvature_loss_alpha = 0
    TrainConfig.curvature_loss_scale = 0
    TrainConfig.curvature_loss_spacing = 0
    TrainConfig.curvature_loss_weight_schedule = None
    TrainConfig.elastic_loss_type = 'log_svals'
    TrainConfig.elastic_loss_weight_schedule = %CONSTANT_ELASTIC_LOSS_SCHED
    TrainConfig.elastic_reduce_method = 'weight'
    TrainConfig.histogram_every = 100
    TrainConfig.hyper_alpha_schedule = \
        {'schedules': [(1000, ('constant', 0.0)),
                       (0, ('linear', 0.0, %hyper_point_max_deg, 10000))],
         'type': 'piecewise'}
    TrainConfig.hyper_reg_loss_weight = 0.001
    TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg)
    TrainConfig.log_every = 100
    TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE
    TrainConfig.max_steps = %max_steps
    TrainConfig.nerf_alpha_schedule = ('constant', %spatial_point_max_deg)
    TrainConfig.print_every = 10
    TrainConfig.save_every = 1000
    TrainConfig.shuffle_buffer_size = 5000000
    TrainConfig.use_background_loss = False
    TrainConfig.use_curvature_loss = False
    TrainConfig.use_elastic_loss = False
    TrainConfig.use_hyper_reg_loss = False
    TrainConfig.use_warp_reg_loss = True
    TrainConfig.use_weight_norm = False
    TrainConfig.warp_alpha_schedule = \
        {'final_value': %warp_max_deg,
         'initial_value': %warp_min_deg,
         'num_steps': 50000,
         'type': 'linear'}
    TrainConfig.warp_reg_loss_alpha = -2.0
    TrainConfig.warp_reg_loss_scale = 0.001
    TrainConfig.warp_reg_loss_weight = 0.001

In [4]:
# @title Create datasource and show an example.

from hypernerf import datasets
from hypernerf import image_utils

dummy_model = models.NerfModel({}, 0, 0)
datasource = exp_config.datasource_cls(
    data_dir=data_dir,
    image_scale=exp_config.image_scale,
    random_seed=exp_config.random_seed,
    # Enable metadata based on model needs.
    use_warp_id=dummy_model.use_warp,
    use_appearance_id=(
        dummy_model.nerf_embed_key == 'appearance'
        or dummy_model.hyper_embed_key == 'appearance'),
    use_camera_id=dummy_model.nerf_embed_key == 'camera',
    use_time=dummy_model.warp_embed_key == 'time')

mediapy.show_image(datasource.load_rgb(datasource.train_ids[0]))

*** Loading dataset IDs from /home/zwyan/3d_cv/data/hypernerf/raw/americano/dataset.json
Creating datasource of type NerfiesDataSource with use_appearance_id=True, use_camera_id=False, use_warp_id=True, use_depth=False, use_time=False, train_stride=1, val_stride=1


In [5]:
# @title Load model
# @markdown Defines the model and initializes its parameters.

from flax.training import checkpoints
from hypernerf import models
from hypernerf import model_utils
from hypernerf import schedules
from hypernerf import training

rng = random.PRNGKey(exp_config.random_seed)
np.random.seed(exp_config.random_seed + jax.process_index())
devices_to_use = jax.devices()

learning_rate_sched = schedules.from_config(train_config.lr_schedule)
nerf_alpha_sched = schedules.from_config(train_config.nerf_alpha_schedule)
warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)
elastic_loss_weight_sched = schedules.from_config(
train_config.elastic_loss_weight_schedule)
hyper_alpha_sched = schedules.from_config(train_config.hyper_alpha_schedule)
hyper_sheet_alpha_sched = schedules.from_config(
    train_config.hyper_sheet_alpha_schedule)

rng, key = random.split(rng)
params = {}
model, params['model'] = models.construct_nerf(
      key,
      batch_size=train_config.batch_size,
      embeddings_dict=datasource.embeddings_dict,
      near=datasource.near,
      far=datasource.far,
      screw_input_mode=spec_config.screw_input_mode)

optimizer_def = optim.Adam(learning_rate_sched(0))
optimizer = optimizer_def.create(params)

state = model_utils.TrainState(
    optimizer=optimizer,
    nerf_alpha=nerf_alpha_sched(0),
    warp_alpha=warp_alpha_sched(0),
    hyper_alpha=hyper_alpha_sched(0),
    hyper_sheet_alpha=hyper_sheet_alpha_sched(0))
scalar_params = training.ScalarParams(
    learning_rate=learning_rate_sched(0),
    elastic_loss_weight=elastic_loss_weight_sched(0),
    warp_reg_loss_weight=train_config.warp_reg_loss_weight,
    warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,
    warp_reg_loss_scale=train_config.warp_reg_loss_scale,
    background_loss_weight=train_config.background_loss_weight,
    hyper_reg_loss_weight=train_config.hyper_reg_loss_weight
)

logging.info('Restoring checkpoint from %s', checkpoint_dir)
state = checkpoints.restore_checkpoint(checkpoint_dir, state)
step = state.optimizer.state.step + 1
state = jax_utils.replicate(state, devices=devices_to_use)
del params



Restoring checkpoint from /home/zwyan/3d_cv/repos/hypernerf_barf/experiments/spec_exp01/checkpoints
Restoring checkpoint from /home/zwyan/3d_cv/repos/hypernerf_barf/experiments/spec_exp01/checkpoints/checkpoint_250000


In [6]:
# @title Define pmapped render function.

import functools
from hypernerf import evaluation

devices = jax.devices()


def _model_fn(key_0, key_1, params, rays_dict, extra_params):
  out = model.apply({'params': params},
                    rays_dict,
                    extra_params=extra_params,
                    rngs={
                        'coarse': key_0,
                        'fine': key_1
                    },
                    mutable=False,
                    screw_input_mode=spec_config.screw_input_mode
                   )
  return jax.lax.all_gather(out, axis_name='batch')

pmodel_fn = jax.pmap(
    # Note rng_keys are useless in eval mode since there's no randomness.
    _model_fn,
    in_axes=(0, 0, 0, 0, 0),  # Only distribute the data input.
    devices=devices_to_use,
    axis_name='batch',
)

render_fn = functools.partial(evaluation.render_image,
                              model_fn=pmodel_fn,
                              device_count=len(devices),
                              chunk=eval_config.chunk)

In [7]:
# @title Load cameras.

from hypernerf import utils

camera_path = 'fix_camera_93'  # @param {type: 'string'}

camera_dir = Path(data_dir, camera_path)
print(f'Loading cameras from {camera_dir}')
test_camera_paths = datasource.glob_cameras(camera_dir)
test_cameras = utils.parallel_map(datasource.load_camera, test_camera_paths, show_pbar=True)

Loading cameras from /home/zwyan/3d_cv/data/hypernerf/raw/americano/fix_camera_93


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 526/526 [00:00<00:00, 4557.35it/s]


In [8]:
# @title Render video frames.
from hypernerf import visualization as viz


rng = rng + jax.process_index()  # Make random seed separate across hosts.
keys = random.split(rng, len(devices))

results = []
for i in range(len(test_cameras)):
  print(f'Rendering frame {i+1}/{len(test_cameras)}')
  camera = test_cameras[i]
  batch = datasets.camera_to_rays(camera)
  batch['metadata'] = {
      'appearance': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * i,
      'warp': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * i,
  }

  render = render_fn(state, batch, rng=rng)
  rgb = np.array(render['rgb'])
  depth_med = np.array(render['med_depth'])

  sigma_gradient = np.array(render['ray_sigma_gradient'])
  sigma_gradient = sigma_gradient / 2.0 + 0.5

  results.append((rgb, depth_med, sigma_gradient))
  depth_viz = viz.colorize(depth_med.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)
  mediapy.show_images([rgb, depth_viz, sigma_gradient])

Rendering frame 1/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192


  return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
  return jax.tree_map(lambda x: x[0], tree)
  ret_map = jax.tree_map(lambda x: utils.unshard(x, padding), ret_map)
  ret_map = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *ret_maps)


Rendering took 95.5


Rendering frame 2/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192
Rendering took 9.98


Rendering frame 3/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192
Rendering took 9.46


Rendering frame 4/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192
Rendering took 9.60


Rendering frame 5/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192
Rendering took 9.88


Rendering frame 6/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192
Rendering took 9.80


Rendering frame 7/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192
Rendering took 9.86


Rendering frame 8/526
Rendering: num_batches = 16, num_rays = 128640, chunk = 8192


KeyboardInterrupt: 

In [None]:
# @title Show rendered video.

fps = 30  # @param {type:'number'}

frames = []
for rgb, depth, sigma_gradient in results:
  depth_viz = viz.colorize(depth.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)
  frame = np.concatenate([rgb, depth_viz, sigma_gradient], axis=1)
  frames.append(image_utils.image_to_uint8(frame))

mediapy.set_show_save_dir(train_dir)
mediapy.show_video(frames, fps=fps, title="result_{}_with_norm".format(camera_path))