-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support adagrad sparse update #5272
Changes from 4 commits
65605a3
7bc425d
2f03c72
d0e20f7
c72e7a5
f184d65
ae3d22a
1349abd
59aeb85
c982f9c
ce119ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,14 @@ limitations under the License. */ | |
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename Place, typename T> | ||
struct SparseAdagradFunctor { | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::SelectedRows& grad, | ||
const framework::Tensor& learning_rate, T epsilon, | ||
framework::Tensor* moment, framework::Tensor* param); | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class AdagradOpKernel : public framework::OpKernel<T> { | ||
public: | ||
|
@@ -29,25 +37,43 @@ class AdagradOpKernel : public framework::OpKernel<T> { | |
param_out_tensor->mutable_data<T>(ctx.GetPlace()); | ||
moment_out_tensor->mutable_data<T>(ctx.GetPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to move lines 26 - 30 into the if else block since they are only used for LoDTensor type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, we have to allocate memory before calculating. |
||
|
||
float epsilon = ctx.Attr<float>("epsilon"); | ||
|
||
auto param = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("Param")); | ||
auto grad = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("Grad")); | ||
auto moment = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("Moment")); | ||
auto lr = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("LearningRate")); | ||
|
||
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); | ||
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); | ||
auto place = ctx.GetEigenDevice<Place>(); | ||
|
||
moment_out.device(place) = moment + grad * grad; | ||
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); | ||
param_out.device(place) = | ||
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); | ||
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); | ||
|
||
auto* grad_var = ctx.InputVar("Grad"); | ||
if (grad_var->IsType<framework::LoDTensor>()) { | ||
auto param = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("Param")); | ||
auto grad = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("Grad")); | ||
auto moment = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("Moment")); | ||
auto lr = framework::EigenVector<T>::Flatten( | ||
*ctx.Input<framework::Tensor>("LearningRate")); | ||
|
||
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); | ||
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); | ||
auto place = ctx.GetEigenDevice<Place>(); | ||
|
||
moment_out.device(place) = moment + grad * grad; | ||
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); | ||
param_out.device(place) = | ||
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); | ||
} else if (grad_var->IsType<framework::SelectedRows>()) { | ||
auto* param = ctx.Input<framework::Tensor>("Param"); | ||
auto* param_out = ctx.Output<framework::Tensor>("ParamOut"); | ||
PADDLE_ENFORCE_EQ(param, param_out); | ||
|
||
auto* moment = ctx.Input<framework::Tensor>("Moment"); | ||
auto* moment_out = ctx.Output<framework::Tensor>("MomentOut"); | ||
PADDLE_ENFORCE_EQ(moment, moment_out); | ||
|
||
SparseAdagradFunctor<Place, T> functor; | ||
functor(ctx.device_context(), *ctx.Input<framework::SelectedRows>("Grad"), | ||
*ctx.Input<framework::Tensor>("LearningRate"), epsilon, | ||
moment_out, param_out); | ||
} else { | ||
PADDLE_THROW("Unsupported Variable Type of Grad"); | ||
} | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to place empty lines before and after #include cmath by following http://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done