Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ReduceSumFloatCudaKernel with GEMM #7684

Merged
merged 4 commits into from
Mar 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 91 additions & 6 deletions oneflow/user/kernels/reduce_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.

#ifdef WITH_CUDA
#include "oneflow/core/ep/cuda/cuda_device.h"
#include "oneflow/core/ep/include/primitive/matmul.h"
#endif // WITH_CUDA

namespace oneflow {
Expand Down Expand Up @@ -86,7 +87,6 @@ class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSu
#define REGISTER_REDUCE_ARITHMETIC_KERNELS(device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_prod", BinaryFuncProd, device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_min", BinaryFuncMin, device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_sum", BinaryFuncSum, device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_max", BinaryFuncMax, device, dtype)

#define REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(device) \
Expand All @@ -102,6 +102,22 @@ REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCPU)
REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCUDA)
#endif

#define REGISTER_REDUCE_SUM_KERNELS(device, dtype) \
REGISTER_REDUCE_XPU_KERNEL("reduce_sum", BinaryFuncSum, device, dtype)

#define REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(device) \
REGISTER_REDUCE_SUM_KERNELS(device, double) \
REGISTER_REDUCE_SUM_KERNELS(device, int8_t) \
REGISTER_REDUCE_SUM_KERNELS(device, uint8_t) \
REGISTER_REDUCE_SUM_KERNELS(device, int32_t) \
REGISTER_REDUCE_SUM_KERNELS(device, int64_t)

REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCUDA)
#endif
REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, float)

#define REGISTER_REDUCE_LOGICAL_KERNELS(device) \
REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, bool) \
REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, bool) \
Expand Down Expand Up @@ -133,10 +149,12 @@ std::vector<int32_t> RegularAxis(const std::vector<int32_t>& axis) {
void GetReduceSumLayout(const std::vector<int32_t>& axis, const ShapeView& in_shape,
bool* is_axis_contiguous, int64_t* outer_size, int64_t* inner_size,
int64_t* reduce_size) {
*is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size());
*outer_size = in_shape.Count(0, axis.front());
*inner_size = in_shape.Count(axis.back() + 1);
*reduce_size = in_shape.Count(axis.front(), axis.back() + 1);
if (!axis.empty()) {
*is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size());
*outer_size = in_shape.Count(0, axis.front());
*inner_size = in_shape.Count(axis.back() + 1);
*reduce_size = in_shape.Count(axis.front(), axis.back() + 1);
}
}

} // namespace
Expand Down Expand Up @@ -236,6 +254,73 @@ REGISTER_USER_KERNEL("reduce_sum")
return tmp_bytes;
});

#endif
class ReduceSumFloatCudaKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
ReduceSumFloatCudaKernel() = default;
~ReduceSumFloatCudaKernel() = default;

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
std::vector<int32_t> axis = RegularAxis(ctx->Attr<std::vector<int32_t>>("axis"));
const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input_tensor", 0);
user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const ShapeView& in_shape = input_tensor->shape();
if (input_tensor->shape().elem_cnt() == 0) {
if (output_tensor->shape().elem_cnt() != 0) {
Memset<DeviceType::kCUDA>(
ctx->stream(), output_tensor->mut_dptr<float>(), 0,
output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()));
}
return;
}
bool is_axis_contiguous = false;
int64_t outer_size = 0, inner_size = 0, reduce_size = 0;
GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size);
const float* ones = nullptr;
auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());
if (cuda_device != nullptr) {
ones = static_cast<const float*>(cuda_device->GetConstOnes(DataType::kFloat, reduce_size));
}
if ((!axis.empty()) && in_shape.NumAxes() > 0 && is_axis_contiguous
&& (outer_size == 1 || inner_size == 1) && ones != nullptr) {
ep::primitive::BlasTransposeType trans_a = (inner_size == 1)
? ep::primitive::BlasTransposeType::N
: ep::primitive::BlasTransposeType::T;
ep::primitive::BlasTransposeType trans_b = ep::primitive::BlasTransposeType::N;
const int32_t m = (inner_size == 1) ? outer_size : inner_size;
const int32_t n = 1;
const int32_t k = reduce_size;
#if CUDA_VERSION >= 11000
CublasMathModeGuard guard(ctx->stream()->As<ep::CudaStream>()->cublas_handle());
// disable tf32
guard.SetMathMode(CUBLAS_DEFAULT_MATH);
#endif // defined(WITH_CUDA) && CUDA_VERSION >= 11000
auto matmul = ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(
DeviceType::kCUDA, DataType::kFloat, trans_a, trans_b);
CHECK(matmul);
matmul->Launch(ctx->stream(), m, n, k, 1.0, input_tensor->dptr(), ones, 0.0,
output_tensor->mut_dptr());
} else {
const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()});
NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(
ctx->stream(), XpuVarNdarray<float>(reduced_shape, output_tensor->mut_dptr<float>()),
XpuVarNdarray<const float>(input_tensor->shape(), input_tensor->dptr<float>()),
XpuVarNdarray<float>(tmp_buffer->shape(), tmp_buffer->mut_dptr<float>()));
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

REGISTER_USER_KERNEL("reduce_sum")
.SetCreateFn<ReduceSumFloatCudaKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)
&& (user_op::HobDataType("output_tensor", 0) == DataType::kFloat))
.SetInferTmpSizeFn([](user_op::InferContext* ctx) {
const Shape& in_shape = ctx->InputTensorDesc("input_tensor", 0).shape();
return GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));
});

#endif // WITH_CUDA

} // namespace oneflow