Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-491] Use depthwise convolution by cuDNNv7 if available, updated version #11076

Merged
merged 6 commits into from
May 29, 2018

Conversation

nihui
Copy link
Contributor

@nihui nihui commented May 28, 2018

this pull request is based on #10804
with the following further changes:

  1. reduce ident changes
  2. prefer cudnn depthwise convolution over mxnet implementation

still use the explicit #if #else #endif statement over
the new variable effective_num_group solution for backward code path compability
because the new variable effective_num_group may confuse readers with standard group convolution

some feedback about the speed

hardware: tesla-m40 24G x 2
system: centos-7
nvidia-387.26
cuda-9.1
cudnn-v7.1

model: mobilenet-v2
batchsize 256 (128 per gpu)

mxnet implementation: 68s/10iter
cudnnv7 implementation: 9.5s/10iter

@nihui nihui changed the title Use depthwise convolution by cuDNNv7 if available, updated version [MXNET-491] Use depthwise convolution by cuDNNv7 if available, updated version May 28, 2018
@piiswrong
Copy link
Contributor

I still think this is too much duplicated code.
We can add comments to explain what effective_num_group is.

@piiswrong
Copy link
Contributor

Also please correct indentation.
All '#if/endif' statements should not be indented

@piiswrong
Copy link
Contributor

Actually, I guess this is fine since we'll eventually remove all the logic for older versions of CUDNN.

@piiswrong
Copy link
Contributor

@austingg I merged this before seeing your comment. Do you have any concerns?

@austingg
Copy link

@piiswrong need more speed benchmark on different architecture gpu and more accurate cudnn version macro like nvidia-caffe.

rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
…d version (apache#11076)

* Use group convolution by cuDNNv7 if available

* Fix coding style

* ident-- for #if statements

* more ident--

* more ident--

* prefer cudnnv7 depthwise convolution
@BiranLi
Copy link

BiranLi commented Jun 5, 2018

I have tested mobilenetv2 in V100.

hardware: tesla-v100 32G
system: centos-7.2
nvidia-396.26
cuda-9.2
cudnn-v7.1

model: mobilenet-v2(only forward)
batchsize 1
with tf depthwise conv 194 samples/s
with cudnn group conv 235 samples/s

@austingg
Copy link

austingg commented Jun 5, 2018

@BiranLi do you do some more benchmark, like more batch size and take backward into consideration.

@BiranLi
Copy link

BiranLi commented Jun 5, 2018

@austingg Yes, I have tested the same case with batchsize 128.

model: mobilenet-v2(only forward)
batchsize 128
with tf depthwise conv 2000 samples/s
with cudnn group conv 2335 samples/s

@haojin2
Copy link
Contributor

haojin2 commented Jun 8, 2018

@BiranLi Can you share some more details on how you're doing this benchmark? Thanks!

@haojin2
Copy link
Contributor

haojin2 commented Jun 11, 2018

Did some extra benchmarks and verified multi-precision training speed improvement on single V100 GPU with mobilenet + ImageNet dataset:
before:
INFO:root:Epoch[0] Batch [20] Speed: 95.60 samples/sec accuracy=0.013765
INFO:root:Epoch[0] Batch [40] Speed: 95.73 samples/sec accuracy=0.148047
INFO:root:Epoch[0] Batch [60] Speed: 95.73 samples/sec accuracy=0.865234
INFO:root:Epoch[0] Batch [80] Speed: 95.75 samples/sec accuracy=1.000000
INFO:root:Epoch[0] Batch [100] Speed: 95.72 samples/sec accuracy=1.000000
after:
INFO:root:Epoch[0] Batch [20] Speed: 1011.35 samples/sec accuracy=0.013765
INFO:root:Epoch[0] Batch [40] Speed: 1032.15 samples/sec accuracy=0.112109
INFO:root:Epoch[0] Batch [60] Speed: 1038.41 samples/sec accuracy=0.832812
INFO:root:Epoch[0] Batch [80] Speed: 1034.26 samples/sec accuracy=1.000000
INFO:root:Epoch[0] Batch [100] Speed: 1032.14 samples/sec accuracy=1.000000
@anirudh2290

anirudh2290 pushed a commit to anirudh2290/mxnet that referenced this pull request Jun 11, 2018
…d version (apache#11076)

* Use group convolution by cuDNNv7 if available

* Fix coding style

* ident-- for #if statements

* more ident--

* more ident--

* prefer cudnnv7 depthwise convolution
piiswrong pushed a commit that referenced this pull request Jun 12, 2018
…d version (#11076) (#11233)

* Use group convolution by cuDNNv7 if available

* Fix coding style

* ident-- for #if statements

* more ident--

* more ident--

* prefer cudnnv7 depthwise convolution
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
…d version (apache#11076)

* Use group convolution by cuDNNv7 if available

* Fix coding style

* ident-- for #if statements

* more ident--

* more ident--

* prefer cudnnv7 depthwise convolution
@shesung
Copy link
Contributor

shesung commented Aug 28, 2018

I observed barely no improvement when using mxnet 1.2.1 + cuda8 + cudnn 7.2.1 on 1080ti
When setting MXNET_CUDNN_AUTOTUNE_DEFAULT=0, performance drop rapidly.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants