# NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

*A deep dive into implementing NeRF from scratch. We start by setting up the environment and understanding the mathematical foundations of Ray Marching.* 

## Step 0: Necessary Imports

In [7]:
### basics

from os import path
import json

 
### base 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

### torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader,Dataset


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

from tqdm import tqdm
import imageio


Using device: cpu


## Step 1: The Data

### 1.1 Source

We are using the **Synthetic Lego Dataset** 

  * **Download:** [Google Drive Link](https://drive.google.com/drive/folders/1cK3UDIJqKAAm7zyrxRYVFJ0BRMgrwhh4)
  * **wget:** wget https://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/nerf_example_data.zip
### 1.2 Directory Structure

After extraction, ensure your folder looks like this:

```text
nerf_synthetic/
  └── lego/
      ├── train/                # Training Images
      ├── val/                  # Validation Images
      ├── test/                 # Test Images
      ├── transforms_train.json # Camera Poses for training
      ├── transforms_val.json
      └── transforms_test.json
```


## Step 2: Understanding Camera Metadata (`transforms.json`)

The JSON file contains the **Intrinsic** (Focal Length) and **Extrinsic** (Position/Rotation) parameters for every image.

### 2.1 Focal Length & Intrinsic Matrix ($K$)

At the top of the JSON, `camera_angle_x` gives us the Field of View (FOV). We use this to calculate the Focal Length ($f$) and build the Intrinsic Matrix ($K$).

  * **`camera_angle_x`**: Horizontal FOV in radians.
  * **`W`, `H`**: Image Width and Height.

$$f = \frac{W / 2}{\tan(\text{camera\_angle\_x} / 2)}$$

**The Intrinsic Matrix ($K$):**
This maps 3D camera coordinates to 2D image pixels.

$$
K = \begin{bmatrix}
f & 0 & W/2 \\
0 & f & H/2 \\
0 & 0 & 1
\end{bmatrix}
$$

![Image of pinhole camera focal length field of view diagram](focal-angle.png)

### 2.2 Camera Poses (Extrinsics)

For every frame, the `transform_matrix` represents the **Camera-to-World ($c2w$)** transformation.

```json
"frames": [
    {
        "transform_matrix": [
            [r11, r12, r13, tx],
            [r21, r22, r23, ty],
            [r31, r32, r33, tz],
            [0.0, 0.0, 0.0, 1.0]
        ]
    }
]
```

**The Pose Matrix ($T_{c2w}$):**
This places the camera into the 3D world.

$$
T_{c2w} = \left[ \begin{array}{c|c}
\mathbf{R} & \mathbf{t} \\
\hline
0 & 1
\end{array} \right]

$$

  * **$\mathbf{R}$ (Rotation $3 \times 3$):** Defines the camera's orientation (Right, Up, Back vectors).
  * **$\mathbf{t}$ (Translation $3 \times 1$):** Defines the camera's position $(x, y, z)$ in the world.

-----


## Step 3: Ray Generation Math

To render the scene, we must convert 2D pixels $(u, v)$ into 3D Rays $\mathbf{r}(t) = \mathbf{o} + t\mathbf{d}$.



### 3.1 The Three Coordinate Spaces

1.  **Pixel Space $(u, v)$**: The 2D grid of the image.
2.  **Camera Space $(x_c, y_c, z_c)$**: 3D coordinates relative to the camera lens.
3.  **World Space $(x_w, y_w, z_w)$**: The absolute 3D coordinates of the scene.

### 3.2 The Algorithm (Pixel $\rightarrow$ Ray)

**Step A: Pixel to Camera Space**
We project the 2D pixel to a 3D direction vector. Note the negative $Z$ (forward) and negative $Y$ (screen coordinates vs 3D coordinates).

$$
x_c = \frac{(u - W/2)}{f}, \quad y_c = \frac{-(v - H/2)}{f}, \quad z_c = -1
$$

**Step B: Camera Space to World Space**
We use the `transform_matrix` (Pose) to rotate and translate this vector into the world.

  * **Ray Origin ($\mathbf{o}$):** The camera's position.
      * `rays_o = transform_matrix[:3, 3]`
  * **Ray Direction ($\mathbf{d}$):** The pixel vector rotated by the camera's orientation.
      * `rays_d = transform_matrix[:3, :3] @ [x_c, y_c, z_c]`

### 3.3 Practical Correspondence Table

| Concept | In Math / Code | In JSON Data |
| :--- | :--- | :--- |
| **Focal Length** | $f$ | Calculated from `camera_angle_x` |
| **Rotation** | $R$ (Orientation) | `transform_matrix[:3, :3]` |
| **Translation** | $t$ (Position) | `transform_matrix[:3, 3]` |
| **Ray Origin** | $\mathbf{o}$ | Matches **Translation** exactly |
| **Ray Direction** | $\mathbf{d}$ | Calculated using **Rotation** and $f$ |

## Step 4: Create Dataset And DataLoader

### Step 4.1:  Dataset

In [8]:
class LegoRayDataset(Dataset):
    def __init__(self,base_path='./nerf_synthetic/lego',split='train'):
        self.base_path = base_path
        self.images = path.join(base_path,split)

        json_file_path = path.join(base_path, f'transforms_{split}.json')
        with open(json_file_path, 'r') as f:
            self.transforms = json.load(f)
        self.fov = self.transforms['camera_angle_x']

    
    def __len__(self):
        return len(self.transforms['frames'])

    
    def __getitem__(self, idx):
        frame = self.transforms['frames'][idx]
        
        fname = frame['file_path']
        if fname.startswith('./'):
            fname = fname[2:] 
        
        image_path = path.join(self.base_path, f"{fname}.png")

        img = plt.imread(image_path) 
        img = torch.from_numpy(img).float()

        if img.shape[-1] == 4:
            rgb = img[..., :3]
            alpha = img[..., 3:4]
            
            img = rgb * alpha + (1 - alpha)
        
        pose = torch.tensor(frame['transform_matrix'], dtype=torch.float32)
        return img,pose,self.fov
    
def get_rays(H, W, fov, c2w):
    focal = 0.5 * W / np.tan(0.5 * fov)

    i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32), 
                        torch.arange(H, dtype=torch.float32), indexing='xy')
    # 0 1 2 ....   W-1
    # 1
    # 2
    # 3
    # .
    # .
    # .
    # H-1        H-1 W-1   
    
    device = c2w.device
    i = i.to(device)
    j = j.to(device)

    dirs = torch.stack([(i - W * .5) / focal, -(j - H * .5) / focal, -torch.ones_like(i)], -1)
    #Cartesian
    


    #  i - W*0.5: Centers the coordinates. 
    # -(j - H*0.5): Centers the Y-axis.
        # Why negative? In images, "Down" is +Y. In 3D graphics (OpenGL/NeRF), "Up" is +Y
    # / focal: Normalizes the pixel distance. 

    # -1: The Z-direction.
    # By convention, the camera looks down the Negative Z-axis.
    

    # d_world​= R ⋅ d_camera​
    rays_d = torch.sum(dirs.unsqueeze(-2) * c2w[:3, :3], -1) 
    # dirs shape  (H, W,  3) -->(H, W, 1, 3) [[x,y,z]] (a row that can be multiplied by a matrix) 
    # input: (H, W, 3, 3)
    # sum over last axis (3).
    # output: (H, W, 3).
    rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    rays_o = c2w[:3, 3].expand(rays_d.shape)
    # c2w shape (3,) ==> expand (H,W,3)
    
    return rays_o, rays_d  



### Step 5.1:  DataLoader

In [9]:
def ray_collate_fn(batch):
    all_rays_o = []
    all_rays_d = []
    all_target_colors = []
    
    for img, pose, fov in batch:
        H, W = img.shape[:2]
        
        N_RAYS_PER_IMAGE = 4096

        coords = torch.stack(torch.meshgrid(
            torch.arange(H, dtype=torch.float32),
            torch.arange(W, dtype=torch.float32),
            indexing='ij'
        ), -1).reshape(-1, 2)
        # HWx2
        # [[0,0],
        #  [0,1],
        #  [0,2],
        #  ...
        #  [H-1, W-1]]


        
        select_inds = np.random.choice(coords.shape[0], size=[N_RAYS_PER_IMAGE], replace=False)
        #select_inds = [501, 12, 9999, 42, ...]

        select_coords = coords[select_inds].long() 
        
        i = select_coords[:, 1] 
        j = select_coords[:, 0] 
        
        focal = 0.5 * W / np.tan(0.5 * fov)
        
        dirs = torch.stack([(i - W * .5) / focal, -(j - H * .5) / focal, -torch.ones_like(i)], -1)
        
        rays_d = torch.sum(dirs[..., np.newaxis, :] * pose[:3, :3], -1)
        rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
        rays_o = pose[:3, 3].expand(rays_d.shape)
        
        target_colors = img[j, i] 

        all_rays_o.append(rays_o)
        all_rays_d.append(rays_d)
        all_target_colors.append(target_colors)

    return torch.cat(all_rays_o, 0), torch.cat(all_rays_d, 0), torch.cat(all_target_colors, 0)

In [10]:
train_dataset = LegoRayDataset(split='train')
val_dataset = LegoRayDataset(split='val')
test_dataset = LegoRayDataset(split='test')

BATCH_SIZE=2
train_loader = DataLoader(train_dataset,BATCH_SIZE,shuffle=True,collate_fn=ray_collate_fn,num_workers=8,prefetch_factor=16,persistent_workers=True,pin_memory=True)
val_loader = DataLoader(val_dataset,BATCH_SIZE,shuffle=True,collate_fn=ray_collate_fn,num_workers=8,prefetch_factor=16,persistent_workers=True,pin_memory=True)
test_loader = DataLoader(test_dataset,BATCH_SIZE,collate_fn=ray_collate_fn)

## Step 6: NeRF model

In [11]:
# X : location; 
L = 10
# d : direction
L=4

In [12]:
import torch

In [13]:

class PositionalEncoder(nn.Module):
    def __init__(self, d_input, n_freqs, ):
        super().__init__()
        self.d_input = d_input
        self.n_freqs = n_freqs
        self.d_output = d_input * (1 + 2 * n_freqs) 
        self.embed_fns = [lambda x: x] 

        freq_bands = 2.**torch.linspace(0., n_freqs - 1, n_freqs)
        for freq in freq_bands:
            self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq * 3.14159))
            self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq * 3.14159))

    def forward(self, x):
        return torch.cat([fn(x) for fn in self.embed_fns], dim=-1)


In [14]:

class NeRF(nn.Module):
    def __init__(self, L1=10, L2=4):
        super().__init__()
        self.embed_pos = PositionalEncoder(3, L1)
        self.embed_dir = PositionalEncoder(3, L2)
        
        input_ch_pos = self.embed_pos.d_output 
        input_ch_dir = self.embed_dir.d_output

        self.density_mlp1 = nn.Sequential(
            nn.Linear(input_ch_pos, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        
        self.skip_layer = nn.Linear(256 + input_ch_pos, 256)

        self.density_mlp2 = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        
        self.density_head = nn.Linear(256, 1)
        self.feature_head = nn.Sequential( 
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        
        self.color_mlp = nn.Sequential(
            nn.Linear(256 + input_ch_dir, 128),
            nn.ReLU(),
            nn.Linear(128, 3) 
        )

    def forward(self, x, d):
        x = x / 4.0 

        x_embedded = self.embed_pos(x)
        d_embedded = self.embed_dir(d)
        
        h = self.density_mlp1(x_embedded)
        
        h = torch.cat([h, x_embedded], dim=-1)
        
        h = self.skip_layer(h)
        
        h = self.density_mlp2(h)
        
        sigma = self.density_head(h)
        
        feature = self.feature_head(h)
        h_color = torch.cat([feature, d_embedded], dim=-1)
        
        rgb = torch.sigmoid(self.color_mlp(h_color)) 
        
        return sigma, rgb

## Step7 : Setup the training loop

In [16]:
coarse = NeRF().to(device)
fine = NeRF().to(device)
params = list(coarse.parameters()) + list(fine.parameters())

optimizer = optim.Adam(params, lr=5e-4,  betas=(0.9, 0.999))

TARGET_STEPS = 200000 
steps_per_epoch = len(train_loader) # 100/BATCH_size

num_epochs = int(TARGET_STEPS / steps_per_epoch) + 1
print(num_epochs," Epochs")
# num_epochs=200

gamma = (5e-5 / 5e-4) ** (1 / num_epochs)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

4001  Epochs


In [None]:
def volume_render(raw, z_vals, rays_d):
    # raw: [N_rays, N_samples, 4] 
    # z_vals: [N_rays, N_samples] 
    # rays_d: [N_rays, 3] 
    
    # delta
    dists = z_vals[..., 1:] - z_vals[..., :-1] 
    
    # The last sample goes to infinity
    last_dist = torch.tensor([1e10], device=raw.device).expand(dists[..., :1].shape)
    
    dists = torch.cat([dists, last_dist], -1)

    # distance = dt
    dists = dists * torch.norm(rays_d.unsqueeze(1), dim=-1)

    # rgb and sigma
    rgb = raw[..., :3] 
    sigma = F.relu(raw[..., 3])        

    # opacity
    opacity = 1.0 - torch.exp(-sigma * dists)

    # T
    p = 1.0 - opacity + 1e-10
    T = torch.cumprod(torch.cat([torch.ones((opacity.shape[0], 1), device=raw.device), p], -1), -1)[:, :-1]

    # weights
    weights = T * opacity
    acc_map = torch.sum(weights, -1) 
    rgb_map = torch.sum(weights.unsqueeze(-1) * rgb, -2)

    # Add White Background
    rgb_map = rgb_map + (1. - acc_map.unsqueeze(-1) )
    
    return rgb_map,weights

In [None]:
def sample_pdf(bins, weights, N_fine=128, noise=True):
    """
    Sample N_fine points from the probability distribution defined by weights.
    bins: [Batch, N_coarse-1] (Mid-points of coarse z_vals)
    weights: [Batch, N_coarse-2] (Weights from coarse network)
    """
    weights = weights + 1e-5
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # [Batch, N_coarse-1]
    bins = torch.cat([bins, bins[..., -1:]], -1) 

    # 3. Generate random queries (u)
    shape_of_queries= list(cdf.shape[:-1]) + [N_fine]
    # (Batch,128)
    if noise:
        u = torch.rand(shape_of_queries, device=weights.device)
    else:
        u = torch.linspace(0., 1., steps=N_fine, device=weights.device)
        u = u.expand(shape_of_queries)

    inds = torch.searchsorted(cdf, u.contiguous(), right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)

    inds_g = torch.stack([below, above], -1)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    # [4096, 128, 64]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
    # cdf_g[..., 0] = CDF Low 
    # cdf_g[..., 1] = CDF High 
   
    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples

In [None]:
def train_hierarchical():
    criterion = nn.MSELoss()
    coarse.train()
    fine.train()
    
    loss_history = []
    epoch_bar = tqdm(range(num_epochs), desc="Epochs")
    
    for epoch in epoch_bar:
        total_loss = 0
        
        for rays_o, rays_d, targets in tqdm(train_loader, desc="Training", leave=False):
            rays_o, rays_d, targets = rays_o.to(device), rays_d.to(device), targets.to(device)
            
            N_c = 64
            N_f = 128
            near, far = 2.0, 6.0
            
            z_vals = torch.linspace(near, far, steps=N_c, device=device)
            z_vals = z_vals.repeat(rays_o.shape[0], 1)

            m_rand = (torch.rand(z_vals.shape, device=device) - 0.5) * (far - near) / N_c
            z_vals = z_vals + m_rand

            pts = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals.unsqueeze(-1)
            
            pts_flat = pts.reshape(-1, 3)
            dirs_flat = rays_d.unsqueeze(1).expand(pts.shape).reshape(-1, 3)
            
            sigma_c, rgb_c = coarse(pts_flat, dirs_flat)
            raw_c = torch.cat([rgb_c, sigma_c], -1).reshape(rays_o.shape[0], N_c, 4)
            
            rgb_map_c, weights_c = volume_render(raw_c, z_vals, rays_d)

       
            z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
            
            z_samples = sample_pdf(z_vals_mid, weights_c[..., :-1], N_f)
            z_samples = z_samples.detach() 
            
            z_vals_fine, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
            
            pts_fine = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals_fine.unsqueeze(-1)
            
            pts_flat_f = pts_fine.reshape(-1, 3)
            dirs_flat_f = rays_d.unsqueeze(1).expand(pts_fine.shape).reshape(-1, 3)
            
            sigma_f, rgb_f = fine(pts_flat_f, dirs_flat_f)
            raw_f = torch.cat([rgb_f, sigma_f], -1).reshape(rays_o.shape[0], N_c + N_f, 4)
            
            rgb_map_f, _ = volume_render(raw_f, z_vals_fine, rays_d)

            loss = criterion(rgb_map_c, targets) + criterion(rgb_map_f, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        loss_history.append(avg_train_loss)
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            coarse.eval()
            fine.eval()
            total_val_loss = 0
            with torch.no_grad():
                for rays_o, rays_d, targets in val_loader:
                    rays_o, rays_d, targets = rays_o.to(device), rays_d.to(device), targets.to(device)
                    
                    # Coarse Pass
                    N_c, near, far = 64, 2.0, 6.0
                    t_vals = torch.linspace(0., 1., steps=N_c, device=device)
                    z_vals = near * (1.-t_vals) + far * (t_vals)
                    z_vals = z_vals.expand([rays_o.shape[0], N_c])
                    
                    pts = rays_o.unsqueeze(-2)+ rays_d.unsqueeze(-2)* z_vals.unsqueeze(-1)
                    pts_flat = pts.reshape(-1, 3)
                    dirs_flat = rays_d[:, None, :].expand(pts.shape).reshape(-1, 3)
                    sigma_c, rgb_c = coarse(pts_flat, dirs_flat)
                    raw_c = torch.cat([rgb_c, sigma_c], -1).reshape(rays_o.shape[0], N_c, 4)
                    rgb_map_c, weights_c = volume_render(raw_c, z_vals, rays_d)
                    
                    N_f = 128
                    z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
                    z_samples = sample_pdf(z_vals_mid, weights_c[..., :-1], N_f, noise=False) 
                    z_vals_fine, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
                    
                    pts_fine = rays_o.unsqueeze(-2)+ rays_d.unsqueeze(-2)* z_vals_fine.unsqueeze(-1)
                    sigma_f, rgb_f = fine(pts_fine.reshape(-1, 3), rays_d[:, None, :].expand(pts_fine.shape).reshape(-1, 3))
                    raw_f = torch.cat([rgb_f, sigma_f], -1).reshape(rays_o.shape[0], N_c + N_f, 4)
                    rgb_map_f, _ = volume_render(raw_f, z_vals_fine, rays_d)
                    
                    loss = criterion(rgb_map_f, targets)
                    total_val_loss += loss.item()
                torch.save({
                    'coarse': coarse.state_dict(),
                    'fine': fine.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'nerf_checkpoint.pth')
                print("\nMODEL SAVED\n")

            val_psnr = -10. * torch.log10(torch.tensor(total_val_loss / len(val_loader)))
            epoch_bar.set_postfix({"Loss": f"{avg_train_loss:.4f}", "Val PSNR": f"{val_psnr:.2f}"})

    return loss_history
        

In [None]:
loss_history = train_hierarchical()

plt.figure(figsize=(10, 5))
plt.plot(loss_history)
plt.title("Training Loss (MSE)")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.yscale("log")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:

def pose_spherical(theta, phi, radius):
    trans_t = lambda t : torch.Tensor([
        [1,0,0,0],
        [0,1,0,0],
        [0,0,1,t],
        [0,0,0,1]]).float()

    rot_phi = lambda phi : torch.Tensor([
        [1,0,0,0],
        [0,np.cos(phi),-np.sin(phi),0],
        [0,np.sin(phi), np.cos(phi),0],
        [0,0,0,1]]).float()

    rot_theta = lambda th : torch.Tensor([
        [np.cos(th),0,-np.sin(th),0],
        [0,1,0,0],
        [np.sin(th),0, np.cos(th),0],
        [0,0,0,1]]).float()
    
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    
    c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
    return c2w
def render_minimal(checkpoint_path, save_path='./video.mp4'):
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Rendering on {device}...")

    coarse = NeRF(L1=10, L2=4).to(device)
    fine = NeRF(L1=10, L2=4).to(device)
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    coarse.load_state_dict(checkpoint['coarse'])
    fine.load_state_dict(checkpoint['fine'])
    
    coarse.eval()
    fine.eval()

    test_dataset = LegoRayDataset(split='test')
    H, W = 800, 800
    focal = test_dataset.fov
    
    poses = [pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(0, 360, 40)[:-1]]
    frames = []

    for c2w in tqdm(poses):
        c2w = c2w.to(device)
        rays_o, rays_d = get_rays(H, W, focal, c2w)
        rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
        
        chunk_size = 4096 
        img_flat = []
        
        with torch.no_grad():
            for i in range(0, rays_o.shape[0], chunk_size):
                bo = rays_o[i:i+chunk_size]
                bd = rays_d[i:i+chunk_size]
                
                N_c = 64
                t = torch.linspace(0., 1., steps=N_c, device=device)
                z = 2.0 * (1.-t) + 6.0 * t
                z = z.expand([bo.shape[0], N_c])
                
                pts = bo.unsqueeze(1) + bd.unsqueeze(1) * z.unsqueeze(-1)
                
                pts_flat = pts.reshape(-1, 3)
                dirs_flat = bd.unsqueeze(1).expand(pts.shape).reshape(-1, 3)
                
                sigma_c, rgb_c = coarse(pts_flat, dirs_flat)
                raw_c = torch.cat([rgb_c, sigma_c], -1).reshape(bo.shape[0], N_c, 4)
                
                _, w_c = volume_render(raw_c, z, bd)
                
                N_f = 128
                z_mid = .5 * (z[..., 1:] + z[..., :-1])
                z_samp = sample_pdf(z_mid, w_c[..., :-1], N_f, noise=False)
                z_fine, _ = torch.sort(torch.cat([z, z_samp], -1), -1)
                
                pts_f = bo.unsqueeze(1) + bd.unsqueeze(1) * z_fine.unsqueeze(-1)
                
                pts_flat_f = pts_f.reshape(-1, 3)
                dirs_flat_f = bd.unsqueeze(1).expand(pts_f.shape).reshape(-1, 3)
                
                sigma_f, rgb_f = fine(pts_flat_f, dirs_flat_f)
                raw_f = torch.cat([rgb_f, sigma_f], -1).reshape(bo.shape[0], N_c + N_f, 4)
                
                rgb, _ = volume_render(raw_f, z_fine, bd)
                img_flat.append(rgb.cpu())

        img = torch.cat(img_flat, 0).reshape(H, W, 3)
        img = torch.clamp(img, 0, 1)
        frames.append((img.numpy() * 255).astype(np.uint8))

    imageio.mimsave(save_path, frames, fps=30, quality=8)
    print(f"Video saved to {save_path}")

In [None]:
render_minimal('./nerf_checkpoint.pth')