-
-
Notifications
You must be signed in to change notification settings - Fork 2k
/
websocket.py
245 lines (196 loc) · 8.6 KB
/
websocket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import websockets
import asyncio
import json
from syft.core import utils
from syft.core.workers import BaseWorker
class WebSocketWorker(BaseWorker):
"""A worker capable of performing the functions of a BaseWorker across a
websocket connection. This worker essentially can replace Socket Worker.
:Parameters:
* **hook (**:class:`.hooks.BaseHook` **)** This is a reference to
the hook object which overloaded the underlying deep learning framework.
* **id (int or string, optional)** the integer or string identifier
for this node
* **is_client_worker (bool, optional)** a boolean which determines
whether this worker is associated with an end user client. If so,
it assumes that the client will maintain control over when
tensors/variables/models are instantiated or deleted as opposed to
handling tensor/variable/model lifecycle internally.
* **objects (list of tensors, variables, or models, optional)**
When the worker is NOT a client worker, it stores all tensors
it receives or creates in this dictionary.
The key to each object is it's id.
* **tmp_objects (list of tensors, variables, or models, optional)**
When the worker IS a client worker, it stores some tensors temporarily
in this _tmp_objects simply to ensure that they do not get deallocated by
the Python garbage collector while in the process of being registered.
This dictionary can be emptied using the clear_tmp_objects method.
* **known_workers (list of **:class:`BaseWorker` ** objects, optional)** This dictionary
can include all known workers.
* **verbose (bool, optional)** A flag for whether or not to print events to stdout.
:Example Server:
>>> from syft.core.hooks import TorchHook
>>> from syft.core.workers import WebSocketWorker
>>> hook = TorchHook()
Hooking into Torch...
Overloading complete.
>>> local_worker = WebSocketWorker(hook=hook,
id=2,
port=8181,
is_pointer=False,
is_client_worker=False)
Starting a Websocket Worker....
Ready to receive commands....
Server Socket has been initialized
:Example Client:
>>> import torch
>>> from syft.core.hooks import TorchHook
>>> from syft.core.workers import WebSocketWorker
>>> hook = TorchHook(local_worker=WebSocketWorker(id=0, port=8182))
Starting Socket Worker...
Ready!
Hooking into Torch...
Overloading complete.
>>> remote_client = WebSocketWorker(hook=hook,id=2, port=8181, is_pointer=True)
>>> hook.local_worker.add_worker(remote_client)
Attaching Pointer to Socket Worker...
>>> x = torch.FloatTensor([1,2,3,4,5]).send(remote_client)
>>> x2 = torch.FloatTensor([1,2,3,4,4]).send(remote_client)
>>> y = x + x2 + x
>>> y
[torch.FloatTensor - Locations:[<syft.core.workers.SocketWorker object at 0x7f94eaaa6630>]]
>>> y.get()
3
6
9
12
14
[torch.FloatTensor of size 5]
"""
def __init__(
self,
hook=None,
hostname="localhost",
port=8110,
max_connections=5,
id=0,
is_client_worker=True,
objects={},
tmp_objects={},
known_workers={},
verbose=True,
is_pointer=False,
queue_size=0,
):
super().__init__(
hook=hook,
id=id,
is_client_worker=is_client_worker,
objects=objects,
tmp_objects=tmp_objects,
known_workers=known_workers,
verbose=verbose,
queue_size=queue_size,
)
self.is_asyncronous = True
self.hook = hook
self.hostname = hostname
self.port = port
self.uri = "ws://" + self.hostname + ":" + str(self.port)
self.max_connections = max_connections
self.is_pointer = is_pointer
if self.is_pointer:
if self.verbose:
print("Attaching Pointer to WebSocket Worker....")
self.serversocket = None
clientsocket = websockets.client.connect(self.uri)
self.clientsocket = clientsocket
else:
if self.verbose:
print("Starting a Websocket Worker....")
if not is_client_worker or self.is_pointer:
print("Ready to recieve commands....")
self.serversocket = websockets.serve(
self._server_socket_listener, self.hostname, self.port
)
print("Server Socket has been initialized")
asyncio.get_event_loop().run_until_complete(self.serversocket)
asyncio.get_event_loop().run_forever()
else:
print("Ready...")
async def _client_socket_connect(self, json_request):
"""Establishes a connection to the server socket and waits for a
response. Then the response is returned.
:Parameters:
* **json_request** JSON request that is needed to be sent to the server Socket
* ** out (json)** The response from the server is returned as JSON.
"""
async with websockets.connect(self.uri) as client_socket:
await client_socket.send(json_request)
recieved_msg = await client_socket.recv()
return recieved_msg
async def _server_socket_listener(self, websocket, path):
"""A listener for the server socket so whenever a message is sent by a
client to the server socket, this method is called and the server
responses accordingly.
:Parameters:
* **websocket** The incoming socket, which messages are recieved from and sent to.
* **path** The path which messages are recieved from and sent to.
"""
msg_wrapper_byte = await websocket.recv()
msg_wrapper_str = msg_wrapper_byte.decode("utf-8")
if self.verbose:
print("Recieved Command From:", self.uri)
decoder = utils.PythonJSONDecoder(self)
msg_wrapper = decoder.decode(msg_wrapper_str)
await websocket.send(self.process_message_type(msg_wrapper))
def whoami(self):
"""Returns metadata information about the worker.
This method returns the default which is the id and uri of the
worker.
"""
return json.dumps({"uri": self.uri, "id": self.id})
async def _send_msg(self, message_wrapper_json_binary, recipient):
response = await recipient._client_socket_listener(message_wrapper_json_binary)
response = self._process_buffer(response=response)
return response
def send_msg(self, message, message_type, recipient):
"""Sends a string message to another worker with message_type
information indicating how the message should be processed.
:Parameters:
* **recipient (** :class:`VirtualWorker` **)** the worker being sent a message.
* **message (string)** the message being sent
* **message_type (string)** the type of message being sent. This affects how
the message is processed by the recipient. The types of message are described
in :func:`receive_msg`.
* **out (object)** the response from the message being sent. This can be a variety
of object types. However, the object is typically only used during testing or
local development with :class:`VirtualWorker` workers.
"""
message_wrapper = {}
message_wrapper["message"] = message
message_wrapper["type"] = message_type
self.message_queue.append(message_wrapper)
if self.queue_size:
if len(self.message_queue) > self.queue_size:
message_wrapper = self.compile_composite_message()
else:
return None
message_wrapper_json = json.dumps(message_wrapper) + "\n"
message_wrapper_json_binary = message_wrapper_json.encode()
self.message_queue = []
response = recipient._client_socket_listener(message_wrapper_json_binary)
response = self._process_buffer(response=response)
return response
def _process_buffer(cls, response, delimiter="\n"):
buffer = response
if delimiter in buffer:
(line, buffer) = buffer.split(delimiter, 1)
return line + delimiter
else:
return buffer
def _client_socket_listener(cls, message_wrapper_json_binary):
response = asyncio.get_event_loop().run_until_complete(
cls._client_socket_connect(message_wrapper_json_binary)
)
return response