In [1]:
import viser, time  # pip install viser
import numpy as np
from data import data_loader
from utils import sample_along_rays
from utils import get_intrinsic_matrix
from rays import RaysData
from utils import transform

In [6]:

images_train, c2ws_train, _, _, _, focal = data_loader()
K = get_intrinsic_matrix(focal, images_train.shape[1], images_train.shape[2])
dataset = RaysData(images_train, K, c2ws_train,focal=focal)
rays_o, rays_d, pixels = dataset.sample_rays(100)
points = sample_along_rays(rays_o, rays_d, perturb=True)
H, W = images_train.shape[1:3]

In [2]:
import torch

def volrend(sigmas, rgbs, step_size):
    # sigmas: Nxn_samplesx1 tensor
    # rgbs: Nxn_samplesx3 tensor
    # step_size: Nxn_samplesx1 tensor
    # return: N x 3 tensor 
    alpha = 1.0 - torch.exp(-sigmas * step_size)
    alpha=alpha.squeeze()
    T_i = torch.cat([torch.ones_like(alpha[:, :1]), 1.0 - alpha], dim=-1)[:, :-1]
    T_i = torch.cumprod(T_i, dim=-1)
    weights = alpha * T_i
    colors = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2)
    return colors

def pos_encoding(L, x):
        # apply a serious of sinusoidal functions to the input cooridnates, to expand its dimensionality
        # pe(x)={x,sin(πx),cos(πx),sin(2^1πx),cos(2^1πx),...,sin(2^(L-1)πx),cos(2^(l-1)πx)}
        # x: [N, 3]
        # L: int
        # return: [N, 3 * L + 3]
        x = x.unsqueeze(-1)
        l = torch.arange(L, dtype=torch.float32, device=x.device)
        l = 2**l
        x = x * l * torch.pi
        x = torch.cat([x.sin(), x.cos()], dim=-1)
        return x.flatten(-2)


In [6]:
pos_encoding(10, torch.rand(1, 3)).shape

torch.Size([1, 60])

In [32]:
import torch

torch.manual_seed(42)
sigmas = torch.rand((10, 64, 1))
rgbs = torch.rand((10, 64, 3))
step_size = (6.0 - 2.0) / 64
step_size = torch.ones((10, 64, 1)) * step_size
rendered_colors = volrend(sigmas, rgbs, step_size)

correct = torch.tensor([
    [0.5006, 0.3728, 0.4728],
    [0.4322, 0.3559, 0.4134],
    [0.4027, 0.4394, 0.4610],
    [0.4514, 0.3829, 0.4196],
    [0.4002, 0.4599, 0.4103],
    [0.4471, 0.4044, 0.4069],
    [0.4285, 0.4072, 0.3777],
    [0.4152, 0.4190, 0.4361],
    [0.4051, 0.3651, 0.3969],
    [0.3253, 0.3587, 0.4215]
  ])
torch.allclose(rendered_colors, correct, rtol=1e-4, atol=1e-4)

True