Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Use depthwise convolution(group convolution) by cuDNNv7 if available#10804

Closed
kice wants to merge 2 commits intoapache:masterfrom
kice:master
Closed

Use depthwise convolution(group convolution) by cuDNNv7 if available#10804
kice wants to merge 2 commits intoapache:masterfrom
kice:master

Conversation

@kice
Copy link
Contributor

@kice kice commented May 4, 2018

Description

Use group convolution by cuDNNv7 to improve GPU memory usage. This commit is base on terrychenism@90cc3d5

Related pr -> #7393

@nihui
Copy link
Contributor

nihui commented May 4, 2018

how about the speed improvment over mxnet native implementation ?

@szha
Copy link
Member

szha commented May 4, 2018

@kice would you list some comparison on the speed and memory usage? also, does cudnn api have any additional constraint?

@snflake
Copy link

snflake commented May 4, 2018

Great work! This seems to explain the current low performance of Mxnet compared to Tensorflow when dilation rate > 1 is used together with depthhwise convolution. PR #7393 only addresses dilation rate = 1. Tensorflow custom CUDA implementation also works with dilation rate 1. They use CuDNN otherwise. The reason is MxNet did not use group feature of CuDNN v7 which is implemented in this PR.
Would you fix merge failure? I would like to test this.

@snflake
Copy link

snflake commented May 4, 2018

The CI failure seems to not related to this PR.

unknown file: Failure
C++ exception with description "[04:51:17] /work/mxnet/tests/cpp/operator/mkldnn.cc:85: Check failed: mkldnn_format_last == 56 (67 vs. 56)

@snflake
Copy link

snflake commented May 4, 2018

About the speed, I used TensorRT with cudnn 7 for inference and depthwise conv is very fast regardless of dilation rate. There is no need for custom depthwise conv implementation if cudnn 7 group is used.

@IFeelBloated
Copy link

I have been working on something recently with heavy use of ResNeXt building blocks, would be nice to have grouped convolutions directly backed by cudnn7

@piiswrong
Copy link
Contributor

How does cudnn implementation compare to the custom kernels from tf? Should we always use cudnn?

@snflake
Copy link

snflake commented May 4, 2018

I got similar runtime with MobileNet v2 on laptop (Nvidia Quadro M1000M) using custom kernel and grouped conv by cudnn 7. IMO, we should always use cudnn if cudnn 7 is available.

DType *out_ptr = GetNdPtr(out_data[conv::kOut], param_.kernel.ndim() + 2, s);

#if CUDNN_MAJOR >= 7
typename DataType<DType>::ScaleType alpha = 1.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont indent for #if statements

CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_w_, param_.num_group));
#endif

#if CUDNN_MAJOR <= 6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about creating new variable effective_num_group and set it to 1 for cudnn7 and num_group otherwise, instead of always use #if tests.

}
}
#if CUDNN_MAJOR >= 7
typename DataType<DType>::ScaleType alpha = 1.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks the same as the old version if you change for loop to 0->effective_num_group?

out_ptr));
}
#else
for (uint32_t g = 0; g < param_.num_group; ++g) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to for (uint32_t g = 0; g < effective_num_group_; ++g) and you don't need #if tests anymore?

@piiswrong
Copy link
Contributor

@kice could you change the algorithm selection test to always prefer cudnn than custom kernel when cudnn > 7?

@nihui
Copy link
Contributor

nihui commented May 16, 2018

hello

some feedback about the speed

hardware: tesla-m40 24G x 2
system: centos-7
nvidia-387.26
cuda-9.1
cudnn-v7.1

model: mobilenet-v2
batchsize 256 (128 per gpu)

mxnet implementation: 68s/10iter
cudnnv7 implementation: 9.5s/10iter

ps: need to comment mxnet DepthwiseConvolutionOp path in src/operator/nn/convolution.cu to enable cudnn one

@HaoLiuHust
Copy link
Contributor

@nihui could you explain how to get the improvement?

@piiswrong
Copy link
Contributor

@kice Any updates?

@austingg
Copy link

cudnn has optimized some special path for grouped convolution mostly in cudnn 7.0.3 and 7.0.4 . Performance improvements for grouped convolutions when input channels and output channels per group are 1, 2, or 4 for the following algorithms. cudnn release note

we may referecence nvidia-caffe's verfication

#if CUDNN_VERSION_MIN(7, 0, 2)
  #define CUDNN_GROUPING
#endif
#if CUDNN_VERSION_MIN(7, 0, 3)
  #define CUDNN_GROUPING2
#endif

bool use_v7grouping() const {
#if defined(CUDNN_GROUPING2)
    return (this->channels_ == this->group_
         || this->channels_ == this->group_ * 2
         || this->channels_ == this->group_* 4)
        && (this->num_output_ == this->group_
         || this->num_output_ == this->group_ * 2
         || this->num_output_ == this->group_* 4);
#elif defined(CUDNN_GROUPING)
    return this->channels_ == this->num_output_ && this->channels_ == this->group_;
#else
    return false;
#endif
  }

for old path, it still uses for-loop.

@austingg
Copy link

I have tested this pr on two 1080ti. batch_size 96 each, and the speed is about 480 image/sec. However, the original depthwise conv path gets 650 images sec. So the speedup maybe architecture related. Besides, I also test nvidia-caffe on mobilenet-v2-1.0, with cudnn 7.0.5. it get 620 images/sec. I believe further funetune is need.

@piiswrong
Copy link
Contributor

Please move to #11076

@piiswrong piiswrong closed this May 29, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants