Skip to content

Commit

Permalink
[NNVM] Move FTVMCompute registration of the elementwise operator to c…
Browse files Browse the repository at this point in the history
…++ (#1351)
  • Loading branch information
nishi-t authored and tqchen committed Jun 29, 2018
1 parent c9c031a commit ebdde3c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 39 deletions.
39 changes: 0 additions & 39 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,69 +182,30 @@ def compute_cast(attrs, inputs, _):
reg.register_schedule("clip", _fschedule_elemwise)

# elemwise sum
@reg.register_compute("elemwise_sum")
def compute_elemwise_sum(attrs, inputs, _):
"""Compute definition of elemwise sum"""
num_args = attrs.get_int("num_args")
assert num_args == len(inputs), "Number of tensors does not match num_args."
return topi.tensor.elemwise_sum(inputs)
reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE)
reg.register_schedule("elemwise_sum", _fschedule_elemwise)

# full
@reg.register_compute("full")
def compute_full(attrs, inputs, _):
"""Compute definition of full"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
fill_value = attrs.get_float("fill_value")
return topi.tensor.full(shape, dtype, fill_value)
reg.register_pattern("full", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("full", _fschedule_elemwise)

# full_like
@reg.register_compute("full_like")
def compute_full_like(attrs, inputs, _):
"""Compute definition of full_like"""
fill_value = attrs.get_float("fill_value")
return topi.tensor.full_like(inputs[0], fill_value)
reg.register_pattern("full_like", OpPattern.ELEMWISE)
reg.register_schedule("full_like", _fschedule_elemwise)

# zeros
@reg.register_compute("zeros")
def compute_zeros(attrs, inputs, _):
"""Compute definition of zeros"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
return topi.tensor.full(shape, dtype, 0)
reg.register_pattern("zeros", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("zeros", _fschedule_elemwise)

# zeros_like
@reg.register_compute("zeros_like")
def compute_zeros_like(_, inputs, out_info):
"""Compute definition of zeros_like"""
return topi.tensor.full_like(inputs[0], 0)
reg.register_pattern("zeros_like", OpPattern.ELEMWISE)
reg.register_schedule("zeros_like", _fschedule_elemwise)

# ones
@reg.register_compute("ones")
def compute_ones(attrs, inputs, _):
"""Compute definition of ones"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
#tvm.tensor.Tensor()
return topi.tensor.full(shape, dtype, 1)
reg.register_pattern("ones", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("ones", _fschedule_elemwise)

# ones_like
@reg.register_compute("ones_like")
def compute_ones_like(_, inputs, out_info):
"""Compute definition of ones_like"""
return topi.tensor.full_like(inputs[0], 1)
reg.register_pattern("ones_like", OpPattern.ELEMWISE)
reg.register_schedule("ones_like", _fschedule_elemwise)

Expand Down
62 changes: 62 additions & 0 deletions nnvm/src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h>
#include <cmath>
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/broadcast.h"
#include "topi/elemwise.h"
#include "topi/tags.h"
#include "../../compiler/compile_engine.h"

namespace nnvm {
namespace top {
Expand Down Expand Up @@ -382,6 +384,16 @@ NNVM_REGISTER_INIT_OP(full)
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const InitOpWithScalarParam& param = nnvm::get<InitOpWithScalarParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
Type dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, param.fill_value);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
.set_support_level(4);

NNVM_REGISTER_INIT_OP(zeros)
Expand All @@ -395,6 +407,16 @@ NNVM_REGISTER_INIT_OP(zeros)
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const InitOpParam& param = nnvm::get<InitOpParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
Type dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, 0);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
.set_support_level(4);

NNVM_REGISTER_INIT_OP(ones)
Expand All @@ -408,6 +430,16 @@ NNVM_REGISTER_INIT_OP(ones)
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const InitOpParam& param = nnvm::get<InitOpParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
Type dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, 1);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
.set_support_level(4);

// full_like
Expand All @@ -419,20 +451,42 @@ as the input array
.add_arguments(FillValueParam::__FIELDS__())
.set_attr_parser(ParamParser<FillValueParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FillValueParam>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const FillValueParam& param = nnvm::get<FillValueParam>(attrs.parsed);
const Expr fill_value = tvm::make_const(out_info[0]->dtype, param.fill_value);
return Array<Tensor> { topi::full_like(inputs[0], fill_value) };
})
.set_support_level(4);

NNVM_REGISTER_INIT_LIKE_OP(zeros_like)
.describe(R"code(Return an array of zeros with the same shape and type
as the input array.
)code")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor> { topi::full_like(inputs[0],
tvm::make_const(out_info[0]->dtype, 0)) };
})
.set_support_level(4);

NNVM_REGISTER_INIT_LIKE_OP(ones_like)
.describe(R"code(Return an array of ones with the same shape and type
as the input array.
)code")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor> { topi::full_like(inputs[0],
tvm::make_const(out_info[0]->dtype, 1)) };
})
.set_support_level(4);

// unary scalar op
Expand Down Expand Up @@ -684,6 +738,14 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum)
.describe(R"code(Adds all input arguments element-wise.
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ElementWiseReduceParam& param = nnvm::get<ElementWiseReduceParam>(attrs.parsed);
CHECK_EQ(param.num_args, inputs.size()) << """Compute definition of elemwise sum""";
return Array<Tensor>{ topi::elemwise_sum(inputs) };
})
.set_attr<nnvm::FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
Expand Down

0 comments on commit ebdde3c

Please sign in to comment.