diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index ba039272fdb26..00e05f60b3e3e 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -139,23 +139,16 @@ def done(self): self._state = self.DONE_STATE -class _PartitionBuffer(object): - """This class is created to support partition(n) function.""" - def __init__(self, inputs): - self._inputs = inputs - +class _ListBuffer(list): + """Used to support parititioning of a list.""" def partition(self, n): - v = list(self._inputs.values())[0] - if isinstance(v, list): - return [self._inputs] - elif isinstance(v, _GroupingBuffer): - partitions = [] - for name, input in self._inputs.items(): - for part in input.partition(n): - partitions.append({name : part}) - return partitions - else: - raise NotImplementedError(type(self._inputs)) + n = min(n, len(self)) + # for empty list, return iter([[]]) + groups = [[] for _ in range(max(n, 1))] + for idx, input in enumerate(self): + groups[idx % n].append(input) + + return iter(groups) class _GroupingBuffer(object): @@ -184,7 +177,11 @@ def append(self, elements_data): value if is_trivial_windowing else windowed_key_value.with_value(value)) - def _output_ready(self): + def partition(self, n): + """ It is used to partition _GroupingBuffer to N parts. Once it is + partitioned, it would not be re-partitioned with diff N. Re-partition + is not supported now. + """ if len(self._grouped_output) == 0: if self._windowing.is_default(): globally_window = GlobalWindows.windowed_value(None).with_value @@ -197,32 +194,33 @@ def _output_ready(self): # May need to revise. trigger_driver = trigger.create_trigger_driver(self._windowing, True) windowed_key_values = trigger_driver.process_entire_key - - key_coder_impl = self._key_coder.get_impl() coder_impl = self._post_grouped_coder.get_impl() - for encoded_key, windowed_values in self._table.items(): + key_coder_impl = self._key_coder.get_impl() + n = min(n, len(self._table)) + output_stream_list = [] + for _ in range(n): + output_stream_list.append(create_OutputStream()) + for idx, (encoded_key, windowed_values) in enumerate(self._table.items()): key = key_coder_impl.decode(encoded_key) - output_stream = create_OutputStream() for wkvs in windowed_key_values(key, windowed_values): - coder_impl.encode_to_stream(wkvs, output_stream, True) - - self._grouped_output.append(output_stream.get()) + coder_impl.encode_to_stream(wkvs, output_stream_list[idx % n], True) + for output_stream in output_stream_list: + self._grouped_output.append([output_stream.get()]) self._table = None - - def __iter__(self): - self._output_ready() return iter(self._grouped_output) - def partition(self, n): - self._output_ready() - - n = min(n, len(self._grouped_output)) - partitions = [[] for _ in range(n)] - - for idx, out in enumerate(self._grouped_output): - partitions[idx % n].append(out) - - return partitions + def __iter__(self): + """ Since partition() returns a list of list, added this __iter__ to return + a list to simplify code when we need to iterate through ALL elements of + _GroupingBuffer. + """ + if len(self._grouped_output) == 0: + self.partition(1) + iter_result = [] + for output in self._grouped_output: + for out in output: + iter_result.append(out) + return iter(iter_result) class _WindowGroupingBuffer(object): @@ -268,13 +266,11 @@ def encoded_items(self): class FnApiRunner(runner.PipelineRunner): - _num_workers = 1 def __init__( self, default_environment=None, bundle_repeat=0, - num_workers=1, use_state_iterables=False, provision_info=None): """Creates a new Fn API Runner. @@ -288,12 +284,12 @@ def __init__( provision_info: provisioning info to make available to workers, or None """ super(FnApiRunner, self).__init__() - FnApiRunner._num_workers = num_workers self._last_uid = -1 self._default_environment = ( default_environment or beam_runner_api_pb2.Environment(urn=python_urns.EMBEDDED_PYTHON)) self._bundle_repeat = bundle_repeat + self._num_workers = 1 self._progress_frequency = None self._profiler_factory = None self._use_state_iterables = use_state_iterables @@ -322,8 +318,8 @@ def run_pipeline(self, pipeline, options): pipeline.visit(DataflowRunner.group_by_key_input_visitor()) self._bundle_repeat = self._bundle_repeat or options.view_as( pipeline_options.DirectOptions).direct_runner_bundle_repeat - FnApiRunner._num_workers = max(FnApiRunner._num_workers, options.view_as( - pipeline_options.DirectOptions).direct_num_workers) + self._num_workers = options.view_as( + pipeline_options.DirectOptions).direct_num_workers or self._num_workers self._profiler_factory = profiler.Profile.factory_from_options( options.view_as(pipeline_options.ProfilingOptions)) @@ -413,7 +409,7 @@ def run_stages(self, stage_context, stages): try: with self.maybe_profile(): - pcoll_buffers = collections.defaultdict(list) + pcoll_buffers = collections.defaultdict(_ListBuffer) for stage in stages: stage_results = self._run_stage( worker_handler_manager.get_worker_handler, @@ -455,6 +451,7 @@ def _store_side_inputs_in_state(self, def _run_bundle_multiple_times_for_testing(self, controller, process_bundle_descriptor, + num_workers, data_input, data_output, get_input_coder_callable): @@ -463,7 +460,7 @@ def _run_bundle_multiple_times_for_testing(self, controller.state.checkpoint() ParallelBundleManager( controller, lambda pcoll_id: [], get_input_coder_callable, - process_bundle_descriptor, self._progress_frequency, k + process_bundle_descriptor, num_workers, self._progress_frequency, k ).process_bundle(data_input, data_output) finally: controller.state.restore() @@ -499,7 +496,7 @@ def _collect_written_timers_and_add_to_deferred_inputs(self, for windowed_key_timer in timers_by_key_and_window.values(): windowed_timer_coder_impl.encode_to_stream( windowed_key_timer, out, True) - deferred_inputs[transform_id] = [out.get()] + deferred_inputs[transform_id] = _ListBuffer([out.get()]) written_timers[:] = [] def _add_residuals_and_channel_splits_to_deferred_inputs( @@ -554,7 +551,7 @@ def _extract_stage_data_endpoints( if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER: - data_input[target] = [ENCODED_IMPULSE_VALUE] + data_input[target] = _ListBuffer([ENCODED_IMPULSE_VALUE]) else: data_input[target] = pcoll_buffers[pcoll_id] coder_id = pipeline_components.pcollections[ @@ -645,7 +642,7 @@ def get_buffer(buffer_id): """Returns the buffer for a given (operation_type, PCollection ID). For grouping-typed operations, we produce a ``_GroupingBuffer``. For - others, we just use a list. + others, we produce a ``_ListBuffer``. """ kind, name = split_buffer_id(buffer_id) if kind in ('materialize', 'timers'): @@ -684,13 +681,14 @@ def get_input_coder_impl(transform_id): self._run_bundle_multiple_times_for_testing(controller, process_bundle_descriptor, + self._num_workers, data_input, data_output, get_input_coder_impl) bundle_manager = ParallelBundleManager( controller, get_buffer, get_input_coder_impl, process_bundle_descriptor, - self._progress_frequency) + self._num_workers, self._progress_frequency) result, splits = bundle_manager.process_bundle(data_input, data_output) @@ -708,7 +706,7 @@ def input_for(ptransform_id, input_id): last_sent = data_input while True: - deferred_inputs = collections.defaultdict(list) + deferred_inputs = collections.defaultdict(_ListBuffer) self._collect_written_timers_and_add_to_deferred_inputs( context, pipeline_components, stage, get_buffer, deferred_inputs) @@ -729,15 +727,14 @@ def input_for(ptransform_id, input_id): # The worker will be waiting on these inputs as well. for other_input in data_input: if other_input not in deferred_inputs: - deferred_inputs[other_input] = [] + deferred_inputs[other_input] = _ListBuffer([]) # TODO(robertwb): merge results - last_result, splits = ParallelBundleManager( - controller, - get_buffer, - get_input_coder_impl, - process_bundle_descriptor, - self._progress_frequency, - True).process_bundle(deferred_inputs, data_output) + bundle_manager._skip_registration = True + # We cannot split deferred_input until we include residual_roots to + # merged results. Without residual_roots, pipeline stops earlier and we + # may miss some data. + last_result, splits = bundle_manager.process_bundle( + deferred_inputs, data_output, num_workers=1) last_sent = deferred_inputs result = beam_fn_api_pb2.InstructionResponse( process_bundle=beam_fn_api_pb2.ProcessBundleResponse( @@ -783,7 +780,8 @@ def _extract_endpoints(stage, pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER: - data_input[transform.unique_name] = [ENCODED_IMPULSE_VALUE] + data_input[transform.unique_name] = _ListBuffer( + [ENCODED_IMPULSE_VALUE]) else: data_input[transform.unique_name] = pcoll_buffers[pcoll_id] coder_id = pipeline_components.pcollections[ @@ -1027,7 +1025,6 @@ def push(self, request): if not request.instruction_id: self._uid_counter += 1 request.instruction_id = 'control_%s' % self._uid_counter - response = self.worker.do_instruction(request) return ControlFuture(request.instruction_id, response) @@ -1370,7 +1367,7 @@ class BundleManager(object): def __init__( self, controller, get_buffer, get_input_coder_impl, bundle_descriptor, - progress_frequency=None, skip_registration=False): + num_workers, progress_frequency=None, skip_registration=False): """Set up a bundle manager. Args: @@ -1384,6 +1381,7 @@ def __init__( self._controller = controller self._get_buffer = get_buffer self._get_input_coder_impl = get_input_coder_impl + self._num_workers = num_workers self._bundle_descriptor = bundle_descriptor self._registered = skip_registration self._progress_frequency = progress_frequency @@ -1485,12 +1483,9 @@ def _generate_splits_for_testing(self, break return split_results - def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None): + def process_bundle(self, inputs, expected_outputs): # Unique id for the instruction processing this bundle. - if parallel_uid_counter: - BundleManager._uid_counter = parallel_uid_counter - else: - BundleManager._uid_counter += 1 + BundleManager._uid_counter += 1 process_bundle_id = 'bundle_%s' % BundleManager._uid_counter # Register the bundle descriptor, if needed - noop if already registered. @@ -1500,7 +1495,8 @@ def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None): if not split_manager: # If there is no split_manager, write all input data to the channel. for transform_id, elements in inputs.items(): - self._send_input_to_worker(process_bundle_id, transform_id, elements) + self._send_input_to_worker( + process_bundle_id, transform_id, elements) # Check that the bundle was successfully registered. if registration_future and registration_future.get().error: @@ -1518,8 +1514,8 @@ def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None): self._controller, process_bundle_id, self._progress_frequency): if split_manager: - split_results = self._generate_splits_for_testing( - split_manager, inputs, process_bundle_id) + split_results = self._generate_splits_for_testing(split_manager, inputs, + process_bundle_id) # Gather all output data. for output in self._controller.data_plane_handler.input_elements( @@ -1550,20 +1546,30 @@ def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None): class ParallelBundleManager(BundleManager): - _uid_counter = 0 - def process_bundle(self, inputs, expected_outputs): - num_workers = FnApiRunner._num_workers + def _check_inputs_split(self, expected_outputs): + # We skip splitting inputs when timer is set, because operations are not + # triggered until we sent inputs for timers. + for _, pcoll_id in expected_outputs.items(): + kind = split_buffer_id(pcoll_id)[0] + if kind in ['timers']: + return False + + return True + + def process_bundle(self, inputs, expected_outputs, num_workers=None): + num_workers = num_workers or self._num_workers param_list = [] - for part in _PartitionBuffer(inputs).partition(num_workers): - ParallelBundleManager._uid_counter += 1 - param_list.append((part, expected_outputs, - ParallelBundleManager._uid_counter)) + if self._check_inputs_split(expected_outputs): + for name, input in inputs.items(): + for part in input.partition(num_workers): + param_list.append(({name : part}, expected_outputs)) + else: + param_list.append((inputs, expected_outputs)) merged_result = None split_result_list = [] - with futures.ThreadPoolExecutor(max_workers=num_workers) as executor: for result, split_result in executor.map(lambda p: BundleManager( self._controller, self._get_buffer, self._get_input_coder_impl, @@ -1571,7 +1577,6 @@ def process_bundle(self, inputs, expected_outputs): self._registered).process_bundle(*p), param_list): split_result_list += split_result - if merged_result is None: merged_result = result else: diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 4c12d95fb1ccd..72819cbba2e29 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -1179,11 +1179,13 @@ def create_pipeline(self): class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest): def create_pipeline(self): - return beam.Pipeline( + from apache_beam.options.pipeline_options import DirectOptions + p = beam.Pipeline( runner=fn_api_runner.FnApiRunner( - num_workers=2, default_environment=beam_runner_api_pb2.Environment( urn=python_urns.EMBEDDED_PYTHON_GRPC))) + p._options.view_as(DirectOptions).direct_num_workers = 2 + return p class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest): @@ -1199,15 +1201,20 @@ def test_register_finalizations(self): class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest): def create_pipeline(self): - return beam.Pipeline( - runner=fn_api_runner.FnApiRunner(num_workers=2)) + from apache_beam.options.pipeline_options import DirectOptions + p = beam.Pipeline(runner=fn_api_runner.FnApiRunner()) + p._options.view_as(DirectOptions).direct_num_workers = 2 + return p class FnApiRunnerTestWithMultiWorkersAndBundleRepeat(FnApiRunnerTest): def create_pipeline(self): - return beam.Pipeline( - runner=fn_api_runner.FnApiRunner(num_workers=2, bundle_repeat=2)) + from apache_beam.options.pipeline_options import DirectOptions + p = beam.Pipeline( + runner=fn_api_runner.FnApiRunner(bundle_repeat=2)) + p._options.view_as(DirectOptions).direct_num_workers = 2 + return p def test_register_finalizations(self): raise unittest.SkipTest("TODO: Avoid bundle finalizations on repeat.") @@ -1521,6 +1528,24 @@ def restriction_size(self, element, restriction): return restriction[1] - restriction[0] +class FnApiRunnerTestWithMultiWorkers(FnApiRunnerSplitTest): + + def create_pipeline(self): + from apache_beam.options.pipeline_options import DirectOptions + p = beam.Pipeline( + runner=fn_api_runner.FnApiRunner( + default_environment=beam_runner_api_pb2.Environment( + urn=python_urns.EMBEDDED_PYTHON_GRPC))) + p._options.view_as(DirectOptions).direct_num_workers = 2 + return p + + def test_checkpoint(self): + raise unittest.SkipTest("This test is for a single worker only.") + + def test_split_half(self): + raise unittest.SkipTest("This test is for a single worker only.") + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()