Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ name: CI

on:
push:
branches: [main]
tags: [v*.*.*]
branches: [ main ]
tags: [ v*.*.* ]

pull_request:
branches: [ "main" ]
branches: [ main ]
types:
- synchronize
- opened
Expand Down
22 changes: 13 additions & 9 deletions hello/advertizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ def stop(self) -> None:
self._sender.stop()

def advertise(self, info: ServiceInfo | None = None) -> None:
if info:
self._info = info

if self._group:
if info:
self._info = info
if self._info:
self._sender.send(self._info)
log.info('Service advertised', service=self._info, group=self._group)
else:
log.warning('Cannot advertise service, no service info provided', group=self._group)
else:
log.warning('Cannot advertise service, advertizer not started', service=info)

Expand All @@ -72,21 +75,22 @@ def start(self, group: Group, info: ServiceInfo | None = None) -> None:
self._receiver.register(self._handle_message)

def stop(self) -> None:
super().stop()
self._receiver.deregister(self._handle_message)
self._receiver.stop()
super().stop()

def _handle_message(self, message: dict[str, Any]) -> None:
if self._info:
try:
query = ServiceQuery(**message)
log.debug('Query received', group=self._group, query=query)
self._handle_query(query, self._info)
matcher = ServiceMatcher(query)
log.debug('Service query received', group=self._group, query=query)
self._handle_query(matcher, self._info)
except Exception as error:
log.warning('Invalid query message received', group=self._group, received=message, error=error)
log.warning('Invalid service query received', group=self._group, received=message, error=error)

def _handle_query(self, query: ServiceQuery, info: ServiceInfo) -> None:
matcher = ServiceMatcher(query)
if matcher and matcher.matches(info):
def _handle_query(self, matcher: ServiceMatcher, info: ServiceInfo) -> None:
if matcher.matches(info):
delay = round(self._max_delay * random.random(), 3)
log.info('Responding to query', group=self._group, query=matcher.query, service=info, delay=delay)
time.sleep(delay)
Expand Down
88 changes: 47 additions & 41 deletions hello/discoverer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: 2024 Attila Gombos <attila.gombos@effective-range.com>
# SPDX-License-Identifier: MIT

from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import Any, Protocol
Expand All @@ -22,6 +23,8 @@ class DiscoveryEventType(Enum):

@dataclass
class DiscoveryEvent:
group: Group
query: ServiceQuery
service: ServiceInfo
type: DiscoveryEventType

Expand All @@ -41,28 +44,26 @@ def stop(self) -> None:
def discover(self, query: ServiceQuery | None = None) -> None:
raise NotImplementedError()

def get_services(self) -> dict[UUID, ServiceInfo]:
raise NotImplementedError()

def register(self, handler: OnDiscoveryEvent) -> None:
raise NotImplementedError()

def deregister(self, handler: OnDiscoveryEvent) -> None:
raise NotImplementedError()

def get_handlers(self) -> list[OnDiscoveryEvent]:
def get_services(self) -> dict[UUID, ServiceInfo]:
raise NotImplementedError()


class DefaultDiscoverer(Discoverer):

def __init__(self, sender: Sender, receiver: Receiver) -> None:
def __init__(self, sender: Sender, receiver: Receiver, max_workers: int = 8) -> None:
self._sender = sender
self._receiver = receiver
self._group: Group | None = None
self._matcher: ServiceMatcher | None = None
self._cache: dict[UUID, ServiceInfo] = {}
self._services: dict[UUID, ServiceInfo] = {}
self._handlers: list[OnDiscoveryEvent] = []
self._handler_executor = ThreadPoolExecutor(max_workers=max_workers)

def __enter__(self) -> Discoverer:
return self
Expand All @@ -80,65 +81,73 @@ def start(self, group: Group, query: ServiceQuery | None = None) -> None:

def stop(self) -> None:
self._group = None
self._matcher = None
self._sender.stop()
self._receiver.deregister(self._handle_message)
self._receiver.stop()

def discover(self, query: ServiceQuery | None = None) -> None:
if query:
self._matcher = ServiceMatcher(query)

if self._group:
if query:
self._matcher = ServiceMatcher(query)
if self._matcher:
self._sender.send(self._matcher.query)
log.info('Service discovery initiated', query=self._matcher.query, group=self._group)
log.info('Service discovery initiated', group=self._group, query=self._matcher.query)
else:
log.warning('Cannot discover services, no query provided', group=self._group)
else:
log.warning('Cannot discover services, discoverer not started', query=query)

def get_services(self) -> dict[UUID, ServiceInfo]:
return self._cache.copy()

def register(self, handler: OnDiscoveryEvent) -> None:
self._handlers.append(handler)

def deregister(self, handler: OnDiscoveryEvent) -> None:
self._handlers.remove(handler)

def get_handlers(self) -> list[OnDiscoveryEvent]:
return self._handlers.copy()
def get_services(self) -> dict[UUID, ServiceInfo]:
return self._services.copy()

def _handle_message(self, message: dict[str, Any]) -> None:
try:
service = ServiceInfo(UUID(message['uuid']), message['name'], message['role'], message.get('urls', {}))
self._handle_service(service)
except Exception as error:
log.warn('Failed to handle received message', data=message, error=error)
if self._group and self._matcher:
try:
service = ServiceInfo(UUID(message['uuid']), message['name'], message['role'], message.get('urls', {}))
log.debug('Service info received', service=service, group=self._group)
self._handle_service(service, self._group, self._matcher)
except Exception as error:
log.warn('Invalid service info received', group=self._group, data=message, error=error)

def _handle_service(self, service: ServiceInfo) -> None:
if self._matcher and self._matcher.matches(service):
cached = self._cache.get(service.uuid)
def _handle_service(self, service: ServiceInfo, group: Group, matcher: ServiceMatcher) -> None:
if matcher.matches(service):
stored = self._services.get(service.uuid)

if event := self._create_event(cached, service):
if event := self._create_event(group, matcher, stored, service):
self._handle_event(event)

def _create_event(self, cached: ServiceInfo | None, service: ServiceInfo) -> DiscoveryEvent | None:
if cached:
if cached != service:
log.info('Service updated', old_service=cached, new_service=service)
return DiscoveryEvent(service, DiscoveryEventType.UPDATED)
def _create_event(self, group: Group, matcher: ServiceMatcher,
stored: ServiceInfo | None, service: ServiceInfo) -> DiscoveryEvent | None:
if stored:
if stored != service:
log.info('Service updated', group=group, old_service=stored, new_service=service)
return DiscoveryEvent(group, matcher.query, service, DiscoveryEventType.UPDATED)
else:
log.debug('Service unchanged', service=service)
log.debug('Service unchanged', group=group, service=service)
return None
else:
log.info('Service discovered', service=service)
return DiscoveryEvent(service, DiscoveryEventType.DISCOVERED)
log.info('New service discovered', group=group, service=service)
return DiscoveryEvent(group, matcher.query, service, DiscoveryEventType.DISCOVERED)

def _handle_event(self, event: DiscoveryEvent) -> None:
service = event.service
self._cache[service.uuid] = service
for callback in self._handlers:
try:
callback(event)
except Exception as error:
log.warn('Error in event handler execution', event=event, error=error)
self._services[event.service.uuid] = event.service

for handler in self._handlers:
self._handler_executor.submit(self._execute_handler, handler, event)

def _execute_handler(self, handler: OnDiscoveryEvent, event: DiscoveryEvent) -> None:
try:
handler(event)
except Exception as error:
log.warn('Error in event handler execution', event=event, error=error)


class ScheduledDiscoverer(DefaultScheduler[ServiceQuery], Discoverer):
Expand Down Expand Up @@ -172,8 +181,5 @@ def register(self, handler: OnDiscoveryEvent) -> None:
def deregister(self, handler: OnDiscoveryEvent) -> None:
self._discoverer.deregister(handler)

def get_handlers(self) -> list[OnDiscoveryEvent]:
return self._discoverer.get_handlers()

def _execute(self, query: ServiceQuery | None = None) -> None:
self.discover(query)
8 changes: 1 addition & 7 deletions hello/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def register(self, handler: OnMessage) -> None:
def deregister(self, handler: OnMessage) -> None:
raise NotImplementedError()

def get_handlers(self) -> list[OnMessage]:
raise NotImplementedError()


class DishReceiver(Receiver):

Expand Down Expand Up @@ -83,9 +80,6 @@ def register(self, handler: OnMessage) -> None:
def deregister(self, handler: OnMessage) -> None:
self._handlers.remove(handler)

def get_handlers(self) -> list[OnMessage]:
return self._handlers.copy()

def _receive_loop(self) -> None:
while self._group:
try:
Expand All @@ -105,4 +99,4 @@ def _execute_handler(self, handler: OnMessage, message: dict[str, Any]) -> None:
try:
handler(message)
except Exception as error:
log.warn('Error in message handler execution', data=message, group=self._group, error=error)
log.warn('Handler failed to process message', data=message, group=self._group, error=error)
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ description = "A service advertizer/discovery protocol library using ZeroMQ"
authors = [
{ name = "Ferenc Nandor Janky & Attila Gombos", email = "info@effective-range.com" }
]
maintainers = [
{ name = "Ferenc Nandor Janky & Attila Gombos", email = "info@effective-range.com" }
]
dependencies = [
"pyzmq @ git+https://github.com/EffectiveRange/pyzmq.git@v27.1.1",
"python-context-logger @ git+https://github.com/EffectiveRange/python-context-logger.git@latest",
Expand All @@ -25,3 +28,8 @@ build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
version_scheme = "guess-next-dev"
local_scheme = "node-and-date"

[tool.pytest]
addopts = ["--verbose", "--capture=no"]
python_files = ["*Test.py"]
python_classes = ["*Test"]
17 changes: 12 additions & 5 deletions tests/defaultDiscovererTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from uuid import uuid4

from context_logger import setup_logging
from test_utility import wait_for_assertion

from hello import ServiceInfo, Group, ServiceQuery, DefaultDiscoverer, Sender, Receiver, OnDiscoveryEvent, \
DiscoveryEventType, DiscoveryEvent
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_registers_event_handler(self):
discoverer.register(handler)

# Then
self.assertIn(handler, discoverer.get_handlers())
self.assertIn(handler, discoverer._handlers)

def test_deregisters_event_handler(self):
# Given
Expand All @@ -88,7 +89,7 @@ def test_deregisters_event_handler(self):
discoverer.deregister(handler)

# Then
self.assertNotIn(handler, discoverer.get_handlers())
self.assertNotIn(handler, discoverer._handlers)

def test_caches_service_and_calls_handler_when_receives_matching_info(self):
# Given
Expand All @@ -104,7 +105,9 @@ def test_caches_service_and_calls_handler_when_receives_matching_info(self):

# Then
self.assertEqual({SERVICE_INFO.uuid: SERVICE_INFO}, discoverer.get_services())
handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED))
wait_for_assertion(1, lambda: handler.assert_called_once_with(
DiscoveryEvent(GROUP, SERVICE_QUERY, SERVICE_INFO, DiscoveryEventType.DISCOVERED)
))

def test_updates_service_and_calls_handler_when_receives_matching_info(self):
# Given
Expand All @@ -125,7 +128,9 @@ def test_updates_service_and_calls_handler_when_receives_matching_info(self):

# Then
self.assertEqual({SERVICE_INFO.uuid: new_service_info}, discoverer.get_services())
handler.assert_called_once_with(DiscoveryEvent(new_service_info, DiscoveryEventType.UPDATED))
wait_for_assertion(1, lambda: handler.assert_called_once_with(
DiscoveryEvent(GROUP, SERVICE_QUERY, new_service_info, DiscoveryEventType.UPDATED)
))

def test_does_not_call_handler_when_service_info_not_changed(self):
# Given
Expand Down Expand Up @@ -159,7 +164,9 @@ def test_handles_handler_error_gracefully(self):

# Then
self.assertEqual({SERVICE_INFO.uuid: SERVICE_INFO}, discoverer.get_services())
handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED))
wait_for_assertion(1, lambda: handler.assert_called_once_with(
DiscoveryEvent(GROUP, SERVICE_QUERY, SERVICE_INFO, DiscoveryEventType.DISCOVERED)
))

def test_handles_invalid_message_gracefully(self):
# Given
Expand Down
17 changes: 6 additions & 11 deletions tests/dishReceiverTest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from itertools import chain, repeat
from unittest import TestCase
from unittest.mock import MagicMock
from uuid import uuid4
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_registers_handler(self):
receiver.register(handler)

# Then
self.assertIn(handler, receiver.get_handlers())
self.assertIn(handler, receiver._handlers)

def test_deregisters_handler(self):
# Given
Expand All @@ -107,7 +108,7 @@ def test_deregisters_handler(self):
receiver.deregister(handler)

# Then
self.assertNotIn(handler, receiver.get_handlers())
self.assertNotIn(handler, receiver._handlers)

def test_calls_registered_handler_on_message(self):
# Given
Expand All @@ -118,9 +119,7 @@ def test_calls_registered_handler_on_message(self):

with DishReceiver(context) as receiver:
receiver._poller = MagicMock(spec=Poller)
receiver._poller.poll.side_effect = [
{context.socket.return_value: POLLIN},
]
receiver._poller.poll.side_effect = chain([{context.socket.return_value: POLLIN}], repeat({}))
receiver.register(handler)

# When
Expand All @@ -138,9 +137,7 @@ def test_handles_message_receive_error_gracefully(self):

with DishReceiver(context) as receiver:
receiver._poller = MagicMock(spec=Poller)
receiver._poller.poll.side_effect = [
{context.socket.return_value: POLLIN},
]
receiver._poller.poll.side_effect = chain([{context.socket.return_value: POLLIN}], repeat({}))
receiver.register(handler)

# When
Expand All @@ -159,9 +156,7 @@ def test_handles_handler_execution_error_gracefully(self):

with DishReceiver(context) as receiver:
receiver._poller = MagicMock(spec=Poller)
receiver._poller.poll.side_effect = [
{context.socket.return_value: POLLIN},
]
receiver._poller.poll.side_effect = chain([{context.socket.return_value: POLLIN}], repeat({}))
receiver.register(handler)

# When
Expand Down
Loading
Loading