-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better version of PR #7985 (Modify load() for inference) #8024
Changes from 14 commits
e556d23
503f73a
749e33e
f8570e6
076e0e5
0692742
b17af18
e2b88c8
55f40e5
19343d4
4177474
a6f0b0f
e0f8e74
8fb97bc
e813d13
5430ddd
83b4616
8a002ce
4d61569
805b415
33570f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,18 @@ limitations under the License. */ | |
namespace paddle { | ||
namespace inference { | ||
|
||
void ReadProgramDescFromFile(const std::string& model_filename, | ||
std::string& program_desc_str) { | ||
VLOG(3) << "loading model from " << model_filename; | ||
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
inputfs.seekg(0, std::ios::end); | ||
program_desc_str.resize(inputfs.tellg()); | ||
inputfs.seekg(0, std::ios::beg); | ||
VLOG(3) << "program_desc_str's size: " << program_desc_str.size(); | ||
inputfs.read(&program_desc_str[0], program_desc_str.size()); | ||
inputfs.close(); | ||
} | ||
|
||
bool IsParameter(const framework::VarDesc* var, | ||
const framework::ProgramDesc& main_program) { | ||
if (var->Persistable()) { | ||
|
@@ -44,12 +56,15 @@ bool IsParameter(const framework::VarDesc* var, | |
|
||
void LoadPersistables(framework::Executor& executor, | ||
framework::Scope& scope, | ||
const framework::ProgramDesc& main_program, | ||
const std::string& dirname, | ||
const framework::ProgramDesc& main_program) { | ||
const std::string& param_filename) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
const framework::BlockDesc& global_block = main_program.Block(0); | ||
|
||
framework::ProgramDesc* load_program = new framework::ProgramDesc(); | ||
framework::BlockDesc* load_block = load_program->MutableBlock(0); | ||
std::vector<std::string> paramlist; | ||
|
||
for (auto* var : global_block.AllVars()) { | ||
if (IsParameter(var, main_program)) { | ||
VLOG(3) << "parameter's name: " << var->Name(); | ||
|
@@ -61,36 +76,63 @@ void LoadPersistables(framework::Executor& executor, | |
new_var->SetLoDLevel(var->GetLoDLevel()); | ||
new_var->SetPersistable(true); | ||
|
||
// append_op | ||
framework::OpDesc* op = load_block->AppendOp(); | ||
op->SetType("load"); | ||
op->SetOutput("Out", {new_var->Name()}); | ||
op->SetAttr("file_path", {dirname + "/" + new_var->Name()}); | ||
op->CheckAttrs(); | ||
if (!param_filename.empty()) { | ||
paramlist.push_back(new_var->Name()); | ||
} else { | ||
// append_op | ||
framework::OpDesc* op = load_block->AppendOp(); | ||
op->SetType("load"); | ||
op->SetOutput("Out", {new_var->Name()}); | ||
op->SetAttr("file_path", {dirname + "/" + new_var->Name()}); | ||
op->CheckAttrs(); | ||
} | ||
} | ||
} | ||
|
||
if (!param_filename.empty()) { | ||
// sort paramlist to have consistent ordering | ||
std::sort(paramlist.begin(), paramlist.end()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tensor and LoDTensor have no name. If there is something wrong (for example, user coincidentally gives a wrong param_filename), there is no warning information at all. It will be difficult for users to debug. Can we add another mechanism to check or ensure the correctness? We can refer to other frameworks. Also, this can be done in next PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, will think about it and as you mentioned will do later. |
||
// append just the load_combine op | ||
framework::OpDesc* op = load_block->AppendOp(); | ||
op->SetType("load_combine"); | ||
op->SetOutput("Out", paramlist); | ||
op->SetAttr("file_path", {param_filename}); | ||
op->CheckAttrs(); | ||
} | ||
|
||
executor.Run(*load_program, &scope, 0, true, true); | ||
|
||
VLOG(3) << "Ran loading successfully"; | ||
delete load_program; | ||
} | ||
|
||
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor& executor, | ||
framework::Scope& scope, | ||
const std::string& dirname) { | ||
std::string model_filename = dirname + "/__model__"; | ||
LOG(INFO) << "loading model from " << model_filename; | ||
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
std::string program_desc_str; | ||
inputfs.seekg(0, std::ios::end); | ||
program_desc_str.resize(inputfs.tellg()); | ||
inputfs.seekg(0, std::ios::beg); | ||
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); | ||
inputfs.read(&program_desc_str[0], program_desc_str.size()); | ||
inputfs.close(); | ||
ReadProgramDescFromFile(model_filename, program_desc_str); | ||
|
||
std::unique_ptr<framework::ProgramDesc> main_program( | ||
new framework::ProgramDesc(program_desc_str)); | ||
|
||
LoadPersistables(executor, scope, *main_program, dirname, ""); | ||
return main_program; | ||
} | ||
|
||
std::unique_ptr<framework::ProgramDesc> Load( | ||
framework::Executor& executor, | ||
framework::Scope& scope, | ||
const std::string& prog_filename, | ||
const std::string& param_filename) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to keep two std::unique_ptr<framework::ProgramDesc> Load(
framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname);
std::unique_ptr<framework::ProgramDesc> Load(
framework::Executor& executor,
framework::Scope& scope,
const std::string& prog_filename,
const std::string& params_filename); Users are free to rename the file which saving the program. The argument list of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for not describing my suggest clearly. I mean to define the interface to: std::unique_ptr<framework::ProgramDesc> Load(
framework::Executor& executor,
framework::Scope& scope,
const std::string& prog_filepath,
const std::string& param_filepath); No need of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I think you described your suggestion clearly: #8024 (comment) But, I chose to keep another argument of Hence, I put that extra argument. But it seems that you would prefer an implementation without it, so I will modify it. Thanks. |
||
std::string model_filename = prog_filename; | ||
std::string program_desc_str; | ||
ReadProgramDescFromFile(model_filename, program_desc_str); | ||
|
||
std::unique_ptr<framework::ProgramDesc> main_program( | ||
new framework::ProgramDesc(program_desc_str)); | ||
|
||
LoadPersistables(executor, scope, dirname, *main_program); | ||
LoadPersistables(executor, scope, *main_program, "", param_filename); | ||
return main_program; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -342,7 +342,12 @@ def save_inference_model(dirname, | |
prepend_feed_ops(inference_program, feeded_var_names) | ||
append_fetch_ops(inference_program, fetch_var_names) | ||
|
||
model_file_name = dirname + "/__model__" | ||
model_file_name = "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please delete this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, thanks. |
||
if save_file_name == None: | ||
model_file_name = dirname + "/__model__" | ||
else: | ||
model_file_name = dirname + "/__model_combined__" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The naming style is not friendly for users. We need to refine this interface. |
||
|
||
with open(model_file_name, "wb") as f: | ||
f.write(inference_program.desc.serialize_to_string()) | ||
|
||
|
@@ -384,7 +389,12 @@ def load_inference_model(dirname, executor, load_file_name=None): | |
if not os.path.isdir(dirname): | ||
raise ValueError("There is no directory named '%s'", dirname) | ||
|
||
model_file_name = dirname + "/__model__" | ||
model_file_name = "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. Please delete this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if load_file_name == None: | ||
model_file_name = dirname + "/__model__" | ||
else: | ||
model_file_name = dirname + "/__model_combined__" | ||
|
||
with open(model_file_name, "rb") as f: | ||
program_desc_str = f.read() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function is a common function to read binary to string. If there is no other use for this reading function, how about change it to
But we can change this in next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function will also be useful for "loading-from-buffer" case. I have changed the name to
ReadBinaryFile
.