Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Adadelta optimization operator #4576

Merged
merged 7 commits into from
Oct 5, 2017

Conversation

abhinavarora
Copy link
Contributor

No description provided.

@abhinavarora abhinavarora self-assigned this Oct 3, 2017
framework::EigenVector<T>::Flatten(*avg_squared_update_out);
auto place = ctx.GetEigenDevice<Place>();

g_acc_out.device(place) = rho * g_acc + (1 - rho) * g.square();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can name g_acc_out as avg_squared_grad_eigen to be consistent with the formula written in the DOC, it will be better to read and understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will make the change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since avg_squared_grad_eigen is a litter too long, we can use avg_squared_grad here and change the formal value avg_squared_grad get from tensor to avg_squared_grad_t or something like this

Copy link
Member

@jacquesqiao jacquesqiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job, LGTM!

@abhinavarora abhinavarora merged commit 828c5b3 into PaddlePaddle:develop Oct 5, 2017
@abhinavarora abhinavarora deleted the adadelta branch October 5, 2017 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants