# DD2424 - Project
This is an altered version of the method presented in NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

[Project website](http://www.matthewtancik.com/nerf)

Components not included in the notebook

* 5D input including view directions
* Hierarchical Sampling

In [None]:
import os, sys
import tensorflow as tf

from tqdm import tqdm_notebook as tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
if not os.path.exists('tiny_nerf_data.npz'):
    !wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz

In [None]:
from tensorflow.python.client import device_lib
"""
Checks what gpu has been allocated
"""
device_lib.list_local_devices()

# Load input images and poses. 
The input images have size (100,100,3)

In [None]:
data = np.load('tiny_nerf_data.npz')
images = data['images']
poses = data['poses']
focal = data['focal']
H, W = images.shape[1:3]
print(images.shape, poses.shape, focal)

testimg, testpose = images[101], poses[101]
images = images[:100,...,:3]
poses = poses[:100]

plt.imshow(testimg)
plt.show()

# NeRF network

In [None]:
def posenc(x):
  rets = [x]
  for i in range(L_embed):
    for fn in [tf.sin, tf.cos]:
      rets.append(fn(2.**i * x))
  return tf.concat(rets, -1)

def get_rays(H, W, focal, c2w):
    i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy')
    dirs = tf.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -tf.ones_like(i)], -1)
    rays_d = tf.reduce_sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = tf.broadcast_to(c2w[:3,-1], tf.shape(rays_d))
    return rays_o, rays_d

L_embed = 6
embed_fn = posenc

def init_model(D=8, W=256):
    relu = tf.keras.layers.ReLU()    
    dense = lambda W=W, act=relu : tf.keras.layers.Dense(W, activation=act)

    inputs = tf.keras.Input(shape=(3 + 3*2*L_embed)) 
    outputs = inputs
    for i in range(D):
        outputs = dense()(outputs)
        if i%4==0 and i>0:
            outputs = tf.concat([outputs, inputs], -1)
    outputs = dense(4, act=None)(outputs)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model



# Custom model and layer
- 1 (Might work, but super slow)

In [None]:

class MLP1(tf.keras.layers.Dense):
    def __init__(self, W, activation, H=8, CR=False):
      super(MLP, self).__init__(W, activation=activation)

      self.cls_map = {}
      self.H = H
      self.CR = CR

    def build(self, inputs):
      super(MLP, self).build(inputs)
      self.hash_mat = tf.random.uniform((inputs[1],self.H), minval=-1, maxval=1)
    

    @tf.function
    def call(self, inputs): 
      @tf.function
      def func(vec):
        vec = tf.reshape(vec, (1,-1))
        digest = tf.matmul(vec, self.hash_mat)
        digest = tf.cast(tf.greater(digest, 0), dtype=tf.int32)
        key = tf.reduce_sum(digest * 2 ** tf.range(0, self.H)).ref()
        cond = tf.cast(key in self.cls_map, tf.bool)
        def in_table(): 
          self.cls_map.setdefault(key, vec)
          return self.cls_map[key]
        def not_in_table():
          y = super(MLP, self).call(vec)
          self.cls_map[key] = y
          return y

        row = tf.cond(cond, in_table, not_in_table)

        return row
      print(inputs.shape)
      outputs = tf.map_fn(func, inputs, dtype=tf.float32)
      return tf.reshape(outputs, (-1, inputs.shape[1]))
  
class neRF1(tf.keras.Model):
  def init_model(self, D=8, W=256):
    relu = tf.keras.layers.ReLU()    
    dense = lambda W=W, act=relu : MLP(W, activation=act, H=self.H, CR=self.CR)

    inputs = tf.keras.Input(shape=(3 + 3*2*L_embed)) 
    outputs = inputs
    for i in range(D):
        outputs = dense()(outputs)
        if i%4==0 and i>0:
            outputs = tf.concat([outputs, inputs], -1)
    outputs = dense(4, act=None)(outputs)
    
    return inputs, outputs
  


  def __init__(self, H=8, CR=False):
    self.H = H
    self.CR = CR
    inputs, outputs = self.init_model()
    super(neRF, self).__init__(inputs=inputs, outputs=outputs)
  
  def call(self, inputs):
    return super(neRF, self).call(inputs)




    

# Custom model and layer
- 2 Integrated the cls into the layers, not tested

In [None]:
class MLP(tf.keras.layers.Dense):
    def __init__(self, W, activation, H=10, CR=False):
        super(MLP, self).__init__(W, activation=activation)
        self.i = 2 ** tf.range(0, H)
        self.H = H
    def build(self, inputs):
        super(MLP, self).build(inputs)
        self.hash_mat = tf.keras.utils.normalize(tf.random.uniform((inputs[1],self.H), minval=-1, maxval=1), axis=0)
        self.cluster_centers = tf.zeros((2**H, inputs[1]))

    def call(self, inputs): 
        digest = tf.matmul(inputs, self.hash_mat)
        digest = tf.cast(tf.greater(digest, 0), dtype=tf.int32)
        keys = tf.reshape(tf.reduce_sum(digest * self.i, axis=1), (inputs.shape[0], 1))
        keys = tf.concat([keys, tf.reshape(tf.range(0,inputs.shape[0]),(inputs.shape[0],1))], axis=1)
        keys = tf.cast(keys, dtype=tf.int64)
        st = tf.sparse.SparseTensor(keys, values=tf.ones(keys.shape[0]), dense_shape=(2**self.H, inputs.shape[0]))
        centroids = tf.math.divide_no_nan(tf.sparse.sparse_dense_matmul(self.st, inputs), tf.reshape(tf.sparse.reduce_sum(self.st, axis=1), (self.cluster_centers.shape[0],1)))
        self.cluster_centers = (self.cluster_centers + centroids) / 2
        y = super(MLP, self).call(self.cluster_centers)
        outputs = tf.sparse.sparse_dense_matmul(tf.sparse.transpose(st), inputs)
        return outputs

  
class neRF(tf.keras.Model):
  def init_model(self, D=8, W=256):
    relu = tf.keras.layers.ReLU()    
    dense = lambda W=W, act=relu : MLP(W, activation=act, H=self.H, CR=self.CR)

    inputs = tf.keras.Input(shape=(3 + 3*2*L_embed)) 
    outputs = inputs
    for i in range(D):
        outputs = dense()(outputs)
        if i%4==0 and i>0:
            outputs = tf.concat([outputs, inputs], -1)
    outputs = dense(4, act=None)(outputs)
    
    return inputs, outputs
  


  def __init__(self, H=10, CR=False):
    self.H = H
    self.CR = CR
    inputs, outputs = self.init_model()
    super(neRF, self).__init__(inputs=inputs, outputs=outputs)
  
  def call(self, inputs):
    return super(neRF, self).call(inputs)


# Clustering function

In [None]:
"""
N = C if input data has size (D,C)
H = number of hash functions (hyperplanes)
"""
class CLS():
  def __init__(self, N, H):
    self.set(N, H)

  def hash(self, inputs):
    """
    Given input (D,C)
    returns outputs matrix (2^H, C)
    """
    digest = tf.matmul(inputs, self.hash_mat)
    digest = tf.cast(tf.greater(digest, 0), dtype=tf.int32)
    keys = tf.reshape(tf.reduce_sum(digest * self.i, axis=1), (inputs.shape[0], 1))
    keys = tf.concat([keys, tf.reshape(tf.range(0,inputs.shape[0]),(inputs.shape[0],1))], axis=1)
    keys = tf.cast(keys, dtype=tf.int64)
    self.st = tf.sparse.SparseTensor(keys, values=tf.ones(keys.shape[0]), dense_shape=(2**self.H, inputs.shape[0]))
    centroids = tf.math.divide_no_nan(tf.sparse.sparse_dense_matmul(self.st, inputs), tf.reshape(tf.sparse.reduce_sum(self.st, axis=1), (self.cluster_centers.shape[0],1)))
    self.cluster_centers = (self.cluster_centers + centroids) / 2
    return self.cluster_centers
  
  def dehash(self, inputs):
    return tf.sparse.sparse_dense_matmul(tf.sparse.transpose(self.st), inputs)
  
  def set(self, N, H=10):
    self.H = H
    self.hash_mat = tf.keras.utils.normalize(tf.random.uniform((N,self.H), minval=-1, maxval=1), axis=0)
    self.cluster_centers = tf.zeros((2**H, N))
    self.i = 2 ** tf.range(0, H)

"""
Decorated with @tf.function
"""
@tf.function
def hash(self, inputs, hash_mat, idx, H, cluster_centers):
    """
    Given input (D,C)
    returns outputs matrix (2^H, C)
    """
    digest = tf.matmul(inputs, hash_mat)
    digest = tf.cast(tf.greater(digest, 0), dtype=tf.int32)
    keys = tf.reshape(tf.reduce_sum(digest * i, axis=1), (inputs.shape[0], 1))
    keys = tf.concat([keys, tf.reshape(tf.range(0,inputs.shape[0]),(inputs.shape[0],1))], axis=1)
    keys = tf.cast(keys, dtype=tf.int64)
    st = tf.sparse.SparseTensor(keys, values=tf.ones(keys.shape[0]), dense_shape=(2**H, inputs.shape[0]))
    centroids = tf.math.divide_no_nan(tf.sparse.sparse_dense_matmul(st, inputs), tf.reshape(tf.sparse.reduce_sum(st, axis=1), (cluster_centers.shape[0],1)))
    cluster_centers = (cluster_centers + centroids) / 2
    return cluster_centers, st
  
def dehash(self, inputs, st):
    return tf.sparse.sparse_dense_matmul(tf.sparse.transpose(st), inputs)

# Training function
render_rays() takes the model and a ray as input. Draw N_samples of points alone the ray. Run the network on these points, use the output to compute the the color and opacity(density). 

In [None]:
def render_rays_profile(network_fn, rays_o, rays_d, near, far, N_samples, cls=None rand=False):
    pr = cProfile.Profile()
    pr2 = cProfile.Profile()
    
    # Compute 3D query points
    z_vals = tf.linspace(near, far, N_samples) 
    if rand:
      z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    # Run network
    pts_flat = tf.reshape(pts, [-1,3])
    pts_flat = embed_fn(pts_flat)
    
      
    cond = tf.cast(cls!=None, tf.bool)
    def hashable():
      hashed = cls.hash(pts_flat)
      raw = network_fn(hashed)
      return clus.dehash(raw)
    def normal():
      return network_fn(pts_flat) 
    
    pr.enable()
    raw = tf.cond(cond, hashable, normal)
    pr.disable()
    ps = pstats.Stats(pr).sort_stats('tottime').dump_stats("[Profile]forward")
    
    raw = tf.reshape(raw, list(pts.shape[:-1]) + [4])
    
    # Compute opacities and colors
    sigma_a = tf.nn.relu(raw[...,3])
    rgb = tf.math.sigmoid(raw[...,:3]) 
    
    # Do volume rendering
    pr2.enable()
    dists = tf.concat([z_vals[..., 1:] - z_vals[..., :-1], tf.broadcast_to([1e10], z_vals[...,:1].shape)], -1) 
    alpha = 1.-tf.exp(-sigma_a * dists)  
    weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    
    rgb_map = tf.reduce_sum(weights[...,None] * rgb, -2) 
    depth_map = tf.reduce_sum(weights * z_vals, -1) 
    acc_map = tf.reduce_sum(weights, -1)
    pr2.disable()
    ps = pstats.Stats(pr2).sort_stats('tottime').dump_stats("[Profile]rendering")
    
    return rgb_map, depth_map, acc_map

In [None]:
def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, cls=None, rand=False):
    # Compute 3D query points
    z_vals = tf.linspace(near, far, N_samples) 
    if rand:
      z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    # Run network
    pts_flat = tf.reshape(pts, [-1,3])
    pts_flat = embed_fn(pts_flat)
    
    cond = tf.cast(cls!=None, tf.bool)
    def hashable():
      hashed = cls.hash(pts_flat)
      raw = network_fn(hashed)
    def normal():
      return network_fn(pts_flat) 
    

    raw = tf.cond(cond, hashable, normal)
    raw = tf.reshape(raw, list(pts.shape[:-1]) + [4])

    # Compute opacities and colors
    sigma_a = tf.nn.relu(raw[...,3])
    rgb = tf.math.sigmoid(raw[...,:3]) 
    
    
    # Do volume rendering
    dists = tf.concat([z_vals[..., 1:] - z_vals[..., :-1], tf.broadcast_to([1e10], z_vals[...,:1].shape)], -1) 
    alpha = 1.-tf.exp(-sigma_a * dists)  
    weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    
    rgb_map = tf.reduce_sum(weights[...,None] * rgb, -2) 
    depth_map = tf.reduce_sum(weights * z_vals, -1) 
    acc_map = tf.reduce_sum(weights, -1)
    
    return rgb_map, depth_map, acc_map

def render_rays_original(network_fn, rays_o, rays_d, near, far, N_samples, rand=False):
    # Compute 3D query points
    z_vals = tf.linspace(near, far, N_samples) 
    if rand:
      z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    # Run network
    pts_flat = tf.reshape(pts, [-1,3])
    pts_flat = embed_fn(pts_flat)
    
    raw = network_fn(pts_flat) 
    raw = tf.reshape(raw, list(pts.shape[:-1]) + [4])

    # Compute opacities and colors
    sigma_a = tf.nn.relu(raw[...,3])
    rgb = tf.math.sigmoid(raw[...,:3]) 
    
    
    # Do volume rendering
    dists = tf.concat([z_vals[..., 1:] - z_vals[..., :-1], tf.broadcast_to([1e10], z_vals[...,:1].shape)], -1) 
    alpha = 1.-tf.exp(-sigma_a * dists)  
    weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    
    rgb_map = tf.reduce_sum(weights[...,None] * rgb, -2) 
    depth_map = tf.reduce_sum(weights * z_vals, -1) 
    acc_map = tf.reduce_sum(weights, -1)
    
    return rgb_map, depth_map, acc_map

# Render a single image of the scene 

In [None]:
import time, cProfile, pstats
model = init_model()
#model = neRF()
optimizer = tf.keras.optimizers.Adam(5e-4)

N_samples = 64
N_iters = 2000
psnrs = []
iternums = []
losses = []
i_plot = 50

tf.random.set_seed(0)

log = False
profile = False

pr = cProfile.Profile()
pr.enable()

h = 10
clus = CLS2(39, h)

#t = time.time()
for i in range(N_iters+1):
    img_i = np.random.randint(images.shape[0])
    target = images[img_i]
    pose = poses[img_i]
    rays_o, rays_d = get_rays(H, W, focal, pose)
    with tf.GradientTape() as tape:
        if (i%10==0): # run with clustering every 10th iteration
          rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples,cls=clus, rand=True)
        else:
          rgb, depth, acc = render_rays_original(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples, rand=True)
        loss = tf.reduce_mean(tf.square(rgb - target))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))


    if i%i_plot==0 and log:
        
        print(i, (time.time() - t) / i_plot, 'secs per iter')
        
        if profile:
          p = pstats.Stats('[Profile]forward')
          p.strip_dirs().sort_stats('tottime').print_stats(.01)

          p = pstats.Stats('[Profile]rendering')
          p.strip_dirs().sort_stats('tottime').print_stats(.01)

          p = pstats.Stats('[Profile]backprop')
          p.strip_dirs().sort_stats('tottime').print_stats(.01)
        
        # Render the holdout view for logging
        
        rays_o, rays_d = get_rays(H, W, focal, testpose)
        rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
        loss = tf.reduce_mean(tf.square(rgb - testimg))
        psnr = -10. * tf.math.log(loss) / tf.math.log(10.)
        psnrs.append(psnr.numpy())
        losses.append(loss.numpy())
        iternums.append(i)
        print("PSNR: {}".format(psnr))
        plt.figure(figsize=(10,4))
        plt.subplot(131)
        plt.imshow(rgb)
        plt.title(f'Iteration: {i}')
        plt.subplot(132)
        plt.plot(iternums, psnrs)
        
        plt.title('PSNR')
        plt.subplot(133)
        plt.plot(iternums, losses)
        plt.title('Loss')
        plt.grid()
        if (i%100==0):
          plt.savefig("outputs_{}.png".format(i))
        plt.show()

        t = time.time()

pr.disable()
ps = pstats.Stats(pr).sort_stats('tottime').dump_stats("[Profile]training")
p = pstats.Stats('[Profile]training')
p.strip_dirs().sort_stats('tottime').print_stats(.01)
model.save('nerf_default')

# Test
testimg, testpose = images[102], poses[102]
img_i = np.random.randint(images.shape[0])
rays_o, rays_d = get_rays(H, W, focal, testpose)
rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
loss = tf.reduce_mean(tf.square(rgb - testimg))
psnr = -10. * tf.math.log(loss) / tf.math.log(10.)
print("PSNR: {}".format(psnr))
print("Loss: {}".format(loss))

plt.figure(figsize=(10,4))
plt.imshow(rgb)
plt.savefig("output.png")
plt.show()
print('Done')