# NeRF: Neural Radiance Fields

In [28]:
import torch
import numpy as np
import imageio.v2 as imageio
import cv2
import argparse
import json
import os
from tqdm import trange

In [29]:
args = {
    'config': 'configs/lego.txt',
    'expname': 'blender_paper_lego',
    'basedir': './logs',
    'datadir': './data/nerf_synthetic/lego',
    'netdepth': 8,
    'netwidth': 256,
    'netdepth_fine': 8,
    'netwidth_fine': 256,
    'N_rand': 1024,
    'lrate': 0.0005,
    'lrate_decay': 500,
    'chunk': 32768,
    'netchunk': 65536,
    'no_batching': True,
    'no_reload': False,
    'ft_path': None,
    'N_samples': 64,
    'N_importance': 128,
    'perturb': 1.0,
    'use_viewdirs': True,
    'i_embed': 0,
    'multires': 10,
    'multires_views': 4,
    'raw_noise_std': 0.0,
    'render_only': False,
    'render_test': False,
    'render_factor': 0,
    'precrop_iters': 500,
    'precrop_frac': 0.5,
    'dataset_type': 'blender',
    'testskip': 8,
    'shape': 'greek',
    'white_bkgd': True,
    'half_res': True,
    'factor': 8,
    'no_ndc': False,
    'lindisp': False,
    'spherify': False,
    'llffhold': 8,
    'i_print': 100,
    'i_img': 500,
    'i_weights': 10000,
    'i_testset': 50000,
    'i_video': 50000
}
args = argparse.Namespace(**args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float32)

In [30]:
trans_t = lambda t: torch.tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, t],
    [0, 0, 0, 1]], dtype=torch.float32)

rot_phi = lambda phi: torch.tensor([
    [1, 0, 0, 0],
    [0, np.cos(phi), -np.sin(phi), 0],
    [0, np.sin(phi), np.cos(phi), 0],
    [0, 0, 0, 1]], dtype=torch.float32)

rot_theta = lambda th: torch.tensor([
    [np.cos(th), 0, -np.sin(th), 0],
    [0, 1, 0, 0],
    [np.sin(th), 0, np.cos(th), 0],
    [0, 0, 0, 1]], dtype=torch.float32)


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    c2w = torch.tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]), dtype=torch.float32) @ c2w
    return c2w


def get_embedder(multires: int):
    include_input = True
    input_dims = 3
    max_freq_log2 = multires - 1
    num_freqs = multires
    log_sampling = True
    periodic_fns = [torch.sin, torch.cos]

    embed_fns = []
    out_dim = 0
    if include_input:
        embed_fns.append(lambda x: x)
        out_dim += input_dims
    if log_sampling:
        freq_bands = 2. ** torch.linspace(0., max_freq_log2, steps=num_freqs)
    else:
        freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, steps=num_freqs)

    for _freq in freq_bands:
        for _p_fn in periodic_fns:
            embed_fns.append(lambda x, p_fn=_p_fn, freq=_freq: p_fn(x * freq))
            out_dim += input_dims

    return lambda x: torch.cat([fn(x) for fn in embed_fns], -1), out_dim


def get_rays(H, W, K, pose):
    i, j = torch.meshgrid(torch.linspace(0, W - 1, W),
                          torch.linspace(0, H - 1, H), indexing='ij')  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * pose[:3, :3],
                       -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = pose[:3, -1].expand(rays_d.shape)
    return rays_o, rays_d


def batchify(fn, chunk):
    if chunk is None:
        return fn

    def ret(inputs):
        return torch.cat([fn(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0)

    return ret


def run_network(inputs: torch.Tensor, viewdirs: torch.Tensor, network_fn, embed_fn, embeddirs_fn, netchunk: int):
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)

    if viewdirs is not None:
        input_dirs = viewdirs[:, None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = torch.cat([embedded, embedded_dirs], -1)

    outputs_flat = batchify(network_fn, netchunk)(embedded)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs


def render_rays(rays_flat: torch.Tensor):
    pass


def batchify_rays(rays_flat: torch.Tensor, chunk: int):
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        pass
    return all_ret

## Setup Log System

In [31]:
os.makedirs(os.path.normpath(os.path.join(args.basedir, args.expname)), exist_ok=True)
with open(os.path.normpath(os.path.join(args.basedir, args.expname, 'args.txt')), 'w') as file:
    for arg in sorted(vars(args)):
        attr = getattr(args, arg)
        file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
    with open(os.path.normpath(os.path.join(args.basedir, args.expname, 'config.txt')), 'w') as file:
        file.write(open(args.config, 'r').read())

## Load Images and Poses

In [32]:
metas = {}
for _ in ['train', 'val', 'test']:
    with open(os.path.normpath(os.path.join(args.datadir, 'transforms_{}.json'.format(_))), 'r') as fp:
        metas[_] = json.load(fp)

_images = []
_poses = []
_counts = [0]
for _ in ['train', 'val', 'test']:
    _img_array = []
    _pose_array = []
    for frame in metas[_]['frames'][::(1 if _ == 'train' else args.testskip)]:
        _img_array.append(
            imageio.imread(os.path.normpath(os.path.join(args.datadir, frame['file_path'] + '.png'))))
        _pose_array.append(np.array(frame['transform_matrix']))
    _counts.append(_counts[-1] + len(_img_array))
    _images.append((np.array(_img_array) / 255.).astype(np.float32))
    _poses.append(np.array(_pose_array).astype(np.float32))

images_concatenated = np.concatenate(_images, 0)
poses_concatenated = np.concatenate(_poses, 0)

width, height = images_concatenated.shape[1:3]
near, far = 2., 6.
focal = .5 * width / np.tan(.5 * float(metas['train']['camera_angle_x']))
i_train, i_val, i_test = [np.arange(_counts[i], _counts[i + 1]) for i in range(3)]
K = np.array([
    [focal, 0, 0.5 * width],
    [0, focal, 0.5 * height],
    [0, 0, 1]
])

if args.half_res:
    width = width // 2
    height = height // 2
    focal = focal / 2.

    _ = np.zeros((images_concatenated.shape[0], height, width, 4))
    for i, img in enumerate(images_concatenated):
        _[i] = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
    images_concatenated = _

if args.white_bkgd:
    images_concatenated = images_concatenated[..., :3] * images_concatenated[..., -1:] + (
            1. - images_concatenated[..., -1:])
else:
    images_concatenated = images_concatenated[..., :3]

render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)
if args.render_test:
    render_poses = np.array(poses_concatenated[i_test])

## Create NeRF Model 

In [33]:
class NeRF(torch.nn.Module):
    def __init__(self, netdepth: int, netwidth: int, input_channel: int, output_channel: int, input_channel_views: int,
                 use_viewdirs: bool,
                 skips):
        super(NeRF, self).__init__()
        self.netdepth = netdepth
        self.netwidth = netwidth
        self.use_viewdirs = use_viewdirs
        self.input_channel = input_channel
        self.output_channel = output_channel
        self.input_channel_views = input_channel_views
        self.skips = skips

        self.pts_linears = torch.nn.ModuleList(
            [torch.nn.Linear(self.input_channel, self.netwidth)] + [
                torch.nn.Linear(self.netwidth, self.netwidth) if i not in self.skips else torch.nn.Linear(
                    self.netwidth + self.input_channel, self.netwidth) for i in range(self.netdepth - 1)])
        self.views_linears = torch.nn.ModuleList(
            [torch.nn.Linear(self.input_channel_views + self.netwidth, self.netwidth // 2)])
        # self.views_linears = torch.nn.ModuleList(
        #     [torch.nn.Linear(input_ch_views + self.netwidth, self.netwidth // 2)] + [
        #         torch.nn.Linear(self.netwidth // 2, self.netwidth // 2) for i in range(self.netdepth // 2)])

        if self.use_viewdirs:
            self.feature_linear = torch.nn.Linear(self.netwidth, self.netwidth)
            self.alpha_linear = torch.nn.Linear(self.netwidth, 1)
            self.rgb_linear = torch.nn.Linear(self.netwidth // 2, 3)
        else:
            self.output_linear = torch.nn.Linear(self.netwidth, self.output_channel)

    def forward(self, x):
        input_pts, input_views = torch.split(x, [self.input_channel, self.input_channel_views], dim=-1)
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = torch.nn.functional.relu(h)
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        if self.use_viewdirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = torch.cat([feature, input_views], -1)

            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = torch.nn.functional.relu(h)

            rgb = self.rgb_linear(h)
            outputs = torch.cat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)

        return outputs

In [34]:
embed_fn = torch.nn.Identity()
input_ch = 3
if args.i_embed == 0:
    embed_fn, input_ch = get_embedder(args.multires)

embeddirs_fn = None
input_ch_views = 0
if args.use_viewdirs:
    embeddirs_fn, input_ch_views = get_embedder(args.multires_views)

output_ch = 5 if args.N_importance > 0 else 4

model_coarse = NeRF(netdepth=args.netdepth, netwidth=args.netwidth, input_channel=input_ch, output_channel=output_ch,
                    input_channel_views=input_ch_views, use_viewdirs=args.use_viewdirs, skips=[4]).to(device)
grad_vars_coarse = list(model_coarse.parameters())
model_fine = NeRF(netdepth=args.netdepth_fine, netwidth=args.netwidth_fine, input_channel=input_ch,
                  output_channel=output_ch, input_channel_views=input_ch_views, use_viewdirs=args.use_viewdirs,
                  skips=[4]).to(device)
grad_vars_fine = list(model_coarse.parameters()) + list(model_fine.parameters())
optimizer = torch.optim.Adam(params=grad_vars_fine, lr=args.lrate, betas=(0.9, 0.999))

network_query_fn = lambda inputs, viewdirs, network_fn: run_network(inputs, viewdirs, network_fn,
                                                                    embed_fn=embed_fn,
                                                                    embeddirs_fn=embeddirs_fn,
                                                                    netchunk=args.netchunk)

In [35]:
# for i in trange(1, 2):
image_index = np.random.choice(i_train)
target_image = torch.tensor(images_concatenated[image_index])
target_pose = torch.tensor(poses_concatenated[image_index])
rays_o, rays_d = get_rays(height, width, K, torch.tensor(poses_concatenated[0]))

coords = torch.reshape(torch.stack(
    torch.meshgrid(torch.linspace(0, height - 1, height), torch.linspace(0, width - 1, width), indexing='ij'), -1),
    [-1, 2])
select_coords = coords[np.random.choice(coords.shape[0], size=[args.N_rand], replace=False)].long()
select_rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]
select_rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]
target_s = target_image[select_coords[:, 0], select_coords[:, 1]]

viewdirs = select_rays_d
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1, 3]).float()

rays_final = torch.cat([select_rays_o, select_rays_d, near * torch.ones_like(select_rays_d[..., :1]),
                        far * torch.ones_like(select_rays_d[..., :1]), viewdirs], -1)

print(rays_final.shape)

torch.Size([1024, 11])


In [55]:
N_rays = rays_final.shape[0]
rays_o, rays_d = rays_final[:,0:3], rays_final[:,3:6]
viewdirs = rays_final[:,-3:] if rays_final.shape[-1] > 8 else None
bounds = torch.reshape(rays_final[...,6:8], [-1,1,2])
near, far = bounds[...,0], bounds[...,1]
print(near[0, 0], far[0, 0])

tensor(2., device='cuda:0') tensor(6., device='cuda:0')


In [56]:
torch.linspace(0., 1., steps=args.N_samples)

tensor([0.0000, 0.0159, 0.0317, 0.0476, 0.0635, 0.0794, 0.0952, 0.1111, 0.1270,
        0.1429, 0.1587, 0.1746, 0.1905, 0.2063, 0.2222, 0.2381, 0.2540, 0.2698,
        0.2857, 0.3016, 0.3175, 0.3333, 0.3492, 0.3651, 0.3810, 0.3968, 0.4127,
        0.4286, 0.4444, 0.4603, 0.4762, 0.4921, 0.5079, 0.5238, 0.5397, 0.5556,
        0.5714, 0.5873, 0.6032, 0.6190, 0.6349, 0.6508, 0.6667, 0.6825, 0.6984,
        0.7143, 0.7302, 0.7460, 0.7619, 0.7778, 0.7937, 0.8095, 0.8254, 0.8413,
        0.8571, 0.8730, 0.8889, 0.9048, 0.9206, 0.9365, 0.9524, 0.9683, 0.9841,
        1.0000], device='cuda:0')