Skip to content

Commit

Permalink
fix var bug (#7517)
Browse files Browse the repository at this point in the history
* fix var bug

* cuda malloc aligned size

* refine

* auto format by CI

* refine var kernel registe

* format

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people authored and marigoold committed Mar 15, 2022
1 parent fc3fcaf commit 1128d55
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 53 deletions.
74 changes: 25 additions & 49 deletions oneflow/user/kernels/variance_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,76 +35,52 @@ class VarKernel final : public user_op::OpKernel {
const T* in_ptr = input->dptr<T>();
T* out_ptr = output->mut_dptr<T>();
const std::vector<int32_t> axis = ctx->Attr<std::vector<int32_t>>("dim");
T* tmp_buffer_ptr = axis.size() == input->shape().NumAxes()
? ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr<T>()
: nullptr;
// only all dims cuda case will use tmp buffer.
T* tmp_buffer_ptr =
(axis.size() == input->shape().NumAxes() && DeviceType::kCUDA == device_type)
? ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr<T>()
: nullptr;
VarParamHelper param_helper(input->shape(), axis, unbiased);
VarFunctor<device_type, T>()(ctx->stream(), in_ptr, out_ptr, tmp_buffer_ptr,
param_helper.param);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_VAR_CPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("var").SetCreateFn<VarKernel<DeviceType::kCPU, dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceType() == DeviceType::kCPU) \
&& (user_op::HobAttr<DataType>("dtype") == GetDataType<dtype>::value));
REGISTER_VAR_CPU_KERNEL(float)
REGISTER_VAR_CPU_KERNEL(double)
#undef REGISTER_VAR_CPU_KERNEL

#ifdef WITH_CUDA

size_t InferTmpBufferSize(user_op::InferContext* ctx) {
const TensorDesc& input = ctx->InputTensorDesc("input", 0);
const Shape& input_shape = input.shape();
const std::vector<int32_t> axis = ctx->Attr<std::vector<int32_t>>("dim");
if (axis.size() == input_shape.NumAxes()) {
return static_cast<size_t>(std::ceil(std::sqrt(input.shape().elem_cnt())))
* GetSizeOfDataType(input.data_type()) * 3;
return GetCudaAlignedSize(
std::min(static_cast<int32_t>(std::ceil(std::sqrt(input.shape().elem_cnt()))),
kCudaMaxBlocksNum)
* GetSizeOfDataType(input.data_type()) * 3);
}
return 0;
}

#define REGISTER_VAR_KERNEL(device, dtype) \
#define REGISTER_VAR_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("var") \
.SetCreateFn<VarKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == device) \
.SetCreateFn<VarKernel<DeviceType::kCUDA, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobAttr<DataType>("dtype") == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn(InferTmpBufferSize);

#define REGISTER_VAR_KERNELS_WITH_DEVICE(device) \
REGISTER_VAR_KERNEL(device, float) \
REGISTER_VAR_KERNEL(device, double)

REGISTER_VAR_KERNELS_WITH_DEVICE(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_VAR_KERNELS_WITH_DEVICE(DeviceType::kCUDA)
REGISTER_VAR_CUDA_KERNEL(float)
REGISTER_VAR_CUDA_KERNEL(double)
#undef REGISTER_VAR_CUDA_KERNEL
#endif

#undef REGISTER_VAR_KERNELS_WITH_DEVICE
#undef REGISTER_VAR_KERNEL

template<DeviceType device_type, typename T>
class VarGradKernel final : public user_op::OpKernel {
public:
VarGradKernel() = default;
~VarGradKernel() override = default;

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
// TODO(liufengwei): Kernel implementation replaces functional::xx
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_VAR_GRAD_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("var_grad") \
.SetCreateFn<VarGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == device) \
&& (user_op::HobAttr<DataType>("dtype") == GetDataType<dtype>::value));

#define REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE(device) \
REGISTER_VAR_GRAD_KERNEL(device, float) \
REGISTER_VAR_GRAD_KERNEL(device, double)

REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE(DeviceType::kCUDA)
#endif

#undef REGISTER_VAR_GRAD_KERNELS_WITH_DEVICE
#undef REGISTER_VAR_GRAD_KERNEL

} // namespace user_op
} // namespace oneflow
1 change: 1 addition & 0 deletions oneflow/user/kernels/variance_kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ __global__ void ComputeVarScalarOut(const T* in_ptr, T* out_ptr, T* tmp_buffer_p
if (threadIdx.x == 0) {
*out_ptr =
cuda::layer_norm::Div(final_m2, (var_param.unbiased ? final_count - 1 : final_count));
done_block_count = 0;
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions python/oneflow/test/tensor/test_tensor_part_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,6 @@ def test_floor_tensor_with_random_data(test_case):
y = x.floor()
return y

@unittest.skip(
"TODO: probably fail, skip for now and fix it in future."
"ref to: https://github.com/Oneflow-Inc/OneTeam/issues/1006#issuecomment-1022768858"
)
@autotest(check_graph=True)
def test_tensor_var_all_dim_with_random_data(test_case):
device = random_device()
Expand Down

0 comments on commit 1128d55

Please sign in to comment.