Skip to content

Commit

Permalink
Optimize ReduceSumFloatCudaKernel with GEMM (#7684)
Browse files Browse the repository at this point in the history
* Optimize ReduceSumFloatCudaKernel with GEMM

* Fix WITH_CUDA
  • Loading branch information
liujuncheng committed Mar 7, 2022
1 parent 61ee046 commit 687dcdd
Showing 1 changed file with 91 additions and 6 deletions.
97 changes: 91 additions & 6 deletions oneflow/user/kernels/reduce_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.

#ifdef WITH_CUDA
#include "oneflow/core/ep/cuda/cuda_device.h"
#include "oneflow/core/ep/include/primitive/matmul.h"
#endif // WITH_CUDA

namespace oneflow {
Expand Down Expand Up @@ -86,7 +87,6 @@ class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSu
#define REGISTER_REDUCE_ARITHMETIC_KERNELS(device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_prod", BinaryFuncProd, device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_min", BinaryFuncMin, device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_sum", BinaryFuncSum, device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_max", BinaryFuncMax, device, dtype)

#define REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(device) \
Expand All @@ -102,6 +102,22 @@ REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCPU)
REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCUDA)
#endif

#define REGISTER_REDUCE_SUM_KERNELS(device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_sum", BinaryFuncSum, device, dtype)

#define REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(device) \
REGISTER_REDUCE_SUM_KERNELS(device, double) \
REGISTER_REDUCE_SUM_KERNELS(device, int8_t) \
REGISTER_REDUCE_SUM_KERNELS(device, uint8_t) \
REGISTER_REDUCE_SUM_KERNELS(device, int32_t) \
REGISTER_REDUCE_SUM_KERNELS(device, int64_t)

REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCUDA)
#endif
REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, float)

#define REGISTER_REDUCE_LOGICAL_KERNELS(device) \
REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, bool) \
REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, bool) \
Expand Down Expand Up @@ -133,10 +149,12 @@ std::vector<int32_t> RegularAxis(const std::vector<int32_t>& axis) {
void GetReduceSumLayout(const std::vector<int32_t>& axis, const ShapeView& in_shape,
bool* is_axis_contiguous, int64_t* outer_size, int64_t* inner_size,
int64_t* reduce_size) {
*is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size());
*outer_size = in_shape.Count(0, axis.front());
*inner_size = in_shape.Count(axis.back() + 1);
*reduce_size = in_shape.Count(axis.front(), axis.back() + 1);
if (!axis.empty()) {
*is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size());
*outer_size = in_shape.Count(0, axis.front());
*inner_size = in_shape.Count(axis.back() + 1);
*reduce_size = in_shape.Count(axis.front(), axis.back() + 1);
}
}

} // namespace
Expand Down Expand Up @@ -236,6 +254,73 @@ REGISTER_USER_KERNEL("reduce_sum")
return tmp_bytes;
});

#endif
class ReduceSumFloatCudaKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
ReduceSumFloatCudaKernel() = default;
~ReduceSumFloatCudaKernel() = default;

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
std::vector<int32_t> axis = RegularAxis(ctx->Attr<std::vector<int32_t>>("axis"));
const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input_tensor", 0);
user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const ShapeView& in_shape = input_tensor->shape();
if (input_tensor->shape().elem_cnt() == 0) {
if (output_tensor->shape().elem_cnt() != 0) {
Memset<DeviceType::kCUDA>(
ctx->stream(), output_tensor->mut_dptr<float>(), 0,
output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()));
}
return;
}
bool is_axis_contiguous = false;
int64_t outer_size = 0, inner_size = 0, reduce_size = 0;
GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size);
const float* ones = nullptr;
auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());
if (cuda_device != nullptr) {
ones = static_cast<const float*>(cuda_device->GetConstOnes(DataType::kFloat, reduce_size));
}
if ((!axis.empty()) && in_shape.NumAxes() > 0 && is_axis_contiguous
&& (outer_size == 1 || inner_size == 1) && ones != nullptr) {
ep::primitive::BlasTransposeType trans_a = (inner_size == 1)
? ep::primitive::BlasTransposeType::N
: ep::primitive::BlasTransposeType::T;
ep::primitive::BlasTransposeType trans_b = ep::primitive::BlasTransposeType::N;
const int32_t m = (inner_size == 1) ? outer_size : inner_size;
const int32_t n = 1;
const int32_t k = reduce_size;
#if CUDA_VERSION >= 11000
CublasMathModeGuard guard(ctx->stream()->As<ep::CudaStream>()->cublas_handle());
// disable tf32
guard.SetMathMode(CUBLAS_DEFAULT_MATH);
#endif // defined(WITH_CUDA) && CUDA_VERSION >= 11000
auto matmul = ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(
DeviceType::kCUDA, DataType::kFloat, trans_a, trans_b);
CHECK(matmul);
matmul->Launch(ctx->stream(), m, n, k, 1.0, input_tensor->dptr(), ones, 0.0,
output_tensor->mut_dptr());
} else {
const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()});
NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(
ctx->stream(), XpuVarNdarray<float>(reduced_shape, output_tensor->mut_dptr<float>()),
XpuVarNdarray<const float>(input_tensor->shape(), input_tensor->dptr<float>()),
XpuVarNdarray<float>(tmp_buffer->shape(), tmp_buffer->mut_dptr<float>()));
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

REGISTER_USER_KERNEL("reduce_sum")
.SetCreateFn<ReduceSumFloatCudaKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)
&& (user_op::HobDataType("output_tensor", 0) == DataType::kFloat))
.SetInferTmpSizeFn([](user_op::InferContext* ctx) {
const Shape& in_shape = ctx->InputTensorDesc("input_tensor", 0).shape();
return GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));
});

#endif // WITH_CUDA

} // namespace oneflow

0 comments on commit 687dcdd

Please sign in to comment.