Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[OP] Accelerate GPU version of LayerNorm(axis=-1) #14935

Merged
merged 1 commit into from
May 21, 2019

Conversation

sxjscience
Copy link
Member

@sxjscience sxjscience commented May 12, 2019

Description

Accelerate the speed of LayerNorm when (axis=-1).

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Accelerate the GPU version of LayerNorm when axis=-1

Comments

  • Benchmark Results are listed in the comments

@sxjscience
Copy link
Member Author

sxjscience commented May 16, 2019

We tested the speed of LayerNorm(axis=-1) with different batch, channel, dtype combinations. The results are listed as follows. We use both nvprof to profile the speed of the kernels and time.time() to check the overall running time of the python script. Here, we also highlight the running times of the batch/channel combinations in Large BERT, i.e., (B = 4096, 8192, C = 768, 1024).

We run the speed test in a P3.2 machine (V100). All experiments are repeated for 3 times and the average running time is reported. codes: https://github.com/sxjscience/benchmark_ops/blob/master/gen_layernorm_benchmark.py

To reproduce, run the following code:

git clone https://github.com/sxjscience/benchmark_ops.git
cd benchmark_ops
python gen_layernorm_benchmark.py

PyTorch + Apex + FP32 --> time in microsecond (us), apex: https://github.com/NVIDIA/apex

Forward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 8 19 31 56 106
C=64 8 19 32 56 108
C=128 6 15 23 42 77
C=256 6 17 27 49 92
C=512 7 26 38 70 134
C=768 8 31 45 82 157
C=1024 9 38 58 106 204

Backward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 8 18 25 42 91
C=64 8 18 25 51 99
C=128 8 19 31 58 105
C=256 9 28 41 72 132
C=512 10 44 64 115 218
C=768 10 62 92 174 332
C=1024 11 82 122 230 450

Backward Data (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 2 8 11 20 38
C=64 3 8 11 21 41
C=128 3 8 12 23 45
C=256 3 11 17 32 64
C=512 4 19 31 63 126
C=768 4 30 49 103 203
C=1024 5 45 72 142 286

Backward Gamma & Beta (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 5 11 14 23 53
C=64 5 11 14 30 58
C=128 5 11 18 35 60
C=256 5 17 24 40 69
C=512 6 25 33 52 92
C=768 6 32 42 71 129
C=1024 6 37 50 89 164

Forward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 105 107 107 135 161
C=64 107 108 109 128 209
C=128 106 112 145 155 208
C=256 104 120 149 187 216
C=512 103 171 186 204 246
C=768 108 183 190 211 261
C=1024 111 191 195 229 294

Backward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 236 246 222 261 231
C=64 215 222 228 250 302
C=128 219 232 301 273 289
C=256 238 237 269 310 314
C=512 221 300 291 311 322
C=768 222 295 291 305 392
C=1024 248 311 329 321 467

@sxjscience
Copy link
Member Author

sxjscience commented May 16, 2019

MXNet (new kernel) + FP32

According to nvprof, the performance of the new kernel matches that of nvidia/apex. However, if we check the overall running time of the python script, MXNet is much slower than PyTorch. This is caused by some other overheads and is not related to the CUDA kernel.

Forward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 6 11 15 25 46
C=64 6 11 15 25 44
C=128 6 11 16 25 43
C=256 7 17 24 43 79
C=512 9 26 35 60 111
C=768 9 37 54 98 188
C=1024 10 44 65 116 221

Backward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 8 17 23 38 87
C=64 8 18 23 50 102
C=128 8 18 31 60 110
C=256 9 31 47 83 150
C=512 10 46 70 126 237
C=768 12 71 104 195 372
C=1024 13 85 127 237 465

Backward Data (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 3 7 10 17 30
C=64 3 8 10 17 31
C=128 3 8 11 19 35
C=256 4 12 18 34 67
C=512 5 19 32 66 132
C=768 5 31 48 96 189
C=1024 5 43 68 133 265

Backward Gamma & Beta (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 5 10 13 21 57
C=64 5 10 13 32 71
C=128 5 10 20 42 75
C=256 5 18 29 49 83
C=512 6 26 38 60 106
C=768 7 40 56 99 183
C=1024 8 42 59 104 201

Forward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 249 367 370 399 463
C=64 266 311 373 428 469
C=128 251 371 384 434 487
C=256 258 396 426 473 513
C=512 255 410 453 496 533
C=768 279 441 457 530 621
C=1024 253 451 478 550 661

Backward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 255 305 352 329 353
C=64 256 321 342 326 376
C=128 254 311 322 340 389
C=256 255 322 333 361 421
C=512 256 331 351 405 539
C=768 305 344 386 479 671
C=1024 262 371 398 509 741

@eric-haibin-lin
Copy link
Member

Nice work! Can you retrigger CI?

fix lint

fix lint

fix bug

further accelerate

fix

fix bug

fix bug
@sxjscience
Copy link
Member Author

PyTorch + Apex + FP16
Forward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 8 20 34 60 115
C=64 8 20 34 61 116
C=128 9 20 35 62 118
C=256 7 17 27 49 92
C=512 7 22 34 62 119
C=768 8 28 43 79 153
C=1024 8 36 53 98 190

Backward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 8 18 24 41 75
C=64 8 18 25 42 91
C=128 8 19 25 51 99
C=256 8 20 33 62 114
C=512 9 30 45 79 144
C=768 10 39 58 102 188
C=1024 11 48 71 122 232

Backward Data (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 2 8 11 20 37
C=64 2 8 11 20 39
C=128 3 8 12 22 43
C=256 3 9 14 27 52
C=512 4 13 20 37 72
C=768 4 18 28 53 103
C=1024 5 23 37 69 140

Backward Gamma & Beta (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 5 10 13 21 38
C=64 5 10 14 22 52
C=128 5 10 14 29 56
C=256 5 11 19 35 62
C=512 5 17 25 42 72
C=768 6 21 30 49 85
C=1024 6 25 34 53 92

Forward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 106 108 130 133 203
C=64 108 107 113 153 192
C=128 107 129 120 184 232
C=256 106 122 168 199 220
C=512 108 181 190 203 241
C=768 119 190 190 212 261
C=1024 117 193 195 225 286

Backward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 239 239 270 238 297
C=64 219 243 246 296 273
C=128 218 272 257 302 297
C=256 216 256 291 304 302
C=512 250 300 297 317 326
C=768 259 318 300 307 316
C=1024 236 312 306 320 330

@sxjscience
Copy link
Member Author

sxjscience commented May 19, 2019

MXNet (new kernel) + FP16

Forward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 6 10 14 22 39
C=64 5 10 13 21 37
C=128 6 11 15 26 46
C=256 7 17 24 43 81
C=512 9 23 33 59 112
C=768 9 34 52 99 193
C=1024 10 40 61 115 224

Backward (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 8 16 20 33 59
C=64 8 16 21 34 74
C=128 8 16 21 43 80
C=256 9 21 33 64 116
C=512 10 32 47 82 149
C=768 12 50 74 135 258
C=1024 13 57 86 152 286

Backward Data (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 3 7 10 16 30
C=64 3 8 10 17 30
C=128 3 8 10 17 31
C=256 4 11 16 30 57
C=512 4 15 22 40 77
C=768 5 24 37 70 137
C=1024 5 28 42 80 158

Backward Gamma & Beta (nvprof timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 5 9 11 17 29
C=64 5 9 11 17 44
C=128 5 9 11 26 50
C=256 5 10 17 34 59
C=512 6 17 26 42 73
C=768 7 26 37 65 120
C=1024 8 30 44 72 128

Forward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 271 362 371 390 476
C=64 259 390 333 406 463
C=128 225 370 415 419 475
C=256 261 413 403 455 512
C=512 252 410 428 472 546
C=768 256 441 452 509 626
C=1024 255 451 474 522 659

Backward (python timer)

B=128 B=2560 B=4096 B=8192 B=16384
C=32 295 311 317 321 345
C=64 274 315 321 329 350
C=128 256 315 338 330 354
C=256 257 338 326 346 391
C=512 308 327 333 395 435
C=768 261 348 346 412 537
C=1024 264 347 359 423 564

@pinaraws
Copy link

@mxnet-label-bot add[Operator, pr-awaiting-review]

@marcoabreu marcoabreu added Operator pr-awaiting-review PR is waiting for code review labels May 20, 2019
@eric-haibin-lin eric-haibin-lin merged commit 8ae24bf into apache:master May 21, 2019
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
fix lint

fix lint

fix bug

further accelerate

fix

fix bug

fix bug
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants