Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add conv3d_trans_cudnn_op #5603

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,30 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()

if ("${TARGET}" STREQUAL "compare_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
endif()

# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()

# pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()

if ("${TARGET}" STREQUAL "compare_op")
# pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
endif()

# pool_with_index_op contains several operators
Expand All @@ -80,25 +94,18 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif()

# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()

# conv_transpose_op contains several operators
if ("${TARGET}" STREQUAL "conv_transpose_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n")
endif()
# pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")

# conv_transpose_cudnn_op contains two operators
if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n")
endif()

# save_restore_op contains several operators
Expand Down
10 changes: 4 additions & 6 deletions paddle/operators/conv_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
// Because beta is zero, it is unnecessary to reset input_grad.

for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc,
Expand All @@ -241,9 +240,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
// Because beta is zero, it is unnecessary to reset filter_grad.

for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
Expand Down
12 changes: 8 additions & 4 deletions paddle/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad);

REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);

REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
12 changes: 8 additions & 4 deletions paddle/operators/conv_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
namespace ops = paddle::operators;

REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);

REGISTER_OP_GPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
framework::OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
.SetDefault({1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
}
};

class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public:
CudnnConv3DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
Expand All @@ -48,3 +65,14 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);

REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
ops::ConvTransposeOpGrad);

REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REGISTER double for these ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;

if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}

// N, M, H, W
// (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// N, C, O_h, O_w
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
// M, C, K_h, K_w
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
cudnnConvolutionDescriptor_t cudnn_conv_desc =
Expand Down Expand Up @@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;

// Input: (N, M, H, W)
// Input: (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// Output: (N, C, O_H, O_W)
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()));
// Filter (M, C, K_H, K_W)
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));

Expand Down Expand Up @@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
math::set_constant(ctx.device_context(), input_grad, 0);

// Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data,
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
Expand All @@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
math::set_constant(ctx.device_context(), filter_grad, 0);

// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
Expand All @@ -234,3 +238,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);

REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no double for these ops. Please add in next PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

12 changes: 8 additions & 4 deletions paddle/operators/conv_transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,

REGISTER_OP_CPU_KERNEL(
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);

REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad);

REGISTER_OP_CPU_KERNEL(
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
12 changes: 8 additions & 4 deletions paddle/operators/conv_transpose_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ namespace ops = paddle::operators;

REGISTER_OP_GPU_KERNEL(
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);

REGISTER_OP_GPU_KERNEL(
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
3 changes: 1 addition & 2 deletions paddle/operators/pool_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {

if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::GPUPlace, T> set_zero;
set_zero(ctx.device_context(), input_grad, static_cast<T>(0));
// Because beta is zero, it is unnecessary to reset input_grad.

PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
Expand Down
15 changes: 10 additions & 5 deletions paddle/platform/cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
} \
} while (false)

enum class DataLayout {
enum class DataLayout { // Not use
kNHWC,
kNCHW,
kNCDHW,
kNCHW_VECT_C,
};

Expand Down Expand Up @@ -107,12 +108,15 @@ class CudnnDataType<double> {
}
};

inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
inline cudnnTensorFormat_t GetCudnnTensorFormat(
const DataLayout& order) { // Not use
switch (order) {
case DataLayout::kNHWC:
return CUDNN_TENSOR_NHWC;
case DataLayout::kNCHW:
return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
default:
PADDLE_THROW("Unknown cudnn equivalent for order");
}
Expand All @@ -139,7 +143,7 @@ class ScopedTensorDescriptor {
strides[i] = dims[i + 1] * strides[i + 1];
}
// Update tensor descriptor dims setting if groups > 1
// FIXME(typhoonzero): Assume using NCHW order
// FIXME(typhoonzero): Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups;
Expand Down Expand Up @@ -176,9 +180,10 @@ class ScopedFilterDescriptor {
const cudnnDataType_t type,
const std::vector<int>& kernel,
const int groups = 1) {
// filter layout: MCHW, where M is the number of
// filter layout: MCHW(MCDHW), where M is the number of
// output image channels, C is the number of input image channels,
// H and W is height and width of filter.
// D is the depth of the filter, H is the height of the filter, and W is the
// width of the filter.
std::vector<int> kernel_with_group(kernel.begin(), kernel.end());
if (groups > 1) {
// M /= groups
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,11 @@ def init_test_case(self):
self.filter_size = [f_c, 6, 3, 3, 3]


# ------------ test_cudnn ------------
class TestCudnn(TestConv3dTransposeOp):
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"


if __name__ == '__main__':
unittest.main()