# PyTorch Tutorial: Neural Radiance Fields (NeRF)

Neural Radiance Fields (NeRF) represent 3D scenes as neural networks that map 3D coordinates to color and density. Given a set of 2D images, NeRF can synthesize novel views of a scene.

## Learning Objectives
- Understand implicit neural representations
- Implement positional encoding (Fourier features)
- Build a NeRF MLP architecture
- Understand volume rendering
- Learn about ray marching and sampling strategies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. The Problem: 3D Scene Representation

Traditional 3D representations:
- **Meshes**: Vertices + faces (discrete, hard to optimize)
- **Voxels**: 3D grids (memory intensive)
- **Point clouds**: Unstructured points (sparse)

**NeRF's approach**: Represent the scene as a continuous function:
$$F: (x, y, z, \theta, \phi) \rightarrow (r, g, b, \sigma)$$

Where:
- $(x, y, z)$: 3D position
- $(\theta, \phi)$: Viewing direction
- $(r, g, b)$: Color
- $\sigma$: Volume density (opacity)

In [None]:
# Visualize the concept
fig = plt.figure(figsize=(12, 4))

# Traditional representations
ax1 = fig.add_subplot(131, projection='3d')
# Simple mesh (cube)
vertices = np.array([[0,0,0], [1,0,0], [1,1,0], [0,1,0],
                     [0,0,1], [1,0,1], [1,1,1], [0,1,1]])
for i in range(4):
    ax1.plot3D(*zip(vertices[i], vertices[(i+1)%4]), 'b-')
    ax1.plot3D(*zip(vertices[i+4], vertices[(i+1)%4+4]), 'b-')
    ax1.plot3D(*zip(vertices[i], vertices[i+4]), 'b-')
ax1.set_title('Mesh')

# Voxels
ax2 = fig.add_subplot(132, projection='3d')
voxels = np.random.rand(5, 5, 5) > 0.7
ax2.voxels(voxels, alpha=0.5)
ax2.set_title('Voxels')

# NeRF concept (continuous)
ax3 = fig.add_subplot(133, projection='3d')
u = np.linspace(0, 2*np.pi, 50)
v = np.linspace(0, np.pi, 25)
x = np.outer(np.cos(u), np.sin(v))
y = np.outer(np.sin(u), np.sin(v))
z = np.outer(np.ones(np.size(u)), np.cos(v))
ax3.plot_surface(x, y, z, alpha=0.7, cmap='viridis')
ax3.set_title('NeRF (Continuous)')

plt.tight_layout()
plt.show()

print("NeRF represents scenes as continuous functions, enabling:")
print("- Arbitrary resolution rendering")
print("- Smooth interpolation between views")
print("- Compact representation (just network weights)")

## 2. Positional Encoding

MLPs struggle to learn high-frequency functions. **Positional encoding** maps inputs to higher dimensions using sinusoidal functions:

$$\gamma(p) = [\sin(2^0\pi p), \cos(2^0\pi p), ..., \sin(2^{L-1}\pi p), \cos(2^{L-1}\pi p)]$$

This is the key insight that makes NeRF work!

In [None]:
class PositionalEncoding(nn.Module):
    """
    Fourier feature positional encoding.
    Maps low-dimensional input to higher dimensions using sinusoids.
    """
    
    def __init__(self, num_frequencies: int = 10, include_input: bool = True):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.include_input = include_input
        
        # Frequency bands: 2^0, 2^1, ..., 2^(L-1)
        self.register_buffer(
            'frequency_bands',
            2.0 ** torch.linspace(0, num_frequencies - 1, num_frequencies)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [..., input_dim]
        Returns: [..., input_dim * (2 * num_frequencies + include_input)]
        """
        encodings = []
        
        if self.include_input:
            encodings.append(x)
        
        for freq in self.frequency_bands:
            encodings.append(torch.sin(freq * np.pi * x))
            encodings.append(torch.cos(freq * np.pi * x))
        
        return torch.cat(encodings, dim=-1)
    
    def output_dim(self, input_dim: int) -> int:
        return input_dim * (2 * self.num_frequencies + int(self.include_input))


# Demonstrate why positional encoding helps
def visualize_positional_encoding():
    pe = PositionalEncoding(num_frequencies=6)
    
    # 1D input
    x = torch.linspace(-1, 1, 200).unsqueeze(-1)
    encoded = pe(x)
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 6))
    
    # Original input
    axes[0].plot(x.numpy())
    axes[0].set_title('Original Input (1D)')
    axes[0].set_ylabel('Value')
    
    # Encoded (show first few dimensions)
    for i in range(min(8, encoded.shape[-1])):
        axes[1].plot(encoded[:, i].numpy(), alpha=0.7, label=f'dim {i}')
    axes[1].set_title('Positional Encoding (First 8 Dimensions)')
    axes[1].set_ylabel('Value')
    axes[1].legend(loc='upper right')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Input dimension: 1")
    print(f"Output dimension: {encoded.shape[-1]}")

visualize_positional_encoding()

## 3. NeRF MLP Architecture

The NeRF network architecture:
1. Encode position $(x, y, z)$ with positional encoding
2. Pass through MLP layers
3. Output density $\sigma$
4. Concatenate with encoded viewing direction
5. Output color $(r, g, b)$

In [None]:
class NeRF(nn.Module):
    """
    Neural Radiance Field network.
    
    Input: 3D position (x, y, z) and viewing direction (theta, phi)
    Output: RGB color and volume density (sigma)
    """
    
    def __init__(
        self,
        pos_encoding_freqs: int = 10,
        dir_encoding_freqs: int = 4,
        hidden_dim: int = 256,
        num_layers: int = 8,
        skip_connection: int = 4,  # Layer to add skip connection
    ):
        super().__init__()
        
        self.pos_encoder = PositionalEncoding(pos_encoding_freqs)
        self.dir_encoder = PositionalEncoding(dir_encoding_freqs)
        
        pos_dim = self.pos_encoder.output_dim(3)  # 3D position
        dir_dim = self.dir_encoder.output_dim(3)  # 3D direction (unit vector)
        
        self.skip_connection = skip_connection
        
        # Position-dependent layers
        self.pos_layers = nn.ModuleList()
        self.pos_layers.append(nn.Linear(pos_dim, hidden_dim))
        
        for i in range(1, num_layers):
            if i == skip_connection:
                # Skip connection: concat original positional encoding
                self.pos_layers.append(nn.Linear(hidden_dim + pos_dim, hidden_dim))
            else:
                self.pos_layers.append(nn.Linear(hidden_dim, hidden_dim))
        
        # Density output (no view dependence)
        self.density_layer = nn.Linear(hidden_dim, 1)
        
        # Feature vector for color prediction
        self.feature_layer = nn.Linear(hidden_dim, hidden_dim)
        
        # View-dependent color layers
        self.color_layer1 = nn.Linear(hidden_dim + dir_dim, hidden_dim // 2)
        self.color_layer2 = nn.Linear(hidden_dim // 2, 3)
        
        self.pos_dim = pos_dim
    
    def forward(self, pos: torch.Tensor, direction: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        pos: [..., 3] - 3D positions
        direction: [..., 3] - viewing directions (unit vectors)
        
        Returns:
            rgb: [..., 3] - colors
            sigma: [..., 1] - densities
        """
        # Encode inputs
        pos_encoded = self.pos_encoder(pos)
        dir_encoded = self.dir_encoder(direction)
        
        # Position network with skip connection
        h = pos_encoded
        for i, layer in enumerate(self.pos_layers):
            if i == self.skip_connection:
                h = torch.cat([h, pos_encoded], dim=-1)
            h = F.relu(layer(h))
        
        # Density (ReLU ensures non-negative)
        sigma = F.relu(self.density_layer(h))
        
        # Feature for color
        feature = self.feature_layer(h)
        
        # View-dependent color
        h = torch.cat([feature, dir_encoded], dim=-1)
        h = F.relu(self.color_layer1(h))
        rgb = torch.sigmoid(self.color_layer2(h))  # [0, 1] for colors
        
        return rgb, sigma


# Test the network
nerf = NeRF().to(device)
test_pos = torch.randn(100, 3).to(device)  # 100 random 3D positions
test_dir = F.normalize(torch.randn(100, 3), dim=-1).to(device)  # Unit directions

rgb, sigma = nerf(test_pos, test_dir)
print(f"Input position shape: {test_pos.shape}")
print(f"Input direction shape: {test_dir.shape}")
print(f"Output RGB shape: {rgb.shape}")
print(f"Output sigma shape: {sigma.shape}")
print(f"\nModel parameters: {sum(p.numel() for p in nerf.parameters()):,}")

## 4. Volume Rendering

To render an image, we cast rays through each pixel and integrate color along the ray:

$$C(\mathbf{r}) = \int_{t_n}^{t_f} T(t) \cdot \sigma(\mathbf{r}(t)) \cdot \mathbf{c}(\mathbf{r}(t), \mathbf{d}) \, dt$$

Where:
- $T(t) = \exp(-\int_{t_n}^{t} \sigma(\mathbf{r}(s)) \, ds)$ is transmittance
- $\sigma$ is density
- $\mathbf{c}$ is color

In [None]:
def volume_rendering(
    rgb: torch.Tensor,
    sigma: torch.Tensor,
    z_vals: torch.Tensor,
    rays_d: torch.Tensor,
    white_background: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Volume rendering equation.
    
    Args:
        rgb: [batch, num_samples, 3] - colors at sample points
        sigma: [batch, num_samples, 1] - densities at sample points
        z_vals: [batch, num_samples] - depth values along rays
        rays_d: [batch, 3] - ray directions
        white_background: whether to use white background
    
    Returns:
        rgb_map: [batch, 3] - rendered colors
        depth_map: [batch] - rendered depths
        weights: [batch, num_samples] - integration weights
    """
    # Compute distances between adjacent samples
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, torch.full_like(dists[..., :1], 1e10)], dim=-1)
    
    # Multiply by ray direction magnitude for actual distance
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
    
    # Alpha values: 1 - exp(-sigma * delta)
    alpha = 1.0 - torch.exp(-sigma.squeeze(-1) * dists)
    
    # Transmittance: cumulative product of (1 - alpha)
    # T_i = prod_{j=1}^{i-1} (1 - alpha_j)
    transmittance = torch.cumprod(
        torch.cat([torch.ones_like(alpha[..., :1]), 1.0 - alpha + 1e-10], dim=-1),
        dim=-1
    )[..., :-1]
    
    # Weights for each sample
    weights = alpha * transmittance
    
    # Rendered color: weighted sum
    rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)
    
    # Rendered depth: weighted sum of z values
    depth_map = torch.sum(weights * z_vals, dim=-1)
    
    # Handle background
    if white_background:
        acc_map = torch.sum(weights, dim=-1)
        rgb_map = rgb_map + (1.0 - acc_map[..., None])
    
    return rgb_map, depth_map, weights


# Demonstrate volume rendering
print("Volume Rendering Steps:")
print("1. Sample points along each ray")
print("2. Query NeRF for (rgb, sigma) at each point")
print("3. Compute alpha = 1 - exp(-sigma * delta)")
print("4. Compute transmittance T = cumulative product of (1 - alpha)")
print("5. Compute weights w = alpha * T")
print("6. Final color = sum(w * rgb)")

## 5. Ray Generation

To render from a camera, we need to generate rays for each pixel.

In [None]:
def get_rays(
    height: int,
    width: int,
    focal: float,
    camera_to_world: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate rays for each pixel.
    
    Args:
        height, width: image dimensions
        focal: focal length
        camera_to_world: [4, 4] camera pose matrix
    
    Returns:
        rays_o: [H, W, 3] - ray origins
        rays_d: [H, W, 3] - ray directions
    """
    # Create pixel coordinates
    i, j = torch.meshgrid(
        torch.arange(width, dtype=torch.float32),
        torch.arange(height, dtype=torch.float32),
        indexing='xy'
    )
    
    # Convert to camera coordinates (pinhole camera model)
    # Normalized device coordinates: center at (0, 0)
    directions = torch.stack([
        (i - width / 2) / focal,
        -(j - height / 2) / focal,  # Flip y
        -torch.ones_like(i),  # Looking down -z axis
    ], dim=-1)
    
    # Transform to world coordinates
    rays_d = torch.sum(
        directions[..., None, :] * camera_to_world[:3, :3],
        dim=-1
    )
    
    # Ray origins = camera position
    rays_o = camera_to_world[:3, 3].expand(rays_d.shape)
    
    return rays_o, rays_d


def sample_along_rays(
    rays_o: torch.Tensor,
    rays_d: torch.Tensor,
    near: float,
    far: float,
    num_samples: int,
    perturb: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample points along rays.
    
    Args:
        rays_o: [..., 3] - ray origins
        rays_d: [..., 3] - ray directions
        near, far: near and far bounds
        num_samples: number of samples per ray
        perturb: whether to add noise to sample positions
    
    Returns:
        pts: [..., num_samples, 3] - sample positions
        z_vals: [..., num_samples] - depth values
    """
    # Uniform samples between near and far
    t_vals = torch.linspace(0.0, 1.0, num_samples, device=rays_o.device)
    z_vals = near * (1.0 - t_vals) + far * t_vals
    
    # Expand to match ray batch shape
    z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [num_samples])
    
    # Perturb samples (stratified sampling)
    if perturb:
        mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
        lower = torch.cat([z_vals[..., :1], mids], dim=-1)
        t_rand = torch.rand_like(z_vals)
        z_vals = lower + (upper - lower) * t_rand
    
    # Compute 3D sample positions: o + t * d
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
    
    return pts, z_vals


# Visualize ray sampling
camera_pose = torch.eye(4)
camera_pose[2, 3] = 4.0  # Camera at z=4

rays_o, rays_d = get_rays(height=4, width=4, focal=2.0, camera_to_world=camera_pose)
pts, z_vals = sample_along_rays(rays_o, rays_d, near=2.0, far=6.0, num_samples=8)

print(f"Rays origin shape: {rays_o.shape}")
print(f"Rays direction shape: {rays_d.shape}")
print(f"Sample points shape: {pts.shape}")
print(f"Z values shape: {z_vals.shape}")

## 6. Rendering Pipeline

In [None]:
def render_rays(
    model: NeRF,
    rays_o: torch.Tensor,
    rays_d: torch.Tensor,
    near: float,
    far: float,
    num_samples: int = 64,
    perturb: bool = True,
) -> dict:
    """
    Render colors for a batch of rays.
    
    Args:
        model: NeRF network
        rays_o: [batch, 3] ray origins
        rays_d: [batch, 3] ray directions
        near, far: rendering bounds
        num_samples: samples per ray
        perturb: stratified sampling
    
    Returns:
        Dictionary with rgb_map, depth_map, etc.
    """
    # Sample points along rays
    pts, z_vals = sample_along_rays(rays_o, rays_d, near, far, num_samples, perturb)
    
    # Flatten for network input
    batch_size = rays_o.shape[0]
    pts_flat = pts.reshape(-1, 3)
    
    # Viewing directions (same for all points on a ray)
    dirs = F.normalize(rays_d, dim=-1)
    dirs_flat = dirs[:, None, :].expand(-1, num_samples, -1).reshape(-1, 3)
    
    # Query network
    rgb, sigma = model(pts_flat, dirs_flat)
    
    # Reshape
    rgb = rgb.reshape(batch_size, num_samples, 3)
    sigma = sigma.reshape(batch_size, num_samples, 1)
    
    # Volume rendering
    rgb_map, depth_map, weights = volume_rendering(rgb, sigma, z_vals, rays_d)
    
    return {
        'rgb_map': rgb_map,
        'depth_map': depth_map,
        'weights': weights,
    }


# Test rendering
nerf = NeRF().to(device)
test_rays_o = torch.randn(16, 3).to(device)
test_rays_d = F.normalize(torch.randn(16, 3), dim=-1).to(device)

with torch.no_grad():
    result = render_rays(nerf, test_rays_o, test_rays_d, near=2.0, far=6.0)

print(f"Rendered RGB shape: {result['rgb_map'].shape}")
print(f"Rendered depth shape: {result['depth_map'].shape}")

## 7. Hierarchical Sampling

NeRF uses a **coarse-to-fine** strategy:
1. **Coarse network**: Uniform sampling to identify important regions
2. **Fine network**: Importance sampling in high-density regions

In [None]:
def sample_pdf(
    bins: torch.Tensor,
    weights: torch.Tensor,
    num_samples: int,
) -> torch.Tensor:
    """
    Inverse transform sampling from a piecewise-constant PDF.
    
    Args:
        bins: [batch, num_bins + 1] - bin edges
        weights: [batch, num_bins] - weights for each bin
        num_samples: number of samples to draw
    
    Returns:
        samples: [batch, num_samples] - sampled positions
    """
    # Normalize weights to get PDF
    weights = weights + 1e-5  # Prevent division by zero
    pdf = weights / torch.sum(weights, dim=-1, keepdim=True)
    
    # Compute CDF
    cdf = torch.cumsum(pdf, dim=-1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1)
    
    # Sample uniform values
    u = torch.rand(list(cdf.shape[:-1]) + [num_samples], device=bins.device)
    u = u.contiguous()
    
    # Invert CDF
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.clamp(inds - 1, min=0)
    above = torch.clamp(inds, max=cdf.shape[-1] - 1)
    
    # Gather CDF values
    cdf_below = torch.gather(cdf, -1, below)
    cdf_above = torch.gather(cdf, -1, above)
    
    # Gather bin edges
    bins_below = torch.gather(bins, -1, below)
    bins_above = torch.gather(bins, -1, above)
    
    # Linear interpolation
    denom = cdf_above - cdf_below
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_below) / denom
    samples = bins_below + t * (bins_above - bins_below)
    
    return samples


print("Hierarchical Sampling Strategy:")
print("="*50)
print("1. Coarse pass: 64 uniform samples along ray")
print("2. Evaluate coarse network -> get weights")
print("3. Use weights as PDF for importance sampling")
print("4. Fine pass: 128 additional samples (importance sampled)")
print("5. Evaluate fine network with all samples")
print("\nThis focuses compute on regions with high density!")

## 8. Modern Improvements: Hash Encoding (Instant-NGP)

Hash encoding enables much faster training by using a multi-resolution hash table instead of deep networks.

In [None]:
class HashEncoding(nn.Module):
    """
    Simplified multi-resolution hash encoding (Instant-NGP style).
    Maps 3D positions to features using hash tables at multiple resolutions.
    """
    
    def __init__(
        self,
        num_levels: int = 16,
        features_per_level: int = 2,
        log2_hashmap_size: int = 19,
        base_resolution: int = 16,
        finest_resolution: int = 512,
    ):
        super().__init__()
        
        self.num_levels = num_levels
        self.features_per_level = features_per_level
        self.hashmap_size = 2 ** log2_hashmap_size
        
        # Compute resolution at each level (geometric progression)
        b = np.exp((np.log(finest_resolution) - np.log(base_resolution)) / (num_levels - 1))
        self.resolutions = [int(base_resolution * (b ** i)) for i in range(num_levels)]
        
        # Hash tables for each level
        self.hash_tables = nn.ParameterList([
            nn.Parameter(torch.randn(self.hashmap_size, features_per_level) * 0.01)
            for _ in range(num_levels)
        ])
        
        # Large primes for hashing
        self.primes = torch.tensor([1, 2654435761, 805459861], dtype=torch.long)
    
    def hash_function(self, coords: torch.Tensor) -> torch.Tensor:
        """Spatial hash function."""
        # XOR hashing with prime numbers
        result = torch.zeros(coords.shape[0], dtype=torch.long, device=coords.device)
        for i in range(3):
            result = result ^ (coords[:, i].long() * self.primes[i].to(coords.device))
        return result % self.hashmap_size
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [batch, 3] - 3D positions in [0, 1]^3
        Returns: [batch, num_levels * features_per_level]
        """
        features = []
        
        for level, (resolution, hash_table) in enumerate(zip(self.resolutions, self.hash_tables)):
            # Scale position to grid resolution
            scaled = x * resolution
            
            # Get integer grid coordinates (for simplicity, just floor)
            coords = torch.floor(scaled).long()
            coords = torch.clamp(coords, 0, resolution - 1)
            
            # Hash and lookup
            indices = self.hash_function(coords)
            level_features = hash_table[indices]
            features.append(level_features)
        
        return torch.cat(features, dim=-1)
    
    def output_dim(self) -> int:
        return self.num_levels * self.features_per_level


# Compare encodings
pos_enc = PositionalEncoding(num_frequencies=10)
hash_enc = HashEncoding(num_levels=16, features_per_level=2)

test_input = torch.rand(1000, 3)  # Random 3D positions

pos_output = pos_enc(test_input)
hash_output = hash_enc(test_input)

print("Encoding Comparison:")
print(f"  Positional encoding output dim: {pos_output.shape[-1]}")
print(f"  Hash encoding output dim: {hash_output.shape[-1]}")
print(f"\nHash encoding advantages:")
print("  - Learnable features (adapts to scene)")
print("  - Multi-resolution (coarse + fine details)")
print("  - Much faster training (minutes vs hours)")

## 9. FAANG Interview Questions

### Q1: What is NeRF and how does it represent 3D scenes?

**Answer**:

NeRF (Neural Radiance Field) represents a 3D scene as a continuous function implemented by an MLP:

$$F: (x, y, z, \theta, \phi) \rightarrow (r, g, b, \sigma)$$

**Inputs**:
- $(x, y, z)$: 3D position
- $(\theta, \phi)$: Viewing direction

**Outputs**:
- $(r, g, b)$: View-dependent color (captures specular effects)
- $\sigma$: View-independent density (geometry)

**Key insight**: Density depends only on position (geometry is view-independent), but color depends on viewing direction (specular highlights vary with viewpoint).

---

### Q2: Why is positional encoding important for NeRF?

**Answer**:

**Problem**: MLPs are biased toward learning low-frequency functions. Without positional encoding, NeRF produces blurry results.

**Solution**: Map inputs to higher dimensions using sinusoids:
$$\gamma(p) = [\sin(2^0\pi p), \cos(2^0\pi p), ..., \sin(2^{L-1}\pi p), \cos(2^{L-1}\pi p)]$$

**Why it works**:
- Sinusoids at different frequencies capture different detail levels
- Network can now represent high-frequency variations
- Similar to Fourier features / random Fourier features

---

### Q3: Explain volume rendering in NeRF.

**Answer**:

Volume rendering integrates color along each ray:

$$C(\mathbf{r}) = \int_{t_n}^{t_f} T(t) \cdot \sigma(t) \cdot \mathbf{c}(t) \, dt$$

**Components**:
- $T(t)$: Transmittance (probability ray hasn't hit anything yet)
- $\sigma(t)$: Density at point $t$
- $\mathbf{c}(t)$: Color at point $t$

**Discrete approximation**:
1. Sample N points along ray
2. Compute $\alpha_i = 1 - \exp(-\sigma_i \cdot \delta_i)$
3. Compute $T_i = \prod_{j<i}(1-\alpha_j)$
4. Compute weights $w_i = \alpha_i \cdot T_i$
5. Final color = $\sum_i w_i \cdot c_i$

---

### Q4: What is hierarchical sampling and why is it used?

**Answer**:

**Problem**: Uniform sampling wastes compute on empty space.

**Solution**: Two-pass sampling:

1. **Coarse pass**: 
   - Uniform samples (e.g., 64)
   - Evaluate coarse network
   - Get density weights

2. **Fine pass**:
   - Use weights as PDF
   - Importance sample additional points (e.g., 128)
   - Sample more in high-density regions
   - Evaluate fine network

**Benefits**:
- Focus compute on surfaces/objects
- Better quality with same sample count
- Or same quality with fewer samples

---

### Q5: How does Instant-NGP improve on original NeRF?

**Answer**:

**Original NeRF problems**:
- Slow training (hours/days)
- Deep networks needed
- Fixed positional encoding

**Instant-NGP improvements**:

1. **Hash encoding**:
   - Multi-resolution hash tables instead of sinusoids
   - Learnable features (adapt to scene)
   - O(1) lookup instead of network forward pass

2. **Smaller network**:
   - Hash features are more expressive
   - Only need 2-3 layer MLP

3. **Fully-fused CUDA kernels**:
   - Optimized GPU implementation

**Result**: Training in seconds/minutes instead of hours.

## 10. Key Takeaways

1. **NeRF** represents 3D scenes as continuous neural functions: (x,y,z,dir) -> (rgb, sigma)
2. **Positional encoding** enables MLPs to learn high-frequency details
3. **Volume rendering** integrates color along rays using density-weighted sums
4. **Hierarchical sampling** focuses compute on important regions
5. **View-dependent color** enables specular/reflective effects
6. **Hash encoding** (Instant-NGP) dramatically speeds up training
7. NeRF spawned many variants: mip-NeRF, NeRF-W, D-NeRF (dynamic), etc.