-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi-Auto] Add layer_norm infer_backward rule #56505
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that 17ba34b's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
17ba34b
to
7fa7b30
Compare
<< "dst_dims_mapping: [" | ||
<< str_join(output_dist_attrs[i].dims_mapping()) << "]"; | ||
} | ||
VLOG(4) << "*********"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove meaningless log
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// begin_norm_axis=2, x=ij, y=kl) | ||
// ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, begin_norm_axis=2, z=ij) | ||
// ijkl,y(kl),y(kl)->ijkl,z(ij),z(ij) (x,scale,bias->out,mean,variance, | ||
// begin_norm_axis=2, z=ij, y=kl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
below in line 101: mean_axes = "j"; the notation should be "x" ?
otherwise "j" maybe confuse with the broadcast axis in before begin_norm_axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 101 is the case when begin_norm_axis<=1. When begin_norm_axis<=1, the first axis can be propagated to mean and var, so mean_axes and var_axes is set to be the same as input's first axis, which is "j".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
U are right
@@ -127,8 +126,8 @@ LayerNormSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, | |||
for (size_t i = 0; i < out_axes.size(); ++i) { | |||
if (i < static_cast<size_t>(begin_norm_axis)) { | |||
out_dims_mapping.push_back(x_dims_mapping[i]); | |||
// if ijk,k,k->ijk,x,x (x,scale,bias->out,mean,variance, | |||
// begin_norm_axis=2, x=ij), and the dims_mapping of input is (0,1,-1), | |||
// if ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the current rule for LN is problematic:
- not all axes before begin_norm_axis could be sharded.
- the first axis after begin_norm_axis could sharded in current implementation.
- it would be better to refactor the LN rule using TransDim algorithms: the axes mapping is like axes-flatten in TransDim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but it is ok for now, since most usage of LN is DP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, now only the first axis of the input can be sharded, other axes will be set to replicated.
7fa7b30
to
117c259
Compare
117c259
to
5f4c449
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Function optimization
PR changes
Others
Description
Pcard-70448
Add infer_backward rule for layer_norm to infer inputs' dims mappings from outputs'.