Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[PERFORMANCE] [v1.x] Layer normalization code from Marian for CPU #19601

Merged
merged 13 commits into from Jan 5, 2021

Conversation

kpuatamazon
Copy link
Contributor

Description

Adds a CPU kernel for LayerNorm that handles the common case of axis = -1. This is based upon the implementation from Marian at https://github.com/marian-nmt/marian-dev/blob/3b468e462809fe42a01a717c8d9307c465e6c35e/src/tensors/cpu/tensor_operators.cpp#L1047-L1087 .

Compared to the MXNet-internal generic implementation, the kernel is 1.6-29x faster. When used in Sockeye, end-to-end translation is 14%.
Compared to the MKL implementation, the kernel is 0.9-2.28x faster. Marian's is faster than MKL for all channels tested wider than 32.

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage. There's already a test_operator.py:test_layer_norm that covers this well and it passes.
  • Code is well-documented---more documented than the baseline

Changes

  • Copy Marian optimized CPU LayerNorm implementation and adapt to MXNet.
  • Refactor dispatch of optimized versions using bool return value.

Benchmarks

Speed

  • Shapes borrowed from [OP] Accelerate GPU version of LayerNorm(axis=-1) #14935
  • c5.12xlarge
  • Based on db08005 (v1.x)
  • Ubuntu 18
  • cmake -DCMAKE_BUILD_TYPE=Release -DUSE_MKLDNN=ON -DUSE_CUDA=OFF -DUSE_TVM_OP=OFF -DUSE_MKL_IF_AVAILABLE=OFF -DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8 -GNinja except for the MKL case when -DUSE_MKL_IF_AVAILABLE=ON
  • MKL 20190005 when used.
  • Time in seconds.
  • export OMP_NUM_THREADS=4

Benchmark program

#!/usr/bin/env python3
import mxnet as mx
import time

def time_procedure(shape, count):
  data = mx.nd.random_uniform(shape=shape, low=-1.0, high = 1.0)
  factors = mx.nd.random_uniform(shape=(shape[-1],))
  mx.nd.waitall()
  begin = time.time()
  for i in range(0, count):
    out = mx.nd.LayerNorm(data, factors, factors)
    mx.nd.waitall()
  return (time.time() - begin) / count

count = 200

for channel in [32, 64, 128, 256, 512, 768, 1024]:
  for batch in [1, 128, 2560, 4096, 8192, 16384]:
    s = (batch, channel)
    timing = time_procedure(s, count)
    print("{:5d}x{:5d} | {:.7f}".format(s[0], s[1], timing))

Here are the results (in seconds). Yes, I included first run. Make your JIT faster.

Shape Marian MKL MXNet Generic Marian speedup v MKL Marian speedup v MXNet
1x 32 0.0000254 0.0000267 0.0000409 1.05x 1.61x
128x 32 0.0000318 0.0000308 0.0000632 0.97x 1.99x
2560x 32 0.0000690 0.0000679 0.0004944 0.98x 7.17x
4096x 32 0.0000952 0.0000907 0.0007636 0.95x 8.02x
8192x 32 0.0001591 0.0001503 0.0015753 0.94x 9.90x
16384x 32 0.0002900 0.0002633 0.0030074 0.91x 10.37x
1x 64 0.0000240 0.0000249 0.0000399 1.04x 1.66x
128x 64 0.0000311 0.0000327 0.0000837 1.05x 2.69x
2560x 64 0.0000826 0.0000984 0.0009193 1.19x 11.13x
4096x 64 0.0001142 0.0001366 0.0015389 1.20x 13.48x
8192x 64 0.0001985 0.0002446 0.0029263 1.23x 14.74x
16384x 64 0.0003815 0.0004561 0.0056857 1.20x 14.90x
1x 128 0.0000243 0.0000254 0.0000401 1.05x 1.65x
128x 128 0.0000342 0.0000397 0.0001280 1.16x 3.74x
2560x 128 0.0001063 0.0001594 0.0018591 1.50x 17.49x
4096x 128 0.0001501 0.0002355 0.0028828 1.57x 19.21x
8192x 128 0.0002695 0.0004378 0.0055950 1.62x 20.76x
16384x 128 0.0005846 0.0008852 0.0110546 1.51x 18.91x
1x 256 0.0000252 0.0000272 0.0000424 1.08x 1.68x
128x 256 0.0000381 0.0000446 0.0002133 1.17x 5.60x
2560x 256 0.0001542 0.0002870 0.0035257 1.86x 22.86x
4096x 256 0.0002241 0.0004369 0.0055310 1.95x 24.68x
8192x 256 0.0005067 0.0008487 0.0109084 1.67x 21.53x
16384x 256 0.0011817 0.0017543 0.0217319 1.48x 18.39x
1x 512 0.0000262 0.0000306 0.0000475 1.17x 1.81x
128x 512 0.0000405 0.0000549 0.0003818 1.36x 9.43x
2560x 512 0.0002462 0.0005229 0.0068302 2.12x 27.74x
4096x 512 0.0003823 0.0008172 0.0108432 2.14x 28.36x
8192x 512 0.0008764 0.0017205 0.0216015 1.96x 24.65x
16384x 512 0.0057181 0.0072662 0.0464290 1.27x 8.12x
1x 768 0.0000274 0.0000309 0.0000519 1.13x 1.89x
128x 768 0.0000439 0.0000675 0.0005498 1.54x 12.52x
2560x 768 0.0003469 0.0007757 0.0101437 2.24x 29.24x
4096x 768 0.0005857 0.0013381 0.0161946 2.28x 27.65x
8192x 768 0.0014930 0.0026524 0.0322792 1.78x 21.62x
16384x 768 0.0088047 0.0110582 0.0698267 1.26x 7.93x
1x 1024 0.0000275 0.0000330 0.0000573 1.20x 2.08x
128x 1024 0.0000486 0.0000790 0.0007189 1.63x 14.79x
2560x 1024 0.0004582 0.0010214 0.0135037 2.23x 29.47x
4096x 1024 0.0008070 0.0017359 0.0215496 2.15x 26.70x
8192x 1024 0.0057007 0.0073134 0.0463280 1.28x 8.13x
16384x 1024 0.0116098 0.0147560 0.0935520 1.27x 8.06x

AWS Sockeye

Observed a 14% speed up in end-to-end machine translation with Sockeye. Sockeye 2.2 (29795b82) on a c5.12xlarge with export OMP_NUM_THREADS=4 translating a test set.

Compiled on Ubuntu 18 with cmake -DCMAKE_BUILD_TYPE=Release -DUSE_MKLDNN=ON -DUSE_CUDA=OFF -DUSE_TVM_OP=OFF -DUSE_MKL_IF_AVAILABLE=OFF -DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8 -GNinja .. Note: no MKL.

Before

[INFO:__main__] Processed 2964 lines. Total time: 133.3097, sec/sent: 0.0450, sent/sec: 22.2339

real	2m15.716s
user	9m52.988s
sys	0m13.504s

After

[INFO:__main__] Processed 2964 lines. Total time: 116.6679, sec/sent: 0.0394, sent/sec: 25.4054

real	1m58.858s
user	8m45.803s
sys	0m13.823s

The above runs were done as normal, without the profiler. I then turned the profiler on. We can see that LayerNorm is consuming a substantial amount of time:
Before

operator
=================
Name                          Total Count        Time (ms)    Min Time (ms)    Max Time (ms)    Avg Time (ms)
----                          -----------        ---------    -------------    -------------    -------------
_contrib_intgemm_fully_connected          822520       26357.8887           0.0090           0.3390           0.0320
LayerNorm                          459522       20225.8086           0.0230           0.4860           0.0440
elemwise_add                       601340        7813.2148           0.0040           0.1970           0.0130
_contrib_interleaved_matmul_encdec_qk          155884        7557.1152           0.0050           0.3560           0.0485
_contrib_interleaved_matmul_encdec_valatt          155884        6168.3472           0.0040           0.4120           0.0396
FullyConnected                      48262        4070.1250           0.0260           4.7480           0.0843
DeleteVariable                    1577462        3830.7241           0.0000           0.3660           0.0024
Concat                             107622        3493.2380           0.0100           0.2970           0.0325
take                               386096        3484.5449           0.0020           1.5600           0.0090
SliceChannel                        65296        3468.1431           0.0060           0.4370           0.0531
where                              144786        3203.5801           0.0030           0.2090           0.0221
Activation                         252408        3095.2820           0.0060           0.1750           0.0123

After

operator
=================
Name                          Total Count        Time (ms)    Min Time (ms)    Max Time (ms)    Avg Time (ms)
----                          -----------        ---------    -------------    -------------    -------------
_contrib_intgemm_fully_connected          822316       25351.8438           0.0090           0.4190           0.0308
elemwise_add                       601170        8229.7861           0.0040           0.1650           0.0137
_contrib_interleaved_matmul_encdec_qk          155850        7577.9399           0.0050           0.4030           0.0486
_contrib_interleaved_matmul_encdec_valatt          155850        6169.1318           0.0040           0.4310           0.0396
FullyConnected                      48245        4170.0972           0.0240           4.8480           0.0864
DeleteVariable                    1576986        3935.9939           0.0000           0.3490           0.0025
take                               385960        3624.0161           0.0020           2.4180           0.0094
Concat                             107605        3561.9041           0.0100           0.3540           0.0331
SliceChannel                        65296        3475.8010           0.0060           0.5690           0.0532
where                              144735        3241.1169           0.0030           0.2080           0.0224
Activation                         252340        2855.7710           0.0050           0.2440           0.0113
LayerNorm                          459403        2791.0029           0.0040           0.0540           0.0061

The new implementation is 7.21x as fast on average according to the profiler.

The number of LayerNorm invocations changes 0.02% because beam search iterations are impacted by tie breaking.

Unit test

Before: 62.210s
After: 61.321s

But note unit tests spend most of their time comparing things rather than running the kernels.

Comments

  • LayerNorm is just one of those kernels that changes slightly with any implementation, so outputs that depend on near-ties will change.
  • float16 support on CPU accumulates in float32. Since float16 only exists in conversion on CPU, this is faster anyway. Also, there wasn't an OMP reduction for float16.
  • There is no threaded parallelization within a channel (i.e. to sum). I doubt existing channel sizes justify this given the cost of threading on CPUs.

Kenneth Heafield added 2 commits November 30, 2020 14:34
Experiment with OMP_NUM_THREADS=4, times in s, c5.12xlarge

|batchxchanne| New code | MKL      |
|    1x   32 | 0.0000288| 0.0000278|
|  128x   32 | 0.0000308| 0.0000311|
| 2560x   32 | 0.0000712| 0.0000672|
| 4096x   32 | 0.0000946| 0.0000910|
| 8192x   32 | 0.0001597| 0.0001523|
|16384x   32 | 0.0002905| 0.0002619|
|    1x   64 | 0.0000264| 0.0000256|
|  128x   64 | 0.0000339| 0.0000330|
| 2560x   64 | 0.0000829| 0.0000972|
| 4096x   64 | 0.0001137| 0.0001356|
| 8192x   64 | 0.0002027| 0.0002435|
|16384x   64 | 0.0003715| 0.0004639|
|    1x  128 | 0.0000262| 0.0000263|
|  128x  128 | 0.0000325| 0.0000389|
| 2560x  128 | 0.0001074| 0.0001580|
| 4096x  128 | 0.0001505| 0.0002336|
| 8192x  128 | 0.0002861| 0.0004481|
|16384x  128 | 0.0005648| 0.0008613|
|    1x  256 | 0.0000273| 0.0000276|
|  128x  256 | 0.0000390| 0.0000431|
| 2560x  256 | 0.0001533| 0.0002811|
| 4096x  256 | 0.0002258| 0.0004300|
| 8192x  256 | 0.0004300| 0.0008464|
|16384x  256 | 0.0010436| 0.0017613|
|    1x  512 | 0.0000256| 0.0000302|
|  128x  512 | 0.0000408| 0.0000551|
| 2560x  512 | 0.0002444| 0.0005225|
| 4096x  512 | 0.0003828| 0.0008147|
| 8192x  512 | 0.0008832| 0.0017192|
|16384x  512 | 0.0058463| 0.0074497|
|    1x  768 | 0.0000252| 0.0000308|
|  128x  768 | 0.0000450| 0.0000676|
| 2560x  768 | 0.0003440| 0.0007719|
| 4096x  768 | 0.0005890| 0.0013346|
| 8192x  768 | 0.0014946| 0.0026145|
|16384x  768 | 0.0089495| 0.0113557|
|    1x 1024 | 0.0000285| 0.0000308|
|  128x 1024 | 0.0000487| 0.0000786|
| 2560x 1024 | 0.0004614| 0.0010190|
| 4096x 1024 | 0.0008083| 0.0017376|
| 8192x 1024 | 0.0059020| 0.0075588|
|16384x 1024 | 0.0116553| 0.0146855|

Benchmark program
```python
import mxnet as mx
import time

def time_procedure(shape, count):
  data = mx.nd.random_uniform(shape=shape, low=-1.0, high = 1.0)
  factors = mx.nd.random_uniform(shape=(shape[-1],))
  mx.nd.waitall()
  begin = time.time()
  for i in range(0, count):
    out = mx.nd.LayerNorm(data, factors, factors)
    mx.nd.waitall()
  return (time.time() - begin) / count

count = 200

for channel in [32, 64, 128, 256, 512, 768, 1024]:
  for batch in [1, 128, 2560, 4096, 8192, 16384]:
    s = (batch, channel)
    timing = time_procedure(s, count)
    print("{:5d}x{:5d} | {:.7f}".format(s[0], s[1], timing))
```
@mxnet-bot
Copy link

Hey @kpuatamazon , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [clang, windows-cpu, centos-cpu, miscellaneous, windows-gpu, unix-gpu, edge, sanity, unix-cpu, website, centos-gpu]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Nov 30, 2020
@kpuatamazon
Copy link
Contributor Author

Lint broken?
@mxnet-bot run ci [sanity]

[2020-11-30T17:51:32.903Z] �[91m+ pip3 install -r /work/requirements
[2020-11-30T17:51:33.157Z] �[0m�[91mDEPRECATION: Python 3.5 reached the end of its life on September 13th, 2020. Please upgrade your Python as Python 3.5 is no longer maintained. pip 21.0 will drop support for Python 3.5 in January 2021. pip 21.0 will remove support for this functionality.
[2020-11-30T17:51:35.665Z] �[0mCollecting astroid==2.3.3
[2020-11-30T17:51:35.665Z]   Downloading astroid-2.3.3-py3-none-any.whl (205 kB)
[2020-11-30T17:51:35.665Z] 
[2020-11-30T17:51:35.665Z] The conflict is caused by:
[2020-11-30T17:51:35.665Z]     The user requested six==1.11.0
[2020-11-30T17:51:35.665Z]     astroid 2.3.3 depends on six~=1.12
[2020-11-30T17:51:35.665Z] 
[2020-11-30T17:51:35.665Z] To fix this you could try to:
[2020-11-30T17:51:35.665Z] 1. loosen the range of package versions you've specified
[2020-11-30T17:51:35.665Z] 2. remove package versions to allow pip attempt to solve the dependency conflict
[2020-11-30T17:51:35.665Z] 
[2020-11-30T17:51:35.665Z] �[91mERROR: Cannot install -r /work/requirements (line 31) and six==1.11.0 because these package versions have conflicting dependencies.
[2020-11-30T17:51:35.665Z] ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/user_guide/#fixing-conflicting-dependencies
[2020-11-30T17:51:36.222Z] �[0mThe command '/bin/sh -c /work/ubuntu_python.sh' returned a non-zero code: 1
[2020-11-30T17:51:36.222Z] Traceback (most recent call last):
[2020-11-30T17:51:36.222Z]   File "ci/build.py", line 456, in <module>
[2020-11-30T17:51:36.222Z]     sys.exit(main())
[2020-11-30T17:51:36.222Z]   File "ci/build.py", line 366, in main
[2020-11-30T17:51:36.222Z]     cache_intermediate=args.cache_intermediate)
[2020-11-30T17:51:36.222Z]   File "ci/build.py", line 114, in build_docker
[2020-11-30T17:51:36.222Z]     run_cmd()
[2020-11-30T17:51:36.222Z]   File "/home/jenkins_slave/workspace/sanity-lint/ci/util.py", line 84, in f_retry
[2020-11-30T17:51:36.222Z]     return f(*args, **kwargs)
[2020-11-30T17:51:36.222Z]   File "ci/build.py", line 112, in run_cmd
[2020-11-30T17:51:36.222Z]     check_call(cmd)
[2020-11-30T17:51:36.222Z]   File "/usr/lib/python3.6/subprocess.py", line 311, in check_call
[2020-11-30T17:51:36.222Z]     raise CalledProcessError(retcode, cmd)
[2020-11-30T17:51:36.222Z] subprocess.CalledProcessError: Command '['docker', 'build', '-f', 'docker/Dockerfile.build.ubuntu_cpu', '--build-arg', 'USER_ID=1001', '--build-arg', 'GROUP_ID=1001', '--cache-from', 'mxnetci/build.ubuntu_cpu', '-t', 'mxnetci/build.ubuntu_cpu', 'docker']' returned non-zero exit status 1.
script returned exit code 1

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [sanity]

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Nov 30, 2020
@lanking520 lanking520 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Nov 30, 2020
@kpuatamazon kpuatamazon changed the title [PERFORMANCE] Layer normalization code from Marian for CPU [PERFORMANCE] [v1.x] Layer normalization code from Marian for CPU Nov 30, 2020
@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Dec 1, 2020
@kpuatamazon
Copy link
Contributor Author

@mxnet-bot run ci [sanity]

Maybe #19604 fixed lint?

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [sanity]

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Dec 4, 2020
@kpuatamazon
Copy link
Contributor Author

@mxnet-bot run ci [centos-cpu, centos-gpu, clang, edge, miscellaneous, unix-cpu, unix-gpu, website, windows-cpu, windows-gpu]

These have been "Expected" for days, seems the results got lost.

@kpuatamazon
Copy link
Contributor Author

@mxnet-bot run ci [website]

Bot didn't respond, is anybody home?

@lanking520 lanking520 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Dec 4, 2020
@lanking520 lanking520 added the pr-work-in-progress PR is still work in progress label Dec 7, 2020
@kpuatamazon
Copy link
Contributor Author

@mxnet-bot run ci [unix-cpu]

#19081 seed 675318784 causes the test to fail in v1.x as well.

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [unix-cpu]

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Dec 9, 2020
@samskalicky
Copy link
Contributor

I restarted the CI jobs a few times, looks like its passing now.

Is it possible that the MKL implementation's performance might improve in the future? Should we keep that and hide it behind a build flag, making the Marian implementation default?

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-awaiting-review PR is waiting for code review labels Dec 21, 2020
@kpuatamazon
Copy link
Contributor Author

Hi @samskalicky as requested there is now a -DUSE_MKL_LAYERNORM=ON with which to call MKL and the old wrapper is there.

My one-day-a-week contract ends 31 December 2020 so this is partly a goodbye and hope to get this in. I will be in today and probably 28 December. Afterwards, I am just @kpu.

@lanking520 lanking520 added pr-awaiting-review PR is waiting for code review and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Dec 21, 2020
Copy link
Contributor

@samskalicky samskalicky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the MKL option, LGTM!

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-awaiting-review PR is waiting for code review labels Dec 28, 2020
@kpuatamazon
Copy link
Contributor Author

I've merged the latest v1.x in, added the USE_MKL_LAYERNORM to Makefile, and tested that option (for @szha) in build_ubuntu_cpu_mkl CI for consistency with master.

Today is my last day. Hope it works.

@lanking520 lanking520 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Dec 28, 2020
@fhieber
Copy link
Contributor

fhieber commented Jan 4, 2021

What are the next steps for this PR? Is this ready to be merged?

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Jan 4, 2021
@szha szha merged commit 99420a0 into apache:v1.x Jan 5, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants