Skip to content

Commit

Permalink
Add support for mocking async instance methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ollipa committed Dec 16, 2021
1 parent 58fe2b9 commit ef3df8d
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 7 deletions.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Expand Up @@ -53,13 +53,14 @@ python = "^3.6.2"

[tool.poetry.dev-dependencies]
pytest = "^6.2.4"
pytest-cov = "^2.12.1"
pytest-asyncio = "^0.16.0"
mypy = "^0.910"
pylint = "^2.9.3"
black = "^21.6b0"
isort = "^5.9.1"
tox = "^3.23.1"
Twisted = "^21.2.0"
pytest-cov = "^2.12.1"
mkdocs-material = "^7.3.4"
markdown-include = "^0.6.0"
mkdocstrings = "^0.16.2"
Expand Down
19 changes: 14 additions & 5 deletions src/flexmock/_api.py
Expand Up @@ -269,12 +269,16 @@ def _create_placeholder_mock_for_proper_teardown(
FlexmockContainer.add_expectation(mock, expectation)

def _update_method(self, expectation: "Expectation", name: str) -> None:
method_instance = self._create_mock_method(name)
if self._hasattr(self._object, name) and not hasattr(expectation, "_original"):
expectation._update_original(name, self._object)
expectation._method_type = self._get_method_type(name, expectation._original)
if expectation._method_type in SPECIAL_METHODS:
expectation._original_function = getattr(self._object, name)

is_async = hasattr(expectation, "_original") and inspect.iscoroutinefunction(
expectation._original
)
method_instance = self._create_mock_method(name, is_async)
if not inspect.isclass(self._object) or expectation._method_type in SPECIAL_METHODS:
method_instance = types.MethodType(method_instance, self._object)
expectation._local_override = _setattr(self._object, name, method_instance)
Expand Down Expand Up @@ -377,7 +381,7 @@ def updated(self: Any) -> Any:
setattr(obj, new_name, original)
self._create_placeholder_mock_for_proper_teardown(obj, name, original)

def _create_mock_method(self, name: str) -> Callable[..., Any]:
def _create_mock_method(self, name: str, is_async: bool) -> Callable[..., Any]:
def _handle_exception_matching(expectation: Expectation) -> None:
# pylint: disable=misplaced-bare-raise
return_values = expectation._return_values
Expand Down Expand Up @@ -503,11 +507,11 @@ def _handle_matched_expectation(
raise return_value.raises # pylint: disable=raising-bad-type
return return_value.value

def mock_method(runtime_self: Any, *kargs: Any, **kwargs: Any) -> Any:
arguments = {"kargs": kargs, "kwargs": kwargs}
def mock_method(runtime_self: Any, *args: Any, **kwargs: Any) -> Any:
arguments = {"kargs": args, "kwargs": kwargs}
expectation = FlexmockContainer.get_flexmock_expectation(self, name, arguments)
if expectation:
return _handle_matched_expectation(expectation, runtime_self, *kargs, **kwargs)
return _handle_matched_expectation(expectation, runtime_self, *args, **kwargs)
# inform the user which expectation(s) for the method were _not_ matched
saved_expectations = reversed(FlexmockContainer.get_expectations_with_name(self, name))
error_msg = (
Expand All @@ -521,6 +525,11 @@ def mock_method(runtime_self: Any, *kargs: Any, **kwargs: Any) -> Any:
)
raise MethodSignatureError(error_msg)

async def async_mock_method(*args: Any, **kwargs: Any) -> Any:
return mock_method(*args, **kwargs)

if is_async:
return async_mock_method
return mock_method


Expand Down
4 changes: 4 additions & 0 deletions tests/some_module.py
Expand Up @@ -50,6 +50,10 @@ def module_function(x, y):
return x - y


async def async_function(x, y):
return x - y


def kwargs_only_func1(foo, *, bar, baz=5):
return foo + bar + baz

Expand Down
127 changes: 127 additions & 0 deletions tests/test_pytest.py
Expand Up @@ -10,6 +10,8 @@
from tests.features import FlexmockTestCase
from tests.utils import assert_raises

from . import some_module


def test_module_level_test_for_pytest():
flexmock(foo="bar").should_receive("foo").once()
Expand Down Expand Up @@ -81,3 +83,128 @@ def test_flexmock_doesnt_override_existing_exception(self):
def test_flexmock_doesnt_override_assertion(self):
self._setup_failing_expectation()
assert False, "Flexmock shouldn't suppress this assertion"


class TestAsync:
@pytest.mark.asyncio
async def test_mock_async_instance_method(self):
class Class:
async def method(self):
return "method"

assert await Class().method() == "method"
flexmock(Class).should_receive("method").and_return("mocked_method")
assert await Class().method() == "mocked_method"

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_spy_async_instance_method(self):
class Class:
async def method(self):
return "method"

assert await Class().method() == "method"
flexmock(Class).should_call("method").and_return("method").once()
assert await Class().method() == "method"

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_mock_async_class_method(self):
class Class:
@classmethod
async def classmethod(cls):
return "classmethod"

assert await Class.classmethod() == "classmethod"
flexmock(Class).should_receive("classmethod").and_return("mocked_classmethod")
assert await Class.classmethod() == "mocked_classmethod"

instance = Class()
assert await instance.classmethod() == "classmethod"
flexmock(instance).should_receive("classmethod").and_return("mocked_classmethod")
assert await instance.classmethod() == "mocked_classmethod"

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_spy_async_class_method(self):
class Class:
@classmethod
async def classmethod(cls):
return "classmethod"

assert await Class.classmethod() == "classmethod"
flexmock(Class).should_call("classmethod").and_return("classmethod").once()
assert await Class.classmethod() == "classmethod"

instance = Class()
assert await instance.classmethod() == "classmethod"
flexmock(instance).should_call("classmethod").and_return("classmethod").once()
assert await instance.classmethod() == "classmethod"

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_mock_async_static_method(self):
class Class:
@staticmethod
async def staticmethod():
return "staticmethod"

assert await Class.staticmethod() == "staticmethod"
flexmock(Class).should_receive("staticmethod").and_return("mocked_staticmethod")
assert await Class.staticmethod() == "mocked_staticmethod"

instance = Class()
assert await instance.staticmethod() == "staticmethod"
flexmock(instance).should_receive("staticmethod").and_return("mocked_staticmethod")
assert await instance.staticmethod() == "mocked_staticmethod"

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_spy_async_static_method(self):
class Class:
@staticmethod
async def staticmethod():
return "staticmethod"

assert await Class.staticmethod() == "staticmethod"
flexmock(Class).should_call("staticmethod").and_return("staticmethod").once()
assert await Class.staticmethod() == "staticmethod"

instance = Class()
assert await instance.staticmethod() == "staticmethod"
flexmock(instance).should_call("staticmethod").and_return("staticmethod").once()
assert await instance.staticmethod() == "staticmethod"

@pytest.mark.asyncio
async def test_mock_async_function(self):
assert await some_module.async_function(15, 5) == 10
flexmock(some_module).should_receive("async_function").and_return(20)
assert await some_module.async_function(15, 5) == 20

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_spy_async_function(self):
assert await some_module.async_function(15, 5) == 10
flexmock(some_module).should_call("async_function").and_return(10).once()
assert await some_module.async_function(15, 5) == 10

@pytest.mark.asyncio
async def test_mock_async_method_with_args(self):
class Class:
async def method(self, arg1):
return arg1

assert await Class().method("value") == "value"
flexmock(Class).should_receive("method").with_args("value").and_return("mocked").once()
assert await Class().method("value") == "mocked"

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_spy_async_method_with_args(self):
class Class:
async def method(self, arg1):
return arg1

assert await Class().method("value") == "value"
flexmock(Class).should_call("method").with_args("value").and_return("value").once()
assert await Class().method("value") == "value"

0 comments on commit ef3df8d

Please sign in to comment.