In [1]:
import jax
import flax
import optax
import numpy as np
from jax import lax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state, common_utils


import cv2
import rerun as rr
from functools import cache
from matplotlib import pyplot as plt

rr.init("rr_nerf_video")
# rr.start_web_viewer_server()

In [3]:
def load_nerf_data(path):
    data = np.load(path)
    images, poses, focal = data['images'], data['poses'], float(data['focal'])
    timesteps, img_height, img_width, _ = images.shape

    train_images, train_poses = images[:100], poses[:100]
    val_images, val_poses = images[100:], poses[100:]

    return {
        'images': train_images,
        'poses' : train_poses,
    }, {
        'images': val_images,
        'poses' : val_poses,
    }, {
        'height': img_height,
        'width' : img_width,
        'focal' : focal,
    }

train_data, val_data, param_data = load_nerf_data('./tiny_nerf_data.npz')

In [6]:
@cache
def generate_directions(height, width, focal):
    """
    Take the height, width and focal length of the image and generate a tensor of shape [H, W, 3].
    Where each pixel in the image has a vector pointing at it from the focal point. 

    Args:
        height (int): The height of the image.
        width (int) : The width of the image.
        focal (float): The focal length of the camera
    """
    
    i, j = np.meshgrid(np.arange(width), np.arange(height), indexing="xy")
    
    transformed_i = (i - (width/2.)) / focal
    transformed_j = (j - (height/2.))/ focal

    k = -np.ones_like(i)
    directions = np.stack([transformed_i, transformed_j, k], axis=-1)
    return directions
    

def generate_rays(height, width, focal, pose):
    directions = generate_directions(height, width, focal)

    # Equivalent to: np.dot(directions, pose[:3, :3].T)
    ray_directions = np.einsum("ijl,kl", directions, pose[:3, :3])
    
    ray_origins = np.broadcast_to(pose[:3, -1], ray_directions.shape)

    print(ray_origins.shape, ray_directions.shape)

    return {
        'origins': ray_origins,
        'directions': ray_directions
    }

rays = generate_rays(param_data['height'], param_data['width'], param_data['focal'], train_data['poses'][0])

rec = rr.memory_recording()
rr.log("ground_truth/rays", rr.Arrows3D(origins = np.reshape( rays['origins'], [-1, 3] ), vectors = np.reshape( rays['directions'], [-1, 3] )))
rec

(100, 100, 3) (100, 100, 3)


[2023-11-01T23:45:24Z WARN  re_sdk::log_sink] Dropping data in MemorySink


In [8]:
def compute_3d_points(ray_origins, ray_directions, random_number_generator=None, near_bound = 2., far_bound = 6., num_sample_points = 256):
    """
    This function computes 3D query points for volumetric rendering along rays defined by their origins and directions.
    It parametrically samples points along each ray within specified bounds and returns the 3D query points as well as the corresponding parameter values.
    Optionally, uniform noise can be added to the sample space to make it more continuous when a random number generator is provided.

    The ray equation used for point computation is: r(t) = o + t * d, where:
    - r(t) is the 3D point along the ray.
    - o is the ray origin.
    - t is the parameter that varies from near_bound to far_bound.
    - d is the ray direction vector.

    Parameters:
    -----------
    ray_origins : numpy.ndarray
        Array of ray origins with shape (..., 3).

    ray_directions : numpy.ndarray
        Array of ray directions with shape (..., 3).

    random_number_generator : jax.random.PRNGKey, optional
        A random number generator for adding noise to `t_vals`. Default is None.

    near_bound : float, optional
        The lower bound of the parametric distance. Default is 2.0.

    far_bound : float, optional
        The upper bound of the parametric distance. Default is 6.0.

    num_sample_points : int, optional
        The number of sample points along each ray. Default is 256.

    Returns:
    --------
    points : numpy.ndarray
        An array of 3D points computed along each ray.
        Shape: (..., num_sample_points, 3).

    t_vals : numpy.ndarray
        An array of parameter values representing the positions of the 3D points along each ray.
        Shape: (..., num_sample_points).

    Example:
    --------
    >>> rays = generate_rays(...)
    >>> points, t_vals = compute_3d_points(rays['origins'], rays['directions'])
    """
    # Sample space to parametrically compute the ray points
    t_vals = np.linspace(near_bound, far_bound, num_sample_points)
    print(t_vals.shape)
    if random_number_generator is not None:
        # inject a uniform noise into the sample space to make it continuous
        t_shape = ray_origins.shape[:-1] + (num_sample_points,)
        noise = jax.random.uniform(
            random_number_generator, t_shape
        ) * (far_bound - near_bound) / num_sample_points
        t_vals = t_vals + noise
    
    # Compute the ray traversal points using r(t) = o + t * d
    ray_origins = ray_origins[..., None, :]
    ray_directions = ray_directions[..., None, :]
    t_vals_flat = t_vals[..., :, None]
    points = ray_origins + ray_directions * t_vals_flat
    print("XXXXX", points.shape)
    return points, t_vals

points, t_vals = compute_3d_points( rays['origins'], rays['directions'] )


print(points.shape, t_vals.shape)

rec_points = np.concatenate(
    (
        np.reshape(rays['origins'], [-1, 3])[::4,:],
        np.reshape(points[::4, ::4, ::4, :], [-1,3])
    ), axis = 0
)
rec = rr.memory_recording()
rr.log(
    "ground_truth/3d-points", 
    rr.Points3D(rec_points, colors = (0, 255, 255)),
)

print(rays['directions'].shape, points.shape)

rec

(256,)
XXXXX (100, 100, 256, 3)
(100, 100, 256, 3) (256,)
(100, 100, 3) (100, 100, 256, 3)


[2023-11-01T23:47:39Z WARN  re_sdk::log_sink] Dropping data in MemorySink


In [54]:
class Nerf(nn.Module):
    num_layers: int = 8
    width: int = 256
    dtype = jnp.float32
    precision = lax.Precision.DEFAULT

    positional_encoding_dims: int = 6
    
    def setup(self):
        layers = []
        for layer in range(self.num_layers):
            layers.append(
                nn.Dense(self.width, use_bias = True, dtype = self.dtype, precision = self.precision)
            )
        self.ls = layers

        self.out = nn.Dense(4, dtype = self.dtype, precision = self.precision)

        self.pos_enc_mat = jax.vmap(
            lambda x: 2. ** x
        )(jnp.arange(self.positional_encoding_dims))[np.newaxis, :]

    def positional_encoding(self, x):
        bs = x.shape[0]
        inputs_freq = x * self.pos_enc_mat
        periodic_fns = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)], axis = -1)
        periodic_fns = periodic_fns.reshape([bs, -1])
        periodic_fns = jnp.concatenate([x, periodic_fns], axis = -1)
        return periodic_fns
    
    def __call__(self, x):
        y = x
        for i in range(self.num_layers):
            l = self.ls[i]
            y = l(y)
            y = nn.relu(y)
            if i == 4:
                y = jnp.concatenate(
                    [ y, x ], axis = -1
                )
        y = self.out(y)
        return y

# rng = jax.random.PRNGKey(42)
# rng, inp_rng, init_rng = jax.random.split(rng, 3)

# input = jax.random.normal(inp_rng, (10, 6))

# n = Nerf()
# n_params = n.init(init_rng, input)

# o = n.apply(n_params, input, method = Nerf.positional_encoding)
# #o = n.positional_encoding(jnp.array(np.random.rand(10, 6)))
# print(o.shape)
# del n, o, rng, n_params
points.shape

(100, 100, 256, 3)

In [28]:
def forward_pass(model, params, points):
    batch_size = points.shape[0]
    model_output = lax.map( 
        lambda x: 

Array([[0.31241986, 0.5143503 , 0.70449203, 0.15713929, 0.7353689 ,
        0.06814588, 0.50243795, 0.17679498, 0.12858652, 0.23371685],
       [0.08318126, 0.95672715, 0.2124801 , 0.16399641, 0.0363207 ,
        0.9557558 , 0.13908705, 0.8780257 , 0.7389254 , 0.33879337],
       [0.13222605, 0.48189947, 0.7746314 , 0.41167384, 0.3951409 ,
        0.87194014, 0.74472666, 0.6649564 , 0.5595958 , 0.46317744],
       [0.38234416, 0.02431299, 0.07739554, 0.69124   , 0.29505625,
        0.33975714, 0.87353826, 0.8001659 , 0.35694864, 0.72848016],
       [0.36674199, 0.40733728, 0.76831585, 0.15489765, 0.2293656 ,
        0.5034998 , 0.88148266, 0.5861722 , 0.85926497, 0.43841758],
       [0.823101  , 0.05872972, 0.28203857, 0.6043108 , 0.22602063,
        0.87232107, 0.2373956 , 0.64771074, 0.49281663, 0.21767734],
       [0.02035649, 0.7400929 , 0.2502096 , 0.31951928, 0.75557643,
        0.91947335, 0.6760751 , 0.4499451 , 0.58007526, 0.9826965 ],
       [0.5032982 , 0.15859026, 0.5080640

(100, 100, 256, 3)