Skip to content

Commit

Permalink
fix issue with unbatched_mesh_intersection_cuda, separate tests and i…
Browse files Browse the repository at this point in the history
…mprove coverage of edge cases (#588)

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>

add math .cuh files

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>
  • Loading branch information
Caenorst committed Jul 6, 2022
1 parent abe0069 commit b3e159b
Show file tree
Hide file tree
Showing 9 changed files with 662 additions and 307 deletions.
97 changes: 97 additions & 0 deletions kaolin/csrc/2d_math.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef KAOLIN_2D_MATH_CUH_
#define KAOLIN_2D_MATH_CUH_

#include <type_traits>


namespace kaolin {

// TODO(cfujitsang): at some point we need coverage of fp16 and integers but it might be trickier
template<typename T>
struct ScalarTypeToVec2 { using type = void; };
template <> struct ScalarTypeToVec2<float> { using type = float2; };
template <> struct ScalarTypeToVec2<double> { using type = double2; };

template<typename V>
struct Vec2TypeToScalar { using type = void; };
template <> struct Vec2TypeToScalar<float2> { using type = float; };
template <> struct Vec2TypeToScalar<double2> { using type = double; };

template<typename T>
struct IsVec2Type: std::false_type {};
template <> struct IsVec2Type<float2>: std::true_type {};
template <> struct IsVec2Type<double2>: std::true_type {};

__device__
static __forceinline__ float2 make_vec2(float x, float y) {
return make_float2(x, y);
}

__device__
static __forceinline__ double2 make_vec2(double x, double y) {
return make_double2(x, y);
}

template<typename V,
typename T = typename Vec2TypeToScalar<V>::type,
std::enable_if_t<IsVec2Type<V>::value>* = nullptr>
__device__ __forceinline__ T dot(const V a, const V b) {
return a.x * b.x + a.y * b.y;
}

template<typename V, std::enable_if_t<IsVec2Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator- (V a, const V& b) {
a.x -= b.x;
a.y -= b.y;
return a;
}

template<typename V, std::enable_if_t<IsVec2Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator+ (V a, const V& b) {
a.x += b.x;
a.y += b.y;
return a;
}

template<typename V, std::enable_if_t<IsVec2Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator* (V a, const V& b) {
a.x *= b.x;
a.y *= b.y;
return a;
}

template<typename V, std::enable_if_t<IsVec2Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator/ (V a, const V& b) {
a.x /= b.x;
a.y /= b.y;
return a;
}

template<typename V, std::enable_if_t<IsVec2Type<V>::value>* = nullptr>
__device__
static __forceinline__ bool operator== (const V& a, const V& b) {
return a.x == b.x && a.y == b.y;
}

}

#endif // KAOLIN_2D_MATH_CUH_
109 changes: 109 additions & 0 deletions kaolin/csrc/3d_math.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef KAOLIN_3D_MATH_CUH_
#define KAOLIN_3D_MATH_CUH_

#include <type_traits>


namespace kaolin {

// TODO(cfujitsang): at some point we need coverage of fp16 and integers but it might be trickier
template<typename T>
struct ScalarTypeToVec3 { using type = void; };
template <> struct ScalarTypeToVec3<float> { using type = float3; };
template <> struct ScalarTypeToVec3<double> { using type = double3; };

template<typename V>
struct Vec3TypeToScalar { using type = void; };
template <> struct Vec3TypeToScalar<float3> { using type = float; };
template <> struct Vec3TypeToScalar<double3> { using type = double; };

template<typename T>
struct IsVec3Type: std::false_type {};
template <> struct IsVec3Type<float3>: std::true_type {};
template <> struct IsVec3Type<double3>: std::true_type {};

__device__
static __forceinline__ float3 make_vec3(float x, float y, float z) {
return make_float3(x, y, z);
}

__device__
static __forceinline__ double3 make_vec3(double x, double y, double z) {
return make_double3(x, y, z);
}

template<typename V,
typename T = typename Vec3TypeToScalar<V>::type,
std::enable_if_t<IsVec3Type<V>::value>* = nullptr>
__device__ __forceinline__ T dot(const V a, const V b) {
return a.x * b.x + a.y * b.y + a.z * b.z;
}

template<typename V, std::enable_if_t<IsVec3Type<V>::value>* = nullptr>
__device__
static __forceinline__ V cross(const V& a, const V& b) {
return make_vec3(a.y * b.z - a.z * b.y,
a.z * b.x - a.x * b.z,
a.x * b.y - a.y * b.x);
}

template<typename V, std::enable_if_t<IsVec3Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator- (V a, const V& b) {
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
return a;
}

template<typename V, std::enable_if_t<IsVec3Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator+ (V a, const V& b) {
a.x += b.x;
a.y += b.y;
a.z += b.z;
return a;
}

template<typename V, std::enable_if_t<IsVec3Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator* (V a, const V& b) {
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
return a;
}

template<typename V, std::enable_if_t<IsVec3Type<V>::value>* = nullptr>
__device__
static __forceinline__ V operator/ (V a, const V& b) {
a.x /= b.x;
a.y /= b.y;
a.z /= b.z;
return a;
}

template<typename V, std::enable_if_t<IsVec3Type<V>::value>* = nullptr>
__device__
static __forceinline__ bool operator== (const V& a, const V& b) {
return a.x == b.x && a.y == b.y && a.z == b.z;
}

}

#endif // KAOLIN_3D_MATH_CUH_
83 changes: 36 additions & 47 deletions kaolin/csrc/ops/mesh/mesh_intersection.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2019,20-22 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,61 +20,49 @@
namespace kaolin {

#ifdef WITH_CUDA
void UnbatchedMeshIntersectionKernelLauncher(
const float* points,
const float* verts_1,
const float* verts_2,
const float* verts_3,
const int n,
const int m,
float* result);
void unbatched_mesh_intersection_cuda_impl(
const at::Tensor points,
const at::Tensor verts_1,
const at::Tensor verts_2,
const at::Tensor verts_3,
at::Tensor result);
#endif


void unbatched_mesh_intersection_cuda(
at::Tensor unbatched_mesh_intersection_cuda(
const at::Tensor points,
const at::Tensor verts_1,
const at::Tensor verts_2,
const at::Tensor verts_3,
const at::Tensor ints) {
CHECK_CUDA(points);
CHECK_CUDA(verts_1);
CHECK_CUDA(verts_2);
CHECK_CUDA(verts_3);
CHECK_CUDA(ints);
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(verts_1);
CHECK_CONTIGUOUS(verts_2);
CHECK_CONTIGUOUS(verts_3);
CHECK_CONTIGUOUS(ints);

TORCH_CHECK(verts_1.size(0) == verts_2.size(0), "vert_1 and verts_2 must have the same number of points.");
TORCH_CHECK(verts_1.size(0) == verts_3.size(0), "vert_1 and verts_3 must have the same number of points.");
TORCH_CHECK(ints.size(0) == points.size(0), "ints and points must have the same number of points.");

TORCH_CHECK(verts_1.dim() == 2, "verts_1 must have a dimension of 2.");
TORCH_CHECK(verts_2.dim() == 2, "verts_2 must have a dimension of 2.");
TORCH_CHECK(verts_3.dim() == 2, "verts_3 must have a dimension of 2.");
TORCH_CHECK(points.dim() == 2, "points must have a dimension of 2.");
TORCH_CHECK(ints.dim() == 1, "ints must have a dimension of 1.");

TORCH_CHECK(verts_1.size(1) == 3, "verts_1's last dimension must be 3.");
TORCH_CHECK(verts_2.size(1) == 3, "verts_2's last dimension must be 3.");
TORCH_CHECK(verts_3.size(1) == 3, "verts_3's last dimension must be 3.");
TORCH_CHECK(points.size(1) == 3, "points's last dimension must be 3.");

TORCH_CHECK(verts_1.dtype() == verts_2.dtype(), "verts_1 and verts_2's dtype must be the same.");
TORCH_CHECK(verts_1.dtype() == verts_3.dtype(), "verts_1 and verts_3's dtype must be the same.");
TORCH_CHECK(verts_1.dtype() == points.dtype(), "verts_1 and points's dtype must be the same.");
const at::Tensor verts_3) {

at::TensorArg points_arg{points, "points", 1};
at::TensorArg verts_1_arg{verts_1, "verts_1", 2};
at::TensorArg verts_2_arg{verts_2, "verts_2", 3};
at::TensorArg verts_3_arg{verts_3, "verts_3", 4};

const int num_points = points.size(0);
const int num_vertices = verts_1.size(0);

at::checkAllSameGPU(__func__, {
points_arg, verts_1_arg, verts_2_arg, verts_3_arg});
at::checkAllContiguous(__func__, {
points_arg, verts_1_arg, verts_2_arg, verts_3_arg});
at::checkAllSameType(__func__, {
points_arg, verts_1_arg, verts_2_arg, verts_3_arg});

at::checkSize(__func__, points_arg, {num_points, 3});
at::checkSize(__func__, verts_1_arg, {num_vertices, 3});
at::checkSize(__func__, verts_2_arg, {num_vertices, 3});
at::checkSize(__func__, verts_3_arg, {num_vertices, 3});

at::Tensor out = at::zeros({num_points}, points.options());

#ifdef WITH_CUDA
UnbatchedMeshIntersectionKernelLauncher(points.data<float>(), verts_1.data<float>(),
verts_2.data<float>(),verts_3.data<float>(), points.size(0),
verts_1.size(0),
ints.data<float>());
unbatched_mesh_intersection_cuda_impl(points, verts_1, verts_2, verts_3, out);
#else
AT_ERROR("unbatched_mesh_intersection is not built with CUDA");
#endif
KAOLIN_NO_CUDA_ERROR(__func__);
#endif
return out;
}

} // namespace kaolin
5 changes: 2 additions & 3 deletions kaolin/csrc/ops/mesh/mesh_intersection.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@

namespace kaolin {

void unbatched_mesh_intersection_cuda(
at::Tensor unbatched_mesh_intersection_cuda(
const at::Tensor points,
const at::Tensor verts_1,
const at::Tensor verts_2,
const at::Tensor verts_3,
const at::Tensor ints);
const at::Tensor verts_3);

} // namespace kaolin

Expand Down

0 comments on commit b3e159b

Please sign in to comment.