Skip to content

Commit

Permalink
Unbatched PCL to SPC - docstring (#498)
Browse files Browse the repository at this point in the history
Signed-off-by: operel <operel@nvidia.com>

Co-authored-by: operel <operel@nvidia.com>
  • Loading branch information
Caenorst and operel committed Dec 17, 2021
1 parent 5422898 commit 3b8ce97
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 5 deletions.
60 changes: 59 additions & 1 deletion kaolin/ops/conversions/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import torch
from kaolin.ops.spc.points import quantize_points, points_to_morton, morton_to_points, unbatched_points_to_octree
from kaolin.rep.spc import Spc

__all__ = ['pointclouds_to_voxelgrids']
__all__ = ['pointclouds_to_voxelgrids', 'unbatched_pointcloud_to_spc']

def _base_points_to_voxelgrids(points, resolution, return_sparse=False):
r"""Converts points to voxelgrids. This is the base function for both trianglemeshes_to_voxelgrids
Expand Down Expand Up @@ -132,3 +134,59 @@ def pointclouds_to_voxelgrids(pointclouds, resolution, origin=None, scale=None,
vg = _base_points_to_voxelgrids(pointclouds, resolution, return_sparse=return_sparse)

return vg

def unbatched_pointcloud_to_spc(pointcloud, level, features=None):
r"""This function takes as input a single point-cloud - a set of continuous coordinates in 3D,
and coverts it into a :ref:`Structured Point Cloud (SPC)<spc>`, a compressed octree representation where
the point cloud coordinates are quantized to integer coordinates.
Point coordinates are expected to be normalized to the range [-1, 1].
If a point is out of the range [-1, 1] it will be clipped to it.
If ``features`` are specified, the current implementation will average features
of points that inhabit the same quantized bucket.
Args:
pointclouds (torch.Tensor):
An unbatched pointcloud with shape :math:`(\text{num_points}, 3)`.
Coordinates are expected to be normalized to the range [-1, 1].
level (int):
Maximum number of levels to use in octree hierarchy.
features (torch.Tensor, optional):
Feature vector containing information per point
:math:`(\text{num_points}, \text{feat_dim})`.
Returns:
(kaolin.rep.Spc):
A Structured Point Cloud (SPC) object, holding a single-item batch.
"""
points = quantize_points(pointcloud.contiguous(), level)

# Avoid duplications if cells occupy more than one point
unique, unique_keys, unique_counts = torch.unique(points.contiguous(), dim=0,
return_inverse=True, return_counts=True)

# Create octree hierarchy
morton, keys = torch.sort(points_to_morton(unique.contiguous()).contiguous())
points = morton_to_points(morton.contiguous())
octree = unbatched_points_to_octree(points, level, sorted=True)

# Organize features for octree leaf nodes
feat = None
if features is not None:
# Feature collision of multiple points sharing the same cell is consolidated here.
# Assumes mean averaging
feat_dtype = features.dtype
is_fp = features.is_floating_point()

# Promote to double precision dtype to avoid rounding errors
feat = torch.zeros(unique.shape[0], features.shape[1], device=features.device).double()
feat = feat.index_add_(0, unique_keys, features.double()) / unique_counts[..., None].double()
if not is_fp:
feat = torch.round(feat)
feat = feat.to(feat_dtype)
feat = feat[keys]

# A full SPC requires octree hierarchy + auxilary data structures
lengths = torch.tensor([len(octree)], dtype=torch.int32) # Single entry batch
return Spc(octrees=octree, lengths=lengths, features=feat)
71 changes: 69 additions & 2 deletions kaolin/rep/spc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,72 @@
from ..ops.batch import list_to_packed

class Spc(object):
"""Class containing all the Structured point clouds information"""
"""Data class holding all :ref:`Structured Point Cloud (SPC)<spc>` information.
This class supports batching through :ref:`packed<packed>` representation:
a single Spc object can pack multiple SPC structures of variable sizes.
SPC data structures are represented through the combination various tensors detailed below:
``octrees`` compress the information required to build a full SPC.
In practice, they are a low level structure which also constitute the
:ref:`core part<spc_octree>` of the SPC data structure.
``octrees`` are kept as a torch.ByteTensor, where each byte represents a single octree parent cell,
and each bit represents the occupancy of a child octree cell.
e.g: 8 bits for 8 cells.
Bits describe the octree cells in Morton Order::
. . . . . . . .
| . 3 . 7 | . 3 7
| . . . . . . . . ===> 1 5
| | . 1 . | 5 .
| | . . . . . . . .
| | | | | 2 6
. .|. . | . . . | ===> 0 4
.| 2 |. 6 . |
. . | . . . . . |
. | 0 . 4 . |
. . . . . . . .
If a cell is occupied, an additional cell byte may be generated in the next level,
up till the argument ``level``.
For example, a ``SPC.octrees`` field may, look as follows::
tensor([255, 128, 64, 32, 16, 8, 4, 2, 23], dtype=torch.uint8)
Here "octrees" represents an octree of 9 nodes.
The binary representation should be interpreted as follows::
Level #1, Path*, 11111111 (All cells are occupied, therefore 8 bytes are allocated for level 2)
Level #2, Path*-1, 10000000
Level #2, Path*-2, 01000000
Level #2, Path*-3, 00100000
Level #2, Path*-4, 00010000
Level #2, Path*-5, 00001000
Level #2, Path*-6, 00000100
Level #2, Path*-7, 00000010
Level #2, Path*-8, 00010111
``lengths`` is a tensor of integers required to support batching. Since we assume a packed representation,
all octree cells are shaped as a single stacked 1D tensor. ``lengths`` specifies the number of cells (bytes) each
octree uses.
``features`` represent an optional per-point feature vector.
When ``features`` is not ``None``, a feature is kept for each point at the highest-resolution level in the octree.
``max_level`` is an integer which specifies how many recursive levels an octree should have.
``point_hierarchies``, ``pyramid``, ``exsum`` are auxilary structures, which are generated upon request and
enable efficient indexing to SPC entries.
"""

KEYS = {'octrees', 'lengths', 'max_level', 'pyramids', 'exsum', 'point_hierarchies'}

def __init__(self, octrees, lengths, max_level=None, pyramids=None,
exsum=None, point_hierarchies=None):
exsum=None, point_hierarchies=None, features=None):
assert (isinstance(octrees, torch.Tensor) and octrees.dtype == torch.uint8 and
octrees.ndim == 1), "octrees must be a 1D ByteTensor."
assert (isinstance(lengths, torch.Tensor) and lengths.dtype == torch.int and
Expand Down Expand Up @@ -64,12 +124,19 @@ def __init__(self, octrees, lengths, max_level=None, pyramids=None,
assert point_hierarchies.device == octrees.device, \
"point_hierarchies must be on the same device than octrees."

if features is not None:
assert isinstance(features, torch.Tensor), \
"features must be a torch.Tensor"
assert features.device == octrees.device, \
"features must be on the same device as octrees."

self.octrees = octrees
self.lengths = lengths
self._max_level = max_level
self._pyramids = pyramids
self._exsum = exsum
self._point_hierarchies = point_hierarchies
self.features = features

# TODO(cfujitsang): could be interesting to separate into multiple functions
def _apply_scan_octrees(self):
Expand Down
160 changes: 158 additions & 2 deletions tests/python/kaolin/ops/conversions/test_pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import torch

from kaolin.ops.conversions import pointclouds_to_voxelgrids
from kaolin.utils.testing import FLOAT_TYPES
from kaolin.ops.conversions import pointclouds_to_voxelgrids, unbatched_pointcloud_to_spc
from kaolin.utils.testing import FLOAT_TYPES, BOOL_DTYPES, INT_DTYPES, FLOAT_DTYPES, ALL_DTYPES, check_spc_octrees

@pytest.mark.parametrize('device, dtype', FLOAT_TYPES)
class TestPointcloudToVoxelgrid:
Expand Down Expand Up @@ -139,3 +139,159 @@ def test_pointclouds_to_voxelgrids_scale(self, device, dtype):
output_vg = pointclouds_to_voxelgrids(pointclouds, 3, scale=torch.ones((2), device=device, dtype=dtype) * 4)

assert torch.equal(output_vg, expected_vg)


@pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('level', list(range(1, 6)))
class TestUnbatchedPointcloudToSpc:

@pytest.fixture
def pointcloud(self, device):
return torch.tensor([[-1, -1, -1],
[-1, -1, 0],
[0, -1, -1],
[-1, 0, -1],
[0, 0, 0],
[1, 1, 1],
[0.999, 0.999, 0.999]], device=device)


@pytest.fixture
def typed_pointcloud(self, pointcloud, dtype):
return pointcloud.to(dtype)

@pytest.fixture(autouse=True)
def expected_octree(self, device, level):
level_cutoff_mapping = [1, 6, 12, 18, 24]
level_cutoff = level_cutoff_mapping[level-1]
full_octree = torch.tensor([151,
1, 1, 1, 1, 129,
1, 1, 1, 1, 1, 128,
1, 1, 1, 1, 1, 128,
1, 1, 1, 1, 1, 128], device=device)
expected_octree = full_octree[:level_cutoff]
return expected_octree.byte()

@pytest.fixture(autouse=True)
def bool_features(self):
def _bool_features(device, booltype):
return torch.tensor([[0],
[1],
[1],
[1],
[0],
[1],
[1]], device=device).to(booltype)
return _bool_features

@pytest.fixture(autouse=True)
def expected_bool_features(self):
def _expected_bool_features(device, booltype, level):
if level == 1:
return torch.tensor([[0],
[1],
[1],
[1],
[1]], device=device).to(booltype)
else:
return torch.tensor([[0],
[1],
[1],
[1],
[0],
[1]], device=device).to(booltype)
return _expected_bool_features

@pytest.fixture(autouse=True)
def int_features(self):
def _int_features(device, inttype):
return torch.tensor([[1],
[4],
[7],
[10],
[20],
[37],
[1]], device=device).to(inttype)
return _int_features

@pytest.fixture(autouse=True)
def expected_int_features(self):
def _expected_int_features(device, inttype, level):
if level == 1:
return torch.tensor([[1],
[4],
[10],
[7],
[19]], device=device).to(inttype)
else:
return torch.tensor([[1],
[4],
[10],
[7],
[20],
[19]], device=device).to(inttype)
return _expected_int_features

@pytest.fixture(autouse=True)
def fp_features(self):
def _fp_features(device, fptype):
return torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 10, 10],
[20, 20, 20],
[37, 37, 37],
[1, 2, 3]], device=device).to(fptype)
return _fp_features

@pytest.fixture(autouse=True)
def expected_fp_features(self):
def _expected_fp_features(device, fptype, level):
if level == 1:
return torch.tensor([[1, 2, 3],
[4, 5, 6],
[10, 10, 10],
[7, 8, 9],
[58/3, 59/3, 60/3]], device=device).to(fptype)
else:
return torch.tensor([[1, 2, 3],
[4, 5, 6],
[10, 10, 10],
[7, 8, 9],
[20, 20, 20],
[19, 19.5, 20]], device=device).to(fptype)
return _expected_fp_features


@pytest.mark.parametrize('dtype', FLOAT_DTYPES)
def test_unbatched_pointcloud_to_spc(self, typed_pointcloud, level, expected_octree):
output_spc = unbatched_pointcloud_to_spc(typed_pointcloud, level)
assert check_spc_octrees(output_spc.octrees, output_spc.lengths,
batch_size=output_spc.batch_size,
level=level,
device=typed_pointcloud.device.type)
assert torch.equal(output_spc.octrees, expected_octree)

@pytest.mark.parametrize('booltype', BOOL_DTYPES)
def test_unbatched_pointcloud_to_spc_with_bool_features(self, pointcloud, device, booltype, level,
bool_features, expected_bool_features):
features_arg = bool_features(device, booltype)
expected_features_arg = expected_bool_features(device, booltype, level)
output_spc = unbatched_pointcloud_to_spc(pointcloud, level, features_arg)
assert torch.equal(output_spc.features, expected_features_arg)

@pytest.mark.parametrize('inttype', INT_DTYPES)
def test_unbatched_pointcloud_to_spc_with_int_features(self, pointcloud, device, inttype, level,
int_features, expected_int_features):
features_arg = int_features(device, inttype)
expected_features_arg = expected_int_features(device, inttype, level)
output_spc = unbatched_pointcloud_to_spc(pointcloud, level, features_arg)
assert torch.equal(output_spc.features, expected_features_arg)

@pytest.mark.parametrize('fptype', FLOAT_DTYPES)
def test_unbatched_pointcloud_to_spc_with_fp_features(self, pointcloud, device, fptype, level,
fp_features, expected_fp_features):
features_arg = fp_features(device, fptype)
expected_features_arg = expected_fp_features(device, fptype, level)
output_spc = unbatched_pointcloud_to_spc(pointcloud, level, features_arg)
assert torch.allclose(output_spc.features, expected_features_arg)

0 comments on commit 3b8ce97

Please sign in to comment.