Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CustomReader #10872

Merged
merged 11 commits into from
May 25, 2018
3 changes: 1 addition & 2 deletions paddle/fluid/framework/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class InferShapeContext {

std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;

// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
Expand All @@ -81,8 +82,6 @@ class InferShapeContext {
const std::vector<std::string> &names) const;

virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;

virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
};

} // namespace framework
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/reader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)

cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
# Export local libraries to parent
Expand Down
190 changes: 190 additions & 0 deletions paddle/fluid/operators/reader/create_custom_reader_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class CustomReader : public framework::DecoratedReader {
public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
const platform::Place& dev_place,
const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader),
program_(*sub_block.Program()),
sub_block_id_(sub_block.ID()),
exe_(framework::Executor(dev_place)),
source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {}

void ReadNext(std::vector<framework::LoDTensor>* out) override;

private:
const framework::ProgramDesc program_;
int sub_block_id_;
framework::Executor exe_;

std::vector<std::string> source_var_names_;
std::vector<std::string> sink_var_names_;
};

class CreateCustomReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;

private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
auto* sub_block = Attr<framework::BlockDesc*>("sub_block");
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new CustomReader(underlying_reader.Get(), *sub_block, dev_place,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
}
};

class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
protected:
void Apply() override {
AddAttr<framework::BlockDesc*>(
"sub_block", "The block to hold all preprocessing operators.");
AddAttr<std::vector<std::string>>(
"source_var_names",
"Source variables are starting points of data preprocessing. They hold "
"preprocessing's input tensors. Each source variable corresponds to "
"one of underlying reader's output datas.");
AddAttr<std::vector<std::string>>(
"sink_var_names",
"Sink variables are ending points of data preprocessing. They hold "
"preprocessing's output tensors. Each sink variable corresponds to "
"one of custom reader's output datas.");
AddComment(R"DOC(
CreateCustomReader Operator

A custom reader can be used for input data preprocessing.
A custom reader holds its own sub-block, which will be executed in its
'ReadNext()' function. Users can configurate their own preprocessing
pipelines by inserting operators into custom reader's sub-block.
)DOC");
}
};

class CustomReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"'CustomReaderInferShape' should only be invoked during "
"compile time.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
const auto* sub_block =
ctx->Attrs().Get<framework::BlockDesc*>("sub_block");
const auto sink_var_names =
ctx->Attrs().Get<std::vector<std::string>>("sink_var_names");
std::vector<std::vector<int64_t>> res_dims;
std::vector<int32_t> res_lod_levels;
for (const std::string& var_name : sink_var_names) {
auto* sink_var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(sink_var);
res_dims.emplace_back(sink_var->GetShape());
res_lod_levels.push_back(sink_var->GetLoDLevel());
}
auto* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetShapes(res_dims);
out_reader->SetLoDLevels(res_lod_levels);
}
};

class CustomReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]);
PADDLE_ENFORCE_NOT_NULL(out_reader);
out_reader->SetType(framework::proto::VarType::READER);

auto sink_var_names =
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names"));
const auto* sub_block =
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block"));
std::vector<framework::proto::VarType::Type> res_data_types;
for (const std::string& var_name : sink_var_names) {
framework::VarDesc* var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
res_data_types.emplace_back(var->GetDataType());
}
out_reader->SetDataTypes(res_data_types);
}
};

void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs);
if (underlying_outs.empty()) {
// There is not next data.
return;
}
PADDLE_ENFORCE(
source_var_names_.size() == underlying_outs.size() &&
sink_var_names_.size() == underlying_outs.size(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sink_var_names maybe not equal with source_var_names.size()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!

"The size of source_var_names(%d), the size of sink_var_names(%d) and "
"the size of underlying_outs(%d) are not consistent. Each feeding "
"element must have its own source and sink variable.",
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size());
// The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
auto* scope = new framework::Scope();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is no need to new Scope(). Just

Scope scope;

is cool

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

// 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope->Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod());
}
// 2. Run the sub-block.
exe_.Run(program_, scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
framework::Variable* var = scope->FindVar(sink_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(var);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe fluid/operator/detail/safe_ref.h is good?

auto& tensor = detail::Ref(scope->FindVar(sink_var_names_[i])).Get<framework::LoDTensor>()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
}
delete scope;
}

} // namespace reader
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_OPERATOR(create_custom_reader, ops::CreateCustomReaderOp,
ops::CreateCustomReaderOpMaker, ops::CustomReaderInferShape,
ops::CustomReaderInferVarType,
paddle::framework::EmptyGradOpMaker)
1 change: 1 addition & 0 deletions paddle/fluid/operators/reader/reader_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void DecoratedReaderInferShape::operator()(
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}

void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
Expand Down
84 changes: 81 additions & 3 deletions python/paddle/fluid/layers/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib

from .. import core
from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program, Program
Expand All @@ -21,7 +22,8 @@

__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer'
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'Preprocessor'
]


Expand Down Expand Up @@ -468,8 +470,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]},
attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader)


Expand Down Expand Up @@ -514,3 +514,81 @@ def read_file(file_obj):
return out[0]
else:
return out


class Preprocessor(object):
BEFORE_SUB_BLOCK = 0
IN_SUB_BLOCK = 1
AFTER_SUB_BLOCK = 2

def __init__(self, reader, name=None):
self.underlying_reader = reader
new_reader_name = name if name is not None else unique_name(
"create_custom_reader")
self.main_prog = default_main_program()
self.reader = self.main_prog.current_block().create_var(
name=new_reader_name)
self.sub_block = None
self.source_var_names = None
self.sink_var_names = None
self.status = Preprocessor.BEFORE_SUB_BLOCK

def is_completed(self):
return self.sub_block and self.source_var_names and self.sink_var_names

@contextlib.contextmanager
def block(self):
self.status = Preprocessor.IN_SUB_BLOCK
self.sub_block = self.main_prog.create_block()
yield
self.main_prog.rollback()
self.status = Preprocessor.AFTER_SUB_BLOCK
if not self.is_completed():
raise RuntimeError(
"The definition of preprocessor is incompleted! "
"Please make sure that you have set input and output "
"variables by invoking 'inputs' and 'outputs' in "
"Preprocessor's sub-block.")

def inputs(self):
if self.status != Preprocessor.IN_SUB_BLOCK:
raise RuntimeError(
"Preprocessor.inputs() can only be invoked inside the sub-block."
)

source_shapes = self.underlying_reader.desc.shapes()
source_dtypes = self.underlying_reader.desc.dtypes()
source_lod_levels = self.underlying_reader.desc.lod_levels()
self.source_var_names = []
source_vars = []
for idx in xrange(len(source_shapes)):
Copy link
Collaborator

@reyoung reyoung May 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe more pythonic way

source_var_names = (unique_name("preprocessor_source") for _ in xrange(len(source_shapes)))
for var_name, shape, dtype, lod_level in zip(source_var_names, source_shapes, source_dtypes, source_lod_levels):
	pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing!

self.source_var_names.append(unique_name("preprocessor_source"))
source_vars.append(self.main_prog.current_block().create_var(
name=self.source_var_names[-1],
shape=source_shapes[idx],
dtype=source_dtypes[idx],
lod_level=source_lod_levels[idx]))
return source_vars

def outputs(self, *outs):
if self.status != Preprocessor.IN_SUB_BLOCK:
raise RuntimeError(
"Preprocessor.outputs() can only be invoked inside the sub-block."
)
self.sink_var_names = [var.name for var in outs]

def __call__(self, *args, **kwargs):
if self.status != Preprocessor.AFTER_SUB_BLOCK:
raise RuntimeError(
"Preprocessor output can only be retrieved after rnn block.")

self.main_prog.current_block().append_op(
type="create_custom_reader",
inputs={'UnderlyingReader': self.underlying_reader},
outputs={'Out': [self.reader]},
attrs={
"sub_block": self.sub_block,
"source_var_names": self.source_var_names,
"sink_var_names": self.sink_var_names
})
return monkey_patch_reader_methods(self.reader)
Loading