Imports needed to load the data, train the model, and plot its performance.


In [1]:
import os
import matplotlib.pyplot as plt 
import numpy as np
import json

import torch
import torch.nn as nn
import torch.utils.data as data
import cv2
import imageio



Load data to train the network: Code was adapted from the official NeRF repository to work with PyTorch. https://github.com/bmild/nerf/blob/master/load_blender.py

In [2]:
# Translation matrix but 't' is only considered to be on the z axis. It basically translates the given point by t in the z axis direction.
# Note : This is weird as in spherical coordinates for which these matrices are used, the given 't' value is usually the radius.
trans_t = lambda t : torch.tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1]]).float()


# Rx rotation matrix rotation. Rotation around the x axis. Angle given by 'phi'
rot_phi = lambda phi : torch.tensor([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1]]).float()

# Rz rotation matrix. Rotation around the z axis. Angle given by 'th'
rot_theta = lambda th : torch.tensor([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1]]).float()


def pose_spherical(theta, phi, radius):
    '''
        not sure about this :- seems to perform some sort of position transformation from a 'camera' reference frame to a general 'world' reference frame in carthesian cooridnates
    '''
    # translate point along the z axis by 'radius'
    c2w = trans_t(radius)

    # rotate point by phi around x axis
    c2w = rot_phi(phi/180.*np.pi) @ c2w

    # rotate point by theta around z(?) axis
    c2w = rot_theta((theta/180.*np.pi)) @ c2w

    # I don't understand this transform. It looks like its scaling the x coordinates by -1
    c2w = torch.Tensor(np.array([[-1,0,0,0],
                                 [ 0,0,1,0],
                                 [ 0,1,0,0],
                                 [ 0,0,0,1]])) @ c2w # Had to call double() on both tensors or the matmul() wouldn't work for some reason
    return c2w
    

def load_blender_data(basedir, half_res=False, testskip=1):
    '''
        inputs : 
                basedir : (str) containing the base directory where the data can be found
                half_res: (bool) reduces the resolution of the images to haf of its pixels if true
                testskip: (int) sep with which images are loaded: 1 => all loaded, 2 => only half of them ...
        outputs:
                imgs: numpy array of images as array of RGB values por each pixel 
                not sure about this :- poses: numpy array of 4x4 transformation matrices giving the position and angle of the object with respect to a general reference frame : more info on how they work on https://www.brainvoyager.com/bv/doc/UsersGuide/CoordsAndTransforms/SpatialTransformationMatrices.html
                render_poses: numpy array of 4x4 transformation matrices giving the position of the camera relative to a general reference frame in a circular path for rendering purposes.
                [H, W, focal] : height of screen on which object is projected (pixels), width of screen on which object is proejected (pixels), focal distance : distance between camera and center of screen in some arbitrary unit
                i_split :array of 3 arrays with the numbers of the indices of [ train , val, test] images
    '''
    splits = ['train', 'val', 'test']
    metas = {}
    for s in splits:
        with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]
    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []
        if s=='train' or testskip==0:
            skip = 1
        else:
            skip = testskip
            
        for frame in meta['frames'][::skip]:
            fname = os.path.join(basedir, frame['file_path'] + '.png')
            imgs.append(imageio.imread(fname))
            poses.append(np.array(frame['transform_matrix']))
        imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
        poses = np.array(poses).astype(np.float32)
        counts.append(counts[-1] + imgs.shape[0])
        all_imgs.append(imgs)
        all_poses.append(poses)
    
    i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
    
    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)
    
    H, W = imgs[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x'])
    focal = .5 * W / np.tan(.5 * camera_angle_x)
    
    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
    
    if half_res:
        H = H//2
        W = W//2
        focal = focal/2.

        imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
        for i, img in enumerate(imgs):
            imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
        imgs = imgs_half_res
        # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()

        
    return imgs, poses, render_poses, [H, W, focal], i_split

Define a function that will return the ray direction and origin for the images given in the dataset. Code adapted from the original NeRF repository : https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py

In [110]:
def get_all_rays(H, W, focal, pose):
    """Get ray origins, and directions from a pinhole camera. given the 'pose' transform matrix to transform the direction and position 
       from standard camera at origin to actual position and direction in world cooridnates"""
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H), indexing='xy')
    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
    rays_d =  torch.sum((dirs[..., np.newaxis, :]) * pose[:3,:3], -1)
    rays_o = torch.tensor(pose[:3,-1]).expand(rays_d.shape)
    return rays_o, rays_d

def get_random_ray_batch(H,W, focal, pose, x_pixel_coord,y_pixel_coord):
   "get randomly nb_rays from the rays tha go through each pixel."
   rays_o,rays_d = get_all_rays(H,W,focal,pose)

   
   return torch.squeeze(rays_o[x_pixel_coord,y_pixel_coord]), torch.squeeze(rays_d[x_pixel_coord,y_pixel_coord])


Defining the NeRF neural network as defined in the paper

In [43]:
class NeRF(nn.Module):
    def __init__(self):
        super().__init__()
        
        input_position = 60
        input_direction = 24
        output_colour = 3
        hidden_features = 256

        self.l1 = nn.Linear(input_position,  hidden_features)
        self.l2 = nn.Linear(hidden_features, hidden_features)
        self.l3 = nn.Linear(hidden_features, hidden_features)
        self.l4 = nn.Linear(hidden_features, hidden_features)
        self.l5 = nn.Linear(hidden_features, hidden_features)
        self.l6 = nn.Linear(hidden_features + input_position, hidden_features)
        self.l7 = nn.Linear(hidden_features, hidden_features)
        self.l8 = nn.Linear(hidden_features, hidden_features)        
        self.l9 = nn.Linear(hidden_features, hidden_features)
        self.l10 = nn.Linear(hidden_features+input_direction, 128)
        self.l11 = nn.Linear(128, output_colour)

        self.activationReLU = nn.ReLU()
        self.activationSigmoid = nn.Sigmoid()

    def forward(self, pos, dir):

        h1 = self.l1(self.activationReLU(pos))
        h2 = self.l2(self.activationReLU(h1))
        h3 = self.l3(self.activationReLU(h2))
        h4 = self.l4(self.activationReLU(h3))
        h5 = torch.cat((self.l5(self.activationReLU(h4)),pos))
        h6 = self.l6(self.activationReLU(h5))
        h7 = self.l7(self.activationReLU(h6))
        h8 = self.l8(self.activationReLU(h7)) 
        h9 = self.l9(h8)# no activation function before layer 9
        density = h9[0] #output density
        h9 = torch.cat((h9,dir))#### attention possible bug : cat sur la bonne dimension
        h10 = self.l10(self.activationReLU(h9)) 
        colour = self.l11(self.activationSigmoid(h10))

        return density, colour
    

Define the loss function

In [5]:
def loss_fct(rgb_pred_coarse,rgb_pred_fine,rgb_true):
    loss = 0
    for i in range(len(rgb_pred_coarse)):
        loss += (torch.norm(torch.sub(rgb_pred_coarse[i], rgb_true[i]),2) + torch.norm(torch.sub(rgb_pred_fine[i], rgb_true[i]),2))

Define the encoding function that will take the inputs of the neural network and project them to a higher dimension input

In [33]:
def encoding_fct(value,level):

    encoded = torch.zeros(level*2)

    for i in range(0,level*2, 2):
        encoded[i] = torch.sin(torch.tensor(np.power(2,i)*np.pi*value))
        encoded[i+1] = torch.cos(torch.tensor(np.power(2,i)*np.pi*value))
        
    return encoded

Query fct to get a sample from NeRF networks

In [30]:
def query_from_NeRF(network,pos,direction):

    l_pos = 10
    l_dir = 4

    pos_query = torch.empty(0)
    dir_query = torch.empty(0)
    
    for coord in pos:
        pos_query = torch.cat((pos_query, encoding_fct(coord, l_pos)))
    
    for d in direction:
        dir_query = torch.cat((dir_query, encoding_fct(d,l_dir)))

    density, colour = network.forward(pos_query,dir_query)
    return density , colour

Function the returns random samples following a probability distribution function computed from the weights.

In [116]:
def pdf_sampling(weights,min_dist,distance_step,nb_samples_fine,nb_samples_coarse):
    nb_bins = nb_samples_coarse + 2
    cumulative_probability = torch.zeros(nb_bins)
    bound_distance = torch.zeros(nb_bins)
    bound_distance[0] = min_dist

    for i in range(1,nb_bins-1):
        cumulative_probability[i] = cumulative_probability[i-1] + weights[i-1]
        bound_distance[i] = bound_distance[i-1] + distance_step

    samples = torch.zeros(nb_samples_fine)

    #generate as many numbers in the [0,1[ range as there are samples
    random_samples = torch.rand(nb_samples_fine)

    for i, rand_nb in enumerate(random_samples):
        # for each random sample check within each bin of the cumulative pdf it is 
        for j in range(nb_bins - 1 ):
            
            if(rand_nb < cumulative_probability[j+1] and rand_nb >= cumulative_probability[j]):
                # when in appropriate bin, redraw uniformly between lowest and largest distance of the bin
                samples[i] = (bound_distance[j] - bound_distance[j+1]) * torch.rand(1) + bound_distance[j+1]
                break

    return samples

Hierarchical volume sampling

In [115]:
def hierarchical_volume_sampling(nb_coarse_samples,nb_fine_samples,coarse_net,fine_net,origin:torch.TensorType, direction:torch.TensorType,min_dist=0.,max_dist = 1.):
    
    coarse_samples_density = torch.zeros(nb_coarse_samples)
    weights = torch.zeros(nb_coarse_samples)
    weights = torch.zeros(nb_coarse_samples)
    distance_fraction = torch.tensor(min_dist + max_dist) / (nb_coarse_samples + 1)
    
    coarse_samples_distance = torch.zeros(nb_coarse_samples)
    coarse_colour_of_ray = torch.zeros(3)

    for i in range(nb_coarse_samples):
        
        # equidistant samples with sam direction
        coarse_samples_distance[i] = torch.add(min_dist, (torch.mul(distance_fraction,i + 1))) #sorted by definition

        # we only care for the density of the coarse samples
        coarse_samples_density[i],coarse_colour = query_from_NeRF(coarse_net, origin +  (direction * coarse_samples_distance[i]),direction)

        #compute Ti value
        Ti = torch.tensor(0)
        for k in range(i):
            #TODO : this should be simplified with exponent rules to not have to do a loop
            Ti = torch.add(Ti,coarse_samples_density[k]*distance_fraction)
        Ti = torch.exp(-Ti)
        
        #compute weights of importance of sample to generate pdf later
        weights[i] = Ti * (1 - torch.exp(-distance_fraction * coarse_samples_density[i]))
        coarse_colour_of_ray = weights[i] * coarse_colour
    
    # normalize
    total_weights = torch.sum(weights)
    weights = torch.div(weights,total_weights)
    
    # generate sample locations from pdf made with the weights:
    fine_samples_distance = pdf_sampling(weights,min_dist,distance_fraction,nb_fine_samples,nb_coarse_samples)

    #add the coarse locations to the set of lecations and sort them in incresing order
    fine_samples_distance = torch.cat((fine_samples_distance,coarse_samples_distance))
    fine_samples_distance,_ = torch.sort(fine_samples_distance)
    
    fine_samples_density = torch.zeros((nb_coarse_samples + nb_fine_samples))
    fine_samples_rgb = torch.zeros((nb_coarse_samples + nb_fine_samples,3))

    for i in range(nb_fine_samples):
        fine_samples_density[i],fine_samples_rgb[i] = query_from_NeRF(fine_net,origin +  (direction * fine_samples_distance[i]),direction)

    fine_colour_of_ray = torch.zeros(3)
    for i in range(nb_coarse_samples+nb_fine_samples-1):
        Ti = torch.tensor(0)
        for k in range(i-1):
            #TODO : this should be simplified with exponent rules to not have to do a loop
            Ti = torch.add(Ti, fine_samples_density[k]* (fine_samples_distance[k+1] - fine_samples_distance[k]))
        Ti = torch.exp(-Ti)

        fine_colour_of_ray =torch.add(fine_colour_of_ray, Ti * (1 - torch.exp(-fine_samples_density[i] * (fine_samples_distance[i+1] - fine_samples_distance[i]))))

    return coarse_colour_of_ray,fine_colour_of_ray



Function to synthesize a view from a focal distance and a pose (direction and position).

In [10]:
def synthesize_view(coarse_net, fine_net,focal, pose,H,W):

    rays_origin, rays_dir = get_all_rays(H,W,focal,pose)
    pixels = np.zeros((H,W,3))
    for i in range(H):
        for j in range(W):
            
            colour = hierarchical_volume_sampling(64,128,coarse_net,fine_net,rays_origin[i+j],rays_dir[i+j])
            pixels[i,j] = colour

    return pixels

Training step function.

Training data should be in the format [imgs, poses, render_poses, [H, W, focal], data_split_indices]

In [114]:
def training_step(optimizer,criterion,coarse_net,fine_net,training_data):
    
    #randomly select a training img

    image_nb = np.random.choice(training_data[4][0])

    true_pixels = training_data[0][image_nb]
    pose = training_data[1][image_nb]
    H,W,focal = training_data[3]
    #randomly select a batch of rays
    
    max_batch_size = 100
    batch_size = max_batch_size
    if H < batch_size:
        batch_size = H
    if W < batch_size:
        batch_size = W

    x_pixels_coord = np.random.choice(np.arange(W),batch_size,replace=False).reshape((batch_size,1))
    y_pixels_coord = np.random.choice(np.arange(H),batch_size,replace=False).reshape((batch_size,1))

    rays_o,rays_d = get_random_ray_batch(H,W,focal,pose,x_pixels_coord,y_pixels_coord)

    fine_col = torch.zeros(batch_size,3)
    coarse_col = torch.zeros(batch_size,3)

    for i in range(batch_size):
        a = hierarchical_volume_sampling(64,128,coarse_net,fine_net,rays_o[i], rays_d[i])
        coarse_col[i], fine_col[i] = a

    #compute loss and back prop
    optimizer.zero_grad()    
    loss = criterion(coarse_col,fine_col,true_pixels[x_pixels_coord,y_pixels_coord])
    loss.backward()
    optimizer.step()

    return loss

Function to decrease learning rate as specified in paper: "learning rate that begins at 5 ×10−4 and decays exponentially to 5 ×10−5 over the course of optimization"

In [12]:
def compute_lr_decay_factor(total_steps,initial_val,end_val):
    # compute the decay from the inverse of y = a*decayfactor ^ steps

    diff = abs(initial_val - end_val)
    a = torch.tensor(diff/initial_val)
    b = torch.tensor(total_steps)
    decay_factor = torch.log(a) / torch.log(b) # base change rule
    return decay_factor

def update_lr(optimizer,decay,initial_lr,current_step):
    new_lr = initial_lr * decay**current_step

    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr

    return


MAIN LOOP with data loading,instantiation of NeRF networks and training of the NeRFs

In [13]:
#Load data. Load in halfres to get better performances for testing
synthetic_data = load_blender_data('data/nerf_synthetic/lego',half_res=True) 

In [118]:

#Instantiate NeRF networks
fine_scene1 = NeRF()
coarse_scene1 = NeRF()

# attribute loss fct
criterion = loss_fct

# set optimizer
initial_lr = 5e-04
optimizer = torch.optim.Adam(list(coarse_scene1.parameters()) + list(fine_scene1.parameters()), lr=initial_lr,eps=1e-08) 

#set nb of training steps and lr_decay
nb_steps = 1000
lr_decay = compute_lr_decay_factor(nb_steps,initial_lr,5e-05)

# set array to plot losses
losses = np.zeros(nb_steps)
# start training
for step in range(nb_steps):
    print(step,end='\r')
    losses[step] = training_step(optimizer,criterion,coarse_scene1,fine_scene1,synthetic_data)
    update_lr(optimizer,lr_decay,initial_lr,step)
    


0

  encoded[i] = torch.sin(torch.tensor(np.power(2,i)*np.pi*value))
  encoded[i+1] = torch.cos(torch.tensor(np.power(2,i)*np.pi*value))


Plot losses

Show final trained result from a novel view