Skip to content

Commit

Permalink
inside voxel bug fix (#634)
Browse files Browse the repository at this point in the history
Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>
Co-authored-by: Charles Loop <cloop@nvidia.com>
  • Loading branch information
Caenorst and charlesloopNV committed Oct 18, 2022
1 parent 5456fe8 commit bcce7a9
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 59 deletions.
156 changes: 98 additions & 58 deletions kaolin/csrc/render/spc/raytrace_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
// Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -84,14 +84,12 @@ decide_cuda_kernel(
const float3* __restrict__ ray_o,
const float3* __restrict__ ray_d,
const uint2* __restrict__ nuggets,
float* depth,
uint* __restrict__ info,
const uint8_t* __restrict__ octree,
const uint32_t level,
const uint32_t not_done) {

uint tidx = blockDim.x * blockIdx.x + threadIdx.x;
const float eps = 1e-8;

if (tidx < num) {
uint ridx = nuggets[tidx].x;
Expand All @@ -101,7 +99,59 @@ decide_cuda_kernel(
float3 d = ray_d[ridx];

// Radius of voxel
float r = 1.0 / ((float)(0x1 << level)) + eps;
float r = 1.0 / ((float)(0x1 << level));

// Transform to [-1, 1]
const float3 vc = make_float3(
fmaf(r, fmaf(2.0, p.x, 1.0), -1.0f),
fmaf(r, fmaf(2.0, p.y, 1.0), -1.0f),
fmaf(r, fmaf(2.0, p.z, 1.0), -1.0f));

// Compute aux info (precompute to optimize)
float3 sgn = ray_sgn(d);
float3 ray_inv = make_float3(1.0 / d.x, 1.0 / d.y, 1.0 / d.z);

float depth = ray_aabb(o, d, ray_inv, sgn, vc, r);

if (not_done){
if (depth != 0.0)
info[tidx] = __popc(octree[pidx]);
else
info[tidx] = 0;
}
else { // at bottom
if (depth > 0.0)
info[tidx] = 1;
else
info[tidx] = 0;
}
}
}

// Overloaded version of function above that returns depth of voxel/ ray entry points
__global__ void
decide_cuda_kernel(
const uint num,
const point_data* __restrict__ points,
const float3* __restrict__ ray_o,
const float3* __restrict__ ray_d,
const uint2* __restrict__ nuggets,
float* depth,
uint* __restrict__ info,
const uint8_t* __restrict__ octree,
const uint32_t level) {

uint tidx = blockDim.x * blockIdx.x + threadIdx.x;

if (tidx < num) {
uint ridx = nuggets[tidx].x;
uint pidx = nuggets[tidx].y;
point_data p = points[pidx];
float3 o = ray_o[ridx];
float3 d = ray_d[ridx];

// Radius of voxel
float r = 1.0 / ((float)(0x1 << level));

// Transform to [-1, 1]
const float3 vc = make_float3(
Expand All @@ -117,14 +167,14 @@ decide_cuda_kernel(

// Perform AABB check
if (depth[tidx] > 0.0){
// Count # of occupied voxels for expansion, if more levels are left
info[tidx] = not_done ? __popc(octree[pidx]) : 1;
info[tidx] = 1; // mark to keep
} else {
info[tidx] = 0;
}
}
}

// Overloaded version of function above that returns depth of voxel/ ray entry and exit points
__global__ void
decide_cuda_kernel(
const uint num,
Expand All @@ -135,11 +185,9 @@ decide_cuda_kernel(
float2* __restrict__ depth,
uint* __restrict__ info,
const uint8_t* __restrict__ octree,
const uint32_t level,
const uint32_t not_done) {
const uint32_t level) {

uint tidx = blockDim.x * blockIdx.x + threadIdx.x;
const float eps = 1e-8;

if (tidx < num) {
uint ridx = nuggets[tidx].x;
Expand All @@ -149,7 +197,7 @@ decide_cuda_kernel(
float3 d = ray_d[ridx];

// Radius of voxel
float r = 1.0 / ((float)(0x1 << level)) + eps;
float r = 1.0 / ((float)(0x1 << level));

// Transform to [-1, 1]
const float3 vc = make_float3(
Expand All @@ -165,8 +213,7 @@ decide_cuda_kernel(

// Perform AABB check
if (depth[tidx].x > 0.0 && depth[tidx].y > 0.0){
// Count # of occupied voxels for expansion, if more levels are left
info[tidx] = not_done ? __popc(octree[pidx]) : 1;
info[tidx] = 1; // mark to keep
} else {
info[tidx] = 0;
}
Expand Down Expand Up @@ -435,10 +482,6 @@ cumsum_reverse_cuda_kernel(
}
}

////////////////////////////////////////////////////////////////////////////////////////////////
/// CUDA Implementations
////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> raytrace_cuda_impl(
at::Tensor octree,
at::Tensor points,
Expand All @@ -455,25 +498,21 @@ std::vector<at::Tensor> raytrace_cuda_impl(

uint8_t* octree_ptr = octree.data_ptr<uint8_t>();
point_data* points_ptr = reinterpret_cast<point_data*>(points.data_ptr<short>());
uint* pyramid_ptr = (uint*)pyramid.data_ptr<int>();
uint* pyramid_sum = pyramid_ptr + max_level + 2;
uint* exclusive_sum_ptr = reinterpret_cast<uint*>(exclusive_sum.data_ptr<int>());
float3* ray_o_ptr = reinterpret_cast<float3*>(ray_o.data_ptr<float>());
float3* ray_d_ptr = reinterpret_cast<float3*>(ray_d.data_ptr<float>());


// allocate local GPU storage
at::Tensor nuggets0 = at::empty({num, 2}, octree.options().dtype(at::kInt));
uint2* nuggets0_ptr = reinterpret_cast<uint2*>(nuggets0.data_ptr<int>());
at::Tensor nuggets1;

uint depth_dim = with_exit ? 2 : 1;
at::Tensor depths0 = at::empty({num, depth_dim}, octree.options().dtype(at::kFloat));
float* depth0_ptr = depths0.data_ptr<float>();
at::Tensor depths0;
at::Tensor depths1;

// Generate proposals (first proposal is root node)
init_nuggets_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(num, nuggets0_ptr);
init_nuggets_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()));

uint cnt, buffer = 0;
for (uint32_t l = 0; l <= target_level; l++) {
Expand All @@ -482,24 +521,31 @@ std::vector<at::Tensor> raytrace_cuda_impl(
uint* info_ptr = reinterpret_cast<uint*>(info.data_ptr<int>());

// Do the proposals hit?
if (with_exit) {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, nuggets0_ptr,
reinterpret_cast<float2*>(depth0_ptr),
info_ptr, octree_ptr, l, target_level - l);
} else {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, nuggets0_ptr,
depth0_ptr, info_ptr, octree_ptr, l, target_level - l);
if (l == target_level && return_depth) {
depths0 = at::empty({num, depth_dim}, octree.options().dtype(at::kFloat));

if (with_exit) {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()),
reinterpret_cast<float2*>(l == target_level ? depths0.data_ptr<float>() : 0),
info_ptr, octree_ptr, l);
} else {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()),
l == target_level ? depths0.data_ptr<float>() : 0, info_ptr, octree_ptr, l);
}
}
else {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()),
info_ptr, octree_ptr, l, target_level - l);
}


at::Tensor prefix_sum = at::empty({num+1}, octree.options().dtype(at::kInt));
uint* prefix_sum_ptr = reinterpret_cast<uint*>(prefix_sum.data_ptr<int>());

// set first element to zero
CubDebugExit(cudaMemcpy(prefix_sum_ptr, &buffer, sizeof(uint),
cudaMemcpyHostToDevice));
CubDebugExit(cudaMemcpy(prefix_sum_ptr, &buffer, sizeof(uint), cudaMemcpyHostToDevice));

// set up memory for DeviceScan calls
void* temp_storage_ptr = NULL;
Expand All @@ -508,61 +554,55 @@ std::vector<at::Tensor> raytrace_cuda_impl(
at::Tensor temp_storage = at::empty({(int64_t)temp_storage_bytes}, octree.options());
temp_storage_ptr = (void*)temp_storage.data_ptr<uint8_t>();


CubDebugExit(cub::DeviceScan::InclusiveSum(
temp_storage_ptr, temp_storage_bytes, info_ptr,
prefix_sum_ptr + 1, num)); //start sum on second element
cudaMemcpy(&cnt, prefix_sum_ptr + num, sizeof(uint), cudaMemcpyDeviceToHost);
cudaMemcpy(&cnt, prefix_sum_ptr + num, sizeof(uint), cudaMemcpyDeviceToHost);

// allocate local GPU storage
nuggets1 = at::empty({cnt, 2}, octree.options().dtype(at::kInt));
uint2* nuggets1_ptr = reinterpret_cast<uint2*>(nuggets1.data_ptr<int>());

depths1 = at::empty({cnt, depth_dim}, octree.options().dtype(at::kFloat));
float* depth1_ptr = depths1.data_ptr<float>();

if (cnt == 0)
{
num = cnt;
break; // miss everything
// miss everything
if (cnt == 0) {
num = 0;
nuggets0 = nuggets1;
if (return_depth) depths1 = at::empty({0, depth_dim}, octree.options().dtype(at::kFloat));
break;
}

// Subdivide if more levels remain, repeat
if (l < target_level) {
subdivide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, nuggets0_ptr, nuggets1_ptr, ray_o_ptr, points_ptr,
num, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()), reinterpret_cast<uint2*>(nuggets1.data_ptr<int>()), ray_o_ptr, points_ptr,
octree_ptr, exclusive_sum_ptr, info_ptr, prefix_sum_ptr, l);
} else {
compactify_cuda_kernel<uint2><<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, nuggets0_ptr, nuggets1_ptr,
num, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()), reinterpret_cast<uint2*>(nuggets1.data_ptr<int>()),
info_ptr, prefix_sum_ptr);
if (return_depth) {
depths1 = at::empty({cnt, depth_dim}, octree.options().dtype(at::kFloat));

if (with_exit) {
compactify_cuda_kernel<float2><<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, reinterpret_cast<float2*>(depth0_ptr),
reinterpret_cast<float2*>(depth1_ptr),
num, reinterpret_cast<float2*>(depths0.data_ptr<float>()),
reinterpret_cast<float2*>(depths1.data_ptr<float>()),
info_ptr, prefix_sum_ptr);
} else {
compactify_cuda_kernel<float><<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, depth0_ptr, depth1_ptr,
num, depths0.data_ptr<float>(), depths1.data_ptr<float>(),
info_ptr, prefix_sum_ptr);
}
}
}

CubDebugExit(cudaGetLastError());

nuggets0_ptr = nuggets1_ptr;
depth0_ptr = depth1_ptr;

nuggets0 = nuggets1;
num = cnt;
}

if (return_depth) {
return { nuggets1.index({Slice(0, num)}).contiguous(),
return { nuggets0.index({Slice(0, num)}).contiguous(),
depths1.index({Slice(0, num)}).contiguous() };
} else {
return { nuggets1.index({Slice(0, num)}).contiguous() };
return { nuggets0.index({Slice(0, num)}).contiguous() };
}
}

Expand Down
57 changes: 56 additions & 1 deletion tests/python/kaolin/render/spc/test_raytrace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
# Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -230,6 +230,61 @@ def test_raytrace_with_depth_with_exit(self, octree, point_hierarchy, pyramid, e

assert torch.equal(depth, expected_depth)

@pytest.mark.parametrize('return_depth,with_exit', [(False, False), (True, False), (True, True)])
def test_raytrace_inside(self, octree, point_hierarchy, pyramid, exsum, return_depth, with_exit):
height = 4
width = 4
direction = torch.tensor([[0., 0., -1.]], dtype=torch.float,
device='cuda').repeat(height * width , 1)
origin = self._generate_rays_origin(height, width, 0.9)
outputs = unbatched_raytrace(
octree, point_hierarchy, pyramid, exsum, origin, direction, 2,
return_depth=return_depth, with_exit=with_exit)

ridx = outputs[0]
pidx = outputs[1]

expected_nuggets = torch.tensor([
[ 0, 13],
[ 0, 6],
[ 0, 5],
[ 1, 8],
[ 1, 7],
[ 2, 15],
[ 4, 10],
[ 4, 9],
[ 5, 12],
[ 5, 11]], device='cuda', dtype=torch.int)
assert torch.equal(ridx, expected_nuggets[...,0])
assert torch.equal(pidx, expected_nuggets[...,1])
if return_depth:
depth = outputs[2]
if with_exit:
expected_depth = torch.tensor([
[0.4, 0.9],
[0.9, 1.4],
[1.4, 1.9],
[0.9, 1.4],
[1.4, 1.9],
[1.4, 1.9],
[0.9, 1.4],
[1.4, 1.9],
[0.9, 1.4],
[1.4, 1.9]], device='cuda', dtype=torch.float)
else:
expected_depth = torch.tensor([
[0.4],
[0.9],
[1.4],
[0.9],
[1.4],
[1.4],
[0.9],
[1.4],
[0.9],
[1.4]], device='cuda', dtype=torch.float)
assert torch.allclose(depth, expected_depth)

def test_ambiguous_raytrace(self):
# TODO(ttakikawa):
# Since 0.10.0, the behaviour of raytracing exactly between voxels
Expand Down

0 comments on commit bcce7a9

Please sign in to comment.