Skip to content

Commit

Permalink
add sgd optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang Yang committed Jun 9, 2018
1 parent 5ae4c56 commit aee3708
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 2 deletions.
33 changes: 33 additions & 0 deletions paddle/contrib/tape/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,43 @@ class Linear {
return y;
}

std::vector<VariableHandle> Params() { return {w_}; }

private:
VariableHandle w_;
VariableHandle b_;
std::string act_;
};

class SGD {
public:
SGD(float learning_rate) : learning_rate_(new Variable("sgd")) {
Tape init_tape;

std::string initializer = "fill_constant";
framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{1};
attrs["value"] = learning_rate;
init_tape.AddOp(initializer, {}, {{"Out", {learning_rate_}}}, attrs);

init_tape.Forward();
}

void operator()(VariableHandle input) {
Tape temp_tape;
temp_tape.AddOp("sgd",
{{"Param", {input}},
{"LearningRate", {learning_rate_}},
{"Grad", {input->Grad()}}},
{{"ParamOut", {input}}},
{});
temp_tape.Forward();
input->ResetGrad();
}

private:
VariableHandle learning_rate_;
};
}
}
8 changes: 6 additions & 2 deletions paddle/contrib/tape/tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,16 @@ class ScopeWrapper : public framework::Scope {
const VariableHandleMap &out_vars) {
for (auto &v : in_vars) {
for (auto &vv : v.second) {
vars_[vv->Name()].reset(vv->Var());
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
}
}
}
for (auto &v : out_vars) {
for (auto &vv : v.second) {
vars_[vv->Name()].reset(vv->Var());
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
}
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/contrib/tape/test_tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ TEST(Tape, TestMLP) {
paddle::tape::Linear linear2(3, 3, "relu");
paddle::tape::Mean mean;

paddle::tape::SGD sgd(0.001);

for (int i = 0; i < 2; ++i) {
paddle::tape::reset_global_tape();

Expand All @@ -36,6 +38,13 @@ TEST(Tape, TestMLP) {
auto loss = mean(linear2(linear1(input)));

paddle::tape::get_global_tape().Backward(loss);

for (auto w : linear1.Params()) {
sgd(w);
}
for (auto w : linear2.Params()) {
sgd(w);
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/contrib/tape/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class Variable {
return grad_;
}

void ResetGrad() { grad_ = nullptr; }

// Stochastic Gradient Descent with Momentum
// VariableHandle Momentum ();

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}

void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(10) << "- " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
Expand All @@ -94,6 +95,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
#endif
}
RunImpl(scope, place);
VLOG(10) << "+ " << DebugStringEx(&scope);
}

bool OperatorBase::HasInputs(const std::string& name) const {
Expand Down

0 comments on commit aee3708

Please sign in to comment.