# View Synthesis

> Implement by NeRF in pyTorch

## Task Description

"View synthesis" is a task which
generating images of a 3D scene from a specific point of view.

## Solution Description

"NeRF" (Neural Radiance Field) solved "View synthesis"
by representing 3D scene using a neural network.

## Pipeline Description

1. Preprocessing
2. Inference
3. Rendering
4. Training

## Pipeline Description - Preprocessing

{{ True image }} → {{ Position, Direction, True color }}

## Pipeline Description - Inference

{{ Position, Direction }} → {{ Network }} → {{ Color, Density }}

## Pipeline Description - Rendering

{{ Color, Density }} → {{ Ray sampling }} → {{ Rendered color }}

## Pipeline Description - Training

{{ True color, Rendered color, Network }} → {{ Network }}

## Optimization Description

1. Positional Encoding of input coordinates
    - For learning high-frequency features
    - Using Fourier features
2. Hierarchical Sampling
    - For high-frequency representions
    - Using two networks with different sample size
3. Gradient Descent
    - For minimizing the error between the true and rendered images

## Inference Details

$$
Predict(Position, Direction) = \{Color, Density\} \\
\text{where }
\begin{cases}
    Predict \text{ is MLP}, \\
    Position \in \mathbb{R}^{3}, \\
    Direction \in \mathbb{R}^{3}, \\
    Color \in \mathbb{R}^{3}, \\
    Density \in \mathbb{R}
\end{cases}
$$

### Positional Encoding

The raw and encoded coordinate values will be concatenated to form the network input.

Each coordinate value in `Position` and `Direction` is encoded as follows:

$$
Encode_{N}(p) \\
= \{\sin (2^0 \pi p), \cos (2^0 \pi p), \ldots, \sin (2^{N-1} \pi p), \cos (2^{N-1} \pi p)\} \\
= \{\sin (2^0 \pi p), \sin (\frac{\pi}{2} + 2^0 \pi p), \ldots, \sin (2^{N-1} \pi p), \sin (\frac{\pi}{2} + 2^{N-1} \pi p)\} \\
\text{where } p \in \mathbb{R}, \ N \in \mathbb{N}, \ Encode_{N}(p) \in \mathbb{R}^{2N}
$$

The encoded dimensions are calculated as follows:

| Input     | Dimension | N   | Encoded Dimension |
| --------- | --------- | --- | ----------------- |
| Position  | 3         | 10  | $3 (1 + 2N) = 63$ |
| Direction | 3         | 4   | $3 (1 + 2N) = 27$ |


## Inference Details (Cont.)

### Network

The network is a multi-layer perceptron (MLP) with the following architecture:

| Layer                                  | Input Dimension | Output Dimension | Activation |
| -------------------------------------- | --------------- | ---------------- | ---------- |
| Input Position                         | 63              | 256              | ReLU       |
| Hidden 1                               | 256             | 256              | ReLU       |
| Hidden 2                               | 256             | 256              | ReLU       |
| Hidden 3                               | 256             | 256              | ReLU       |
| Hidden 4                               | 256             | 256              | ReLU       |
| Hidden 5 + Input Position (Skip conn.) | 256 + 63        | 256              | ReLU       |
| Hidden 6                               | 256             | 256              | ReLU       |
| Hidden 7                               | 256             | 256              | ReLU       |
| Hidden 8 + Input Density               | 256 + 27        | 128              | ReLU       |
| Output Density (from Hidden 7\*)       | 256             | 1                | Custom     |
| Output Color (from Hidden 8)           | 128             | 3                | Sigmoid    |

<small>\* The density is not dependent on the direction.</small>


## Training Details

Loss = Sum of (coarse - true)^2 + (fine - true)^2

## References

1. View synthesis. (n.d.). In Wikipedia. Retrieved from https://en.wikipedia.org/wiki/View_synthesis
2. Neural radiance field. (n.d.). In Wikipedia. Retrieved from https://en.wikipedia.org/wiki/Neural_radiance_field
3. Mildenhall, B., Srinivasan, P. P., Tancik, M., Barron, J. T., Ramamoorthi, R., & Ng, R. (2020). NeRF: Neural radiance fields for image synthesis. arXiv preprint arXiv:2003.08934. Retrieved from https://arxiv.org/pdf/2003.08934
4. Tancik, M., Srinivasan, P. P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., Barron, J. T., & Ng, R. (2020). Fourier features let networks learn high frequency functions in low dimensional domains. NeurIPS. Retrieved from https://arxiv.org/pdf/2006.10739

## Implementation

Positional Encoding:

In [1]:
from torch import Tensor
from torch.nn import Module
from torch.types import Device


class PositionalEncoder(Module):
    def __init__(self, encoding_factor: int, device: Device | None = None):
        import torch

        super(PositionalEncoder, self).__init__()

        encoding_factor = max(int(encoding_factor), 0)

        freq_lvls = torch.arange(encoding_factor, device=device)
        self.freq = ((1 << freq_lvls) * torch.pi).repeat_interleave(2).reshape(-1, 1, 1)
        sine_offsets = torch.tensor([[[0.0]], [[torch.pi / 2]]])
        self.offsets = sine_offsets.repeat(encoding_factor, 1, 1)

    def forward(self, inputs: Tensor) -> Tensor:
        import torch

        feature_dim = self.freq.shape[0] * inputs.shape[-1]
        features = (self.freq * inputs + self.offsets).sin_()
        features = features.swapdims_(0, 1).reshape(*inputs.shape[:-1], feature_dim)
        encoding = torch.concat([inputs, features], dim=-1)
        return encoding

In [2]:
from torch import Tensor
from torch.nn import Module


class NeRF(Module):
    def __init__(
        self,
        layer_count: int | None = None,
        hidden_dim: int | None = None,
        extra_hidden_dim: int | None = None,
        position_encoding_factor: int | None = None,
        direction_encoding_factor: int | None = None,
    ):
        from torch import nn

        super(NeRF, self).__init__()

        layer_count = int(layer_count or 8)
        hidden_dim = int(hidden_dim or 256)
        extra_hidden_dim = int(extra_hidden_dim or hidden_dim // 2)
        position_encoding_factor = int(position_encoding_factor or 10)
        direction_encoding_factor = int(direction_encoding_factor or 4)

        COLOR_DIM = 3
        DENSITY_DIM = 1
        RAW_POSITION_DIM = 3
        RAW_DIRECTION_DIM = 3
        encoded_position_dim = RAW_POSITION_DIM * (1 + 2 * position_encoding_factor)
        encoded_direction_dim = RAW_DIRECTION_DIM * (1 + 2 * direction_encoding_factor)

        self.position_hidden_layer_skip_indexs = set(
            [i for i in range(1, layer_count - 1) if i % 4 == 0]
        )
        self.position_input_layer = nn.Linear(encoded_position_dim, hidden_dim)
        self.position_hidden_layers = nn.ModuleList(
            [
                (
                    nn.Linear(hidden_dim + encoded_position_dim, hidden_dim)
                    if i in self.position_hidden_layer_skip_indexs
                    else nn.Linear(hidden_dim, hidden_dim)
                )
                for i in range(layer_count - 1)
            ]
        )
        self.direction_input_layer = nn.Linear(
            hidden_dim + encoded_direction_dim, extra_hidden_dim
        )
        self.density_output_layer = nn.Linear(hidden_dim, DENSITY_DIM)
        self.color_output_layer = nn.Linear(extra_hidden_dim, COLOR_DIM)

        self.position_input_encoder = PositionalEncoder(position_encoding_factor)
        self.direction_input_encoder = PositionalEncoder(direction_encoding_factor)

    def forward(self, inputs: Tensor):
        import torch

        raw_positions = inputs[:, :3]
        raw_directions = inputs[:, 3:]
        encoded_positions: Tensor = self.position_input_encoder(raw_positions)
        encoded_directions: Tensor = self.direction_input_encoder(raw_directions)

        hidden_position: Tensor = self.position_input_layer(encoded_positions)
        for i, layer in enumerate(self.position_hidden_layers):
            hidden_position.relu_()
            hidden_position = layer(
                torch.concat([hidden_position, encoded_positions], dim=-1)
                if i in self.position_hidden_layer_skip_indexs
                else hidden_position
            )
        hidden_position.relu_()

        hidden_direction: Tensor = self.direction_input_layer(
            torch.concat([hidden_position, encoded_directions], dim=-1)
        ).relu_()
        color: Tensor = self.color_output_layer(hidden_direction).sigmoid_()
        density: Tensor = self.density_output_layer(hidden_position)

        return density, color

In [3]:
NeRF()

NeRF(
  (position_input_layer): Linear(in_features=63, out_features=256, bias=True)
  (position_hidden_layers): ModuleList(
    (0-3): 4 x Linear(in_features=256, out_features=256, bias=True)
    (4): Linear(in_features=319, out_features=256, bias=True)
    (5-6): 2 x Linear(in_features=256, out_features=256, bias=True)
  )
  (direction_input_layer): Linear(in_features=283, out_features=128, bias=True)
  (density_output_layer): Linear(in_features=256, out_features=1, bias=True)
  (color_output_layer): Linear(in_features=128, out_features=3, bias=True)
  (position_input_encoder): PositionalEncoder()
  (direction_input_encoder): PositionalEncoder()
)