In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
import numpy as np
import multiprocessing
from math import sin, cos, sqrt, pi
from scipy.spatial.transform import Rotation
import cv2
from PIL import Image
import scipy.interpolate as interpolate

In [None]:
import os
os.chdir("/content/drive/MyDrive/Colab Notebooks")
!pwd

/content/drive/MyDrive/Colab Notebooks


# Preprocessing

In [None]:
cameras = {}
images = {}

with open('south_building/cameras.txt') as f:
    lines = f.readlines()
    for i in range(3, len(lines)):
        vals = lines[i].split(' ')

        camera_id = int(vals[0])
        intrinsics = {}

        intrinsics['W'] = int(vals[2])
        intrinsics['H'] = int(vals[3])

        intrinsics['f'] = float(vals[4])
        intrinsics['cx'] = int(vals[5])
        intrinsics['cy'] = int(vals[6])

        intrinsics['k'] = float(vals[7])

        cameras[camera_id] = intrinsics

with open('south_building/images.txt') as f:
    lines = f.readlines()
    for i in range(4, len(lines), 2):
        vals = lines[i].split(' ')
        image_name = vals[-1]
        extrinsics = {}

        qw = float(vals[1])
        qx = float(vals[2])
        qy = float(vals[3])
        qz = float(vals[4])
        R = Rotation.from_quat([qx, qy, qz, qw]).as_matrix()
        
        tx  = float(vals[5])
        ty = float(vals[6])
        tz = float(vals[7])
        t = np.array([tx, ty, tz]).reshape((3,1))

        extrinsics['R'] = R
        extrinsics['t'] = t
        extrinsics['c_id'] = int(vals[8])
        
        images[image_name] = extrinsics



### TODOs


*   Need to convert points in world coordinate space to NDC space.



In [None]:
import warnings
warnings.filterwarnings('ignore')

dataset = []
basepath = '/content/drive/MyDrive/Colab Notebooks/south_building/'
for i, image in enumerate(images):
    fullpath = basepath + image[:-1] # get rid of \n character
    img = cv2.imread(fullpath) # H x W x c
    extrinsics = images[image]
    intrinsics = cameras[extrinsics['c_id']]

    R = extrinsics['R']
    t = extrinsics['t']
    camera_pos = - R.T @ t

    f = intrinsics['f']
    cx = intrinsics['cx']
    cy = intrinsics['cy']

    # Radial distortion coefficient
    k = intrinsics['k']

    # From each image sample 70000 rays
    for u in range(intrinsics['W']):
        # Apply inverse intrinsic matrix
        xpp = (u - cx) / f
        for v in range(intrinsics['H']):
            ypp = (v - cy) / f

            # Radial distortion correction
            #roots = np.roots([1, 2*k, 1, -(xpp**2 + ypp**2)])
            #r_sq = roots.max().astype(float).item(0)

            #assert(r_sq >= 0)

            # Pixel (u,v) in 3D camera coordinate space
            xp = xpp #/ (1 + k * r_sq)
            yp = ypp #/ (1 + k * r_sq)
            zp = 1.0

            # Pixel (u,v) in 3D world coordinate space
            x = R.T @ (np.array([xp, yp, zp]).reshape(3,1) - t)

            # Ray direction
            d = x - camera_pos
            d = d / np.linalg.norm(d, axis=0) # Normalize
            
            color = img[v, u] # size 3
            dataset.append((camera_pos, d, color))
    break

# Dataset

### TODOs



*   In training NeRF, each data point defines a ray shooting out of a camera origin and the label being the corresponding pixel value of that ray. 



In [None]:
class NeRFDataset(torch.utils.data.Dataset):
    def __init__(self, data, transforms=None):
        super(NeRFDataset, self).__init__()
        
        self.data = data
        self.transforms = transforms


    def __len__(self):
        # return the number of sequences in the dataset
        return len(self.data)
        
    def __getitem__(self, idx):
        #camera_pos, d, color = self.data[idx]
        return self.data[idx] 

# Model

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, L):
        super(PositionalEncoding, self).__init__()
        self.L = L

    def forward(self, x):
        x_enc = torch.zeros(x.shape[0], 2 * self.L * x.shape[1])
        for row in range(x.shape[0]):
            for i in range(x.shape[1]):
                p = x[row, i]
                for new_i in range(self.L):
                    x_enc[row, i * 2 * self.L + new_i] = sin(2**new_i * pi * p)
                    x_enc[row, i * 2 * self.L + new_i + 1] = cos(2**new_i * pi * p)

        return x_enc


class NeRF(nn.Module):
    def __init__(self, L1, L2, input_dim=3, layers=8, feature_dim=256):
        super(NeRF, self).__init__()
        
        self.pos_enc1 = PositionalEncoding(L1)
        self.pos_enc2 = PositionalEncoding(L2)
        modules = []
        modules.append(nn.Linear(2 * L1 * input_dim, feature_dim))
        for i in range(layers):
            modules.append(nn.ReLU())
            if i == layers - 1:
                modules.append(nn.Linear(feature_dim, feature_dim + 1))
            else:
                modules.append(nn.Linear(feature_dim, feature_dim))
            
        self.MLP = nn.Sequential(*modules)

        self.linear = nn.Linear(2 * L2 * input_dim + feature_dim, 128)
        self.output = nn.Linear(128, 3)

    def forward(self, x):
        d = self.pos_enc2(x[:,3:])
        x = self.pos_enc1(x[:,:3])

        output = self.MLP(x)
        density, latent_code = output[:,:1], output[:,1:]

        latent_code = F.relu(self.linear( torch.concat( (d, latent_code), dim=-1) ) )
        color = self.output(latent_code)

        return torch.relu(density), torch.sigmoid(color)

# Training

### TODOs


*   Currently only a single ray is being processed at a time. Sampling points along the ray and integrating color is also being executed sequentially. Need to modify code to process a batch of rays at a time and use tensors to parallelize the computations. This initial inefficient/slow version of the code was written to better understand the algorithm.



In [None]:
def integrate_color(N, density, color, ts, weights=None):
    C = torch.tensor(0.0)
    for i in range(N):
        
        T_i = torch.tensor(0.0)
        for j in range(i):
            delta_j = ts[j + 1] - ts[j]
            sigma_j = density[j]
            T_i = T_i + sigma_j * delta_j

        T_i = torch.exp(-T_i)
        sigma_i = density[i]
        delta_i = ts[i + 1] - ts[i] if i != N - 1 else ts[i] - ts[i - 1]
        c_i = color[i]
 
        w_i = T_i * (1 - torch.exp(-sigma_i * delta_i))
        if weights is not None:
            weights.append(w_i.item())

        C = C + w_i * c_i
    
    return C

def inverse_transform_sampling(X, pdf, n):
    samples = []
    U = np.random.uniform(size=n)

    for i in range(n):
        u = U[i]
        if u <= pdf[0]:
            samples.append(np.random.uniform(tn, X[0]))
        else:
            for j in range(1, len(pdf)):
                if(sum(pdf[0:j]) < u and u <= sum(pdf[0:j + 1])):
                    samples.append(np.random.uniform(X[j - 1], X[j]))
    return samples

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print('Using device', device)

coarse_model = NeRF(10, 4)
fine_model = NeRF(10, 4)
optim = torch.optim.Adam(list(coarse_model.parameters()) + list(fine_model.parameters()), lr=0.01)
coarse_model = coarse_model.to(device)
fine_model = fine_model.to(device)

# Points to sample
Nc = 64
Nf = 128
tn = 1
tf = 50
for idx, ray in enumerate(dataset):
    optim.zero_grad()
    camera_pos, d, ground_truth_color = ray
    ground_truth_color = torch.from_numpy(ground_truth_color).unsqueeze(0) / 255.0

    # Stratified sampling
    points = []
    ts = []
    for bin in range(Nc):
        ti = np.random.uniform(tn + (bin - 1) / Nc * (tf - tn), tn + bin / Nc * (tf - tn))
        points.append((camera_pos + ti * d).reshape(3).tolist() + d.reshape(3).tolist())
        ts.append(ti)

    points = torch.tensor(points) # Nc x 6
    
    density, color = coarse_model(points)

    weights = []
    Cc = integrate_color(Nc, density, color, ts, weights)

    #density, color = fine_model(points)
    #Cf = integrate_color(Nc, density, color, ts)

    # Use weights as a PDF(Probability Density Function) to inverse transform sample Nf points
    weights = np.array(weights)
    weights = weights / np.linalg.norm(weights)

    new_ts = inverse_transform_sampling(ts, weights, Nf)

    new_points = []
    for new_t in new_ts:
        new_points.append((camera_pos + new_t * d).reshape(3).tolist() + d.reshape(3).tolist())
    
    new_points = torch.tensor(new_points) # Nf x 6

    density, color = fine_model(torch.concat((points, new_points), dim=0)) # concat size (Nc + Nf, 6)
    Cf = integrate_color(Nc + Nf, density, color, sorted(ts + new_ts))
 
    loss = F.mse_loss(Cc.unsqueeze(0), ground_truth_color) + F.mse_loss(Cf.unsqueeze(0), ground_truth_color)
    loss.backward()
    optim.step()

    if idx % 10 == 0:
      print('Idx:', idx)
      print('Loss:', loss.item())