-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CustomReader #10872
CustomReader #10872
Changes from 10 commits
e61a38d
e15d616
983c9a2
b48eba1
df8fbf8
4b395b0
2e42b31
239546a
e4e9d36
0457f06
8147063
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/executor.h" | ||
#include "paddle/fluid/operators/reader/reader_op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace reader { | ||
|
||
class CustomReader : public framework::DecoratedReader { | ||
public: | ||
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block, | ||
const platform::Place& dev_place, | ||
const std::vector<std::string>& source_var_names, | ||
const std::vector<std::string>& sink_var_names) | ||
: DecoratedReader(reader), | ||
program_(*sub_block.Program()), | ||
sub_block_id_(sub_block.ID()), | ||
exe_(framework::Executor(dev_place)), | ||
source_var_names_(source_var_names), | ||
sink_var_names_(sink_var_names) {} | ||
|
||
void ReadNext(std::vector<framework::LoDTensor>* out) override; | ||
|
||
private: | ||
const framework::ProgramDesc program_; | ||
int sub_block_id_; | ||
framework::Executor exe_; | ||
|
||
std::vector<std::string> source_var_names_; | ||
std::vector<std::string> sink_var_names_; | ||
}; | ||
|
||
class CreateCustomReaderOp : public framework::OperatorBase { | ||
public: | ||
using framework::OperatorBase::OperatorBase; | ||
|
||
private: | ||
void RunImpl(const framework::Scope& scope, | ||
const platform::Place& dev_place) const override { | ||
auto* out = scope.FindVar(Output("Out")) | ||
->template GetMutable<framework::ReaderHolder>(); | ||
auto* sub_block = Attr<framework::BlockDesc*>("sub_block"); | ||
if (out->Get() != nullptr) { | ||
return; | ||
} | ||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) | ||
->Get<framework::ReaderHolder>(); | ||
out->Reset( | ||
new CustomReader(underlying_reader.Get(), *sub_block, dev_place, | ||
Attr<std::vector<std::string>>("source_var_names"), | ||
Attr<std::vector<std::string>>("sink_var_names"))); | ||
} | ||
}; | ||
|
||
class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase { | ||
protected: | ||
void Apply() override { | ||
AddAttr<framework::BlockDesc*>( | ||
"sub_block", "The block to hold all preprocessing operators."); | ||
AddAttr<std::vector<std::string>>( | ||
"source_var_names", | ||
"Source variables are starting points of data preprocessing. They hold " | ||
"preprocessing's input tensors. Each source variable corresponds to " | ||
"one of underlying reader's output datas."); | ||
AddAttr<std::vector<std::string>>( | ||
"sink_var_names", | ||
"Sink variables are ending points of data preprocessing. They hold " | ||
"preprocessing's output tensors. Each sink variable corresponds to " | ||
"one of custom reader's output datas."); | ||
AddComment(R"DOC( | ||
CreateCustomReader Operator | ||
|
||
A custom reader can be used for input data preprocessing. | ||
A custom reader holds its own sub-block, which will be executed in its | ||
'ReadNext()' function. Users can configurate their own preprocessing | ||
pipelines by inserting operators into custom reader's sub-block. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class CustomReaderInferShape : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(!ctx->IsRuntime(), | ||
"'CustomReaderInferShape' should only be invoked during " | ||
"compile time."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Out"), | ||
"The output decorated reader should not be null."); | ||
const auto* sub_block = | ||
ctx->Attrs().Get<framework::BlockDesc*>("sub_block"); | ||
const auto sink_var_names = | ||
ctx->Attrs().Get<std::vector<std::string>>("sink_var_names"); | ||
std::vector<std::vector<int64_t>> res_dims; | ||
std::vector<int32_t> res_lod_levels; | ||
for (const std::string& var_name : sink_var_names) { | ||
auto* sink_var = sub_block->FindVar(var_name); | ||
PADDLE_ENFORCE_NOT_NULL(sink_var); | ||
res_dims.emplace_back(sink_var->GetShape()); | ||
res_lod_levels.push_back(sink_var->GetLoDLevel()); | ||
} | ||
auto* out_reader = | ||
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); | ||
out_reader->SetShapes(res_dims); | ||
out_reader->SetLoDLevels(res_lod_levels); | ||
} | ||
}; | ||
|
||
class CustomReaderInferVarType : public framework::VarTypeInference { | ||
public: | ||
void operator()(const framework::OpDesc& op_desc, | ||
framework::BlockDesc* block) const override { | ||
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]); | ||
PADDLE_ENFORCE_NOT_NULL(out_reader); | ||
out_reader->SetType(framework::proto::VarType::READER); | ||
|
||
auto sink_var_names = | ||
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names")); | ||
const auto* sub_block = | ||
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block")); | ||
std::vector<framework::proto::VarType::Type> res_data_types; | ||
for (const std::string& var_name : sink_var_names) { | ||
framework::VarDesc* var = sub_block->FindVar(var_name); | ||
PADDLE_ENFORCE_NOT_NULL(var); | ||
res_data_types.emplace_back(var->GetDataType()); | ||
} | ||
out_reader->SetDataTypes(res_data_types); | ||
} | ||
}; | ||
|
||
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { | ||
out->clear(); | ||
std::vector<framework::LoDTensor> underlying_outs; | ||
reader_->ReadNext(&underlying_outs); | ||
if (underlying_outs.empty()) { | ||
// There is not next data. | ||
return; | ||
} | ||
PADDLE_ENFORCE( | ||
source_var_names_.size() == underlying_outs.size() && | ||
sink_var_names_.size() == underlying_outs.size(), | ||
"The size of source_var_names(%d), the size of sink_var_names(%d) and " | ||
"the size of underlying_outs(%d) are not consistent. Each feeding " | ||
"element must have its own source and sink variable.", | ||
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size()); | ||
// The scope for CustomReader's sub-block should be independent and shouldn't | ||
// be any other computation scope's child. Otherwise, data preprocessing and | ||
// compution cannot be concurrent. | ||
auto* scope = new framework::Scope(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, it is no need to Scope scope; is cool There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. |
||
// 1. Copy LoDTensors from underlying reader's output to source variables. | ||
for (size_t i = 0; i < source_var_names_.size(); ++i) { | ||
framework::Variable* var = scope->Var(source_var_names_[i]); | ||
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>(); | ||
tensor->ShareDataWith(underlying_outs[i]); | ||
tensor->set_lod(underlying_outs[i].lod()); | ||
} | ||
// 2. Run the sub-block. | ||
exe_.Run(program_, scope, sub_block_id_, false, true); | ||
// 3. Copy LoDTensors from sink variables to out. | ||
out->resize(sink_var_names_.size()); | ||
for (size_t i = 0; i < sink_var_names_.size(); ++i) { | ||
framework::Variable* var = scope->FindVar(sink_var_names_[i]); | ||
PADDLE_ENFORCE_NOT_NULL(var); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe auto& tensor = detail::Ref(scope->FindVar(sink_var_names_[i])).Get<framework::LoDTensor>() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool! |
||
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>(); | ||
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]); | ||
} | ||
delete scope; | ||
} | ||
|
||
} // namespace reader | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators::reader; | ||
REGISTER_OPERATOR(create_custom_reader, ops::CreateCustomReaderOp, | ||
ops::CreateCustomReaderOpMaker, ops::CustomReaderInferShape, | ||
ops::CustomReaderInferVarType, | ||
paddle::framework::EmptyGradOpMaker) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import contextlib | ||
|
||
from .. import core | ||
from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program, Program | ||
|
@@ -21,7 +22,8 @@ | |
|
||
__all__ = [ | ||
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', | ||
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer' | ||
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer', | ||
'Preprocessor' | ||
] | ||
|
||
|
||
|
@@ -468,8 +470,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None): | |
inputs={'UnderlyingReader': reader}, | ||
outputs={'Out': [new_reader]}, | ||
attrs=attrs) | ||
new_reader.persistable = True | ||
new_reader.stop_gradient = True | ||
return monkey_patch_reader_methods(new_reader) | ||
|
||
|
||
|
@@ -514,3 +514,81 @@ def read_file(file_obj): | |
return out[0] | ||
else: | ||
return out | ||
|
||
|
||
class Preprocessor(object): | ||
BEFORE_SUB_BLOCK = 0 | ||
IN_SUB_BLOCK = 1 | ||
AFTER_SUB_BLOCK = 2 | ||
|
||
def __init__(self, reader, name=None): | ||
self.underlying_reader = reader | ||
new_reader_name = name if name is not None else unique_name( | ||
"create_custom_reader") | ||
self.main_prog = default_main_program() | ||
self.reader = self.main_prog.current_block().create_var( | ||
name=new_reader_name) | ||
self.sub_block = None | ||
self.source_var_names = None | ||
self.sink_var_names = None | ||
self.status = Preprocessor.BEFORE_SUB_BLOCK | ||
|
||
def is_completed(self): | ||
return self.sub_block and self.source_var_names and self.sink_var_names | ||
|
||
@contextlib.contextmanager | ||
def block(self): | ||
self.status = Preprocessor.IN_SUB_BLOCK | ||
self.sub_block = self.main_prog.create_block() | ||
yield | ||
self.main_prog.rollback() | ||
self.status = Preprocessor.AFTER_SUB_BLOCK | ||
if not self.is_completed(): | ||
raise RuntimeError( | ||
"The definition of preprocessor is incompleted! " | ||
"Please make sure that you have set input and output " | ||
"variables by invoking 'inputs' and 'outputs' in " | ||
"Preprocessor's sub-block.") | ||
|
||
def inputs(self): | ||
if self.status != Preprocessor.IN_SUB_BLOCK: | ||
raise RuntimeError( | ||
"Preprocessor.inputs() can only be invoked inside the sub-block." | ||
) | ||
|
||
source_shapes = self.underlying_reader.desc.shapes() | ||
source_dtypes = self.underlying_reader.desc.dtypes() | ||
source_lod_levels = self.underlying_reader.desc.lod_levels() | ||
self.source_var_names = [] | ||
source_vars = [] | ||
for idx in xrange(len(source_shapes)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe more pythonic way source_var_names = (unique_name("preprocessor_source") for _ in xrange(len(source_shapes)))
for var_name, shape, dtype, lod_level in zip(source_var_names, source_shapes, source_dtypes, source_lod_levels):
pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Amazing! |
||
self.source_var_names.append(unique_name("preprocessor_source")) | ||
source_vars.append(self.main_prog.current_block().create_var( | ||
name=self.source_var_names[-1], | ||
shape=source_shapes[idx], | ||
dtype=source_dtypes[idx], | ||
lod_level=source_lod_levels[idx])) | ||
return source_vars | ||
|
||
def outputs(self, *outs): | ||
if self.status != Preprocessor.IN_SUB_BLOCK: | ||
raise RuntimeError( | ||
"Preprocessor.outputs() can only be invoked inside the sub-block." | ||
) | ||
self.sink_var_names = [var.name for var in outs] | ||
|
||
def __call__(self, *args, **kwargs): | ||
if self.status != Preprocessor.AFTER_SUB_BLOCK: | ||
raise RuntimeError( | ||
"Preprocessor output can only be retrieved after rnn block.") | ||
|
||
self.main_prog.current_block().append_op( | ||
type="create_custom_reader", | ||
inputs={'UnderlyingReader': self.underlying_reader}, | ||
outputs={'Out': [self.reader]}, | ||
attrs={ | ||
"sub_block": self.sub_block, | ||
"source_var_names": self.source_var_names, | ||
"sink_var_names": self.sink_var_names | ||
}) | ||
return monkey_patch_reader_methods(self.reader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sink_var_names maybe not equal with source_var_names.size()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks!