/
mailbox.py
527 lines (435 loc) · 19.3 KB
/
mailbox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
# pylint: disable=redefined-builtin
from concurrent.futures import Future, TimeoutError
import heapq
import sys
import threading
import typing
import logging
from strax.utils import exporter
export, __all__ = exporter()
@export
class MailboxException(Exception):
pass
@export
class MailboxReadTimeout(MailboxException):
pass
@export
class MailboxFullTimeout(MailboxException):
pass
@export
class InvalidMessageNumber(MailboxException):
pass
@export
class MailBoxAlreadyClosed(MailboxException):
pass
@export
class MailboxKilled(MailboxException):
pass
@export
class Mailbox:
"""Publish/subscribe mailbox for builing complex pipelines out of simple iterators, using
multithreading.
A sender can be any iterable. To read from the mailbox, either:
1. Use .subscribe() to get an iterator.
You can only use one of these per thread.
2. Use .add_subscriber(f) to subscribe the function f.
f should take an iterator as its first argument
(and actually iterate over it, of course).
Each sender and receiver is wrapped in a thread, so they can be paused:
- senders, if the mailbox is full;
- readers, if they call next() but the next message is not yet available.
Any futures sent in are awaited before they are passed to receivers.
Exceptions in a sender cause MailboxKilled to be raised in each reader.
If the reader doesn't catch this, and it writes to another mailbox,
this therefore kills that mailbox (raises MailboxKilled for each reader)
as well. Thus MailboxKilled exceptions travel downstream in pipelines.
Sender threads are not killed by exceptions raise in readers.
To kill sender threads too, use .kill(upstream=True). Even this does not
propagate further upstream than the immediate sender threads.
"""
# In strax, these are overriden by context options
# 'timeout' and 'max_messages'. They are here only to support
# creating mailboxes directly without strax.
DEFAULT_TIMEOUT = 300
DEFAULT_MAX_MESSAGES = 4
def __init__(self, name="mailbox", timeout=None, lazy=False, max_messages=None):
self.name = name
if timeout is None:
timeout = self.DEFAULT_TIMEOUT
self.timeout = timeout
if max_messages is None:
max_messages = self.DEFAULT_MAX_MESSAGES
self.max_messages = max_messages
self.lazy = lazy
if self.lazy:
self.max_messages = float("inf")
self.closed = False
self.force_killed = False
self.killed = False
self.killed_because = None
self._mailbox = []
self._subscribers_have_read = []
self._subscriber_waiting_for = []
self._subscriber_can_drive = []
self._n_sent = 0
self._threads = []
self._lock = threading.RLock()
self.log = logging.getLogger(self.name)
# Conditions to wait on
# Do NOT call notify_all when the condition is False!
# We use wait_for, which also returns False when the timeout is broken
# (Is this an odd design decision in the standard library
# or am I misunderstanding something?)
class Condition:
"""Small helper class which wraps "threading.Condition" to get some useful logging
information for debugging."""
def __init__(self, name, log, lock):
self.log = log
self._lock = lock
self.name = name
self.log.debug(f'Initialize "{name}" with lock state: {lock}.')
self.threading_condition = threading.Condition(lock=lock)
def notify_all(self):
self.log.debug(f"Notifying all for {self.name} with lock state: {self._lock}")
self.threading_condition.notify_all()
def wait_for(self, *args, **kwargs):
self.log.debug(
f'Waiting for a change in "{self.name}" with state {args[0]} and lock '
f"state: {self._lock}"
)
return self.threading_condition.wait_for(*args, **kwargs)
# If you're waiting to read a new message that hasn't yet arrived:
self._read_condition = Condition("_read_condition", self.log, lock=self._lock)
# If you're waiting to write a new message because the mailbox is full
self._write_condition = Condition("_write_condition", self.log, lock=self._lock)
# If you're waiting to fetch a new element because the subscribers
# stil have other things to do
self._fetch_new_condition = Condition("_fetch_new_condition", self.log, lock=self._lock)
self.log.debug("Initialized")
def add_sender(self, source, name=None):
"""Configure mailbox to read from an iterable source.
:param source: Iterable to read from
:param name: Name of the thread in which the function will run. Defaults to
source:<mailbox_name>
"""
if name is None:
name = f"source:{self.name}"
t = threading.Thread(target=self._send_from, name=name, args=(source,))
self._threads.append(t)
def add_reader(self, subscriber, name=None, can_drive=True, **kwargs):
"""Subscribe a function to the mailbox.
:param subscriber: Function which accepts a generator over messages as the first argument.
Any kwargs will also be passed to the function.
:param name: Name of the thread in which the function will run. Defaults to
read_<number>:<mailbox_name>
:param can_drive: Whether this reader can cause new messages to be generated when in lazy
mode.
"""
if name is None:
name = f"read_{self._n_subscribers}:{self.name}"
t = threading.Thread(
target=subscriber, name=name, args=(self.subscribe(can_drive=can_drive),), kwargs=kwargs
)
self._threads.append(t)
def subscribe(self, can_drive=True):
"""Return generator over messages in the mailbox."""
with self._lock:
subscriber_i = self._n_subscribers
self._subscriber_can_drive.append(can_drive)
self._subscribers_have_read.append(-1)
self._subscriber_waiting_for.append(None)
self.log.debug(f"Added subscriber {subscriber_i}")
return self._read(subscriber_i=subscriber_i)
def start(self):
if not self._n_subscribers:
raise ValueError(f"Attempt to start mailbox {self.name} without subscribers")
for t in self._threads:
t.start()
def kill_from_exception(self, e, reraise=True):
"""Kill the mailbox following a caught exception e."""
if isinstance(e, MailboxKilled):
# Kill this mailbox too.
self.log.debug("Propagating MailboxKilled exception")
self.kill(reason=e.args[0])
# Do NOT raise! One traceback on the screen is enough.
else:
self.log.debug(f"Killing mailbox due to exception {e}!")
self.kill(reason=(e.__class__, e, sys.exc_info()[2]))
if reraise:
raise e
def kill(self, upstream=True, reason=None):
with self._lock:
self.log.debug(f"Kill received by {self.name}")
if upstream:
self.force_killed = True
if self.killed:
self.log.debug(f"Double kill on {self.name} = NOP")
return
self.killed = True
self.killed_because = reason
self._read_condition.notify_all()
self._write_condition.notify_all()
self._fetch_new_condition.notify_all()
def cleanup(self):
for t in self._threads:
t.join(timeout=self.timeout)
if t.is_alive():
raise RuntimeError("Thread %s did not terminate!" % t.name)
def _can_fetch(self):
"""Return if we can fetch then send the next element from the source.
If not, it returns None (to distinguish from False, which means the timeout was broken)
"""
assert self.lazy
# The .send() knows how to handle the exception properly
# (if we raise here we will likely duplicate the exception)
if self.killed:
return True
# If someone is still waiting for a message we already have
# (so they just haven't woken up yet), don't fetch a new message.
if len(self._mailbox) and any(
[x is not None and x <= self._lowest_msg_number for x in self._subscriber_waiting_for]
):
return False
# Everyone is waiting for the new chunk or not at all.
# Fetch only if a driver is waiting.
for _i, waiting_for in enumerate(self._subscriber_waiting_for):
if self._subscriber_can_drive[_i] and waiting_for is not None:
return True
return False
def _send_from(self, iterable):
"""Send to mailbox from iterable, exiting appropriately if an exception is thrown."""
try:
i = 0
while True:
if self.lazy:
with self._lock:
if not self._can_fetch():
self.log.debug(
f"Waiting to fetch {i}, "
f"{self._subscriber_waiting_for}, "
f"{self._subscriber_can_drive}"
)
if not self._fetch_new_condition.wait_for(
self._can_fetch, timeout=self.timeout
):
raise MailboxReadTimeout(
f"{self} could not progress beyond {i}, "
"no driving subscriber requested it."
)
try:
x = next(iterable)
except StopIteration:
# No need to send this yet, close will do that
break
try:
self.send(x)
except Exception as e:
# Inform the source we're going down
iterable.throw(e)
raise
i += 1
except Exception as e:
self.kill_from_exception(e)
else:
self.log.debug("Producing iterable exhausted, regular stop")
self.close()
def send(self, msg, msg_number: typing.Union[int, None] = None):
"""Send a message.
If the message is a future, receivers will be passed its result. (possibly waiting for
completion if needed)
If the mailbox is currently full, sleep until there is room for your message (or timeout
occurs)
"""
with self._lock:
if self.closed:
raise MailBoxAlreadyClosed(f"Can't send to closed {self.name}")
if self.force_killed:
self.log.debug(f"Sender found {self.name} force-killed")
raise MailboxKilled(self.killed_because)
if self.killed:
self.log.debug("Send to killed mailbox: message lost")
return
# We accept int numbers or anything which equals to it's int(...)
# (like numpy integers)
if msg_number is None:
msg_number = self._n_sent
try:
int(msg_number)
assert msg_number == int(msg_number)
except (ValueError, AssertionError):
raise InvalidMessageNumber("Msg numbers must be integers")
read_until = min(self._subscribers_have_read, default=-1)
if msg_number <= read_until:
raise InvalidMessageNumber(
f"Attempt to send message {msg_number} while "
f"subscribers already read {read_until}."
)
def can_write():
return len(self._mailbox) < self.max_messages or self.killed
if not can_write():
self.log.debug("Subscribers have read: " + str(self._subscribers_have_read))
self.log.debug(f"Mailbox full, wait to send {msg_number}")
if not self._write_condition.wait_for(can_write, timeout=self.timeout):
raise MailboxFullTimeout(f"Mailbox buffer for {self.name} emptied too slow.")
if self.killed:
self.log.debug(
f"Sender found {self.name} killed while waiting for room for new messages."
)
if self.force_killed:
raise MailboxKilled(self.killed_because)
return
heapq.heappush(self._mailbox, (msg_number, msg))
self.log.debug(f"Sent {msg_number}")
self._n_sent += 1
self._read_condition.notify_all()
def close(self):
self.log.debug(f"Closing; sending StopIteration")
with self._lock:
self.send(StopIteration)
self.closed = True
self.log.debug(f"Closed to incoming messages")
def _read(self, subscriber_i):
"""Iterate over incoming messages in order.
Your thread will sleep until the next message is available, or timeout expires (in which
case MailboxReadTimeout is raised)
"""
self.log.debug("Start reading")
next_number = 0
last_message = False
while not last_message:
with self._lock:
# Wait for new messages
def next_ready():
return self._has_msg(next_number) or self.killed
if not next_ready():
self.log.debug(f"Checking/waiting for {next_number}")
self._subscriber_waiting_for[subscriber_i] = next_number
if self.lazy and self._can_fetch():
self._fetch_new_condition.notify_all()
if not self._read_condition.wait_for(next_ready, self.timeout):
raise MailboxReadTimeout(f"{self.name} did not get {next_number} in time.")
self._subscriber_waiting_for[subscriber_i] = None
if self.killed:
self.log.debug(f"Reader finds {self.name} killed")
raise MailboxKilled(self.killed_because)
# Grab all messages we can yield
to_yield = []
while self._has_msg(next_number):
msg = self._get_msg(next_number)
if msg is StopIteration:
self.log.debug(f"{next_number} is StopIteration")
last_message = True
to_yield.append((next_number, msg))
next_number += 1
if len(to_yield) > 1:
self.log.debug(
f"Read {to_yield[0][0]}-{to_yield[-1][0]} in subscriber {subscriber_i}"
)
else:
self.log.debug(f"Read {to_yield[0][0]} in subscriber {subscriber_i}")
self._subscribers_have_read[subscriber_i] = next_number - 1
# Clean up the mailbox
while len(self._mailbox) and (
min(self._subscribers_have_read) >= self._lowest_msg_number
):
heapq.heappop(self._mailbox)
if self.lazy and self._can_fetch():
self._fetch_new_condition.notify_all()
self._write_condition.notify_all()
for msg_number, msg in to_yield:
if msg is StopIteration:
break
elif isinstance(msg, Future):
if not msg.done():
self.log.debug(f"Waiting for future {msg_number}")
try:
res = msg.result(timeout=self.timeout)
except TimeoutError:
raise TimeoutError(f"Future {msg_number} timed out!")
self.log.debug(f"Future {msg_number} completed")
else:
res = msg.result()
self.log.debug(f"Future {msg_number} was already done")
else:
res = msg
try:
yield res
except Exception as e:
# TODO: Should I also handle timeout errors like this?
self.kill_from_exception(e)
self.log.debug("Done reading")
def __repr__(self):
return f"<{self.__class__.__name__}: {self.name}>"
def _get_msg(self, number):
for msg_number, msg in self._mailbox:
if msg_number == number:
return msg
raise RuntimeError(f"Could not find message {number}")
def _has_msg(self, number):
"""Return if mailbox has message number.
Also returns True if mailbox is killed, so be sure to check self.killed after this!
"""
if self.killed:
return True
return any([msg_number == number for msg_number, _ in self._mailbox])
@property
def _n_subscribers(self):
return len(self._subscribers_have_read)
@property
def _lowest_msg_number(self):
return self._mailbox[0][0]
@export
def divide_outputs(
source, mailboxes: typing.Dict[str, Mailbox], lazy=False, flow_freely=tuple(), outputs=None
):
"""This code is a 'mail sorter' which gets dicts of arrays from source and sends the right array
to the right mailbox."""
# raise ZeroDivisionError # TODO: check this is handled properly
if outputs is None:
outputs = mailboxes.keys()
mbs_to_kill = [mailboxes[d] for d in outputs]
# TODO: this code duplicates exception handling and cleanup
# from Mailbox.send_from! Can we avoid that somehow?
i = 0
try:
while True:
for d in outputs:
m = mailboxes[d]
if d in flow_freely:
# Do not block on account of these guys
m.log.debug(f"Not locking {d}")
continue
if lazy:
with m._lock:
if not m._can_fetch():
m.log.debug(
f"Waiting to fetch {i}, "
f"{m._subscriber_waiting_for}, "
f"{m._subscriber_can_drive}"
)
if not m._fetch_new_condition.wait_for(m._can_fetch, timeout=m.timeout):
raise MailboxReadTimeout(
f"{m} could not progress beyond {i}, "
"no driving subscriber requested it."
)
try:
result = next(source)
except StopIteration:
# No need to send this yet, close will do that
break
try:
for d, x in result.items():
mailboxes[d].send(x)
except Exception as e:
# Inform the source we're going down
source.throw(e)
raise
i += 1
except Exception as e:
for m in mbs_to_kill:
m.kill_from_exception(e, reraise=False)
if not isinstance(e, MailboxKilled):
raise
else:
for m in mbs_to_kill:
m.close()