Skip to content

Commit

Permalink
BEAM-3645 add changes from review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Hannah-Jiang committed Jun 25, 2019
1 parent 1b4df97 commit fafd01c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 83 deletions.
159 changes: 82 additions & 77 deletions sdks/python/apache_beam/runners/portability/fn_api_runner.py
Expand Up @@ -139,23 +139,16 @@ def done(self):
self._state = self.DONE_STATE


class _PartitionBuffer(object):
"""This class is created to support partition(n) function."""
def __init__(self, inputs):
self._inputs = inputs

class _ListBuffer(list):
"""Used to support parititioning of a list."""
def partition(self, n):
v = list(self._inputs.values())[0]
if isinstance(v, list):
return [self._inputs]
elif isinstance(v, _GroupingBuffer):
partitions = []
for name, input in self._inputs.items():
for part in input.partition(n):
partitions.append({name : part})
return partitions
else:
raise NotImplementedError(type(self._inputs))
n = min(n, len(self))
# for empty list, return iter([[]])
groups = [[] for _ in range(max(n, 1))]
for idx, input in enumerate(self):
groups[idx % n].append(input)

return iter(groups)


class _GroupingBuffer(object):
Expand Down Expand Up @@ -184,7 +177,11 @@ def append(self, elements_data):
value if is_trivial_windowing
else windowed_key_value.with_value(value))

def _output_ready(self):
def partition(self, n):
""" It is used to partition _GroupingBuffer to N parts. Once it is
partitioned, it would not be re-partitioned with diff N. Re-partition
is not supported now.
"""
if len(self._grouped_output) == 0:
if self._windowing.is_default():
globally_window = GlobalWindows.windowed_value(None).with_value
Expand All @@ -197,32 +194,33 @@ def _output_ready(self):
# May need to revise.
trigger_driver = trigger.create_trigger_driver(self._windowing, True)
windowed_key_values = trigger_driver.process_entire_key

key_coder_impl = self._key_coder.get_impl()
coder_impl = self._post_grouped_coder.get_impl()
for encoded_key, windowed_values in self._table.items():
key_coder_impl = self._key_coder.get_impl()
n = min(n, len(self._table))
output_stream_list = []
for _ in range(n):
output_stream_list.append(create_OutputStream())
for idx, (encoded_key, windowed_values) in enumerate(self._table.items()):
key = key_coder_impl.decode(encoded_key)
output_stream = create_OutputStream()
for wkvs in windowed_key_values(key, windowed_values):
coder_impl.encode_to_stream(wkvs, output_stream, True)

self._grouped_output.append(output_stream.get())
coder_impl.encode_to_stream(wkvs, output_stream_list[idx % n], True)
for output_stream in output_stream_list:
self._grouped_output.append([output_stream.get()])
self._table = None

def __iter__(self):
self._output_ready()
return iter(self._grouped_output)

def partition(self, n):
self._output_ready()

n = min(n, len(self._grouped_output))
partitions = [[] for _ in range(n)]

for idx, out in enumerate(self._grouped_output):
partitions[idx % n].append(out)

return partitions
def __iter__(self):
""" Since partition() returns a list of list, added this __iter__ to return
a list to simplify code when we need to iterate through ALL elements of
_GroupingBuffer.
"""
if len(self._grouped_output) == 0:
self.partition(1)
iter_result = []
for output in self._grouped_output:
for out in output:
iter_result.append(out)
return iter(iter_result)


class _WindowGroupingBuffer(object):
Expand Down Expand Up @@ -268,13 +266,11 @@ def encoded_items(self):


class FnApiRunner(runner.PipelineRunner):
_num_workers = 1

def __init__(
self,
default_environment=None,
bundle_repeat=0,
num_workers=1,
use_state_iterables=False,
provision_info=None):
"""Creates a new Fn API Runner.
Expand All @@ -288,12 +284,12 @@ def __init__(
provision_info: provisioning info to make available to workers, or None
"""
super(FnApiRunner, self).__init__()
FnApiRunner._num_workers = num_workers
self._last_uid = -1
self._default_environment = (
default_environment
or beam_runner_api_pb2.Environment(urn=python_urns.EMBEDDED_PYTHON))
self._bundle_repeat = bundle_repeat
self._num_workers = 1
self._progress_frequency = None
self._profiler_factory = None
self._use_state_iterables = use_state_iterables
Expand Down Expand Up @@ -322,8 +318,8 @@ def run_pipeline(self, pipeline, options):
pipeline.visit(DataflowRunner.group_by_key_input_visitor())
self._bundle_repeat = self._bundle_repeat or options.view_as(
pipeline_options.DirectOptions).direct_runner_bundle_repeat
FnApiRunner._num_workers = max(FnApiRunner._num_workers, options.view_as(
pipeline_options.DirectOptions).direct_num_workers)
self._num_workers = options.view_as(
pipeline_options.DirectOptions).direct_num_workers or self._num_workers
self._profiler_factory = profiler.Profile.factory_from_options(
options.view_as(pipeline_options.ProfilingOptions))

Expand Down Expand Up @@ -413,7 +409,7 @@ def run_stages(self, stage_context, stages):

try:
with self.maybe_profile():
pcoll_buffers = collections.defaultdict(list)
pcoll_buffers = collections.defaultdict(_ListBuffer)
for stage in stages:
stage_results = self._run_stage(
worker_handler_manager.get_worker_handler,
Expand Down Expand Up @@ -455,6 +451,7 @@ def _store_side_inputs_in_state(self,
def _run_bundle_multiple_times_for_testing(self,
controller,
process_bundle_descriptor,
num_workers,
data_input,
data_output,
get_input_coder_callable):
Expand All @@ -463,7 +460,7 @@ def _run_bundle_multiple_times_for_testing(self,
controller.state.checkpoint()
ParallelBundleManager(
controller, lambda pcoll_id: [], get_input_coder_callable,
process_bundle_descriptor, self._progress_frequency, k
process_bundle_descriptor, num_workers, self._progress_frequency, k
).process_bundle(data_input, data_output)
finally:
controller.state.restore()
Expand Down Expand Up @@ -499,7 +496,7 @@ def _collect_written_timers_and_add_to_deferred_inputs(self,
for windowed_key_timer in timers_by_key_and_window.values():
windowed_timer_coder_impl.encode_to_stream(
windowed_key_timer, out, True)
deferred_inputs[transform_id] = [out.get()]
deferred_inputs[transform_id] = _ListBuffer([out.get()])
written_timers[:] = []

def _add_residuals_and_channel_splits_to_deferred_inputs(
Expand Down Expand Up @@ -554,7 +551,7 @@ def _extract_stage_data_endpoints(
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
target = transform.unique_name, only_element(transform.outputs)
if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
data_input[target] = [ENCODED_IMPULSE_VALUE]
data_input[target] = _ListBuffer([ENCODED_IMPULSE_VALUE])
else:
data_input[target] = pcoll_buffers[pcoll_id]
coder_id = pipeline_components.pcollections[
Expand Down Expand Up @@ -645,7 +642,7 @@ def get_buffer(buffer_id):
"""Returns the buffer for a given (operation_type, PCollection ID).
For grouping-typed operations, we produce a ``_GroupingBuffer``. For
others, we just use a list.
others, we produce a ``_ListBuffer``.
"""
kind, name = split_buffer_id(buffer_id)
if kind in ('materialize', 'timers'):
Expand Down Expand Up @@ -684,13 +681,14 @@ def get_input_coder_impl(transform_id):

self._run_bundle_multiple_times_for_testing(controller,
process_bundle_descriptor,
self._num_workers,
data_input,
data_output,
get_input_coder_impl)

bundle_manager = ParallelBundleManager(
controller, get_buffer, get_input_coder_impl, process_bundle_descriptor,
self._progress_frequency)
self._num_workers, self._progress_frequency)

result, splits = bundle_manager.process_bundle(data_input, data_output)

Expand All @@ -708,7 +706,7 @@ def input_for(ptransform_id, input_id):
last_sent = data_input

while True:
deferred_inputs = collections.defaultdict(list)
deferred_inputs = collections.defaultdict(_ListBuffer)

self._collect_written_timers_and_add_to_deferred_inputs(
context, pipeline_components, stage, get_buffer, deferred_inputs)
Expand All @@ -729,15 +727,14 @@ def input_for(ptransform_id, input_id):
# The worker will be waiting on these inputs as well.
for other_input in data_input:
if other_input not in deferred_inputs:
deferred_inputs[other_input] = []
deferred_inputs[other_input] = _ListBuffer([])
# TODO(robertwb): merge results
last_result, splits = ParallelBundleManager(
controller,
get_buffer,
get_input_coder_impl,
process_bundle_descriptor,
self._progress_frequency,
True).process_bundle(deferred_inputs, data_output)
bundle_manager._skip_registration = True
# We cannot split deferred_input until we include residual_roots to
# merged results. Without residual_roots, pipeline stops earlier and we
# may miss some data.
last_result, splits = bundle_manager.process_bundle(
deferred_inputs, data_output, num_workers=1)
last_sent = deferred_inputs
result = beam_fn_api_pb2.InstructionResponse(
process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
Expand Down Expand Up @@ -783,7 +780,8 @@ def _extract_endpoints(stage,
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
data_input[transform.unique_name] = [ENCODED_IMPULSE_VALUE]
data_input[transform.unique_name] = _ListBuffer(
[ENCODED_IMPULSE_VALUE])
else:
data_input[transform.unique_name] = pcoll_buffers[pcoll_id]
coder_id = pipeline_components.pcollections[
Expand Down Expand Up @@ -1027,7 +1025,6 @@ def push(self, request):
if not request.instruction_id:
self._uid_counter += 1
request.instruction_id = 'control_%s' % self._uid_counter

response = self.worker.do_instruction(request)
return ControlFuture(request.instruction_id, response)

Expand Down Expand Up @@ -1370,7 +1367,7 @@ class BundleManager(object):

def __init__(
self, controller, get_buffer, get_input_coder_impl, bundle_descriptor,
progress_frequency=None, skip_registration=False):
num_workers, progress_frequency=None, skip_registration=False):
"""Set up a bundle manager.
Args:
Expand All @@ -1384,6 +1381,7 @@ def __init__(
self._controller = controller
self._get_buffer = get_buffer
self._get_input_coder_impl = get_input_coder_impl
self._num_workers = num_workers
self._bundle_descriptor = bundle_descriptor
self._registered = skip_registration
self._progress_frequency = progress_frequency
Expand Down Expand Up @@ -1485,12 +1483,9 @@ def _generate_splits_for_testing(self,
break
return split_results

def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None):
def process_bundle(self, inputs, expected_outputs):
# Unique id for the instruction processing this bundle.
if parallel_uid_counter:
BundleManager._uid_counter = parallel_uid_counter
else:
BundleManager._uid_counter += 1
BundleManager._uid_counter += 1
process_bundle_id = 'bundle_%s' % BundleManager._uid_counter

# Register the bundle descriptor, if needed - noop if already registered.
Expand All @@ -1500,7 +1495,8 @@ def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None):
if not split_manager:
# If there is no split_manager, write all input data to the channel.
for transform_id, elements in inputs.items():
self._send_input_to_worker(process_bundle_id, transform_id, elements)
self._send_input_to_worker(
process_bundle_id, transform_id, elements)

# Check that the bundle was successfully registered.
if registration_future and registration_future.get().error:
Expand All @@ -1518,8 +1514,8 @@ def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None):
self._controller, process_bundle_id, self._progress_frequency):

if split_manager:
split_results = self._generate_splits_for_testing(
split_manager, inputs, process_bundle_id)
split_results = self._generate_splits_for_testing(split_manager, inputs,
process_bundle_id)

# Gather all output data.
for output in self._controller.data_plane_handler.input_elements(
Expand Down Expand Up @@ -1550,28 +1546,37 @@ def process_bundle(self, inputs, expected_outputs, parallel_uid_counter=None):


class ParallelBundleManager(BundleManager):
_uid_counter = 0

def process_bundle(self, inputs, expected_outputs):
num_workers = FnApiRunner._num_workers
def _check_inputs_split(self, expected_outputs):
# We skip splitting inputs when timer is set, because operations are not
# triggered until we sent inputs for timers.
for _, pcoll_id in expected_outputs.items():
kind = split_buffer_id(pcoll_id)[0]
if kind in ['timers']:
return False

return True

def process_bundle(self, inputs, expected_outputs, num_workers=None):
num_workers = num_workers or self._num_workers
param_list = []

for part in _PartitionBuffer(inputs).partition(num_workers):
ParallelBundleManager._uid_counter += 1
param_list.append((part, expected_outputs,
ParallelBundleManager._uid_counter))
if self._check_inputs_split(expected_outputs):
for name, input in inputs.items():
for part in input.partition(num_workers):
param_list.append(({name : part}, expected_outputs))
else:
param_list.append((inputs, expected_outputs))

merged_result = None
split_result_list = []

with futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
for result, split_result in executor.map(lambda p: BundleManager(
self._controller, self._get_buffer, self._get_input_coder_impl,
self._bundle_descriptor, self._progress_frequency,
self._registered).process_bundle(*p), param_list):

split_result_list += split_result

if merged_result is None:
merged_result = result
else:
Expand Down

0 comments on commit fafd01c

Please sign in to comment.