Imports needed to load the data, train the model, and plot its performance.


In [14]:
import os
import matplotlib.pyplot as plt 
import numpy as np
import json

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

from PIL import Image
import imageio 
from torchvision import datasets, transforms, utils


Load data to train the network: Code was adapted from the official NeRF repository to work with PyTorch. https://github.com/bmild/nerf/blob/master/load_blender.py

In [19]:
trans_t = lambda t : torch.tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1],
], dtype=torch.float32)

rot_phi = lambda phi : torch.tensor([
    [1,0,0,0],
    [0,torch.cos(phi),-torch.sin(phi),0],
    [0,torch.sin(phi), torch.cos(phi),0],
    [0,0,0,1],
], dtype=torch.float32)

rot_theta = lambda th : torch.tensor([
    [torch.cos(th),0,-torch.sin(th),0],
    [0,1,0,0],
    [torch.sin(th),0, torch.cos(th),0],
    [0,0,0,1],
], dtype=torch.float32)


def pose_spherical(theta, phi, radius):
    c2w = trans_t(torch.tensor(radius))
    c2w = rot_phi(torch.tensor(phi/180.*np.pi)) @ c2w
    c2w = rot_theta(torch.tensor((theta/180.*np.pi))) @ c2w
    c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w
    return c2w
    


def load_blender_data(basedir, half_res=False, testskip=1):
    splits = ['train', 'val', 'test']
    metas = {}
    for s in splits:
        with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]
    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []
        if s=='train' or testskip==0:
            skip = 1
        else:
            skip = testskip
            
        for frame in meta['frames'][::skip]:
            fname = os.path.join(basedir, frame['file_path'] + '.png')
            imgs.append(imageio.imread(fname))
            poses.append(np.array(frame['transform_matrix']))
        imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
        poses = np.array(poses).astype(np.float32)
        counts.append(counts[-1] + imgs.shape[0])
        all_imgs.append(imgs)
        all_poses.append(poses)
    
    i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
    
    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)
    
    H, W = imgs[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x'])
    focal = .5 * W / np.tan(.5 * camera_angle_x)
    
    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]],0)
    
    if half_res:
        imgs = torch.image.resize_area(imgs, [400, 400]).numpy()
        H = H//2
        W = W//2
        focal = focal/2.
        
    return imgs, poses, render_poses, [H, W, focal], i_split

synthetic_data = load_blender_data('data/nerf_synthetic/lego')

Defining the NeRF neural network as defined in the paper

In [3]:
class NeRF(nn.Module):
    def __init__(self):
        super().__init__()
        
        input_position = 60
        input_direction = 24
        output_density = 1
        output_colour = 3
        hidden_features = 256

        self.l1 = nn.Linear(input_position,  hidden_features)
        self.l2 = nn.Linear(hidden_features, hidden_features)
        self.l3 = nn.Linear(hidden_features, hidden_features)
        self.l4 = nn.Linear(hidden_features, hidden_features)
        self.l5 = nn.Linear(hidden_features + input_position, hidden_features)
        self.l6 = nn.Linear(hidden_features, hidden_features)
        self.l7 = nn.Linear(hidden_features, hidden_features)
        self.l8 = nn.Linear(hidden_features, hidden_features)        
        self.l9 = nn.Linear(hidden_features+input_direction, hidden_features+output_density)
        self.l10 = nn.Linear(hidden_features, 128)
        self.l11 = nn.Linear(128, output_colour)

        self.activationReLU = nn.ReLU()
        self.activationSigmoid = nn.Sigmoid()

    def forward(self, pos, dir):

        h1 = self.activationReLU(self.l1(pos))
        h2 = self.activationReLU(self.l2(h1))
        h3 = self.activationReLU(self.l3(h2))
        h4 = self.activationReLU(self.l4(h3))
        h5 = self.activationReLU(self.l5(torch.cat([h4, pos]))) 
        h6 = self.activationReLU(self.l6(h5))
        h7 = self.activationReLU(self.l7(h6))
        h8 = self.l8(h7) # no activation function before layer 9
        partial_h9 = self.l9(h8)
        density = partial_h9[:,0]
        h9 = self.activationReLU(torch.cat([partial_h9[:,1:] + dir])) #### cat sur la bonne dimension
        h10 = self.activationReLU(self.l10(h9))
        colour = self.activationReLU(self.l11(h10))

        return density, colour
    

Instantiate a coarse and fine network to start the scene function approximation

In [4]:
fine_scene1 = NeRF()
coarse_scene1 = NeRF()

Define the encoding function that will take the inputs of the neural network and project them to a higher dimension input

In [11]:
def encoding_fct(value,max_dim):
    encoded = torch.zeros(max_dim*2)
    for i in range(0,max_dim*2, 2):
        encoded[i] = torch.sin(torch.pow(torch.Tensor([2]),i)*torch.pi*value)
        encoded[i+1] = torch.cos(torch.pow(torch.Tensor([2]),i)*torch.pi*value)
    return encoded

tensor([0., 1., 0., 1., 0., 1., 0., 1.])

TODO : Define Hierarchical volume sampling

In [None]:
def hierarchical_volume_sampling(nb_coarse_samples,nb_fine_samples,coarse_net,fine_net,min_dist,max_dist, origin,direction):
    
    for i in range(nb_coarse_samples):
        # equidistant samples

        for j in range(nb_fine_samples):
            # samples randomly taken from a probability law made from the coarse samples weights

    return colour