In [None]:
# Allow importing from src
import sys
sys.path.insert(0, '../src/')

In [None]:
from nerf import LNeRF
from utils.rays import create_rays, render_rays
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset
from torchvision.utils import make_grid

COMPUTE_DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    COMPUTE_DEVICE = torch.device('cuda:0')
elif torch.mps.is_available():
    COMPUTE_DEVICE = torch.device('mps')
print(f"{COMPUTE_DEVICE=}")

float_to_rad = lambda f: torch.deg2rad(torch.tensor(f, dtype=torch.float32))

# Setup

## Model loading

In [None]:
model = LNeRF.load_from_checkpoint(
    "../lightning_logs/shurtape200x200_decay=1e-06_exp/checkpoints/best_val_psnr_epoch=9.ckpt",
    map_location=COMPUTE_DEVICE,
    hparams_file="../lightning_logs/shurtape200x200_decay=1e-06_exp/hparams.yaml"
)
model.freeze()
focal = torch.tensor(model.hparams.focal, dtype=torch.float32)

## Rendering parameter calculation

In [None]:
# In accordance with mitsuba's conventions
def look_at(radius: float, theta: Tensor, phi: Tensor, target: Tensor = torch.tensor([0,0,0], dtype=torch.float32)):
    origin = torch.tensor([
        radius * torch.sin(theta) * torch.cos(phi),
        radius * torch.sin(theta) * torch.sin(phi),
        radius * torch.cos(theta),
    ])
    
    forward = F.normalize(origin - target, p="fro", dim=0)
    up = torch.tensor([0, 0, 1], dtype=torch.float32)
    right = F.normalize(torch.cross(up, forward, dim=0), p="fro", dim=0)
    up = F.normalize(torch.cross(forward, right, dim=0), p="fro", dim=0)

    return torch.tensor([
        [right[0], up[0], forward[0], origin[0]],
        [right[1], up[1], forward[1], origin[1]],
        [right[2], up[2], forward[2], origin[2]],
        [0, 0, 0, 1],
    ])

look_at(4, float_to_rad(45), float_to_rad(10))

In [None]:
def intrinsic(focal: Tensor, size: float):
    return torch.tensor([
        [focal.item(), 0, size // 2],
        [0, focal.item(), size // 2],
        [0, 0, 1],
    ], dtype=torch.float32)

In [None]:
c2w = look_at(4, float_to_rad(1e-4), float_to_rad(0)).to(COMPUTE_DEVICE)
img = model.render_image(200, 200, c2w, focal.to(COMPUTE_DEVICE))
plt.imshow(img.cpu().clamp(0,1))
plt.axis('off')

# Point cloud using volume rendering depths

In [None]:
img_shape = (200, 200)
thetas = [1e-4, 90, 90, 90, 90, 180 - 1e-4]
phis = [0, 0, 90, 180, 270, 0]

origins, directions = [], []
for theta, phi in zip(thetas, phis):
    o, d = create_rays(img_shape[0], img_shape[1], intrinsic(focal, 200), look_at(4, float_to_rad(theta), float_to_rad(phi)))
    o, d = o.flatten(0, 1), d.flatten(0, 1)
    origins.append(o)
    directions.append(d)

origins = torch.cat(origins, dim=0)
directions = torch.cat(directions, dim=0)

od_loader = DataLoader(TensorDataset(origins, directions), batch_size=2**11)

coarse_rgbs, coarse_depths, fine_rgbs, fine_depths = [], [], [], []

for o, d in od_loader:
    with torch.no_grad():
        cc, cd, fc, fd = model.compute_along_rays(o.to(COMPUTE_DEVICE), d.to(COMPUTE_DEVICE), coarse_samples=64, fine_samples=128)
    coarse_rgbs.append(cc)
    coarse_depths.append(cd)
    fine_rgbs.append(fc)
    fine_depths.append(fd)

coarse_rgbs = torch.cat(coarse_rgbs, dim=0)
coarse_depths = torch.cat(coarse_depths, dim=0)
fine_rgbs = torch.cat(fine_rgbs, dim=0)
fine_depths = torch.cat(fine_depths, dim=0)

In [None]:
rgb, depth, acc, alpha, weights = render_rays(fine_rgbs, fine_depths, far=model.hparams.far)

rescaled_acc = acc - acc.min()
rescaled_acc /= rescaled_acc.max()

rgba = torch.cat([rgb.cpu(), rescaled_acc.cpu()], dim=-1).clamp(0,1)
images = rgba.reshape((-1,) + img_shape + (4,))
image_grid = make_grid(images.permute(0, 3, 1, 2), 3).permute(1, 2, 0)

depth_map = depth.reshape((-1,) + img_shape + (1,)).cpu()
depth_map[depth_map <= model.hparams.near] = model.hparams.far
depth_imgs = (depth_map.expand((-1, -1, -1, 3)) - model.hparams.near) / (model.hparams.far - model.hparams.near)
depth_imgs = torch.cat([depth_imgs, torch.ones_like(depth_map)], dim=-1)
depth_grid = make_grid(depth_imgs.permute(0, 3, 1, 2), 3).permute(1, 2, 0)

fig, ax = plt.subplots(1, 1, figsize=(18, 6))
ax.imshow(torch.cat([image_grid, depth_grid], dim=1))
ax.axis('off')
plt.show()

# torch.min(rescaled_acc), torch.max(rescaled_acc), torch.quantile(rescaled_acc, 0.4)

In [None]:
def farthest_point_sampling(points, num_samples, return_indices=False):
    sampled_indices = torch.zeros(num_samples, dtype=torch.long)
    distances = torch.full((points.shape[0],), float('inf'))
    
    # Start with a random point
    sampled_indices[0] = torch.randint(0, points.shape[0], (1,))
    
    for i in range(1, num_samples):
        last_selected = points[sampled_indices[i-1]]
        dist = torch.norm(points - last_selected, dim=1)
        distances = torch.minimum(distances, dist)
        sampled_indices[i] = torch.argmax(distances)
    
    if return_indices:
        return points[sampled_indices], sampled_indices
    
    return points[sampled_indices]

mask = (depth > model.hparams.near).cpu() & (rgba[..., -1] > 0.98)

points = origins.cpu()[mask] + depth.unsqueeze(-1).cpu()[mask] * directions.cpu()[mask]
points, sampled_indices = farthest_point_sampling(points, 10_000, True)
colors = rgba[mask][sampled_indices]
og = look_at(2, float_to_rad(70), float_to_rad(110))[:3, -1]
sorting = torch.argsort(torch.norm(og - points, p="fro", dim=-1), descending=True)

fig, ax = plt.subplots(1, 1, figsize=(8,8), subplot_kw={"projection": "3d"})
for (x, y, z), color in zip(points[sorting], colors[sorting]):
    ax.plot(x, y, z, linewidth=0, markersize=1, marker='o', c=color.tolist())
ax.view_init(20, 110, 0)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)

points.shape

# Point cloud using density gradient based ray exploration

## Ray and depth visualization

In [None]:
img_shape = (200, 200)
theta, phi = 90, 0

origins, directions = create_rays(img_shape[0], img_shape[1], intrinsic(focal, 200), look_at(4, float_to_rad(theta), float_to_rad(phi)))
origins, directions = origins.flatten(0, 1), directions.flatten(0, 1)

od_loader = DataLoader(TensorDataset(origins, directions), batch_size=2**11)

coarse_rgbs, coarse_depths, fine_rgbs, fine_depths = [], [], [], []

for o, d in od_loader:
    with torch.no_grad():
        cc, cd, fc, fd = model.compute_along_rays(o.to(COMPUTE_DEVICE), d.to(COMPUTE_DEVICE), coarse_samples=64, fine_samples=128)
    coarse_rgbs.append(cc)
    coarse_depths.append(cd)
    fine_rgbs.append(fc)
    fine_depths.append(fd)

coarse_rgbs = torch.cat(coarse_rgbs, dim=0)
coarse_depths = torch.cat(coarse_depths, dim=0)
fine_rgbs = torch.cat(fine_rgbs, dim=0)
fine_depths = torch.cat(fine_depths, dim=0)

rgb, depth, acc, alpha, weights = render_rays(fine_rgbs, fine_depths, far=model.hparams.far)

rescaled_acc = acc - acc.min()
rescaled_acc /= rescaled_acc.max()

rgba = torch.cat([rgb.cpu(), rescaled_acc.cpu()], dim=-1).clamp(0,1)
plt.imshow(rgba.reshape(img_shape + (4,)))
plt.show()

In [None]:
def colored_line(ax, x, y, c, **kwargs):
    if isinstance(c, torch.Tensor):
        c = c.numpy()

    for i in range(len(x)-1):
        color = c[i] if (c[i] <= 0.9).any() else np.array([0.9, 0.9, 0.9])
        ax.plot(x[i:i+2], y[i:i+2], color=color, **kwargs)

In [None]:
img_shape = (200, 200)
theta, phi = 90, 0
idx = 85
origins, directions = create_rays(img_shape[0], img_shape[1], intrinsic(focal, 200), look_at(4, float_to_rad(theta), float_to_rad(phi)))
origin, direction = origins[idx:idx+1, idx], directions[idx:idx+1, idx]

with torch.no_grad():
    coarse_rgbs, coarse_depths, fine_rgbs, fine_depths = model.compute_along_rays(origin.to(COMPUTE_DEVICE), direction.to(COMPUTE_DEVICE), coarse_samples=256, fine_samples=256)
rgb, depth, acc, alpha, weights = render_rays(fine_rgbs, fine_depths, far=model.hparams.far)

fig, ax = plt.subplots(1, 1, figsize=(14, 4))
ax.set_title("Colored coarse samples for full ray")
colored_line(ax, coarse_depths[0].cpu(), coarse_rgbs[0, ..., 3].cpu(), coarse_rgbs[0, ..., :3].cpu(), linewidth=4)
plt.show()

fig, axs = plt.subplots(1, 4, figsize=(20, 4))
axs[0].set_title("Coarses")
colored_line(axs[0], coarse_depths[0].cpu(), coarse_rgbs[0, ..., 3].cpu(), coarse_rgbs[0, ..., :3].cpu())

axs[1].set_title("Fines")
colored_line(axs[1], fine_depths[0].cpu(), fine_rgbs[0, ..., 3].cpu(), fine_rgbs[0, ..., :3].cpu())

axs[2].set_title("Alphas")
axs[2].plot(fine_depths[0].cpu(), alpha[0].cpu())

axs[3].set_title("Weights")
axs[3].plot(fine_depths[0].cpu(), weights[0].cpu())

for ax in axs.flatten():
    ax.axvline(depth.item(), color="red")
    ax.set_xlim(3.0, 3.5)

## Gradient based ray exploration

In [None]:
def sample_ray_uniformally(origins: Tensor, directions: Tensor, near: float, far: float,
                           num_samples: int) -> tuple[Tensor, Tensor, Tensor]:
    device = origins.device
    depths = torch.linspace(near, far, num_samples, dtype=torch.float32, device=device, requires_grad=True).expand(origins.shape[0], -1)
    depths.retain_grad()
    points = origins[..., None, :] + directions[..., None, :] * depths[..., :, None]
    directions = directions[..., None, :].expand(points.shape)
    return points, directions, depths

origin, direction = origin.to(COMPUTE_DEVICE), direction.to(COMPUTE_DEVICE)
points, point_dirs, depths = sample_ray_uniformally(origin, direction, model.hparams.near, model.hparams.far, num_samples=2**10)
sigma = model.nerf(points, point_dirs, skip_colors=True)
sigma.backward(torch.ones_like(sigma))
sigma = sigma.detach()

grads = depths.grad[0]
display_grads = grads / torch.max(torch.abs(grads)) * torch.max(sigma)

depths = depths.detach()
plt.xlim(3.00, 3.4)
plt.ylim(-5.0, torch.max(sigma).item())
plt.plot(depths[0].cpu(), sigma[0, ..., 0].cpu(), label="sigma")
plt.plot(depths[0].cpu(), display_grads.cpu(), label="depth grad rescaled")
plt.xlabel("depth")
plt.plot(depths[0].cpu(), torch.zeros_like(depths[0].cpu()), label="zeroline", color="red")
plt.legend()

In [None]:
sigma_limit = 5.0  # When sigma is larger than this, we use gradient ascent
gamma = 2e-5  # Step size multiplier (analogous to learning rate for gradient descent)
non_grad_step_size = 3e-2  # Step size when sigma is below limit
grad_epsilon = 5e-2  # If the magnitude of the gradient falls below this, we consider the depth for the surface point to be found
max_iters = 300  # Limit for iteration count

depths = torch.full((max_iters,), torch.inf, dtype=torch.float32)
sigmas = torch.full((max_iters,), torch.inf, dtype=torch.float32)
grads = torch.full((max_iters,), torch.inf, dtype=torch.float32)

verbose = True
if verbose:
    print(f"{'depth':^8} | {'sigma':^9} | {'grad_step?':^10} | {'grad':^10} | {'step_size':^10}")

depth = torch.full((1,), model.hparams.near, dtype=torch.float32, requires_grad=True, device=COMPUTE_DEVICE)
iters, stop = 0, False
while not stop and depth < model.hparams.far and iters < max_iters:
    # delta_sigma/delta_depth calculated
    sigma = model.nerf(origin + depth * direction, direction, skip_colors=True)
    depth.retain_grad()  # Needed to retain grad for non-leaf nodes in the computation graph
    sigma.backward(torch.ones_like(sigma))  # Backward without target function

    # fixed step size if sigma is below a limit (aka. empty space areas)
    grad_step = (sigma > sigma_limit).all()
    grad = depth.grad
    step_size = grad * gamma
    step_size[~grad_step] = non_grad_step_size

    depths[iters] = depth.item()
    sigmas[iters] = sigma.item()
    grads[iters] = grad.item()

    if verbose:
        print(f"{depth.item():8.6f} | {sigma.item():9.5f} | {str(grad_step.item()):<10} | {grad.item():10.5f} | {step_size.item():10.7f}")
    
    # approximated local maximum specified by epsilon gradient
    if grad.abs() < grad_epsilon:
        stop = True 

    depth = depth + step_size
    iters += 1

mask = depths != torch.inf
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.set_xlim(depths[mask].min() * 0.999, depths[mask].max() * 1.001)
ax.set_ylim(-3, torch.max(sigmas[mask]) * 1.2)
ax.set_title(f"Local max found at depth {depths[mask][-1]:8.6f} after {iters:03d} iters")
ax.plot(depths[mask], sigmas[mask], label="sigma", marker='o')
ax.plot(depths[mask], grads[mask] / torch.max(grads[mask]), label="grad normalized")
ax.plot(depths[mask], torch.zeros_like(depths[mask]), label="zeroline", linestyle='--', color="red")
ax.plot(depths[mask], torch.full_like(depths[mask], sigma_limit), label="sigma limit for grad step", linestyle='--', color="purple")
ax.set_xlabel("depth")
ax.legend()