Imports needed to load the data, train the model, and plot its performance.


In [1]:
import os
import matplotlib.pyplot as plt 
import numpy as np
import json

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

from PIL import Image
import torchvision



Load data to train the network: Code was adapted from the official NeRF repository to work with PyTorch. https://github.com/bmild/nerf/blob/master/load_blender.py

In [2]:
# Translation matrix but 't' is only considered to be on the z axis. It basically translates the given point by t in the z axis direction.
# Note : This is weird as in spherical coordinates for which these matrices are used, the given 't' value is usually the radius.
trans_t = lambda t : torch.tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1],
], dtype=torch.float32)


# Rx rotation matrix rotation. Rotation around the x axis. Angle given by 'phi'
rot_phi = lambda phi : torch.tensor([
    [1,0,0,0],
    [0,torch.cos(phi),-torch.sin(phi),0],
    [0,torch.sin(phi), torch.cos(phi),0],
    [0,0,0,1],
], dtype=torch.float32)

# Rz rotation matrix. Rotation around the z axis. Angle given by 'th'
rot_theta = lambda th : torch.tensor([
    [torch.cos(th),0,-torch.sin(th),0],
    [0,1,0,0],
    [torch.sin(th),0, torch.cos(th),0],
    [0,0,0,1],
], dtype=torch.float32)


def pose_spherical(theta, phi, radius):
    '''
        not sure about this :- seems to perform some sort of position transformation from a 'camera' reference frame to a general 'world' reference frame in carthesian cooridnates
    '''
    # translate point along the z axis by 'radius'
    c2w = trans_t(torch.tensor(radius))

    # rotate point by phi around x axis
    c2w = torch.matmul(rot_phi(torch.tensor(phi/180.*np.pi)) , c2w)

    # rotate point by theta around z(?) axis
    c2w = torch.matmul(rot_theta(torch.tensor((theta/180.*np.pi))) , c2w)

    # I don't understand this transform. It looks like its scaling the x coordinates by -1
    c2w = torch.matmul(torch.tensor(np.array([[-1,0,0,0],
                                              [ 0,0,1,0],
                                              [ 0,1,0,0],
                                              [ 0,0,0,1]])).double() , c2w.double()) # Had to call double() on both tensors or the matmul() wouldn't work for some reason
    return c2w
    


def load_blender_data(basedir, half_res=False, testskip=1):
    '''
        inputs : 
                basedir : (str) containing the base directory where the data can be found
                half_res: (bool) reduces the resolution of the images to haf of its pixels if true
                testskip: (int) sep with which images are loaded: 1 => all loaded, 2 => only half of them ...
        outputs:
                imgs: numpy array of images as array of RGB values por each pixel 
                not sure about this :- poses: numpy array of 4x4 transformation matrices giving the position and angle of the object with respect to a general reference frame : more info on how they work on https://www.brainvoyager.com/bv/doc/UsersGuide/CoordsAndTransforms/SpatialTransformationMatrices.html
                render_poses: numpy array of 4x4 transformation matrices giving the position of the camera relative to a general reference frame in a circular path for rendering purposes.
                [H, W, focal] : height of screen on which object is projected (pixels), width of screen on which object is proejected (pixels), focal distance : distance between camera and center of screen in some arbitrary unit
                i_split :array of 3 arrays with the numbers of the indices of [ train , val, test] images
    '''
    splits = ['train', 'val', 'test']
    metas = {}
    for s in splits:
        with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]
    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []
        if s=='train' or testskip==0:
            skip = 1
        else:
            skip = testskip
            
        for frame in meta['frames'][::skip]:
            fname = os.path.join(basedir, frame['file_path'] + '.png')
            imgs.append(torchvision.io.read_image(fname)) # Load images as RGB torch tensors
            poses.append(np.array(frame['transform_matrix']))
        poses = np.array(poses).astype(np.float32)
        counts.append(counts[-1] + len(imgs))
        all_imgs.append(imgs)
        all_poses.append(poses)
    
    i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)
    
    H, W = imgs[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x']) # FOV angle of camera
    focal = .5 * W / np.tan(.5 * camera_angle_x) # focal distance of camera to the screen.
    
    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]],0)

    if half_res:
        for i,img in enumerate(imgs):
            imgs[i] = torchvision.transforms.functional.resize(img, [400, 400]).numpy()
        H = H//2
        W = W//2
        focal = focal/2.
        
    return imgs, poses, render_poses, [H, W, focal], i_split 

synthetic_data = load_blender_data('data/nerf_synthetic/lego',half_res=True) # load in halfres to get better performances at dev



Define a function that will return the ray direction and origin for the images given in the dataset. Code adapted from the original NeRF repository : https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py

In [3]:
def get_rays(H, W, focal, pose):
    """Get ray origins, and directions from a pinhole camera. given the 'pose' transform matrix to transform the direction and position 
       from standard camera at origin to actual position and direction in world cooridnates"""
    i, j = torch.meshgrid(torch.range(W, dtype=torch.float32),
                       torch.range(H, dtype=torch.float32), indexing='xy')
    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
    rays_d = torch.sum(dirs[..., np.newaxis, :] * pose[:3, :3], -1)
    rays_o = torch.broadcast_to(pose[:3, -1], torch.size(rays_d))
    return rays_o, rays_d


Defining the NeRF neural network as defined in the paper

In [4]:
class NeRF(nn.Module):
    def __init__(self):
        super().__init__()
        
        input_position = 60
        input_direction = 24
        output_density = 1
        output_colour = 3
        hidden_features = 256

        self.l1 = nn.Linear(input_position,  hidden_features)
        self.l2 = nn.Linear(hidden_features, hidden_features)
        self.l3 = nn.Linear(hidden_features, hidden_features)
        self.l4 = nn.Linear(hidden_features, hidden_features)
        self.l5 = nn.Linear(hidden_features + input_position, hidden_features)
        self.l6 = nn.Linear(hidden_features, hidden_features)
        self.l7 = nn.Linear(hidden_features, hidden_features)
        self.l8 = nn.Linear(hidden_features, hidden_features)        
        self.l9 = nn.Linear(hidden_features+input_direction, hidden_features+output_density)
        self.l10 = nn.Linear(hidden_features, 128)
        self.l11 = nn.Linear(128, output_colour)

        self.activationReLU = nn.ReLU()
        self.activationSigmoid = nn.Sigmoid()

    def forward(self, pos, dir):

        h1 = self.activationReLU(self.l1(pos))
        h2 = self.activationReLU(self.l2(h1))
        h3 = self.activationReLU(self.l3(h2))
        h4 = self.activationReLU(self.l4(h3))
        h5 = self.activationReLU(self.l5(torch.cat([h4, pos]))) 
        h6 = self.activationReLU(self.l6(h5))
        h7 = self.activationReLU(self.l7(h6))
        h8 = self.l8(h7) # no activation function before layer 9
        partial_h9 = self.l9(h8)
        density = partial_h9[:,0]
        h9 = self.activationReLU(torch.cat([partial_h9[:,1:] + dir])) #### cat sur la bonne dimension
        h10 = self.activationReLU(self.l10(h9))
        colour = self.activationReLU(self.l11(h10))

        return density, colour
    

Instantiate a coarse and fine network to start the scene function approximation

In [5]:
fine_scene1 = NeRF()
coarse_scene1 = NeRF()

Define the loss function and optimizer

TODO: make it support tensors

In [6]:
def loss_fct(rgb_pred_coarse,rgb_pred_fine,rgb_true):
    loss = 0
    for i in range(len(rgb_pred_coarse)):
        loss += (torch.norm(torch.sub(rgb_pred_coarse[i], rgb_true[i]),2) + torch.norm(torch.sub(rgb_pred_fine[i], rgb_true[i]),2))

criterion = loss_fct

optimizer = torch.optim.Adam(list(coarse_scene1.parameters()) + list(fine_scene1.parameters()), lr=5e-04,eps=1e-08) 
### NOTE : do not forget to reduce the learning rate afterwards when doing the optimization steps

Define the encoding function that will take the inputs of the neural network and project them to a higher dimension input

In [7]:
def encoding_fct(value,max_dim):
    encoded = torch.zeros(max_dim*2)
    for i in range(0,max_dim*2, 2):
        encoded[i] = torch.sin(torch.pow(torch.Tensor([2]),i)*torch.pi*value)
        encoded[i+1] = torch.cos(torch.pow(torch.Tensor([2]),i)*torch.pi*value)
    return encoded

TODO : Define Hierarchical volume sampling

In [8]:
def hierarchical_volume_sampling(nb_coarse_samples,nb_fine_samples,coarse_net,fine_net,min_dist,max_dist, origin,direction):
    
    for i in range(nb_coarse_samples):
        # equidistant samples

        for j in range(nb_fine_samples):
            # samples randomly taken from a probability law made from the coarse samples weights

    return colour

IndentationError: expected an indented block (3591139706.py, line 9)