Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests configurations #41

Merged
merged 5 commits into from Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 38 additions & 2 deletions docs/library/testing.md
Expand Up @@ -25,11 +25,21 @@ In addition, it creates a test called `test_auto_model_library` that iterates th

### Defining test cases

Any `modelkit.core.Model` can define its own test cases which are discoverable by the test created by `make_modellibrary_test`:
Any `modelkit.core.Model` can define its own test cases which are discoverable by the test created by `make_modellibrary_test`.

There are two ways of defining test cases.

#### Adding TEST_CASES as a class attribute

Tests added to the TEST_CASES class attribute are shared across the different models defined in the CONFIGURATIONS map.

In the following example, 4 test cases will be ran:
- 2 for `some_model_a`
- 2 for `some_model_b`

```python
class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS: Dict[str, Dict] = {"some_model": {}}
CONFIGURATIONS: Dict[str, Dict] = {"some_model_a": {}, "some_model_b": {}}

TEST_CASES = {
"cases": [
Expand All @@ -43,6 +53,32 @@ class TestableModel(Model[ModelItemType, ModelItemType]):

```

#### Adding test_cases to the CONFIGURATIONS map

Tests added to the CONFIGURATIONS map are restricted to their parent.

In the following example, 2 test cases will be ran for `some_model_a`:

```python
class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS: Dict[str, Dict] = {
"some_model_a": {
"test_cases": {
"cases": [
victorbenichoux marked this conversation as resolved.
Show resolved Hide resolved
{"item": {"x": 1}, "result": {"x": 1}},
{"item": {"x": 2}, "result": {"x": 2}},
],
}
},
"some_model_b": {},
}

def _predict(self, item):
return item

```
Both ways of testing can be used simultaneously and interchangeably.

Each test is instantiated with an item value and a result value, the automatic test will iterate through them and run the equivalent of:

```python
Expand Down
7 changes: 0 additions & 7 deletions modelkit/core/library.py
Expand Up @@ -366,13 +366,6 @@ async def aclose(self):
if isinstance(model, AsyncModel):
await model.close()

def _iterate_test_cases(self):
model_types = {type(model_type) for model_type in self._models.values()}
for model_type in model_types:
for model_key, item, result in model_type._iterate_test_cases():
if model_key in self.models:
yield self.get(model_key), item, result

def describe(self, console=None) -> None:
if not console:
console = Console()
Expand Down
53 changes: 39 additions & 14 deletions modelkit/core/model.py
Expand Up @@ -23,7 +23,12 @@
from structlog import get_logger

from modelkit.core.settings import LibrarySettings
from modelkit.core.types import ItemType, ModelTestingConfiguration, ReturnType
from modelkit.core.types import (
ItemType,
ModelTestingConfiguration,
ReturnType,
TestCases,
)
from modelkit.utils import traceback
from modelkit.utils.cache import Cache, CacheItem
from modelkit.utils.memory import PerformanceTracker
Expand Down Expand Up @@ -256,20 +261,40 @@ def __setstate__(self, state):
self.initialize_validation_models()

@classmethod
def _iterate_test_cases(cls, model_keys=None):
if not hasattr(cls, "TEST_CASES"):
logger.debug("No TEST_CASES defined", model_type=cls.__name__)
def _iterate_test_cases(cls, model_key: Optional[str] = None):
if (
not hasattr(cls, "TEST_CASES")
and not (
model_key
or any("test_cases" in conf for conf in cls.CONFIGURATIONS.values())
)
and (model_key and "test_cases" not in cls.CONFIGURATIONS[model_key])
) or (model_key and model_key not in cls.CONFIGURATIONS):
logger.debug("No test cases defined", model_type=cls.__name__)
return
if isinstance(cls.TEST_CASES, dict):
# This used to be OK with type instantiation but fails with a pydantic
# error since 1.18
# test_cases = ModelTestingConfiguration[ItemType, ReturnType]
test_cases = ModelTestingConfiguration(**cls.TEST_CASES)
else:
test_cases = cls.TEST_CASES
model_keys = model_keys or test_cases.model_keys or cls.CONFIGURATIONS.keys()

model_keys = [model_key] if model_key else cls.CONFIGURATIONS.keys()
cls_test_cases: List[TestCases] = []

if hasattr(cls, "TEST_CASES"):
if isinstance(cls.TEST_CASES, dict):
# This used to be OK with type instantiation but fails with a pydantic
# error since 1.18
# test_cases = ModelTestingConfiguration[ItemType, ReturnType]
cls_test_cases = ModelTestingConfiguration(**cls.TEST_CASES).cases
else:
cls_test_cases = cls.TEST_CASES.cases

for model_key in model_keys:
for case in test_cases.cases:
for case in cls_test_cases:
yield model_key, case.item, case.result, case.keyword_args

conf = cls.CONFIGURATIONS[model_key]
if "test_cases" not in conf:
continue
for case in conf["test_cases"]["cases"]:
if isinstance(case, dict):
case = TestCases(**case)
yield model_key, case.item, case.result, case.keyword_args

def describe(self, t=None):
Expand Down Expand Up @@ -350,7 +375,7 @@ def predict(self, item: ItemType, **kwargs):
def test(self):
console = Console()
for i, (model_key, item, expected, keyword_args) in enumerate(
self._iterate_test_cases(model_keys=[self.configuration_key])
self._iterate_test_cases(model_key=self.configuration_key)
):
result = None
try:
Expand Down
33 changes: 28 additions & 5 deletions tests/test_auto_testing.py
Expand Up @@ -33,6 +33,8 @@ def _predict(self, item, add_one=False):


def test_list_cases():
expected = [("some_model", {"x": 1}, {"x": 1}, {})]

class SomeModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS = {"some_model": {}}

Expand All @@ -43,9 +45,9 @@ class SomeModel(Model[ModelItemType, ModelItemType]):
def _predict(self, item):
return item

assert list(SomeModel._iterate_test_cases()) == [
("some_model", {"x": 1}, {"x": 1}, {})
]
assert list(SomeModel._iterate_test_cases()) == expected
assert list(SomeModel._iterate_test_cases("some_model")) == expected
assert list(SomeModel._iterate_test_cases("unknown_model")) == []

class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS = {"some_model": {}}
Expand All @@ -55,8 +57,29 @@ class TestableModel(Model[ModelItemType, ModelItemType]):
def _predict(self, item):
return item

assert list(TestableModel._iterate_test_cases()) == [
("some_model", {"x": 1}, {"x": 1}, {})
assert list(TestableModel._iterate_test_cases()) == expected
assert list(TestableModel._iterate_test_cases("some_model")) == expected
assert list(TestableModel._iterate_test_cases("unknown_model")) == []

class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS = {
"some_model": {
"test_cases": {"cases": [{"item": {"x": 1}, "result": {"x": 1}}]}
},
"some_other_model": {},
}
TEST_CASES = {"cases": [{"item": {"x": 1}, "result": {"x": 1}}]}
victorbenichoux marked this conversation as resolved.
Show resolved Hide resolved

def _predict(self, item):
return item

assert list(TestableModel._iterate_test_cases()) == expected * 2 + [
("some_other_model", {"x": 1}, {"x": 1}, {})
]
assert list(TestableModel._iterate_test_cases("some_model")) == expected * 2
assert list(TestableModel._iterate_test_cases("unknown_model")) == []
assert list(TestableModel._iterate_test_cases("some_other_model")) == [
("some_other_model", {"x": 1}, {"x": 1}, {})
]


Expand Down