Skip to content

Commit

Permalink
fix data layout
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Nov 15, 2017
1 parent e825a49 commit 7c2fd61
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions paddle/operators/pool_cudnn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;

if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}

cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
Expand Down Expand Up @@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;

if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}

cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
Expand Down

0 comments on commit 7c2fd61

Please sign in to comment.