Skip to content

Commit

Permalink
Make services stateful; remove circular dep
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Liam Trask authored and Andrew Liam Trask committed Jul 11, 2020
1 parent a5627b5 commit a25b58f
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 61 deletions.
1 change: 1 addition & 0 deletions docs/api/syft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Subpackages
:maxdepth: 4

syft.ast
syft.common
syft.core
syft.lib
syft.typecheck
12 changes: 6 additions & 6 deletions src/syft/core/worker/service/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# from . import delete_object_service
# from . import get_object_service
# from . import run_class_service
# from . import run_function_or_constructor_service
# from . import save_object_service
# from . import worker_service
from . import delete_object_service
from . import get_object_service
from . import run_class_service
from . import run_function_or_constructor_service
from . import save_object_service
from . import worker_service
8 changes: 0 additions & 8 deletions src/syft/core/worker/service/delete_object_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations
from .... import type_hints
from .worker_service import WorkerService
from .. import message_service_mapping
from ...message import DeleteObjectMessage
from ....common import AbstractWorker

Expand All @@ -17,10 +16,3 @@ def process(worker: AbstractWorker, msg: DeleteObjectMessage) -> None:
def message_type_handler() -> type:
return DeleteObjectMessage

@staticmethod
@type_hints
def register_service() -> None:
message_service_mapping[DeleteObjectMessage] = DeleteObjectService


DeleteObjectService.register_service()
9 changes: 0 additions & 9 deletions src/syft/core/worker/service/get_object_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .... import type_hints
from .worker_service import WorkerService
from .. import message_service_mapping
from ...message import GetObjectMessage
from ....common import AbstractWorker

Expand All @@ -17,11 +16,3 @@ def process(worker: AbstractWorker, msg: GetObjectMessage) -> object: #TODO: ret
@type_hints
def message_type_handler() -> type:
return GetObjectMessage

@staticmethod
@type_hints
def register_service() -> None:
message_service_mapping[GetObjectMessage] = GetObjectService


GetObjectService.register_service()
9 changes: 0 additions & 9 deletions src/syft/core/worker/service/run_class_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from .worker_service import WorkerService
from .. import message_service_mapping
from ...message import RunClassMethodMessage
from .... import type_hints
from ...pointer.pointer import Pointer
Expand Down Expand Up @@ -60,11 +59,3 @@ def process(worker: AbstractWorker, msg: RunClassMethodMessage) -> None:
@type_hints
def message_type_handler() -> type:
return RunClassMethodMessage

@staticmethod
@type_hints
def register_service() -> None:
message_service_mapping[RunClassMethodMessage] = RunClassMethodService


RunClassMethodService.register_service()
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from .worker_service import WorkerService
from .. import message_service_mapping
from ...message import RunFunctionOrConstructorMessage
from .... import type_hints
from ....common import AbstractWorker
Expand All @@ -15,14 +14,4 @@ def process(worker: AbstractWorker, msg: RunFunctionOrConstructorMessage) -> Non
@staticmethod
@type_hints
def message_type_handler() -> type:
return RunFunctionOrConstructorMessage

@staticmethod
@type_hints
def register_service() -> None:
message_service_mapping[
RunFunctionOrConstructorMessage
] = RunFunctionOrConstructorService


RunFunctionOrConstructorService.register_service()
return RunFunctionOrConstructorMessage
9 changes: 0 additions & 9 deletions src/syft/core/worker/service/save_object_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from .worker_service import WorkerService
from .. import message_service_mapping
from ...message import SaveObjectMessage

from .... import type_hints
Expand All @@ -20,11 +19,3 @@ def process(worker: AbstractWorker, msg: SaveObjectMessage) -> None:
@type_hints
def message_type_handler() -> type:
return SaveObjectMessage

@staticmethod
@type_hints
def register_service() -> None:
message_service_mapping[SaveObjectMessage] = SaveObjectService


SaveObjectService.register_service()
7 changes: 1 addition & 6 deletions src/syft/core/worker/service/worker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,4 @@ def process(worker: AbstractWorker, msg: SyftMessage) -> object:
@staticmethod
@type_hints
def message_type_handler() -> SyftMessage:
raise NotImplementedError

@staticmethod
@type_hints
def register_service() -> None:
raise NotImplementedError
raise NotImplementedError
13 changes: 12 additions & 1 deletion src/syft/core/worker/virtual/virtual_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ..worker import Worker
from ..virtual.virtual_client import VirtualClient
from typing import final

from .. import service


class VirtualWorker(Worker):
Expand All @@ -15,3 +15,14 @@ def _recv_msg(self, msg):
def get_client(self, verbose=False):
self.client.verbose = verbose
return self.client

def _register_services(self) -> None:
services = list()
services.append(service.get_object_service.GetObjectService)
services.append(service.save_object_service.SaveObjectService)
services.append(service.run_class_service.RunClassMethodService)
services.append(service.delete_object_service.DeleteObjectService)
services.append(service.run_function_or_constructor_service.RunFunctionOrConstructorService)

for s in services:
self.msg_router[s.message_type_handler()] = s()
8 changes: 7 additions & 1 deletion src/syft/core/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def __init__(self, id: str, debug: bool = False, supported_frameworks: list = []
)
self.frameworks.attrs[name] = ast

self.msg_router = worker.message_service_mapping
self.msg_router = {}

if debug:
self.worker_stats = WorkerStats()

self._register_services()

@type_hints
def recv_msg(self, msg: SyftMessage) -> object:
return self.msg_router[type(msg)].process(worker=self, msg=msg)
Expand All @@ -62,6 +64,10 @@ def _send_msg(self) -> None:
def _recv_msg(self) -> None:
raise NotImplementedError

@type_hints
def _register_services(self) -> None:
raise NotImplementedError

def __repr__(self):
if self.worker_stats:
return f"Worker: {self.id}\n{self.worker_stats}"
Expand Down

0 comments on commit a25b58f

Please sign in to comment.