Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusedLayerNorm vs torch.nn.LayerNorm #449

Open
dhpollack opened this issue Aug 23, 2019 · 14 comments
Open

FusedLayerNorm vs torch.nn.LayerNorm #449

dhpollack opened this issue Aug 23, 2019 · 14 comments

Comments

@dhpollack
Copy link

What's the advantage of using the FusedLayerNorm over torch.nn.LayerNorm? I'm running into an issue with using TorchScript and I'm wondering if I can replace the former with the latter.

The deeper question is: Is the apex version of layer norm significantly optimized over the standard pytorch version or is it simply a legacy of when pytorch did not have a built in layer norm function?

@ptrblck
Copy link
Contributor

ptrblck commented Sep 2, 2019

FusedLayerNorm should give a speedup compared to torch.nn.LayerNorm.
Gist for profiling: https://gist.github.com/ptrblck/8b1c6a7efd97604a7dedbf2c3edd1019

@bryant1410
Copy link

I'm trying that gist with 10,000 iterations on a V100 and torch.nn.LayerNorm is faster:

upstream layernorm 32.502
apex layernorm 33.823

And it's even more if I convert the input to half and do norm.half() and fused_norm.half:

upstream layernorm 23.555
apex layernorm 31.152

@bryant1410
Copy link

@ptrblck
Copy link
Contributor

ptrblck commented Sep 23, 2019

@bryant1410 Thanks for reporting!
I'm not sure, where the perf regression comes from.
However, @zasdfgbnm is porting out FusedLayerNorm approach to PyTorch in this PR, which should land hopefully soon.

@zasdfgbnm
Copy link

zasdfgbnm commented Sep 23, 2019

This is what I get on the DGX station (Tesla V100 GPU) with the master-py3-devel docker:

upstream layernorm 31.495
apex layernorm 32.434
upstream half layernorm 22.503
apex half layernorm 29.342

With pytorch and apex master compiled from source:

upstream layernorm 26.903
apex layernorm 32.299
upstream half layernorm 20.610
apex half layernorm 29.691

Pull request pytorch/pytorch#26201 ("upstream" means layernorm implementation ported from APEX):

upstream layernorm 39.049
apex layernorm 32.504
upstream half layernorm 35.924
apex half layernorm 29.523

@bryant1410
Copy link

Btw, I ran mine with the commit 880ab92, not with master.

@ngoyal2707
Copy link

Is The Fused LN heavily optimized for transformer like application cause I do get big speed up for standard NLP representation (T, B, C) and taking norm across C.

upstream layernorm 2.036
apex layernorm 0.620
upstream half layernorm 1.470
apex half layernorm 0.473

@zasdfgbnm
Copy link

@ngoyal2707 Which upstream version are you using? The layernorm in upstream has been improved a lot recently.

@myleott
Copy link
Contributor

myleott commented May 1, 2020

I observe the same issue as @ngoyal2707 on PyTorch 1.5 -- torch.nn.LayerNorm is slower than apex.FusedLayerNorm for shapes typical in NLP models. For example: (512, 16, 1024) with normalization over the last dimension is slower using torch.nn.LayerNorm.

@vgoklani
Copy link

I also see a performance boost using the FusedLayerNorm for our NLP-based transformer.

@pommedeterresautee
Copy link

pommedeterresautee commented Jun 16, 2020

I just replaced all LayerNorm by the apex version in a model from Transformers library (Roberta based), and on a real dataset with sequence length on average of 200 tokens. So basically real life setup, I can't measure any difference. I have also run the benchmark and I get on the same machine :

upstream layernorm 2.132
apex layernorm 2.745

@vgoklani is it a custom transformer or from an OSS library?

@hitvoice
Copy link

hitvoice commented Jul 16, 2020

I ran the gist with shape (32, 128, 768) which is common in Transformers on V100/CUDA10. What I got:

upstream layernorm 0.136
apex layernorm 0.040
upstream layernorm(half) 0.106
apex layernorm(half) 0.047

After changing the sequence length to 256:

upstream layernorm 0.258
apex layernorm 0.070
upstream layernorm(half) 0.203
apex layernorm(half) 0.045

@pommedeterresautee I suggest you provide your device and CUDA version and that'll be more helpful.

@pommedeterresautee
Copy link

@hitvoice 2080 TI and apex from master branch at the time of my precedent message, so June 16th

@hellojialee
Copy link

I ran the gist provided by @ptrblck on a 2080TI GPU, following is the result:

upstream layernorm 2.282
apex layernorm 3.018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants