# 02 - Core Components

This notebook defines the core NeRF architecture and rendering functions.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Ray Generation

In [None]:
def get_rays(H, W, focal, c2w):
    i, j = torch.meshgrid(
        torch.arange(W, device=device),
        torch.arange(H, device=device),
        indexing='xy'
    )
    dirs = torch.stack([
        (i - W*0.5)/focal,
        -(j - H*0.5)/focal,
        -torch.ones_like(i)
    ], -1)
    rays_d = (dirs[..., None, :] * c2w[:3, :3]).sum(-1)
    rays_o = c2w[:3, 3].expand(rays_d.shape)
    return rays_o, rays_d

## Positional Encoding

In [None]:
class PosEnc(nn.Module):
    def __init__(self, n_freqs=10, include_input=True, log_sampling=True):
        super().__init__()
        self.include_input = include_input
        if log_sampling:
            self.freq_bands = 2.**torch.linspace(0, n_freqs-1, n_freqs)
        else:
            self.freq_bands = torch.linspace(2.**0., 2.**(n_freqs-1), n_freqs)
    
    def forward(self, x):
        out = [x] if self.include_input else []
        for f in self.freq_bands.to(x.device):
            out += [torch.sin(f * x), torch.cos(f * x)]
        return torch.cat(out, -1)

pe_xyz = PosEnc(n_freqs=10).to(device)
pe_dir = PosEnc(n_freqs=4).to(device)

## NeRF Model

In [None]:
class NeRF(nn.Module):
    def __init__(self, D=8, W=256, in_ch_xyz=3*2*10+3, in_ch_dir=3*2*4+3, skips=(4,)):
        super().__init__()
        self.skips = set(skips)
        
        # Point MLP with skip connections
        self.pts_linears = nn.ModuleList()
        self.pts_linears.append(nn.Linear(in_ch_xyz, W))
        for i in range(1, D):
            if i in self.skips:
                self.pts_linears.append(nn.Linear(W + in_ch_xyz, W))
            else:
                self.pts_linears.append(nn.Linear(W, W))
        
        # Heads
        self.sigma_linear = nn.Linear(W, 1)
        self.feature_linear = nn.Linear(W, W)
        
        # View-direction branch
        self.dir_linears = nn.ModuleList([nn.Linear(W + in_ch_dir, W // 2)])
        self.rgb_linear = nn.Linear(W // 2, 3)
    
    def forward(self, x, d):
        h = x
        for i, layer in enumerate(self.pts_linears):
            if i in self.skips:
                h = torch.cat([h, x], dim=-1)
            h = F.relu(layer(h))
        
        sigma = F.softplus(self.sigma_linear(h))
        feat = self.feature_linear(h)
        
        h_dir = torch.cat([feat, d], dim=-1)
        for layer in self.dir_linears:
            h_dir = F.relu(layer(h_dir))
        rgb = torch.sigmoid(self.rgb_linear(h_dir))
        
        return rgb, sigma

model = NeRF().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Volume Rendering

In [None]:
def render_rays(model, rays_o, rays_d, near=2.0, far=6.0, n_samples=64, perturb=True, white_bkgd=True):
    # Stratified sampling
    t_vals = torch.linspace(0., 1., steps=n_samples, device=rays_o.device)
    z_vals = near * (1. - t_vals) + far * t_vals
    z_vals = z_vals[None, :].repeat(rays_o.shape[0], 1)
    
    if perturb:
        mids = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])
        upper = torch.cat([mids, z_vals[:, -1:]], -1)
        lower = torch.cat([z_vals[:, :1], mids], -1)
        z_vals = lower + (upper - lower) * torch.rand_like(z_vals)
    
    # Sample points
    pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., None]
    
    # Encode
    dirs = F.normalize(rays_d, dim=-1)
    dirs_enc = pe_dir(dirs)
    pts_enc = pe_xyz(pts)
    
    # Forward pass
    N, S = pts_enc.shape[:2]
    x = pts_enc.reshape(N*S, -1)
    d = dirs_enc[:, None, :].expand(N, S, -1).reshape(N*S, -1)
    
    rgb, sigma = model(x, d)
    rgb = rgb.view(N, S, 3)
    sigma = sigma.view(N, S)
    
    # Volume rendering
    deltas = z_vals[:, 1:] - z_vals[:, :-1]
    deltas = torch.cat([deltas, 1e10 * torch.ones_like(deltas[:, :1])], -1)
    alpha = 1. - torch.exp(-sigma * deltas)
    T = torch.cumprod(
        torch.cat([torch.ones((N,1), device=alpha.device), 1. - alpha + 1e-10], -1),
        -1
    )[:, :-1]
    weights = alpha * T
    
    rgb_map = (weights[..., None] * rgb).sum(dim=1)
    depth_map = (weights * z_vals).sum(dim=1)
    
    if white_bkgd:
        acc_map = weights.sum(dim=1, keepdim=True)
        rgb_map = rgb_map + (1. - acc_map)
    
    return rgb_map.clamp(0, 1), depth_map

print("âœ… Core components ready!")