From e166bf87cdb1f538c91029f2764a07bf2aa785ca Mon Sep 17 00:00:00 2001 From: simonJJJ <821898965@qq.com> Date: Fri, 9 Jul 2021 18:56:08 +0800 Subject: [PATCH] fix pool gpu kernel --- oneflow/user/kernels/pool_gpu_kernel.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/pool_gpu_kernel.cpp b/oneflow/user/kernels/pool_gpu_kernel.cpp index aa4aa8c2e07..53206e544f7 100644 --- a/oneflow/user/kernels/pool_gpu_kernel.cpp +++ b/oneflow/user/kernels/pool_gpu_kernel.cpp @@ -43,9 +43,9 @@ class CudnnPoolDesc final { class GPUPoolOpKernelState final : public user_op::OpKernelState { public: - GPUPoolOpKernelState(const int32_t dim, const std::string& pooling_type, const Shape& x_shape, - const Shape& y_shape, const std::string& data_format, const DataType& dtype, - const Params3D& params_3d) + GPUPoolOpKernelState(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape, + const ShapeView& y_shape, const std::string& data_format, + const DataType& dtype, const Params3D& params_3d) : dim_(dim), pooling_type_(pooling_type) { Reset(dim, pooling_type, x_shape, y_shape, data_format, dtype, params_3d); } @@ -82,7 +82,7 @@ class GPUPoolOpKernelState final : public user_op::OpKernelState { const int32_t& dim, const std::string& pooling_type, user_op::KernelComputeContext* ctx) { if (pooling_type != "MAX" && pooling_type != "AVG") { UNIMPLEMENTED(); } const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - const Shape& x_shape = x_desc->shape(); + const ShapeView& x_shape = ctx->Tensor4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::string& padding = ctx->Attr("padding"); const auto& padding_before = ctx->Attr>("padding_before"); @@ -92,8 +92,8 @@ class GPUPoolOpKernelState final : public user_op::OpKernelState { const bool ceil_mode = ctx->Attr("ceil_mode"); const Params3D params_3d(dim, x_shape, data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); - const Shape y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape(); - const DataType dtype = x_desc->data_type(); + const ShapeView& y_shape = ctx->Tensor4ArgNameAndIndex("y", 0)->shape(); + const DataType dtype = ctx->Tensor4ArgNameAndIndex("x", 0)->data_type(); return std::make_shared(dim, pooling_type, x_shape, y_shape, data_format, dtype, params_3d); }