Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support adagrad sparse update #5272

Merged
merged 11 commits into from
Nov 16, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,19 @@ set(DEPS_OPS
pool_with_index_op
nccl_op
sequence_conv_op
lstm_op)
lstm_op
sgd_op
adagrad_op)


op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op selected_rows_functor)
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
op_library(adagrad_op DEPS selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
if(WITH_GPU)
Expand Down
57 changes: 52 additions & 5 deletions paddle/operators/adagrad_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/adagrad_op.h"
#include <cmath>
#include "paddle/operators/math/selected_rows_functor.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to place empty lines before and after #include cmath by following http://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


namespace paddle {
namespace operators {
Expand All @@ -21,7 +23,7 @@ class AdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
Expand Down Expand Up @@ -54,8 +56,8 @@ class AdagradOp : public framework::OperatorWithKernel {

class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdagradOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
AdagradOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
Expand Down Expand Up @@ -83,10 +85,55 @@ by avoiding division by zero.
)DOC");
}
};

template <typename T>
struct SparseAdagradFunctor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
std::unique_ptr<framework::SelectedRows> grad_square{
new framework::SelectedRows()};
grad_square->set_rows(grad.rows());
grad_square->set_height(grad.height());
grad_square->mutable_value()->mutable_data<T>(grad.value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto g = framework::EigenVector<T>::Flatten(grad.value());
gs.device(*context.GetEigenDevice<platform::CPUPlace>()) = g * g;

math::SelectedRowsAddToTensor<platform::CPUPlace, T> functor;
functor(context, *grad_square, moment);

auto grad_rows = grad.rows();
auto grad_rows_size = grad_rows.size();

int64_t grad_row_numel = grad.value().numel() / grad_rows_size;

auto* lr = learning_rate.data<T>();
auto* param_data = param->data<T>();
auto* moment_data = moment->data<T>();
auto* grad_data = grad.value().data<T>();

for (size_t i = 0; i < grad_rows_size; i++) {
for (int64_t j = 0; j < grad_row_numel; j++) {
param_data[grad_rows[i] * grad_row_numel + j] -=
lr[0] * grad_data[i * grad_row_numel + j] /
(std::sqrt(moment_data[grad_rows[i] * grad_row_numel + j]) +
epsilon);
}
}
}
};

template struct SparseAdagradFunctor<platform::CPUPlace, float>;
template struct SparseAdagradFunctor<platform::CPUPlace, double>;
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
REGISTER_OP_CPU_KERNEL(adagrad,
ops::AdagradOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
adagrad, ops::AdagradOpKernel<paddle::platform::CPUPlace, float>,
ops::AdagradOpKernel<paddle::platform::CPUPlace, double>);
81 changes: 79 additions & 2 deletions paddle/operators/adagrad_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,84 @@

#define EIGEN_USE_GPU
#include "paddle/operators/adagrad_op.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

namespace {
template <typename T, int block_size>
__global__ void SparseAdagradFunctorKernel(const T* grad, const int64_t* rows,
const T* learning_rate, T* param,
T* moment, int64_t row_numel,
T epsilon) {
const int ty = blockIdx.y;
int tid = threadIdx.x;

grad += ty * row_numel;
param += rows[ty] * row_numel;
moment += rows[ty] * row_numel;

for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(param + index,
-1.0 * learning_rate[0] * grad[index] /
(sqrt(moment[index]) + epsilon));
}
}
} // namespace

template <typename T>
struct SparseAdagradFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
std::unique_ptr<framework::SelectedRows> grad_square{
new framework::SelectedRows()};
grad_square->set_rows(grad.rows());
grad_square->set_height(grad.height());
grad_square->mutable_value()->mutable_data<T>(grad.value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto g = framework::EigenVector<T>::Flatten(grad.value());
gs.device(*context.GetEigenDevice<platform::GPUPlace>()) = g * g;

math::SelectedRowsAddToTensor<platform::GPUPlace, T> functor;
functor(context, *grad_square, moment);

auto grad_rows = grad.rows();
auto grad_rows_size = grad_rows.size();

int64_t grad_row_numel = grad.value().numel() / grad_rows_size;

auto* lr = learning_rate.data<T>();
auto* param_data = param->data<T>();
auto* moment_data = moment->data<T>();
auto* grad_data = grad.value().data<T>();

const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, grad_rows.size());
SparseAdagradFunctorKernel<
T, 256><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_data, grad.rows().data(),
learning_rate.data<T>(), param_data,
moment_data, grad_row_numel, epsilon);
}
};

template struct SparseAdagradFunctor<platform::GPUPlace, float>;
template struct SparseAdagradFunctor<platform::GPUPlace, double>;

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adagrad,
ops::AdagradOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
adagrad, ops::AdagradOpKernel<paddle::platform::GPUPlace, float>,
ops::AdagradOpKernel<paddle::platform::GPUPlace, double>);
64 changes: 45 additions & 19 deletions paddle/operators/adagrad_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename Place, typename T>
struct SparseAdagradFunctor {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param);
};

template <typename Place, typename T>
class AdagradOpKernel : public framework::OpKernel<T> {
public:
Expand All @@ -29,25 +37,43 @@ class AdagradOpKernel : public framework::OpKernel<T> {
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to move lines 26 - 30 into the if else block since they are only used for LoDTensor type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we have to allocate memory before calculating. param_out_tensor->mutable_data<T>(ctx.GetPlace()); have to be called in both if-else block. I will unify the variable name in both if-else block to avoid inconsistence.


float epsilon = ctx.Attr<float>("epsilon");

auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));

auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>();

moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));

auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));

auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>();

moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
PADDLE_ENFORCE_EQ(param, param_out);

auto* moment = ctx.Input<framework::Tensor>("Moment");
auto* moment_out = ctx.Output<framework::Tensor>("MomentOut");
PADDLE_ENFORCE_EQ(moment, moment_out);

SparseAdagradFunctor<Place, T> functor;
functor(ctx.device_context(), *ctx.Input<framework::SelectedRows>("Grad"),
*ctx.Input<framework::Tensor>("LearningRate"), epsilon,
moment_out, param_out);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
auto x_dims = ctx->GetInputDim("X");
// TODO(qiao) change batch_size
for (int i = 1; i < shape.size(); ++i) {
for (size_t i = 1; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0,
"Each dimension of shape "
"must be positiv except the first.");
Expand Down
15 changes: 8 additions & 7 deletions paddle/operators/sgd_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ namespace paddle {
namespace operators {

namespace {
template <typename T>
template <typename T, int block_size>
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
const int64_t* rows,
const T* learning_rate, T* tensor_out,
int64_t row_numel, int block_size) {
int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;

Expand Down Expand Up @@ -59,14 +59,15 @@ struct SparseSGDFunctor<platform::GPUPlace, T> {
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();

int block_size = 256;
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in_data, in_rows.data(), learning_rate.data<T>(),
out_data, in_row_numel, block_size);
T, 256><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in_data, in_rows.data(),
learning_rate.data<T>(), out_data,
in_row_numel);
}
};

Expand Down
1 change: 0 additions & 1 deletion paddle/operators/sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
#include <vector>
#include "paddle/framework/var_type_inference.h"
#include "paddle/operators/net_op.h"

namespace paddle {
namespace operators {
Expand Down
Loading