From b463cba858b9873afb585eede3e727b1191f5cc1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Jun 2024 21:07:46 -0500 Subject: [PATCH] Fix breakage in unit tests --- src/relax/transform/fuse_tir.cc | 12 ++++++++---- src/relax/transform/rewrite_dataflow_reshape.cc | 17 ++++++++++------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index e712b5022a7d..c36d7e4ac9fe 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -659,21 +659,25 @@ class FusedTIRConstructor : public ExprVisitor { */ void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { Array arg_list; - Array buffer_list; if (const auto* arg_tuple = args.as()) { arg_list = arg_tuple->fields; } else { arg_list = {args}; } + Array relax_tensors; + Array tir_buffers; + ICHECK_GE(func->params.size(), arg_list.size()); for (size_t i = 0; i < arg_list.size(); ++i) { const tir::Var& param = func->params[i]; - const tir::Buffer& buffer = func->buffer_map.at(param); - buffer_list.push_back(buffer); + if (auto buffer = func->buffer_map.Get(param)) { + relax_tensors.push_back(arg_list[i]); + tir_buffers.push_back(buffer.value()); + } } - MapArgsToBuffer(arg_list, buffer_list); + MapArgsToBuffer(relax_tensors, tir_buffers); } static Array GetInplaceOutputIndices(const Array& inplace_indices, diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 8345f3e0b745..5403b7090c53 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -34,12 +34,15 @@ namespace tvm { namespace relax { -std::vector GetUsedArgsIndices(const tir::PrimFunc& fn, size_t num_args) { +std::vector GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) { std::vector indices; for (size_t i = 0; i < num_args; ++i) { - auto buffer_var = fn->buffer_map[fn->params[i]]->data; - if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == buffer_var.get(); })) { - indices.push_back(i); + if (auto buffer = fn->buffer_map.Get(fn->params[i])) { + auto buffer_var = buffer.value()->data; + if (tir::UsesVar(fn->body, + [=](const tir::VarNode* var) { return var == buffer_var.get(); })) { + indices.push_back(i); + } } } return indices; @@ -83,17 +86,17 @@ class DataflowReshapeRewriter : public ExprMutator { auto prim_fn = Downcast(mod_->Lookup(Downcast(call->args[0]))); auto arg_tuple = Downcast(call->args[1])->fields; - auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size()); + auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size()); // The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps // can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR // then flattens the tuple input so that the fused TIR reshape function ends up having // multiple input buffers. But only one of them should be accessed and reshaped. - if (used_arg_indices.size() != 1) { + if (used_tensor_arg_indices.size() != 1) { return GetRef(call); } - auto arg = arg_tuple[used_arg_indices[0]]; + auto arg = arg_tuple[used_tensor_arg_indices[0]]; if (!IsCallingTIRReshape(call, arg)) { return GetRef(call);