New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dropout is Slow #13825

Open
thomelane opened this Issue Jan 10, 2019 · 1 comment

Comments

Projects
None yet
2 participants
@thomelane
Copy link
Contributor

thomelane commented Jan 10, 2019

Description

Adding dropout to network seems to reduce speed of forward pass substantially.

I've seen this a few times now, and the latest was occasion was on the discussion forum:
https://discuss.mxnet.io/t/training-speed-in-mxnet-is-nearly-2-5x-times-slower-than-pytorch. User's training was x2.5 slower than PyTorch originally, but reported a 6 times speedup by removing dropout, ultimately making training faster than PyTorch. User also reported a significant reduction in memory usage.

Could be related to #12976.

Environment info (Required)

User above is using MXNet version 1.3.1 on Ubuntu 16.04.5 with GTX1080.
And I have been able to replicate on AWS EC2 p3.2xlarge.

(mxnet_p36) ubuntu@ip-172-31-14-75:~$ python diagnose.py
----------Python Info----------
Version      : 3.6.5
Compiler     : GCC 7.2.0
Build        : ('default', 'Apr 29 2018 16:14:56')
Arch         : ('64bit', '')
------------Pip Info-----------
Version      : 10.0.1
Directory    : /home/ubuntu/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/pip
----------MXNet Info-----------
Version      : 1.3.1
Directory    : /home/ubuntu/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet
Commit Hash   : 19c501680183237d52a862e6ae1dc4ddc296305b
----------System Info----------
Platform     : Linux-4.4.0-1074-aws-x86_64-with-debian-stretch-sid
system       : Linux
node         : ip-172-31-14-75
release      : 4.4.0-1074-aws
version      : #84-Ubuntu SMP Thu Dec 6 08:57:58 UTC 2018
----------Hardware Info----------
machine      : x86_64
processor    : x86_64
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                8
On-line CPU(s) list:   0-7
Thread(s) per core:    2
Core(s) per socket:    4
Socket(s):             1
NUMA node(s):          1
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 79
Model name:            Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping:              1
CPU MHz:               2699.984
CPU max MHz:           3000.0000
CPU min MHz:           1200.0000
BogoMIPS:              4600.16
Hypervisor vendor:     Xen
Virtualization type:   full
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              46080K
NUMA node0 CPU(s):     0-7
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single kaiser fsgsbase bmi1 hle avx2smep bmi2 erms invpcid rtm rdseed adx xsaveopt
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0019 sec, LOAD: 0.3825 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0998 sec, LOAD: 0.0959 sec.
Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.5308 sec, LOAD: 0.2854 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0082 sec, LOAD: 0.1051 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0099 sec, LOAD: 0.3274 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0109 sec, LOAD: 0.0471 sec.

Package used (Python/R/Scala/Julia): Python

Minimum reproducible example

Using mx.nd.Dropout

%%timeit
data = mx.nd.random.uniform(shape=(100, 3, 224, 224), ctx=mx.gpu())
for i in range(100):
    # using mode='always' to force dropout behaviour
    data = mx.nd.Dropout(data, 0.5, mode='always')
mx.nd.waitall()
# 1.44 s ± 512 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Approx ~4.5 times slower than custom dropout.

Using Custom Dropout

%%timeit
data = mx.nd.random.uniform(shape=(100, 3, 224, 224), ctx=mx.gpu())
for i in range(100):
    dropout_mask = mx.nd.random.uniform(shape=(100, 3, 224, 224), ctx=mx.gpu()) > 0.5
    data = data * dropout_mask
mx.nd.waitall()
# 325 ms ± 338 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Approx ~4.5 times faster than mx.nd.Dropout.

Other examples, including backward pass.

Network using mx.gluon.nn.Dropout

net = mx.gluon.nn.HybridSequential()
net.add(mx.gluon.nn.Conv2D(channels=32, kernel_size=3, strides=2))
net.add(mx.gluon.nn.BatchNorm())
net.add(mx.gluon.nn.Activation('relu'))
net.add(mx.gluon.nn.Dropout(0.5))
net.add(mx.gluon.nn.GlobalMaxPool2D())
net.add(mx.gluon.nn.Dense(units=1000))
net.initialize(ctx=mx.gpu())
%%timeit
for i in range(10):
    data = mx.nd.random.uniform(shape=(100, 3, 224, 224), ctx=mx.gpu())
    with mx.autograd.record():
        output = net(data)
    output.backward()
    mx.nd.waitall()
# 524 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Approx 4 times slower than without.

Network without mx.gluon.nn.Dropout

net = mx.gluon.nn.HybridSequential()
net.add(mx.gluon.nn.Conv2D(channels=32, kernel_size=3, strides=2))
net.add(mx.gluon.nn.BatchNorm())
net.add(mx.gluon.nn.Activation('relu'))
net.add(mx.gluon.nn.GlobalMaxPool2D())
net.add(mx.gluon.nn.Dense(units=1000))
net.initialize(ctx=mx.gpu())
%%timeit
for i in range(10):
    data = mx.nd.random.uniform(shape=(100, 3, 224, 224), ctx=mx.gpu())
    with mx.autograd.record():
        output = net(data)
    output.backward()
    mx.nd.waitall()
# 123 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Approx 4 times faster than with.

@thomelane

This comment has been minimized.

Copy link
Contributor

thomelane commented Jan 10, 2019

@mxnet-label-bot add [Bug, Operator, Python, Performance, CUDA]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment