# A Guided Tour to Neural Radiance Field (NeRF)

## Environment Initialization

Common Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import sys
import os
from tqdm import tqdm
from matplotlib import pyplot as plt

NeRF is computation hungry, be sure to run this notebook on GPU

In [None]:
# select devices
DEVICE = None
if torch.cuda.is_available():
    print("Good to go!")
    DEVICE = torch.device("cuda")
else:
    print("Bad to go!")
    DEVICE = torch.device("cpu")

## Load Data

Load config from config files.

In [None]:
# load config
sys.path.append(os.getcwd())
# choose between hotdog, lego
import configs.hotdog, configs.lego
sample_t: tuple = (2,6)

#### Modify: scale_factor, config ###
# 100x100: scale_factor=3
# 200x200:scale_factor=2
scale_factor = 3
# change config file here
config = configs.lego

Load the dataset and show the first one.

In [None]:
from nerf.data import load_blender
imgs, poses, int_mat = load_blender(config.datadir, device="cpu", scale_factor=scale_factor)
img_n, img_h, img_w = imgs.shape[:3]
# visualize
plt.imshow(np.array(imgs[0].to(device="cpu")))
plt.axis("off")
plt.title("demo image")
plt.show()
print("and its pose: ")
print(np.array(poses[0].to(device="cpu")))

# Test Functions

In [None]:
# compute rays
from nerf.graphics import compute_rays

rays_o, rays_d = compute_rays((img_h, img_w), int_mat, poses[0])
print("origin: ", rays_o[0,0])
print("normalized origin: ", F.normalize(rays_o[0,0], dim=0))
print("center of ray: ", rays_d[img_h//2,img_w//2])

In [None]:
# query from rays
from nerf.graphics import queries_from_rays
samples = None
samples, depths = queries_from_rays(rays_o, rays_d, sample_t, 8)
print("samples[0, 0]: ", samples[0,0])
# print("depths: ", depths)

In [None]:
# test pos encode

from nerf.nerf_helper import *

L = 6
x = torch.tensor([[ 1.8013, -0.6242,  0.7009]])
# x = torch.tensor([ 1.8013, -0.6242,  0.7009])
enc_x = PosEncode(x, L, True)
# enc_xx = PosEncode1(x, L, True)
print(enc_x)
# print(enc_xx)

In [None]:
# test render from nerf
from nerf.graphics import render_from_nerf
fake_depth = torch.Tensor([1])
fake_nerf_output = imgs[0].cpu().reshape(img_h, img_w, 1, 4)
rgb, depth = render_from_nerf(fake_nerf_output, fake_depth)
plt.imshow(rgb)
plt.show()

## Training

In [None]:
from nerf.model import NeRF 
import os.path

seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
###### hyper-parameters
L_pos = 10
L_dir = 4
num_samples = 32
batch_size = 8192 # increase batchsize if u have large GPU MEM
fc_width = 128
fc_depth = 4
skips = [2]
lr = 5e-4
num_it = 10001
display_every = 200

###### models
model = NeRF(ch_in_pos=6*L_pos+3, ch_in_dir=6*L_dir+3, fc_width=fc_width, fc_depth=fc_depth, skips=skips)
model.to(DEVICE)

###### load validation data
imgs_val, poses_val, _ = load_blender(config.datadir, data_type="val",scale_factor=scale_factor, device="cpu")
num_val = imgs_val.shape[0]

###### train
psnrs = []
val_its = []
i = 0

###### optimizer, checkpoint
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr)
ckpt_path = 'nerf.pt'

###### check saved checkpoints
if os.path.exists(ckpt_path):
    print("checkpoint found! Loading...")
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    i = checkpoint['epoch']
    loss = checkpoint['loss']
    psnrs = checkpoint['psnrs']
    val_its = checkpoint['its']
    print("checkpoint loaded, i =",i)
else:
    print("No checkpoint found")

In [None]:
from nerf.training_logic import train_NeRF
train_NeRF(model = model, optimizer=optimizer,imgs_train=imgs, imgs_val=imgs_val, poses_train=poses, poses_val=poses_val,int_mat=int_mat, sample_t=sample_t,
            L_pos=L_pos, L_dir=L_dir, num_samples=num_samples, ckpt_path=ckpt_path, batch_size=batch_size,
            psnrs=psnrs, val_its=val_its, start_iter_num=i, end_iter_num=num_it)

# Generate Video

In [None]:
# raise Exception("nothing wrong")
# Select model
ckpt_path = 'nerf.pt' # './trained_models/basic_nerf_lego_9800.pt'
# change config file
config = configs.lego


if os.path.exists(ckpt_path):
    print(ckpt_path, "found! Loading...")
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    i = checkpoint['epoch']
    loss = checkpoint['loss']
    psnrs = checkpoint['psnrs']
    its = checkpoint['its']
    print("checkpoint loaded, i =",i)
else:
    print("No checkpoint found")

In [None]:
from nerf.graphics import generate_demo_poses
from nerf.nerf_helper import nerf_iter_once
import imageio
import cv2

# generate pose
model.eval()
gen_num: int = 120
repeat: int = 2
gen_poses:Tensor = generate_demo_poses(height=4, num_poses=gen_num).to(poses).to(DEVICE)
gen_imgs:list = []
for i in range(gen_num):
    with torch.no_grad():
        pred_rgb, pred_depth = nerf_iter_once(
                model,
                (img_h, img_w),
                int_mat.to(DEVICE),
                gen_poses[i],
                sample_t,
                L_pos,
                L_dir,
                num_samples=num_samples,
                batch_size=batch_size
                )
    # concat channels
    pred_rgbd: Tensor = torch.cat([pred_rgb, pred_depth[...,None]], dim=-1)
    # translate to [0,255]
    img_np = np.array(pred_rgbd.detach().cpu()*255).astype(np.uint8)
    img_np = cv2.resize(img_np, (112,112), interpolation=cv2.INTER_AREA)
    gen_imgs.append(img_np)

gen_imgs = gen_imgs * repeat

imageio.mimwrite('{}.mp4'.format(config.expname), gen_imgs, fps=30, quality=8)