Skip to content

[AMP] Disallow converting layer norm to fp16#9782

Merged
masahi merged 2 commits intoapache:mainfrom
masahi:fp16-disable-layernorm
Dec 22, 2021
Merged

[AMP] Disallow converting layer norm to fp16#9782
masahi merged 2 commits intoapache:mainfrom
masahi:fp16-disable-layernorm

Conversation

@masahi
Copy link
Member

@masahi masahi commented Dec 21, 2021

nn.layer_norm is decomposed into mean and variance during SimplifyInference, so if ToMixedPrecision is applied before SimplifyInference, we end up with fp16 input mean and variance even though they are on the NEVER list!

This happens in the facebook DETR object detection model, but not in BERT.

@AndrewZhaoLuo @comaniac

@comaniac
Copy link
Contributor

That's an interesting finding. In general, similar to BatchNorm, LayerNorm itself could be converted to FP16 as long as mean and variance stick to FP32, but I don't have this mechanism for now so this is a fair workaround.

Meanwhile, do you think ToMixedPrecision should be applied later than the expression mutation passes (e.g., after SimplifyExpr, SimplifyInference, etc, and before FuseOps), so that we could make sure all ops (types) in the IR are basically fixed?

@masahi masahi force-pushed the fp16-disable-layernorm branch from 3671d9b to ba5786e Compare December 21, 2021 20:20
@masahi
Copy link
Member Author

masahi commented Dec 21, 2021

Meanwhile, do you think ToMixedPrecision should be applied later than the expression mutation passes (e.g., after SimplifyExpr, SimplifyInference, etc, and before FuseOps), so that we could make sure all ops (types) in the IR are basically fixed?

I thought about making SimplifyInference a prereq pass for ToMixedPrecision, but I think putting ToMixedPrecision later in the pipeline makes it a bit inconvenient to use this pass (because it is fundamentally an optional feature). So it is more natural to invoke ToMixedPrecision somewhere in a user script. Always running SimplifyInference before ToMixedPrecision may also break someone's workflow that depends on a particular pattern be present in a mod.

@comaniac
Copy link
Contributor

Yeah we shouldn't have such dependency as running SimplifyInference or SimplifyExpr are not the hard requirement of ToMixedPrecision. One way I could think of is putting ToMixedPrecision into the compiler pass sequence, but only apply it via PassContext(config={"ToMixedPrecision": "float16"}).

@masahi
Copy link
Member Author

masahi commented Dec 22, 2021

One reason for not putting ToMixedPrecision in the pipeline: In most cases I want to see the converted fp16 models, which is hard to do if the conversion is done in the middle.

@comaniac
Copy link
Contributor

Yeah it's indeed annoying. Anyways we now have two options:

  1. Keep the current approach. In this way, the best practice I'm thinking is asking users to run these passes manually before ToMixedPrecision.
  2. Put it into the pipeline. In this way, we have to hack the pipeline to print the IR if needed.

I'm fine with either way so I'll let this PR go first. cc @AndrewZhaoLuo

@masahi masahi merged commit 8fa5464 into apache:main Dec 22, 2021
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 7, 2022
* [AMP] Disallow converting layer norm to fp16

* black
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
* [AMP] Disallow converting layer norm to fp16

* black
qsqqsqqsq-intellif pushed a commit to qsqqsqqsq-intellif/tvm that referenced this pull request Apr 29, 2022
* [AMP] Disallow converting layer norm to fp16

* black
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants