In [1]:
import os
import time
import imageio
import torch
import torch.nn.functional as F
import numpy as np
import tqdm
from main import config_parser
from data_loader.load_llff import load_llff_data
from model import create_nerf, img2mse, mse2psnr
from render import get_rays_np, render

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = config_parser()
args = parser.parse_args(args='--config configs/ckc.txt')

In [3]:
K = None
device = torch.device("cpu")
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
                                                          recenter=True, bd_factor=.75,
                                                          spherify=args.spherify)
hwf = poses[0,:3,-1]
poses = poses[:,:3,:4]
print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
if not isinstance(i_test, list):
    i_test = [i_test]
if args.llffhold > 0:                   # take every 1/N images as LLFF test set,
    print('Auto LLFF holdout,', args.llffhold)
    i_test = np.arange(images.shape[0])[::args.llffhold]
i_val = i_test
i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                (i not in i_test and i not in i_val)])
print('DEFINING BOUNDS')
if args.no_ndc:
    near = np.ndarray.min(bds) * .9
    far = np.ndarray.max(bds) * 1.
    
else:
    near = 0.
    far = 1.
print('NEAR FAR', near, far)

Loaded image data (292, 135, 3, 95) [292.         135.         202.50087625]
Loaded ./data/nerf_llff_data/ckc 2.6607503791705396 105.38726428465726
Data:
(95, 3, 5) (95, 292, 135, 3) (95, 2)
HOLDOUT view is 41
Loaded llff (95, 292, 135, 3) (120, 3, 5) [292.      135.      202.50087] ./data/nerf_llff_data/ckc
Auto LLFF holdout, 8
DEFINING BOUNDS
NEAR FAR 0.1832474336028099 8.064537048339844


In [4]:
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
if K is None:
    K = np.array([
        [focal, 0, 0.5*W],
        [0, focal, 0.5*H],
        [0, 0, 1]
    ])
# Create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args, device)
global_step = start         # ckpt number 
bds_dict = {
    'near' : near,
    'far' : far,
}
render_kwargs_train.update(bds_dict)            # insert near and far values
render_kwargs_test.update(bds_dict)
# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)

Found ckpts []
Not ndc!


In [5]:
N_rand = args.N_rand
print('get rays')
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0)   # [N, ro+rd, H, W, 3]
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1)                    # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4])                          # [N, H, W, ro+rd+rgb, 3]
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0)                  # train images only
rays_rgb = np.reshape(rays_rgb, [-1,3,3])                               # [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)
print('done')
i_batch = 0

get rays
done, concats
shuffle rays
done


In [6]:
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
rays_rgb = torch.Tensor(rays_rgb).to(device)

In [7]:
print(device)
for i in range(1, 25):
    print(i)
    time0 = time.time()
    # Sample random ray batch
    # Random over all images
    batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
    batch = torch.transpose(batch, 0, 1)
    batch_rays, target_s = batch[:2], batch[2]
    i_batch += N_rand
    if i_batch >= rays_rgb.shape[0]:
        print("Shuffle data after an epoch!")
        rand_idx = torch.randperm(rays_rgb.shape[0])
        rays_rgb = rays_rgb[rand_idx]
        i_batch = 0
            

    rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                            verbose=i < 10, retraw=True,
                                            **render_kwargs_train)
    optimizer.zero_grad()
    img_loss = img2mse(rgb, target_s)
    trans = extras['raw'][...,-1]
    loss = img_loss
    psnr = mse2psnr(img_loss)
    if 'rgb0' in extras:
        img_loss0 = img2mse(extras['rgb0'], target_s)
        loss = loss + img_loss0
        psnr0 = mse2psnr(img_loss0)
    loss.backward()
    optimizer.step()
    # NOTE: IMPORTANT!
    ###   update learning rate   ###
    decay_rate = 0.1
    decay_steps = args.lrate_decay * 1000
    new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate
    ################################

cuda
1


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!