Copyright 2021 Google LLC.
SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# iNeRF implementation.

Implementation of "iNeRF: Inverting Neural Radiance Fields for Pose Estimation"
Website: https://yenchenlin.me/inerf/

Note: this implementation needs to be integrated with the public version of jaxnerf: https://github.com/google-research/google-research/tree/master/jaxnerf

In [None]:
# TODO(yenchenl): add pip installs.

In [None]:
import colabtools
import functools
import gc
import time
from absl import app
from absl import flags
from flax import jax_utils
from flax import nn
from flax import optim
from flax.metrics import tensorboard
from flax.training import checkpoints
import getpass
import jax
from jax import config
from jax import random
from jax import numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax.example_libraries import optimizers
import matplotlib.pyplot as plt
import PIL
from PIL import Image as PilImage
import numpy as np
from scipy.spatial.transform import Rotation as R
from six.moves import reload_module
import yaml

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform, jax.device_count())

In [None]:
# TODO(yenchenl): jaxnerf imports may need attention
from jaxnerf.nerf import datasets
from jaxnerf.nerf import model_utils
from jaxnerf.nerf import models
from jaxnerf.nerf import utils

In [None]:
# TODO(yenchenl): update paths
flags.DEFINE_string(
    "train_dir",
    "/path_to/jaxnerf_models/blender/lego/",
    "Experiment path.")
flags.mark_flag_as_required("train_dir")
flags.DEFINE_string(
    "data_dir",
    "/path_to/datasets/nerf/nerf_synthetic/lego",
    "Data path.")

In [None]:
FLAGS = flags.FLAGS
flags.DEFINE_integer("n_gpus", 1, "Number of gpus per train worker.")
flags.DEFINE_integer("n_gpus_eval", 1, "Number of gpus per eval worker.")
flags.mark_flag_as_required("data_dir")
flags.DEFINE_enum("config", "blender", ["blender","llff",],
                  "Choice of the reuse-able full configuration.")
flags.DEFINE_bool("is_train", True, "The job is in the training mode.")
flags.DEFINE_bool("use_tpu", False, "Whether to use tpu for training.")
flags.DEFINE_bool("use_tpu_eval", False, "Whether to use tpu for evaluation.")
flags.DEFINE_integer("render_every", 0,
                     "the interval in optimization steps between rendering"
                     "a validation example. 0 is recommended if using"
                     "parallel train and eval jobs.")
flags.DEFINE_integer(
    "chunk", None, "the size of chunks for evaluation inferences, set to"
    "the value that fits your GPU/TPU memory.")

In [None]:
flags.DEFINE_enum("dataset", "blender",
                  list(k for k in datasets.dataset_dict.keys()),
                  "The type of dataset feed to nerf.")
flags.DEFINE_bool("image_batching", False,
                  "sample rays in a batch from different images.")
flags.DEFINE_bool(
    "white_bkgd", True, "using white color as default background."
    "(used in the blender dataset only)")
flags.DEFINE_integer("batch_size", 1024,
                      "the number of rays in a mini-batch (for training).")
flags.DEFINE_integer(
    "factor", 4, "the downsample factor of images, 0 for no downsample.")
flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
flags.DEFINE_bool(
    "render_path", False, "render generated path if set true."
    "(used in the llff dataset only)")
flags.DEFINE_integer(
    "llffhold", 8, "will take every 1/N images as LLFF test set."
    "(used in the llff dataset only)")

In [None]:
# Model Flags
flags.DEFINE_enum("model", "nerf", list(k for k in models.model_dict.keys()),
                  "name of model to use.")
flags.DEFINE_float("near", 2., "near clip of volumetric rendering.")
flags.DEFINE_float("far", 6., "far clip of volumentric rendering.")
flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
flags.DEFINE_integer("net_depth_condition", 1,
                      "depth of the second part of MLP.")
flags.DEFINE_integer("net_width_condition", 128,
                      "width of the second part of MLP.")
flags.DEFINE_enum("activation", "relu", ["relu",],
                  "activation function used in MLP.")
flags.DEFINE_integer(
    "skip_layer", 4, "add a skip connection to the output vector of every"
    "skip_layer layers.")
flags.DEFINE_integer("alpha_channel", 1, "the number of alpha channels.")
flags.DEFINE_integer("rgb_channel", 3, "the number of rgb channels.")
flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
flags.DEFINE_integer("deg_point", 10,
                      "Degree of positional encoding for points.")
flags.DEFINE_integer("deg_view", 4,
                      "degree of positional encoding for viewdirs.")
flags.DEFINE_integer("n_samples", 64, "the number of samples on each ray.")
flags.DEFINE_integer("n_fine_samples", 128,
                      "the number of samples on each ray for the fine model.")
flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
flags.DEFINE_float(
    "noise_std", None, "std dev of noise added to regularize sigma output."
    "(used in the llff dataset only)")
flags.DEFINE_bool("lindisp", False,
                  "sampling linearly in disparity rather than depth.")

In [None]:
# Train Flags
flags.DEFINE_float("lr", 5e-4, "Learning rate for training.")
flags.DEFINE_integer("lr_decay", 500,
                      "the number of steps (in 1000s) for exponential"
                      "learning rate decay.")
flags.DEFINE_integer("max_steps", 1000000,
                      "the number of optimization steps.")
flags.DEFINE_integer("save_every", 10000,
                      "the number of steps to save a checkpoint.")
flags.DEFINE_integer("gc_every", 10000,
                      "the number of steps to run python garbage collection.")

# No randomization in eval!
flags.DEFINE_bool("randomized", False, "Whether stochastic or not.")


In [None]:
def compute_pose_error(T_estobject_cam, T_gtobject_cam):
  """Compute scalars for rotation and translation error between two poses."""
  T_estobject_gtobject = T_estobject_cam @ np.linalg.inv(T_gtobject_cam)
  rotation_error = np.arccos((np.trace(T_estobject_gtobject[:3, :3]) - 1) / 2)
  translation_error = np.linalg.norm(T_estobject_gtobject[:3, -1])
  return rotation_error * 180 / np.pi, translation_error

## Load dataset.

In [None]:
# TODO(yenchenl): blender config
blender_cfg = yaml.load(Open('/path_to/nerf/blender.yaml'))

In [None]:
rng = random.PRNGKey(20200823)
# Shift the numpy random seed by host_id() to shuffle data loaded by different
# hosts.
np.random.seed(20201473 + jax.host_id())

if FLAGS.config is not None:
  FLAGS.__dict__.update(blender_cfg)
if FLAGS.batch_size % jax.device_count() != 0:
  raise ValueError("Batch size must be divisible by the number of devices.")
dataset = datasets.get_dataset("test", FLAGS)

# Load pre-trained model.

In [None]:
rng, key = random.split(rng)
init_model, init_state = models.get_model(key, FLAGS)
dummy_optimizer_def = optim.Adam(FLAGS.lr)
dummy_optimizer = dummy_optimizer_def.create(init_model)
state = model_utils.TrainState(step=0, optimizer=dummy_optimizer,
                               model_state=init_state)
del init_model, init_state

In [None]:
state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
nerf_model = state.optimizer.target

In [None]:
idx = 0
test_image = dataset.images[idx]
test_pixels = test_image.reshape([dataset.resolution, 3])
test_pose = dataset.camtoworlds[idx]
print(f"Pixels/pose shapes: {test_pixels.shape}, {test_pose.shape}")

# Set the perturbation.

In [None]:
perturbation = jax.numpy.zeros((4, 4))
pred_pose = np.array(test_pose)

USE_ROTATION_PERTURBATION = True
USE_TRANSLATION_PERTURBATION = True

if USE_ROTATION_PERTURBATION:
  magnitude = 30.0  # Set the magnitue of rotation perturbation (degree).
  magnitude_rad = magnitude / 180.0 * np.pi
  direction = np.random.randn(3)
  translation_magnitude = np.linalg.norm(direction)
  eps = 1e-6
  if translation_magnitude < eps:  # Prevents divide-by-0.
    translation_magnitude = eps
  direction = direction / translation_magnitude
  perturbed_rotvec = direction * magnitude_rad
  pred_rot_mat = R.from_rotvec(perturbed_rotvec).as_matrix()
  delta = np.eye(4)
  delta[:3, :3] = pred_rot_mat
  pred_pose = delta @ pred_pose
if USE_TRANSLATION_PERTURBATION:
  magnitude = 0.05  # Set the magnitue of translation perturbation along xyz.
  perturbation = jax.ops.index_add(
      perturbation, jax.ops.index[:3, -1], magnitude)
  pred_pose = pred_pose + perturbation


print(f"Initial pose error: {compute_pose_error(pred_pose, test_pose)}")

In [None]:
pred_pose_init = pred_pose * 1.0

In [None]:
def RPtoSE3(R: jnp.ndarray, p: jnp.ndarray) -> np.ndarray:
  """Rotation and translation to homogeneous transform.

  Args:
    R: (3, 3) An orthonormal rotation matrix.
    p: (3,) A 3-vector representing an offset.

  Returns:
    X: (4, 4) The homogeneous transformation matrix described by rotating by R
      and translating by p.
  """
  p = jnp.reshape(p, (3, 1))
  return jnp.block([[R, p], [jnp.array([[0.0, 0.0, 0.0, 1.0]])]])

def DecomposeScrew(V: np.ndarray):
  """Decompose a screw V into a normalized axis and a magnitude.

  Args:
    V: (6,) A spatial vector describing a screw motion.

  Returns:
    S: (6,) A unit screw axis.
    theta: An angle of rotation such that S * theta = V.
  """
  w, v = jnp.split(V, 2)
  w_is_zero = jnp.allclose(w, jnp.zeros_like(w))
  v_is_zero = jnp.allclose(v, jnp.zeros_like(v))
  both_zero = w_is_zero * v_is_zero

  dtheta = jnp.where(
      both_zero, 0.0,
      jnp.where(1 - w_is_zero, jnp.linalg.norm(w), jnp.linalg.norm(v)))
  S = jnp.where(both_zero, V, V / dtheta)

  return (S, dtheta)

def Skew(w: jnp.ndarray) -> jnp.ndarray:
  """Build a skew matrix ("cross product matrix") for vector w.

  Modern Robotics Eqn 3.30.

  Args:
    w: (3,) A 3-vector

  Returns:
    W: (3, 3) A skew matrix such that W @ v == w x v
  """
  w = jnp.reshape(w, (3))
  return jnp.array([[0.0, -w[2], w[1]],\
                   [w[2], 0.0, -w[0]],\
                   [-w[1], w[0], 0.0]])

def ExpSO3(w: jnp.ndarray, theta: float) -> np.ndarray:
  """Exponential map from Lie algebra so3 to Lie group SO3.

  Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula.

  Args:
    w: (3,) An axis of rotation.
    theta: An angle of rotation.

  Returns:
    R: (3, 3) An orthonormal rotation matrix representing a rotation of
      magnitude theta about axis w.
  """
  W = Skew(w)
  return jnp.eye(3) + jnp.sin(theta) * W + (1.0 - jnp.cos(theta)) * W @ W

def ExpSE3(S: jnp.ndarray, theta: float) -> np.ndarray:
  """Exponential map from Lie algebra so3 to Lie group SO3.

  Modern Robotics Eqn 3.88.

  Args:
    S: (6,) A screw axis of motion.
    theta: Magnitude of motion.

  Returns:
    a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating
      motion of magnitude theta about S for one second.
  """
  w, v = jnp.split(S, 2)
  W = Skew(w)
  R = ExpSO3(w, theta)
  p = (theta * jnp.eye(3) + (1.0 - jnp.cos(theta)) * W +
       (theta - jnp.sin(theta)) * W @ W) @ v
  return RPtoSE3(R, p)

# iNeRF training.

In [None]:
def train_step_exp(screw_delta, test_pixels, hwf, batch_size, nerf_model):
  """
  screw_delta (6,): screw of delta pose, relative to initial pose
  test_pixels (H*W, 3): ground truth image's pixels.
  hwf (3): image height, width, and focal length.
  batch_size: number of rays and pixels to sample.
  """
  # rng_key, key_0, key_1 = random.split(rng_key, 3)
  rng_key = random.PRNGKey(20200823)
  rng_key, key_0, key_1 = random.split(rng_key, 3)

  def loss_fn(screw_delta):
    """screw_delta is a (6,)"""

    # pred_pose_delta is a (4,4) matrix, SE3, relative to pred_pose_init
    pred_pose_delta = ExpSE3(*DecomposeScrew(screw_delta))

    # pred_pose is the full new estimated pose.
    pred_pose = pred_pose_delta @  pred_pose_init

    resolution = test_pixels.shape[0]
    h, w, f = hwf
    x, y = jnp.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        jnp.arange(w),  # X-Axis (columns)
        jnp.arange(h),  # Y-Axis (rows)
        indexing="xy")
    dirs = jnp.stack([(x - w * 0.5) / f,
                      -(y - h * 0.5) / f, -jnp.ones_like(x)],
                    axis=-1)
    rays_d = ((dirs[None, ..., None, :] * pred_pose[None, None, :3, :3]).sum(axis=-1))
    rays_o = jnp.broadcast_to(pred_pose[None, None, :3, -1],
                              list(rays_d.shape))
    rays = jnp.concatenate([rays_o, rays_d], axis=-1)
    rays = rays.reshape([resolution, rays.shape[-1]])

    # Sample rays.
    ray_sample_indices = np.random.randint(0, resolution, (batch_size,))
    
    batch_pixels = jnp.array(test_pixels[ray_sample_indices][None, :, :])
    batch_rays = rays[ray_sample_indices][None, :, :]
    batch = {'pixels': batch_pixels, 'rays': batch_rays}
    model_outputs = nerf_model(key_0, key_1, batch_rays[0])
    rgb = model_outputs[-1][0]
    # MSE
    loss = ((rgb - batch["pixels"][0][..., :3])**2).mean()
    return loss

  # Forward.
  grad_pose = jax.value_and_grad(loss_fn)
  loss, grad = grad_pose(screw_delta)
  return loss, grad

In [None]:
# Start over by re-initializing the initial relative pose.
delta_init = np.random.randn(6) * 1e-6
screw_delta = delta_init
print(screw_delta)

In [None]:
# This will just re-initialize the optimizier, but not the current guess.

initial_learning_rate = 1e-2
decay_steps = 100
decay_rate = 0.6
exp_schedule = optimizers.exponential_decay(initial_learning_rate, decay_steps, decay_rate)
step = []
rate = []
for i in range(1000):
  step.append(i)
  rate.append(exp_schedule(i))

import matplotlib.pyplot as plt
plt.title('learning rate')
plt.plot(step, rate)
plt.yscale("log")
plt.show()

opt_init, opt_update, get_params = optimizers.adam(step_size=exp_schedule)
opt_state = opt_init(screw_delta)
print(screw_delta)

In [None]:
# Inference loop.
hwf = (dataset.h, dataset.w, dataset.focal)
inference_batch_size = 2048
n_iters = 300

pred_poses = []
R_errors = []
t_errors = []
losses = []
for i in range(n_iters+1):
  screw_delta = get_params(opt_state)
  pred_pose_delta = ExpSE3(*DecomposeScrew(screw_delta))

  pred_pose = pred_pose_delta @ pred_pose_init
  loss, grad = train_step_exp(screw_delta, test_pixels, hwf,
                              inference_batch_size, nerf_model)
  opt_state = opt_update(i, grad, opt_state)

  losses.append(loss)
  pred_poses.append(np.array(pred_pose))
  R_error, t_error = compute_pose_error(np.array(pred_pose), test_pose)
  R_errors.append(R_error)
  t_errors.append(t_error)
  if i % 50 == 0:
    print(f"{i}/{n_iters} iterations ...")
    print(f"loss: {loss} | R error: {R_error} | t error: {t_error}")

In [None]:
print(len(losses))

fig, axes = plt.subplots(3, 1)
fig.tight_layout()

axes[0].set_title('MSE Loss')
axes[0].plot(range(n_iters+1), losses)
axes[1].set_title('Rotation Error')
axes[1].plot(range(n_iters+1), R_errors, color='r')
axes[2].set_title('Translation Error')
axes[2].plot(range(n_iters+1), t_errors, color='c')

# To make these each log scale on y axes
[ax.set_yscale('log') for ax in axes]

# Show video.

In [None]:
render_fn = jax.pmap(
    # Note rng_keys are useless in eval mode since there's no randomness.
    # pylint: disable=g-long-lambda
    lambda key_0, key_1, model, rays: jax.lax.all_gather(
        model(key_0, key_1, rays), axis_name="batch"),
    in_axes=(None, None, None, 0),  # Only distribute the data input.
    donate_argnums=3,
    axis_name="batch",
)

render_fn_jit = jit(render_fn)

In [None]:
def get_batch(pred_pose, test_pixels, hwf, batch_size):
    resolution = test_pixels.shape[0]
    h, w, f = hwf
    x, y = jnp.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        jnp.arange(w),  # X-Axis (columns)
        jnp.arange(h),  # Y-Axis (rows)
        indexing="xy")
    dirs = np.stack([(x - w * 0.5) / f,
                      -(y - h * 0.5) / f, -jnp.ones_like(x)],
                    axis=-1)
    rays_d = ((dirs[None, ..., None, :] * pred_pose[None, None, :3, :3]).sum(axis=-1))
    rays_o = jnp.broadcast_to(pred_pose[None, None, :3, -1],
                              list(rays_d.shape))
    rays = jnp.concatenate([rays_o, rays_d], axis=-1)[0]

    batch = {'rays': rays}
    return batch

In [None]:
n_frames = 20
images = []
n_iters_reduced = 200  # set this different from n_iters, so can focus on the part where the most happens.
for idx in range(0, n_iters_reduced+1, n_iters_reduced//n_frames):
  print("Rendering, ", idx)
  batch = get_batch(pred_poses[idx], test_pixels, hwf, batch_size)
  pred_color, pred_disp, pred_acc = utils.render_image(
      state, batch, render_fn_jit, rng, chunk=8192)
  images.append(pred_color)

In [None]:
def save_animation(images, test_image):
  pil_ims = [
      PilImage.fromarray(
          (np.clip(np.array(im), 0.0, 1.0) * 255.0).astype(np.uint8))
      for im in images
  ]
  test_im = PilImage.fromarray(
          (np.clip(test_image, 0.0, 1.0) * 255.0).astype(np.uint8))
  pil_ims = [PIL.Image.blend(im, test_im, 0.5) for im in pil_ims]
  pil_ims[0].save(
      '/tmp/optimization_animation.gif',
      save_all=True,
      append_images=pil_ims[1:],
      duration=200,
      loop=0)
  colabtools.publish.image('/tmp/optimization_animation.gif')


save_animation(images, test_image)