diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 524735caa9d6..bc1f03137456 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -438,13 +438,15 @@ class RelayToTIRVisitor : public MixedModeMutator { int context_buffer_size = 0; PrimExpr context_buffer_var = tir::StringImm("NULL"); - if (pool_name == "cmsisnn.qnn_avg_pool2d") { + if (pool_name == "cmsis-nn.qnn_avg_pool2d") { CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); int32_t input_c = qnn::get_const_int(input_shape[3]); context_buffer_size = AvgPoolBufferSize(flags, input_c); - std::string context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); - context_buffer_var = tir::Var(context_buffer_name, - PointerType(PrimType(DataType::Int(8)), "global.workspace")); + if (context_buffer_size) { + std::string context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); + context_buffer_var = tir::Var(context_buffer_name, + PointerType(PrimType(DataType::Int(8)), "global.workspace")); + } } tvm::Array context_buffer_args = {context_buffer_var, ToArg(context_buffer_size)};