In [None]:
import sys 
sys.path.append("../")

import nerf_model
import dataloader
import nerf_helpers

import cv2
from PIL import Image

import gc
import torch
import itertools
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
def sample_random_coordinates(N, height, width, alpha=None, proportion=0.5): 
    """Two [N,] torch tensors representing random coordinates.

    Args:
        N: int representing number of coordinates to sample
        height: the maximum height value (exclusive)
        width: maximum width value (exclusive)
        alpha: alpha channel of an image, used to weight coordinates to sample.
               No weighting if None.
        proportion: proportion of coordinates that have to have alpha values > 0.
    Returns:
        xs: [N,] torch tensor of random ints [0,width)
        ys: [N,] torch tensor of random ints [0,height)
    """
    if alpha == None:
        xs = torch.randint(0, width, size=(N,))
        ys = torch.randint(0, height, size=(N,))
    else: 
        num_in_alpha = int(N * proportion)
        num_random = N - num_in_alpha
        xs = torch.randint(0, width, size=(num_random*2,))
        ys = torch.randint(0, height, size=(num_random*2,))
        valid = alpha[ys, xs]
        xs = xs[valid]; ys = ys[valid]
        if xs.shape[0] > num_in_alpha:
            xs = xs[:num_in_alpha]; ys = ys[:num_in_alpha]
        xs = torch.randint(0, width, size=(num_random,))
        ys = torch.randint(0, height, size=(num_random,))

    return xs, ys

alpha = torch.rand((800,800), dtype=torch.bool)

sample_random_coordinates(4096, 800, 800, alpha)

In [None]:
def visualize(coords, rgb): 
    a, b = coords.shape
    if b == 3:
        coords = coords.T
    if type(rgb) != str:
        rgb = rgb.T
    plot_fig = go.Scatter3d(x=coords[0], y=coords[1], z=coords[2], 
    mode='markers', marker=dict(
       size=2,
       color=rgb
    ))
    return plot_fig

In [None]:
fig = make_subplots(specs=[[{"secondary_y": True}]])

for batch in iter(sdl):
    gc.collect()

    origins = batch['all_origin'].view((-1,3))[::10,:]
    direcs = batch['all_direc'].view((-1, 3))[::10,:]

    points = origins + direcs

    pic = visualize(points, 'pink')
    o = visualize(origins, 'black')
    fig.add_trace(pic)
    fig.add_trace(o)
fig.show()

In [None]:
base_dir = '../data/lego/'
sdl = dataloader.getSyntheticDataloader(base_dir, 'train', 4096, num_workers=1, shuffle=True)

batch = next(iter(sdl))
nerf_helpers.fix_batchify(batch)
print(batch.keys())
print(batch['rgb'].shape)
print(batch['origin'].shape)
print(batch['direc'].shape)
print(batch['xs'].shape)
# print(batch['all_origin'].shape)

o_rays = batch['origin']
d_rays = batch['direc']
rgb = batch['rgb']
xs = batch['xs']
ys = batch['ys']

In [None]:
print(xs)
print(ys)

In [None]:
o_rays, d_rays, rgb = _crop_rays_outside_center(0, xs, ys, o_rays, d_rays, rgb, edge_width=150)
print(o_rays.shape)

In [None]:
def _crop_rays_outside_center(cropping, xs, ys, o_rays, d_rays, rgb, edge_width=150):
    """Rejects any samples that are outside of the edge with for the next self.cropping iterations. 

    Sampling towards the center when starting training is beneficial because 
    border rgb pixels are (0,0,0) which are not helpful for training. I wanted to
    put this in the dataloader itself, but there's no good ways to keep track of num. 
    iters in the dataloader since it resets every epoch. Hence I put it in the trainer module.

    Args:
        xs: [N,] torch tensor of random ints [0,width)
        ys: [N,] torch tensor of random ints [0,height)
        o_rays: [N x 3] coordinates of the ray origin.
        d_rays: [N x 3] directions of the ray.
        rgb: [N x 3] tensor of colors.
        edge_width: any pixel from the outer edge of the image to the edge_width inwards
                    is removed.
        Returns:
            o_rays, d_rays, and rgb but only in the indices within bounds.
    """
    if cropping > 0:
        IM_HEIGHT = 800 
        IM_WIDTH = 800
        x_idxs = torch.logical_and(xs > edge_width, xs < IM_WIDTH - edge_width)
        y_idxs = torch.logical_and(ys > edge_width, ys < IM_HEIGHT - edge_width)
        idxs = torch.logical_and(x_idxs, y_idxs)
        o_rays = o_rays[idxs,:] 
        d_rays = d_rays[idxs,:]
        rgb = rgb[idxs,:]
        cropping -= 1
    return o_rays, d_rays, rgb

In [None]:
idx = 0
for i in range(2): 
    for batch in tqdm(iter(sdl)):
        no_cropping = (batch['xs'] > 600).sum() + (batch['xs'] < 200).sum().item()
        print("hi")
        if no_cropping > 0: 
            print(idx)
        idx += 1

In [None]:
coarse_fn = network.coarse_network
print(list(coarse_fn.density_fn.children()))
use_relu = list(coarse_fn.density_fn.children())[:-1] + [torch.nn.ReLU()]
torch.nn.Sequential(*use_relu)

In [None]:
from importlib import reload
dataloader = reload(dataloader)
nerf_model = reload(nerf_model)
nerf_helpers = reload(nerf_helpers)
gc.collect()

In [None]:
network = nerf_model.NeRFNetwork(position_dim=10, direction_dim=4, coarse_samples=64,
                 fine_samples=128)

In [None]:
network.train()
pred_dict = network.forward(o_rays, d_rays)
pred_rgbs = pred_dict['pred_rgbs']
loss = F.mse_loss(pred_rgbs, rgb)
loss.backward()
print(loss)

In [None]:
for param in network.coarse_network.parameters():
    print(torch.linalg.norm(param.grad))

In [None]:
for param in network.fine_network.parameters():
    print(torch.linalg.norm(param.grad))

In [None]:
print((pred_dict['all_density'] == 0.0).sum())

In [None]:
pred_dict['all_density'].shape

In [None]:
(pred_dict['all_density'] < 0).sum()