##Tiny NeRF
This is a simplied version of the method presented in *NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis*

[Project Website](http://www.matthewtancik.com/nerf)

[arXiv Paper](https://arxiv.org/abs/2003.08934)

[Full Code](github.com/bmild/nerf)

Components not included in the notebook
*   5D input including view directions
*   Hierarchical Sampling



In [None]:
import os, sys
import torch
from torch import nn

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
if not os.path.exists('tiny_nerf_data.npz'):
    !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz

# Load Input Images and Poses

In [None]:
data = np.load('tiny_nerf_data.npz')
images = data['images']
poses = data['poses']
focal = data['focal']
H, W = images.shape[1:3]
print(images.shape, poses.shape, focal)

testimg, testpose = images[101], poses[101]
images = images[:100,...,:3]
poses = poses[:100]

plt.imshow(testimg)
plt.show()

# Optimize NeRF

In [None]:


def posenc(x):
  rets = [x]
  for i in range(L_embed):
    for fn in [torch.sin, torch.cos]:
      rets.append(fn(2.**i * x))
  return torch.cat(rets, -1).to(device)

L_embed = 6
embed_fn = posenc
# L_embed = 0
# embed_fn = tf.identity
  
class NeRF(nn.Module):
    def __init__(self, D=8, W=256):
      super(NeRF, self).__init__()
      inputSize = 3 + 3*2*L_embed
      self.inputLayer = nn.Linear(inputSize, W)
      self.hiddenLayers = nn.ModuleList() 
      for i in range(D-1):
        if i%4==0 and i>0:
          self.hiddenLayers.append(nn.Linear(W+inputSize, W))
        else:
          self.hiddenLayers.append(nn.Linear(W, W))
      self.outputLayer = nn.Linear(W, 4)
  
    def forward(self, x):
      x_initial = x
      x = nn.functional.relu(self.inputLayer(x))
      for i, layer in enumerate(self.hiddenLayers):
        x = nn.functional.relu(layer(x))
        if (i+1)%4 == 0 and i+1>0:
          x = torch.cat([x, x_initial], dim=-1)
      x=self.outputLayer(x)
      return x


def get_rays(H, W, focal, c2w):
    c2w = torch.from_numpy(c2w).to(device)
    focal = torch.from_numpy(focal).to(device)
    i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32, device=device), torch.arange(H, dtype=torch.float32, device=device), indexing="xy")
    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = torch.broadcast_to(c2w[:3,-1], rays_d.size())
    return rays_o, rays_d


def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, rand=False):

    def batchify(fn, chunk=1024*32):
        return lambda inputs : torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.size()[0], chunk)], 0)
    
    # Compute 3D query points
    z_vals = torch.linspace(near, far, N_samples, device=device)
    if rand:
      z_vals = torch.broadcast_to(z_vals,list(rays_o.shape[:-1]) + [N_samples]).clone()
      z_vals += torch.rand(list(rays_o.shape[:-1]) + [N_samples], device=device) * (far-near)/N_samples
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

    # Run network
    pts_flat = torch.reshape(pts,(-1,3))
    pts_flat = embed_fn(pts_flat)
    raw = batchify(network_fn)(pts_flat)
    raw = torch.reshape(raw, list(pts.size()[:-1]) + [4])
    
    # Compute opacities and colors
    sigma_a = nn.functional.relu(raw[...,3])
    rgb = torch.sigmoid(raw[...,:3]) 
    
    # Do volume rendering
    dists = torch.cat([z_vals[..., 1:] - z_vals[..., :-1], torch.full(z_vals[...,:1].shape, 1e10, device=device)], -1) 
    alpha = 1.-torch.exp(-sigma_a * dists)  
    weights = torch.cumprod(1.-alpha + 1e-10, -1)
    weights = torch.roll(weights, 1, -1)
    weights[..., 0] = 1.
    weights =  alpha * weights

    rgb_map = torch.sum(weights[...,None] * rgb, -2) 
    depth_map = torch.sum(weights * z_vals, -1) 
    acc_map = torch.sum(weights, -1)

    return rgb_map, depth_map, acc_map

Here we optimize the model. We plot a rendered holdout view and its PSNR every 50 iterations.

In [None]:
model = NeRF().to(device)
lr = 5e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

N_samples = 64
N_iters = 1000
psnrs = []
iternums = []
i_plot = 25
loss_f = torch.nn.MSELoss()

import time
t = time.time()
for i in range(N_iters+1):
    model.train()
    img_i = np.random.randint(images.shape[0])
    target = images[img_i]
    pose = poses[img_i]
    
    target = torch.from_numpy(target).to(device)    
    rays_o, rays_d = get_rays(H, W, focal, pose)

    optimizer.zero_grad()
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples, rand=False)
    loss = loss_f(rgb, target)
    loss.backward()
    optimizer.step()
    
    if i%i_plot==0:
        model.eval()
        print(i, (time.time() - t) / i_plot, 'secs per iter')
        t = time.time()
        
        # Render the holdout view for logging
        rays_o, rays_d = get_rays(H, W, focal, testpose)
        rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
        loss = loss_f(rgb, torch.from_numpy(testimg).to(device))
        psnr = -10. * torch.log10(loss)

        psnrs.append(psnr.cpu().detach().numpy())
        iternums.append(i)
        
        plt.figure(figsize=(10,4))
        plt.subplot(121)
        plt.imshow(rgb.cpu().detach().numpy())
        plt.title(f'Iteration: {i}')
        plt.subplot(122)
        plt.plot(iternums, psnrs)
        plt.title('PSNR')
        plt.show()

print('Done')

# Interactive Visualization

In [None]:
%matplotlib inline
from ipywidgets import interactive, widgets

trans_t = lambda t : np.array([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1],
], dtype=np.float32)

rot_phi = lambda phi : np.array([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1],
], dtype=np.float32)

rot_theta = lambda th : np.array([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1],
], dtype=np.float32)

def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]], dtype=np.float32) @ c2w
    return c2w


def f(**kwargs):
    c2w = pose_spherical(**kwargs)
    rays_o, rays_d = get_rays(H, W, focal, c2w[:3,:4])
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
    rgb = rgb.cpu().detach().numpy()
    img = np.clip(rgb,0,1)
    
    plt.figure(2, figsize=(20,6))
    plt.imshow(img)
    plt.show()
    

sldr = lambda v, mi, ma: widgets.FloatSlider(
    value=v,
    min=mi,
    max=ma,
    step=.01,
)

names = [
    ['theta', [100., 0., 360]],
    ['phi', [-30., -90, 0]],
    ['radius', [4., 3., 5.]],
]

interactive_plot = interactive(f, **{s[0] : sldr(*s[1]) for s in names})
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot

# Render 360 Video

In [None]:
frames = []
for th in tqdm(np.linspace(0., 360., 120, endpoint=False)):
    c2w = pose_spherical(th, -30., 4.)
    rays_o, rays_d = get_rays(H, W, focal, c2w[:3,:4])
    rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
    rgb = rgb.cpu().detach().numpy()
    frames.append((255*np.clip(rgb,0,1)).astype(np.uint8))

!pip install imageio-ffmpeg
import imageio
f = 'video.mp4'
imageio.mimwrite(f, frames, fps=30, quality=7)

In [None]:
from IPython.display import HTML
from base64 import b64encode
mp4 = open('video.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls autoplay loop>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)