Skip to content

Commit

Permalink
Merge pull request NVIDIAGameWorks#52 from chrischoy/cchoy/fix_densit…
Browse files Browse the repository at this point in the history
…y_rendering

 ERROR FIX: fix density rendering error for packed tracer
  • Loading branch information
Caenorst committed Sep 23, 2022
2 parents 048f267 + ebee68c commit 951f506
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 7 deletions.
28 changes: 28 additions & 0 deletions tests/test_packed_rf_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import unittest

from wisp.datasets import RandomViewDataset
from wisp.models.nefs import NeuralRadianceField
from wisp.tracers import PackedRFTracer
from wisp.models import Pipeline


class Test(unittest.TestCase):
def test_extra_channels(self):
device = "cuda:0"
nef = NeuralRadianceField(grid_type="HashGrid", multiscale_type="cat")
nef.grid.init_from_geometric(min_width=2, max_width=4, num_lods=1)
nef.to(device)
tracer = PackedRFTracer()
pipeline = Pipeline(nef, tracer)
dataset = RandomViewDataset(num_rays=128, device=device)
datum = dataset[0]
rb = pipeline(rays=datum.rays, channels=["rgb", "density"])
assert hasattr(rb, "density")
assert rb.rgb.shape[0] == rb.density.shape[0]
1 change: 1 addition & 0 deletions wisp/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@

from .sdf_dataset import SDFDataset
from .multiview_dataset import MultiviewDataset
from .random_view_dataset import RandomViewDataset
from .utils import default_collate
144 changes: 144 additions & 0 deletions wisp/datasets/random_view_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

from typing import Callable, Tuple, Union
from copy import deepcopy
import unittest
import random
import numpy as np

import torch

from kaolin.render.camera import Camera
from kaolin.render.camera.extrinsics import CameraExtrinsics
from torch.utils.data import Dataset
from wisp.utils import DotDict
from wisp.ops.raygen import generate_pinhole_rays, generate_centered_pixel_coords


def spherical_eye(
radius=1,
theta=np.pi / 3,
phi=0,
):
return torch.FloatTensor(
[
radius * np.sin(theta) * np.sin(phi),
radius * np.cos(theta),
radius * np.sin(theta) * np.cos(phi),
],
) # [3]


# TODO (cchoy): move the radius/theta initialization to kaolin CameraExtrinsics
def spherical_coord_to_pose(
radius=1, theta=np.pi / 3, phi=0, up=torch.FloatTensor([0, 1, 0])
):
"""generate camera pose from a spherical coordinate
Args:
size: batch size of generated poses.
device: where to allocate the output.
radius: camera radius
theta_range: [min, max], should be in [0, pi]
phi_range: [min, max], should be in [0, 2 * pi]
Return:
poses: [size, 4, 4] in OpenGL convention
"""

eye = spherical_eye(radius, theta, phi)

# lookat
def normalize(vec):
return torch.nn.functional.normalize(vec, dim=-1)

backward = -normalize(eye)
right = normalize(torch.cross(backward, up, dim=-1))
up = normalize(torch.cross(right, backward, dim=-1))

world_rot = torch.stack((right, up, -backward), dim=1)
world_tran = -world_rot @ eye.unsqueeze(-1)

return CameraExtrinsics._from_world_in_cam_coords(
rotation=world_rot, translation=world_tran, device="cpu", requires_grad=False
)


class RandomViewDataset(Dataset):
def __init__(
self,
# TODO(cchoy) add different random view types e.g. forward_facing, inward
n_size=100, # length of this dataset. Used to define number of iterations per epoch
view_radius_range: Tuple = (2, 4),
view_theta_range: Tuple = (np.pi / 4, np.pi / 2 - np.pi / 8),
view_phi_range: Tuple = (0, 2 * np.pi),
viewport_height: int = 320,
viewport_width: int = 320,
fov: float = 30 * np.pi / 180,
ray_dist_range: Tuple = (0.01, 8),
look_at: Tuple = (0, 0, 0),
num_rays: int = -1, # number of rays. If -1, return all rays
transform: Callable = None,
**kwargs,
):
self.n_size = n_size
self.cam = DotDict(
dict(
fov=fov,
width=viewport_width,
height=viewport_height,
)
)
self.view_radius_range = view_radius_range
self.view_theta_range = view_theta_range
self.view_phi_range = view_phi_range
self.ray_dist_range = ray_dist_range
assert len(look_at) == 3
self.look_at = look_at

self.num_rays = num_rays
self.transform = transform

def __len__(self):
"""Length of the dataset in number of rays."""
return self.n_size

def __getitem__(self, idx: int):
"""Returns a ray."""
# TODO (cchoy): uniform sphere sampling (http://corysimon.github.io/articles/uniformdistn-on-sphere/)
radius = random.uniform(*self.view_radius_range)
theta = random.uniform(*self.view_theta_range)
phi = random.uniform(*self.view_phi_range)
cam = Camera.from_args(
eye=spherical_eye(radius, theta, phi),
at=torch.tensor([0.0, 0.0, 0.0]),
up=torch.tensor([0.0, 1.0, 0.0]),
fov=self.cam.fov,
width=self.cam.width,
height=self.cam.height,
device="cpu",
)

ray_grid = generate_centered_pixel_coords(
cam.width, cam.height, cam.width, cam.height, device="cpu"
)
out = DotDict(dict(rays=generate_pinhole_rays(cam, ray_grid), cam=cam))
if self.num_rays > 0:
ray_idx = random.sample(range(len(out.rays)), self.num_rays)
out.rays = out.rays[ray_idx]

if self.transform is not None:
out = self.transform(out)

return out


class TestRandViewDataset(unittest.TestCase):
def load(self):
dataset = RandomViewDataset()
print(dataset[0])
13 changes: 7 additions & 6 deletions wisp/tracers/packed_rf_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def trace(self, nef, channels, extra_channels, rays,
color, density = nef(coords=samples, ray_d=hit_ray_d, pidx=pidx, lod_idx=lod_idx,
channels=["rgb", "density"])

timer.check("RGBA")
timer.check("RGBA")
del ridx, rays

# Compute optical thickness
Expand Down Expand Up @@ -126,7 +126,7 @@ def trace(self, nef, channels, extra_channels, rays,
color = alpha * ray_colors

rgb[ridx_hit.long()] = color

timer.check("Composit")

extra_outputs = {}
Expand All @@ -136,12 +136,13 @@ def trace(self, nef, channels, extra_channels, rays,
pidx=pidx,
lod_idx=lod_idx,
channels=channel)
ray_feats, transmittance = spc_render.exponential_integration(feats.reshape(-1, 3), tau, boundary, exclusive=True)
num_channels = feats.shape[-1]
ray_feats, transmittance = spc_render.exponential_integration(
feats.view(-1, num_channels), tau, boundary, exclusive=True
)
composited_feats = alpha * ray_feats
out_feats = torch.zeros(N, feats.shape[-1], device=feats.device)
out_feats = torch.zeros(N, num_channels, device=feats.device)
out_feats[ridx_hit.long()] = composited_feats
# TODO(ttakikawa): Right now the extra_channels are assumed to be dim 3. Think about how we can make this more generic...
assert(out_feats.shape[-1] == 3)
extra_outputs[channel] = out_feats

return RenderBuffer(depth=depth, hit=hit, rgb=rgb, alpha=out_alpha, **extra_outputs)
2 changes: 1 addition & 1 deletion wisp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.


from .helper_classes import DotDict
from .debug import PsDebugger
from .perf import PerfTimer, print_gpu_memory
21 changes: 21 additions & 0 deletions wisp/utils/helper_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
class DotDict(dict):
"""
from https://stackoverflow.com/questions/13520421/recursive-dotdict
a dictionary that supports dot notation
as well as dictionary access notation
usage: d = DotDict() or d = DotDict({'val1':'first'})
set attributes: d.val2 = 'second' or d['val2'] = 'second'
get attributes: d.val2 or d['val2']
"""

__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__

def __init__(self, dct=None):
if dct is not None:
for key, value in dct.items():
if hasattr(value, "keys"):
value = DotDict(value)
self[key] = value

0 comments on commit 951f506

Please sign in to comment.