In [1]:
import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import commentjson as json
import tinycudann as tcnn

import matplotlib.pyplot as plt

In [2]:
from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
from load_blender import load_blender_data
from load_LINEMOD import load_LINEMOD_data
from utils import *

In [3]:
device = torch.device("cuda")
np.random.seed(0)
torch.set_default_tensor_type('torch.cuda.FloatTensor')

## Load Data

In [4]:
parser = config_parser()
args = parser.parse_args('--config C:/Users/chuzh/Study/CIS565/final/Neural-Radiance-Fields-with-Refractions/code/configs/lego.txt')

In [5]:
with open(args.nnconfig) as config_file:
    nnconfig = json.load(config_file)

In [6]:
# Load data
K = None
if args.dataset_type == 'llff':
    images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
                                                                  recenter=True, bd_factor=.75,
                                                                  spherify=args.spherify)
    hwf = poses[0,:3,-1]
    poses = poses[:,:3,:4]
    print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
    if not isinstance(i_test, list):
        i_test = [i_test]

    if args.llffhold > 0:
        print('Auto LLFF holdout,', args.llffhold)
        i_test = np.arange(images.shape[0])[::args.llffhold]

    i_val = i_test
    i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                        (i not in i_test and i not in i_val)])

    print('DEFINING BOUNDS')
    if args.no_ndc:
        near = np.ndarray.min(bds) * .9
        far = np.ndarray.max(bds) * 1.
            
    else:
        near = 0.
        far = 1.
    print('NEAR FAR', near, far)

elif args.dataset_type == 'blender':
    images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
    print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
    i_train, i_val, i_test = i_split

    near = 2.
    far = 6.

    if args.white_bkgd:
        images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
    else:
        images = images[...,:3]

elif args.dataset_type == 'LINEMOD':
    images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
    print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
    print(f'[CHECK HERE] near: {near}, far: {far}.')
    i_train, i_val, i_test = i_split

    if args.white_bkgd:
        images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
    else:
        images = images[...,:3]

elif args.dataset_type == 'deepvoxels':

    images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
                                                                 basedir=args.datadir,
                                                                 testskip=args.testskip)

    print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
    i_train, i_val, i_test = i_split

    hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
    near = hemi_R-1.
    far = hemi_R+1.

else:
    print('Unknown dataset type', args.dataset_type, 'exiting')

# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]

if K is None:
    K = np.array([
        [focal, 0, 0.5*W],
        [0, focal, 0.5*H],
        [0, 0, 1]
    ])

if args.render_test:
    render_poses = np.array(poses[i_test])

# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
    for arg in sorted(vars(args)):
        attr = getattr(args, arg)
        file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
    f = os.path.join(basedir, expname, 'config.txt')
    with open(f, 'w') as file:
        file.write(open(args.config, 'r').read())

Loaded blender (138, 800, 800, 4) torch.Size([40, 4, 4]) [800, 800, 1111.1110311937682] ./data/nerf_synthetic/lego


## Create DataSet

In [7]:
N_rand = args.N_rand

In [8]:
print('get rays')
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)

print('done')
i_batch = 0

images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
rays_rgb = torch.Tensor(rays_rgb).to(device)

get rays
done, concats
shuffle rays
done


## Define Model

In [9]:
class instant_NeRF(nn.Module):
    def __init__(self, pos_in = 3, dir_in = 3, pos_out = 16, nnconfig = None):
        super(instant_NeRF, self).__init__()
        self.pos_in, self.dir_in = pos_in, dir_in
        self.pos_encoding = tcnn.Encoding(pos_in , nnconfig["encoding"])
        self.density = tcnn.Network(n_input_dims=self.pos_encoding.n_output_dims, n_output_dims=16, network_config=nnconfig["network"])
        self.dir_encoding = tcnn.Encoding(dir_in, nnconfig["dir_encoding"])
        self.rgb = tcnn.Network(n_input_dims=self.dir_encoding.n_output_dims + pos_out, n_output_dims=3, network_config=nnconfig["rgb_network"])

    
    def forward(self, x):
        pos, view = torch.split(x, [self.pos_in, self.dir_in], dim=-1)
        encoded_pos = self.pos_encoding(pos)
        encoded_dir = self.dir_encoding(view)
        density = self.density(encoded_pos)
        rgb = self.rgb(torch.cat([density, encoded_dir], -1))
        
        return torch.cat([rgb,density[...,:1]], -1)

## Create Model

In [10]:
in_pos = 3
in_view = 3
model = instant_NeRF(pos_in = in_pos, dir_in = in_view, nnconfig = nnconfig)
grad_vars = list(model.parameters())

optimizer = torch.optim.Adam([
{'params':[grad_vars[0]],'weight_decay': 0},
{'params':grad_vars[1:],'weight_decay': 1e-6}
], lr=1e-2, betas=(0.9, 0.99), eps = 1e-15)

start = 0
basedir = args.basedir
expname = args.expname

render_kwargs_train = {
        'perturb' : args.perturb,
        'N_samples' : args.N_samples,
        'network' : model,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

if args.dataset_type != 'llff' or args.no_ndc:
    print('Not ndc!')
    render_kwargs_train['ndc'] = False
    render_kwargs_train['lindisp'] = args.lindisp
    
render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.

global_step = start
bds_dict = {
    'near' : near,
    'far' : far,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)

render_poses = torch.Tensor(render_poses).to(device)

Not ndc!


## Render Image

In [11]:
def render(H, W, K, rays = None, ndc = True, near = 0., far = 1.,
           N_samples = 0,
           network = None,
           retraw=False,
           lindisp=False,
           perturb=0.,
           white_bkgd=False,
           raw_noise_std=0.,
           verbose=False):

    rays_o, rays_d = rays
        
    viewdirs = rays_d
    viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
    viewdirs = torch.reshape(viewdirs, [-1,3]).float()
    
    shape = rays_d.shape
    if ndc:
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
        
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()
    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    
    N_rays = rays_o.shape[0]
    
    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals[:-1]) + 1./far * (t_vals[:-1]))
    z_vals = z_vals.expand([N_rays, N_samples])
    
    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)
        
        z_vals = lower + (upper - lower) * t_rand
        
    pos = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    pos_flat = torch.reshape(pos, [-1, pos.shape[-1]])
    
    dirs = viewdirs[:,None].expand(pos.shape)
    dirs_flat = torch.reshape(dirs, [-1, dirs.shape[-1]])
    
    inputs = torch.cat([pos_flat, dirs_flat], -1)
     
    rgbs = network(inputs)

    rgbs = torch.reshape(rgbs, list(pos.shape[:-1]) + [rgbs.shape[-1]])
    
    rgbs2alpha = lambda rgbs, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(rgbs)*dists)
    
    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)
    
    dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
    rgb = torch.sigmoid(rgbs[...,:3])
    
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(rgbs[...,3].shape) * raw_noise_std
    
    alpha = rgbs2alpha(rgbs[...,3] + noise, dists)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)
    
    depth_map = torch.sum(weights * z_vals, -1)
    disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, -1)

    if white_bkgd:
        rgb_map = rgb_map + (1.-acc_map[...,None])

    others = {}
    
    return rgb_map, disp_map, acc_map, others

In [12]:
def render_path(render_poses, hwf, K, render_kwargs, chunk = 1024, gt_imgs=None, savedir=None, render_factor=0):

    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor

    rgbs = []
    disps = []
    
    total_time = 0

    t = time.time()
    for i, c2w in enumerate(render_poses):
        t = time.time()
        rays_o, rays_d = get_rays(H, W, K, c2w[:3,:4])
        rays_o = rays_o.reshape(H * W, 3)
        rays_d = rays_d.reshape(H * W, 3)
        rgb = torch.zeros(H * W, 3)
        disp = torch.zeros(H * W)
        acc = torch.zeros(H * W)
        for j in range(0, H * W, chunk):
            rgb[j:j+chunk], disp[j:j+chunk], acc[j:j+chunk], _ = render(H, W, K, (rays_o[j:j+chunk], rays_d[j:j+chunk]), **render_kwargs)
        disp = disp / disp.max()
        rgbs.append(rgb.reshape(H, W, 3).cpu().numpy())
        disps.append(disp.reshape(H, W).cpu().numpy())
        
        total_time += time.time() - t

        """
        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
            print(p)
        """

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)


    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)
    
    print('average render time:', total_time / len(render_poses))

    return rgbs, disps


In [13]:
inputs.shape[:-1]

NameError: name 'inputs' is not defined

## Train model

In [14]:
N_iters = 200000 + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)

Begin
TRAIN views are [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]
TEST views are [113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
 131 132 133 134 135 136 137]
VAL views are [100 101 102 103 104 105 106 107 108 109 110 111 112]


In [15]:
start = start + 1
for i in trange(start, N_iters):
    time0 = time.time()
    batch = rays_rgb[i_batch:i_batch+N_rand]
    batch = torch.transpose(batch, 0, 1)
    batch_rays, target_s = batch[:2], batch[2]
    
    i_batch += N_rand
    if i_batch >= rays_rgb.shape[0]:
        #print("Shuffle data after an epoch!")
        rand_idx = torch.randperm(rays_rgb.shape[0])
        rays_rgb = rays_rgb[rand_idx]
        i_batch = 0
    rgb, disp, acc, extras = render(H, W, K, rays=batch_rays,
                                    verbose=i < 10, retraw=True,
                                    **render_kwargs_train)
    optimizer.zero_grad()
    img_loss = img2mse(rgb, target_s)
    loss = img_loss
    psnr = mse2psnr(img_loss)
    loss.backward()
    optimizer.step()
        
    if (i >= 20000) and (i % 10000 == 0):
        new_lrate = args.lrate * 0.33
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        
    if i % args.i_print == 0:
        tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
            
    if (i % args.i_video == 0) and (i > 0):
        # Turn on testing mode
        with torch.no_grad():
            rgbs, disps = render_path(render_poses, hwf, K, render_kwargs_test)
        print('Done, saving', rgbs.shape, disps.shape)
        moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
        imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
        imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

            # if args.use_viewdirs:
            #     render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
            #     with torch.no_grad():
            #         rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
            #     render_kwargs_test['c2w_staticcam'] = None
            #     imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

    if i%args.i_testset==0 and i > 0:
        testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', poses[i_test].shape)
        with torch.no_grad():
            render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
        print('Saved test set')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[TRAIN] Iter: 10000 Loss: 0.018529970198869705  PSNR: 17.321252822875977
test poses shape torch.Size([25, 4, 4])


  5%|████████▏                                                                                                                                                           | 10009/200000 [02:45<60:30:09,  1.15s/it]

average render time: 1.412659034729004
Saved test set


 10%|████████████████▋                                                                                                                                                      | 19996/200000 [04:43<36:18, 82.64it/s]

[TRAIN] Iter: 20000 Loss: 0.016631726175546646  PSNR: 17.790626525878906
test poses shape torch.Size([25, 4, 4])


 10%|████████████████▍                                                                                                                                                   | 20009/200000 [05:36<69:53:13,  1.40s/it]

average render time: 1.44331711769104
Saved test set


 15%|█████████████████████████                                                                                                                                              | 29996/200000 [07:36<32:57, 85.96it/s]

[TRAIN] Iter: 30000 Loss: 0.011511610820889473  PSNR: 19.388639450073242
test poses shape torch.Size([25, 4, 4])


 15%|████████████████████████▌                                                                                                                                           | 30018/200000 [08:26<42:36:12,  1.11it/s]

average render time: 1.4063135528564452
Saved test set


 20%|█████████████████████████████████▍                                                                                                                                     | 39997/200000 [10:24<32:40, 81.61it/s]

[TRAIN] Iter: 40000 Loss: 0.012791832908987999  PSNR: 18.93067169189453
test poses shape torch.Size([25, 4, 4])


 20%|████████████████████████████████▊                                                                                                                                   | 40018/200000 [11:13<39:59:49,  1.11it/s]

average render time: 1.4032281684875487
Saved test set


 25%|█████████████████████████████████████████▋                                                                                                                             | 49997/200000 [13:11<29:31, 84.67it/s]

[TRAIN] Iter: 50000 Loss: 0.009811382740736008  PSNR: 20.082698822021484


 25%|█████████████████████████████████████████▋                                                                                                                             | 49997/200000 [13:30<29:31, 84.67it/s]

average render time: 1.4517602801322937
Done, saving (40, 800, 800, 3) (40, 800, 800)
test poses shape torch.Size([25, 4, 4])


 25%|████████████████████████████████████████▊                                                                                                                          | 50008/200000 [15:02<130:08:36,  3.12s/it]

average render time: 1.461791000366211
Saved test set


 30%|██████████████████████████████████████████████████                                                                                                                     | 59994/200000 [17:02<27:42, 84.21it/s]

[TRAIN] Iter: 60000 Loss: 0.010239677503705025  PSNR: 19.897138595581055
test poses shape torch.Size([25, 4, 4])


 30%|█████████████████████████████████████████████████▏                                                                                                                  | 60009/200000 [17:52<49:11:35,  1.27s/it]

average render time: 1.4642861366271973
Saved test set


 35%|██████████████████████████████████████████████████████████▍                                                                                                            | 69994/200000 [20:07<25:59, 83.34it/s]

[TRAIN] Iter: 70000 Loss: 0.010131310671567917  PSNR: 19.943344116210938
test poses shape torch.Size([25, 4, 4])


 35%|█████████████████████████████████████████████████████████▍                                                                                                          | 70008/200000 [20:57<47:58:09,  1.33s/it]

average render time: 1.4624547004699706
Saved test set


 40%|██████████████████████████████████████████████████████████████████▊                                                                                                    | 79992/200000 [23:08<24:22, 82.04it/s]

[TRAIN] Iter: 80000 Loss: 0.011115392670035362  PSNR: 19.54075050354004
test poses shape torch.Size([25, 4, 4])


 40%|█████████████████████████████████████████████████████████████████▌                                                                                                  | 80008/200000 [24:02<46:31:57,  1.40s/it]

average render time: 1.6480175018310548
Saved test set


 45%|███████████████████████████████████████████████████████████████████████████▏                                                                                           | 89991/200000 [26:23<21:24, 85.65it/s]

[TRAIN] Iter: 90000 Loss: 0.010407477617263794  PSNR: 19.82654571533203
test poses shape torch.Size([25, 4, 4])


 45%|█████████████████████████████████████████████████████████████████████████▊                                                                                          | 90018/200000 [27:13<25:31:15,  1.20it/s]

average render time: 1.4972771644592284
Saved test set


 50%|███████████████████████████████████████████████████████████████████████████████████▍                                                                                   | 99998/200000 [29:21<25:06, 66.39it/s]

[TRAIN] Iter: 100000 Loss: 0.009664545767009258  PSNR: 20.14818572998047


 50%|███████████████████████████████████████████████████████████████████████████████████▍                                                                                   | 99998/200000 [29:31<25:06, 66.39it/s]

average render time: 1.5959432184696198
Done, saving (40, 800, 800, 3) (40, 800, 800)
test poses shape torch.Size([25, 4, 4])


 50%|█████████████████████████████████████████████████████████████████████████████████                                                                                 | 100008/200000 [31:20<106:44:38,  3.84s/it]

average render time: 1.6108227920532228
Saved test set


 55%|██████████████████████████████████████████████████████████████████████████████████████████▋                                                                           | 109199/200000 [33:38<27:58, 54.09it/s]


KeyboardInterrupt: 

In [None]:
with torch.no_grad():
    rgbs, disps = render_path(render_poses, (1024,1024,hwf[2]), K, render_kwargs_test)
    print('Done, saving', rgbs.shape, disps.shape)
    moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
    imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
    imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)