From 687dcdd125ec46f3fa1414b75edd467e3be90dd6 Mon Sep 17 00:00:00 2001 From: Juncheng Date: Mon, 7 Mar 2022 15:15:36 +0800 Subject: [PATCH] Optimize ReduceSumFloatCudaKernel with GEMM (#7684) * Optimize ReduceSumFloatCudaKernel with GEMM * Fix WITH_CUDA --- oneflow/user/kernels/reduce_kernel.cpp | 97 ++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index 17dfbef87d1..8d1499b2717 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -23,6 +23,7 @@ limitations under the License. #ifdef WITH_CUDA #include "oneflow/core/ep/cuda/cuda_device.h" +#include "oneflow/core/ep/include/primitive/matmul.h" #endif // WITH_CUDA namespace oneflow { @@ -86,7 +87,6 @@ class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSu #define REGISTER_REDUCE_ARITHMETIC_KERNELS(device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_prod", BinaryFuncProd, device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_min", BinaryFuncMin, device, dtype) \ - REGISTER_REDUCE_XPU_KERNEL("reduce_sum", BinaryFuncSum, device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_max", BinaryFuncMax, device, dtype) #define REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(device) \ @@ -102,6 +102,22 @@ REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCPU) REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCUDA) #endif +#define REGISTER_REDUCE_SUM_KERNELS(device, dtype) \ + REGISTER_REDUCE_XPU_KERNEL("reduce_sum", BinaryFuncSum, device, dtype) + +#define REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(device) \ + REGISTER_REDUCE_SUM_KERNELS(device, double) \ + REGISTER_REDUCE_SUM_KERNELS(device, int8_t) \ + REGISTER_REDUCE_SUM_KERNELS(device, uint8_t) \ + REGISTER_REDUCE_SUM_KERNELS(device, int32_t) \ + REGISTER_REDUCE_SUM_KERNELS(device, int64_t) + +REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCPU) +#ifdef WITH_CUDA +REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCUDA) +#endif +REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, float) + #define REGISTER_REDUCE_LOGICAL_KERNELS(device) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, bool) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, bool) \ @@ -133,10 +149,12 @@ std::vector RegularAxis(const std::vector& axis) { void GetReduceSumLayout(const std::vector& axis, const ShapeView& in_shape, bool* is_axis_contiguous, int64_t* outer_size, int64_t* inner_size, int64_t* reduce_size) { - *is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size()); - *outer_size = in_shape.Count(0, axis.front()); - *inner_size = in_shape.Count(axis.back() + 1); - *reduce_size = in_shape.Count(axis.front(), axis.back() + 1); + if (!axis.empty()) { + *is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size()); + *outer_size = in_shape.Count(0, axis.front()); + *inner_size = in_shape.Count(axis.back() + 1); + *reduce_size = in_shape.Count(axis.front(), axis.back() + 1); + } } } // namespace @@ -236,6 +254,73 @@ REGISTER_USER_KERNEL("reduce_sum") return tmp_bytes; }); -#endif +class ReduceSumFloatCudaKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { + public: + ReduceSumFloatCudaKernel() = default; + ~ReduceSumFloatCudaKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + std::vector axis = RegularAxis(ctx->Attr>("axis")); + const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input_tensor", 0); + user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const ShapeView& in_shape = input_tensor->shape(); + if (input_tensor->shape().elem_cnt() == 0) { + if (output_tensor->shape().elem_cnt() != 0) { + Memset( + ctx->stream(), output_tensor->mut_dptr(), 0, + output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type())); + } + return; + } + bool is_axis_contiguous = false; + int64_t outer_size = 0, inner_size = 0, reduce_size = 0; + GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size); + const float* ones = nullptr; + auto* cuda_device = dynamic_cast(ctx->stream()->device()); + if (cuda_device != nullptr) { + ones = static_cast(cuda_device->GetConstOnes(DataType::kFloat, reduce_size)); + } + if ((!axis.empty()) && in_shape.NumAxes() > 0 && is_axis_contiguous + && (outer_size == 1 || inner_size == 1) && ones != nullptr) { + ep::primitive::BlasTransposeType trans_a = (inner_size == 1) + ? ep::primitive::BlasTransposeType::N + : ep::primitive::BlasTransposeType::T; + ep::primitive::BlasTransposeType trans_b = ep::primitive::BlasTransposeType::N; + const int32_t m = (inner_size == 1) ? outer_size : inner_size; + const int32_t n = 1; + const int32_t k = reduce_size; +#if CUDA_VERSION >= 11000 + CublasMathModeGuard guard(ctx->stream()->As()->cublas_handle()); + // disable tf32 + guard.SetMathMode(CUBLAS_DEFAULT_MATH); +#endif // defined(WITH_CUDA) && CUDA_VERSION >= 11000 + auto matmul = ep::primitive::NewPrimitive( + DeviceType::kCUDA, DataType::kFloat, trans_a, trans_b); + CHECK(matmul); + matmul->Launch(ctx->stream(), m, n, k, 1.0, input_tensor->dptr(), ones, 0.0, + output_tensor->mut_dptr()); + } else { + const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()}); + NdarrayReduce::Reduce( + ctx->stream(), XpuVarNdarray(reduced_shape, output_tensor->mut_dptr()), + XpuVarNdarray(input_tensor->shape(), input_tensor->dptr()), + XpuVarNdarray(tmp_buffer->shape(), tmp_buffer->mut_dptr())); + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +REGISTER_USER_KERNEL("reduce_sum") + .SetCreateFn() + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) + && (user_op::HobDataType("output_tensor", 0) == DataType::kFloat)) + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputTensorDesc("input_tensor", 0).shape(); + return GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); + }); + +#endif // WITH_CUDA } // namespace oneflow