# NeRF

## 0. Hyperparameters

In [11]:
import torch
import numpy as np
import imageio.v2 as imageio
import cv2
import argparse
import json
import os
from typing import Tuple

args = {
    # 1. Model
    "coarse_net_width": 256,
    "coarse_net_depth": 8,
    "fine_net_width": 256,
    "fine_net_depth": 8,
    "learning_rate": 5e-4,

    # 2. Position Encoding
    'datadir': './data/nerf_synthetic/lego',
}
args = argparse.Namespace(**args)

## 1. Load Data and Preprocess

In [12]:
def load_data_blender(in_type: str, in_skip: int = 1) -> Tuple[np.ndarray, np.ndarray, int, float]:
    with open(os.path.normpath(os.path.join(args.datadir, 'transforms_{}.json'.format(in_type))), 'r') as fp:
        _meta = json.load(fp)
        _image_array = []
        _pose_array = []
        for _frame in _meta['frames'][::in_skip]:
            _image_array.append(
                imageio.imread(os.path.normpath(os.path.join(args.datadir, _frame['file_path'] + '.png'))))
            _pose_array.append(np.array(_frame['transform_matrix']))
        return (np.array(_image_array) / 255.).astype(np.float32), np.array(_pose_array), len(_meta['frames']), _meta[
            'camera_angle_x']


def resample_images(in_origin: np.ndarray, in_focal, in_rate: float) -> Tuple[np.ndarray, int, int, float]:
    _target_height, _target_width, _target_focal = int(in_origin.shape[1] * in_rate), int(
        in_origin.shape[2] * in_rate), in_focal * in_rate
    _ret = np.zeros((in_origin.shape[0], _target_height, _target_width, 4))
    for _idx, _image in enumerate(in_origin):
        _ret[_idx] = cv2.resize(_image, (_target_width, _target_height), interpolation=cv2.INTER_AREA)
    return _ret, _target_width, _target_height, _target_focal


def make_white_background(in_origin: np.ndarray) -> np.ndarray:
    return in_origin[..., :3] * in_origin[..., -1:] + (1. - in_origin[..., -1:])

## 2. Model

In [13]:
class NeRF(torch.nn.Module):
    def __init__(self, net_width: int, net_depth: int, input_channel: int, output_channel: int,
                 input_channel_views: int, use_view: bool, skips: list):
        super(NeRF, self).__init__()
        self.net_width = net_width
        self.net_depth = net_depth
        self.input_channel = input_channel
        self.output_channel = output_channel
        self.input_channel_views = input_channel_views
        self.use_view = use_view
        self.skips = skips

        self.points_layers = torch.nn.ModuleList(
            [torch.nn.Linear(in_features=input_channel, out_features=net_width)] + [
                torch.nn.Linear(in_features=net_width, out_features=net_width) if i not in skips else torch.nn.Linear(
                    in_features=(net_width + input_channel), out_features=net_width) for i in range(net_depth - 1)])
        self.views_layers = torch.nn.ModuleList(
            [torch.nn.Linear(in_features=input_channel_views + net_width, out_features=net_width // 2)])
        if use_view:
            self.feature_layer = torch.nn.Linear(net_width, net_width)
            self.alpha_layer = torch.nn.Linear(net_width, 1)
            self.rgb_layer = torch.nn.Linear(net_width // 2, 3)
        else:
            self.output_linear = torch.nn.Linear(net_width, output_channel)

    def forward(self, x):
        pass

## 3. Position Encoding

In [14]:
def generate_rays(in_width: int, in_height: int, in_focal: float, in_pose: torch.Tensor) -> Tuple[
    torch.Tensor, torch.Tensor]:
    _x, _y = torch.meshgrid(torch.linspace(0, in_width - 1, in_width), torch.linspace(0, in_height - 1, in_height),
                            indexing='xy')
    _rays_dirs = torch.matmul(
        torch.stack([(_x - in_width / 2.) / in_focal, -(_y - in_height / 2.) / in_focal, -torch.ones_like(_x)], -1),
        in_pose[0:3, 0:3].T)  # seems no need to normalize
    _rays_oris = in_pose[0:3, -1].expand(_rays_dirs.shape)
    return _rays_oris, _rays_dirs


def generate_query_points(in_rays_oris: torch.Tensor, in_rays_dirs: torch.Tensor, in_near: float, in_far: float,
                          in_batch_size: int, in_sample_size: int) -> torch.Tensor:
    _height, _width = in_rays_oris.shape[0], in_rays_oris.shape[1]
    _coords = torch.reshape(torch.stack(
        torch.meshgrid(torch.linspace(0, _height - 1, _height), torch.linspace(0, _width - 1, _width),
                       indexing='ij'), -1), [-1, 2])
    _selected_coords = _coords[np.random.choice(_coords.shape[0], size=[in_batch_size], replace=False)].long()
    _selected_rays_oris = in_rays_oris[_selected_coords[:, 0], _selected_coords[:, 1]]
    _selected_rays_dirs = in_rays_dirs[_selected_coords[:, 0], _selected_coords[:, 1]]
    _t_vals = torch.linspace(0., 1., steps=in_sample_size)
    _z_vals = 1. / (1. / in_near * (1. - _t_vals) + 1. / in_far * _t_vals)
    _z_vals = _z_vals.expand([in_batch_size, in_sample_size])

    if True:  # perturb samples
        _mid = .5 * (_z_vals[..., 1:] + _z_vals[..., :-1])
        _upper = torch.cat([_mid, _z_vals[..., -1:]], -1)
        _lower = torch.cat([_z_vals[..., :1], _mid], -1)
        _t_rand = torch.rand(_z_vals.shape)
        _z_vals = _lower + (_upper - _lower) * _t_rand

    _sampled_points = _selected_rays_oris[..., None, :] + _selected_rays_dirs[..., None, :] * _z_vals[..., :, None]
    return _sampled_points


def generate_position_encoding_fn(in_multires: int):
    _input_dims = 3
    _out_dim = _input_dims
    _max_freq_log2 = in_multires - 1
    _num_freq = in_multires
    _freq_bands = 2. ** torch.linspace(0., _max_freq_log2, steps=_num_freq)

    embed_fns = [lambda x: x]
    for _freq in _freq_bands:
        for _p_fn in [torch.sin, torch.cos]:
            embed_fns.append(lambda x, p_fn=_p_fn, freq=_freq: p_fn(x * freq))
            _out_dim += _input_dims

    _embed_fn = lambda x: torch.cat([fn(x) for fn in embed_fns], -1)
    return _embed_fn, _out_dim

## 4. Volume Rendering

In [17]:
def render(in_points: torch.Tensor):
    pass

## 5. Training

In [16]:
def train():
    pass

In [15]:
IMAGES_TRAIN, POSE_TRAIN, FRAMES_TRAIN, ANGLE_TRAIN = load_data_blender('train', 1)
IMAGES_VAL, POSE_VAL, FRAMES_VAL, ANGLE_VAL = load_data_blender('val', 8)
IMAGES_TEST, POSE_TEST, FRAMES_TEST, ANGLE_TEST = load_data_blender('test', 8)

WIDTH_ORIGIN, HEIGHT_ORIGIN = IMAGES_TRAIN.shape[2], IMAGES_TRAIN.shape[1]
FOCAL_ORIGIN = 0.5 * WIDTH_ORIGIN / np.tan(0.5 * ANGLE_TRAIN)

IMAGES_TRAIN_RESAMPLED, WIDTH_TRAIN_RESAMPLED, HEIGHT_TRAIN_RESAMPLED, FOCAL_TRAIN_RESAMPLED = \
    resample_images(IMAGES_TRAIN, FOCAL_ORIGIN, 0.5)
IMAGES_VAL_RESAMPLED, WIDTH_VAL_RESAMPLED, HEIGHT_VAL_RESAMPLED, FOCAL_VAL_RESAMPLED = \
    resample_images(IMAGES_VAL, FOCAL_ORIGIN, 0.5)
IMAGES_TEST_RESAMPLED, WIDTH_TEST_RESAMPLED, HEIGHT_TEST_RESAMPLED, FOCAL_TEST_RESAMPLED = \
    resample_images(IMAGES_TEST, FOCAL_ORIGIN, 0.5)

IMAGES_TRAIN_WHITE_BACKGROUND = make_white_background(IMAGES_TRAIN)
IMAGES_VAL_WHITE_BACKGROUND = make_white_background(IMAGES_VAL)
IMAGES_TEST_WHITE_BACKGROUND = make_white_background(IMAGES_TEST)

MODEL_COARSE = NeRF(args.coarse_net_width, args.coarse_net_depth, 3, 4, 8, True, [4])
GRAD_VARS = list(MODEL_COARSE.parameters())
OPTIMIZER = torch.optim.Adam(params=GRAD_VARS, lr=args.learning_rate, betas=(0.9, 0.999))

RAY_ORIS, RAY_DIRS = generate_rays(WIDTH_ORIGIN, HEIGHT_ORIGIN, FOCAL_ORIGIN, torch.tensor(POSE_TRAIN[0]).float())
POINTS = generate_query_points(RAY_ORIS, RAY_DIRS, 2., 6., 1024, 64)
POINTS_FLAT = torch.reshape(POINTS, [-1, POINTS.shape[-1]])
EMBED_fn, OUT_DIM = generate_position_encoding_fn(10)