Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCHEDULE] Fuse support for 0 rank tensor #1328

Merged
merged 2 commits into from
Jun 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ class Stage : public NodeRef {
* \return reference to self.
*/
EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*!
* \brief Fuse all the axes together into a single axis.
*
* \param axes All the axes to be fused.
* \param p_target The result target domain.
*
* \note axes can be an empty array,
* in that case, a singleton itervar is created and
* inserted to the outermost loop.
* The fuse of empty array is used to support zero-dimension tensors.
*
* \return reference to self.
*/
EXPORT Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
Expand All @@ -151,9 +165,9 @@ class Stage : public NodeRef {
* \return reference to self.
*/
EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
Expand Down Expand Up @@ -674,6 +688,25 @@ class RebaseNode : public IterVarRelationNode {
};


/*!
* \brief Singleton iterator [0, 1)
*/
class SingletonNode : public IterVarRelationNode {
public:
/*! \brief The singleton iterator */
IterVar iter;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter", &iter);
}

static IterVarRelation make(IterVar iter);

static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
};


// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ class Fuse(NodeBase):
pass


@register_node
class Singleton(NodeBase):
"""Singleton axis."""
pass


@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.
Expand Down Expand Up @@ -380,10 +386,7 @@ def fuse(self, *args):
fused : IterVar
The fused variable of iteration.
"""
assert len(args) >= 1, "Length of the arguments must be >=1 for fuse."
fused = args[0]
for i in range(1, len(args)):
fused = _api_internal._StageFuse(self, fused, args[i])
fused = _api_internal._StageFuse(self, args)
return fused

def set_scope(self, scope):
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ TVM_REGISTER_API("_StageFuse")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused;
args[0].operator Stage()
.fuse(args[1], args[2], &fused);
.fuse(args[1], &fused);
*ret = fused;
});

Expand Down
11 changes: 11 additions & 0 deletions src/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ void PassDownDomain(const Stage& stage,
Update(p_state, r->rebased,
Range::make_by_min_extent(
0, state.at(r->parent)->extent));
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::make_by_min_extent(0, 1));
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -147,6 +149,7 @@ void PassUpIndex(const Stage& stage,
} else {
state[s->parent] = value;
}
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -192,6 +195,8 @@ void PassDownIndex(const Stage& stage,
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(is_zero(parent_min));
state[s->rebased] = value;
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = make_zero(s->iter->var.type());
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -296,6 +301,7 @@ void PassUpDomain(const Stage& stage,
state.at(r->rebased),
&parent);
state[r->parent] = parent;
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -344,6 +350,7 @@ void PassUpBitMaskOr(const Stage& stage,
} else {
state[s->parent] |= state[s->rebased];
}
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -390,6 +397,8 @@ void PassDownBitMaskOr(const Stage& stage,
} else {
state[s->rebased] |= state.at(s->parent);
}
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = 0;
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -438,6 +447,8 @@ void PassUpBoundCheck(const Stage& s,
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) {
// nop
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down
38 changes: 37 additions & 1 deletion src/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
IterVar fused = IterVarNode::make(
Range(), Var(fused_name, outer->var.type()), iter_type);

*p_target = fused;
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();

Expand All @@ -255,6 +254,31 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_);
*p_target = fused;
return *this;
}

Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*)
if (axes.size() != 0) {
IterVar fused = axes[0];
for (size_t i = 1; i < axes.size(); ++i) {
this->fuse(fused, axes[i], &fused);
}
*p_target = std::move(fused);
} else {
StageNode* self = operator->();
// special handle fuse empty array.
// insert at the outer most loop
IterVar singleton = IterVarNode::make(
Range::make_by_min_extent(0, 1),
Var("singleton", Int(32)), kDataPar);
self->relations.push_back(SingletonNode::make(singleton));
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
all_vars->data.push_back(singleton.node_);
leaf_vars->data.insert(leaf_vars->data.begin(), singleton.node_);
*p_target = singleton;
}
return *this;
}

Expand Down Expand Up @@ -732,11 +756,18 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return IterVarRelation(n);
}

IterVarRelation SingletonNode::make(IterVar iter) {
auto n = std::make_shared<SingletonNode>();
n->iter = iter;
return IterVarRelation(n);
}

TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);

// Printer
Expand Down Expand Up @@ -778,6 +809,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
p->stream << "singleton(";
p->print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
p->stream << "schedule(" << op << ")";
});
Expand Down
9 changes: 4 additions & 5 deletions tests/python/integration/test_ewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_multiple_cache_write():
n = tvm.convert(1024)
A0 = tvm.placeholder((n,), name='A0', dtype = "float32")
A1 = tvm.placeholder((n,), name='A1', dtype = "float32")
B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
name='B')
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
name='C')
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
Expand Down Expand Up @@ -76,7 +76,7 @@ def check_device(device, host="stackvm"):
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
func(a0, a1, c)
np.testing.assert_allclose(
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
rtol=1e-5)

check_device("cuda", "llvm")
Expand Down Expand Up @@ -235,7 +235,6 @@ def check_device(device):
f(a, b)
np.testing.assert_allclose(
b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)

check_device("cuda")


Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_lang_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def test_fuse():
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)


def test_singleton():
A = tvm.placeholder((), name='A')
T = tvm.compute((), lambda : A() + 1)
s = tvm.create_schedule(T.op)
fused = s[T].fuse()
assert any(isinstance(x, tvm.schedule.Singleton) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused,)
dump = pkl.dumps(s)
s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule)


def test_vectorize():
m = tvm.var('m')
n = tvm.var('n')
Expand Down Expand Up @@ -174,6 +187,7 @@ def intrin_func(ins, outs):


if __name__ == "__main__":
test_singleton()
test_pragma()
test_tensor_intrin()
test_rfactor()
Expand Down
6 changes: 5 additions & 1 deletion topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_broadcast_to():
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)

def test_add():
verify_broadcast_binary_ele(
(), (), topi.add, np.add)
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add)

Expand All @@ -113,6 +115,8 @@ def test_multiply():
def test_divide():
verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(), None, topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)

Expand Down Expand Up @@ -157,10 +161,10 @@ def test_shift():


if __name__ == "__main__":
test_add()
test_shift()
test_cmp()
test_mod()
test_add()
test_subtract()
test_multiply()
test_divide()
Expand Down