In [1]:
import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

In [2]:
from load_llff import *
from utils import *

In [3]:
device = torch.device("cuda")
np.random.seed(0)
torch.set_default_tensor_type('torch.cuda.FloatTensor')

## Load Data

In [4]:
parser = config_parser()
args = parser.parse_args('--config C:/Users/chuzh/Study/CIS565/final/Neural-Radiance-Fields-with-Refractions/code/configs/fern.txt')

In [5]:
# Load data
K = None
if args.dataset_type == 'llff':
    images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
                                                                  recenter=True, bd_factor=.75,
                                                                  spherify=args.spherify)
    hwf = poses[0,:3,-1]
    poses = poses[:,:3,:4]
    print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
    if not isinstance(i_test, list):
        i_test = [i_test]

    if args.llffhold > 0:
        print('Auto LLFF holdout,', args.llffhold)
        i_test = np.arange(images.shape[0])[::args.llffhold]

    i_val = i_test
    i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                        (i not in i_test and i not in i_val)])

    print('DEFINING BOUNDS')
    if args.no_ndc:
        near = np.ndarray.min(bds) * .9
        far = np.ndarray.max(bds) * 1.
            
    else:
        near = 0.
        far = 1.
    print('NEAR FAR', near, far)

elif args.dataset_type == 'blender':
    images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
    print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
    i_train, i_val, i_test = i_split

    near = 2.
    far = 6.

    if args.white_bkgd:
        images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
    else:
        images = images[...,:3]

elif args.dataset_type == 'LINEMOD':
    images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
    print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
    print(f'[CHECK HERE] near: {near}, far: {far}.')
    i_train, i_val, i_test = i_split

    if args.white_bkgd:
        images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
    else:
        images = images[...,:3]

elif args.dataset_type == 'deepvoxels':

    images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
                                                                 basedir=args.datadir,
                                                                 testskip=args.testskip)

    print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
    i_train, i_val, i_test = i_split

    hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
    near = hemi_R-1.
    far = hemi_R+1.

else:
    print('Unknown dataset type', args.dataset_type, 'exiting')

# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]

if K is None:
    K = np.array([
        [focal, 0, 0.5*W],
        [0, focal, 0.5*H],
        [0, 0, 1]
    ])

if args.render_test:
    render_poses = np.array(poses[i_test])

# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
    for arg in sorted(vars(args)):
        attr = getattr(args, arg)
        file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
    f = os.path.join(basedir, expname, 'config.txt')
    with open(f, 'w') as file:
        file.write(open(args.config, 'r').read())

Loaded image data (378, 504, 3, 20) [378.         504.         407.56579161]
Loaded ./data/nerf_llff_data/fern 16.985296178676084 80.00209740336334
recentered (3, 5)
[[ 1.0000000e+00  0.0000000e+00  0.0000000e+00  1.4901161e-09]
 [ 0.0000000e+00  1.0000000e+00 -1.8730975e-09 -9.6857544e-09]
 [-0.0000000e+00  1.8730975e-09  1.0000000e+00  0.0000000e+00]]
Data:
(20, 3, 5) (20, 378, 504, 3) (20, 2)
HOLDOUT view is 12
Loaded llff (20, 378, 504, 3) (120, 3, 5) [378.     504.     407.5658] ./data/nerf_llff_data/fern
Auto LLFF holdout, 8
DEFINING BOUNDS
NEAR FAR 0.0 1.0


## Create DataSet

In [6]:
N_rand = args.N_rand

In [7]:
print('get rays')
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)

print('done')
i_batch = 0

images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
rays_rgb = torch.Tensor(rays_rgb).to(device)

get rays
done, concats
shuffle rays
done


## Define Model

In [8]:
class NeRF(nn.Module):
    def __init__(self, in_pos = 3, in_view = 3, hidden = 256):
        super(NeRF, self).__init__()
        self.in_pos = in_pos
        self.in_view = in_view
        self.l1 = nn.Linear(in_pos, hidden)
        self.l2 = nn.Linear(hidden, hidden)
        self.l3 = nn.Linear(hidden, hidden)
        self.l4 = nn.Linear(hidden, hidden)        
        self.l5 = nn.Linear(in_pos + hidden, hidden)
        self.l6 = nn.Linear(hidden, hidden)    
        self.l7 = nn.Linear(hidden, hidden)   
        self.l8 = nn.Linear(hidden, hidden)   
        self.l9 = nn.Linear(hidden, hidden)   
        self.l0 = nn.Linear(hidden + in_view, hidden//2)
        self.rgb = nn.Linear(hidden//2, 3)
        self.sigma = nn.Linear(hidden, 1)
        self.r = nn.ReLU()
        self.s = nn.Sigmoid()
        
        self.o1 = nn.Linear(in_pos + in_view, hidden)
        self.o2 = nn.Linear(hidden, hidden)
        self.o3 = nn.Linear(hidden, hidden)
        self.o4 = nn.Linear(in_pos + in_view + hidden, hidden)
        self.o5 = nn.Linear(hidden, hidden)
        self.o6 = nn.Linear(hidden, hidden//2)
        self.out = nn.Linear(hidden//2, 3)
    
    def nerf(self, pos,view):
        out = self.r(self.l1(pos))
        out = self.r(self.l2(out))
        out = self.r(self.l3(out))
        out = self.r(self.l4(out))
        out = self.r(self.l5(torch.cat([pos, out], -1)))
        out = self.r(self.l6(out))
        out = self.r(self.l7(out))
        out = self.r(self.l8(out))
        sigma = self.sigma(out)
        out = self.l9(out)
        out = self.r(self.l0(torch.cat([view, out], -1)))
        rgb = self.s(self.rgb(out))
        return torch.cat([rgb, sigma], -1)
    
    def offset(self, x):
        pos, view = torch.split(x, [self.in_pos, self.in_view], dim=-1)
        out = self.r(self.o1(x))
        out = self.r(self.o2(out))
        out = self.r(self.o3(out))
        out = self.r(self.o4(torch.cat([x, out], -1)))
        out = self.r(self.o5(out))
        out = self.r(self.o6(out))
        out = self.out(out)
        return out
    
    def forward(self, x):
        pos, view = torch.split(x, [self.in_pos, self.in_view], dim=-1)
        offset = self.offset(x)
        return self.nerf(pos + offset, view)

## Render Image

In [9]:
def render(H, W, K, chunk = 1024*32, rays = None, c2w = None, ndc = True, near = 0., far = 1., **kwargs):
    if c2w is not None:
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        rays_o, rays_d = rays
        
    viewdirs = rays_d
    viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
    viewdirs = torch.reshape(viewdirs, [-1,3]).float()
    
    shape = rays_d.shape
    if ndc:
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
    
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()
    
    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    rays = torch.cat([rays, viewdirs], -1)
    
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(shape[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]
    

In [10]:
def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3,-1].expand(rays_d.shape)
    return rays_o, rays_d

In [11]:
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret

In [12]:
def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False):
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6]
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1] # [-1,1]
    viewdirs = ray_batch[:,8:11]
    
    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
    z_vals = z_vals.expand([N_rays, N_samples])
    
    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)
        
        z_vals = lower + (upper - lower) * t_rand
    
    pos = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    raw = network_query_fn(pos, viewdirs, network_fn)
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd)
    
    ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
    if retraw:
        ret['raw'] = raw
    
    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()):
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret

In [13]:
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False):
    
    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
    
    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)
    
    dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
    
    rgb = raw[...,:3]
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[...,3].shape) * raw_noise_std
    
    alpha = raw2alpha(raw[...,3] + noise, dists)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)
    
    depth_map = torch.sum(weights * z_vals, -1)
    disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, -1)

    if white_bkgd:
        rgb_map = rgb_map + (1.-acc_map[...,None])

    return rgb_map, disp_map, acc_map, weights, depth_map    

## Batchify Network

In [14]:
def batchify(fn, chunk):
    """Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn
    def ret(inputs):
        return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret


def run_network(inputs, viewdirs, fn, netchunk=1024*64):
    """Prepares inputs and applies network 'fn'.
    """
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    input_dirs = viewdirs[:,None].expand(inputs.shape)
    input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
    
    out = torch.cat([inputs_flat, input_dirs_flat], -1)
     
    outputs_flat = batchify(fn, netchunk)(out)

    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs

## Create Model

In [15]:
in_pos = 3
in_view = 3
hidden = args.netwidth
model = NeRF(in_pos, in_view, hidden)
grad_vars = list(model.parameters())

network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                     netchunk=args.netchunk)
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

start = 0
basedir = args.basedir
expname = args.expname

render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

if args.dataset_type != 'llff' or args.no_ndc:
    print('Not ndc!')
    render_kwargs_train['ndc'] = False
    render_kwargs_train['lindisp'] = args.lindisp
    
render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.

global_step = start
bds_dict = {
    'near' : near,
    'far' : far,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)

render_poses = torch.Tensor(render_poses).to(device)

## Train model

In [16]:
N_iters = 100000 + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)

Begin
TRAIN views are [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19]
TEST views are [ 0  8 16]
VAL views are [ 0  8 16]


In [17]:
start = start + 1
for i in trange(start, N_iters):
    time0 = time.time()
    batch = rays_rgb[i_batch:i_batch+N_rand]
    batch = torch.transpose(batch, 0, 1)
    batch_rays, target_s = batch[:2], batch[2]
    
    i_batch += N_rand
    if i_batch >= rays_rgb.shape[0]:
        print("Shuffle data after an epoch!")
        rand_idx = torch.randperm(rays_rgb.shape[0])
        rays_rgb = rays_rgb[rand_idx]
        i_batch = 0
    rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                                verbose=i < 10, retraw=True,
                                                **render_kwargs_train)
    optimizer.zero_grad()
    img_loss = img2mse(rgb, target_s)
    loss = img_loss
    psnr = mse2psnr(img_loss)
    loss.backward()
    optimizer.step()
        
    decay_rate = 0.1
    decay_steps = args.lrate_decay * 1000
    new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate
        
    if i%args.i_print==0:
        tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
            
    if i%args.i_video==0 and i > 0:
        # Turn on testing mode
        with torch.no_grad():
            rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
        print('Done, saving', rgbs.shape, disps.shape)
        moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
        imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
        imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

            # if args.use_viewdirs:
            #     render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
            #     with torch.no_grad():
            #         rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
            #     render_kwargs_test['c2w_staticcam'] = None
            #     imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

    if i%args.i_testset==0 and i > 0:
        testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', poses[i_test].shape)
        with torch.no_grad():
            render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
        print('Saved test set')

  0%|▏                                                                                                                                                                      | 101/100000 [00:10<2:17:36, 12.10it/s]

[TRAIN] Iter: 100 Loss: 0.03330076485872269  PSNR: 14.775458335876465


  0%|▎                                                                                                                                                                      | 201/100000 [00:20<2:34:06, 10.79it/s]

[TRAIN] Iter: 200 Loss: 0.027960598468780518  PSNR: 15.534534454345703


  0%|▌                                                                                                                                                                      | 302/100000 [00:31<2:19:41, 11.89it/s]

[TRAIN] Iter: 300 Loss: 0.02517564967274666  PSNR: 15.990192413330078


  0%|▋                                                                                                                                                                      | 401/100000 [00:41<2:23:24, 11.58it/s]

[TRAIN] Iter: 400 Loss: 0.02228948101401329  PSNR: 16.518999099731445


  1%|▊                                                                                                                                                                      | 501/100000 [00:51<2:17:59, 12.02it/s]

[TRAIN] Iter: 500 Loss: 0.025650298222899437  PSNR: 15.909074783325195


  1%|█                                                                                                                                                                      | 601/100000 [01:01<2:18:27, 11.96it/s]

[TRAIN] Iter: 600 Loss: 0.022476783022284508  PSNR: 16.48265838623047


  1%|█▏                                                                                                                                                                     | 701/100000 [01:10<2:41:55, 10.22it/s]

[TRAIN] Iter: 700 Loss: 0.02206098660826683  PSNR: 16.563751220703125


  1%|█▎                                                                                                                                                                     | 800/100000 [01:21<2:41:47, 10.22it/s]

[TRAIN] Iter: 800 Loss: 0.019963767379522324  PSNR: 16.997573852539062


  1%|█▍                                                                                                                                                                     | 833/100000 [01:24<2:47:56,  9.84it/s]


KeyboardInterrupt: 

In [None]:
def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):

    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor

    rgbs = []
    disps = []

    t = time.time()
    for i, c2w in enumerate(tqdm(render_poses)):
        print(i, time.time() - t)
        t = time.time()
        rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())
        if i==0:
            print(rgb.shape, disp.shape)

        """
        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
            print(p)
        """

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)


    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    return rgbs, disps


In [None]:
with torch.no_grad():
    rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
    print('Done, saving', rgbs.shape, disps.shape)
    moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
    imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
    imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

In [None]:
testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', poses[i_test].shape)
with torch.no_grad():
    render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)

## Various Tests

In [None]:
batch = rays_rgb[i_batch:i_batch+N_rand]

In [None]:
batch = torch.transpose(batch, 0, 1)

In [None]:
batch.shape

In [None]:
batch_rays, target_s = batch[:2], batch[2]

In [None]:
i_batch += N_rand

In [None]:
batch_rays.shape

In [None]:
chunk=args.chunk
rays=batch_rays
verbose=True
retraw=True

In [None]:
rays_o, rays_d = rays

In [None]:
sh = rays_d.shape

In [None]:
sh

In [None]:
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

In [None]:
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()

In [None]:
rays_o.shape

In [None]:
near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])

In [None]:
rays = torch.cat([rays_o, rays_d, near, far], -1)

In [None]:
ray_batch = rays

In [None]:
rays[:,-3:].shaperays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6]

In [None]:
ray_batch[:,3:6]-rays_d

In [None]:
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])

In [None]:
bounds.shape

In [None]:
args.use_viewdirs

In [None]:
viewdirs = rays_d

In [None]:
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)

In [None]:
viewdirs = torch.reshape(viewdirs, [-1,3]).float()

In [None]:
viewdirs

In [None]:
rays_d

In [None]:
model = NeRF()

In [None]:
model(rays_o,viewdirs)