Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some functions to PaddleAPI.h #1013

Merged
merged 5 commits into from
Jan 10, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/api/Arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ void Arguments::setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError) {
a.cpuSequenceDims = m->cast<paddle::IVector>(vec->getSharedPtr());
}

float Arguments::sumCosts() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function names should be Camel Cased: https://google.github.io/styleguide/cppguide.html#Function_Names

是不是至少对于新加的函数,应该符合code style,这样至少提醒大家关注规范;现有的函数,可以以后写个工具重命名?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

麻烦review一下这个 baidu-adu/cpp-primer-digest#1
这个是Paddle目前的命名风格。

return paddle::Argument::sumCosts(m->outputs);
}

int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return a.getBatchSize();
Expand Down
6 changes: 6 additions & 0 deletions paddle/api/PaddleAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ class Arguments {
IVector* vec) throw(RangeError);
void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError);

float sumCosts() const;
Copy link
Collaborator

@wangkuiyi wangkuiyi Dec 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR的description里说明一下为什么需要增加这几个函数吧。加了之后能有什么好处:

  • sumCosts
  • load
  • save

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


private:
static Arguments* createByPaddleArgumentVector(void* ptr);
void* getInternalArgumentsPtr() const;
Expand Down Expand Up @@ -549,6 +551,10 @@ class Parameter {
ParameterConfig* getConfig();
void setValueUpdated();

bool save(const std::string& filename) const;

bool load(const std::string& filename) const;

private:
static Parameter* createFromRawPtr(void* ptr);
static Parameter* createFromSharedPtr(void* ptr);
Expand Down
8 changes: 8 additions & 0 deletions paddle/api/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,11 @@ ParameterConfig* Parameter::getConfig() {
size_t Parameter::getID() const { return m->getPtr()->getID(); }

void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }

bool Parameter::save(const std::string& filename) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我之前一直迟迟没有approve这个PR的一个主要原因是,save/load(filename) 这样的methods不是一个好的设计。

首先这些methods不容易被unit test。除非我们有一个in-memory mock filesystem。但实际上我们不需要这么复杂的test facility。而且这些methods里的内容经常和网络传输methods里的内容重复——都是要 serialize/deserialize class memebers。

一个比较常见的设计是用 serialize/deserialize 来代替 save/load:

std::string serialize();
error deserialize(const std::string& input);

这样一来容易unit test,二来容易用于网络传输和磁盘I/O:

File f("/tmp/a");
Parameters ps;
f.write(ps.serialize());

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非常有道理!现在的接口是不适合分布式化的,serialize/deserialize才能更方便传输。目前暴露的接口是老接口,仅仅是暴露出来,下一步提供c-api的时候可以考虑重构。

return m->getPtr()->save(filename);
}

bool Parameter::load(const std::string& filename) const {
return m->getPtr()->load(filename);
}
2 changes: 2 additions & 0 deletions paddle/api/test/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.w0
*.wbias
2 changes: 2 additions & 0 deletions paddle/api/test/testArguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def test_load_arguments(self):
args = swig_paddle.Arguments.createArguments(1)
args.setSlotValue(0, m)

self.assertAlmostEqual(27.0, args.sumCosts())

mat = args.getSlotValue(0)
assert isinstance(mat, swig_paddle.Matrix)
np_mat = mat.toNumpyMatInplace()
Expand Down
4 changes: 4 additions & 0 deletions paddle/api/test/testGradientMachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_create_gradient_machine(self):
assert isinstance(val, swig_paddle.Vector)
arr = numpy.full((len(val), ), 0.1, dtype="float32")
val.copyFromNumpyArray(arr)
self.assertTrue(param.save(param.getName()))
param_config = param.getConfig().toProto()
assert isinstance(param_config,
paddle.proto.ParameterConfig_pb2.ParameterConfig)
Expand Down Expand Up @@ -92,6 +93,9 @@ def backward_callback(param_):

self.assertTrue(self.isCalled)

for param in machine.getParameters():
self.assertTrue(param.load(param.getName()))

def test_train_one_pass(self):
conf_file_path = './testTrainConfig.py'
trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile(
Expand Down