diff --git a/doc/fluid/dev/new_op_cn.md b/doc/fluid/dev/new_op_cn.md index ff7408111fa..e03ccf6ccad 100644 --- a/doc/fluid/dev/new_op_cn.md +++ b/doc/fluid/dev/new_op_cn.md @@ -150,8 +150,9 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Input("Y")->dims(); + //never use Input or Output if you want a to get a LoDTensor. + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Input("Y")->dims(); PADDLE_ENFORCE_EQ(dim0.size(), 2, "input X(%s) should be a tensor with 2 dims, a matrix", ctx.op_.Input("X")); @@ -161,7 +162,7 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( dim0[1], dim1[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize({dim0[0], dim1[1]}); + ctx.Output("Out")->Resize({dim0[0], dim1[1]}); } }; ``` @@ -201,6 +202,8 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs, - 与`InferShapeContext`相比,`ExecutionContext`增加了设备类型,同样可获取到输入输出和属性参数。 - `Compute`函数里实现`OpKernel`的具体计算逻辑。 +Op的输入和输出可分别通过ExecutionContext::Input()和ExecutionContext::Output()获得。注意:若op的输入/输出的变量类型是LoDTensor(fluid默认所有的Tensor默认都是LoDTensor类型),请写成ExecutionContext::Input()和ExecutionContext::Output(),不要写ExecutionContext::Input()和ExecutionContext::Output()。因为若实际的变量类型为SelectedRows,Input()和Output()方法会将SelectedRows类型特化为Tensor,导致潜在的错误。 + 下面是 `MulKernel` `Compute`的实现: ```cpp @@ -208,9 +211,9 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs, class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* Z = context.Output("Out"); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); Z->mutable_data(context.GetPlace()); auto& device_context = context.template device_context(); math::matmul(*X, false, *Y, false, 1, Z, 0, device_context);