/
websocket_server.py
180 lines (147 loc) · 5.92 KB
/
websocket_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import binascii
from typing import Union
from typing import List
import asyncio
import torch
import websockets
import ssl
import sys
import tblib.pickling_support
import socket
import logging
tblib.pickling_support.install()
import syft as sy
from syft.frameworks.torch.tensors.interpreters import AbstractTensor
from syft.workers.virtual import VirtualWorker
from syft.exceptions import GetNotPermittedError
from syft.exceptions import ResponseSignatureError
from syft.federated import FederatedClient
class WebsocketServerWorker(VirtualWorker, FederatedClient):
def __init__(
self,
hook,
host: str,
port: int,
id: Union[int, str] = 0,
log_msgs: bool = False,
verbose: bool = False,
data: List[Union[torch.Tensor, AbstractTensor]] = None,
loop=None,
cert_path: str = None,
key_path: str = None,
):
"""This is a simple extension to normal workers wherein
all messages are passed over websockets. Note that because
BaseWorker assumes a request/response paradigm, this worker
enforces this paradigm by default.
Args:
hook (sy.TorchHook): a normal TorchHook object
id (str or id): the unique id of the worker (string or int)
log_msgs (bool): whether or not all messages should be
saved locally for later inspection.
verbose (bool): a verbose option - will print all messages
sent/received to stdout
host (str): the host on which the server should be run
port (int): the port on which the server should be run
data (dict): any initial tensors the server should be
initialized with (such as datasets)
loop: the asyncio event loop if you want to pass one in
yourself
cert_path: path to used secure certificate, only needed for secure connections
key_path: path to secure key, only needed for secure connections
"""
self.port = port
self.host = host
self.cert_path = cert_path
self.key_path = key_path
if loop is None:
loop = asyncio.new_event_loop()
# this queue is populated when messages are received
# from a client
self.broadcast_queue = asyncio.Queue()
# this is the asyncio event loop
self.loop = loop
# call BaseWorker constructor
super().__init__(hook=hook, id=id, data=data, log_msgs=log_msgs, verbose=verbose)
async def _consumer_handler(self, websocket: websockets.WebSocketCommonProtocol):
"""This handler listens for messages from WebsocketClientWorker
objects.
Args:
websocket: the connection object to receive messages from and
add them into the queue.
"""
while True:
msg = await websocket.recv()
await self.broadcast_queue.put(msg)
async def _producer_handler(self, websocket: websockets.WebSocketCommonProtocol):
"""This handler listens to the queue and processes messages as they
arrive.
Args:
websocket: the connection object we use to send responses
back to the client.
"""
while True:
# get a message from the queue
message = await self.broadcast_queue.get()
# convert that string message to the binary it represent
message = binascii.unhexlify(message[2:-1])
# process the message
response = self._recv_msg(message)
# convert the binary to a string representation
# (this is needed for the websocket library)
response = str(binascii.hexlify(response))
# send the response
await websocket.send(response)
def _recv_msg(self, message: bin) -> bin:
try:
return self.recv_msg(message)
except (ResponseSignatureError, GetNotPermittedError) as e:
return sy.serde.serialize(e)
async def _handler(self, websocket: websockets.WebSocketCommonProtocol, *unused_args):
"""Setup the consumer and producer response handlers with asyncio.
Args:
websocket: the websocket connection to the client
"""
asyncio.set_event_loop(self.loop)
consumer_task = asyncio.ensure_future(self._consumer_handler(websocket))
producer_task = asyncio.ensure_future(self._producer_handler(websocket))
done, pending = await asyncio.wait(
[consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
def start(self):
"""Start the server"""
# Secure behavior: adds a secure layer applying cryptography and authentication
if not (self.cert_path is None) and not (self.key_path is None):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(self.cert_path, self.key_path)
start_server = websockets.serve(
self._handler,
self.host,
self.port,
ssl=ssl_context,
max_size=None,
ping_timeout=None,
close_timeout=None,
)
else:
# Insecure
start_server = websockets.serve(
self._handler,
self.host,
self.port,
max_size=None,
ping_timeout=None,
close_timeout=None,
)
asyncio.get_event_loop().run_until_complete(start_server)
print("Serving. Press CTRL-C to stop.")
try:
asyncio.get_event_loop().run_forever()
except KeyboardInterrupt:
logging.info("Websocket server stopped.")
def list_objects(self, *args):
return str(self._objects)
def objects_count(self, *args):
return len(self._objects)