Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions src/op/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,23 @@ class TensorIntrinMatcher final : public IRMutator {
Expr Mutate_(const Reduce* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Reduce>();
Array<IterVar> axis = op->axis;
Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) {
auto it = axis_remap_.find(op->axis[i]);
if (it != axis_remap_.end()) {
axis.Set(i, it->second);
axis.push_back(it->second);
}
}
if (!axis.same_as(op->axis)) {
return Reduce::make(
op->combiner, op->source, axis, op->condition, op->value_index);
} else {
return e;
}
return Reduce::make(
op->combiner, op->source, axis, op->condition, op->value_index);
}

void Init(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
CHECK(self == stage->op.get());
// input remap.
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size());
Expand All @@ -204,7 +201,8 @@ class TensorIntrinMatcher final : public IRMutator {
e.start = e.region.size() - e.tensor.ndim();
for (size_t i = 0; i < e.start; ++i) {
CHECK(is_one(e.region[i]->extent))
<< "Tensorize: Input dimension mismatch with tensor intrin "
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape
<< ", given region=" << e.region;
}
Expand All @@ -223,6 +221,7 @@ class TensorIntrinMatcher final : public IRMutator {
<< "Tensorize: Output mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->axis.size()
<< ", tensorize-dim=" << self->axis.size();
var_remap_[self->axis[i]->var.get()] = r->min;
}
// Assume we tensorize at regin axis i [min, min + extent)
// The corresponding intrinsic axis is j [0, extent)
Expand All @@ -244,6 +243,7 @@ class TensorIntrinMatcher final : public IRMutator {
<< "Tensorize: Reduction mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->reduce_axis.size()
<< ", tensorize-dim=" << self->reduce_axis.size();
var_remap_[self->reduce_axis[i]->var.get()] = r->min;
}
for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
IterVar iv = self->reduce_axis[i];
Expand Down Expand Up @@ -328,7 +328,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data.
Stmt nop = Evaluate::make(0);
std::vector<Stmt> bind_nest;
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size())
<< "Tensorize failed: input size mismatch ";
Expand All @@ -345,7 +345,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
tuple.push_back(r->min);
tuple.push_back(r->extent);
}
bind_nest.emplace_back(AttrStmt::make(
input_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
Expand All @@ -365,7 +365,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
Tensor tensor = stage->op.output(i - intrin->inputs.size());
Buffer buffer = intrin->buffers[i];
Array<NodeRef> bind_spec{buffer, tensor};
bind_nest.emplace_back(AttrStmt::make(
output_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
Expand Down Expand Up @@ -400,11 +400,12 @@ Stmt MakeTensorize(const ComputeOpNode* self,
CHECK_EQ(n.init_predicates.size(), 0U);
CHECK(intrin->body.defined())
<< "Normal store op for intrin " << intrin << " is not defined";
Stmt body = ir::MergeNest(bind_nest, intrin->body);
Stmt body = MergeNest(output_bind_nest, intrin->body);
body = MergeNest(input_bind_nest, body);
body = Substitute(body, vmap);
body = ir::MergeNest(binder.asserts(), body);
body = MergeNest(binder.asserts(), body);
body = Substitute(body, n.main_vmap);
return ir::MergeNest(nest, body);
return MergeNest(nest, body);
} else {
// Need to split reduction
CHECK(intrin->reduce_init.defined())
Expand All @@ -419,14 +420,15 @@ Stmt MakeTensorize(const ComputeOpNode* self,
std::vector<std::vector<Stmt> > init_nest(
n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
Stmt init = MergeNest(bind_nest, intrin->reduce_init);
Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
init = Substitute(init, n.init_vmap);
init = MergeNest(init_nest, init);
// The update
std::vector<std::vector<Stmt> > update_nest(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
Stmt update = MergeNest(bind_nest, intrin->reduce_update);
Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
update = MergeNest(input_bind_nest, update);
update = Substitute(update, vmap);
update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap);
Expand Down
4 changes: 2 additions & 2 deletions src/pass/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class StorageFlattener : public IRMutator {
Array<Expr>(), Expr(),
key.GetName(), skey.to_string(),
align, 0);

buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
Expand Down Expand Up @@ -239,7 +238,8 @@ class StorageFlattener : public IRMutator {
CHECK(buffer && tensor);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
CHECK(buf_map_.count(key));
CHECK(buf_map_.count(key))
<< "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index;
const BufferEntry& be = buf_map_.at(key);
CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
Expand Down
13 changes: 13 additions & 0 deletions src/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,19 @@ class SchedulePostProc : public IRMutator {
return this->Mutate(op->body);
}
}
} else if (op->attr_key == ir::attr::buffer_bind_scope) {
Array<NodeRef> tuple(op->node.node_);
Tensor tensor(tuple[1].node_);
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
return AttrStmt::make(
Array<NodeRef>{tuple[0], it->second.output(tensor->value_index)},
op->attr_key, op->value, Mutate(op->body));
} else {
return this->Mutate(op->body);
}
}
}
return IRMutator::Mutate_(op, s);
}
Expand Down