Skip to content

Commit

Permalink
[AOT] Remove duplication in tvm.testing.aot.compile_models (#15032)
Browse files Browse the repository at this point in the history
Previously, the body of an if/else were nearly identical, except for a
`with apply_fixed_config` context.  This commit removes the
duplication using `contextlib.ExitStack` to conditionally enter the
context.
  • Loading branch information
Lunderberg committed Jun 13, 2023
1 parent 21361a6 commit 6c6ad6c
Showing 1 changed file with 31 additions and 52 deletions.
83 changes: 31 additions & 52 deletions python/tvm/testing/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=use-list-literal, consider-using-with, f-string-without-interpolation
"""Common functions for AOT test cases"""
import contextlib
import sys
import datetime
import os
Expand Down Expand Up @@ -657,64 +658,42 @@ def compile_models(

compiled_mods = list()
for model in models:
if schedule_name:
# Testing with deterministic schedule
task_list = autotvm.task.extract_from_program(
model.module, target=target, params=model.params
with contextlib.ExitStack() as context_stack:
if schedule_name:
# Testing with deterministic schedule
task_list = autotvm.task.extract_from_program(
model.module, target=target, params=model.params
)
context_stack.enter_context(
tvm.autotvm.apply_fixed_config(task_list, schedule_name)
)

context_stack.enter_context(tvm.transform.PassContext(opt_level=3, config=config))

build_kwargs = dict(
ir_mod=model.module,
params=model.params,
mod_name=model.name,
)
with tvm.autotvm.apply_fixed_config(task_list, schedule_name):
with tvm.transform.PassContext(opt_level=3, config=config):
if use_runtime_executor:
executor_factory = tvm.relay.build(
model.module,
target,
executor=executor,
runtime=runtime,
workspace_memory_pools=workspace_memory_pools,
constant_memory_pools=constant_memory_pools,
params=model.params,
mod_name=model.name,
)
compiled_mods.append(
AOTCompiledTestModel(model=model, executor_factory=executor_factory)
)
else:
executor_factory = tvm.relay.build(
model.module,
tvm.target.Target(target, host=target),
params=model.params,
mod_name=model.name,
)
compiled_mods.append(
AOTCompiledTestModel(model=model, executor_factory=executor_factory)
)
else:
with tvm.transform.PassContext(opt_level=3, config=config):
# TODO(Mousius) - Remove once executor/runtime are fully removed from Target
if use_runtime_executor:
executor_factory = tvm.relay.build(
model.module,
target,

# TODO(Mousius) - Remove once executor/runtime are fully removed from Target
if use_runtime_executor:
build_kwargs.update(
dict(
target=target,
executor=executor,
runtime=runtime,
workspace_memory_pools=workspace_memory_pools,
constant_memory_pools=constant_memory_pools,
params=model.params,
mod_name=model.name,
)
compiled_mods.append(
AOTCompiledTestModel(model=model, executor_factory=executor_factory)
)
else:
executor_factory = tvm.relay.build(
model.module,
tvm.target.Target(target, host=target),
params=model.params,
mod_name=model.name,
)
compiled_mods.append(
AOTCompiledTestModel(model=model, executor_factory=executor_factory)
)
)
else:
build_kwargs.update(dict(target=tvm.target.Target(target, host=target)))

executor_factory = tvm.relay.build(**build_kwargs)
compiled_mods.append(
AOTCompiledTestModel(model=model, executor_factory=executor_factory)
)
return compiled_mods


Expand Down

0 comments on commit 6c6ad6c

Please sign in to comment.