diff --git a/alluka/_self_injecting.py b/alluka/_self_injecting.py index f5753d2f..2b7a3edb 100644 --- a/alluka/_self_injecting.py +++ b/alluka/_self_injecting.py @@ -69,15 +69,18 @@ async def callback(database: Database = alluka.inject(type=Database)) -> None: ``` """ - __slots__ = ("_callback", "_client") + __slots__ = ("_callback", "_client", "_get_client") - def __init__(self, client: alluka.Client, callback: _CallbackSigT, /) -> None: + def __init__( + self, client: typing.Union[alluka.Client, collections.Callable[[], alluka.Client]], callback: _CallbackSigT, / + ) -> None: """Initialise a self injecting callback. Parameters ---------- client - The injection client to use to resolve dependencies. + Either the injection client instance to use to resolve dependencies + or a callback used to get the client instance. callback : alluka.abc.CallbackSig The callback to make self-injecting. @@ -90,7 +93,14 @@ def __init__(self, client: alluka.Client, callback: _CallbackSigT, /) -> None: positionally. """ self._callback = callback - self._client = client + + if isinstance(client, alluka.Client): + self._client: typing.Optional[alluka.Client] = client + self._get_client: typing.Optional[collections.Callable[[], alluka.Client]] = None + + else: + self._client = None + self._get_client = client @typing.overload async def __call__( @@ -112,7 +122,12 @@ async def __call__( **kwargs: typing.Any, ) -> _T: # <>. - return await self._client.call_with_async_di(self._callback, *args, **kwargs) + client = self._client + if not client: + assert self._get_client + client = self._get_client() + + return await client.call_with_async_di(self._callback, *args, **kwargs) @property def callback(self) -> _CallbackSigT: @@ -150,15 +165,18 @@ async def callback(database: Database = alluka.inject(type=Database)) -> None: ``` """ - __slots__ = ("_callback", "_client") + __slots__ = ("_callback", "_client", "_get_client") - def __init__(self, client: alluka.Client, callback: _SyncCallbackT, /) -> None: + def __init__( + self, client: typing.Union[alluka.Client, collections.Callable[[], alluka.Client]], callback: _SyncCallbackT, / + ) -> None: """Initialise a sync self injecting callback. Parameters ---------- client - The injection client to use to resolve dependencies. + Either the injection client instance to use to resolve dependencies + or a callback used to get the client instance. callback : collections.abc.Callable The callback to make self-injecting. @@ -169,11 +187,23 @@ def __init__(self, client: alluka.Client, callback: _SyncCallbackT, /) -> None: positionally. """ self._callback = callback - self._client = client + + if isinstance(client, alluka.Client): + self._client: typing.Optional[alluka.Client] = client + self._get_client: typing.Optional[collections.Callable[[], alluka.Client]] = None + + else: + self._client = None + self._get_client = client def __call__(self: SelfInjecting[collections.Callable[..., _T]], *args: typing.Any, **kwargs: typing.Any) -> _T: # <>. - return self._client.call_with_di(self._callback, *args, **kwargs) + client = self._client + if not client: + assert self._get_client + client = self._get_client() + + return client.call_with_di(self._callback, *args, **kwargs) @property def callback(self) -> _SyncCallbackT: diff --git a/alluka/local.py b/alluka/local.py new file mode 100644 index 00000000..31fe7e02 --- /dev/null +++ b/alluka/local.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# BSD 3-Clause License +# +# Copyright (c) 2020-2022, Faster Speeding +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Standard functions for using a context local dependency injection client. + +.. note:: + This module's functionality will only work if `initialize` has been called + to set the DI client for the local scope and you will most likely want to + call this in your `__init__.py` file to set the DI client for the main + thread. +""" +from __future__ import annotations + +__all__ = ["as_self_async_injecting", "as_self_injecting", "call_with_async_di", "call_with_di", "get", "initialize"] + +import contextvars +import typing + +from . import _client +from . import _self_injecting +from . import abc + +if typing.TYPE_CHECKING: + from collections import abc as collections + + _CallbackSigT = typing.TypeVar("_CallbackSigT", bound=abc.CallbackSig[typing.Any]) + _DefaultT = typing.TypeVar("_DefaultT") + _SyncCallbackSigT = typing.TypeVar("_SyncCallbackSigT", bound=collections.Callable[..., typing.Any]) + _T = typing.TypeVar("_T") + + +_CVAR_NAME: typing.Final[str] = "alluka_injector" +_injector = contextvars.ContextVar[abc.Client](_CVAR_NAME) + + +def initialize(client: typing.Optional[abc.Client] = None, /) -> None: + """Link or initialise an injection client for the current context. + + This uses the contextvars package to store the client. + + Parameters + ---------- + client + If provided, this will be set as the client for the current context. + If not provided, a new client will be created. + + Raises + ------ + RuntimeError + If the local client is already initialised. + """ + if _injector.get(None) is not None: + raise RuntimeError("Alluka client already initialised in the current context") + + client = client or _client.Client() + _injector.set(client) + + +@typing.overload +def get() -> abc.Client: + ... + + +@typing.overload +def get(*, default: _DefaultT) -> typing.Union[abc.Client, _DefaultT]: + ... + + +def get(*, default: _DefaultT = ...) -> typing.Union[abc.Client, _DefaultT]: + """Get the local client for the current context. + + Parameters + ---------- + default + The value to return if the client is not initialised. + + If not provided, a RuntimeError will be raised instead. + + Returns + ------- + alluka.abc.Client | _DefaultT + The client for the local context, or the default value if the client + is not initialised. + + Raises + ------ + RuntimeError + If the client is not initialised and no default value was provided. + """ + client = _injector.get(None) + if client is None: + if default is not ...: + return default + + raise RuntimeError("Alluka client not initialised in the current context") + + return client + + +def call_with_di(callback: collections.Callable[..., _T], *args: typing.Any, **kwargs: typing.Any) -> _T: + """Use the local client to call a callback with DI. + + Parameters + ---------- + callback + The callback to call. + *args + Positional arguments to passthrough to the callback. + **kwargs + Keyword arguments to passthrough to the callback. + + Returns + ------- + _T + The result of the call. + """ + return get().call_with_di(callback, *args, **kwargs) + + +@typing.overload +async def call_with_async_di( + callback: collections.Callable[..., collections.Coroutine[typing.Any, typing.Any, _T]], + *args: typing.Any, + **kwargs: typing.Any, +) -> _T: + ... + + +@typing.overload +async def call_with_async_di(callback: collections.Callable[..., _T], *args: typing.Any, **kwargs: typing.Any) -> _T: + ... + + +async def call_with_async_di(callback: abc.CallbackSig[_T], *args: typing.Any, **kwargs: typing.Any) -> _T: + """Use the local client to call a callback with async DI. + + Parameters + ---------- + callback + The callback to call. + *args + Positional arguments to passthrough to the callback. + **kwargs + Keyword arguments to passthrough to the callback. + + Returns + ------- + _T + The result of the call. + """ + return await get().call_with_async_di(callback, *args, **kwargs) + + +def as_self_async_injecting(callback: _CallbackSigT, /) -> _self_injecting.AsyncSelfInjecting[_CallbackSigT]: + """Mark a callback as self async injecting using the local DI client. + + Parameters + ---------- + callback + The callback to mark as self-injecting. + + Returns + ------- + alluka.self_injecting.AsyncSelfInjecting + The self-injecting callback. + """ + return _self_injecting.AsyncSelfInjecting(get, callback) + + +def as_self_injecting(callback: _SyncCallbackSigT, /) -> _self_injecting.SelfInjecting[_SyncCallbackSigT]: + """Mark a callback as self-injecting using the local DI client. + + Parameters + ---------- + callback + The callback to mark as self-injecting. + + Returns + ------- + alluka.self_injecting.SelfInjecting + The self-injecting callback. + """ + return _self_injecting.SelfInjecting(get, callback) diff --git a/tests/test__self_injecting.py b/tests/test__self_injecting.py index 218a30b5..816b68f0 100644 --- a/tests/test__self_injecting.py +++ b/tests/test__self_injecting.py @@ -39,7 +39,7 @@ class TestAsyncSelfInjecting: @pytest.mark.anyio() async def test_call_dunder_method(self): mock_callback = mock.Mock() - mock_client = mock.AsyncMock() + mock_client = mock.AsyncMock(alluka.abc.Client) self_injecting = alluka.AsyncSelfInjecting(mock_client, mock_callback) result = await self_injecting() @@ -47,6 +47,18 @@ async def test_call_dunder_method(self): assert result is mock_client.call_with_async_di.return_value mock_client.call_with_async_di.assert_awaited_once_with(mock_callback) + @pytest.mark.anyio() + async def test_call_dunder_method_with_client_getter(self): + mock_callback = mock.Mock() + mock_client_getter = mock.Mock(return_value=mock.AsyncMock()) + self_injecting = alluka.AsyncSelfInjecting(mock_client_getter, mock_callback) + + result = await self_injecting() + + assert result is mock_client_getter.return_value.call_with_async_di.return_value + mock_client_getter.return_value.call_with_async_di.assert_awaited_once_with(mock_callback) + mock_client_getter.assert_called_once_with() + def test_callback_property(self): mock_callback = mock.Mock() @@ -58,7 +70,7 @@ def test_callback_property(self): class TestSelfInjecting: def test_call_dunder_method(self): mock_callback = mock.Mock() - mock_client = mock.Mock() + mock_client = mock.Mock(alluka.abc.Client) self_injecting = alluka.SelfInjecting(mock_client, mock_callback) result = self_injecting() @@ -66,6 +78,17 @@ def test_call_dunder_method(self): assert result is mock_client.call_with_di.return_value mock_client.call_with_di.assert_called_once_with(mock_callback) + def test_call_dunder_method_with_client_getter(self): + mock_callback = mock.Mock() + mock_client_getter = mock.Mock() + self_injecting = alluka.SelfInjecting(mock_client_getter, mock_callback) + + result = self_injecting() + + assert result is mock_client_getter.return_value.call_with_di.return_value + mock_client_getter.return_value.call_with_di.assert_called_once_with(mock_callback) + mock_client_getter.assert_called_once_with() + def test_callback_property(self): mock_callback = mock.Mock() diff --git a/tests/test_local.py b/tests/test_local.py new file mode 100644 index 00000000..0b5e7490 --- /dev/null +++ b/tests/test_local.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# BSD 3-Clause License +# +# Copyright (c) 2020-2022, Faster Speeding +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from unittest import mock + +import pytest + +import alluka +import alluka.local + + +def test_initialize(): + alluka.local.initialize() + + assert isinstance(alluka.local.get(), alluka.Client) + + +def test_initialize_when_passed_through(): + mock_client = mock.Mock() + alluka.local.initialize(mock_client) + + assert alluka.local.get() is mock_client + + +def test_initialize_when_already_set(): + alluka.local.initialize() + + with pytest.raises(RuntimeError, match="Alluka client already initialised in the current context"): + alluka.local.initialize() + + +def test_initialize_when_passed_through_and_already_set(): + alluka.local.initialize() + + with pytest.raises(RuntimeError, match="Alluka client already initialised in the current context"): + alluka.local.initialize(mock.Mock()) + + +def test_get(): + mock_client = mock.Mock() + alluka.local.initialize(mock_client) + + assert alluka.local.get() is mock_client + + +def test_get_when_not_set(): + with pytest.raises(RuntimeError, match="Alluka client not initialised in the current context"): + alluka.local.get() + + +def test_get_when_not_set_and_default(): + result = alluka.local.get(default=None) + + assert result is None + + +def test_call_with_di(): + mock_client = mock.Mock() + alluka.local.initialize(mock_client) + mock_callback = mock.Mock() + + result = alluka.local.call_with_di(mock_callback, 123, 321, 123, 321, hello="Ok", bye="meow") + + assert result is mock_client.call_with_di.return_value + mock_client.call_with_di.assert_called_once_with(mock_callback, 123, 321, 123, 321, hello="Ok", bye="meow") + + +@pytest.mark.anyio() +async def test_call_with_async_di(): + mock_client = mock.AsyncMock() + alluka.local.initialize(mock_client) + mock_callback = mock.Mock() + + result = await alluka.local.call_with_async_di(mock_callback, 69, 320, hello="goodbye") + + assert result is mock_client.call_with_async_di.return_value + mock_client.call_with_async_di.assert_awaited_once_with(mock_callback, 69, 320, hello="goodbye") + + +def test_as_self_async_injecting(): + mock_callback = mock.Mock() + + result = alluka.local.as_self_async_injecting(mock_callback) + + assert isinstance(result, alluka.AsyncSelfInjecting) + assert result._get_client is alluka.local.get + assert result.callback is mock_callback + + +def test_as_self_injecting(): + mock_callback = mock.Mock() + + result = alluka.local.as_self_injecting(mock_callback) + + assert isinstance(result, alluka.SelfInjecting) + assert result._get_client is alluka.local.get + assert result.callback is mock_callback