From be09a162e32d158f5ae043e064223bb4f3742648 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Wed, 14 Jun 2017 16:14:50 -0700 Subject: [PATCH] Migrate DirectRunner evaluators to use Beam state API --- .../runners/dataflow/native_io/iobase_test.py | 39 +++++++++- .../runners/direct/evaluation_context.py | 56 ++++++++++---- .../runners/direct/transform_evaluator.py | 74 ++++++++++--------- .../runners/direct/transform_result.py | 3 +- 4 files changed, 122 insertions(+), 50 deletions(-) diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py index 7610baff6b47..3d8c24f5651c 100644 --- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py +++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py @@ -20,7 +20,9 @@ import unittest -from apache_beam import error, pvalue +from apache_beam import Create +from apache_beam import error +from apache_beam import pvalue from apache_beam.runners.dataflow.native_io.iobase import ( _dict_printable_fields, _NativeWrite, @@ -28,10 +30,12 @@ DynamicSplitRequest, DynamicSplitResultWithPosition, NativeSink, + NativeSinkWriter, NativeSource, ReaderPosition, ReaderProgress ) +from apache_beam.testing.test_pipeline import TestPipeline class TestHelperFunctions(unittest.TestCase): @@ -154,6 +158,39 @@ def __init__(self, validate=False, dataset=None, project=None, fake_sink = FakeSink() self.assertEqual(fake_sink.__repr__(), "") + def test_on_direct_runner(self): + class FakeSink(NativeSink): + """A fake sink outputing a number of elements.""" + + def __init__(self): + self.written_values = [] + self.writer_instance = FakeSinkWriter(self.written_values) + + def writer(self): + return self.writer_instance + + class FakeSinkWriter(NativeSinkWriter): + """A fake sink writer for testing.""" + + def __init__(self, written_values): + self.written_values = written_values + + def __enter__(self): + return self + + def __exit__(self, *unused_args): + pass + + def Write(self, value): + self.written_values.append(value) + + p = TestPipeline() + sink = FakeSink() + p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned + p.run() + + self.assertEqual(['a', 'b', 'c'], sink.written_values) + class Test_NativeWrite(unittest.TestCase): diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index 68d99d373a7e..8fa8e06922d0 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -27,22 +27,22 @@ from apache_beam.runners.direct.watermark_manager import WatermarkManager from apache_beam.runners.direct.executor import TransformExecutor from apache_beam.runners.direct.direct_metrics import DirectMetrics +from apache_beam.transforms.trigger import InMemoryUnmergedState from apache_beam.utils import counters class _ExecutionContext(object): - def __init__(self, watermarks, existing_state): - self._watermarks = watermarks - self._existing_state = existing_state + def __init__(self, watermarks, keyed_states): + self.watermarks = watermarks + self.keyed_states = keyed_states - @property - def watermarks(self): - return self._watermarks + self._step_context = None - @property - def existing_state(self): - return self._existing_state + def get_step_context(self): + if not self._step_context: + self._step_context = DirectStepContext(self.keyed_states) + return self._step_context class _SideInputView(object): @@ -145,9 +145,8 @@ def __init__(self, pipeline_options, bundle_factory, root_transforms, self._pcollection_to_views = collections.defaultdict(list) for view in views: self._pcollection_to_views[view.pvalue].append(view) - - # AppliedPTransform -> Evaluator specific state objects - self._application_state_interals = {} + self._transform_keyed_states = self._initialize_keyed_states( + root_transforms, value_to_consumers) self._watermark_manager = WatermarkManager( Clock(), root_transforms, value_to_consumers) self._side_inputs_container = _SideInputsContainer(views) @@ -158,6 +157,15 @@ def __init__(self, pipeline_options, bundle_factory, root_transforms, self._lock = threading.Lock() + def _initialize_keyed_states(self, root_transforms, value_to_consumers): + transform_keyed_states = {} + for transform in root_transforms: + transform_keyed_states[transform] = {} + for consumers in value_to_consumers.values(): + for consumer in consumers: + transform_keyed_states[consumer] = {} + return transform_keyed_states + def use_pvalue_cache(self, cache): assert not self._cache self._cache = cache @@ -231,7 +239,6 @@ def handle_result( counter.name, counter.combine_fn) merged_counter.accumulator.merge([counter.accumulator]) - self._application_state_interals[result.transform] = result.state return committed_bundles def get_aggregator_values(self, aggregator_or_name): @@ -256,7 +263,7 @@ def _commit_bundles(self, uncommitted_bundles): def get_execution_context(self, applied_ptransform): return _ExecutionContext( self._watermark_manager.get_watermarks(applied_ptransform), - self._application_state_interals.get(applied_ptransform)) + self._transform_keyed_states[applied_ptransform]) def create_bundle(self, output_pcollection): """Create an uncommitted bundle for the specified PCollection.""" @@ -296,3 +303,24 @@ def get_value_or_schedule_after_output(self, side_input, task): assert isinstance(task, TransformExecutor) return self._side_inputs_container.get_value_or_schedule_after_output( side_input, task) + + +class DirectUnmergedState(InMemoryUnmergedState): + """UnmergedState implementation for the DirectRunner.""" + + def __init__(self): + super(DirectUnmergedState, self).__init__(defensive_copy=False) + + +class DirectStepContext(object): + """Context for the currently-executing step.""" + + def __init__(self, keyed_existing_state): + self.keyed_existing_state = keyed_existing_state + + def get_keyed_state(self, key): + # TODO(ccy): consider implementing transactional copy on write semantics + # for state so that work items can be safely retried. + if not self.keyed_existing_state.get(key): + self.keyed_existing_state[key] = DirectUnmergedState() + return self.keyed_existing_state[key] diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index b1cb626ca0cb..f5b5db5c0a77 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -33,6 +33,8 @@ from apache_beam.transforms import core from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue +from apache_beam.transforms.trigger import _CombiningValueStateTag +from apache_beam.transforms.trigger import _ListStateTag from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn from apache_beam.typehints.typecheck import TypeCheckError from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn @@ -207,7 +209,7 @@ def _read_values_to_bundles(reader): bundles = _read_values_to_bundles(reader) return TransformResult( - self._applied_ptransform, bundles, None, None, None, None) + self._applied_ptransform, bundles, None, None, None) class _FlattenEvaluator(_TransformEvaluator): @@ -231,7 +233,7 @@ def process_element(self, element): def finish_bundle(self): bundles = [self.bundle] return TransformResult( - self._applied_ptransform, bundles, None, None, None, None) + self._applied_ptransform, bundles, None, None, None) class _TaggedReceivers(dict): @@ -320,7 +322,7 @@ def finish_bundle(self): bundles = self._tagged_receivers.values() result_counters = self._counter_factory.get_counters() return TransformResult( - self._applied_ptransform, bundles, None, None, result_counters, None, + self._applied_ptransform, bundles, None, result_counters, None, self._tagged_receivers.undeclared_in_memory_tag_values) @@ -328,13 +330,8 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator): """TransformEvaluator for _GroupByKeyOnly transform.""" MAX_ELEMENT_PER_BUNDLE = None - - class _GroupByKeyOnlyEvaluatorState(object): - - def __init__(self): - # output: {} key -> [values] - self.output = collections.defaultdict(list) - self.completed = False + ELEMENTS_TAG = _ListStateTag('elements') + COMPLETION_TAG = _CombiningValueStateTag('completed', any) def __init__(self, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs, scoped_metrics_container): @@ -349,9 +346,8 @@ def _is_final_bundle(self): == WatermarkManager.WATERMARK_POS_INF) def start_bundle(self): - self.state = (self._execution_context.existing_state - if self._execution_context.existing_state - else _GroupByKeyOnlyEvaluator._GroupByKeyOnlyEvaluatorState()) + self.step_context = self._execution_context.get_step_context() + self.global_state = self.step_context.get_keyed_state(None) assert len(self._outputs) == 1 self.output_pcollection = list(self._outputs)[0] @@ -362,12 +358,15 @@ def start_bundle(self): self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0]) def process_element(self, element): - assert not self.state.completed + assert not self.global_state.get_state( + None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG) if (isinstance(element, WindowedValue) and isinstance(element.value, collections.Iterable) and len(element.value) == 2): k, v = element.value - self.state.output[self.key_coder.encode(k)].append(v) + encoded_k = self.key_coder.encode(k) + state = self.step_context.get_keyed_state(encoded_k) + state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v) else: raise TypeCheckError('Input to _GroupByKeyOnly must be a PCollection of ' 'windowed key-value pairs. Instead received: %r.' @@ -375,15 +374,23 @@ def process_element(self, element): def finish_bundle(self): if self._is_final_bundle: - if self.state.completed: + if self.global_state.get_state( + None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG): # Ignore empty bundles after emitting output. (This may happen because # empty bundles do not affect input watermarks.) bundles = [] else: - gbk_result = ( - map(GlobalWindows.windowed_value, ( - (self.key_coder.decode(k), v) - for k, v in self.state.output.iteritems()))) + gbk_result = [] + # TODO(ccy): perhaps we can clean this up to not use this + # internal attribute of the DirectStepContext. + for encoded_k in self.step_context.keyed_existing_state: + # Ignore global state. + if encoded_k is None: + continue + k = self.key_coder.decode(encoded_k) + state = self.step_context.get_keyed_state(encoded_k) + vs = state.get_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG) + gbk_result.append(GlobalWindows.windowed_value((k, vs))) def len_element_fn(element): _, v = element.value @@ -393,21 +400,22 @@ def len_element_fn(element): self.output_pcollection, gbk_result, _GroupByKeyOnlyEvaluator.MAX_ELEMENT_PER_BUNDLE, len_element_fn) - self.state.completed = True - state = self.state + self.global_state.add_state( + None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG, True) hold = WatermarkManager.WATERMARK_POS_INF else: bundles = [] - state = self.state hold = WatermarkManager.WATERMARK_NEG_INF return TransformResult( - self._applied_ptransform, bundles, state, None, None, hold) + self._applied_ptransform, bundles, None, None, hold) class _NativeWriteEvaluator(_TransformEvaluator): """TransformEvaluator for _NativeWrite transform.""" + ELEMENTS_TAG = _ListStateTag('elements') + def __init__(self, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs, scoped_metrics_container): assert not side_inputs @@ -429,12 +437,12 @@ def _has_already_produced_output(self): == WatermarkManager.WATERMARK_POS_INF) def start_bundle(self): - # state: [values] - self.state = (self._execution_context.existing_state - if self._execution_context.existing_state else []) + self.step_context = self._execution_context.get_step_context() + self.global_state = self.step_context.get_keyed_state(None) def process_element(self, element): - self.state.append(element) + self.global_state.add_state( + None, _NativeWriteEvaluator.ELEMENTS_TAG, element) def finish_bundle(self): # finish_bundle will append incoming bundles in memory until all the bundles @@ -444,19 +452,19 @@ def finish_bundle(self): # ignored and would not generate additional output files. # TODO(altay): Do not wait until the last bundle to write in a single shard. if self._is_final_bundle: + elements = self.global_state.get_state( + None, _NativeWriteEvaluator.ELEMENTS_TAG) if self._has_already_produced_output: # Ignore empty bundles that arrive after the output is produced. - assert self.state == [] + assert elements == [] else: self._sink.pipeline_options = self._evaluation_context.pipeline_options with self._sink.writer() as writer: - for v in self.state: + for v in elements: writer.Write(v.value) - state = None hold = WatermarkManager.WATERMARK_POS_INF else: - state = self.state hold = WatermarkManager.WATERMARK_NEG_INF return TransformResult( - self._applied_ptransform, [], state, None, None, hold) + self._applied_ptransform, [], None, None, hold) diff --git a/sdks/python/apache_beam/runners/direct/transform_result.py b/sdks/python/apache_beam/runners/direct/transform_result.py index febdd202aa0a..51593e3a434b 100644 --- a/sdks/python/apache_beam/runners/direct/transform_result.py +++ b/sdks/python/apache_beam/runners/direct/transform_result.py @@ -25,12 +25,11 @@ class TransformResult(object): The result of evaluating an AppliedPTransform with a TransformEvaluator.""" - def __init__(self, applied_ptransform, uncommitted_output_bundles, state, + def __init__(self, applied_ptransform, uncommitted_output_bundles, timer_update, counters, watermark_hold, undeclared_tag_values=None): self.transform = applied_ptransform self.uncommitted_output_bundles = uncommitted_output_bundles - self.state = state # TODO: timer update is currently unused. self.timer_update = timer_update self.counters = counters