In [None]:
import sys 
sys.path.append("../")

import nerf_model
import dataloader

import plotly
import torch 
import cv2
from PIL import Image
import itertools 
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

In [None]:
from importlib import reload 
dataloader = reload(dataloader)

In [None]:
ds = dataloader.SyntheticDataset('../data/toy_data/', 'train', 4096)
dl = dataloader.getSyntheticDataloader('../data/toy_data/', 'train', 4096)
batch = next(iter(dl))
print(batch.keys())
# images = batch['image']
# poses = batch['cam_to_world']
# N, C, H, W = images.shape

In [None]:
batch['origin'].shape

In [None]:
def visualize(coords, rgb): 
    a, b = coords.shape
    if b == 4:
        coords = coords.T
    d, N = coords.shape
    if type(rgb) != str:
        rgb = rgb.T
    plot_fig = go.Scatter3d(x=coords[0], y=coords[1], z=coords[2], 
    mode='markers', marker=dict(
       size=2,
       color=rgb
    ))
    return plot_fig

## Figuring out how to convert images to rays

In [None]:
def convert_to_ndc_rays(o_rays, d_rays, focal, width, height, near=1.0): 
    """
    Args:
        d_rays: [N x 4] representing ray direction 
        o_rays: [N x 4] representing ray origin 
        angle: camera_angle_x
        width: the maximum width 
        height: the maximum height
        near: the near depth bound (i.e. 1)
    """
    t_near  = - (near + o_rays[:,:,2] ) / d_rays[:,:,2] 
    o_rays = o_rays + t_near[...,None] * d_rays
    ox, oy, oz = o_rays[:,:,0], o_rays[:,:,1], o_rays[:,:,2] 
    dx, dy, dz = d_rays[:,:,0], d_rays[:,:,1], d_rays[:,:,2] 
    
    ox_new =  -1.0 * focal / (width / 2) * (ox / oz)
    oy_new =  -1.0 * focal / (height / 2) * (oy / oz)
    oz_new = 1.0 + (2 * near) / oz
    
    dx_new =  -1.0 * focal / (width / 2) * ((dx / dz) - (ox / oz))
    dy_new =  -1.0 * focal / (height / 2) * ((dy / dz) - (oy / oz))
    dz_new = (- 2 * near) / oz
    
    o_rays_new = torch.stack([ox_new, oy_new, oz_new], axis=-1)
    d_rays_new = torch.stack([dx_new, dy_new, dz_new], axis=-1)
    
    return o_rays_new, d_rays_new

def ndc_rays(H, W, focal, near, rays_o, rays_d):
    """Normalized device coordinate rays.
    Space such that the canvas is a cube with sides [-1, 1] in each axis.
    Args:
      H: int. Height in pixels.
      W: int. Width in pixels.
      focal: float. Focal length of pinhole camera.
      near: float or array of shape[batch_size]. Near depth bound for the scene.
      rays_o: array of shape [batch_size, 3]. Camera origin.
      rays_d: array of shape [batch_size, 3]. Ray direction.
    Returns:
      rays_o: array of shape [batch_size, 3]. Camera origin in NDC.
      rays_d: array of shape [batch_size, 3]. Ray direction in NDC.
    """
    # Shift ray origins to near plane
    t = -(near + rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. + 2. * near / rays_o[..., 2]

    d0 = -1./(W/(2.*focal)) * \
        (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2])
    d1 = -1./(H/(2.*focal)) * \
        (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2])
    d2 = -2. * near / rays_o[..., 2]

    rays_o = torch.stack([o0, o1, o2], -1)
    rays_d = torch.stack([d0, d1, d2], -1)

    return rays_o, rays_d



In [None]:
rays_o = torch.rand((1,10,3))
rays_d = torch.rand((1,10,3))
ndc_o_rays, ndc_d_rays = convert_to_ndc_rays(rays_o, rays_d, 0.6, 4, 4)

bmild_o_rays, bmild_d_rays = ndc_rays(4, 4, 0.6, 1, rays_o.squeeze(0), rays_d.squeeze(0))
np.testing.assert_allclose(ndc_o_rays.squeeze(0).numpy(), bmild_o_rays.numpy()) #, rtol=1e-6)

In [None]:
print(rays_o.shape)
ndc_o_rays, ndc_d_rays = convert_to_ndc_rays(rays_o, rays_d, 0.6, 4, 4)

print(ndc_o_rays.shape)

bmild_o_rays, bmild_d_rays = ndc_rays(4, 4, 0.6, 1, rays_o, rays_d)

np.testing.assert_allclose(ndc_o_rays.numpy(), bmild_o_rays.numpy(), rtol=1e-6)

In [None]:
def np_get_rays(H, W, focal, c2w):
    """Get ray origins, directions from a pinhole camera."""
    i, j = np.meshgrid(np.arange(W, dtype=np.float32),
                       np.arange(H, dtype=np.float32), indexing='xy')
    dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
    rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
    return rays_o, rays_d

np_rays_o, np_rays_d = np_get_rays(4, 4, 0.6, pose.numpy())

def get_rays(H, W, focal, c2w):
    """Get ray origins, directions from a pinhole camera."""
    i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32),
                       torch.arange(H, dtype=torch.float32), indexing='xy')
    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
    rays_d = torch.sum(dirs[..., None, :] * c2w[:3, :3], -1)
    rays_o = torch.broadcast_to(c2w[:3, -1], rays_d.shape)
    return rays_o, rays_d

rays_o, rays_d = get_rays(4, 4, 0.6, pose)

np.testing.assert_allclose(rays_o.numpy(), np_rays_o)
np.testing.assert_allclose(rays_d.numpy(), np_rays_d)



In [None]:
pose = ds[0]['cam_to_world']
origin = torch.Tensor([[0,0,0,1]]).T
print(pose @ origin)
print(pose[:3, -1])

In [None]:
camera = dict(
    up=dict(x=0, y=1, z=0),
#     center=dict(x=0, y=0, z=0),
#     eye=dict(x=1.25, y=1.25, z=1.25)
)

fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.update_layout(scene_camera=camera)
for i, batch in enumerate(ds):
    world_coords = batch['world_coords'] # [4096 x 4]
    pixel_coords = batch['pixel_coords'] # [4096 x 4]
    pixels = pixel_coords[:,:2].long() # [4096 x 2]
    image = batch['image'] # [4 x 800 x800]
    rgba = image[:, pixels[:,0], pixels[:,1]] # [4 x 4096]
    idxs = rgba[3,:] >= 1e-5
    rgb = rgba[:3, idxs]
    world_coords = world_coords[idxs, :]
    plot = visualize(world_coords, rgb)
    cam = visualize(batch['cam_in_world'], 'blue')
    
    fig.add_trace(plot)
    fig.add_trace(cam)

fig.show()

## Batch Training Simulation

In [None]:
pixel_coords = torch_generate_ih_coordinates(H, W, cam_angle)
pixel_coords = pixel_coords.reshape((-1, 4)).T.unsqueeze(0)
world_coords = torch.bmm(poses, pixel_coords).swapaxes(1,2)
world_coords = world_coords.reshape((4, -1))[:3,:] # [::10]
world_coords = world_coords[:,::10]

fig = make_subplots(specs=[[{"secondary_y": True}]])

origin_fig = go.Scatter3d(x=world_coords[0], y=world_coords[1], z=world_coords[2], 
   mode='markers', marker=dict(
       size=2,
       color='purple'
  ))
fig.add_trace(origin_fig)
fig.show()    

# fig.write_html("./batch.html")

## Plot all Datapoints
can't really tell if its correct or not but it looks fairly close enough? 

In [None]:
cam_angle = cam_angles[0]
cam_coords = generate_ih_coordinates(H, W, cam_angle).reshape((-1, 4))
camera = np.array([0,0,0,1]).reshape((4,1))
fig = make_subplots(specs=[[{"secondary_y": True}]])
origin_fig = go.Scatter3d(x=[0], y=[0], z=[0], 
   mode='markers', marker=dict(
       size=2,
       color='purple'
  ))
fig.add_trace(origin_fig)

for i in range(N):
    image = images[i]; pose = poses[i]
    image = image.numpy().reshape((4,-1))
    opacity = image[3,:] > 1e-5; rgb = image[:3,opacity].T
    hit_coords = cam_coords[opacity,:]
    decimate = 20
    rgb = rgb[::decimate]
    hit_coords = hit_coords[::decimate].T
    w_coords = (pose @ hit_coords)
    img_fig = go.Scatter3d(x=w_coords[0], y=w_coords[2], z=w_coords[1],
               mode='markers', marker=dict(
                   size=1,
                   color=rgb
               ))
    cam_pose = pose @ camera
    cam_fig = go.Scatter3d(x=cam_pose[0], y=cam_pose[2], z=cam_pose[1],
               mode='markers', marker=dict(
                   size=5,
                   color='red'
               ))
    fig.add_trace(img_fig)
    fig.add_trace(cam_fig)
    
fig.show()    

fig.write_html("./project_all.html")

## Plotting in Camera Frame

In [None]:
ds = dataloader.SyntheticDataset('../data/toy_data/', 'train')
frame = ds[0]

print(frame.keys())

In [None]:
camera = np.array([0,0,0,1]).reshape((4,1))
cam_points = [camera]
for frame in ds:
    c2w = frame['transform_matrix']
    camera = np.array([0,0,0,1]).reshape((4,1))
    point = c2w @ camera
    cam_points.append(point)
cam_points = np.stack(cam_points).reshape((-1,4))
print(cam_points)

In [None]:

x, y, z, _ = cam_points.T
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                                   mode='markers', marker=dict(
                                       size=5,
                                       color=['red'],
                                   ))])

fig.show()

## Plotting in World Frame

In [None]:
focal_length = np.arctan(0.5 * ds.camera_angle)
print(ds.camera_angle)
print(focal_length)

In [None]:
height = 800; width = 800
x_center = height // 2
y_center = width // 2
row_coords, col_coords = np.mgrid[0:height, 0:width]
row_coords = (row_coords - x_center) / height
col_coords = (col_coords - y_center) / width
z_axis = np.full((height, width), - focal_length)
perspective = np.ones((height, width))

cam_coords = np.stack([row_coords, col_coords, z_axis,  perspective], axis=0)
print(cam_coords.shape)
print(frame['image'].shape)

In [None]:
image = frame['image'].numpy().reshape((4,-1))
opacity = image[3,:] > 1e-5
rgb = image[:3,opacity].T #reshape((-1,3))

cam_coords = cam_coords.reshape((4, -1))
x, y, z, _ = cam_coords[:,opacity]

img_fig = go.Scatter3d(x=x, y=y, z=z,
                       mode='markers', marker=dict(
                           size=2,
                           color=rgb
                       ))
origin_fig = go.Scatter3d(x=[0], y=[0], z=[0], 
                       mode='markers', marker=dict(
                           size=2,
                           color='purple'
                      ))

fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(img_fig)
fig.add_trace(origin_fig)
fig.show()

fig.write_html("./cam_perspective_4.html")

In [None]:
world_coords = c2w @ cam_coords
x, y, z = world_coords[:3, opacity] / world_coords[3, opacity]

img_fig = go.Scatter3d(x=x, y=y, z=z,
                       mode='markers', marker=dict(
                           size=2,
                           color=rgb
                       ))
cam_x, cam_y, cam_z, _ = cam_points[-1]
cam_fig = go.Scatter3d(x=[cam_x], y=[cam_y], z=[cam_z], 
                       mode='markers', marker=dict(
                           size=2,
                           color='purple'
                      ))


origin_fig = go.Scatter3d(x=[0], y=[0], z=[0], 
                       mode='markers', marker=dict(
                           size=2,
                           color='red'
                      ))

fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(img_fig)
fig.add_trace(origin_fig)
fig.add_trace(cam_fig)
fig.show()

fig.write_html("./world_perspective_4.html")