Skip to content

Commit

Permalink
fix scan_octrees (#653)
Browse files Browse the repository at this point in the history
Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>
  • Loading branch information
Caenorst committed Nov 14, 2022
1 parent 17491c8 commit 3f0ee1e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions kaolin/csrc/ops/spc/scan_octrees.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
// Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -54,7 +54,7 @@ int scan_octrees_cuda_impl(

void* temp_storage_ptr = NULL;
uint64_t temp_storage_bytes = get_cub_storage_bytes(
temp_storage_ptr, num_childrens_per_node_ptr, prefix_sum_ptr, num_childrens_per_node.size(0) + 1);
temp_storage_ptr, num_childrens_per_node_ptr, prefix_sum_ptr + 1, num_childrens_per_node.size(0));
at::Tensor temp_storage = at::zeros({(int64_t) temp_storage_bytes },
octrees.options().dtype(at::kByte));
temp_storage_ptr = (void*) temp_storage.data_ptr<uint8_t>();
Expand All @@ -73,9 +73,9 @@ int scan_octrees_cuda_impl(
// compute exclusive sum 1 element beyond end of list to get inclusive sum starting at prefix_sum_ptr+1
scan_nodes_cuda_kernel<<< (osize + (THREADS_PER_BLOCK - 1)) / THREADS_PER_BLOCK, THREADS_PER_BLOCK >>>(
osize, O0, num_childrens_per_node_ptr);
CubDebugExit(cub::DeviceScan::ExclusiveSum(
CubDebugExit(cub::DeviceScan::InclusiveSum(
temp_storage_ptr, temp_storage_bytes, num_childrens_per_node_ptr,
EX0, osize + 1)); // carful with the +1
EX0 + 1, osize));

int* Pmid = h0;
int* PmidSum = h0 + KAOLIN_SPC_MAX_LEVELS + 2;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/kaolin/ops/spc/test_spc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_unbatched_make_trinkets(self, octrees, lengths, max_level):

for i in range(1, max_level+1):
parent = point_hierarchy.index_select(0, unbatched_get_level_points(parents, pyramid, i))
assert torch.equal(parent, unbatched_get_level_points(point_hierarchy, pyramid, i)//2)
assert torch.equal(parent, torch.div(unbatched_get_level_points(point_hierarchy, pyramid, i), 2, rounding_mode='trunc'))


class TestQuery:
Expand Down

0 comments on commit 3f0ee1e

Please sign in to comment.