From 6c6ad6c43ad3cb378866683ed322316b173267aa Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 13 Jun 2023 04:07:04 -0400 Subject: [PATCH] [AOT] Remove duplication in tvm.testing.aot.compile_models (#15032) Previously, the body of an if/else were nearly identical, except for a `with apply_fixed_config` context. This commit removes the duplication using `contextlib.ExitStack` to conditionally enter the context. --- python/tvm/testing/aot.py | 83 +++++++++++++++------------------------ 1 file changed, 31 insertions(+), 52 deletions(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index a13fc4b7b26a..b2814aff2d77 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=use-list-literal, consider-using-with, f-string-without-interpolation """Common functions for AOT test cases""" +import contextlib import sys import datetime import os @@ -657,64 +658,42 @@ def compile_models( compiled_mods = list() for model in models: - if schedule_name: - # Testing with deterministic schedule - task_list = autotvm.task.extract_from_program( - model.module, target=target, params=model.params + with contextlib.ExitStack() as context_stack: + if schedule_name: + # Testing with deterministic schedule + task_list = autotvm.task.extract_from_program( + model.module, target=target, params=model.params + ) + context_stack.enter_context( + tvm.autotvm.apply_fixed_config(task_list, schedule_name) + ) + + context_stack.enter_context(tvm.transform.PassContext(opt_level=3, config=config)) + + build_kwargs = dict( + ir_mod=model.module, + params=model.params, + mod_name=model.name, ) - with tvm.autotvm.apply_fixed_config(task_list, schedule_name): - with tvm.transform.PassContext(opt_level=3, config=config): - if use_runtime_executor: - executor_factory = tvm.relay.build( - model.module, - target, - executor=executor, - runtime=runtime, - workspace_memory_pools=workspace_memory_pools, - constant_memory_pools=constant_memory_pools, - params=model.params, - mod_name=model.name, - ) - compiled_mods.append( - AOTCompiledTestModel(model=model, executor_factory=executor_factory) - ) - else: - executor_factory = tvm.relay.build( - model.module, - tvm.target.Target(target, host=target), - params=model.params, - mod_name=model.name, - ) - compiled_mods.append( - AOTCompiledTestModel(model=model, executor_factory=executor_factory) - ) - else: - with tvm.transform.PassContext(opt_level=3, config=config): - # TODO(Mousius) - Remove once executor/runtime are fully removed from Target - if use_runtime_executor: - executor_factory = tvm.relay.build( - model.module, - target, + + # TODO(Mousius) - Remove once executor/runtime are fully removed from Target + if use_runtime_executor: + build_kwargs.update( + dict( + target=target, executor=executor, runtime=runtime, workspace_memory_pools=workspace_memory_pools, constant_memory_pools=constant_memory_pools, - params=model.params, - mod_name=model.name, - ) - compiled_mods.append( - AOTCompiledTestModel(model=model, executor_factory=executor_factory) - ) - else: - executor_factory = tvm.relay.build( - model.module, - tvm.target.Target(target, host=target), - params=model.params, - mod_name=model.name, - ) - compiled_mods.append( - AOTCompiledTestModel(model=model, executor_factory=executor_factory) ) + ) + else: + build_kwargs.update(dict(target=tvm.target.Target(target, host=target))) + + executor_factory = tvm.relay.build(**build_kwargs) + compiled_mods.append( + AOTCompiledTestModel(model=model, executor_factory=executor_factory) + ) return compiled_mods