In [None]:
# Allow importing from src
import sys
sys.path.insert(0, '../src/')

## Imports

In [None]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

## Example data

In [None]:
data =  np.load("../data/tiny_nerf_data.npz")
images, c2ws, focal = data["images"], data["poses"], data["focal"]

print(
    f"Shapes:",
    f"{images.shape=}",
    f"{c2ws.shape=}",
    f"{focal.shape=}",
    sep='\n  ',
)

plt.imshow(images[2])
plt.title("Image at index 2")

print(f"C2W transform at index 2: \n", c2ws[2])

print(f"Focal length: {focal:.4f}")

images = torch.from_numpy(images)
c2ws = torch.from_numpy(c2ws)
focal = torch.from_numpy(focal)

## Utility functions

To be exported into `ROOT_DIR/src/utils/`

In [None]:
def create_rays(height, width, intrinsic, c2w):
    focal_x = intrinsic[0, 0]
    focal_y = intrinsic[1, 1]
    # cx and cy handle the misalignement of the principal point with the center of the image
    cx = intrinsic[0, 2]
    cy = intrinsic[1, 2]

    # Index each point on the image, determine ray directions to them
    i, j = torch.meshgrid(torch.arange(width, dtype=torch.float32), torch.arange(height, dtype=torch.float32), indexing='xy')
    directions = torch.stack((
        (i - cx) / focal_x,
        -(j - cy) / focal_y,
        -torch.ones(i.shape, dtype=torch.float32)  # -1 since ray is cast away from camera
    ), -1)

    # Transform ray directions to World, origins just need to be broadcasted accordingly
    ray_directions = directions @ c2w[:3, :3].T
    ray_origins = torch.broadcast_to(c2w[:3, -1], ray_directions.shape)  # c2w last column determines position
    
    return ray_origins, ray_directions


# Test on real data
ex_index = 2
ex_img = images[ex_index]
ex_intr = torch.tensor([
    [focal.item(), 0, ex_img.shape[1] // 2],
    [0, focal.item(), ex_img.shape[0] // 2],
    [0, 0, 1],
], dtype=torch.float32)

origins, directions = create_rays(ex_img.shape[0], ex_img.shape[1], ex_intr, c2ws[ex_index])

"""
# Test on example data
ex_intr = torch.tensor([
    [4, 0, 5 // 2],
    [0, 4, 5 // 2],
    [0, 0, 1],
], dtype=torch.float32)

ex_c2w = torch.tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 1],
    [0, 0, 0, 1],
], dtype=torch.float32)

o, d = create_rays(5, 5, ex_intr, ex_c2w)
"""

print(f"{origins.shape=} | {directions.shape=}")

In [None]:
def create_nerf_data(images, c2ws, focal):
    rays = []
    colors = []

    # Collecting to list then concat for ease
    for image, c2w in zip(images, c2ws):
        intrinsic = torch.tensor([
            [focal.item(), 0, image.shape[1] // 2],
            [0, focal.item(), image.shape[0] // 2],
            [0, 0, 1],
        ], dtype=torch.float32)

        origins, directions = create_rays(image.shape[0], image.shape[1], intrinsic, c2w)

        data = torch.stack([origins, directions], dim=2)
        rays.append(data.flatten(0, 1))
        colors.append(image.flatten(0, 1))
        # shape(height * width, 1+1+1, 3) {1+1+1===origin, direction, rgb}

    rays = torch.cat(rays, dim=0)
    colors = torch.cat(colors, dim=0)
    return rays, colors


rays, colors = create_nerf_data(images, c2ws, focal)

print(f"{rays.shape=} | {colors.shape=}")
print(f"Training rays' size:   {rays.element_size() * rays.nelement() / (1024**2):.2f} MB")
print(f"Training images' size: {colors.element_size() * colors.nelement() / (1024**2):.2f} MB")

In [None]:
class PositionalEncoding(nn.Module):
    # Appended Positional Encoding Module

    def __init__(self, max_freq: int):
        super(PositionalEncoding, self).__init__()
        self._max_freq = max_freq
        self._freq_bands = 2.0 ** torch.linspace(0.0, max_freq - 1, steps=max_freq, dtype=torch.float32)
    
    def forward(self, x):
        encs = (x[..., None] * self._freq_bands).flatten(-2, -1)
        # Encoding to (x, sin parts, cos parts) of shape(N, M+M*max_freq*2) if x is of shape(N,M)
        return torch.cat([x, encs.sin(), encs.cos()], dim=-1)
    
    def get_out_dim(self, in_dim):
        return in_dim + in_dim * self._max_freq * 2
    
    
enc = PositionalEncoding(10)
print(f"Output dimension: {enc.get_out_dim(3)}")
enc(rays[:1000, 1, :]).shape

In [None]:
def sample_ray_uniformally(origins, directions, near, far, num_samples, perturb=True):
    depths = torch.linspace(near, far, num_samples, dtype=torch.float32).expand(origins.shape[0], -1)

    if perturb:
        # Noise is at most half of step size, this ensures sorted depths, required for volume rendering
        noise = (torch.rand(depths.shape) - 0.5) * (far - near) / num_samples / 2
        # Clamping to stay between near and far
        depths = (depths + noise).clamp(near, far)

    points = origins[..., None, :] + directions[..., None, :] * depths[..., :, None]
    # Expand directions to make NeRF input
    directions = directions[..., None, :].expand(points.shape)
    return points, directions, depths


def plot_ray_sampling(points, cartesian_direction, title):
    fig, axes = plt.subplots(1, 2, figsize=(16,8), subplot_kw={"projection": "3d"})
    axes = axes.flatten()
    fig.suptitle(title)
    plt.tight_layout()

    # Convert to spherical coordinates
    X, Y, Z = -cartesian_direction
    R = torch.sqrt(X**2 + Y**2 + Z**2)
    X, Y, Z = X/R, Y/R, Z/R  # normalization
    azim = torch.rad2deg(torch.atan2(Y, X))
    elev = torch.rad2deg(torch.arcsin(Z))
    # Multiple angles to understand better
    for ax, (mod_elev, mod_azim) in zip(axes, [[0,0],[-10, 60]]):
        ax.view_init(elev + mod_elev, azim + mod_azim, 0)
        ax.plot(points[:, :, 0], points[:, :, 1], points[:, :, 2], linewidth=0.2, markersize=2, marker='o')
    plt.show()


points, directions, depths = sample_ray_uniformally(rays[:10000:71, 0], rays[:10000:71, 1], 0.0, 10.0, 7, perturb=True)

print(f"{points.shape=} | {directions.shape=} | {depths.shape=}")
plot_ray_sampling(points, directions[:, 0, :].mean(0), "Example of perturbed uniform samples along ray")        

In [None]:
def sample_pdf(bins, weights, num_samples, deterministic=False):
    weights = weights + 1e-5  # avoid nans later
    pdf = weights / torch.sum(weights, -1, keepdim=True)  # Normalize PDF
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1)  # Prepend 0 to have cdf->[0,1]

    if deterministic:
        u = torch.linspace(0.0, 1.0, steps=num_samples)
        u = u.expand(list(cdf.shape[:-1]) + [num_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [num_samples])

    # Inverting the CDF
    u = u.contiguous()  # Need contigous memory layout for further operations
    indexes = torch.searchsorted(cdf, u, right=True)  # Finding bins
    # Need to ensure below and above don't leave bounds of bins
    below = torch.max(torch.zeros_like(indexes-1), indexes-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(indexes), indexes)
    indexes = torch.stack([below, above], dim=-1)

    # Gathering sampled bins and bound probabilites
    shape = [indexes.shape[0], indexes.shape[1], cdf.shape[-1]]
    cdf = torch.gather(cdf.unsqueeze(1).expand(shape), dim=2, index=indexes)
    bins = torch.gather(bins.unsqueeze(1).expand(shape), dim=2, index=indexes)

    # denominator is the size of the bins
    denominator = cdf[..., 1] - cdf[..., 0]
    denominator = torch.where(denominator < 1e-5, torch.ones_like(denominator), denominator)
    denominator[denominator < 1e-5] = 1.0
    # t gives the relative position inside the bins 
    t = (u - cdf[..., 0]) / denominator

    samples = bins[..., 0] + t * (bins[..., 1] - bins[..., 0])
    return samples


bins = torch.linspace(0, 1, steps=11) # 10 intervals (bins)
weights = torch.tensor([0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.3, 0.25, 0.2, 0.1]) # Arbitrary weights
weights = weights / torch.sum(weights, -1)

# Expand the inputs to include a batch dimension
bins = bins.unsqueeze(0) # Shape: [1, 11]
weights = weights.unsqueeze(0) # Shape: [1, 10]
samples = sample_pdf(bins, weights, 5)

print(f"{samples=}")

In [None]:
def sample_ray_hierarchically(origins, directions, num_samples, bins, weights, deterministic=False):
    depths = sample_pdf(bins, weights, num_samples, deterministic=deterministic)

    points = origins[..., None, :] + directions[..., None, :] * depths[..., :, None]
    # Expand directions to make NeRF input
    directions = directions[..., None, :].expand(points.shape)
    return points, directions, depths

bins = torch.linspace(0, 1, 65).expand((10000 // 71 + 1, 65))  # Computed after uniform sampling
weights = torch.randn((10000 // 71 + 1, 64))  # Retrieved from NeRF's sigma values
points, directions, depths = sample_ray_hierarchically(rays[:10000:71, 0], rays[:10000:71, 1], 64, bins, weights)

print(f"{points.shape=} | {directions.shape=} | {depths.shape=}")
plot_ray_sampling(points, directions[:, 0, :].mean(0), "Example of perturbed uniform samples along ray")    

In [None]:
class NeRF(nn.Module):
    def __init__(self, num_layers=8, hidden_size=256, in_coordinates=3, in_directions=3,
                 skips=[4], coord_encode_freq=10, dir_encode_freq=4):
        super(NeRF, self).__init__()
        self.in_coordinates = in_coordinates
        self.in_directions = in_directions
        self.skips = tuple(skips)

        self.coordinate_encoder = PositionalEncoding(coord_encode_freq)
        self.direction_encoder = PositionalEncoding(dir_encode_freq)

        coord_dim = self.coordinate_encoder.get_out_dim(self.in_coordinates)
        self.feature_mlp = nn.ModuleList([nn.Linear(coord_dim, hidden_size)])
        # go until num_layers -1 as we already have the initial layer
        for i in range(num_layers - 1):
            self.feature_mlp.append(nn.Sequential(
                # skip with +1 as we already have the initial layer
                nn.Linear(hidden_size + (coord_dim if i+1 in self.skips else 0), hidden_size),
                nn.ReLU(inplace=True),
            ))

        self.sigma_fc = nn.Linear(hidden_size, 1)

        self.color_preproc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(inplace=True),
        )
        dir_dim = self.direction_encoder.get_out_dim(self.in_directions)
        self.rgb_mlp = nn.Sequential(
            nn.Linear(hidden_size + dir_dim, hidden_size // 2),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size // 2, 3),
            nn.Sigmoid(),
        )

    def forward(self, coordinates, directions, skip_colors=False):
        coordinates = self.coordinate_encoder(coordinates)
        features = coordinates
        for i, fc in enumerate(self.feature_mlp):
            if i in self.skips:
                features = torch.cat([features, coordinates], -1)
            features = fc(features)

        sigma = self.sigma_fc(features)

        if skip_colors:
            return sigma

        directions = self.direction_encoder(directions)
        features = self.color_preproc(features)
        features = torch.cat([features, directions], dim=-1)
        rgb = self.rgb_mlp(features)

        return torch.cat([rgb, sigma], dim=-1)
    
    def compute_along_rays(self, origins, directions, near, far, coarse_samples, fine_samples, deterministic=True):
        # This function deviates from the original NeRF paper as the coarse and fine samples are processed by the same model
        points, expanded_directions, coarse_depths = sample_ray_uniformally(origins, directions, near, far, coarse_samples)
        coarse_out = self(points, expanded_directions)
        
        # Bin bounds are halfway between sampled coordinates + near + far plane
        bins = torch.cat([
            torch.tensor(near, dtype=torch.float32).expand(origins.shape[0], 1),
            (coarse_depths[..., 1:] + coarse_depths[..., :-1]) / 2,
            torch.tensor(far, dtype=torch.float32).expand(origins.shape[0], 1),
        ], -1)

        points, expanded_directions, fine_depths = sample_ray_hierarchically(origins, directions, fine_samples, bins,
                                                                             coarse_out[..., -1], deterministic=deterministic)
        fine_out = self(points, expanded_directions)

        # deterministic ensures depth sorted output, if non-deterministic, sort manually as sortedness is required for volume rendering
        if not deterministic:
            fine_depths, idxs = torch.sort(fine_depths, dim=-1)
            fine_out = fine_out[torch.arange(idxs.shape[0]).unsqueeze(1), idxs]

        return coarse_out, coarse_depths, fine_out, fine_depths
    

print(f"Output shape for origins + directions: {NeRF()(rays[:50, 0], rays[:50, 1]).shape}")

coarse_rgba, coarse_depths, fine_rgba, fine_depths = NeRF().compute_along_rays(rays[:500, 0], rays[:500, 1], 0, 10.0, 64, 5, False)
print(f"Coarse output shape for hierarchical sampling: {coarse_rgba.shape=} | {coarse_depths.shape=}")
print(f"Fine output shape for hierarchical sampling:   {fine_rgba.shape=} | {fine_depths.shape=}")