Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

add depthwise convolution's gpu version optimization#7393

Merged
piiswrong merged 8 commits intoapache:masterfrom
crazy-cat:master
Aug 16, 2017
Merged

add depthwise convolution's gpu version optimization#7393
piiswrong merged 8 commits intoapache:masterfrom
crazy-cat:master

Conversation

@crazy-cat
Copy link
Contributor

As the cudnn is not optimized for depthwise convolution, we optimized the gpu version of depthwise 2D convolution.
The training effect is as follows:
cudnn version mobilenet training in imagenet

cd example/image-classification/;
python train_imagenet.py --network mobilenet --gpus=0 --data-train=./train_480_q95.rec --data-nthreads 8

INFO:root:start with arguments Namespace(batch_size=128, benchmark=0, data_nthreads=8, data_train='./train_480_q95.rec', data_val=None, disp_batches=20, dtype='float32', gpus='0', image_shape='3,224,224', kv_store='device', load_epoch=None, lr=0.1, lr_factor=0.1, lr_step_epochs='30,60', max_random_aspect_ratio=0.25, max_random_h=36, max_random_l=50, max_random_rotate_angle=10, max_random_s=50, max_random_scale=1, max_random_shear_ratio=0.1, min_random_scale=1, model_prefix=None, mom=0.9, monitor=0, network='mobilenet', num_classes=1000, num_epochs=80, num_examples=1281167, num_layers=50, optimizer='sgd', pad_size=0, random_crop=1, random_mirror=1, rgb_mean='123.68,116.779,103.939', test_io=0, top_k=0, wd=0.0001)
[10:03:41] src/io/iter_image_recordio_2.cc:135: ImageRecordIOParser2: ./train_480_q95.rec, use 7 threads for decoding..
[10:03:45] src/operator/././cudnn_algoreg-inl.h:65: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
INFO:root:Epoch[0] Batch [20]	Speed: 133.85 samples/sec	accuracy=0.000744
INFO:root:Epoch[0] Batch [40]	Speed: 135.98 samples/sec	accuracy=0.001953
INFO:root:Epoch[0] Batch [60]	Speed: 135.47 samples/sec	accuracy=0.000391
INFO:root:Epoch[0] Batch [80]	Speed: 132.32 samples/sec	accuracy=0.001563
INFO:root:Epoch[0] Batch [100]	Speed: 134.01 samples/sec	accuracy=0.001953

our version mobilenet training in imagenet

cd example/image-classification/;
python train_imagenet.py --network mobilenet --gpus=0 --data-train=./train_480_q95.rec --data-nthreads 8

INFO:root:start with arguments Namespace(batch_size=128, benchmark=0, data_nthreads=8, data_train='./train_480_q95.rec', data_val=None, disp_batches=20, dtype='float32', gpus='0', image_shape='3,224,224', kv_store='device', load_epoch=None, lr=0.1, lr_factor=0.1, lr_step_epochs='30,60', max_random_aspect_ratio=0.25, max_random_h=36, max_random_l=50, max_random_rotate_angle=10, max_random_s=50, max_random_scale=1, max_random_shear_ratio=0.1, min_random_scale=1, model_prefix=None, mom=0.9, monitor=0, network='mobilenet', num_classes=1000, num_epochs=80, num_examples=1281167, num_layers=50, optimizer='sgd', pad_size=0, random_crop=1, random_mirror=1, rgb_mean='123.68,116.779,103.939', test_io=0, top_k=0, wd=0.0001)
[09:59:19] src/io/iter_image_recordio_2.cc:135: ImageRecordIOParser2: ./train_480_q95.rec, use 7 threads for decoding..
[09:59:25] src/operator/././cudnn_algoreg-inl.h:65: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
INFO:root:Epoch[0] Batch [20]	Speed: 476.02 samples/sec	accuracy=0.000372
INFO:root:Epoch[0] Batch [40]	Speed: 489.77 samples/sec	accuracy=0.001563
INFO:root:Epoch[0] Batch [60]	Speed: 495.26 samples/sec	accuracy=0.000781
INFO:root:Epoch[0] Batch [80]	Speed: 494.94 samples/sec	accuracy=0.001563
INFO:root:Epoch[0] Batch [100]	Speed: 494.81 samples/sec	accuracy=0.002734

The defaule depthwise conv will go in optimized version, you can change depthwise_conv_off to True in symbols/mobilenet.py if you want to use cudnn version.

...
    conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True,  depthwise_conv_off=True,
                    name='%s%s_conv2d' %(name, suffix))
...

Hardware :
TITAN X (Pascal) + Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz * 16 + 128GMem
Software :
cuda8.0 + cudnn5.1
As described above, we get about 3-4 times speed compared the cudnn version. About the test, we have compared the result in every depthwise layer with the conv version.

.describe("Set layout for input, output and weight. Empty for\n "
"default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.");
DMLC_DECLARE_FIELD(depthwise_conv_off).set_default(false)
.describe("whether to turn off depthwise convolution for this layer");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason we would want to turn this off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just provide a choice, like the cudnn_off

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this if there are no important reasons. Convolution has too many switches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

param.num_filter == (*in_shape)[conv::kData][1] &&
param.kernel.ndim() == 2 &&
param.dilate == mshadow::Shape2(1, 1) &&
dtype == mshadow::kFloat32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for limiting to float32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are no processing in cuda kernel when dtype==mshadow::kFloat16

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there plan to support dilation with depthwise kernel? It is used in mobilenet v2 + deeplabv3 for segmentation. Tensorflow has efficient implementation. mxnet is much slower in this case.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@crazy-cat Will you please implement support for dilation rate > 1?

for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-4)

def test_depthwise_convolution():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move tests to tests/python/gpu/test_operator_gpu.py and use the standard consistency and numerical gradient tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is fine. please reduce 224 to something like 32 or 64 and test for 2 or 3 more configs. Like differetn num_base and kernal_size/pad/stride

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

kernel = (3, 3)
stride = (1, 1)
pad = (1,1)
shape = (2, num_base, 224, 224)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input doesn't need to be this big. Use smaller number to make tests run faster.

num_filter=num_filter/num_group, kernel=kernel, stride=stride, pad=pad)
for i in range(num_group)])

dev = mx.gpu(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dev = default_context().
This will fail on gpu-less test servers.

<< "cuRAND: " << common::cuda::CurandGetErrorString(e); \
}

#define CUDA_1D_KERNEL_LOOP(i, n) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function 'CUDA_KERNEL_LOOP' was defined in 'mxnet_op.h'. It's same to the 'CUDA_1D_KERNEL_LOOP'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

namespace cuda {
template<typename DType, int kFilterWidth, int kFilterHeight>
__global__ void __launch_bounds__(1024, 2)
DepthwiseConv2dBackwardFilterKernel(const DepthwiseArgs args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put this gpu function into the file depthwise_convolution_tf.cuh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In depthwise_convolution_tf.cuh, the cuda kernel's main logic are from tensorflow, but this kernel is done by us alone, so we keep it in mxnet namespace.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this logic faster than the logic in tensorflow? There are total 4 loops in this logic.
Why there is no cross - border judgment during filter * input in this logic. Just like
' if (in_r_start >= 0 && in_c_start >= 0 && in_r_end < in_rows && in_c_end < in_cols)'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have test that. The tf version is just atomicAdd by all threads.
The cross-border judgement is in line 135-138.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Others, the filter size is very small compared with the input or output shape, so the conflict will be serious when thread atomicAdd frequently.

@solin319
Copy link
Contributor

solin319 commented Aug 9, 2017

@piiswrong When will mxnet support cudnn-v7? Grouped Convolutions was included in this version.

@austingg
Copy link

austingg commented Aug 9, 2017

@solin319 , it is said grouped conv in cudnn v7 is not efficient when used in depthwise conv, where the group num == input channels

@chinakook
Copy link
Contributor

Nice work, It's a good feature for mobilenet!

@piiswrong
Copy link
Contributor

@crazy-cat
Copy link
Contributor Author

@piiswrong All checks have passed.

@piiswrong
Copy link
Contributor

Could you rebase to master and push again? Somehow test is failing

@piiswrong
Copy link
Contributor

Thanks!

@7oud
Copy link

7oud commented Aug 23, 2017

@austingg do you have some benchmark about grouped conv in cudnn v7?

@terrychenism
Copy link
Contributor

I have tested cudnn7 for mobilenet on single Titan X. terrychenism@90cc3d5

cudnn v5 - 66 samples/sec
cudnn v7 - 105 samples/sec
depthwise-conv - 163 samples/sec

@BiranLi
Copy link

BiranLi commented Aug 29, 2017

@terrychenism Thx for your code. I'll test it again.

@7oud
Copy link

7oud commented Aug 29, 2017

@BiranLi Thanks for your test data, now there are two opposite results, cudnn v7 faster in your test and depthwise faster in @terrychenism ' s test. BTW, in your test "cudnn v7 with dw -- 1200samples" is in one second including forward and backward ?

@BiranLi
Copy link

BiranLi commented Aug 29, 2017

@7oud Yes, a entire train process. And I think I got a mistake in code. I'll update my test data after recent testing.

OK, I get a similar results to @terrychenism .

@chinakook
Copy link
Contributor

chinakook commented Aug 29, 2017

I think we also need a CPU version of depthwise-conv as this operator is assumed to use without cuda in practice.

@austingg
Copy link

@terrychenism how about the convergence, I used the symbol in imageclassification dir and script train_imagenet.py, seems not converge when using depthwise conv.

@leonid-pishchulin
Copy link

I run mobilenet.py when setting num_group=num_filter and num_group=1 on GTX 1080ti on 1024 640x480 images with batch size=8 (total 128 batches) and compute average run-time. Using num_group=num_filter achieves 7.5 ms/fr vs. 10.9 ms/fr in case of num_group=1. Great !
However, a version of ResNet-18 with grouping convolutions runs at 5.3 ms/fr vs 3.7 ms/fr of the original convolutions with num_group=1. Why is that?
I used implementation of ResNet-18 from resnet.py as a baseline and modified it in the following ways

  1. I added 1x1 projections to the beginning of each of the residual blocks 2a, 3a, 4a and 5a prior to the first 3x3 convolutional layer to make the number of input channels and filters to be equal within each 3x3 block
  2. I set num_group = num_filter for each 3x3 convolutional layer starting from 2a.
    In order to evaluate the effect of including additional 1x1 projections onto the total run-time I measured the run-time of the version where I add projections, but set num_group=1. Differences are negligible compared to using no 1x1 projections.
    Any ideas?

Is there smth specific about mobilenet that allows for a speed-up when using depth-wise factorized convolutions, which does not hold for ResNet-18?

@7oud
Copy link

7oud commented Aug 30, 2017

@leonid-pishchulin Thx for your experimental data! Now there are three implement of depthwise conv,

  • cudnn v6 : 7.5 ms/fr, batchsize=20
  • optimized conv without cudnn: 2 ms/fr, batchsize=20
  • cudnn v7 (grouped conv): ?
    It seems that your test speed, 7,5 ms and 10.9 ms are slower, is it related to your image size 640x480 ?
    BTW, could you give the computational costs (GFLOPS) comparison between 7.5 ms/fr vs. 10.9 ms/fr in your test ?

@leonid-pishchulin
Copy link

leonid-pishchulin commented Aug 30, 2017

what is the best way of measuring GFLOPS?
how to call optimized conv without cudnn?
btw. I'm testing with cudnn 6 and cuda 8

@7oud
Copy link

7oud commented Aug 30, 2017

This PR (#7393) is just the optimized conv without cudnn.
The mobilenet paper gives the computation of 569 Million Multi-Adds, could you give the approximately number when using num_group=1

@leonid-pishchulin
Copy link

I double-checked: DepthwiseConvolutionOp is called when num_group=num_filter for both mobilenet and resnet. Have you ever measured the speed-ups when running resnet with depth-separable convolutions? I get not speed-up, the performance is even a bit slower when setting num_group=num_filter

@leonid-pishchulin
Copy link

found a bug in my code. ResNet-18 with num_group=num_filter for 3x3 conv layers is ~2x faster compared to num_group=1. Thanks for the great feature !
Is there a plan to make this work with fp16 at the same level of efficiency as fp16 cuda kernels?

Guneet-Dhillon pushed a commit to Guneet-Dhillon/mxnet that referenced this pull request Sep 13, 2017
* add depthwise convolution's gpu version optimization

* add more config for test_depthwise_convolution

* remove CUDA_1D_KERNEL_LOOP

* fix windows compiling error

* add support for kAddTo when cal input's backward

* remove depthwise_conv_off params

* Update convolution.cu

* Update test_operator.py
@kice
Copy link
Contributor

kice commented May 4, 2018

Did someone do some tests the depthwise conv of cudnn v7 on Pascal GPU? I think we can get some performance improvement on latest architecture.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.