Skip to content

Commit

Permalink
fix trilinear interpolation gradient issue when point is not on voxel (
Browse files Browse the repository at this point in the history
…#684)

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>
  • Loading branch information
Caenorst committed Feb 14, 2023
1 parent a7cc9b3 commit 5c50961
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 42 deletions.
25 changes: 15 additions & 10 deletions kaolin/ops/spc/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,14 @@ def backward(ctx, grad_output):
# Shape (N, 8, 3) tensor of @(coeffs)/@(xyz)
grad_coeffs_by_xyz = grad_coeffs_by_xyz.reshape(-1, 8, 3)
# Shape (N, D, 8) tensor of @(feats_out)/@(coeffs)
grad_fout_by_coeffs = feats[selected_trinkets.long()].permute(0,2,1)
grad_fout_by_coeffs = feats[selected_trinkets.long()].permute(0, 2, 1)
# Shape (N, D, 3) tensor of @(feats_out)/@(xyz), after applying chain rule
grad_fout_by_xyz = grad_fout_by_coeffs @ grad_coeffs_by_xyz
_grad_fout_by_xyz = grad_fout_by_coeffs @ grad_coeffs_by_xyz
# Shape (N, 1, 3) tensor of @(out)/@(xyz) applying chain rule again
grad_fout_by_xyz = torch.zeros(
(grad_output.shape[0], *_grad_fout_by_xyz.shape[1:]),
device='cuda', dtype=_grad_fout_by_xyz.dtype)
grad_fout_by_xyz[mask] = _grad_fout_by_xyz
grad_coords = grad_output @ grad_fout_by_xyz
return grad_coords, None, None, None, grad_feats, None

Expand Down Expand Up @@ -291,10 +295,10 @@ def coords_to_trilinear(coords, points, level):
with ``features`` of shape :math:`(\text{num_points}, 8)`
Args:
coords (torch.FloatTensor): 3D coordinates of shape :math:`(\text{num_points}, 3)`
coords (torch.FloatTensor): 3D coordinates of shape :math:`(\text{num_coords}, 3)`
in normalized space [-1, 1].
points (torch.ShortTensor): Quantized 3D points (the 0th bit of the voxel x is in),
of shape :math:`(\text{num_coords}, 3)`.
of shape :math:`(\text{num_points}, 3)`.
level (int): The level of SPC to interpolate on.
Returns:
Expand All @@ -317,23 +321,24 @@ def coords_to_trilinear_coeffs(coords, points, level):
with ``features`` of shape :math:`(\text{num_points}, 8)`
Args:
coords (torch.FloatTensor): 3D coordinates of shape :math:`(\text{num_points}, 3)`
coords (torch.FloatTensor): 3D coordinates of shape :math:`(\text{num_coords}, 3)`
in normalized space [-1, 1].
points (torch.ShortTensor): Quantized 3D points (the 0th bit of the voxel x is in),
of shape :math:`(\text{num_coords}, 3)`.
of shape :math:`(\text{num_points}, 3)`.
level (int): The level of SPC to interpolate on.
Returns:
(torch.FloatTensor):
The trilinear interpolation coefficients of shape :math:`(\text{num_points}, 8)`.
The trilinear interpolation coefficients of shape :math:`(\text{num_coords}, 8)`.
"""
shape = list(points.shape)
shape = list(coords.shape)
shape[-1] = 8
points = points.reshape(-1, 3)
coords = coords.reshape(-1, 3)
coords_ = (2**level) * (coords * 0.5 + 0.5)
coords_ = (2 ** level) * (coords * 0.5 + 0.5)

return _C.ops.spc.coords_to_trilinear_cuda(coords_.contiguous(), points.contiguous()).reshape(*shape)
return _C.ops.spc.coords_to_trilinear_cuda(
coords_.contiguous(), points.contiguous()).reshape(*shape)


def create_dense_spc(level, device):
Expand Down
5 changes: 5 additions & 0 deletions kaolin/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,8 @@ def contained_torch_equal(elem, other):
else:
return elem == other

def check_allclose(tensor, other, rtol=1e-5, atol=1e-8, equal_nan=False):
if not torch.allclose(tensor, other, atol, rtol, equal_nan):
diff_idx = torch.where(~torch.isclose(tensor, other, atol, rtol, equal_nan))
raise ValueError(f"Tensors are not close on indices {diff_idx}:",
f"Example values: {tensor[diff_idx][:10]} vs {other[diff_idx][:10]}.")
92 changes: 60 additions & 32 deletions tests/python/kaolin/ops/spc/test_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch

from kaolin.utils.testing import check_allclose
from kaolin.ops.spc import points_to_morton, morton_to_points, points_to_corners, \
coords_to_trilinear_coeffs, quantize_points, unbatched_query, \
scan_octrees, unbatched_points_to_octree, generate_points, \
Expand Down Expand Up @@ -90,12 +91,15 @@ def test_coords_to_trilinear_coeffs(self, points):
], dim=-1)

level = 3
coords = (x / (2**level)) * 2.0 - 1.0
assert torch.allclose(coords_to_trilinear_coeffs(coords, points, level), expected_coeffs, atol=1e-5)
coords = (x / (2 ** level)) * 2.0 - 1.0
check_allclose(coords_to_trilinear_coeffs(coords, points, level), expected_coeffs, atol=1e-5)

def test_interpolate_trilinear_forward(self, points):
w = torch.rand(points.shape, device='cuda')
x = points + w
x = torch.cat([
points + w,
-torch.rand((4, 3), device='cuda')
], dim=0)

level = 3

Expand All @@ -108,22 +112,28 @@ def test_interpolate_trilinear_forward(self, points):
point_hierarchy_dual, pyramid_dual = unbatched_make_dual(point_hierarchy, pyramid)
trinkets, parents = unbatched_make_trinkets(point_hierarchy, pyramid, point_hierarchy_dual, pyramid_dual)

coords = (x / (2**level)) * 2.0 - 1.0
coords = (x / (2 ** level)) * 2.0 - 1.0
pidx = unbatched_query(octree, prefix, coords, level, with_parents=False)

feats = torch.rand([pyramid_dual[0, level], 16], device='cuda')

corner_feats = feats.index_select(0, trinkets[pidx].view(-1)).view(-1, 8, 16)
coeffs = coords_to_trilinear_coeffs(coords, points, level)
expected_results = (corner_feats * coeffs[..., None]).sum(-2)
expected_results[points.shape[0]:] = 0.

results = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level)[:, 0]
results = unbatched_interpolate_trilinear(
coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level
)[:, 0]

assert torch.allclose(results, expected_results, rtol=1e-5, atol=1e-5)
check_allclose(results, expected_results, rtol=1e-5, atol=1e-5)

def test_interpolate_trilinear_forward_dtypes(self, points):
w = torch.rand(points.shape, device='cuda')
x = points + w
x = torch.cat([
points + w,
-torch.rand((4, 3), device='cuda')
], dim=0)

level = 3

Expand All @@ -141,20 +151,19 @@ def test_interpolate_trilinear_forward_dtypes(self, points):

feats = torch.rand([pyramid_dual[0, level], 16], device='cuda')

corner_feats = feats.index_select(0, trinkets[pidx].view(-1)).view(-1, 8, 16)
coeffs = coords_to_trilinear_coeffs(coords, points, level)
expected_results = (corner_feats * coeffs[..., None]).sum(-2)

results_float = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level)[:, 0]
results_double = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats.double(), level)[:, 0]
results_half = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats.half(), level)[:, 0]

assert torch.allclose(results_float, results_double.float(), rtol=1e-4, atol=1e-4)
assert torch.allclose(results_float.half(), results_half, rtol=1e-3, atol=1e-3)
check_allclose(results_float, results_double.float(), rtol=1e-4, atol=1e-4)
check_allclose(results_float.half(), results_half, rtol=1e-3, atol=1e-3)

def test_interpolate_trilinear_backward(self, points):
w = torch.rand(points.shape, device='cuda')
x = points + w
x = torch.cat([
points + w,
-torch.rand((4, 3), device='cuda')
], dim=0)

level = 3

Expand All @@ -167,7 +176,7 @@ def test_interpolate_trilinear_backward(self, points):
point_hierarchy_dual, pyramid_dual = unbatched_make_dual(point_hierarchy, pyramid)
trinkets, parents = unbatched_make_trinkets(point_hierarchy, pyramid, point_hierarchy_dual, pyramid_dual)

coords = (x / (2**level)) * 2.0 - 1.0
coords = (x / (2 ** level)) * 2.0 - 1.0
pidx = unbatched_query(octree, prefix, coords, level, with_parents=False)

feats = torch.rand([pyramid_dual[0, level], 16], device='cuda')
Expand All @@ -179,6 +188,7 @@ def test_interpolate_trilinear_backward(self, points):
corner_feats = feats.index_select(0, trinkets[pidx].view(-1)).view(-1, 8, 16)
coeffs = coords_to_trilinear_coeffs(coords, points, level)
expected_results = (corner_feats * coeffs[..., None]).sum(-2)
expected_results[points.shape[0]:] = 0.

loss = expected_results.sum()
loss.backward()
Expand All @@ -188,16 +198,21 @@ def test_interpolate_trilinear_backward(self, points):
feats.grad.detach_()
feats.grad.zero_()

results = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level)[:, 0]
results = unbatched_interpolate_trilinear(
coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level
)[:, 0]
loss = results.sum()
loss.backward()
grad = feats.grad.clone()

assert torch.allclose(grad, expected_grad, rtol=1e-5, atol=1e-5)
check_allclose(grad, expected_grad, rtol=1e-5, atol=1e-5)

def test_interpolate_trilinear_by_coords_backward(self, points):
w = torch.rand(points.shape, device='cuda')
x = points + w
x = torch.cat([
points + w,
-torch.rand((4, 3), device='cuda')
], dim=0)

level = 3

Expand All @@ -208,43 +223,52 @@ def test_interpolate_trilinear_by_coords_backward(self, points):

pyramid = pyramid[0]
point_hierarchy_dual, pyramid_dual = unbatched_make_dual(point_hierarchy, pyramid)
trinkets, parents = unbatched_make_trinkets(point_hierarchy, pyramid, point_hierarchy_dual, pyramid_dual)
trinkets, parents = unbatched_make_trinkets(
point_hierarchy, pyramid, point_hierarchy_dual, pyramid_dual)

coords = (x / (2 ** level)) * 2.0 - 1.0
pidx = unbatched_query(octree, prefix, coords, level, with_parents=False)
feats = torch.rand([pyramid_dual[0, level], 16], device='cuda')

# w is the relative position inside a cell
w = w.detach()
w.requires_grad_(True)
if w.grad is not None:
w.grad.detach()
w.grad.zero_()

# (5, 8, 16)
corner_feats = feats.index_select(0, trinkets[pidx].view(-1)).view(-1, 8, 16)
corner_feats[points.shape[0]:] = 0.

# (5, 8)
expected_coeffs = torch.stack([
(1 - w[:, 0]) * (1 - w[:, 1]) * (1 - w[:, 2]),
(1 - w[:, 0]) * (1 - w[:, 1]) * w[:, 2],
(1 - w[:, 0]) * w[:, 1] * (1 - w[:, 2]),
(1 - w[:, 0]) * w[:, 1] * w[:, 2],
w[:, 0] * (1 - w[:, 1]) * (1 - w[:, 2]),
w[:, 0] * (1 - w[:, 1]) * w[:, 2],
w[:, 0] * w[:, 1] * (1 - w[:, 2]),
w[:, 0] * w[:, 1] * w[:, 2]
], dim=-1)
expected_coeffs = torch.cat([torch.stack([
(1 - w[:, 0]) * (1 - w[:, 1]) * (1 - w[:, 2]),
(1 - w[:, 0]) * (1 - w[:, 1]) * w[:, 2],
(1 - w[:, 0]) * w[:, 1] * (1 - w[:, 2]),
(1 - w[:, 0]) * w[:, 1] * w[:, 2],
w[:, 0] * (1 - w[:, 1]) * (1 - w[:, 2]),
w[:, 0] * (1 - w[:, 1]) * w[:, 2],
w[:, 0] * w[:, 1] * (1 - w[:, 2]),
w[:, 0] * w[:, 1] * w[:, 2]
], dim=-1),
torch.zeros((4, 8), device='cuda', dtype=torch.float)
], dim=0)
expected_coeffs = expected_coeffs.requires_grad_(True) # prevents element0 error
expected_results = (corner_feats * expected_coeffs[..., None]).sum(1)
expected_results[points.shape[0]:] = 0.

loss = expected_results.sum()
loss.backward()
expected_grad = w.grad.clone()
expected_grad = torch.zeros_like(x)
expected_grad[:points.shape[0]] = w.grad.clone()

coords.requires_grad_(True)
if coords.grad is not None:
coords.grad.detach()
coords.grad.zero_()
results = unbatched_interpolate_trilinear(coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level)
results = unbatched_interpolate_trilinear(
coords[:, None], pidx.int(), point_hierarchy, trinkets, feats, level)
loss = results[:, 0].sum()
loss.backward()
coords_grad = coords.grad.clone()
Expand All @@ -254,7 +278,11 @@ def test_interpolate_trilinear_by_coords_backward(self, points):
def test_interpolate_trilinear_by_coords_toggleable(self, points):
# Test that features only grad does not generate coords grad
w = torch.rand(points.shape, device='cuda')
x = points + w
x = torch.cat([
points + w,
-torch.rand((4, 3), device='cuda')
], dim=0)


level = 3

Expand Down

0 comments on commit 5c50961

Please sign in to comment.