Skip to content

Commit

Permalink
Add interpolation kernels + backprop and trinkets for interpolation (#…
Browse files Browse the repository at this point in the history
…567)

Signed-off-by: Towaki Takikawa <tovacinni@gmail.com>
  • Loading branch information
tovacinni committed May 19, 2022
1 parent 6569ea3 commit 006fbf6
Show file tree
Hide file tree
Showing 10 changed files with 564 additions and 28 deletions.
3 changes: 2 additions & 1 deletion kaolin/csrc/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020,21 NVIDIA CORPORATION & AFFILIATES.
// Copyright (c) 2020,21-22 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -50,6 +50,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops_spc.def("query_multiscale_cuda", &query_multiscale_cuda);
ops_spc.def("points_to_morton_cuda", &points_to_morton_cuda);
ops_spc.def("morton_to_points_cuda", &morton_to_points_cuda);
ops_spc.def("interpolate_trilinear_cuda", &interpolate_trilinear_cuda);
ops_spc.def("coords_to_trilinear_cuda", &coords_to_trilinear_cuda);
//ops_spc.def("coord_to_trilinear_jacobian_cuda", &coord_to_trilinear_jacobian_cuda);
ops_spc.def("points_to_corners_cuda", &points_to_corners_cuda);
Expand Down
42 changes: 41 additions & 1 deletion kaolin/csrc/ops/spc/point_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
// Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -25,6 +25,9 @@ namespace kaolin {
void morton_to_points_cuda_impl(at::Tensor morton_codes, at::Tensor points);
void points_to_morton_cuda_impl(at::Tensor points, at::Tensor morton_codes);
void points_to_corners_cuda_impl(at::Tensor points, at::Tensor corners);
void interpolate_trilinear_cuda_impl(at::Tensor coords, at::Tensor pidx,
at::Tensor points, at::Tensor trinkets,
at::Tensor feats_in, at::Tensor feats_out, int32_t level);
void coords_to_trilinear_cuda_impl(at::Tensor coord, at::Tensor points, at::Tensor coeffs);
//void coord_to_trilinear_jacobian_cuda_impl(at::Tensor coord);

Expand Down Expand Up @@ -78,6 +81,43 @@ at::Tensor points_to_corners_cuda(at::Tensor points) {
#endif // WITH_CUDA
}

at::Tensor interpolate_trilinear_cuda(
at::Tensor coords,
at::Tensor pidx,
at::Tensor points,
at::Tensor trinkets,
at::Tensor feats,
int32_t level) {
#ifdef WITH_CUDA
at::TensorArg coords_arg{coords, "coords", 1};
at::TensorArg pidx_arg{pidx, "pidx", 2};
at::TensorArg points_arg{points, "points", 3};
at::TensorArg trinkets_arg{trinkets, "trinkets", 4};
at::TensorArg feats_arg{feats, "feats", 5};
at::checkAllSameGPU(__func__, {coords_arg, pidx_arg, points_arg, trinkets_arg, feats_arg});
at::checkAllContiguous(__func__, {coords_arg, pidx_arg, points_arg, trinkets_arg, feats_arg});
at::checkScalarType(__func__, coords_arg, at::kFloat);
at::checkScalarType(__func__, pidx_arg, at::kInt);
at::checkScalarType(__func__, points_arg, at::kShort);
at::checkScalarType(__func__, trinkets_arg, at::kInt);
at::checkScalarTypes(__func__, feats_arg, {at::kHalf, at::kFloat, at::kDouble});
at::checkDim(__func__, coords_arg, 3);
at::checkDim(__func__, pidx_arg, 1);
at::checkDim(__func__, points_arg, 2);
at::checkDim(__func__, trinkets_arg, 2);
at::checkDim(__func__, feats_arg, 2);

int64_t num_voxels = coords.size(0);
int64_t num_samples = coords.size(1);
int64_t feat_dim = feats.size(1);
at::Tensor feats_out = at::zeros({num_voxels, num_samples, feat_dim}, feats.options());
interpolate_trilinear_cuda_impl(coords, pidx, points, trinkets, feats, feats_out, level);
return feats_out;
#else
KAOLIN_NO_CUDA_ERROR(__func__);
#endif // WITH_CUDA
}

at::Tensor coords_to_trilinear_cuda(at::Tensor coords, at::Tensor points) {
#ifdef WITH_CUDA
at::TensorArg coords_arg{coords, "coords", 1};
Expand Down
10 changes: 9 additions & 1 deletion kaolin/csrc/ops/spc/point_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES
* Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES
* All rights reserved.
*
* NVIDIA CORPORATION and its licensors retain all intellectual property
Expand All @@ -19,6 +19,14 @@ at::Tensor points_to_morton_cuda(at::Tensor points);

at::Tensor morton_to_points_cuda(at::Tensor morton_codes);

at::Tensor interpolate_trilinear_cuda(
at::Tensor coords,
at::Tensor pidx,
at::Tensor points,
at::Tensor trinkets,
at::Tensor feats,
int32_t level);

at::Tensor coords_to_trilinear_cuda(
at::Tensor coords,
at::Tensor points);
Expand Down
114 changes: 106 additions & 8 deletions kaolin/csrc/ops/spc/point_utils_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
// Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,6 +15,7 @@


#include <ATen/ATen.h>
#include <c10/cuda/CUDAGuard.h>

#include "../../spc_math.h"
#include "../../spc_utils.cuh"
Expand All @@ -40,10 +41,68 @@ __global__ void points_to_corners_cuda_kernel(
}
}

template<typename scalar_t>
__global__ void interpolate_trilinear_cuda_kernel(
const float3* coords, // num_voxels, num_samples, 3
const int32_t* pidx, // num_voxels
const point_data* points, // point_hierarchy_size, 3
const int32_t* trinkets, // point_hierarchy_size, 8
const scalar_t* feature_in, // num_feats, feature_dim
scalar_t* feature_out, // num_voxels, num_samples, feature_dim
const int64_t feature_dim,
const int32_t resolution,
const int64_t num_samples,
const int64_t num
){
int64_t idx = blockIdx.x*blockDim.x + threadIdx.x;
int64_t stride = blockDim.x*gridDim.x;
if (idx > num) return;

for (int32_t i=idx; i<num; i+=stride) {

int32_t _i = pidx[i / num_samples];

if (_i > -1) {
point_data point = points[_i];
int32_t trinket[8];

memcpy(&trinket, trinkets + (_i*8), sizeof(int32_t)*8);

float3 x_ = make_float3(resolution * (coords[i].x * 0.5 + 0.5) - point.x,
resolution * (coords[i].y * 0.5 + 0.5) - point.y,
resolution * (coords[i].z * 0.5 + 0.5) - point.z);
float3 _x = make_float3(1.0 - x_.x, 1.0 - x_.y, 1.0 - x_.z);

float c000 = _x.x * _x.y * _x.z;
float c001 = _x.x * _x.y * x_.z;
float c010 = _x.x * x_.y * _x.z;
float c011 = _x.x * x_.y * x_.z;
float c100 = x_.x * _x.y * _x.z;
float c101 = x_.x * _x.y * x_.z;
float c110 = x_.x * x_.y * _x.z;
float c111 = x_.x * x_.y * x_.z;

for (uint64_t j=0; j<feature_dim; ++j) {
scalar_t feat =
feature_in[trinket[0]*feature_dim+j] * c000 +
feature_in[trinket[1]*feature_dim+j] * c001 +
feature_in[trinket[2]*feature_dim+j] * c010 +
feature_in[trinket[3]*feature_dim+j] * c011 +
feature_in[trinket[4]*feature_dim+j] * c100 +
feature_in[trinket[5]*feature_dim+j] * c101 +
feature_in[trinket[6]*feature_dim+j] * c110 +
feature_in[trinket[7]*feature_dim+j] * c111;
feature_out[i*feature_dim+j] = feat;
}
}
}
}

template<typename scalar_t>
__global__ void coords_to_trilinear_cuda_kernel(
const float3* coords,
const point_data* points,
float* coeffs,
scalar_t* coeffs,
const int64_t num_coords
){
int64_t idx = blockIdx.x*blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -128,18 +187,57 @@ void points_to_morton_cuda_impl(at::Tensor points, at::Tensor morton_codes) {
num_points);
}

void interpolate_trilinear_cuda_impl(
at::Tensor coords,
at::Tensor pidx,
at::Tensor points,
at::Tensor trinkets,
at::Tensor feats_in,
at::Tensor feats_out,
int32_t level
){
int64_t num_voxels = coords.size(0);
int64_t num_samples = coords.size(1);
int64_t feat_dim = feats_in.size(1);
int64_t num = num_voxels * num_samples;
int32_t resolution = pow(2, level);

int num_threads = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(feats_in.type(), "interpolate_trilinear_cuda", ([&] {
const at::cuda::OptionalCUDAGuard device_guard(at::device_of(feats_out));
auto stream = at::cuda::getCurrentCUDAStream();
interpolate_trilinear_cuda_kernel<scalar_t><<<(num + num_threads - 1) / num_threads, num_threads, 0, stream>>>(
reinterpret_cast<float3*>(coords.data_ptr<float>()),
pidx.data_ptr<int32_t>(),
reinterpret_cast<point_data*>(points.data_ptr<short>()),
trinkets.data_ptr<int32_t>(),
feats_in.data_ptr<scalar_t>(),
feats_out.data_ptr<scalar_t>(),
feat_dim,
resolution,
num_samples,
num
);
}));
}


void coords_to_trilinear_cuda_impl(
at::Tensor coords,
at::Tensor points,
at::Tensor coeffs
) {
int64_t num_coords = coords.size(0);
coords_to_trilinear_cuda_kernel<<<(num_coords + 1023) / 1024, 1024>>>(
reinterpret_cast<float3*>(coords.data_ptr<float>()),
reinterpret_cast<point_data*>(points.data_ptr<short>()),
coeffs.data_ptr<float>(),
num_coords
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(coeffs.type(), "coords_to_trilinear_cuda", ([&] {
const at::cuda::OptionalCUDAGuard device_guard(at::device_of(coeffs));
auto stream = at::cuda::getCurrentCUDAStream();
coords_to_trilinear_cuda_kernel<scalar_t><<<(num_coords + 1023) / 1024, 1024, 0, stream>>>(
reinterpret_cast<float3*>(coords.data_ptr<float>()),
reinterpret_cast<point_data*>(points.data_ptr<short>()),
coeffs.data_ptr<scalar_t>(),
num_coords
);
}));
}

void coords_to_trilinear_jacobian_cuda_impl(
Expand Down
84 changes: 77 additions & 7 deletions kaolin/ops/spc/points.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
# Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,6 +18,7 @@
'points_to_morton',
'morton_to_points',
'points_to_corners',
'unbatched_interpolate_trilinear',
'coords_to_trilinear',
'unbatched_points_to_octree',
'quantize_points'
Expand Down Expand Up @@ -164,25 +165,94 @@ def points_to_corners(points):
shape.insert(-1, 8)
return _C.ops.spc.points_to_corners_cuda(points.contiguous()).reshape(*shape)

def coords_to_trilinear(coords, points):
class InterpolateTrilinear(torch.autograd.Function):

@staticmethod
def forward(ctx, coords, pidx, point_hierarchy, trinkets, feats, level):

feats_out = _C.ops.spc.interpolate_trilinear_cuda(coords.contiguous(), pidx.contiguous(),
point_hierarchy.contiguous(), trinkets.contiguous(),
feats.contiguous(), level)

ctx.save_for_backward(coords, pidx, point_hierarchy, trinkets)
ctx.level = level
ctx.feats_shape = feats.shape
return feats_out

@staticmethod
def backward(ctx, grad_output):
coords, pidx, point_hierarchy, trinkets = ctx.saved_tensors

level = ctx.level
mask = pidx > -1
selected_points = point_hierarchy.index_select(0, pidx[mask])
selected_trinkets = trinkets.index_select(0, pidx[mask])

# TODO(ttakikawa): Support backprop with respect to coords
grad_feats = None
if ctx.needs_input_grad[4]:
# TODO(ttakikawa): Write a fused kernel
grad_feats = torch.zeros(ctx.feats_shape, device=grad_output.device, dtype=grad_output.dtype)
coeffs = coords_to_trilinear(coords[mask], selected_points[:, None].repeat(1, coords.shape[1], 1), level).type(grad_output.dtype)
grad_feats.index_add_(0, selected_trinkets.reshape(-1),
(coeffs[..., None] * grad_output[mask][..., None, :]).sum(1).reshape(-1, ctx.feats_shape[-1]))
return None, None, None, None, grad_feats, None

def unbatched_interpolate_trilinear(coords, pidx, point_hierarchy, trinkets, feats, level):
r"""Performs trilinear interpolation on a SPC feature grid.
Args:
coords (torch.FloatTensor): 3D coordinates of shape
:math:`(\text{num_coords}, \text{num_samples}, 3)`
in normalized space [-1, 1].
pidx (torch.IntTensor): Index to the point hierarchy which contains the voxel
which the coords exists in. Tensor of shape
:math:`(\text{num_coords}, \text{num_samples})`
point_hierarchy (torch.ShortTensor):
The point hierarchy of shape :math:`(\text{num_points}, 3)`.
See :ref:`point_hierarchies <spc_points>` for a detailed description.
trinkets (torch.IntTensor): An indirection pointer (in practice, an index) to the feature
tensor of shape :math:`(\text{num_points}, 8})`.
feats (torch.Tensor): Floating point feature vectors to interpolate of shape
:math:`(\text{num_feats}, \text{feature_dim})`.
level (int): The level of SPC to interpolate on.
Returns:
(torch.FloatTensor): Interpolated feature vectors of shape
:math:`(\text{num_voxels}, \text{num_samples}, \text{feature_dim}`.
"""
return InterpolateTrilinear.apply(coords, pidx, point_hierarchy, trinkets, feats, level)

def coords_to_trilinear(coords, points, level):
r"""Calculates the coefficients for trilinear interpolation.
This calculates coefficients with respect to the dual octree, which represent the corners of the octree
where the features are stored.
To interpolate with the coefficients, do:
``torch.sum(features * coeffs, dim=-1)``
with ``features`` of shape :math:`(\text{num_points}, 8)`
Args:
coords (torch.FloatTensor): 3D points, of shape :math:`(\text{num_points}, 3)`.
coords (torch.FloatTensor): 3D coordinates of shape :math:`(\text{num_points}, 3)`
in normalized space [-1, 1].
points (torch.ShortTensor): Quantized 3D points (the 0th bit of the voxel x is in),
of shape :math:`(\text{num_points}, 3)`.
of shape :math:`(\text{num_coords}, 3)`.
level (int): The level of SPC to interpolate on.
Returns:
(torch.FloatTensor):
The trilinear interpolation coefficients,
of shape :math:`(\text{num_points}, 8)`.
The trilinear interpolation coefficients of shape :math:`(\text{num_points}, 8)`.
"""
shape = list(points.shape)
shape[-1] = 8
points = points.reshape(-1, 3)
coords = coords.reshape(-1, 3)
return _C.ops.spc.coords_to_trilinear_cuda(coords.contiguous(), points.contiguous()).reshape(*shape)
coords_ = (2**level) * (coords * 0.5 + 0.5)

return _C.ops.spc.coords_to_trilinear_cuda(coords_.contiguous(), points.contiguous()).reshape(*shape)

0 comments on commit 006fbf6

Please sign in to comment.