diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 1580be208f37..7c1e5fb91436 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -172,19 +172,15 @@ class TensorIntrinMatcher final : public IRMutator { Expr Mutate_(const Reduce* op, const Expr& e) final { Expr expr = IRMutator::Mutate_(op, e); op = expr.as(); - Array axis = op->axis; + Array axis; for (size_t i = 0; i < op->axis.size(); ++i) { auto it = axis_remap_.find(op->axis[i]); if (it != axis_remap_.end()) { - axis.Set(i, it->second); + axis.push_back(it->second); } } - if (!axis.same_as(op->axis)) { - return Reduce::make( - op->combiner, op->source, axis, op->condition, op->value_index); - } else { - return e; - } + return Reduce::make( + op->combiner, op->source, axis, op->condition, op->value_index); } void Init(const ComputeOpNode* self, @@ -192,6 +188,7 @@ class TensorIntrinMatcher final : public IRMutator { const std::unordered_map& out_dom, const std::unordered_map >& in_region, const TensorIntrin& intrin) { + CHECK(self == stage->op.get()); // input remap. Array inputs = self->InputTensors(); CHECK_EQ(inputs.size(), intrin->inputs.size()); @@ -204,7 +201,8 @@ class TensorIntrinMatcher final : public IRMutator { e.start = e.region.size() - e.tensor.ndim(); for (size_t i = 0; i < e.start; ++i) { CHECK(is_one(e.region[i]->extent)) - << "Tensorize: Input dimension mismatch with tensor intrin " + << "Tensorize " << intrin->name << ":" + << " Input dimension mismatch with tensor intrin " << " expected shape=" << e.tensor->shape << ", given region=" << e.region; } @@ -223,6 +221,7 @@ class TensorIntrinMatcher final : public IRMutator { << "Tensorize: Output mismatch with tensor intrin " << " intrin-dim=" << intrin_compute->axis.size() << ", tensorize-dim=" << self->axis.size(); + var_remap_[self->axis[i]->var.get()] = r->min; } // Assume we tensorize at regin axis i [min, min + extent) // The corresponding intrinsic axis is j [0, extent) @@ -244,6 +243,7 @@ class TensorIntrinMatcher final : public IRMutator { << "Tensorize: Reduction mismatch with tensor intrin " << " intrin-dim=" << intrin_compute->reduce_axis.size() << ", tensorize-dim=" << self->reduce_axis.size(); + var_remap_[self->reduce_axis[i]->var.get()] = r->min; } for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { IterVar iv = self->reduce_axis[i]; @@ -328,7 +328,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, VerifyTensorizeBody(self, stage, out_dom, in_region, intrin); // Start bind data. Stmt nop = Evaluate::make(0); - std::vector bind_nest; + std::vector input_bind_nest, output_bind_nest; Array inputs = self->InputTensors(); CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch "; @@ -345,7 +345,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, tuple.push_back(r->min); tuple.push_back(r->extent); } - bind_nest.emplace_back(AttrStmt::make( + input_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); } @@ -365,7 +365,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, Tensor tensor = stage->op.output(i - intrin->inputs.size()); Buffer buffer = intrin->buffers[i]; Array bind_spec{buffer, tensor}; - bind_nest.emplace_back(AttrStmt::make( + output_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); } @@ -400,11 +400,12 @@ Stmt MakeTensorize(const ComputeOpNode* self, CHECK_EQ(n.init_predicates.size(), 0U); CHECK(intrin->body.defined()) << "Normal store op for intrin " << intrin << " is not defined"; - Stmt body = ir::MergeNest(bind_nest, intrin->body); + Stmt body = MergeNest(output_bind_nest, intrin->body); + body = MergeNest(input_bind_nest, body); body = Substitute(body, vmap); - body = ir::MergeNest(binder.asserts(), body); + body = MergeNest(binder.asserts(), body); body = Substitute(body, n.main_vmap); - return ir::MergeNest(nest, body); + return MergeNest(nest, body); } else { // Need to split reduction CHECK(intrin->reduce_init.defined()) @@ -419,14 +420,15 @@ Stmt MakeTensorize(const ComputeOpNode* self, std::vector > init_nest( n.init_nest.begin(), n.init_nest.begin() + tloc + 1); init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); - Stmt init = MergeNest(bind_nest, intrin->reduce_init); + Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); init = Substitute(init, n.init_vmap); init = MergeNest(init_nest, init); // The update std::vector > update_nest( n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); - Stmt update = MergeNest(bind_nest, intrin->reduce_update); + Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); + update = MergeNest(input_bind_nest, update); update = Substitute(update, vmap); update = MergeNest(binder.asserts(), update); update = Substitute(update, n.main_vmap); diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 2911293832fb..4c22ee767029 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -117,7 +117,6 @@ class StorageFlattener : public IRMutator { Array(), Expr(), key.GetName(), skey.to_string(), align, 0); - buf_map_[key] = e; Stmt body = this->Mutate(op->body); buf_map_[key].released = true; @@ -239,7 +238,8 @@ class StorageFlattener : public IRMutator { CHECK(buffer && tensor); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); TensorKey key{tensor->op, tensor->value_index}; - CHECK(buf_map_.count(key)); + CHECK(buf_map_.count(key)) + << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index; const BufferEntry& be = buf_map_.at(key); CHECK(!be.released); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 5bbc565695b7..b49b86bbe074 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -181,6 +181,19 @@ class SchedulePostProc : public IRMutator { return this->Mutate(op->body); } } + } else if (op->attr_key == ir::attr::buffer_bind_scope) { + Array tuple(op->node.node_); + Tensor tensor(tuple[1].node_); + auto it = replace_op_.find(tensor->op.get()); + if (it != replace_op_.end()) { + if (it->second.defined()) { + return AttrStmt::make( + Array{tuple[0], it->second.output(tensor->value_index)}, + op->attr_key, op->value, Mutate(op->body)); + } else { + return this->Mutate(op->body); + } + } } return IRMutator::Mutate_(op, s); }