Skip to content

Commit

Permalink
Update docs for interpolate, deprecate older name (#575)
Browse files Browse the repository at this point in the history
Signed-off-by: Towaki Takikawa <ttakikawa@nvidia.com>
  • Loading branch information
tovacinni committed May 27, 2022
1 parent ca631b9 commit 2389781
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 12 deletions.
48 changes: 42 additions & 6 deletions kaolin/ops/spc/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'points_to_corners',
'unbatched_interpolate_trilinear',
'coords_to_trilinear',
'coords_to_trilinear_coeffs',
'unbatched_points_to_octree',
'quantize_points'
]
Expand Down Expand Up @@ -193,7 +194,7 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[4]:
# TODO(ttakikawa): Write a fused kernel
grad_feats = torch.zeros(ctx.feats_shape, device=grad_output.device, dtype=grad_output.dtype)
coeffs = coords_to_trilinear(coords[mask], selected_points[:, None].repeat(1, coords.shape[1], 1), level).type(grad_output.dtype)
coeffs = coords_to_trilinear_coeffs(coords[mask], selected_points[:, None].repeat(1, coords.shape[1], 1), level).type(grad_output.dtype)
grad_feats.index_add_(0, selected_trinkets.reshape(-1),
(coeffs[..., None] * grad_output[mask][..., None, :]).sum(1).reshape(-1, ctx.feats_shape[-1]))
return None, None, None, None, grad_feats, None
Expand All @@ -204,33 +205,68 @@ def unbatched_interpolate_trilinear(coords, pidx, point_hierarchy, trinkets, fea
Args:
coords (torch.FloatTensor): 3D coordinates of shape
:math:`(\text{num_coords}, \text{num_samples}, 3)`
in normalized space [-1, 1].
in normalized space [-1, 1]. ``num_samples`` indicates the number of
coordinates that are grouped inside the same SPC node for performance
optimization purposes. In many cases the ``pidx`` is
generated from :func:`kaolin.ops.spc.unbatched_query`
and so the ``num_samples`` will be 1.
pidx (torch.IntTensor): Index to the point hierarchy which contains the voxel
which the coords exists in. Tensor of shape
:math:`(\text{num_coords}, \text{num_samples})`
:math:`(\text{num_coords})`.
This can be computed with :func:`kaolin.ops.spc.unbatched_query`.
point_hierarchy (torch.ShortTensor):
The point hierarchy of shape :math:`(\text{num_points}, 3)`.
See :ref:`point_hierarchies <spc_points>` for a detailed description.
trinkets (torch.IntTensor): An indirection pointer (in practice, an index) to the feature
tensor of shape :math:`(\text{num_points}, 8})`.
tensor of shape :math:`(\text{num_points}, 8)`.
feats (torch.Tensor): Floating point feature vectors to interpolate of shape
:math:`(\text{num_feats}, \text{feature_dim})`.
level (int): The level of SPC to interpolate on.
Returns:
(torch.FloatTensor): Interpolated feature vectors of shape
:math:`(\text{num_voxels}, \text{num_samples}, \text{feature_dim}`.
(torch.FloatTensor):
Interpolated feature vectors of shape :math:`(\text{num_voxels}, \text{num_samples}, \text{feature_dim})`.
"""
return InterpolateTrilinear.apply(coords, pidx, point_hierarchy, trinkets, feats, level)

def coords_to_trilinear(coords, points, level):
r"""Calculates the coefficients for trilinear interpolation.
.. deprecated:: 0.11.0
This function is deprecated. Use :func:`coords_to_trilinear_coeffs`.
This calculates coefficients with respect to the dual octree, which represent the corners of the octree
where the features are stored.
To interpolate with the coefficients, do:
``torch.sum(features * coeffs, dim=-1)``
with ``features`` of shape :math:`(\text{num_points}, 8)`
Args:
coords (torch.FloatTensor): 3D coordinates of shape :math:`(\text{num_points}, 3)`
in normalized space [-1, 1].
points (torch.ShortTensor): Quantized 3D points (the 0th bit of the voxel x is in),
of shape :math:`(\text{num_coords}, 3)`.
level (int): The level of SPC to interpolate on.
Returns:
(torch.FloatTensor):
The trilinear interpolation coefficients of shape :math:`(\text{num_points}, 8)`.
"""
warnings.warn("coords_to_trilinear is deprecated, "
"please use kaolin.ops.spc.coords_to_trilinear_coeffs instead",
DeprecationWarning, stacklevel=2)
return coords_to_trilinear_coeffs(coords, points, level)

def coords_to_trilinear_coeffs(coords, points, level):
r"""Calculates the coefficients for trilinear interpolation.
This calculates coefficients with respect to the dual octree, which represent the corners of the octree
where the features are stored.
Expand Down
12 changes: 6 additions & 6 deletions tests/python/kaolin/ops/spc/test_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from kaolin.ops.spc import points_to_morton, morton_to_points, points_to_corners, \
coords_to_trilinear, quantize_points, unbatched_query, \
coords_to_trilinear_coeffs, quantize_points, unbatched_query, \
scan_octrees, unbatched_points_to_octree, generate_points, \
unbatched_make_trinkets, unbatched_make_dual, unbatched_interpolate_trilinear

Expand Down Expand Up @@ -75,7 +75,7 @@ def test_points_to_corners(self, points):
expected_corners = torch.stack(expected_corners, dim=-2)
assert torch.equal(points_to_corners(points), expected_corners)

def test_coords_to_trilinear(self, points):
def test_coords_to_trilinear_coeffs(self, points):
w = torch.rand(points.shape, device='cuda')
x = points + w
expected_coeffs = torch.stack([
Expand All @@ -91,7 +91,7 @@ def test_coords_to_trilinear(self, points):

level = 3
coords = (x / (2**level)) * 2.0 - 1.0
assert torch.allclose(coords_to_trilinear(coords, points, level), expected_coeffs, atol=1e-5)
assert torch.allclose(coords_to_trilinear_coeffs(coords, points, level), expected_coeffs, atol=1e-5)

def test_interpolate_trilinear_forward(self, points):
w = torch.rand(points.shape, device='cuda')
Expand All @@ -114,7 +114,7 @@ def test_interpolate_trilinear_forward(self, points):
feats = torch.rand([pyramid_dual[0, level], 16], device='cuda')

corner_feats = feats.index_select(0, trinkets[pidx].view(-1)).view(-1, 8, 16)
coeffs = coords_to_trilinear(coords, points, level)
coeffs = coords_to_trilinear_coeffs(coords, points, level)
expected_results = (corner_feats * coeffs[..., None]).sum(-2)

results = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level)[:, 0]
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_interpolate_trilinear_forward_dtypes(self, points):
feats = torch.rand([pyramid_dual[0, level], 16], device='cuda')

corner_feats = feats.index_select(0, trinkets[pidx].view(-1)).view(-1, 8, 16)
coeffs = coords_to_trilinear(coords, points, level)
coeffs = coords_to_trilinear_coeffs(coords, points, level)
expected_results = (corner_feats * coeffs[..., None]).sum(-2)

results_float = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level)[:, 0]
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_interpolate_trilinear_backward(self, points):
feats.grad.zero_()

corner_feats = feats.index_select(0, trinkets[pidx].view(-1)).view(-1, 8, 16)
coeffs = coords_to_trilinear(coords, points, level)
coeffs = coords_to_trilinear_coeffs(coords, points, level)
expected_results = (corner_feats * coeffs[..., None]).sum(-2)

loss = expected_results.sum()
Expand Down

0 comments on commit 2389781

Please sign in to comment.