-
Notifications
You must be signed in to change notification settings - Fork 756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
skip rms norm #10036
skip rms norm #10036
Conversation
|
||
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") | ||
@flow.unittest.skip_unless_1n1d() | ||
class TestSkipLayerNorm(flow.unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RMS?
@@ -6959,6 +6959,28 @@ def OneFlow_RmsNormGradOp : OneFlow_BaseOp<"rms_norm_grad", [NoSideEffect, Decla | |||
let has_data_type_infer_fn = 1; | |||
} | |||
|
|||
def OneFlow_SkipRmsNormOp : OneFlow_BaseOp<"skip_rms_norm", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AttrSizedOperandSegments
let attrs = (ins | ||
ShapeAttr:$normalized_shape, | ||
DefaultValuedAttr<F32Attr, "0.00001">:$epsilon, | ||
DefaultValuedAttr<F32Attr, "0.00001">:$alpha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值不合理,同时把skip_layer_norm的默认值也处理一下
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10036/ |
No description provided.