Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
639 lines (597 sloc) 21.3 KB
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Compute Op.
* \file compute_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include <string>
#include <utility>
#include "compute_op.h"
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/int_set.h"
namespace tvm {
using namespace ir;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ComputeOpNode*>(node.get());
p->stream << "compute(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode *op);
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
int ComputeOpNode::num_outputs() const {
return body.size();
}
Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Type ComputeOpNode::output_dtype(size_t idx) const {
CHECK_LT(idx, num_outputs());
return body[idx].type();
}
Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
CHECK_LT(idx, num_outputs());
// for now, all outputs of a BaseComputeOp have the same shape
Array<Expr> shape;
for (const auto& ivar : this->axis) {
const Range& r = ivar->dom;
shape.push_back(r->extent);
}
return shape;
}
Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name,
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = make_node<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVarNode::make(
Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
args.push_back(axis.back()->var);
}
return ComputeOpNode::make(
name, tag, attrs, axis, {fcompute(args)}).output(0);
}
Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name,
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = make_node<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVarNode::make(
Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
args.push_back(axis.back()->var);
}
Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
Array<Tensor> outputs;
for (int idx = 0; idx < op->num_outputs(); ++idx) {
outputs.push_back(op.output(idx));
}
return outputs;
}
Operation ComputeOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis,
Array<Expr> body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<ComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->axis = std::move(axis);
n->body = std::move(body);
if (n->body[0]->IsInstance<ir::Reduce>()) {
const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
n->reduce_axis = reduce->axis;
}
VerifyComputeOp(n.get());
return Operation(n);
}
// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret;
std::unordered_set<Tensor> visited;
for (auto& e : body) {
ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
if (!visited.count(t)) {
ret.push_back(t);
visited.insert(t);
}
}
});
}
return ret;
}
Operation ComputeOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
VerifyComputeOp(this);
Array<Expr> arr;
if (this->body[0]->IsInstance<ir::Reduce>()) {
// Specially handle reduce so the replaced op
// still share all the components
Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>();
for (size_t k = 0; k < this->body.size(); ++k) {
auto n = make_node<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->type = r->source[k].type();
arr.push_back(Expr(n));
}
} else {
arr = this->body;
}
} else {
arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
return op::ReplaceTensor(e, rmap);
});
}
if (!arr.same_as(this->body)) {
return ComputeOpNode::make(
this->name, this->tag, this->attrs, this->axis, arr);
} else {
return self;
}
}
void ComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
if (t->op.defined() && out_dom_map->count(t)) {
TensorDom& dom = out_dom_map->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
// We assume that the value of the argument cannot be out of bounds (otherwise it is
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = EvalSet(call->args[i], dom_map);
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
Expr shape_i_min_value = make_zero(t->shape[i].type());
Expr shape_i_max_value = t->shape[i] - 1;
Expr min_value = arg_interval->min_value;
Expr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) ||
analyzer->CanProve(shape_i_min_value >= min_value)) {
min_value = shape_i_min_value;
}
if (arith::is_pos_inf(max_value) ||
analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
} else {
dom.data[i].push_back(arg_intset);
}
}
}
}
};
for (auto& e : body) ir::PostOrderVisit(e, fvisit);
}
void BaseComputeOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
CHECK(!out_dom_map->count(this->axis[i]));
(*out_dom_map)[this->axis[i]] = r;
}
for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
CHECK(!out_dom_map->count(this->reduce_axis[i]));
(*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
}
}
Stmt BaseComputeOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i-1);
realize = ir::Realize::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
Array<Expr> tuple = {static_cast<int>(i),
attr->dim_align_factor,
attr->dim_align_offset};
realize = ir::AttrStmt::make(
t, ir::attr::buffer_dim_align,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
realize);
}
}
}
}
return realize;
}
size_t ComputeOpNode::num_schedulable_dims() const {
return axis.size();
}
// Build a reduction body.
void MakeReduction(const ComputeOpNode* op,
const Array<Tensor>& tensors,
Stmt* init,
Stmt* provide) {
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
std::vector<Stmt> inits, provides;
size_t size = op->body.size();
const Reduce* reduce = op->body[0].as<Reduce>();
CHECK(reduce);
const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
CHECK(combiner);
Array<Expr> lhs;
for (size_t i = 0; i < size; ++i) {
lhs.push_back(tensors[i](args));
}
Array<Expr> init_value = combiner->identity_element;
Array<Expr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
inits.emplace_back(Provide::make(
t->op, t->value_index, init_value[i], args));
provides.emplace_back(Provide::make(
t->op, t->value_index, update_value[i], args));
}
*init = Block::make(inits);
*provide = Block::make(provides);
if (!is_one(reduce->condition)) {
*provide = IfThenElse::make(reduce->condition, *provide);
}
}
// Normal computation.
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
}
Stmt MakeComputeStmt(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
// grab the nest structure
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (self->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < self->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(self, source, &init, &provide);
init = MergeNest(n.init_nest, init);
init = op::Substitute(init, n.init_vmap);
// common nest
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = MergeNest(reduce, provide);
if (debug_keep_trivial_loop) {
provide = MergeNest(common, provide);
} else {
provide = MergeNest(common, Block::make(init, provide));
}
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return op::Substitute(provide, n.main_vmap);
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i)));
}
Stmt provide = Block::make(provides);
provide = MergeNest(n.main_nest, provide);
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return op::Substitute(provide, n.main_vmap);
}
}
enum class ComputeType {
kNormal,
kCrossThreadReduction,
kTensorize
};
ComputeType DetectComputeType(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0, tensorize = 0;
for (IterVar iv : stage->leaf_iter_vars) {
IterVarAttr attr;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
attr = (*it).second;
}
if (attr.defined() && attr->iter_type == kTensorized) {
++tensorize;
}
if (iv->iter_type == kCommReduce) {
if (attr.defined() && attr->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
if (tensorize != 0) {
CHECK(thread_red == 0)
<< "Cannot mix cross thread reduction with Tensorize";
return ComputeType::kTensorize;
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
if (thread_red != 0) {
return ComputeType::kCrossThreadReduction;
} else {
return ComputeType::kNormal;
}
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
} else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
} else {
return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
}
}
ComputeLoopNest ComputeLoopNest::make(
const BaseComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
debug_keep_trivial_loop);
ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>());
for (auto& e : ret.main_predicates) {
e = likely(e);
}
if (stage->store_predicate.defined()) {
ret.main_predicates.push_back(stage->store_predicate);
}
if (self->reduce_axis.size() != 0) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state;
for (IterVar iv : self->reduce_axis) {
update_state[iv] = 2;
}
for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
update_state[self->axis[i]] = 1;
}
// find which iter var is related to reduction and which is related to axis.
schedule::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i; break;
}
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
ret.num_common_loop = begin_loop;
// skip loops that are related to reduction and are unrelated to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if (flag == 2) skip_iter.insert(kv.first);
}
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
} else {
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
ret.num_common_loop = stage->leaf_iter_vars.size();
}
// copy elison here.
return ret;
}
namespace {
/*!
* \brief Verify if ComputeOp is valid with respect to Reduce operations.
*
* The following two properties are verified:
* (1) All Reduce operations must exist at top level.
* (2) For a list of operations, if one is Reduce, then the others
* must be Reduce as well; and their inputs should have the
* same attribute except value_index.
*/
class ComputeVerifier final : protected ir::IRVisitor {
public:
/// Special member functions
//@{
explicit ComputeVerifier(const ComputeOpNode* compute)
: compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
virtual ~ComputeVerifier() = default;
ComputeVerifier(const ComputeVerifier&) = delete;
ComputeVerifier(ComputeVerifier&&) = delete;
ComputeVerifier& operator=(const ComputeVerifier&) = delete;
ComputeVerifier& operator=(ComputeVerifier&&) = delete;
//@}
/// Interface to perform compute verification
void Run() {
for (const Expr e : compute_->body) {
// Check for consistency of top level reductions
const ir::Reduce* reduce = e.as<ir::Reduce>();
CHECK((reduce && reduce_) || (!reduce && !reduce_))
<< "All ComputeOp should be consistent "
<< "with being Reduce operation or not.";
if (reduce && reduce_) {
CHECK(ReduceEqual(reduce, reduce_))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
level_ = 0;
ir::IRVisitor::Visit(e);
}
}
protected:
/// Visitor implementation
//@{
void Visit(const NodeRef& n) final {
++level_;
ir::IRVisitor::Visit(n);
--level_;
}
void Visit_(const ir::Reduce* op) final {
// Check for non top level reductions
CHECK(0 == level_)
<< "Reductions are only allowed at the top level of compute. "
<< "Please create another tensor for further composition.";
}
//@}
private:
const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
const ir::Reduce* reduce_{nullptr}; ///< Top level Reduce operation
int level_{0}; ///< Level of op being processed
};
} // namespace
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode* op) {
ComputeVerifier v(op);
v.Run();
}
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update) {
Array<Expr> conds;
std::unordered_set<const Variable*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (attr->iter_type == kTensorized) {
break;
}
}
if (iv->iter_type == kCommReduce) {
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
conds.push_back(likely(iv->var > vrange->min));
banned.insert(iv->var.get());
}
}
for (const Expr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition "
<< pred << " has a conflict with the reset condition";
}
}
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
update, body);
}
} // namespace tvm
You can’t perform that action at this time.