From 381e86c6d863dcbde3d626abebc1519131d94c13 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Fri, 10 Nov 2017 11:28:43 -0800 Subject: [PATCH] CP #4112, #4122: Properly handle side input exception when all reader threads complete --- .../apache_beam/runners/worker/sideinputs.py | 21 +++++++++++-------- .../runners/worker/sideinputs_test.py | 19 +++++++++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/sideinputs.py b/sdks/python/apache_beam/runners/worker/sideinputs.py index bdf9f4e71f5e..6c7831d41365 100644 --- a/sdks/python/apache_beam/runners/worker/sideinputs.py +++ b/sdks/python/apache_beam/runners/worker/sideinputs.py @@ -116,6 +116,7 @@ def _reader_thread(self): self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL) def __iter__(self): + # pylint: disable=too-many-nested-blocks if self.already_iterated: raise RuntimeError( 'Can only iterate once over PrefetchingSourceSetIterable instance.') @@ -128,15 +129,17 @@ def __iter__(self): num_readers_finished = 0 try: while True: - element = self.element_queue.get() - if element is READER_THREAD_IS_DONE_SENTINEL: - num_readers_finished += 1 - if num_readers_finished == self.num_reader_threads: - return - elif self.has_errored: - raise self.reader_exceptions.get() - else: - yield element + try: + element = self.element_queue.get() + if element is READER_THREAD_IS_DONE_SENTINEL: + num_readers_finished += 1 + if num_readers_finished == self.num_reader_threads: + return + else: + yield element + finally: + if self.has_errored: + raise self.reader_exceptions.get() except GeneratorExit: self.has_errored = True raise diff --git a/sdks/python/apache_beam/runners/worker/sideinputs_test.py b/sdks/python/apache_beam/runners/worker/sideinputs_test.py index d243bbe4e6ee..bb688dd1c927 100644 --- a/sdks/python/apache_beam/runners/worker/sideinputs_test.py +++ b/sdks/python/apache_beam/runners/worker/sideinputs_test.py @@ -91,6 +91,24 @@ def test_multiple_sources_single_reader_iterator_fn(self): sources, max_reader_threads=1) assert list(strip_windows(iterator_fn())) == range(11) + def test_source_iterator_single_source_exception(self): + class MyException(Exception): + pass + + def exception_generator(): + yield 0 + raise MyException('I am an exception!') + + sources = [ + FakeSource(exception_generator()), + ] + iterator_fn = sideinputs.get_iterator_fn_for_sources(sources) + seen = set() + with self.assertRaises(MyException): + for value in iterator_fn(): + seen.add(value.value) + self.assertEqual(sorted(seen), [0]) + def test_source_iterator_fn_exception(self): class MyException(Exception): pass @@ -103,6 +121,7 @@ def exception_generator(): def perpetual_generator(value): while True: yield value + time.sleep(0.1) sources = [ FakeSource(perpetual_generator(1)),