diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 435b4c48..881873a2 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -52,6 +52,13 @@ async def wait_closed(stream: asyncio.StreamWriter) -> None: pass +if sys.version_info < (3, 12): + def markcoroutinefunction(c): # type: ignore + pass +else: + from inspect import markcoroutinefunction # noqa: F401 + + if sys.version_info < (3, 12): from ._asyncio_compat import wait_for as wait_for # noqa: F401 else: diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 06e698df..7b588c27 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -33,7 +33,8 @@ def __new__(mcls, name, bases, dct, *, wrap=False): if not inspect.isfunction(meth): continue - wrapper = mcls._wrap_connection_method(attrname) + iscoroutine = inspect.iscoroutinefunction(meth) + wrapper = mcls._wrap_connection_method(attrname, iscoroutine) wrapper = functools.update_wrapper(wrapper, meth) dct[attrname] = wrapper @@ -43,7 +44,7 @@ def __new__(mcls, name, bases, dct, *, wrap=False): return super().__new__(mcls, name, bases, dct) @staticmethod - def _wrap_connection_method(meth_name): + def _wrap_connection_method(meth_name, iscoroutine): def call_con_method(self, *args, **kwargs): # This method will be owned by PoolConnectionProxy class. if self._con is None: @@ -55,6 +56,9 @@ def call_con_method(self, *args, **kwargs): meth = getattr(self._con.__class__, meth_name) return meth(self._con, *args, **kwargs) + if iscoroutine: + compat.markcoroutinefunction(call_con_method) + return call_con_method