Skip to content

Commit

Permalink
Patch failing tests and rename mark_pack_boundary to mark_pack_bounda…
Browse files Browse the repository at this point in the history
…ries (#493)

Signed-off-by: Towaki Takikawa <ttakikawa@nvidia.com>
  • Loading branch information
tovacinni committed Dec 15, 2021
1 parent 2428ae8 commit 8fe43a0
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/modules/kaolin.ops.spc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ To apply ray tracing we currently only support non-batched version, for instance
>>> point_hierarchy = kaolin.ops.spc.generate_points(octrees, pyramids, exsum)
>>> ridx, pidx, depth = kaolin.render.spc.unbatched_raytrace(octree, point_hierarchy, pyramids[0], exsum,
... origin, direction, max_level)
>>> first_hits_mask = kaolin.render.spc.mark_pack_boundary(ridx)
>>> first_hits_mask = kaolin.render.spc.mark_pack_boundaries(ridx)
>>> first_hits_point = pidx[first_hits_mask]
>>> first_hits_rgb = rgb[first_hits_point - pyramids[max_level - 2]]

Expand Down
2 changes: 1 addition & 1 deletion kaolin/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::module render_spc = render.def_submodule("spc");
render_spc.def("raytrace_cuda", &raytrace_cuda);
render_spc.def("generate_primary_rays_cuda", &generate_primary_rays_cuda); // Deprecate soon
render_spc.def("mark_pack_boundary_cuda", &mark_pack_boundary_cuda);
render_spc.def("mark_pack_boundaries_cuda", &mark_pack_boundaries_cuda);
render_spc.def("generate_shadow_rays_cuda", &generate_shadow_rays_cuda); // Deprecate soon
render_spc.def("inclusive_sum_cuda", &inclusive_sum_cuda);
render_spc.def("diff_cuda", &diff_cuda);
Expand Down
6 changes: 3 additions & 3 deletions kaolin/csrc/render/spc/raytrace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ uint raytrace_cuda_impl(
bool return_depth,
bool with_exit);

void mark_pack_boundary_cuda_impl(
void mark_pack_boundaries_cuda_impl(
at::Tensor pack_ids,
at::Tensor boundaries);

Expand Down Expand Up @@ -230,7 +230,7 @@ std::vector<at::Tensor> raytrace_cuda(
#endif // WITH_CUDA
}

at::Tensor mark_pack_boundary_cuda(
at::Tensor mark_pack_boundaries_cuda(
at::Tensor pack_ids) {
#ifdef WITH_CUDA
at::TensorArg pack_ids_arg{pack_ids, "pack_ids", 1};
Expand All @@ -240,7 +240,7 @@ at::Tensor mark_pack_boundary_cuda(
at::checkScalarTypes(__func__, pack_ids_arg, {at::kByte, at::kChar, at::kInt, at::kLong, at::kShort});
int num_ids = pack_ids.size(0);
at::Tensor boundaries = at::zeros({num_ids}, pack_ids.options().dtype(at::kInt));
mark_pack_boundary_cuda_impl(pack_ids, boundaries);
mark_pack_boundaries_cuda_impl(pack_ids, boundaries);
return boundaries;
#else
KAOLIN_NO_CUDA_ERROR(__func__);
Expand Down
2 changes: 1 addition & 1 deletion kaolin/csrc/render/spc/raytrace.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::vector<at::Tensor> raytrace_cuda(
bool with_exit);


at::Tensor mark_pack_boundary_cuda(
at::Tensor mark_pack_boundaries_cuda(
at::Tensor pack_ids);

std::vector<at::Tensor> generate_shadow_rays_cuda(
Expand Down
8 changes: 4 additions & 4 deletions kaolin/csrc/render/spc/raytrace_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ subdivide_cuda_kernel(

template<typename scalar_t>
__global__ void
mark_pack_boundary_cuda_kernel(
mark_pack_boundaries_cuda_kernel(
const int64_t num,
const scalar_t* __restrict__ pack_ids,
uint* __restrict__ boundaries) {
Expand Down Expand Up @@ -546,14 +546,14 @@ uint raytrace_cuda_impl(
return cnt;
}

void mark_pack_boundary_cuda_impl(
void mark_pack_boundaries_cuda_impl(
at::Tensor pack_ids,
at::Tensor boundaries) {
int64_t num = pack_ids.size(0);
AT_DISPATCH_INTEGRAL_TYPES(pack_ids.type(), "mark_pack_boundary_cuda", ([&] {
AT_DISPATCH_INTEGRAL_TYPES(pack_ids.type(), "mark_pack_boundaries_cuda", ([&] {
const at::cuda::OptionalCUDAGuard device_guard(at::device_of(boundaries));
auto stream = at::cuda::getCurrentCUDAStream();
mark_pack_boundary_cuda_kernel<<<(num + 1023) / 1024, 1024, 0, stream>>>(
mark_pack_boundaries_cuda_kernel<<<(num + 1023) / 1024, 1024, 0, stream>>>(
num,
pack_ids.data_ptr<scalar_t>(),
reinterpret_cast<uint*>(boundaries.data_ptr<int>()));
Expand Down
18 changes: 9 additions & 9 deletions kaolin/render/spc/raytrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__all__ = [
'unbatched_raytrace',
'mark_pack_boundary',
'mark_pack_boundaries',
'mark_first_hit',
'diff',
'sum_reduce',
Expand Down Expand Up @@ -83,7 +83,7 @@ def unbatched_raytrace(octree, point_hierarchy, pyramid, exsum, origin, directio
else:
return ray_index, point_index

def mark_pack_boundary(pack_ids):
def mark_pack_boundaries(pack_ids):
r"""Mark the boundaries of pack IDs.
Pack IDs are sorted tensors which mark the ID of the pack each element belongs in.
Expand All @@ -99,17 +99,17 @@ def mark_pack_boundary(pack_ids):
first_hits (torch.BoolTensor): the boolean mask marking the boundaries.
Examples:
>>> pack_ids = torch.IntTensor([1,1,1,1,2,2,2])
>>> mark_pack_boundary(pack_ids)
tensor([1,0,0,0,1,0,0])
>>> pack_ids = torch.IntTensor([1,1,1,1,2,2,2]).to('cuda:0')
>>> mark_pack_boundaries(pack_ids)
tensor([ True, False, False, False, True, False, False], device='cuda:0')
"""
return _C.render.spc.mark_pack_boundary_cuda(pack_ids.contiguous()).bool()
return _C.render.spc.mark_pack_boundaries_cuda(pack_ids.contiguous()).bool()

def mark_first_hit(ridx):
r"""Mark the first hit in the nuggets.
.. deprecated:: 0.10.0
This function is deprecated. Use :func:`mark_pack_boundary`.
This function is deprecated. Use :func:`mark_pack_boundaries`.
The nuggets are a packed tensor containing correspondences from ray index to point index, sorted
within each ray pack by depth. This will mark true for each first hit (by depth) for a pack of
Expand All @@ -118,8 +118,8 @@ def mark_first_hit(ridx):
Returns:
first_hits (torch.BoolTensor): the boolean mask marking the first hit by depth.
"""
warnings.warn("mark_first_hit has been deprecated, please use mark_pack_boundary instead")
return mark_pack_boundary(ridx)
warnings.warn("mark_first_hit has been deprecated, please use mark_pack_boundaries instead")
return mark_pack_boundaries(ridx)

def diff(feats, boundaries):
r"""Find the delta between each of the features in a pack.
Expand Down
8 changes: 4 additions & 4 deletions tests/python/kaolin/render/spc/test_rayops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def boundaries(self):
boundary = torch.tensor([1,0,1,0,0,1], device='cuda', dtype=torch.bool)
return boundary

def test_mark_pack_boundary(self):
def test_mark_pack_boundaries(self):
ridx = torch.tensor([1,1,1,1,2,2,3,3,3], device='cuda', dtype=torch.int)

expected_boundary = torch.tensor([1,0,0,0,1,0,1,0,0], device='cuda', dtype=torch.bool)

output = spc_render.mark_pack_boundary(ridx)
output = spc_render.mark_pack_boundaries(ridx)

assert torch.equal(output, expected_boundary)

Expand Down Expand Up @@ -161,7 +161,7 @@ def test_cumprod_big(self, feats_big, boundaries_big):

def test_cumprod_big_backward(self, feats_big, boundaries_big):

feats_big += 1e-7
feats_big += 1e-3
feats_big.requires_grad = True
fdim = feats_big.shape[-1]

Expand All @@ -181,7 +181,7 @@ def test_cumprod_big_backward(self, feats_big, boundaries_big):
loss.backward()
grad1 = feats_big.grad.clone()

assert torch.allclose(grad0, grad1)
assert torch.allclose(grad0, grad1, atol=1e-2)

def test_cumprod_reverse(self, feats, boundaries):
cumprod = spc_render.cumprod(feats, boundaries, reverse=True)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/kaolin/render/spc/test_raytrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from kaolin.ops.spc import scan_octrees, generate_points, bits_to_uint8

from kaolin.render.spc import unbatched_raytrace, mark_pack_boundary
from kaolin.render.spc import unbatched_raytrace, mark_pack_boundaries

class TestRaytrace:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_mark_first_positive(self, octree, point_hierarchy, pyramid, exsum):
origin = self._generate_rays_origin(height, width, -3)
ridx, pidx = unbatched_raytrace(
octree, point_hierarchy, pyramid, exsum, origin, direction, 2, return_depth=False)
first_hits = mark_pack_boundary(ridx)
first_hits = mark_pack_boundaries(ridx)
expected_first_hits = torch.tensor([1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0],
device='cuda', dtype=torch.bool)
assert torch.equal(first_hits, expected_first_hits)
Expand All @@ -288,7 +288,7 @@ def test_mark_first_negative(self, octree, point_hierarchy, pyramid, exsum):
origin = self._generate_rays_origin(height, width, 3)
ridx, pidx = unbatched_raytrace(
octree, point_hierarchy, pyramid, exsum, origin, direction, 2, return_depth=False)
first_hits = mark_pack_boundary(ridx)
first_hits = mark_pack_boundaries(ridx)
expected_first_hits = torch.tensor([1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0],
device='cuda', dtype=torch.bool)
assert torch.equal(first_hits, expected_first_hits)
Expand Down

0 comments on commit 8fe43a0

Please sign in to comment.