Skip to content

Commit

Permalink
merge baidu/develop
Browse files Browse the repository at this point in the history
  • Loading branch information
QiJune committed Jul 17, 2017
2 parents 2a03e38 + a0caf23 commit 8718966
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 38 deletions.
11 changes: 11 additions & 0 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,32 @@ class OpRegistry {
static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)());
const OpProto& op_proto = protos().at(op_type);
// set op's inputs_ from desc.
op->type_ = op_desc.type();
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
// set op's outputs_ from desc.
op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_));
// set op's attr;
for (auto& attr : op_desc.attrs()) {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
op_checkers().at(op_type).Check(op->attrs_);
// set argument offsets stored in op.
CreateInOutOffsetMap(op, op_proto);
op->Init();
return op;
}

// init op.in_out_idxs_ to accelerate argument's offset lookup.
static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) {
op->CreateInOutOffsetMap(proto);
}

static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
Expand Down
64 changes: 61 additions & 3 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,83 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>

#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

template <>
Eigen::DefaultDevice* OpKernel::KernelContext::GetEigenDevice<
Eigen::DefaultDevice* KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}

#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice* OpKernel::KernelContext::GetEigenDevice<
platform::GPUPlace, Eigen::GpuDevice>() const {
Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif

void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) {
PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap");
for (int i = 0; i < proto.inputs_size(); i++) {
const auto& name = proto.inputs()[i].name();
in_out_idxs_[name] = i;
}
for (int i = 0; i < proto.outputs_size(); i++) {
const auto& name = proto.outputs()[i].name();
in_out_idxs_[name] = i;
}
}

const std::string& OperatorBase::Input(const std::string& name) const {
auto it = in_out_idxs_.find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name);

if (attrs_.count("input_format") == 0) {
return inputs_[it->second];
} else {
const auto& input_format = GetAttr<std::vector<int>>("input_format");
int idx = input_format[it->second];
return inputs_.at(idx);
}
}

std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_.at(name);

return std::vector<std::string>{
inputs_.begin() + input_format.at(offset),
inputs_.begin() + input_format.at(offset + 1)};
}

const std::string& OperatorBase::Output(const std::string& name) const {
auto it = in_out_idxs_.find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name);

if (attrs_.count("output_format") == 0) {
return outputs_[it->second];
} else {
const auto& output_format = GetAttr<std::vector<int>>("output_format");
int idx = output_format[it->second];
return outputs_.at(idx);
}
}

std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_.at(name);

return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)};
}

std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
Expand Down
98 changes: 72 additions & 26 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
Expand Down Expand Up @@ -77,11 +79,79 @@ class OperatorBase {
virtual void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const = 0;

// Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;

// init in_out_idxs_ to accelerate argument's offset lookup.
void CreateInOutOffsetMap(const OpProto& proto);

public:
std::string type_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attrs_;
// store the arguments' offset described in op_desc.
std::unordered_map<std::string, int> in_out_idxs_;
};

class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}

Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}

const Variable* Input(const std::string& name) const {
return scope_->GetVariable(op_.Input(name));
}

const Variable* Output(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
}

const std::vector<const Variable*> Inputs(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}

const std::vector<const Variable*> Outputs(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}

template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;

platform::Place GetPlace() const { return device_context_.GetPlace(); }

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};

class OpKernel {
Expand All @@ -92,31 +162,6 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const ScopePtr& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}

Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}

template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;

platform::Place GetPlace() const { return device_context_.GetPlace(); }

const OperatorBase& op_;
const ScopePtr& scope_;
const platform::DeviceContext& device_context_;
};

virtual void Compute(const KernelContext& context) const = 0;

Expand Down Expand Up @@ -162,14 +207,15 @@ class OperatorWithKernel : public OperatorBase {
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
}

static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
}

void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
Expand Down
Loading

0 comments on commit 8718966

Please sign in to comment.