diff --git a/oneflow/user/kernels/variance_kernel.cpp b/oneflow/user/kernels/variance_kernel.cpp index d943226f1b0..ad133841c04 100644 --- a/oneflow/user/kernels/variance_kernel.cpp +++ b/oneflow/user/kernels/variance_kernel.cpp @@ -35,9 +35,11 @@ class VarKernel final : public user_op::OpKernel { const T* in_ptr = input->dptr(); T* out_ptr = output->mut_dptr(); const std::vector axis = ctx->Attr>("dim"); - T* tmp_buffer_ptr = axis.size() == input->shape().NumAxes() - ? ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr() - : nullptr; + // only all dims cuda case will use tmp buffer. + T* tmp_buffer_ptr = + (axis.size() == input->shape().NumAxes() && DeviceType::kCUDA == device_type) + ? ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr() + : nullptr; VarParamHelper param_helper(input->shape(), axis, unbiased); VarFunctor()(ctx->stream(), in_ptr, out_ptr, tmp_buffer_ptr, param_helper.param); @@ -45,66 +47,40 @@ class VarKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; +#define REGISTER_VAR_CPU_KERNEL(dtype) \ + REGISTER_USER_KERNEL("var").SetCreateFn>().SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobAttr("dtype") == GetDataType::value)); +REGISTER_VAR_CPU_KERNEL(float) +REGISTER_VAR_CPU_KERNEL(double) +#undef REGISTER_VAR_CPU_KERNEL + +#ifdef WITH_CUDA + size_t InferTmpBufferSize(user_op::InferContext* ctx) { const TensorDesc& input = ctx->InputTensorDesc("input", 0); const Shape& input_shape = input.shape(); const std::vector axis = ctx->Attr>("dim"); if (axis.size() == input_shape.NumAxes()) { - return static_cast(std::ceil(std::sqrt(input.shape().elem_cnt()))) - * GetSizeOfDataType(input.data_type()) * 3; + return GetCudaAlignedSize( + std::min(static_cast(std::ceil(std::sqrt(input.shape().elem_cnt()))), + kCudaMaxBlocksNum) + * GetSizeOfDataType(input.data_type()) * 3); } return 0; } -#define REGISTER_VAR_KERNEL(device, dtype) \ +#define REGISTER_VAR_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("var") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobAttr("dtype") == GetDataType::value)) \ .SetInferTmpSizeFn(InferTmpBufferSize); -#define REGISTER_VAR_KERNELS_WITH_DEVICE(device) \ - REGISTER_VAR_KERNEL(device, float) \ - REGISTER_VAR_KERNEL(device, double) - -REGISTER_VAR_KERNELS_WITH_DEVICE(DeviceType::kCPU) -#ifdef WITH_CUDA -REGISTER_VAR_KERNELS_WITH_DEVICE(DeviceType::kCUDA) +REGISTER_VAR_CUDA_KERNEL(float) +REGISTER_VAR_CUDA_KERNEL(double) +#undef REGISTER_VAR_CUDA_KERNEL #endif -#undef REGISTER_VAR_KERNELS_WITH_DEVICE -#undef REGISTER_VAR_KERNEL - -template -class VarGradKernel final : public user_op::OpKernel { - public: - VarGradKernel() = default; - ~VarGradKernel() override = default; - - private: - void Compute(user_op::KernelComputeContext* ctx) const override { - // TODO(liufengwei): Kernel implementation replaces functional::xx - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_VAR_GRAD_KERNEL(device, dtype) \ - REGISTER_USER_KERNEL("var_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ - && (user_op::HobAttr("dtype") == GetDataType::value)); - -#define REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE(device) \ - REGISTER_VAR_GRAD_KERNEL(device, float) \ - REGISTER_VAR_GRAD_KERNEL(device, double) - -REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE(DeviceType::kCPU) -#ifdef WITH_CUDA -REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE(DeviceType::kCUDA) -#endif - -#undef REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE -#undef REGISTER_VAR_GRAD_KERNEL - } // namespace user_op } // namespace oneflow diff --git a/oneflow/user/kernels/variance_kernel_util.cu b/oneflow/user/kernels/variance_kernel_util.cu index d69a95b2695..7203624efca 100644 --- a/oneflow/user/kernels/variance_kernel_util.cu +++ b/oneflow/user/kernels/variance_kernel_util.cu @@ -160,6 +160,7 @@ __global__ void ComputeVarScalarOut(const T* in_ptr, T* out_ptr, T* tmp_buffer_p if (threadIdx.x == 0) { *out_ptr = cuda::layer_norm::Div(final_m2, (var_param.unbiased ? final_count - 1 : final_count)); + done_block_count = 0; } } } diff --git a/python/oneflow/test/tensor/test_tensor_part_2.py b/python/oneflow/test/tensor/test_tensor_part_2.py index 114ec74415b..d0f0cc3e887 100644 --- a/python/oneflow/test/tensor/test_tensor_part_2.py +++ b/python/oneflow/test/tensor/test_tensor_part_2.py @@ -363,10 +363,6 @@ def test_floor_tensor_with_random_data(test_case): y = x.floor() return y - @unittest.skip( - "TODO: probably fail, skip for now and fix it in future." - "ref to: https://github.com/Oneflow-Inc/OneTeam/issues/1006#issuecomment-1022768858" - ) @autotest(check_graph=True) def test_tensor_var_all_dim_with_random_data(test_case): device = random_device()