Skip to content

Commit

Permalink
Merge pull request #41 from clustree/add-tests-configurations
Browse files Browse the repository at this point in the history
Add tests configurations
  • Loading branch information
antoinejeannot committed Jun 18, 2021
2 parents 449360f + 5a54ca5 commit b21999e
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 62 deletions.
42 changes: 38 additions & 4 deletions docs/library/testing.md
Expand Up @@ -25,23 +25,57 @@ In addition, it creates a test called `test_auto_model_library` that iterates th

### Defining test cases

Any `modelkit.core.Model` can define its own test cases which are discoverable by the test created by `make_modellibrary_test`:
Any `modelkit.core.Model` can define its own test cases which are discoverable by the test created by `make_modellibrary_test`.

There are two ways of defining test cases.

#### Adding TEST_CASES as a class attribute

Tests added to the TEST_CASES class attribute are shared across the different models defined in the CONFIGURATIONS map.

In the following example, 4 test cases will be ran:
- 2 for `some_model_a`
- 2 for `some_model_b`

```python
class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS: Dict[str, Dict] = {"some_model": {}}
CONFIGURATIONS: Dict[str, Dict] = {"some_model_a": {}, "some_model_b": {}}

TEST_CASES = {
"cases": [
TEST_CASES = [
{"item": {"x": 1}, "result": {"x": 1}},
{"item": {"x": 2}, "result": {"x": 2}},
]

def _predict(self, item):
return item

```

#### Adding test_cases to the CONFIGURATIONS map

Tests added to the CONFIGURATIONS map are restricted to their parent.

In the following example, 2 test cases will be ran for `some_model_a`:

```python
class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS: Dict[str, Dict] = {
"some_model_a": {
"test_cases": {
"cases": [
{"item": {"x": 1}, "result": {"x": 1}},
{"item": {"x": 2}, "result": {"x": 2}},
],
}
},
"some_model_b": {},
}

def _predict(self, item):
return item

```
Both ways of testing can be used simultaneously and interchangeably.

Each test is instantiated with an item value and a result value, the automatic test will iterate through them and run the equivalent of:

Expand Down
7 changes: 0 additions & 7 deletions modelkit/core/library.py
Expand Up @@ -366,13 +366,6 @@ async def aclose(self):
if isinstance(model, AsyncModel):
await model.close()

def _iterate_test_cases(self):
model_types = {type(model_type) for model_type in self._models.values()}
for model_type in model_types:
for model_key, item, result in model_type._iterate_test_cases():
if model_key in self.models:
yield self.get(model_key), item, result

def describe(self, console=None) -> None:
if not console:
console = Console()
Expand Down
48 changes: 31 additions & 17 deletions modelkit/core/model.py
Expand Up @@ -23,7 +23,7 @@
from structlog import get_logger

from modelkit.core.settings import LibrarySettings
from modelkit.core.types import ItemType, ModelTestingConfiguration, ReturnType
from modelkit.core.types import ItemType, ReturnType, TestCase
from modelkit.utils import traceback
from modelkit.utils.cache import Cache, CacheItem
from modelkit.utils.memory import PerformanceTracker
Expand Down Expand Up @@ -187,9 +187,7 @@ class AbstractModel(Asset, Generic[ItemType, ReturnType]):
that either take items or lists of items.
"""

# The correct type below raises an error with pydantic after version 0.18
# TEST_CASES: Union[ModelTestingConfiguration[ItemType, ReturnType], Dict]
TEST_CASES: Any
TEST_CASES: List[Union[TestCase[ItemType, ReturnType], Dict]]

def __init__(
self,
Expand Down Expand Up @@ -256,20 +254,36 @@ def __setstate__(self, state):
self.initialize_validation_models()

@classmethod
def _iterate_test_cases(cls, model_keys=None):
if not hasattr(cls, "TEST_CASES"):
logger.debug("No TEST_CASES defined", model_type=cls.__name__)
def _iterate_test_cases(cls, model_key: Optional[str] = None):
if (
not hasattr(cls, "TEST_CASES")
and not (
model_key
or any("test_cases" in conf for conf in cls.CONFIGURATIONS.values())
)
and (model_key and "test_cases" not in cls.CONFIGURATIONS[model_key])
) or (model_key and model_key not in cls.CONFIGURATIONS):
logger.debug("No test cases defined", model_type=cls.__name__)
return
if isinstance(cls.TEST_CASES, dict):
# This used to be OK with type instantiation but fails with a pydantic
# error since 1.18
# test_cases = ModelTestingConfiguration[ItemType, ReturnType]
test_cases = ModelTestingConfiguration(**cls.TEST_CASES)
else:
test_cases = cls.TEST_CASES
model_keys = model_keys or test_cases.model_keys or cls.CONFIGURATIONS.keys()

model_keys = [model_key] if model_key else cls.CONFIGURATIONS.keys()
cls_test_cases: List[Union[TestCase[ItemType, ReturnType], Dict]] = []

if hasattr(cls, "TEST_CASES"):
cls_test_cases = cls.TEST_CASES

for model_key in model_keys:
for case in test_cases.cases:
for case in cls_test_cases:
if isinstance(case, dict):
case = TestCase(**case)
yield model_key, case.item, case.result, case.keyword_args

conf = cls.CONFIGURATIONS[model_key]
if "test_cases" not in conf:
continue
for case in conf["test_cases"]:
if isinstance(case, dict):
case = TestCase(**case)
yield model_key, case.item, case.result, case.keyword_args

def describe(self, t=None):
Expand Down Expand Up @@ -350,7 +364,7 @@ def predict(self, item: ItemType, **kwargs):
def test(self):
console = Console()
for i, (model_key, item, expected, keyword_args) in enumerate(
self._iterate_test_cases(model_keys=[self.configuration_key])
self._iterate_test_cases(model_key=self.configuration_key)
):
result = None
try:
Expand Down
11 changes: 2 additions & 9 deletions modelkit/core/types.py
@@ -1,5 +1,5 @@
from types import ModuleType
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Type, TypeVar, Union

import pydantic
import pydantic.generics
Expand All @@ -17,14 +17,7 @@
LibraryModelsType = Union[ModuleType, Type, List, str]


class TestCases(pydantic.generics.GenericModel, Generic[TestItemType, TestReturnType]):
class TestCase(pydantic.generics.GenericModel, Generic[TestItemType, TestReturnType]):
item: TestItemType
result: TestReturnType
keyword_args: Dict[str, Any] = {}


class ModelTestingConfiguration(
pydantic.generics.GenericModel, Generic[TestItemType, TestReturnType]
):
model_keys: Optional[List[str]]
cases: List[TestCases[TestItemType, TestReturnType]]
50 changes: 33 additions & 17 deletions tests/test_auto_testing.py
Expand Up @@ -3,7 +3,6 @@
import pydantic

from modelkit.core.model import Model
from modelkit.core.types import ModelTestingConfiguration
from modelkit.testing import modellibrary_auto_test, modellibrary_fixture


Expand All @@ -18,13 +17,11 @@ class ModelReturnType(pydantic.BaseModel):
class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS: Dict[str, Dict] = {"some_model": {}}

TEST_CASES = {
"cases": [
{"item": {"x": 1}, "result": {"x": 1}},
{"item": {"x": 2}, "result": {"x": 2}},
{"item": {"x": 1}, "result": {"x": 2}, "keyword_args": {"add_one": True}},
]
}
TEST_CASES = [
{"item": {"x": 1}, "result": {"x": 1}},
{"item": {"x": 2}, "result": {"x": 2}},
{"item": {"x": 1}, "result": {"x": 2}, "keyword_args": {"add_one": True}},
]

def _predict(self, item, add_one=False):
if add_one:
Expand All @@ -33,30 +30,49 @@ def _predict(self, item, add_one=False):


def test_list_cases():
expected = [("some_model", {"x": 1}, {"x": 1}, {})]

class SomeModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS = {"some_model": {}}

TEST_CASES = ModelTestingConfiguration(
cases=[{"item": {"x": 1}, "result": {"x": 1}}]
)
TEST_CASES = [{"item": {"x": 1}, "result": {"x": 1}}]

def _predict(self, item):
return item

assert list(SomeModel._iterate_test_cases()) == [
("some_model", {"x": 1}, {"x": 1}, {})
]
assert list(SomeModel._iterate_test_cases()) == expected
assert list(SomeModel._iterate_test_cases("some_model")) == expected
assert list(SomeModel._iterate_test_cases("unknown_model")) == []

class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS = {"some_model": {}}

TEST_CASES = {"cases": [{"item": {"x": 1}, "result": {"x": 1}}]}
TEST_CASES = [{"item": {"x": 1}, "result": {"x": 1}}]

def _predict(self, item):
return item

assert list(TestableModel._iterate_test_cases()) == expected
assert list(TestableModel._iterate_test_cases("some_model")) == expected
assert list(TestableModel._iterate_test_cases("unknown_model")) == []

class TestableModel(Model[ModelItemType, ModelItemType]):
CONFIGURATIONS = {
"some_model": {"test_cases": [{"item": {"x": 1}, "result": {"x": 1}}]},
"some_other_model": {},
}
TEST_CASES = [{"item": {"x": 1}, "result": {"x": 1}}]

def _predict(self, item):
return item

assert list(TestableModel._iterate_test_cases()) == [
("some_model", {"x": 1}, {"x": 1}, {})
assert list(TestableModel._iterate_test_cases()) == expected * 2 + [
("some_other_model", {"x": 1}, {"x": 1}, {})
]
assert list(TestableModel._iterate_test_cases("some_model")) == expected * 2
assert list(TestableModel._iterate_test_cases("unknown_model")) == []
assert list(TestableModel._iterate_test_cases("some_other_model")) == [
("some_other_model", {"x": 1}, {"x": 1}, {})
]


Expand Down
14 changes: 6 additions & 8 deletions tests/test_core.py
Expand Up @@ -427,14 +427,12 @@ def _predict(self, item, **kwargs):
assert prediction.endswith(os.path.join("category", "override-asset", "1.0"))


SYNC_ASYNC_TEST_CASES = {
"cases": [
{"item": "", "result": 0},
{"item": "a", "result": 1},
{"item": ["a", "b", "c"], "result": 3},
{"item": range(100), "result": 100},
]
}
SYNC_ASYNC_TEST_CASES = [
{"item": "", "result": 0},
{"item": "a", "result": 1},
{"item": ["a", "b", "c"], "result": 3},
{"item": range(100), "result": 100},
]


def test_model_sync_test():
Expand Down

0 comments on commit b21999e

Please sign in to comment.