Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add event broadcasting capability #672

Merged
merged 17 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Jump to:

Description

- Implement asynchronous notifications for shared data
- Filenames conform to snake case
- Update SmartSim environment variables using new naming convention
- Refactor `exception_handler`
Expand Down
11 changes: 7 additions & 4 deletions smartsim/_core/mli/comm/channel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import base64
import typing as t
from abc import ABC, abstractmethod

Expand All @@ -40,20 +41,22 @@ def __init__(self, descriptor: t.Union[str, bytes]) -> None:
self._descriptor = descriptor

@abstractmethod
def send(self, value: bytes) -> None:
def send(self, value: bytes, timeout: float = 0) -> None:
"""Send a message through the underlying communication channel

:param timeout: maximum time to wait (in seconds) for messages to send
:param value: The value to send"""

@abstractmethod
def recv(self) -> t.List[bytes]:
"""Receieve a message through the underlying communication channel
def recv(self, timeout: float = 0) -> t.List[bytes]:
"""Receives message(s) through the underlying communication channel

:param timeout: maximum time to wait (in seconds) for messages to arrive
:returns: the received message"""

@property
def descriptor(self) -> bytes:
"""Return the channel descriptor for the underlying dragon channel"""
if isinstance(self._descriptor, str):
return self._descriptor.encode("utf-8")
return base64.b64decode(self._descriptor.encode("utf-8"))
return self._descriptor
131 changes: 113 additions & 18 deletions smartsim/_core/mli/comm/channel/dragon_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,47 +28,142 @@
import sys
import typing as t

import dragon.channels as dch
import dragon.infrastructure.facts as df
import dragon.infrastructure.parameters as dp
import dragon.managed_memory as dm
import dragon.utils as du

import smartsim._core.mli.comm.channel.channel as cch
from smartsim.error.errors import SmartSimError
from smartsim.log import get_logger

logger = get_logger(__name__)

import dragon.channels as dch

DEFAULT_CHANNEL_BUFFER_SIZE = 500
ankona marked this conversation as resolved.
Show resolved Hide resolved
"""Maximum number of messages that can be buffered. DragonCommChannel will
raise an exception if no clients consume messages before the buffer is filled."""


def create_local(capacity: int = 0) -> dch.Channel:
"""Creates a Channel attached to the local memory pool

:param capacity: the number of events the channel can buffer; uses the default
buffer size `DEFAULT_CHANNEL_BUFFER_SIZE` when not supplied
:returns: the instantiated channel"""
pool = dm.MemoryPool.attach(du.B64.str_to_bytes(dp.this_process.default_pd))
channel: t.Optional[dch.Channel] = None
offset = 0

capacity = capacity if capacity > 0 else DEFAULT_CHANNEL_BUFFER_SIZE

while not channel:
# search for an open channel ID
offset += 1
cid = df.BASE_USER_MANAGED_CUID + offset
try:
channel = dch.Channel(
mem_pool=pool,
c_uid=cid,
capacity=capacity,
)
logger.debug(
f"Channel {cid} created in pool {pool.serialize()} w/capacity {capacity}"
)
except Exception:
if offset < 100:
logger.warning(f"Unable to attach to channnel id {cid}. Retrying...")
else:
logger.error(f"All attempts to attach local channel have failed")
raise

return channel


class DragonCommChannel(cch.CommChannelBase):
"""Passes messages by writing to a Dragon channel"""

def __init__(self, key: bytes) -> None:
"""Initialize the DragonCommChannel instance"""
super().__init__(key)
self._channel: dch.Channel = dch.Channel.attach(key)
def __init__(self, channel: "dch.Channel") -> None:
"""Initialize the DragonCommChannel instance

def send(self, value: bytes) -> None:
:param channel: a channel to use for communications
:param recv_timeout: a default timeout to apply to receive calls"""
serialized_ch = channel.serialize()
descriptor = base64.b64encode(serialized_ch).decode("utf-8")
super().__init__(descriptor)
self._channel = channel

@property
def channel(self) -> "dch.Channel":
"""The underlying communication channel"""
return self._channel

def send(self, value: bytes, timeout: float = 0.001) -> None:
"""Send a message throuh the underlying communication channel
:param value: The value to send"""
with self._channel.sendh(timeout=None) as sendh:

:param value: The value to send
:param timeout: maximum time to wait (in seconds) for messages to send"""
with self._channel.sendh(timeout=timeout) as sendh:
sendh.send_bytes(value)
logger.debug(f"DragonCommChannel {self.descriptor!r} sent message")

def recv(self) -> t.List[bytes]:
"""Receieve a message through the underlying communication channel
def recv(self, timeout: float = 0.001) -> t.List[bytes]:
"""Receives message(s) through the underlying communication channel

:param timeout: maximum time to wait (in seconds) for messages to arrive
:returns: the received message"""
with self._channel.recvh(timeout=None) as recvh:
message_bytes: bytes = recvh.recv_bytes(timeout=None)
return [message_bytes]
with self._channel.recvh(timeout=timeout) as recvh:
messages: t.List[bytes] = []

try:
message_bytes = recvh.recv_bytes(timeout=timeout)
messages.append(message_bytes)
logger.debug(f"DragonCommChannel {self.descriptor!r} received message")
ankona marked this conversation as resolved.
Show resolved Hide resolved
except dch.ChannelEmpty:
# emptied the queue, ok to swallow this ex
logger.debug(f"DragonCommChannel exhausted: {self.descriptor!r}")
except dch.ChannelRecvTimeout as ex:
logger.debug(f"Timeout exceeded on channel.recv: {self.descriptor!r}")

return messages

@property
def descriptor_string(self) -> str:
"""Return the channel descriptor for the underlying dragon channel
as a string. Automatically performs base64 encoding to ensure the
string can be used in a call to `from_descriptor`"""
if isinstance(self._descriptor, str):
return self._descriptor

if isinstance(self._descriptor, bytes):
return base64.b64encode(self._descriptor).decode("utf-8")

raise ValueError(f"Unable to convert channel descriptor: {self._descriptor}")

@classmethod
def from_descriptor(
cls,
descriptor: str,
descriptor: t.Union[bytes, str],
) -> "DragonCommChannel":
"""A factory method that creates an instance from a descriptor string

:param descriptor: The descriptor that uniquely identifies the resource
:param descriptor: The descriptor that uniquely identifies the resource. Output
from `descriptor_string` is correctly encoded.
:returns: An attached DragonCommChannel"""
try:
return DragonCommChannel(base64.b64decode(descriptor))
except:
logger.error(f"Failed to create dragon comm channel: {descriptor}")
raise
utf8_descriptor: t.Union[str, bytes] = descriptor
if isinstance(descriptor, str):
utf8_descriptor = descriptor.encode("utf-8")

# todo: ensure the bytes argument and condition are removed
# after refactoring the RPC models

actual_descriptor = base64.b64decode(utf8_descriptor)
channel = dch.Channel.attach(actual_descriptor)
return DragonCommChannel(channel)
except Exception as ex:
raise SmartSimError(
f"Failed to create dragon comm channel: {descriptor!r}"
) from ex
34 changes: 25 additions & 9 deletions smartsim/_core/mli/comm/channel/dragon_fli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@
# isort: off
from dragon import fli
import dragon.channels as dch
import dragon.infrastructure.facts as df
import dragon.infrastructure.parameters as dp
import dragon.managed_memory as dm
import dragon.utils as du

# isort: on

import base64
import typing as t

import smartsim._core.mli.comm.channel.channel as cch
from smartsim._core.mli.comm.channel.dragon_channel import create_local
from smartsim.log import get_logger

logger = get_logger(__name__)
Expand All @@ -42,37 +47,48 @@
class DragonFLIChannel(cch.CommChannelBase):
"""Passes messages by writing to a Dragon FLI Channel"""

def __init__(self, fli_desc: bytes, sender_supplied: bool = True) -> None:
def __init__(
self,
fli_desc: bytes,
sender_supplied: bool = True,
buffer_size: int = 0,
) -> None:
"""Initialize the DragonFLIChannel instance

:param fli_desc: the descriptor of the FLI channel to attach
:param sender_supplied: flag indicating if the FLI uses sender-supplied streams
:param buffer_size: maximum number of sent messages that can be buffered
"""
super().__init__(fli_desc)
# todo: do we need memory pool information to construct the channel correctly?
self._fli: "fli" = fli.FLInterface.attach(fli_desc)
self._channel: t.Optional["dch"] = (
dch.Channel.make_process_local() if sender_supplied else None
create_local(buffer_size) if sender_supplied else None
)

def send(self, value: bytes) -> None:
def send(self, value: bytes, timeout: float = 0.001) -> None:
"""Send a message through the underlying communication channel

:param timeout: maximum time to wait (in seconds) for messages to send
:param value: The value to send"""
with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh:
sendh.send_bytes(value)
sendh.send_bytes(value, timeout=timeout)
logger.debug(f"DragonFLIChannel {self.descriptor!r} sent message")

def recv(self) -> t.List[bytes]:
"""Receieve a message through the underlying communication channel
def recv(self, timeout: float = 0.001) -> t.List[bytes]:
"""Receives message(s) through the underlying communication channel

:param timeout: maximum time to wait (in seconds) for messages to arrive
:returns: the received message"""
messages = []
eot = False
with self._fli.recvh(timeout=0.001) as recvh:
with self._fli.recvh(timeout=timeout) as recvh:
while not eot:
try:
message, _ = recvh.recv_bytes(timeout=None)
message, _ = recvh.recv_bytes(timeout=timeout)
messages.append(message)
logger.debug(
f"DragonFLIChannel {self.descriptor!r} received message"
)
except fli.FLIEOT:
eot = True
return messages
Expand Down
11 changes: 8 additions & 3 deletions smartsim/_core/mli/infrastructure/control/request_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,10 @@
conditions are satisfied and cooldown is elapsed.
"""
try:
self._perf_timer.set_active(True)
self._perf_timer.is_active = True

Check warning on line 319 in smartsim/_core/mli/infrastructure/control/request_dispatcher.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/request_dispatcher.py#L319

Added line #L319 was not covered by tests
bytes_list: t.List[bytes] = self._incoming_channel.recv()
except Exception:
self._perf_timer.set_active(False)
self._perf_timer.is_active = False

Check warning on line 322 in smartsim/_core/mli/infrastructure/control/request_dispatcher.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/request_dispatcher.py#L322

Added line #L322 was not covered by tests
else:
if not bytes_list:
exception_handler(
Expand Down Expand Up @@ -501,4 +501,9 @@
return False

def __del__(self) -> None:
self._mem_pool.destroy()
"""Destroy allocated memory resources"""
# pool may be null if a failure occurs prior to successful attach
pool: t.Optional[MemoryPool] = getattr(self, "_mem_pool", None)

Check warning on line 506 in smartsim/_core/mli/infrastructure/control/request_dispatcher.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/request_dispatcher.py#L506

Added line #L506 was not covered by tests

if pool:
pool.destroy()

Check warning on line 509 in smartsim/_core/mli/infrastructure/control/request_dispatcher.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/request_dispatcher.py#L508-L509

Added lines #L508 - L509 were not covered by tests
Loading
Loading