Skip to content

Commit

Permalink
[REFACTOR][API-Change] Migrate all Object construction to constructor.
Browse files Browse the repository at this point in the history
This PR migrates all the remaining object constructions to the new constructor style
that is consistent with the rest of the codebase and changes the affected files accordingly.

Other changes:

- ThreadScope::make -> ThreadScope::Create
- StorageScope::make -> StorageScope::Create
  • Loading branch information
tqchen committed Jun 12, 2020
1 parent 54bde85 commit d2c9506
Show file tree
Hide file tree
Showing 49 changed files with 469 additions and 416 deletions.
2 changes: 1 addition & 1 deletion docs/dev/codebase_walkthrough.rst
Expand Up @@ -84,7 +84,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.
::

inline Schedule create_schedule(Array<Operation> ops) {
return ScheduleNode::make(ops);
return Schedule(ops);
}

``Schedule`` consists of collections of ``Stage`` and output ``Operation``.
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/relay_add_pass.rst
Expand Up @@ -138,7 +138,7 @@ is shown below.
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
return TupleGetItem(t, g->index);
}
}
Expand Down
10 changes: 5 additions & 5 deletions docs/dev/relay_pass_infra.rst
Expand Up @@ -344,13 +344,13 @@ registration.
.. code:: c++

// Create a simple Relay program.
auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
auto tensor_type = relay::TensorType({}, tvm::Bool());
auto x = relay::Var("x", relay::Type());
auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});

auto y = relay::VarNode::make("y", tensor_type);
auto y = relay::Var("y", tensor_type);
auto call = relay::Call(f, tvm::Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});

// Create a module for optimization.
auto mod = IRModule::FromExpr(fx);
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/span.h
Expand Up @@ -97,14 +97,14 @@ class SpanNode : public Object {
equal(col_offset, other->col_offset);
}

TVM_DLL static Span make(SourceName source, int lineno, int col_offset);

static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};

class Span : public ObjectRef {
public:
TVM_DLL Span(SourceName source, int lineno, int col_offset);

TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};

Expand Down
90 changes: 75 additions & 15 deletions include/tvm/te/operation.h
Expand Up @@ -177,12 +177,22 @@ class PlaceholderOpNode : public OperationNode {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name, Array<PrimExpr> shape, DataType dtype);

static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};

/*!
* \brief Managed reference to PlaceholderOpNode
* \sa PlaceholderOpNode
*/
class PlaceholderOp : public Operation {
public:
TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
};

/*!
* \brief A Compute op that compute a tensor on certain domain.
* This is the base class for ComputeOp (operating on a scalar at a time) and
Expand Down Expand Up @@ -237,13 +247,23 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);

static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
};

/*!
* \brief Managed reference to ComputeOpNode
* \sa ComputeOpNode
*/
class ComputeOp : public Operation {
public:
TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);

TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
};

/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
Expand Down Expand Up @@ -285,15 +305,25 @@ class TensorComputeOpNode : public BaseComputeOpNode {
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static Operation make(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
};

/*!
* \brief Managed reference to TensorComputeOpNode
* \sa TensorComputeOpNode
*/
class TensorComputeOp : public Operation {
public:
TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);

TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
};

/*!
* \brief Symbolic scan.
*/
Expand Down Expand Up @@ -353,14 +383,24 @@ class ScanOpNode : public OperationNode {
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
IterVar axis, Array<Tensor> init, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> input);

static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
};

/*!
* \brief Managed reference to ScanOpNode
* \sa ScanOpNode
*/
class ScanOp : public Operation {
public:
TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
Array<Tensor> input);

TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
};

/*!
* \brief External computation that cannot be splitted.
*/
Expand Down Expand Up @@ -404,14 +444,24 @@ class ExternOpNode : public OperationNode {
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);

static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
};

/*!
* \brief Managed reference to ExternOpNode
* \sa ExternOpNode
*/
class ExternOp : public Operation {
public:
TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
};

/*!
* \brief A computation operator that generated by hybrid script.
*/
Expand Down Expand Up @@ -459,13 +509,23 @@ class HybridOpNode : public OperationNode {
v->Visit("axis", &axis);
v->Visit("body", &body);
}
TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);

static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
};

/*!
* \brief Managed reference to HybridOpNode
* \sa HybridOpNode
*/
class HybridOp : public Operation {
public:
TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
};

/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
Expand Down
68 changes: 51 additions & 17 deletions include/tvm/te/schedule.h
Expand Up @@ -277,6 +277,12 @@ class Schedule : public ObjectRef {
public:
Schedule() {}
explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
TVM_DLL explicit Schedule(Array<Operation> ops);
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
Expand Down Expand Up @@ -553,13 +559,6 @@ class ScheduleNode : public Object {
*/
TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }

/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
TVM_DLL static Schedule make(Array<Operation> ops);

static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
};
Expand All @@ -569,7 +568,7 @@ class ScheduleNode : public Object {
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
inline Schedule create_schedule(Array<Operation> ops) { return ScheduleNode::make(ops); }
inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }

/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Object {
Expand Down Expand Up @@ -648,13 +647,21 @@ class SplitNode : public IterVarRelationNode {
v->Visit("nparts", &nparts);
}

static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
PrimExpr nparts);

static constexpr const char* _type_key = "Split";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to SplitNode
* \sa SplitNode
*/
class Split : public IterVarRelation {
public:
TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);

TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
};

/*!
* \brief Fuse two domains into one domain.
*/
Expand All @@ -673,12 +680,21 @@ class FuseNode : public IterVarRelationNode {
v->Visit("fused", &fused);
}

static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused);

static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to FuseNode
* \sa FuseNode
*/
class Fuse : public IterVarRelation {
public:
TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);

TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode);
};

/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
Expand All @@ -696,12 +712,21 @@ class RebaseNode : public IterVarRelationNode {
v->Visit("rebased", &rebased);
}

static IterVarRelation make(IterVar parent, IterVar rebased);

static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to RebaseNode
* \sa RebaseNode
*/
class Rebase : public IterVarRelation {
public:
TVM_DLL Rebase(IterVar parent, IterVar rebased);

TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode);
};

/*!
* \brief Singleton iterator [0, 1)
*/
Expand All @@ -712,12 +737,21 @@ class SingletonNode : public IterVarRelationNode {

void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }

static IterVarRelation make(IterVar iter);

static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to SingletonNode
* \sa SingletonNode
*/
class Singleton : public IterVarRelation {
public:
TVM_DLL explicit Singleton(IterVar iter);

TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
};

/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
Expand Down

0 comments on commit d2c9506

Please sign in to comment.