In [None]:
# Allow importing from src
import sys
sys.path.insert(0, '../src/')

## Imports

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

## Example data

In [None]:
data =  np.load("../data/tiny_nerf_data.npz")
images, c2ws, focal = data["images"], data["poses"], data["focal"]

print(
    f"Shapes:",
    f"Images: {images.shape}",
    f"C2ws:   {c2ws.shape}",
    f"Focal:  {focal.shape}",
    sep='\n  ',
)

plt.imshow(images[2])
plt.title("Image at index 2")

print(f"C2W transform at index 2: \n", c2ws[2])

print(f"Focal length: {focal:.4f}")

images = torch.from_numpy(images)
c2ws = torch.from_numpy(c2ws)
focal = torch.from_numpy(focal)

## Utility functions

To be exported into `ROOT_DIR/src/utils/`

In [None]:
def create_rays(height, width, intrinsic, c2w):
    focal_x = intrinsic[0, 0]
    focal_y = intrinsic[1, 1]
    # cx and cy handle the misalignement of the principal point with the center of the image
    cx = intrinsic[0, 2]
    cy = intrinsic[1, 2]

    # Index each point on the image, determine ray directions to them
    i, j = torch.meshgrid(torch.arange(width, dtype=torch.float32), torch.arange(height, dtype=torch.float32), indexing='xy')
    directions = torch.stack((
        (i - cx) / focal_x,
        -(j - cy) / focal_y,
        -torch.ones(i.shape, dtype=torch.float32)  # -1 since ray is cast away from camera
    ), -1)

    # Transform ray directions to World, origins just need to be broadcasted accordingly
    ray_directions = directions @ c2w[:3, :3].T
    ray_origins = torch.broadcast_to(c2w[:3, -1], ray_directions.shape)  # c2w last column determines position
    
    return ray_origins, ray_directions


# Test on real data
ex_index = 2
ex_img = images[ex_index]
ex_intr = torch.tensor([
    [focal.item(), 0, ex_img.shape[1] // 2],
    [0, focal.item(), ex_img.shape[0] // 2],
    [0, 0, 1],
], dtype=torch.float32)

o, d = create_rays(ex_img.shape[0], ex_img.shape[1], ex_intr, c2ws[ex_index])

"""
# Test on example data
ex_intr = torch.tensor([
    [4, 0, 5 // 2],
    [0, 4, 5 // 2],
    [0, 0, 1],
], dtype=torch.float32)

ex_c2w = torch.tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 1],
    [0, 0, 0, 1],
], dtype=torch.float32)

o, d = create_rays(5, 5, ex_intr, ex_c2w)
"""

print(o.shape, d.shape)