# 06 - Hybrid Strategy

Combining soft MSE depth loss with hard depth-guided sampling.

## Hybrid Rendering Function

In [None]:
import torch
import torch.nn.functional as F

def render_rays_hybrid(nerf, rays_o, rays_d, near, far, N_samples,
                       depth_gt=None, depth_mask=None, eps=0.3,
                       white_bg=True, return_extras=False):
    
    N = rays_o.shape[0]
    
    # Hybrid strategy: 50% depth-guided, 50% uniform
    if depth_gt is not None and depth_mask is not None and depth_mask.any():
        N_guided = N_samples // 2
        N_uniform = N_samples - N_guided
        
        # Uniform samples
        t_uniform = torch.linspace(near, far, N_uniform, device=rays_o.device)
        t_uniform = t_uniform[None, :].expand(N, -1)
        
        # Depth-guided samples
        z_guided = torch.zeros(N, N_guided, device=rays_o.device)
        for i in range(N):
            if depth_mask[i]:
                depth_center = depth_gt[i].item()
                near_guided = max(near, depth_center - eps/2)
                far_guided = min(far, depth_center + eps/2)
                t_guided = torch.linspace(0., 1., N_guided, device=rays_o.device)
                z_guided[i] = near_guided * (1. - t_guided) + far_guided * t_guided
            else:
                t_guided = torch.linspace(0., 1., N_guided, device=rays_o.device)
                z_guided[i] = near * (1. - t_guided) + far * t_guided
        
        z_vals = torch.cat([t_uniform, z_guided], dim=1)
        z_vals, _ = torch.sort(z_vals, dim=1)
    else:
        t_vals = torch.linspace(0., 1., steps=N_samples, device=rays_o.device)
        z_vals = near * (1. - t_vals) + far * t_vals
        z_vals = z_vals[None, :].repeat(rays_o.shape[0], 1)
    
    # Volume rendering
    pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., None]
    
    dirs = F.normalize(rays_d, dim=-1)
    dirs_enc = pe_dir(dirs)
    pts_enc = pe_xyz(pts)
    
    N_rays, S = pts_enc.shape[:2]
    x = pts_enc.reshape(N_rays * S, -1)
    d = dirs_enc[:, None, :].expand(N_rays, S, -1).reshape(N_rays * S, -1)
    
    rgb, sigma = nerf(x, d)
    rgb = rgb.view(N_rays, S, 3)
    sigma = sigma.view(N_rays, S)
    
    deltas = z_vals[:, 1:] - z_vals[:, :-1]
    deltas = torch.cat([deltas, 1e10 * torch.ones_like(deltas[:, :1])], -1)
    alpha = 1. - torch.exp(-sigma * deltas)
    T = torch.cumprod(
        torch.cat([torch.ones((N_rays, 1), device=alpha.device), 1. - alpha + 1e-10], -1),
        -1
    )[:, :-1]
    weights = alpha * T
    
    rgb_map = (weights[..., None] * rgb).sum(dim=1)
    depth_map = (weights * z_vals).sum(dim=1)
    
    if white_bg:
        acc_map = weights.sum(dim=1, keepdim=True)
        rgb_map = rgb_map + (1. - acc_map)
    
    if return_extras:
        return rgb_map.clamp(0, 1), depth_map, weights, z_vals, sigma
    else:
        return rgb_map.clamp(0, 1), depth_map

## Training Loop

In [None]:
import os
import numpy as np
import torch
from tqdm.auto import tqdm
from pathlib import Path

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def to_tensor(x):
    if x is None:
        return None
    return torch.from_numpy(x).to(device)

images = to_tensor(imgs_train).float()
poses = to_tensor(poses_train).float()
depths = to_tensor(depths_train).float() if depths_train is not None else None

N_imgs, Ht, Wt = images.shape[:3]
psnr_history = []
loss_history = []

iters = 20000
batch_rays = 1024
lr = 5e-4
N_samples = 64
lambda_soft = 0.01
lambda_hard = 0.005
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def psnr(mse):
    return -10.0 * torch.log10(mse)

print('Starting hybrid training...')
pbar = tqdm(range(iters))

for it in pbar:
    i = torch.randint(0, N_imgs, (1,)).item()
    c2w = poses[i]
    rays_o, rays_d = get_rays(Ht, Wt, focal, c2w)
    target = images[i].view(-1, 3)
    
    idx = torch.randint(0, target.shape[0], (batch_rays,), device=device)
    rays_o_b = rays_o.view(-1, 3)[idx]
    rays_d_b = rays_d.view(-1, 3)[idx]
    target_b = target[idx]
    
    if depths is not None:
        depth_gt_full = depths[i]
        depth_gt_b = depth_gt_full.view(-1)[idx]
        depth_mask_b = (depth_gt_b > 2.0) & (depth_gt_b < 6.0)
    else:
        depth_gt_b = None
        depth_mask_b = None
    
    rgb_pred, depth_pred, weights, z_vals, sigma = render_rays_hybrid(
        model, rays_o_b, rays_d_b,
        near=2.0, far=6.0, N_samples=N_samples,
        depth_gt=depth_gt_b, depth_mask=depth_mask_b,
        eps=0.3, return_extras=True
    )
    
    rgb_loss = F.mse_loss(rgb_pred, target_b)
    
    if depths is not None:
        soft_loss = ((depth_pred - depth_gt_b)**2).mean()
        freespace = free_space_loss(sigma, z_vals, depth_gt_b, depth_mask_b)
        surface = surface_concentration_loss(weights, z_vals, depth_gt_b, depth_mask_b)
        loss = rgb_loss + lambda_soft * soft_loss + lambda_hard * (freespace + surface)
    else:
        soft_loss = torch.tensor(0.0, device=device)
        loss = rgb_loss
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (it + 1) % 100 == 0:
        if depths is not None:
            msg = f"it {it+1} | RGB: {rgb_loss.item():.5f} | Soft: {soft_loss.item():.5f} | PSNR: {psnr(rgb_loss).item():.2f} dB"
            pbar.set_description(msg)
        else:
            pbar.set_description(f'it {it+1} | loss {loss.item():.5f}')

save_dir = Path('results/hybrid')
save_dir.mkdir(parents=True, exist_ok=True)
np.save(save_dir / 'psnr_history.npy', np.array(psnr_history))
np.save(save_dir / 'loss_history.npy', np.array(loss_history))
torch.save(model.state_dict(), save_dir / 'model_hybrid.pth')
print(f'Hybrid training complete! Results saved to {save_dir}')