Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delegate resource exceptions #103

Merged
merged 7 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions anydi/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,9 @@ def __exit__(
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
) -> bool:
"""Exit the singleton context."""
self.close()
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)

def start(self) -> None:
"""Start the singleton context."""
Expand Down Expand Up @@ -464,9 +464,9 @@ async def __aexit__(
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
) -> bool:
"""Exit the singleton context."""
await self.aclose()
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)

async def astart(self) -> None:
"""Start the singleton context asynchronously."""
Expand Down
17 changes: 8 additions & 9 deletions anydi/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,15 @@ def __exit__(
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
) -> bool:
"""Exit the context.

Args:
exc_type: The type of the exception, if any.
exc_val: The exception instance, if any.
exc_tb: The traceback, if any.
"""
self.close()
return
return self._stack.__exit__(exc_type, exc_val, exc_tb)

@abc.abstractmethod
def start(self) -> None:
Expand All @@ -262,7 +261,7 @@ def start(self) -> None:

def close(self) -> None:
"""Close the scoped context."""
self._stack.close()
self._stack.__exit__(None, None, None)

async def __aenter__(self) -> Self:
"""Enter the context asynchronously.
Expand All @@ -278,25 +277,25 @@ async def __aexit__(
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
) -> bool:
"""Exit the context asynchronously.

Args:
exc_type: The type of the exception, if any.
exc_val: The exception instance, if any.
exc_tb: The traceback, if any.
"""
await self.aclose()
return
return await run_async(
self.__exit__, exc_type, exc_val, exc_tb
) or await self._async_stack.__aexit__(exc_type, exc_val, exc_tb)

@abc.abstractmethod
async def astart(self) -> None:
"""Start the scoped context asynchronously."""

async def aclose(self) -> None:
"""Close the scoped context asynchronously."""
await run_async(self._stack.close)
await self._async_stack.aclose()
await self.__aexit__(None, None, None)


@final
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "anydi"
version = "0.26.7"
version = "0.26.8a0"
description = "Dependency Injection library"
authors = ["Anton Ruhlov <antonruhlov@gmail.com>"]
license = "MIT"
Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,19 @@ class Service:
def __init__(self, ident: str) -> None:
self.ident = ident
self.events: list[str] = []


class Resource:
def __init__(self) -> None:
self.called = False
self.committed = False
self.rolled_back = False

def run(self) -> None:
self.called = True

def commit(self) -> None:
self.committed = True

def rollback(self) -> None:
self.rolled_back = True
73 changes: 59 additions & 14 deletions tests/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from anydi import Container, Provider, Scope, auto, dep, request, singleton, transient

from tests.fixtures import Service
from tests.fixtures import Resource, Service


@pytest.fixture
Expand Down Expand Up @@ -877,14 +877,14 @@ def test_release_instance(container: Container) -> None:

def test_override_instance(container: Container) -> None:
origin_name = "origin"
overriden_name = "overriden"
overridden_name = "overridden"

@container.provider(scope="singleton")
def name() -> str:
return origin_name

with container.override(str, overriden_name):
assert container.resolve(str) == overriden_name
with container.override(str, overridden_name):
assert container.resolve(str) == overridden_name

assert container.resolve(str) == origin_name

Expand All @@ -900,42 +900,87 @@ def test_override_instance_provider_not_registered_using_strict_mode() -> None:


def test_override_instance_transient_provider(container: Container) -> None:
overriden_uuid = uuid.uuid4()
overridden_uuid = uuid.uuid4()

@container.provider(scope="transient")
def uuid_provider() -> uuid.UUID:
return uuid.uuid4()

with container.override(uuid.UUID, overriden_uuid):
assert container.resolve(uuid.UUID) == overriden_uuid
with container.override(uuid.UUID, overridden_uuid):
assert container.resolve(uuid.UUID) == overridden_uuid

assert container.resolve(uuid.UUID) != overriden_uuid
assert container.resolve(uuid.UUID) != overridden_uuid


def test_override_instance_resource_provider(container: Container) -> None:
origin = "origin"
overriden = "overriden"
overridden = "overridden"

@container.provider(scope="singleton")
def message() -> Iterator[str]:
yield origin

with container.override(str, overriden):
assert container.resolve(str) == overriden
with container.override(str, overridden):
assert container.resolve(str) == overridden

assert container.resolve(str) == origin


async def test_override_instance_async_resource_provider(container: Container) -> None:
origin = "origin"
overriden = "overriden"
overridden = "overridden"

@container.provider(scope="singleton")
async def message() -> AsyncIterator[str]:
yield origin

with container.override(str, overriden):
assert container.resolve(str) == overriden
with container.override(str, overridden):
assert container.resolve(str) == overridden


def test_resource_delegated_exception(container: Container) -> None:
@container.provider(scope="request")
def resource_provider() -> Iterator[Resource]:
resource = Resource()
try:
yield resource
except Exception: # noqa
resource.rollback()
raise
else:
resource.commit()

with pytest.raises(ValueError), container.request_context():
resource = container.resolve(Resource)
resource.run()
raise ValueError

assert resource.called
assert not resource.committed
assert resource.rolled_back


async def test_async_resource_delegated_exception(container: Container) -> None:
@container.provider(scope="request")
async def resource_provider() -> AsyncIterator[Resource]:
resource = Resource()
try:
yield resource
except Exception: # noqa
resource.rollback()
raise
else:
resource.commit()

with pytest.raises(ValueError):
async with container.arequest_context():
resource = await container.aresolve(Resource)
resource.run()
raise ValueError

assert resource.called
assert not resource.committed
assert resource.rolled_back


# Inspections
Expand Down
Loading