Skip to content

Commit

Permalink
feat: reduce overhead to dispatch method handlers (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Aug 18, 2023
1 parent 8f4f945 commit b222552
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
44 changes: 33 additions & 11 deletions src/dbus_fast/aio/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
from collections import deque
from copy import copy
from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Set, Tuple

from .. import introspection as intr
from ..auth import Authenticator, AuthExternal
Expand All @@ -18,7 +18,7 @@
)
from ..errors import AuthError
from ..message import Message
from ..message_bus import BaseMessageBus
from ..message_bus import BaseMessageBus, _block_unexpected_reply
from ..service import ServiceInterface
from .message_reader import build_message_reader
from .proxy_object import ProxyObject
Expand Down Expand Up @@ -173,7 +173,7 @@ class MessageBus(BaseMessageBus):
:vartype connected: bool
"""

__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future")
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future", "_pending_futures")

def __init__(
self,
Expand All @@ -193,6 +193,7 @@ def __init__(
self._auth = auth

self._disconnect_future = self._loop.create_future()
self._pending_futures: Set[asyncio.Future] = set()

async def connect(self) -> "MessageBus":
"""Connect this message bus to the DBus daemon.
Expand Down Expand Up @@ -431,24 +432,45 @@ def _make_method_handler(self, interface, method):
if not asyncio.iscoroutinefunction(method.fn):
return super()._make_method_handler(interface, method)

def _coro_method_handler(msg, send_reply):
def done(fut):
negotiate_unix_fd = self._negotiate_unix_fd
msg_body_to_args = ServiceInterface._msg_body_to_args
fn_result_to_body = ServiceInterface._fn_result_to_body

def _coroutine_method_handler(
msg: Message, send_reply: Callable[[Message], None]
) -> None:
"""A coroutine method handler."""
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
fut = asyncio.ensure_future(method.fn(interface, *args))
# Hold a strong reference to the future to ensure
# it is not garbage collected before it is done.
self._pending_futures.add(fut)
if (
send_reply is _block_unexpected_reply
or msg.flags.value & NO_REPLY_EXPECTED_VALUE
):
fut.add_done_callback(self._pending_futures.discard)
return

# We only create the closure function if we are actually going to reply
def _done(fut: asyncio.Future) -> None:
"""The callback for when the method is done."""
with send_reply:
result = fut.result()
body, unix_fds = ServiceInterface._fn_result_to_body(
result, method.out_signature_tree
body, unix_fds = fn_result_to_body(
result, method.out_signature_tree, replace_fds=negotiate_unix_fd
)
send_reply(
Message.new_method_return(
msg, method.out_signature, body, unix_fds
)
)

args = ServiceInterface._msg_body_to_args(msg)
fut = asyncio.ensure_future(method.fn(interface, *args))
fut.add_done_callback(done)
fut.add_done_callback(_done)
# Discard the future only after running the done callback
fut.add_done_callback(self._pending_futures.discard)

return _coro_method_handler
return _coroutine_method_handler

async def _auth_readline(self) -> str:
buf = b""
Expand Down
23 changes: 14 additions & 9 deletions src/dbus_fast/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,24 +770,27 @@ def _call(
if not msg.serial:
msg.serial = self.next_serial()

def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err) # type: ignore[misc]

no_reply_expected = not _expects_reply(msg)

# Make sure the return reply handler is installed
# before sending the message to avoid a race condition
# where the reply is lost in case the backend can
# send it right away.
if not no_reply_expected:
self._method_return_handlers[msg.serial] = reply_notify

def _reply_notify(
reply: Optional[Message], err: Optional[Exception]
) -> None:
"""Callback on reply."""
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err)

self._method_return_handlers[msg.serial] = _reply_notify

self.send(msg)

if no_reply_expected:
callback(None, None) # type: ignore[misc]
callback(None, None)

@staticmethod
def _check_callback_type(callback: Callable) -> None:
Expand Down Expand Up @@ -921,7 +924,9 @@ def _make_method_handler(
def _callback_method_handler(
msg: Message, send_reply: Callable[[Message], None]
) -> None:
result = method_fn(interface, *msg_body_to_args(msg))
"""This is the callback that will be called when a method call is."""
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
result = method_fn(interface, *args)
if send_reply is BLOCK_UNEXPECTED_REPLY or not _expects_reply(msg):
return
body, fds = fn_result_to_body(
Expand Down
4 changes: 1 addition & 3 deletions src/dbus_fast/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,7 @@ def _remove_bus(interface: "ServiceInterface", bus: "BaseMessageBus") -> None:

@staticmethod
def _msg_body_to_args(msg: Message) -> List[Any]:
if not msg.unix_fds or not signature_contains_type(
msg.signature_tree, msg.body, "h"
):
if not signature_contains_type(msg.signature_tree, msg.body, "h"):
return msg.body

# XXX: This deep copy could be expensive if messages are very
Expand Down

0 comments on commit b222552

Please sign in to comment.