In [40]:
%%html
<style>
    @page {
        size: A3 landscape;
        margin: 0;
    }
    .jp-RenderedMermaid {
        justify-content: center;
    }
    h2 {
        page-break-before: always;
    }
</style>

# View Synthesis

> Implement by NeRF in pyTorch

## Task Description

"View synthesis" is a task which
generating images of a 3D scene from a specific point of view.

## Solution Description

"NeRF" (Neural Radiance Field) solved "View synthesis"
by representing 3D scene using a neural network.

## Pipeline Description

1. Preprocessing
2. Inference
3. Rendering
4. Training

## Pipeline Description - Preprocessing

{{ True image }} → {{ Position, Direction, True color }}

## Pipeline Description - Inference

{{ Position, Direction }} → {{ Volumetric sampling }} → {{ Positional encoding }} → {{ Network }} → {{ Color, Density }}

## Pipeline Description - Rendering

{{ Color, Density }} → {{ Alpha blending }} → {{ Rendered color }}

## Pipeline Description - Training

{{ True color, Rendered color, Network }} → {{ Network }}

```mermaid
graph TD
    subgraph Preprocessing
        ti[True Image]
        cp[Camera Posture]
        tc[True Color]
    end

    subgraph Inference
        vs[Volume Sampling]
        B --> C[Ray Generation]
        C --> D[Ray Batching]
        D --> E[Volume Sampling]
        E --> F[Positional Encoding]
        F --> G[Network]
        G --> H[Color, Density]
    end

    subgraph Rendering
        H --> I[Alpha Blending]
        I --> J[Rendered Color]
    end

    subgraph Training
        B -.-> L[Loss Calculation]
        J -.-> L
        G -.-> L
        L --> G
    end
```

## Optimization Description

1. Positional Encoding of input coordinates
    - For learning high-frequency features
    - Using Fourier features
2. Stochastic Gradient Descent
    - For minimizing the error between the true and rendered images
    - Choosing a random image from the dataset each iteration
<!-- 3. Hierarchical Sampling
    - For high-frequency representions
    - Using two networks with different sample size -->

## Inference Details

### Positional Encoding

The raw and encoded coordinate values will be concatenated to form the network input.

Each coordinate value in `Position` and `Direction` is encoded as follows:

$$
Encode_{N}(p) \\

= \{\sin (2^0 \pi p), \cos (2^0 \pi p), \ldots, \sin (2^{N-1} \pi p), \cos (2^{N-1} \pi p)\} \\

= \{\sin (2^0 \pi p), \sin (\frac{\pi}{2} + 2^0 \pi p), \ldots, \sin (2^{N-1} \pi p), \sin (\frac{\pi}{2} + 2^{N-1} \pi p)\} \\

\text{where } p \in \mathbb{R}, \ N \in \mathbb{N}, \ Encode_{N}(p) \in \mathbb{R}^{2N}
$$

The encoded dimensions are calculated as follows:

| Input     | Dimension | N   | Encoded Dimension |
| --------- | --------- | --- | ----------------- |
| Position  | 3         | 10  | $3 (1 + 2N) = 63$ |
| Direction | 3         | 4   | $3 (1 + 2N) = 27$ |


## Inference Details (Cont.)

### Network Definition

The neural network is a multi-layer perceptron (MLP) with the following structure:
- The density is not dependent on the direction
- The fifth hidden layer concatenates the input as a skip connection

```mermaid
%%{init: {
    "theme": "neutral",
    "themeVariables": {
        "fontFamily": "Menlo, monospace",
        "fontSize": "10px"
    }
}}%%
flowchart TD
    ip1([Input Position 3])
    ep1([Encoded Position 63])
    h1([Hidden Layer 256])
    h2([Hidden Layer 256])
    h3([Hidden Layer 256])
    h4([Hidden Layer 256])
    h5([Hidden Layer 256])
    ip2([Input Position 3])
    ep2([Encoded Position 63])
    h6([Hidden Layer 256])
    h7([Hidden Layer 256])
    h8([Hidden Layer 256])
    od([Output Density 1])
    iof([Input/Output Feature 256])
    id([Input Direction 3])
    ed([Encoded Direction 27])
    ha([Additional Hidden Layer 128])
    oc([Output Color 3])

    ip1 -->|Encode| ep1
    ep1 -->|ReLU| h1
    h1 -->|ReLU| h2
    h2 -->|ReLU| h3
    h3 -->|ReLU| h4
    h4 -->|ReLU| h5
    ip2 -->|Encode| ep2
    ep2 ---|Concatenate| h5
    h5 -->|ReLU| h6
    h6 -->|ReLU| h7
    h7 -->|ReLU| h8
    h8 -->|ReLU| od
    h8 --> iof
    id -->|Encode| ed
    ed ---|Concatenate| iof
    iof -->|ReLU| ha
    ha -->|Sigmoid| oc

    style ip1 fill:palegreen
    style ip2 fill:palegreen
    style id fill:palegreen
    style ep2 fill:mediumaquamarine
    style ep1 fill:mediumaquamarine
    style ed fill:mediumaquamarine
    style h1 fill:deepskyblue
    style h2 fill:deepskyblue
    style h3 fill:deepskyblue
    style h4 fill:deepskyblue
    style h5 fill:deepskyblue
    style h6 fill:deepskyblue
    style h7 fill:deepskyblue
    style h8 fill:deepskyblue
    style ha fill:deepskyblue
    style iof fill:tan
    style od fill:salmon
    style oc fill:salmon
```

## Inference Details (Cont.)

### Volume Sampling

To represent a continuous scene, we can sample points along the rays, which can be written as $r(t_{i}) = o + d t_{i}$ where:
- $o$ is the origin point
- $d$ is the direction vector
- $t_{i} \sim U[\frac{i - 1}{N}, \frac{i}{N}]$ is the distance along the ray
- $N$ is the number of samples per ray, we use $N = 64$

## Rendering Details (Cont.)

### Alpha Blending

$$
\hat{C} = \sum_{i=1}^{N} T_{i} \alpha_{i} c_{i} \\
$$

## Training Details

| Module Name                       | Details                                    |
| --------------------------------- | ------------------------------------------ |
| Optimizer - Adam                  | Learning rate: $5 \times 10^{-4}$          |
| Loss function - Mean Square Error | Error between the true and rendered colors |
| Data loader                       | 1024 rays per batch                        |

$$
Loss = \frac{\Sigma_{r \in \mathbb{R}} (C_{rendered}(r) - C_{true}(r))^2}{|\mathbb{R}|} \\

\text{where } \mathbb{R} \text{ is a batch of rays}
$$


## References

1. View synthesis. (n.d.). In Wikipedia. Retrieved from https://en.wikipedia.org/wiki/View_synthesis
2. Neural radiance field. (n.d.). In Wikipedia. Retrieved from https://en.wikipedia.org/wiki/Neural_radiance_field
3. Mildenhall, B., Srinivasan, P. P., Tancik, M., Barron, J. T., Ramamoorthi, R., & Ng, R. (2020). NeRF: Neural radiance fields for image synthesis. arXiv preprint arXiv:2003.08934. Retrieved from https://arxiv.org/pdf/2003.08934
4. Tancik, M., Srinivasan, P. P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., Barron, J. T., & Ng, R. (2020). Fourier features let networks learn high frequency functions in low dimensional domains. NeurIPS. Retrieved from https://arxiv.org/pdf/2006.10739

## Implementation

----
#### Positional Encoding

In [41]:
from torch import Tensor
from torch.nn import Module
from torch.types import Device


class PositionalEncoder(Module):
    """
    ### Dimension Transformations
    - Input: `[..., D]`
    - Output: `[..., D * (2 * encoding_factor + 1)]`
    """

    def __init__(self, encoding_factor: int, device: Device | None = None):
        import torch

        super(PositionalEncoder, self).__init__()

        encoding_factor = max(int(encoding_factor), 0)

        freq_lvls = torch.arange(encoding_factor, device=device)
        self.freq = ((1 << freq_lvls) * torch.pi).repeat_interleave(2)
        sine_offsets = torch.tensor([0.0, torch.pi / 2])
        self.offsets = sine_offsets.repeat(encoding_factor)

    def forward(self, inputs: Tensor) -> Tensor:
        import torch

        inputs = torch.as_tensor(inputs).unsqueeze(-1)

        features = (self.freq * inputs + self.offsets).sin_()
        features = torch.concat([inputs, features], dim=-1)
        features = features.reshape(*inputs.shape[:-2], -1)
        return features

----
#### Network Definition

In [42]:
from torch import Tensor
from torch.nn import Module


class NeRF(Module):
    def __init__(
        self,
        layer_count: int | None = None,
        hidden_dim: int | None = None,
        additional_hidden_dim: int | None = None,
        position_encoding_factor: int | None = None,
        direction_encoding_factor: int | None = None,
    ):
        from torch import nn

        super(NeRF, self).__init__()

        layer_count = int(layer_count or 8)
        hidden_dim = int(hidden_dim or 256)
        additional_hidden_dim = int(additional_hidden_dim or hidden_dim // 2)
        position_encoding_factor = int(position_encoding_factor or 10)
        direction_encoding_factor = int(direction_encoding_factor or 4)

        COLOR_DIM = 3
        DENSITY_DIM = 1
        RAW_POSITION_DIM = 3
        RAW_DIRECTION_DIM = 3
        encoded_position_dim = RAW_POSITION_DIM * (1 + 2 * position_encoding_factor)
        encoded_direction_dim = RAW_DIRECTION_DIM * (1 + 2 * direction_encoding_factor)

        self.position_hidden_layer_skip_indexs = set(
            [i for i in range(1, layer_count - 1) if i % 4 == 0]
        )
        self.position_input_layer = nn.Linear(encoded_position_dim, hidden_dim)
        self.position_hidden_layers = nn.ModuleList(
            [
                (
                    nn.Linear(hidden_dim + encoded_position_dim, hidden_dim)
                    if i in self.position_hidden_layer_skip_indexs
                    else nn.Linear(hidden_dim, hidden_dim)
                )
                for i in range(layer_count)
            ]
        )
        self.density_output_layer = nn.Linear(hidden_dim, DENSITY_DIM)
        self.direction_input_layer = nn.Linear(
            hidden_dim + encoded_direction_dim,
            additional_hidden_dim,
        )
        self.color_output_layer = nn.Linear(additional_hidden_dim, COLOR_DIM)

        self.position_input_encoder = PositionalEncoder(position_encoding_factor)
        self.direction_input_encoder = PositionalEncoder(direction_encoding_factor)

    def forward(self, inputs: Tensor):
        import torch

        inputs = torch.as_tensor(inputs)

        raw_positions = inputs[..., :3]
        raw_directions = inputs[..., 3:]
        encoded_positions: Tensor = self.position_input_encoder(raw_positions)
        encoded_directions: Tensor = self.direction_input_encoder(raw_directions)

        hidden_positions: Tensor = self.position_input_layer(encoded_positions)
        for index, layer in enumerate(self.position_hidden_layers):
            hidden_positions.relu_()
            hidden_positions = layer(
                torch.concat([hidden_positions, encoded_positions], dim=-1)
                if index in self.position_hidden_layer_skip_indexs
                else hidden_positions
            )

        density: Tensor = self.density_output_layer(hidden_positions).relu_()
        hidden_directions: Tensor = self.direction_input_layer(
            torch.concat([hidden_positions, encoded_directions], dim=-1)
        ).relu_()
        color: Tensor = self.color_output_layer(hidden_directions).sigmoid_()

        return torch.concat([color, density], dim=-1)

----
#### Weight Initialization

In [43]:
class LogNormalInitializer:
    """
    ## Examples
    ```python
    from torch.nn import Module

    Module().apply(LogNormalInitializer(mean=0.0, std=2.0, seed=1))
    ```
    """
    def __init__(
        self,
        mean: float | None = None,
        std: float | None = None,
        seed: int | None = None,
    ):
        mean = float(mean or 0.0)
        std = float(std or 2.0)
        seed = int(seed or 1)

        self.mean = mean
        self.std = std
        self.seed = seed

    def __call__(self, module: Module) -> None:
        import torch

        if isinstance(module, torch.nn.Linear):
            if module.weight is not None:
                with torch.no_grad():
                    module.weight.log_normal_(
                        mean=self.mean,
                        std=self.std,
                        generator=torch.Generator(
                            module.weight.device,
                        ).manual_seed(
                            self.seed,
                        ),
                    ).clamp_min_(
                        torch.finfo(torch.float32).eps,
                    )
                    module.weight.div_(
                        module.weight.max(),
                    ).clamp_min_(
                        torch.finfo(torch.float32).eps,
                    )
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

In [44]:
NeRF().apply(LogNormalInitializer())

NeRF(
  (position_input_layer): Linear(in_features=63, out_features=256, bias=True)
  (position_hidden_layers): ModuleList(
    (0-3): 4 x Linear(in_features=256, out_features=256, bias=True)
    (4): Linear(in_features=319, out_features=256, bias=True)
    (5-7): 3 x Linear(in_features=256, out_features=256, bias=True)
  )
  (density_output_layer): Linear(in_features=256, out_features=1, bias=True)
  (direction_input_layer): Linear(in_features=283, out_features=128, bias=True)
  (color_output_layer): Linear(in_features=128, out_features=3, bias=True)
  (position_input_encoder): PositionalEncoder()
  (direction_input_encoder): PositionalEncoder()
)

In [45]:
!curl -OL 'http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   348  100   348    0     0   1070      0 --:--:-- --:--:-- --:--:--  1070
100 12.1M  100 12.1M    0     0   735k      0  0:00:16  0:00:16 --:--:-- 1884k0     0    0     0      0      0 --:--:--  0:00:02 --:--:--     0


In [46]:
import numpy as np

data = np.load('tiny_nerf_data.npz')
images = data['images']
poses = data['poses']
focal = data['focal']
H, W = images.shape[1:3]
(images.shape, poses.shape, focal)
# H W C

((106, 100, 100, 3), (106, 4, 4), array(138.8888789))

In [47]:
poses[0]

array([[-9.9990219e-01,  4.1922452e-03, -1.3345719e-02, -5.3798322e-02],
       [-1.3988681e-02, -2.9965907e-01,  9.5394367e-01,  3.8454704e+00],
       [-4.6566129e-10,  9.5403719e-01,  2.9968831e-01,  1.2080823e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  1.0000000e+00]],
      dtype=float32)

----
#### Volume Sampling

In [48]:
from torch import Tensor


class VolumeSampler:
    """
    ## Dimension Transformations
    - Input: `[4, 4]`
    - Output: `[height, width, points_per_ray, 3]`
    """
    def __init__(
        self,
        height: int,
        width: int,
        focal: float,
        points_per_ray: int | None = None,
    ):
        import torch

        focal_inverse = 1.0 / float(focal)
        unit_half_norm = focal_inverse / 2
        height_half_norm = height * unit_half_norm
        width_half_norm = width * unit_half_norm
        points_per_ray = int(points_per_ray or 96)

        self.directions = torch.stack(
            torch.meshgrid(
                torch.arange(
                    -width_half_norm + unit_half_norm,
                    width_half_norm,
                    focal_inverse,
                ),
                torch.arange(
                    height_half_norm - unit_half_norm,
                    -height_half_norm,
                    -focal_inverse,
                ),
                torch.tensor(-1.0),
                indexing="xy",
            ),
            dim=-1,
        )
        self.distances = (
            torch.arange(points_per_ray, dtype=torch.float)
            .add_(torch.rand(points_per_ray))
            .div_(points_per_ray)
            .unsqueeze(-1)
        )

    def __call__(self, posture: Tensor) -> Tensor:
        import torch

        posture = torch.as_tensor(posture)[:3]

        directions = (self.directions * posture[:, :3]).sum(dim=-1)
        origins = posture[:, 3].broadcast_to(directions.shape)
        points = origins.unsqueeze(-2) + directions.unsqueeze(-2) * self.distances
        return points

In [50]:
VolumeSampler(H, W, focal)(poses[0]).shape

torch.Size([100, 100, 96, 3])