Skip to content

Commit

Permalink
[SCHEDULE] Mutate dataflow in schedule, refactor Stage
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 15, 2017
1 parent 820a859 commit 11f300f
Show file tree
Hide file tree
Showing 14 changed files with 433 additions and 95 deletions.
22 changes: 11 additions & 11 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Tensor placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");

Expand All @@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct new tensors by scan over scan_axis.
Expand All @@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");

// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}

} // namespace tvm
Expand Down
55 changes: 48 additions & 7 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,28 @@ class Schedule : public NodeRef {
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
Expand All @@ -193,6 +215,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
// declare container type
using ContainerType = ScheduleNode;
};
Expand Down Expand Up @@ -244,8 +271,14 @@ class IterVarAttr : public NodeRef {
*/
class StageNode : public Node {
public:
/*! \brief The operation to be scheduled */
/*! \brief The operation of stage, can be different from original op. */
Operation op;
/*!
* \brief The original operator, can be null
* If defined, records the operation in original data flow.
* This means the stage's op has been changed.
*/
Operation origin_op;
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief All the nodes in the iter var */
Expand All @@ -265,17 +298,21 @@ class StageNode : public Node {
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
/*! \brief Whether this is an output stage */
bool is_output{false};

void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
v->Visit("is_output", &is_output);
}

static constexpr const char* _type_key = "Stage";
Expand All @@ -285,18 +322,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*! \brief The output operations in original data flow graph */
Array<Operation> outputs;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
* \brief list of all stages for non-placeholder ops.
* The stages are sorted in dependency order.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
}
Expand Down Expand Up @@ -412,12 +449,16 @@ inline StageNode* Stage::operator->() {

inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone);
return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}

inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}

inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
Expand Down
1 change: 0 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")

# lowering
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,53 @@ def normalize(self):
"""
_api_internal._ScheduleNormalize(self)

def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)

def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)


@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
Expand Down
14 changes: 13 additions & 1 deletion src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)

TVM_REGISTER_API(_Placeholder)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Placeholder(args[0],
*ret = placeholder(args[0],
args[1],
args[2]);
});
Expand Down Expand Up @@ -280,4 +280,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize();
});

TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});

TVM_REGISTER_API(_ScheduleCacheWrite)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_write(args[1], args[2]);
});

} // namespace tvm
6 changes: 3 additions & 3 deletions src/lang/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,



Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}

Expand Down Expand Up @@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape);
}

Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
Expand Down Expand Up @@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}

Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Expand Down
25 changes: 6 additions & 19 deletions src/schedule/auto_inline_elem_wise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
#include <tvm/ir_visitor.h>

namespace tvm {
namespace ir {
namespace schedule {

using namespace ir;

class ElemWiseDetector : public IRVisitor {
class ElemWiseDetector : public ir::IRVisitor {
public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}

Expand All @@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
}

for (size_t i = 0; i < axis_.size(); ++i) {
// const Variable *v1 = axis_[i]->var.as<Variable>();
// const Variable *v2 = axis[i].as<Variable>();
if (!axis[i].same_as(axis_[i]->var)) {
// if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_ = false;
return;
}
Expand All @@ -52,22 +51,10 @@ bool IsElemWise(const Operation& op) {
return false;
}

} // namespace ir

namespace schedule {

void AutoInlineElemWise(Schedule sch) {
for (Stage s : sch->stages) {
if (!s.is_scheduled() && ir::IsElemWise(s->op)) {
bool is_root = false;
for (auto r : sch->roots) {
if (r == s->op) {
is_root = true;
break;
}
}
if (!is_root)
s.compute_inline();
if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
s.compute_inline();
}
}
}
Expand Down
12 changes: 9 additions & 3 deletions src/schedule/bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
LOG(INFO) << time_dom.size();
CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
Expand All @@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
Range r = arith::Union(tdom.data.at(i)).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i]));
(*rmap)[compute->axis[i]] = r;
}
Expand Down Expand Up @@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
direct_consume_by_parent = true;
}
}
} else {
LOG(INFO) << "not in feed graph consumer = " << stage->op;
}
}
// The relax set
Expand Down Expand Up @@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
}

FeedGraph CreateFeedGraph(const Schedule& sch) {
auto g = CreateReadGraph(sch->roots);
Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
auto g = CreateReadGraph(roots);
FeedGraph fg;
for (auto kv : g) {
for (Tensor t : kv.second) {
Expand Down Expand Up @@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
Map<IterVar, Range> InferBound(const Schedule& sch) {
FeedGraph feed_graph = CreateFeedGraph(sch);
AttachPath attach_path = CreateAttachPath(sch);

std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
Expand Down
Loading

0 comments on commit 11f300f

Please sign in to comment.