Skip to content

Commit

Permalink
add sgd group
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Mar 8, 2018
1 parent a9b9ec4 commit cfa1f9b
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 0 deletions.
80 changes: 80 additions & 0 deletions paddle/fluid/operators/sgd_group_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/sgd_group_op.h"

namespace paddle {
namespace operators {

class SGDGroupOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Params"),
"Inputs(Param) of SGDGroupOp should not be null.");
PADDLE_ENFORCE(ctx->HasInputs("Grads"),
"Inputs(Grad) of SGDGroupOp should not be null.");
PADDLE_ENFORCE(ctx->HasInputs("LearningRates"),
"Inputs(LearningRates) of SGDGroupOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs("ParamOuts"),
"Outputs(ParamOut) of SGDGroupOp should not be null.");

auto params = ctx->GetInputsDim("Params");
auto grads = ctx->GetInputsDim("Grads");
auto learning_rates = ctx->GetInputsDim("LearningRates");

auto param_num = params.size();

PADDLE_ENFORCE_EQ(param_num, grads.size(),
"The number of param and grads should be equal.");
PADDLE_ENFORCE_EQ(
param_num, learning_rates.size(),
"The number of param and learning_rates should be equal.");

for (size_t i = 0; i < param_num; ++i) {
PADDLE_ENFORCE_EQ(framework::product(learning_rates[i]), 1,
"Learning rate should have 1 element");
}

auto param_dims = ctx->GetInputsDim("Params");
// TODO(qijun): check dimensions of Param and Grad at complie
// and run time.
ctx->SetOutputsDim("ParamOuts", param_dims);
}
};

class SGDGroupOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDGroupOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Params", "(vector<Tensor>) Input parameter").AsDuplicable();
AddInput("LearningRates", "(vector<Tensor>) Learning rate of SGD")
.AsDuplicable();
AddInput("Grads", "(vector<Tensor>) Input gradient").AsDuplicable();
AddOutput("ParamOuts", "(vector<Tensor>) Output parameter").AsDuplicable();
AddComment(R"DOC(
SGDGroup operator
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd_group, ops::SGDGroupOp, ops::SGDGroupOpMaker);
REGISTER_OP_CPU_KERNEL(sgd_group, ops::SGDGroupOpKernel<float>,
ops::SGDGroupOpKernel<double>);
73 changes: 73 additions & 0 deletions paddle/fluid/operators/sgd_group_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU
#include "paddle/fluid/operators/sgd_group_op.h"
#include "paddle/fluid/platform/cuda_helper.h"

namespace paddle {
namespace operators {

namespace {

template <typename T>
__global__ void SGDGroupKernel(const T* g, const T* p, const T* learning_rate,
const int num, T* p_out) {
T lr = learning_rate[0];
int grid_size = blockDim.x * gridDim.x;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) {
T g_data = g[i];
T p_data = p[i];
p_out[i] = p_data - lr * g_data;
}
}

} // namespace

template <typename T>
class SGDGroupOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto params = ctx.MultiInput<framework::Tensor>("Params");
auto learning_rates = ctx.MultiInput<framework::Tensor>("LearningRates");
auto grads = ctx.MultiInput<framework::Tensor>("Grads");

auto param_outs = ctx.MultiOutput<framework::Tensor>("ParamOuts");

auto grad_var = ctx.MultiInputVar("Grads");

if (grad_var[0]->IsType<framework::LoDTensor>()) {
for (size_t j = 0; j < params.size(); ++j) {
auto* param_out_data = param_outs[j]->mutable_data<T>(ctx.GetPlace());
auto* grad_data = grads[j]->data<T>();
auto* param_data = params[j]->data<T>();
int param_num = params[j]->numel();
int block = 512;
int grid = (param_num + block - 1) / block;
SGDGroupKernel<
T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
grad_data, param_data, learning_rates[j]->data<T>(), param_num,
param_out_data);
}
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sgd_group, ops::SGDGroupOpCUDAKernel<float>,
ops::SGDGroupOpCUDAKernel<double>);
78 changes: 78 additions & 0 deletions paddle/fluid/operators/sgd_group_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"

namespace paddle {
namespace operators {

template <typename T>
class SGDGroupOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
//
// auto* param = ctx.Input<framework::Tensor>("Param");
// auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
// auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
//
// auto* grad_var = ctx.InputVar("Grad");
// // Actually, all tensors are LoDTensor except SelectedRows.
// if (grad_var->IsType<framework::LoDTensor>()) {
// param_out->mutable_data<T>(ctx.GetPlace());
// auto* grad = ctx.Input<framework::Tensor>("Grad");
//
// auto p = framework::EigenVector<T>::Flatten(*param);
// auto g = framework::EigenVector<T>::Flatten(*grad);
// auto o = framework::EigenVector<T>::Flatten(*param_out);
// auto* lr = learning_rate->data<T>();
//
// o = p - lr[0] * g;
// } else if (grad_var->IsType<framework::SelectedRows>()) {
// // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// // This manual optimization brings difficulty to track data
// dependency.
// // It's better to find a more elegant solution.
// PADDLE_ENFORCE_EQ(param, param_out);
// auto* grad = ctx.Input<framework::SelectedRows>("Grad");
//
// auto in_height = grad->height();
// auto out_dims = param_out->dims();
// PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
//
// auto& in_value = grad->value();
// auto& in_rows = grad->rows();
//
// int64_t in_row_numel = in_value.numel() / in_rows.size();
// PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
//
// auto* in_data = in_value.data<T>();
// auto* out_data = param_out->data<T>();
// auto* lr = learning_rate->data<T>();
//
// for (size_t i = 0; i < in_rows.size(); i++) {
// for (int64_t j = 0; j < in_row_numel; j++) {
// out_data[in_rows[i] * in_row_numel + j] -=
// lr[0] * in_data[i * in_row_numel + j];
// }
// }
// } else {
// PADDLE_THROW("Unsupported Variable Type of Grad");
// }
}
};
} // namespace operators
} // namespace paddle
54 changes: 54 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sgd_group_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest


class TestSGDOp(OpTest):
def setUp(self):
self.op_type = "sgd_group"
w0 = np.random.random((1, 124)).astype('float32')
w1 = np.random.random((3, 24)).astype('float32')
w2 = np.random.random((4, 104)).astype('float32')

g0 = np.random.random((1, 124)).astype('float32')
g1 = np.random.random((3, 24)).astype('float32')
g2 = np.random.random((4, 104)).astype('float32')

lr0 = np.array([0.1]).astype("float32")
lr1 = np.array([0.2]).astype("float32")
lr2 = np.array([0.3]).astype("float32")

o0 = w0 - lr0 * g0
o1 = w1 - lr1 * g1
o2 = w2 - lr2 * g2

self.inputs = {
"Params": [("w0", w0), ("w1", w1), ("w2", w2)],
"Grads": [("g0", g0), ("g1", g1), ("g2", g2)],
'LearningRates': [("lr0", lr0), ("lr1", lr1), ("lr2", lr2)]
}

self.outputs = {'ParamOuts': [("o0", o0), ("o1", o1), ("o2", o2)]}

def test_check_output(self):
self.check_output()


if __name__ == "__main__":
unittest.main()

0 comments on commit cfa1f9b

Please sign in to comment.