/
stim_server_client.py
316 lines (247 loc) · 9.9 KB
/
stim_server_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# Author: Mainak Jas <mainak@neuro.hut.fi>
# License: BSD (3-clause)
from ..externals.six.moves import queue
import time
import socket
from ..externals.six.moves import socketserver
import threading
import numpy as np
from ..utils import logger, verbose
class _ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
"""Create a threaded TCP server.
Parameters
----------
server_address : str
Address on which server is listening
request_handler_class : subclass of BaseRequestHandler
_TriggerHandler which defines the handle method
stim_server : instance of StimServer
object of StimServer class
"""
def __init__(self, server_address, request_handler_class,
stim_server): # noqa: D102
# Basically, this server is the same as a normal TCPServer class
# except that it has an additional attribute stim_server
# Create the server and bind it to the desired server address
socketserver.TCPServer.__init__(self, server_address,
request_handler_class,
False)
self.stim_server = stim_server
class _TriggerHandler(socketserver.BaseRequestHandler):
"""Request handler on the server side."""
def handle(self):
"""Handle requests on the server side."""
self.request.settimeout(None)
while self.server.stim_server._running:
data = self.request.recv(1024) # clip input at 1Kb
data = data.decode() # need to turn it into a string (Py3k)
if data == 'add client':
# Add stim_server._client
client_id = self.server.stim_server \
._add_client(self.client_address[0],
self)
# Instantiate queue for communication between threads
# Note: new queue for each handler
if not hasattr(self, '_tx_queue'):
self._tx_queue = queue.Queue()
self.request.sendall("Client added".encode('utf-8'))
# Mark the client as running
for client in self.server.stim_server._clients:
if client['id'] == client_id:
client['running'] = True
elif data == 'get trigger':
# Pop triggers and send them
if (self._tx_queue.qsize() > 0 and
self.server.stim_server, '_clients'):
trigger = self._tx_queue.get()
self.request.sendall(str(trigger).encode('utf-8'))
else:
self.request.sendall("Empty".encode('utf-8'))
class StimServer(object):
"""Stimulation Server.
Server to communicate with StimClient(s).
Parameters
----------
port : int
The port to which the stimulation server must bind to.
n_clients : int
The number of clients which will connect to the server.
See Also
--------
StimClient
"""
def __init__(self, port=4218, n_clients=1): # noqa: D102
# Start a threaded TCP server, binding to localhost on specified port
self._data = _ThreadedTCPServer(('', port),
_TriggerHandler, self)
self.n_clients = n_clients
def __enter__(self): # noqa: D105
# This is done to avoid "[Errno 98] Address already in use"
self._data.allow_reuse_address = True
self._data.server_bind()
self._data.server_activate()
# Start a thread for the server
self._thread = threading.Thread(target=self._data.serve_forever)
# Ctrl-C will cleanly kill all spawned threads
# Once the main thread exits, other threads will exit
self._thread.daemon = True
self._thread.start()
self._running = False
self._clients = list()
return self
def __exit__(self, type, value, traceback): # noqa: D105
self.shutdown()
@verbose
def start(self, timeout=np.inf, verbose=None):
"""Start the server.
Parameters
----------
timeout : float
Maximum time to wait for clients to be added.
verbose : bool, str, int, or None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).
"""
# Start server
if not self._running:
logger.info('RtServer: Start')
self._running = True
start_time = time.time() # init delay counter.
# wait till n_clients are added
while (len(self._clients) < self.n_clients):
current_time = time.time()
if (current_time > start_time + timeout):
raise StopIteration
time.sleep(0.1)
@verbose
def _add_client(self, ip, sock, verbose=None):
"""Add client.
Parameters
----------
ip : str
IP address of the client.
sock : instance of socket.socket
The client socket.
verbose : bool, str, int, or None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).
"""
logger.info("Adding client with ip = %s" % ip)
client = dict(ip=ip, id=len(self._clients), running=False, socket=sock)
self._clients.append(client)
return client['id']
@verbose
def shutdown(self, verbose=None):
"""Shutdown the client and server.
Parameters
----------
verbose : bool, str, int, or None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).
"""
logger.info("Shutting down ...")
# stop running all the clients
if hasattr(self, '_clients'):
for client in self._clients:
client['running'] = False
self._running = False
self._data.shutdown()
self._data.server_close()
self._data.socket.close()
@verbose
def add_trigger(self, trigger, verbose=None):
"""Add a trigger.
Parameters
----------
trigger : int
The trigger to be added to the queue for sending to StimClient.
verbose : bool, str, int, or None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).
See Also
--------
StimClient.get_trigger
"""
for client in self._clients:
client_id = client['id']
logger.info("Sending trigger %d to client %d"
% (trigger, client_id))
client['socket']._tx_queue.put(trigger)
class StimClient(object):
"""Stimulation Client.
Client to communicate with StimServer
Parameters
----------
host : str
Hostname (or IP address) of the host where StimServer is running.
port : int
Port to use for the connection.
timeout : float
Communication timeout in seconds.
verbose : bool, str, int, or None
If not None, override default verbose level (see :func:`mne.verbose`
and :ref:`Logging documentation <tut_logging>` for more).
See Also
--------
StimServer
"""
@verbose
def __init__(self, host, port=4218, timeout=5.0,
verbose=None): # noqa: D102
try:
logger.info("Setting up client socket")
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.settimeout(timeout)
self._sock.connect((host, port))
logger.info("Establishing connection with server")
data = "add client".encode('utf-8')
n_sent = self._sock.send(data)
if n_sent != len(data):
raise RuntimeError('Could not communicate with server')
resp = self._sock.recv(1024).decode() # turn bytes into str (Py3k)
if resp == 'Client added':
logger.info("Connection established")
else:
raise RuntimeError('Client not added')
except Exception:
raise RuntimeError('Setting up acquisition <-> stimulation '
'computer connection (host: %s '
'port: %d) failed. Make sure StimServer '
'is running.' % (host, port))
def close(self):
"""Close the socket object."""
self._sock.close()
@verbose
def get_trigger(self, timeout=5.0, verbose=None):
"""Get triggers from StimServer.
Parameters
----------
timeout : float
maximum time to wait for a valid trigger from the server
verbose : bool, str, int, or None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).
See Also
--------
StimServer.add_trigger
"""
start_time = time.time() # init delay counter. Will stop iterations
while True:
try:
current_time = time.time()
# Raise timeout error
if current_time > (start_time + timeout):
logger.info("received nothing")
return None
self._sock.send("get trigger".encode('utf-8'))
trigger = self._sock.recv(1024)
if trigger != 'Empty':
logger.info("received trigger %s" % str(trigger))
return int(trigger)
except RuntimeError as err:
logger.info('Cannot receive triggers: %s' % (err))