In [12]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# NOTE: Not implemented hierarchical sampling

In [13]:
# Define NERF architecture
class NerfModel(nn.Module):
    def __init__(self, posEncCoord_dim=10, posEncDirection_dim=4, output_dim=128):
        """
        Args:
            posEncCoord_dim: higher level dimension of 3D coordinate vector (X) after positional encoding
            posEncDirection_dim: higher level dimension of viewing direction vector (d) after positional encoding
            output_dim: dimension of output of a layer
        """
        super(NerfModel, self).__init__()
        
        self.block1 = nn.Sequential(nn.Linear(posEncCoord_dim * 3 * 2 + 3, output_dim), nn.ReLU(),
                                    nn.Linear(output_dim, output_dim), nn.ReLU(),
                                    nn.Linear(output_dim, output_dim), nn.ReLU(),
                                    nn.Linear(output_dim, output_dim), nn.ReLU(), )
        # Density (sigma) estimation
        self.block2 = nn.Sequential(nn.Linear(posEncCoord_dim * 3 * 2 + output_dim + 3, output_dim), nn.ReLU(),
                                    nn.Linear(output_dim, output_dim), nn.ReLU(),
                                    nn.Linear(output_dim, output_dim), nn.ReLU(),
                                    nn.Linear(output_dim, output_dim + 1), )
        # Color estimation
        self.block3 = nn.Sequential(nn.Linear(posEncDirection_dim * 2 * 3 + output_dim + 3, output_dim // 2), nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(output_dim // 2, 3), nn.Sigmoid(), )

        self.posEncCoord_dim = posEncCoord_dim
        self.posEncDirection_dim = posEncDirection_dim
        self.relu = nn.ReLU()

    @staticmethod
    def positional_encoding(x, L):
        """
        Args:
        x: lower-dimensional matrix to be encoded
        L: output dimension after positional encoding

        Explanation:
        Passing freq*element as input, sin and cosine values are calculated 
        and appended to the matrix for each element in the matrix. 
        Frequency changes from 0 to L (max freq parameter)

        So, final dimensions after pos enc are L*3*2 + 3
        """
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, o, d):
        """
        Args:
        o: 3D input location
        d: 3D unit Cartesian vector representing viewing direction

        Explanation:
        Given a 3D location and a viewing direction, the expected RGB value and
        density is predicted by the MLP.
        """
        emb_x = self.positional_encoding(o, self.posEncCoord_dim) # emb_x: [batch_size, posEncDirection_dim * 6 + 3]
        emb_d = self.positional_encoding(d, self.posEncDirection_dim) # emb_d: [batch_size, posEncDirection_dim * 6 + 3]
        features = self.block1(emb_x) # features: [batch_size, output_dim]
        tmp = self.block2(torch.cat((features, emb_x), dim=1)) # tmp: [batch_size, output_dim + 1]
        features, sigma = tmp[:, :-1], self.relu(tmp[:, -1]) # features: [batch_size, output_dim], sigma: [batch_size]
        features = self.block3(torch.cat((features, emb_d), dim=1)) # features: [batch_size, output_dim // 2]
        c = self.block4(features) # c: [batch_size, 3]
        return c, sigma

In [14]:
def compute_accumulated_transmittance(alphas):
    """
    Args:
    alphas: Values indicating transparency or opacity along a ray path

    Explanation: The function computes the accumulated transmittance along a
    sequence of alpha values
    """
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)

In [15]:
def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device

    # Setting up a system where rays have associated depths along their paths, and 
    # these depths are evenly distributed between a "near" and "far" point. 
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)

    
    # Perturbing the sampled depths with random values by introducing small, random variations to
    # the depths along each ray. This is done to avoid patterns or regularities in
    # the sampling process
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # t: [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    
    # Compute the 3D points along each ray: The x tensor is computed by adding
    # the perturbed depths multiplied by the ray directions to the ray origins.
    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)   # x: [batch_size, nb_bins, 3]
    
    # Expand the ray_directions tensor to match the shape of x
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1) 

    # The NeRF model to predict colors and densities along the rays. Reshape the results to match the shape of x
    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    # Compute alpha values based on the predicted density and perturbed depths.
    # alpha values indicate transparency or opacity along a ray path
    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    
    # Compute the pixel values as a weighted sum of colors along each ray
    c = (weights * colors).sum(dim=1)
    weight_sum = weights.sum(-1).sum(-1)
    
    # Regularization for white background 
    return c + 1 - weight_sum.unsqueeze(-1)

In [16]:
def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5),
          nb_bins=192, H=400, W=400):
    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        for batch in data_loader:
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)
            
            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins) 
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        scheduler.step()

        for img_index in range(200):
            test(hn, hf, testing_dataset, img_index=img_index, nb_bins=nb_bins, H=H, W=W)
    return training_loss

In [20]:
if __name__ == '__main__':
    # device = 'cuda'
    
    training_dataset = torch.from_numpy(np.load('training_data.pkl', allow_pickle=True))
    testing_dataset = torch.from_numpy(np.load('testing_data.pkl', allow_pickle=True))
    model = NerfModel(output_dim=256).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
    data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True, pin_memory=False)
    train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=device, hn=2, hf=6, nb_bins=192, H=400,
          W=400)

In [21]:
@torch.no_grad()
def test(hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
    """
    Args:
        hn: near plane distance
        hf: far plane distance
        dataset: dataset to render
        chunk_size (int, optional): chunk size for memory efficiency. Defaults to 10.
        img_index (int, optional): image index to render. Defaults to 0.
        nb_bins (int, optional): number of bins for density estimation. Defaults to 192.
        H (int, optional): image height. Defaults to 400.
        W (int, optional): image width. Defaults to 400.
        
    Returns:
        None: None
    """
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []   # list of regenerated pixel values
    for i in range(int(np.ceil(H / chunk_size))):   # iterate over chunks
        # Get chunk of rays
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)        
        regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values)
    img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)

    plt.figure()
    plt.imshow(img)
    plt.savefig(f'novel_views/img_{img_index}.png', bbox_inches='tight')
    plt.close()
