In [None]:
import sys 
sys.path.append("../")

import nerf_model
import dataloader
import nerf_helpers

import cv2
from PIL import Image

import gc
import torch
import itertools
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

### Visualizing Datasets

In [None]:
def visualize(coords, rgb): 
    a, b = coords.shape
    if b == 3:
        coords = coords.T
    if type(rgb) != str:
        a, b = rgb.shape
        if a == 3:
            rgb = rgb.T
    plot_fig = go.Scatter3d(x=coords[0], y=coords[1], z=coords[2], 
    mode='markers', marker=dict(
       size=2,
       color=rgb
    ),)
    return plot_fig

def line_visualize(coords, rgb): 
    a, b = coords.shape
    if b == 3:
        coords = coords.T
    if type(rgb) != str:
        rgb = rgb.T
    plot_fig = go.Scatter3d(x=coords[0], y=coords[1], z=coords[2], 
    mode='lines', line=dict(
       width=1,
       color=rgb
    ),)
    return plot_fig

In [None]:
fig = make_subplots(specs=[[{"secondary_y": True}]])

for i, batch in enumerate(iter(sdl)):
    if i < 5:
        continue
    gc.collect()
    nerf_helpers.fix_batchify(batch)
    origins = batch['all_origin'].reshape((-1,3))[::10,:]
    direcs = batch['all_direc'].reshape((-1, 3))[::10,:]
    images = batch['image'].reshape((-1, 3))[::10,:]
    im_coords = origins + 6 * direcs
    pic = visualize(im_coords, images)
    o = visualize(origins[0,None,:], 'purple')
    fig.add_trace(pic)
    fig.add_trace(o)
fig.show()

In [None]:
fig = make_subplots(specs=[[{"secondary_y": True}]])

for i, batch in enumerate(iter(sdl)):
    if i == 2:
        break
    gc.collect()
    nerf_helpers.fix_batchify(batch)
    origins = batch['all_origin'][200:600,200:600,:].reshape((-1,3))[::300,:]
    direcs = batch['all_direc'][200:600,200:600,:].reshape((-1, 3))[::300,:]
    images = batch['image'][200:600,200:600,:].reshape((-1, 3))[::300,:]
    
    samples, ts = nerf_helpers.generate_coarse_samples(origins, direcs, 16)
    rgb = torch.broadcast_to(images[:, None, :], samples.shape)
    rgb = rgb.reshape((-1,3)).T
    samples = samples.view((-1,3))
    pic = line_visualize(samples, rgb)
    o = visualize(origins[0,None,:], 'purple')
    fig.add_trace(pic)
    fig.add_trace(o)
fig.show()

### Model Debugging

In [None]:
from importlib import reload
dataloader = reload(dataloader)
nerf_model = reload(nerf_model)
nerf_helpers = reload(nerf_helpers)
gc.collect()

In [None]:
base_dir = '../data/lego/'
sdl = dataloader.getSyntheticDataloader(base_dir, 'val', 1, num_workers=1, shuffle=True)

batch = next(iter(sdl))
nerf_helpers.fix_batchify(batch)
print(batch.keys())
print(batch['rgb'].shape)
print(batch['origin'].shape)
print(batch['direc'].shape)
print(batch['xs'].shape)

o_rays = batch['origin']
d_rays = batch['direc']
rgb = batch['rgb']
xs = batch['xs']
ys = batch['ys']

In [None]:
network = nerf_model.NeRFNetwork(position_dim=10, direction_dim=4, coarse_samples=64,
                 fine_samples=128)

In [None]:
network.train()
pred_dict = network.forward(o_rays, d_rays)
fine_rgbs = pred_dict['fine_rgb_rays']
coarse_rgbs = pred_dict['coarse_rgb_rays']

fine_loss = F.mse_loss(fine_rgbs, rgb)
coarse_loss = F.mse_loss(coarse_rgbs, rgb)
loss = coarse_loss + fine_loss
loss.backward()
print(loss)

In [None]:
pred_dict.keys()

In [None]:
pred_dict['coarse_ts']

In [None]:
pred_dict['coarse_deltas']

In [None]:
fc = torch.nn.Linear(2,2)

In [None]:
multi = torch.Tensor([[[1.,2.],[3.,4.]]])
single = torch.Tensor([[1.,2.],[3.,4.]])

In [None]:
fc(multi)

In [None]:
fc(single)