# 04 - Soft Depth Supervision

Training NeRF with soft MSE depth loss.

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from pathlib import Path

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Soft Depth Rendering Function

In [None]:
def render_rays_soft(nerf, rays_o, rays_d, near, far, N_samples, white_bg=True):
    # Stratified sampling
    z_vals = torch.linspace(near, far, N_samples, device=rays_o.device)
    z_vals = z_vals.unsqueeze(0)
    
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
    pts_flat = pts.reshape(-1, 3)
    dirs_flat = rays_d.unsqueeze(1).expand_as(pts).reshape(-1, 3)
    
    # Encode
    pts_enc = pe_xyz(pts_flat)
    dirs_enc = pe_dir(dirs_flat)
    
    # MLP forward
    rgb, sigma = nerf(pts_enc, dirs_enc)
    rgb = rgb.view(*pts.shape[:-1], 3)
    sigma = sigma.view(*pts.shape[:-1])
    
    # Volume rendering
    deltas = z_vals[..., 1:] - z_vals[..., :-1]
    deltas = torch.cat([deltas, 1e-3 * torch.ones_like(deltas[..., :1])], dim=-1)
    alpha = 1.0 - torch.exp(-sigma * deltas)
    trans = torch.cumprod(
        torch.cat([torch.ones_like(alpha[..., :1]), 1.-alpha + 1e-10], dim=-1),
        dim=-1
    )[..., :-1]
    weights = alpha * trans
    
    rgb_map = (weights[..., None] * rgb).sum(dim=-2)
    if white_bg:
        rgb_map = rgb_map + (1 - weights.sum(dim=-1, keepdim=True))
    
    depth_soft = (weights * z_vals).sum(dim=-1)
    
    return rgb_map, depth_soft

## Training Loop

In [None]:
def to_tensor(x):
    if x is None:
        return None
    return torch.from_numpy(x).to(device)

# Load data
images = to_tensor(imgs_train).float()
poses = to_tensor(poses_train).float()
depths = to_tensor(depths_train).float() if depths_train is not None else None

N_imgs, Ht, Wt = images.shape[:3]
psnr_history = []
loss_history = []

# Training parameters
iters = 20000
batch_rays = 1024
lr = 5e-4
lambda_depth = 0.01 if depths is not None else 0.0
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def psnr(mse):
    return -10.0 * torch.log10(mse)

print('Starting soft depth training...')
pbar = tqdm(range(iters))

for it in pbar:
    i = torch.randint(0, N_imgs, (1,)).item()
    c2w = poses[i]
    rays_o, rays_d = get_rays(Ht, Wt, focal, c2w)
    target = images[i].view(-1, 3)
    
    idx = torch.randint(0, target.shape[0], (batch_rays,), device=device)
    rays_o_b = rays_o.view(-1, 3)[idx]
    rays_d_b = rays_d.view(-1, 3)[idx]
    target_b = target[idx]
    
    rgb, depth_pred = render_rays_soft(model, rays_o_b, rays_d_b, near=2.0, far=6.0, N_samples=64)
    
    rgb_loss = F.mse_loss(rgb, target_b)
    depth_loss = torch.tensor(0.0, device=device)
    
    if depths is not None:
        depth_gt_full = depths[i]
        depth_gt_b = depth_gt_full.view(-1)[idx]
        depth_loss = ((depth_pred - depth_gt_b)**2).mean()
    
    loss = rgb_loss + lambda_depth * depth_loss
    loss_history.append(loss.item())
    psnr_history.append(psnr(rgb_loss).item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (it + 1) % 100 == 0:
        if depths is not None:
            msg = f"it {it+1} | RGB: {rgb_loss.item():.5f} | Depth: {depth_loss.item():.5f} | Total: {loss.item():.5f} | PSNR: {psnr(rgb_loss).item():.2f} dB"
            pbar.set_description(msg)
        else:
            pbar.set_description(f'it {it+1} | loss {loss.item():.5f} | psnr {psnr(rgb_loss).item():.2f} dB')

# Save results
save_dir = Path('results/soft')
save_dir.mkdir(parents=True, exist_ok=True)
np.save(save_dir / 'psnr_history.npy', np.array(psnr_history))
np.save(save_dir / 'loss_history.npy', np.array(loss_history))
torch.save(model.state_dict(), save_dir / 'model_soft.pth')
print(f'Soft depth training complete! Results saved to {save_dir}')