Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add layer normalization operator #7789

Merged
merged 10 commits into from
Jan 31, 2018

Conversation

chengduoZH
Copy link
Contributor

@chengduoZH chengduoZH commented Jan 23, 2018

fix #7174
The process of writing this Op is a bit complicated and encountered some rare problems, I write those problems here and hope that others can learn some things from these problems.
layer normalization is a new op, different from batch normalization, it normalizes for a sample, but not for features of min_batch. So I write this op but not reuse batch normalization. This paper is an introduction to it.
The problem is in the process of computing the gradient. When the dy is the same value, according to the formula, the dx should be zero or nearly zero(about 1e-10), but the numerical gradient is not that. The dx of the numerical gradient is a big data(about 1e-3). So I think the dx of the numerical gradient is wrong.

wx20180127-130321 2x

Because of above reasons, I assign dy to random data and compute the result of dx by Python. And I compare the result of dx of by Python and that of by C++. The equation of comparison is the same of op_test's.

@chengduoZH chengduoZH force-pushed the feature/layer_norm branch 5 times, most recently from 0bbc2c7 to 2b9ac13 Compare January 24, 2018 05:33
@chengduoZH chengduoZH force-pushed the feature/layer_norm branch 5 times, most recently from f35cae2 to 681a95a Compare January 25, 2018 16:53
@chengduoZH chengduoZH changed the title [WIP] Add layer normalization operator Add layer normalization operator Jan 25, 2018
@chengduoZH chengduoZH force-pushed the feature/layer_norm branch 3 times, most recently from d6c2df6 to 2ad0642 Compare January 29, 2018 14:59

auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), left, 1);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), left, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the size of scale and bias be right rather than left. If left, the size of scale might be coupled with batch size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, thanks!

Copy link
Contributor

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this work.

PADDLE_ENFORCE(ctx->HasInput("X"), "");
PADDLE_ENFORCE(ctx->HasInput("Scale"), "");
PADDLE_ENFORCE(ctx->HasInput("Bias"), "");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 36 ~ 39, would you please complete the comments before merging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddInput("X", "The input tensor");
AddInput("Scale",
"Scale is a 1-dimensional tensor of size H "
"that is applied to the output");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a full stop at the end of the comment. Also for X, Bias and Y.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

});
AddAttr<int>("begin_norm_axis",
"(int default:1), the "
"axis of `begin_norm_axis ... Rank(X) - 1` will be normalized")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • At a full stop at the end of the comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

.AddCustomChecker([](const int &begin_norm_axis) {
PADDLE_ENFORCE_GT(begin_norm_axis, 0,
"'begin_norm_axis' should be greater than zero.");
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here may need two attributes to let the user decide whether to apply the scale and bias to normed output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary, we can mark Bias and Scale with AsDispensable.
If the inputs don't include ``BiasorScale`, the program will ignore it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK. I mean just mark Bias and Scale as internal variables for internal implementation, because for most users they do not care about these. Admittedly, this is not very important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because for most users they do not care about these.

You are right. The interface of layer_norm in Python can use two arguments to make user decide whether to apply the scale or bias to normed output.

PADDLE_ENFORCE(ctx->HasInput("Scale"), "");
PADDLE_ENFORCE(ctx->HasInput("Mean"), "");
PADDLE_ENFORCE(ctx->HasInput("Variance"), "");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you please complete the comment before merging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddInput("Bias",
"Bias is a 1-dimensional tensor of size H "
"that is applied to the output");
AddOutput("Y", "result after normalization");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a full stop at the end of the comment. Also for X, Bias and Y.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"that is applied to the output");
AddOutput("Y", "result after normalization");
AddOutput("Mean", "Mean of the current mini batch.");
AddOutput("Variance", "Variance of the current mini batch.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark mean and variance with .AsIntermediate();. User will not use it in layer norm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y, axis=1).reshape([N, 1])
# d_mean_1 = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape(
# [N, 1]) * (-1.0 / D * np.sqrt(1.0 / var) *
# np.sum(x - mean, axis=1).reshape([N, 1])).reshape([N, 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are line 67 ~ 69 useless? If so, they should be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

lcy-seso
lcy-seso previously approved these changes Jan 30, 2018
Copy link
Contributor

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We can merge this implementation first.

Copy link
Contributor

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lcy-seso lcy-seso merged commit e261c79 into PaddlePaddle:develop Jan 31, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the layer normalization operator.
3 participants