Skip to content

Commit

Permalink
[UnitTests] Copy cached values when using tvm.testing.fixture(cache_r…
Browse files Browse the repository at this point in the history
…eturn_value=True)

To avoid unit tests being able to influence each other through a
shared cache, all cached fixtures are passed through copy.deepcopy
prior to use.
  • Loading branch information
Lunderberg committed Jun 18, 2021
1 parent 2870483 commit 0da04a8
Showing 1 changed file with 52 additions and 10 deletions.
62 changes: 52 additions & 10 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def test_something():
`TVM_TEST_TARGETS` environment variable in the CI.
"""
import collections
import copy
import functools
import logging
import os
import sys
import time
import pickle
import pytest
import _pytest
import numpy as np
Expand Down Expand Up @@ -1134,20 +1136,60 @@ def wraps(func):


def _fixture_cache(func):
cache = functools.lru_cache(maxsize=None)(func)
cache = {}
num_uses = 0

# Using functools.lru_cache would require the function arguments
# to be hashable, which wouldn't allow caching fixtures that
# depend on numpy arrays. For example, a fixture that takes a
# numpy array as input, then calculates uses a slow method to
# compute a known correct output for that input. Therefore,
# including a fallback for serializable types.
def get_cache_key(*args, **kwargs):
try:
hash((args, kwargs))
return (args, kwargs)
except TypeError as e:
pass

try:
return pickle.dumps((args, kwargs))
except TypeError as e:
raise TypeError(
"TVM caching of fixtures requires arguments to the fixture "
"to be either hashable or serializable"
) from e

@functools.wraps(func)
def wrapper(*args, **kwargs):
yield cache(*args, **kwargs)

nonlocal num_uses
num_uses += 1

# Clear the cache once all tests that use a particular fixture
# have completed.
if num_uses == wrapper.num_tests_use_this:
cache.cache_clear()
try:
cache_key = get_cache_key(*args, **kwargs)

try:
cached_value = cache[cache_key]
except KeyError:
cached_value = cache[cache_key] = func(*args, **kwargs)

try:
yield copy.deepcopy(cached_value)
except TypeError as e:
rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/"
"0007-parametrized-unit-tests.md#unresolved-questions"
)
message = (
"TVM caching of fixtures can only be used on serializable data types, not {}.\n"
"Please see {} for details/discussion."
).format(type(cached_value), rfc_url)
raise TypeError(message) from e

finally:
# Clear the cache once all tests that use a particular fixture
# have completed.
nonlocal num_uses
num_uses += 1
if num_uses == wrapper.num_tests_use_this:
cache.clear()

# Set in the pytest_collection_modifyitems()
wrapper.num_tests_use_this = 0
Expand Down

0 comments on commit 0da04a8

Please sign in to comment.