Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

operator部分代码设计 #300

Closed
RayLiu2015 opened this issue May 28, 2018 · 0 comments
Closed

operator部分代码设计 #300

RayLiu2015 opened this issue May 28, 2018 · 0 comments

Comments

@RayLiu2015
Copy link
Collaborator

RayLiu2015 commented May 28, 2018

Operator 设计

一个operator 可分为两层:

  • 一层为 op 层, 包括参数获取包装成param结构体(传递给kernel层的参数结构体)和InferShape的操作.
  • 另一层为 kernel 层, 包含了op的具体运算部分, kernel层为可特化到具体平台实现

每一个op需要对应着一段op注册代码, 用于上层在实例化op时使用

op层

template <typename Dtype>
class OperatorBase {
 public:
  /*
   *  @b op 基类的实例化方法, op 获取到了 输入、参数以及提前分配好的输出 tensor
   * */
  OperatorBase(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs,
               std::shared_ptr<Scope> scope);
  virtual ~OperatorBase() {}
  void Run() const;

  std::vector<string> GetOutKeys() const;

  virtual void RunImpl() const = 0;

  virtual void Init() = 0;
  /*
   * @b op 运算所需的输入, 如上一层的输出结果、卷积核
   * */
  const VariableNameMap &Inputs() const { return inputs_; }
  /*
   * @b op 的输出, 内存会提前被分配好, 运算结果会被存到分配好的内存内
   * */
  const VariableNameMap &Outputs() const { return outputs_; }
  /*
   * @b op 类型
   * */
  const std::string &Type() const { return type_; }

  /*
   * @b 根据输入形状和参数计算出输出形状
   * */
  virtual void InferShape() const = 0;

 protected:
  std::shared_ptr<Scope> scope_;
  std::string type_;
  VariableNameMap inputs_;
  VariableNameMap outputs_;
  AttributeMap attrs_;

 private:
  void CheckAllInputOutputSet() const;
};
/*
 * @b 这个类为所有带有运算的 op 的父类, 这个 op 继承与 OperatorBase
 * */
template <typename Dtype, typename ParamType, typename KernelType>
class OperatorWithKernel : public OperatorBase<Dtype> {
 public:
  OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
        param_(inputs, outputs, attrs, *scope) {}

  virtual void RunImpl() const { this->kernel_.Compute(this->param_); }

  virtual void InferShape() const = 0;

  void Init() {
   // op 实现者可以重写该方法, 对参数进行预处理
    PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), "  %s kernel init failed",
                          this->type_.c_str());
  }

 protected:
  KernelType kernel_;
  ParamType param_;
};

kernel 层

/*
 * @b 所有kernel的父类
 * */
template <typename Dtype, typename P>
class OpKernelBase {
 public:
  /*
   * @b 所有kernel 需实现 Compute 方法
   * @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
   *    所有结构体存在与: paddle-mobile/src/operators/op_param.h
   * */
  virtual void Compute(const P &para) const = 0;
  virtual bool Init(P *para) { return true; };
  virtual ~OpKernelBase() = default;
};

例子: 一个 relu 的实现

//relu_op.h

template <typename DeviceType, typename T>
class ReluOp
    : public framework::OperatorWithKernel<
          DeviceType, ReluParam, operators::ReluKernel<DeviceType, T>> {
 public:
  /*
   * @b op 的实例化方法, 需要调用父类的实例化方法, 以及实例化自己的参数结构体
   * */
  ReluOp(const std::string &type, const VariableNameMap &inputs,
         const VariableNameMap &outputs, const framework::AttributeMap &attrs,
         std::shared_ptr<framework::Scope> scope)
      : framework::OperatorWithKernel<DeviceType, ReluParam,
                                      operators::ReluKernel<DeviceType, T>>(
            type, inputs, outputs, attrs, scope) {}

  using framework::OperatorWithKernel<
      DeviceType, ReluParam,
      operators::ReluKernel<DeviceType, T>>::OperatorWithKernel;
  void InferShape() const override;

 protected:
};

.cpp 中给出了 InferShape 的实现, 和op注册部分

//relu_op.cpp
namespace paddle_mobile {
namespace operators {

template <typename Dtype, typename T>
void ReluOp<Dtype, T>::InferShape() const {
  auto input_dims = param_.InputX()->dims();
  param_.Out()->Resize(input_dims);
}
template class ReluOp<CPU, float>;
}  // namespace operators
}  // namespace paddle_mobile

/*
 * @b 每一个 op 都需要注册一下的,
 *    USE_OP的参数 和 REGISTER_OPERATOR的第一个参数 都是需要和model中类型对应起来的
 * */
namespace ops = paddle_mobile::operators;
USE_OP(relu);
REGISTER_OPERATOR(relu, ops::ReluOp);

ReluParam 用于包装参数的结构体

//op_param.h
/*
 * @b op 层实例化好这个 param 传递给 kernel 层使用
 * */
class ReluParam : public OpParam {
 public:
  ReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
            const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<Tensor>(inputs, scope);
    out_ = OutFrom<Tensor>(outputs, scope);
  }
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }
 private:
  Tensor *input_x_;
  Tensor *out_;
};

kernel 层声明

template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
 public:
  void Compute(const ReluParam& param) const;
  bool Init(ReluParam* param);
};

特化到 arm 平台

/*
 * @b 特化到具体平台的实现, param 从 op 层传入
 * */
template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) {
  // 进行一些预处理操作
  return true;
}

/*
 * @b 特化到具体平台的实现, param 从 op 层传入
 * */
template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const {
      // arm 汇编实现 ...
}

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants