Skip to content

Commit

Permalink
introduce remote nodes and dynamic routing
Browse files Browse the repository at this point in the history
  • Loading branch information
Rima committed Jul 21, 2020
1 parent d06f424 commit e7710b2
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 36 deletions.
31 changes: 28 additions & 3 deletions src/syft/core/io/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,49 @@
from typing import final


class BaseRoute(object):
def configure_connection(self, protocol, url, port):
pass

def register_broadcast_channel(self, channel_name):
"""
Args:
channel_name: the name of the channel to broadcast on.
"""
self.broadcast_channel = name

def client(self):
# connect to configured connection.
return

def type(self):
if self.vm != '*':
return 'VM'
if self.domain != '*':
return 'Domain'
if self.device != '*':
return 'Device'
if self.network != '*':
return 'Network'

@final
class PublicRoute(object):
class PublicRoute(BaseRoute):
@syft_decorator(typechecking=True)
def __init__(self, network: (str, UID), domain: (str, UID)):
self.network = network
self.domain = domain


@final
class PrivateRoute(object):
class PrivateRoute(BaseRoute):
@syft_decorator(typechecking=True)
def __init__(self, device: (str, UID), vm: (str, UID)):
self.device = device
self.vm = vm


@final
class Route(object):
class Route(BaseRoute):
@syft_decorator(typechecking=True)
def __init__(self, pub_route: PublicRoute, pri_route: PrivateRoute):
self.pub_route = pub_route
Expand Down
5 changes: 3 additions & 2 deletions src/syft/core/message/syft_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ def __init__(self, route: Route, msg_id: UID = None) -> None:
self.route = route
self.msg_id = msg_id


class SyftMessageWithReply(SyftMessage):
def __init__(self, route: Route, reply_to: Route, UID = None) -> None:
msg = super(self, route, UID)
def __init__(self, route: Route, reply_to: Route, msg_id: UID = None) -> None:
super(self).__init__(route, msg_id)
msg.reply_to = reply_to
54 changes: 39 additions & 15 deletions src/syft/core/nodes/abstract/remote_nodes.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,54 @@
import enum
from .client import Client
from ...io.route import Route
from ...message.syft_message import SyftMessage

class RemoteNodeTypes(enum.Enum):
VM = 1
Device = 2
Domain = 3
Network = 4

class RemoteNode(object):
def __init__(self, route: Route, type: RemoteNodeType) -> None:
self.id = id
self.route = route

class RemoteNodes(object):
class MyRemoteNodes(object):
nodes = {}

def __init__(self):
for type in RemoteNodeTypes:
def __init__(self, my_route: Route):
for type in RouteTypes:
self.nodes.update({type.name: []})
self.iam = iam.type
self.my_route = my_route

def register_worker(self, route: Route, type: RemoteNodeType) -> None:
def register_node(self, route: Route, type: RouteTypes) -> None:
if type not in self.nodes:
# log unknown type
return

self.nodes[type].append(route)

def forget_worker(self, route: Route, type: RemoteNodeType) -> None:
def forget_node(self, route: Route, type: RouteTypes) -> None:
self.nodes[type].pop(route)

def route_message_to_relevant_nodes(self, route: Route, message: SyftMessage) -> None:
"""
check if the message should be forwarded.
Network: routes to domains
Domain: routes to devices
Device: routes to VMs
VM: doesn't route
"""
if self.iam == 'network':
if route.domain != '*':
route.domain.client().send(message)
else:
for domain in self.nodes['domain']:
domain.client().send(message)
elif self.iam == 'domain':
if route.device != '*':
route.device.client().send(message)
else:
for device in self.nodes['device']:
device.client().send(message)
elif self.iam == 'device':
if route.vm != '*':
route.vm.client().send(message)
else:
for vm in self.nodes['vm']:
vm.client().send(message)

def broadcast(self, route: Route):
channel = route.broadcast_channel
28 changes: 12 additions & 16 deletions src/syft/core/nodes/abstract/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@

# CORE IMPORTS
from ...store.store import ObjectStore
from ...message import SyftMessage
from ...message import SyftMessage, SyftMessageWithReply
from ...io import Route

# nodes related imports
from ..abstract.remote_nodes import RemoteNodes

from remote_nodes import MyRemoteNodes

class Worker(AbstractWorker):

Expand All @@ -43,12 +40,14 @@ def __init__(self, name: str = None):
self.name = name
self.store = ObjectStore()
self.msg_router = {}
# bootstrap
self.known_workers = RemoteNodes()
self.services_registered = False
self.remote_nodes = MyRemoteNodes()

@type_hints
def recv_msg(self, msg: SyftMessage) -> SyftMessage:
# should the message be forwarded, go for it.
self.remote_nodes.route_message_to_relevant_nodes(msg)

try:
return self.msg_router[type(msg)].process(worker=self, msg=msg)
except KeyError as e:
Expand Down Expand Up @@ -92,12 +91,9 @@ def _register_services(self) -> None:

self.services_registered = True

@type_hints
def listen_on_messages(self, msg: SyftMessage) -> SyftMessage:
"""
Allows workers to connect to open messaging protocols and listen on
messages.
The worker would extend this class to implement the specific protocol.
"""
return self.recv_msg(msg)
def reply_to_message(self, original_msg: SyftMessage, reply_msg: SyftMessage):
# get a client and reply.
if not type(original_msg) is SyftMessageWithReply:
return
route = msg.reply_to.route
route.client().send(msg)
2 changes: 2 additions & 0 deletions src/syft/core/nodes/device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class Device(Worker, AbstractDevice):
def __init__(self, name: str):
super().__init__(name=name)

remote_nodes = DeviceRemoteNodes()

# the VM objects themselves
self._vms = {}

Expand Down

0 comments on commit e7710b2

Please sign in to comment.