# NeRF: Neural Radiance Fields

In [2]:
import torch
import numpy as np
import imageio.v2 as imageio
import cv2
import argparse
import json
import os
import time
from tqdm import tqdm, trange

In [3]:
args = {
    'config': 'configs/lego.txt',
    'expname': 'blender_paper_lego',
    'basedir': './logs',
    'datadir': './data/nerf_synthetic/lego',
    'netdepth': 8,
    'netwidth': 256,
    'netdepth_fine': 8,
    'netwidth_fine': 256,
    'N_rand': 1024,
    'lrate': 0.0005,
    'lrate_decay': 500,
    'chunk': 32768,
    'netchunk': 65536,
    'no_batching': True,
    'no_reload': False,
    'ft_path': None,
    'N_samples': 64,
    'N_importance': 128,
    'perturb': 1.0,
    'use_viewdirs': True,
    'i_embed': 0,
    'multires': 10,
    'multires_views': 4,
    'raw_noise_std': 0.0,
    'render_only': False,
    'render_test': False,
    'render_factor': 0,
    'precrop_iters': 500,
    'precrop_frac': 0.5,
    'dataset_type': 'blender',
    'testskip': 8,
    'shape': 'greek',
    'white_bkgd': True,
    'half_res': True,
    'factor': 8,
    'no_ndc': False,
    'lindisp': False,
    'spherify': False,
    'llffhold': 8,
    'i_print': 100,
    'i_img': 500,
    'i_weights': 10000,
    'i_testset': 50000,
    'i_video': 50000
}
args = argparse.Namespace(**args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float32)

In [4]:
trans_t = lambda t: torch.tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, t],
    [0, 0, 0, 1]], dtype=torch.float32)

rot_phi = lambda phi: torch.tensor([
    [1, 0, 0, 0],
    [0, np.cos(phi), -np.sin(phi), 0],
    [0, np.sin(phi), np.cos(phi), 0],
    [0, 0, 0, 1]], dtype=torch.float32)

rot_theta = lambda th: torch.tensor([
    [np.cos(th), 0, -np.sin(th), 0],
    [0, 1, 0, 0],
    [np.sin(th), 0, np.cos(th), 0],
    [0, 0, 0, 1]], dtype=torch.float32)


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    c2w = torch.tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]), dtype=torch.float32) @ c2w
    return c2w


def get_embedder(multires: int):
    include_input = True
    input_dims = 3
    max_freq_log2 = multires - 1
    num_freqs = multires
    log_sampling = True
    periodic_fns = [torch.sin, torch.cos]

    embed_fns = []
    out_dim = 0
    if include_input:
        embed_fns.append(lambda x: x)
        out_dim += input_dims
    if log_sampling:
        freq_bands = 2. ** torch.linspace(0., max_freq_log2, steps=num_freqs)
    else:
        freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, steps=num_freqs)

    for _freq in freq_bands:
        for _p_fn in periodic_fns:
            embed_fns.append(lambda x, p_fn=_p_fn, freq=_freq: p_fn(x * freq))
            out_dim += input_dims

    return lambda x: torch.cat([fn(x) for fn in embed_fns], -1), out_dim

## Setup Log System

In [5]:
os.makedirs(os.path.normpath(os.path.join(args.basedir, args.expname)), exist_ok=True)
with open(os.path.normpath(os.path.join(args.basedir, args.expname, 'args.txt')), 'w') as file:
    for arg in sorted(vars(args)):
        attr = getattr(args, arg)
        file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
    with open(os.path.normpath(os.path.join(args.basedir, args.expname, 'config.txt')), 'w') as file:
        file.write(open(args.config, 'r').read())

## Load Images and Poses

In [6]:
metas = {}
for _ in ['train', 'val', 'test']:
    with open(os.path.normpath(os.path.join(args.datadir, 'transforms_{}.json'.format(_))), 'r') as fp:
        metas[_] = json.load(fp)

_images = []
_poses = []
_counts = [0]
for _ in ['train', 'val', 'test']:
    _img_array = []
    _pose_array = []
    for frame in metas[_]['frames'][::(1 if _ == 'train' else args.testskip)]:
        _img_array.append(
            imageio.imread(os.path.normpath(os.path.join(args.datadir, frame['file_path'] + '.png'))))
        _pose_array.append(np.array(frame['transform_matrix']))
    _counts.append(_counts[-1] + len(_img_array))
    _images.append((np.array(_img_array) / 255.).astype(np.float32))
    _poses.append(np.array(_pose_array).astype(np.float32))

images_concatenated = np.concatenate(_images, 0)
poses_concatenated = np.concatenate(_poses, 0)

width, height = images_concatenated.shape[1:3]
near, far = 2., 6.
focal = .5 * width / np.tan(.5 * float(metas['train']['camera_angle_x']))
i_train, i_val, i_test = [np.arange(_counts[i], _counts[i + 1]) for i in range(3)]
K = np.array([
    [focal, 0, 0.5 * width],
    [0, focal, 0.5 * height],
    [0, 0, 1]
])

if args.half_res:
    width = width // 2
    height = height // 2
    focal = focal / 2.

    _ = np.zeros((images_concatenated.shape[0], height, width, 4))
    for i, img in enumerate(images_concatenated):
        _[i] = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
    images_concatenated = _

if args.white_bkgd:
    images_concatenated = images_concatenated[..., :3] * images_concatenated[..., -1:] + (
            1. - images_concatenated[..., -1:])
else:
    images_concatenated = images_concatenated[..., :3]

render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)
if args.render_test:
    render_poses = np.array(poses_concatenated[i_test])

## Create NeRF Model 

In [None]:
class NeRF(torch.nn.Module):
    def __init__(self):
        super(NeRF, self).__init__()

In [7]:
embed_fn = torch.nn.Identity()
input_ch = 3
if args.i_embed:
    embed_fn, input_ch = get_embedder(args.multires)

embeddirs_fn = None
input_ch_views = 0
if args.use_viewdirs:
    embeddirs_fn, input_ch = get_embedder(args.multires_views)

output_ch = 5 if args.N_importance > 0 else 4