Skip to content

Commit

Permalink
[REFACTOR][TE][TIR] Call::Halide => ProducerLoad, DSL/TIR decouple. (#…
Browse files Browse the repository at this point in the history
…5743)

In the HalideIR's design, DSL components and IR are mixed together.
For example, Call::Halide can containa reference to a function which is
constructed in the tensor expression language.

While this coupled design simplifies certain aspect of the DSL construction,
it prevents the TIR to evolve as a clean standalone IR:

- The additional tensor expression provided in the function is opaque to the IR
  and may become obsolete as we transform them.
- The duplication of the information in the DSL tensor and IR makes it hard to
  design a stand-alone text format (when there are elements shared in the tensor
  expression and normal statements).

This PR aims to clearly de-couple the TIR from high-level DSL structures(tensor expression),
while still provide clear extensions to build DSLs on top of the TIR.

We introduce a DataProducer as a base class for high level tensor expressions objects
that produce data. We then introduce ProducerLoad to replace the Call::Halide usage,
so that the Call node can always be self contained and used for low-level calls.

The high-level tensor expression DSL can still generate a PrimExpr that contains a ProducerLoad.
These PrimExprs contains fragments of information that can be combined together to
generate a low-level TIR PrimFunc.

We also state clearly that DataProducer **should not** appear in any TIR PrimFunc.
Instead, the high-level DSL layer should lowered DataProducers to Buffers and TIR statements
that produces these buffers. We can further provide verifications to validate such invariance.

Changes:
- Introduce DataProducer to serve as a base class for Tensor in tensor expressions.
- Migrate use of Call::Halide to ProducerLoad
- Migrate the other usages of Calls.

We will also create follow-up PRs to migrate the remaining two DSL related IR nodes(Realize/Provide)
to use the DataProducer.
  • Loading branch information
tqchen committed Jun 7, 2020
1 parent 7053546 commit 6ae439c
Show file tree
Hide file tree
Showing 44 changed files with 519 additions and 363 deletions.
15 changes: 11 additions & 4 deletions include/tvm/te/tensor.h
Expand Up @@ -49,11 +49,11 @@ class OperationNode;
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public ObjectRef {
class Tensor : public DataProducer {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
explicit Tensor(ObjectPtr<Object> n) : DataProducer(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -157,7 +157,7 @@ class Operation : public tir::FunctionRef {
};

/*! \brief Node to represent a tensor */
class TensorNode : public Object {
class TensorNode : public DataProducerNode {
public:
/*! \brief The shape of the tensor */
Array<PrimExpr> shape;
Expand All @@ -176,10 +176,17 @@ class TensorNode : public Object {
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}

Array<PrimExpr> GetShape() const final { return shape; }

DataType GetDataType() const final { return dtype; }

TVM_DLL String GetNameHint() const final;

TVM_DLL static Tensor make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);

static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
};

// Implementations of inline functions
Expand Down
55 changes: 55 additions & 0 deletions include/tvm/tir/buffer.h
Expand Up @@ -203,6 +203,61 @@ inline const BufferNode* Buffer::operator->() const {
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
std::string name = "buffer");

/*!
* \brief Base node for data producers.
*
* A DataProducer stores necessary information(e.g. a tensor expression) to produce
* a multi-dimensional array. The stored information is opaque to the TIR.
* DataProducer can appear in high-level DSLs that are built on top of the TIR.
*
* A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower
* all DataProducers to Buffers before TIR transformations.
*
* \sa tvm::te::Tensor
*/
class DataProducerNode : public Object {
public:
/*! \brief destructor. */
virtual ~DataProducerNode() {}
/*!
* \brief Get the shape of the result.
* \return The shape.
*/
virtual Array<PrimExpr> GetShape() const = 0;
/*!
* \brief Get the data type of the result.
* \return The data type.
*/
virtual DataType GetDataType() const = 0;
/*!
* \brief Get the name hint of the data producer.
* \return The data type.
*/
virtual String GetNameHint() const = 0;

bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const {
// because buffer producer is opaque, we just do pointer equality.
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {}

static constexpr const char* _type_key = "DataProducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
};

/*!
* \brief Managed reference to DataProducerNode.
* \sa DataProducerNode
*/
class DataProducer : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode);
};

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_BUFFER_H_
81 changes: 55 additions & 26 deletions include/tvm/tir/expr.h
Expand Up @@ -449,12 +449,64 @@ class BufferLoadNode : public PrimExprNode {
TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to BufferLoadNode.
* \sa BufferLoadNode
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices);
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
};

/*!
* \brief Load value from the result produced by the producer.
*
* \note This node only appears in high-level DSLs that are built on top of the TIR.
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
* this node before TIR transformations.
*
* \sa ProducerLoad, DataProducerNode
*/
class ProducerLoadNode : public PrimExprNode {
public:
/*! \brief The buffer producer. */
DataProducer producer;
/*! \brief The location arguments. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("producer", &producer);
v->Visit("indices", &indices);
}

bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(producer, other->producer) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(producer);
hash_reduce(indices);
}

static constexpr const char* _type_key = "ProducerLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to ProducerLoadNode.
* \sa ProducerLoadNode
*/
class ProducerLoad : public PrimExpr {
public:
TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices);

TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
};

/*!
* \brief Load the value from buffer_var.
*
Expand Down Expand Up @@ -661,11 +713,6 @@ class CallNode : public PrimExprNode {
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*!
* \brief Halide-style call, evaluates func(args).
* \note Deprecated, move to BufferLoad in the future.
*/
Halide = 3,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
/*! \brief Intrinsic functions that are pure. */
Expand All @@ -677,49 +724,31 @@ class CallNode : public PrimExprNode {
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;
/*!
* \brief The function to be called.
* \note Deprecated, move to BufferLoad in the future.
*/
FunctionRef func;
/*!
* \brief The output value index if func's value is a tuple.
* \note Deprecated, move to BufferLoad in the future.
*/
int value_index{0};

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
v->Visit("func", &func);
v->Visit("value_index", &value_index);
}

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) &&
equal(call_type, other->call_type) && equal(func, other->func) &&
equal(value_index, other->value_index);
equal(call_type, other->call_type);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(name);
hash_reduce(args);
hash_reduce(call_type);
hash_reduce(func);
hash_reduce(value_index);
}

TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array<PrimExpr> args,
CallType call_type, FunctionRef func = FunctionRef(),
int value_index = 0);
CallType call_type);

/*! \return Whether call node is pure. */
bool is_pure() const {
return (call_type == PureExtern || call_type == PureIntrinsic || call_type == Halide);
}
bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }

/*!
* \return Whether call node corresponds to a defined intrinsic.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/expr_functor.h
Expand Up @@ -119,6 +119,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -163,6 +164,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
IR_EXPR_FUNCTOR_DISPATCH(AddNode);
Expand Down Expand Up @@ -213,6 +215,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const ProducerLoadNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const AddNode* op) override;
Expand Down Expand Up @@ -258,6 +261,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
PrimExpr VisitExpr_(const AddNode* op) override;
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/task.py
Expand Up @@ -495,11 +495,11 @@ def _count_flop(exp):
if isinstance(exp, expr.Select):
return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
_count_flop(exp.false_value))
if isinstance(exp, expr.Call):
if exp.call_type == expr.Call.Halide:
# Ignore flops from indexing expressions.
return 0
if isinstance(exp, expr.ProducerLoad):
# Ignore flops from indexing expressions.
return 0

if isinstance(exp, expr.Call):
return sum([_count_flop(x) for x in exp.args])

raise FlopCalculationError("Found unsupported operator in the compute expr")
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/target/datatype.py
Expand Up @@ -88,7 +88,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None):
op_name : str
The name of the operation which the function computes, given by its
Halide::Internal class name (e.g. Add, LE, Cast).
class name (e.g. Add, LE, Cast).
target : str
The name of codegen target.
Expand Down Expand Up @@ -136,8 +136,8 @@ def lower(op):
dtype += "x" + str(t.lanes)
if isinstance(op, (_Cast, _FloatImm)):
return _Call(dtype, extern_func_name, convert([op.value]),
_Call.Extern, None, 0)
_Call.Extern)
return _Call(dtype, extern_func_name, convert([op.a, op.b]),
_Call.Extern, None, 0)
_Call.Extern)

return lower
8 changes: 3 additions & 5 deletions python/tvm/te/hybrid/parser.py
Expand Up @@ -272,8 +272,7 @@ def visit_Name(self, node):
return entry if isinstance(node.ctx, ast.Load) else None
if ty is Symbol.BufferVar:
if isinstance(node.ctx, ast.Load):
return tvm.tir.Call(entry.dtype, entry.name, [tvm.runtime.const(0, 'int32')], \
_expr.Call.Halide, entry.op, entry.value_index)
return tvm.tir.ProducerLoad(entry, [tvm.runtime.const(0, 'int32')])
return entry, [tvm.runtime.const(0, 'int32')]
# Do I need any assertion here?
return entry
Expand Down Expand Up @@ -305,7 +304,7 @@ def visit_AugAssign(self, node):
args = [tvm.runtime.const(0, 'int32')]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")

read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
read = tvm.tir.ProducerLoad(buf, args)
value = HybridParser._binop_maker[type(node.op)](read, rhs)

return tvm.tir.Provide(buf.op, 0, value, args)
Expand Down Expand Up @@ -392,8 +391,7 @@ def visit_Subscript(self, node):
arr = arr[i.value]
return arr
if isinstance(node.ctx, ast.Load):
return tvm.tir.Call(arr.dtype, arr.name, args,
_expr.Call.Halide, arr.op, arr.value_index)
return tvm.tir.ProducerLoad(arr, args)
return arr, args

def visit_With(self, node):
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/te/hybrid/util.py
Expand Up @@ -78,10 +78,9 @@ def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _stmt.Provide(buf.op, op.value_index, op.value, op.args)
if isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _expr.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys():
buf = rmap[op.producer.op]
return _expr.ProducerLoad(buf, op.indices)
return None

return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/te/tensor.py
Expand Up @@ -19,7 +19,7 @@
import tvm._ffi

from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr
from tvm.tir import expr as _expr, DataProducer

from . import _ffi_api

Expand Down Expand Up @@ -52,7 +52,7 @@ class TensorIntrinCall(Object):


@tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp):
class Tensor(DataProducer, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""

def __call__(self, *indices):
Expand All @@ -69,9 +69,8 @@ def __call__(self, *indices):
else:
raise ValueError("The indices must be expression")

return _expr.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
return _expr.ProducerLoad(self, args)


def __getitem__(self, indices):
return TensorSlice(self, indices)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/__init__.py
Expand Up @@ -19,12 +19,12 @@
from tvm.ir import PrimExpr
from tvm.runtime import const

from .buffer import Buffer, decl_buffer
from .buffer import Buffer, decl_buffer, DataProducer
from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, For
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/tir/buffer.py
Expand Up @@ -245,3 +245,8 @@ def decl_buffer(shape,
return _ffi_api.Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor, buffer_type)


@tvm._ffi.register_object
class DataProducer(Object):
pass

0 comments on commit 6ae439c

Please sign in to comment.