Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytest] Sort unit tests before running. #9188

Merged
merged 2 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/tvm/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def pytest_collection_modifyitems(config, items):
# pylint: disable=unused-argument
_count_num_fixture_uses(items)
_remove_global_fixture_definitions(items)
_sort_tests(items)


@pytest.fixture
Expand Down Expand Up @@ -236,6 +237,25 @@ def _remove_global_fixture_definitions(items):
delattr(module, name)


def _sort_tests(items):
"""Sort tests by file/function.

By default, pytest will sort tests to maximize the re-use of
fixtures. However, this assumes that all fixtures have an equal
cost to generate, and no caches outside of those managed by
pytest. A tvm.testing.parameter is effectively free, while
reference data for testing may be quite large. Since most of the
TVM fixtures are specific to a python function, sort the test
ordering by python function, so that
tvm.testing.utils._fixture_cache can be cleared sooner rather than
later.

Should be called from pytest_collection_modifyitems.

"""
items.sort(key=lambda item: item.location)


def _target_to_requirement(target):
if isinstance(target, str):
target = tvm.target.Target(target)
Expand Down
15 changes: 9 additions & 6 deletions tests/python/unittest/test_tvm_testing_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def test_device_parametrization(self, dev):
self.devices_used.append(dev)

def test_all_targets_used(self):
assert self.targets_used == self.enabled_targets
assert self.devices_used == self.enabled_devices
assert sorted(self.targets_used) == sorted(self.enabled_targets)

def test_all_devices_used(self):
sort_key = lambda dev: (dev.device_type, dev.device_id)
assert sorted(self.devices_used, key=sort_key) == sorted(self.enabled_devices, key=sort_key)

targets_with_explicit_list = []

Expand All @@ -70,9 +73,9 @@ def test_exclude_target(self, target):
self.targets_with_exclusion.append(target)

def test_all_nonexcluded_targets_ran(self):
assert self.targets_with_exclusion == [
target for target in self.enabled_targets if not target.startswith("llvm")
]
assert sorted(self.targets_with_exclusion) == sorted(
[target for target in self.enabled_targets if not target.startswith("llvm")]
)

run_targets_with_known_failure = []

Expand All @@ -85,7 +88,7 @@ def test_known_failing_target(self, target):
assert "llvm" not in target

def test_all_targets_ran(self):
assert self.run_targets_with_known_failure == self.enabled_targets
assert sorted(self.run_targets_with_known_failure) == sorted(self.enabled_targets)

@tvm.testing.known_failing_targets("llvm")
@tvm.testing.parametrize_targets("llvm")
Expand Down