diff --git a/poetry.lock b/poetry.lock index 7662106..65299bb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -699,6 +699,20 @@ toml = "*" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.16.0" +description = "Pytest support for asyncio." +category = "dev" +optional = false +python-versions = ">= 3.6" + +[package.dependencies] +pytest = ">=5.4.0" + +[package.extras] +testing = ["coverage", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "2.12.1" @@ -1156,7 +1170,7 @@ test = ["zope.testing"] [metadata] lock-version = "1.1" python-versions = "^3.6.2" -content-hash = "9452274a46e475f240c1f0fb20aaa1d15c4aa75797a14346f4257646496c98ea" +content-hash = "91d9384aa625f97bbcf6eb4b42872397ecfa8c1d3d7dc5f55b392aa14b2971f1" [metadata.files] alabaster = [ @@ -1504,6 +1518,10 @@ pytest = [ {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, ] +pytest-asyncio = [ + {file = "pytest-asyncio-0.16.0.tar.gz", hash = "sha256:7496c5977ce88c34379df64a66459fe395cd05543f0a2f837016e7144391fcfb"}, + {file = "pytest_asyncio-0.16.0-py3-none-any.whl", hash = "sha256:5f2a21273c47b331ae6aa5b36087047b4899e40f03f18397c0e65fa5cca54e9b"}, +] pytest-cov = [ {file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"}, {file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"}, diff --git a/pyproject.toml b/pyproject.toml index 97a9614..d5da76c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,13 +53,14 @@ python = "^3.6.2" [tool.poetry.dev-dependencies] pytest = "^6.2.4" +pytest-cov = "^2.12.1" +pytest-asyncio = "^0.16.0" mypy = "^0.910" pylint = "^2.9.3" black = "^21.6b0" isort = "^5.9.1" tox = "^3.23.1" Twisted = "^21.2.0" -pytest-cov = "^2.12.1" mkdocs-material = "^7.3.4" markdown-include = "^0.6.0" mkdocstrings = "^0.16.2" diff --git a/src/flexmock/_api.py b/src/flexmock/_api.py index 95bc6d3..00b54b5 100644 --- a/src/flexmock/_api.py +++ b/src/flexmock/_api.py @@ -269,12 +269,16 @@ def _create_placeholder_mock_for_proper_teardown( FlexmockContainer.add_expectation(mock, expectation) def _update_method(self, expectation: "Expectation", name: str) -> None: - method_instance = self._create_mock_method(name) if self._hasattr(self._object, name) and not hasattr(expectation, "_original"): expectation._update_original(name, self._object) expectation._method_type = self._get_method_type(name, expectation._original) if expectation._method_type in SPECIAL_METHODS: expectation._original_function = getattr(self._object, name) + + is_async = hasattr(expectation, "_original") and inspect.iscoroutinefunction( + expectation._original + ) + method_instance = self._create_mock_method(name, is_async) if not inspect.isclass(self._object) or expectation._method_type in SPECIAL_METHODS: method_instance = types.MethodType(method_instance, self._object) expectation._local_override = _setattr(self._object, name, method_instance) @@ -377,7 +381,7 @@ def updated(self: Any) -> Any: setattr(obj, new_name, original) self._create_placeholder_mock_for_proper_teardown(obj, name, original) - def _create_mock_method(self, name: str) -> Callable[..., Any]: + def _create_mock_method(self, name: str, is_async: bool) -> Callable[..., Any]: def _handle_exception_matching(expectation: Expectation) -> None: # pylint: disable=misplaced-bare-raise return_values = expectation._return_values @@ -503,11 +507,11 @@ def _handle_matched_expectation( raise return_value.raises # pylint: disable=raising-bad-type return return_value.value - def mock_method(runtime_self: Any, *kargs: Any, **kwargs: Any) -> Any: - arguments = {"kargs": kargs, "kwargs": kwargs} + def mock_method(runtime_self: Any, *args: Any, **kwargs: Any) -> Any: + arguments = {"kargs": args, "kwargs": kwargs} expectation = FlexmockContainer.get_flexmock_expectation(self, name, arguments) if expectation: - return _handle_matched_expectation(expectation, runtime_self, *kargs, **kwargs) + return _handle_matched_expectation(expectation, runtime_self, *args, **kwargs) # inform the user which expectation(s) for the method were _not_ matched saved_expectations = reversed(FlexmockContainer.get_expectations_with_name(self, name)) error_msg = ( @@ -521,6 +525,11 @@ def mock_method(runtime_self: Any, *kargs: Any, **kwargs: Any) -> Any: ) raise MethodSignatureError(error_msg) + async def async_mock_method(*args: Any, **kwargs: Any) -> Any: + return mock_method(*args, **kwargs) + + if is_async: + return async_mock_method return mock_method diff --git a/tests/some_module.py b/tests/some_module.py index 61715c4..977a82b 100644 --- a/tests/some_module.py +++ b/tests/some_module.py @@ -50,6 +50,10 @@ def module_function(x, y): return x - y +async def async_function(x, y): + return x - y + + def kwargs_only_func1(foo, *, bar, baz=5): return foo + bar + baz diff --git a/tests/test_pytest.py b/tests/test_pytest.py index b787364..5c3d9a2 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -10,6 +10,8 @@ from tests.features import FlexmockTestCase from tests.utils import assert_raises +from . import some_module + def test_module_level_test_for_pytest(): flexmock(foo="bar").should_receive("foo").once() @@ -81,3 +83,128 @@ def test_flexmock_doesnt_override_existing_exception(self): def test_flexmock_doesnt_override_assertion(self): self._setup_failing_expectation() assert False, "Flexmock shouldn't suppress this assertion" + + +class TestAsync: + @pytest.mark.asyncio + async def test_mock_async_instance_method(self): + class Class: + async def method(self): + return "method" + + assert await Class().method() == "method" + flexmock(Class).should_receive("method").and_return("mocked_method") + assert await Class().method() == "mocked_method" + + @pytest.mark.xfail + @pytest.mark.asyncio + async def test_spy_async_instance_method(self): + class Class: + async def method(self): + return "method" + + assert await Class().method() == "method" + flexmock(Class).should_call("method").and_return("method").once() + assert await Class().method() == "method" + + @pytest.mark.xfail + @pytest.mark.asyncio + async def test_mock_async_class_method(self): + class Class: + @classmethod + async def classmethod(cls): + return "classmethod" + + assert await Class.classmethod() == "classmethod" + flexmock(Class).should_receive("classmethod").and_return("mocked_classmethod") + assert await Class.classmethod() == "mocked_classmethod" + + instance = Class() + assert await instance.classmethod() == "classmethod" + flexmock(instance).should_receive("classmethod").and_return("mocked_classmethod") + assert await instance.classmethod() == "mocked_classmethod" + + @pytest.mark.xfail + @pytest.mark.asyncio + async def test_spy_async_class_method(self): + class Class: + @classmethod + async def classmethod(cls): + return "classmethod" + + assert await Class.classmethod() == "classmethod" + flexmock(Class).should_call("classmethod").and_return("classmethod").once() + assert await Class.classmethod() == "classmethod" + + instance = Class() + assert await instance.classmethod() == "classmethod" + flexmock(instance).should_call("classmethod").and_return("classmethod").once() + assert await instance.classmethod() == "classmethod" + + @pytest.mark.xfail + @pytest.mark.asyncio + async def test_mock_async_static_method(self): + class Class: + @staticmethod + async def staticmethod(): + return "staticmethod" + + assert await Class.staticmethod() == "staticmethod" + flexmock(Class).should_receive("staticmethod").and_return("mocked_staticmethod") + assert await Class.staticmethod() == "mocked_staticmethod" + + instance = Class() + assert await instance.staticmethod() == "staticmethod" + flexmock(instance).should_receive("staticmethod").and_return("mocked_staticmethod") + assert await instance.staticmethod() == "mocked_staticmethod" + + @pytest.mark.xfail + @pytest.mark.asyncio + async def test_spy_async_static_method(self): + class Class: + @staticmethod + async def staticmethod(): + return "staticmethod" + + assert await Class.staticmethod() == "staticmethod" + flexmock(Class).should_call("staticmethod").and_return("staticmethod").once() + assert await Class.staticmethod() == "staticmethod" + + instance = Class() + assert await instance.staticmethod() == "staticmethod" + flexmock(instance).should_call("staticmethod").and_return("staticmethod").once() + assert await instance.staticmethod() == "staticmethod" + + @pytest.mark.asyncio + async def test_mock_async_function(self): + assert await some_module.async_function(15, 5) == 10 + flexmock(some_module).should_receive("async_function").and_return(20) + assert await some_module.async_function(15, 5) == 20 + + @pytest.mark.xfail + @pytest.mark.asyncio + async def test_spy_async_function(self): + assert await some_module.async_function(15, 5) == 10 + flexmock(some_module).should_call("async_function").and_return(10).once() + assert await some_module.async_function(15, 5) == 10 + + @pytest.mark.asyncio + async def test_mock_async_method_with_args(self): + class Class: + async def method(self, arg1): + return arg1 + + assert await Class().method("value") == "value" + flexmock(Class).should_receive("method").with_args("value").and_return("mocked").once() + assert await Class().method("value") == "mocked" + + @pytest.mark.xfail + @pytest.mark.asyncio + async def test_spy_async_method_with_args(self): + class Class: + async def method(self, arg1): + return arg1 + + assert await Class().method("value") == "value" + flexmock(Class).should_call("method").with_args("value").and_return("value").once() + assert await Class().method("value") == "value"