# NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

*A deep dive into implementing NeRF from scratch. We start by setting up the environment and understanding the mathematical foundations of Ray Marching.* 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class NeRFNetwork(nn.Module):
    def __init__(self):
        super().__ini__()
        self.density_mlp=nn.Sequential(
            
            nn.Linear(3,256),
            nn.ReLU(),            
            nn.Linear(256,256),
            nn.ReLU(),            
            nn.Linear(256,256),
            nn.ReLU(),            
            nn.Linear(256,256),
            nn.ReLU(),

            nn.Linear(256,256),
            nn.ReLU(),
            nn.Linear(256,256),
            nn.ReLU(),
            nn.Linear(256,256),
            nn.ReLU(),
            
            nn.Linear(256,256),
            nn.ReLU(),
        )
        self.density_head=self.Linear(256,1)#sigma
        self.color_mlp = nn.Sequential(

            nn.Linear(256+3,128 ),
            nn.ReLU(),

            nn.Linear(128,3 ), # RGB
        ) 
    def forward(self,x,d):
        h = self.density_mlp(x)
        sigma = self.density_head(h)

        conc = torch.concat[h,d]
        color = self.color_mlp( conc  )
        return color,sigma

In [None]:
def volume_render(raw, z_vals, rays_d):
    # raw: [N_rays, N_samples, 4] r,g,b,sigma 

    # z_vals: [N_rays, N_samples] 
    # rays_d: [N_rays, 3] 
    
    # delta
    dists = z_vals[..., 1:] - z_vals[..., :-1] 
    
    # The last sample goes to infinity
    last_dist = torch.tensor([1e10], device=raw.device).expand(dists[..., :1].shape)
    
    dists = torch.cat([dists, last_dist], -1)

    # distance = dt
    dists = dists * torch.norm(rays_d.unsqueeze(1), dim=-1)

    # rgb and sigma
    rgb = raw[..., :3] 
    sigma = F.relu(raw[..., 3])        

    # opacity
    opacity = 1.0 - torch.exp(-sigma * dists)

    # T
    p = 1.0 - opacity + 1e-10
    T = torch.cumprod(torch.cat([torch.ones((opacity.shape[0], 1), device=raw.device), p], -1), -1)[:, :-1]

    # weights
    weights = T * opacity
    acc_map = torch.sum(weights, -1) 
    rgb_map = torch.sum(weights.unsqueeze(-1) * rgb, -2)

    # Add White Background
    rgb_map = rgb_map + (1. - acc_map.unsqueeze(-1) )
    
    return rgb_map,weights

In [7]:
def sample_pdf(bins, weights, N_fine=128, noise=True):
    """
    Sample N_fine points from the probability distribution defined by weights.
    bins: [Batch, N_coarse-1] (Mid-points of coarse z_vals) 
    weights: [Batch, N_coarse-2] (Weights from coarse network)
    """



    weights = weights + 1e-5
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # [Batch, N_coarse-1]
    bins = torch.cat([bins, bins[..., -1:]], -1) 

    # 3. Generate random queries (u)
    shape_of_queries= list(cdf.shape[:-1]) + [N_fine]
    # (Batch,128)
    if noise:
        u = torch.rand(shape_of_queries, device=weights.device) 
    else:
        u = torch.linspace(0., 1., steps=N_fine, device=weights.device)   
        u = u.expand(shape_of_queries)





    inds = torch.searchsorted(cdf, u.contiguous(), right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)

    inds_g = torch.stack([below, above], -1)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    # [4096, 128, 64]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
    # cdf_g[..., 0] = CDF Low 
    # cdf_g[..., 1] = CDF High 
   
    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])



    return samples