[AMP] Disallow converting layer norm to fp16#9782
Conversation
|
That's an interesting finding. In general, similar to BatchNorm, LayerNorm itself could be converted to FP16 as long as mean and variance stick to FP32, but I don't have this mechanism for now so this is a fair workaround. Meanwhile, do you think ToMixedPrecision should be applied later than the expression mutation passes (e.g., after SimplifyExpr, SimplifyInference, etc, and before FuseOps), so that we could make sure all ops (types) in the IR are basically fixed? |
3671d9b to
ba5786e
Compare
I thought about making |
|
Yeah we shouldn't have such dependency as running SimplifyInference or SimplifyExpr are not the hard requirement of ToMixedPrecision. One way I could think of is putting ToMixedPrecision into the compiler pass sequence, but only apply it via |
|
One reason for not putting |
|
Yeah it's indeed annoying. Anyways we now have two options:
I'm fine with either way so I'll let this PR go first. cc @AndrewZhaoLuo |
* [AMP] Disallow converting layer norm to fp16 * black
* [AMP] Disallow converting layer norm to fp16 * black
* [AMP] Disallow converting layer norm to fp16 * black
nn.layer_normis decomposed intomeanandvarianceduringSimplifyInference, so ifToMixedPrecisionis applied beforeSimplifyInference, we end up with fp16 inputmeanandvarianceeven though they are on theNEVERlist!This happens in the facebook DETR object detection model, but not in BERT.
@AndrewZhaoLuo @comaniac