New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Layernorm one_plus support in fast_layer_norm #1557
base: master
Are you sure you want to change the base?
Conversation
@@ -80,7 +80,8 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp | |||
std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size | |||
const at::Tensor &gamma, // hidden_size | |||
const at::Tensor &beta, // hidden_size | |||
const float epsilon | |||
const float epsilon, | |||
const float one_plus |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a float anyway, we could call it sth like gamma_shift
and allow values different than 0 and 1.
This is not needed so this is just a side note.
for is_1p in use_1p: | ||
for h in hidden_sizes: | ||
with self.subTest(f"hidden_size={h}, use_1p={is_1p}"): | ||
self.assertAll(_test_impl(256, 2, h, fp32, fp32, is_1p)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an existing test (already in Apex) that checks that custom LayerNorm (like FastLayerNormFN) behaves exactly the same as e.g. PyT LayerNorm? It would be nice to double check that our modification with adding +1 is mathematically correct. It seems that _test_impl
uses backward_
to check correctness, which is using the same +1 logic as the kernels.
In this implementation, gamma tensor will be initialized to 0 (instead of 1) and the +1 is numerically handled. It is enabled by setting use_1p=True (default: False)