Skip to content

Commit

Permalink
[BEAM-9639] Separate Stage and Bundle execution. Improve typing
Browse files Browse the repository at this point in the history
annotations.
  • Loading branch information
pabloem committed Apr 7, 2020
1 parent 26090a6 commit 32e8965
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,26 @@
import collections
import itertools
from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
from typing import Dict
from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import Tuple

from typing_extensions import Protocol

from apache_beam import coders
from apache_beam.coders import BytesCoder
from apache_beam.coders.coder_impl import create_InputStream
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.coders.coders import GlobalWindowCoder
from apache_beam.coders.coders import WindowedValueCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner.translations import only_element
Expand All @@ -44,15 +55,18 @@

if TYPE_CHECKING:
from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.portability.fn_api_runner import worker_handlers
from apache_beam.transforms.window import BoundedWindow

ENCODED_IMPULSE_VALUE = WindowedValueCoder(
BytesCoder(), GlobalWindowCoder()).get_impl().encode_nested(
GlobalWindows.windowed_value(b''))

DataOutput = Dict[str, bytes]

DataSideInput = Dict[translations.SideInputId,
Tuple[bytes, beam_runner_api_pb2.FunctionSpec]]


class Buffer(Protocol):
def __iter__(self):
Expand Down Expand Up @@ -209,7 +223,7 @@ class WindowGroupingBuffer(object):
def __init__(
self,
access_pattern,
coder # type: coders.WindowedValueCoder
coder # type: WindowedValueCoder
):
# type: (...) -> None
# Here's where we would use a different type of partitioning
Expand Down Expand Up @@ -256,7 +270,7 @@ def encoded_items(self):

class FnApiRunnerExecutionContext(object):
"""
:var pcoll_buffers: (collections.defaultdict of str: list): Mapping of
:var pcoll_buffers: (dict): Mapping of
PCollection IDs to list that functions as buffer for the
``beam.PCollection``.
"""
Expand Down
150 changes: 84 additions & 66 deletions sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,13 @@
from apache_beam.runners.portability.fn_api_runner.translations import only_element
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandlerManager
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import environments
from apache_beam.utils import profiler
from apache_beam.utils import proto_utils
from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor

if TYPE_CHECKING:
from apache_beam.pipeline import Pipeline
from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.portability.api import metrics_pb2

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -383,27 +381,20 @@ def _store_side_inputs_in_state(self,
def _run_bundle_multiple_times_for_testing(
self,
runner_execution_context, # type: execution.FnApiRunnerExecutionContext
bundle_context_manager, # type: execution.BundleContextManager
bundle_manager, # type: BundleManager
data_input,
data_output, # type: DataOutput
cache_token_generator
data_output, # type: execution.DataOutput
):
# type: (...) -> None

"""
If bundle_repeat > 0, replay every bundle for profiling and debugging.
"""
# all workers share state, so use any worker_handler.
for k in range(self._bundle_repeat):
for _ in range(self._bundle_repeat):
try:
runner_execution_context.state_servicer.checkpoint()
testing_bundle_manager = ParallelBundleManager(
bundle_context_manager,
self._progress_frequency,
k,
cache_token_generator=cache_token_generator)
testing_bundle_manager.process_bundle(
data_input, data_output, dry_run=True)
bundle_manager.process_bundle(data_input, data_output, dry_run=True)
finally:
runner_execution_context.state_servicer.restore()

Expand Down Expand Up @@ -447,6 +438,17 @@ def _collect_written_timers_and_add_to_deferred_inputs(
deferred_inputs[transform_id].append(out.get())
written_timers.clear()

def _add_sdk_delayed_applications_to_deferred_inputs(
self, bundle_context_manager, bundle_result, deferred_inputs):
for delayed_application in bundle_result.process_bundle.residual_roots:
name = bundle_context_manager.input_for(
delayed_application.application.transform_id,
delayed_application.application.input_id)
if name not in deferred_inputs:
deferred_inputs[name] = ListBuffer(
coder_impl=bundle_context_manager.get_input_coder_impl(name))
deferred_inputs[name].append(delayed_application.application.element)

def _add_residuals_and_channel_splits_to_deferred_inputs(
self,
splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
Expand Down Expand Up @@ -510,7 +512,8 @@ def _run_stage(self,
Args:
runner_execution_context (execution.FnApiRunnerExecutionContext): An
object containing execution information for the pipeline.
stage (translations.Stage): A description of the stage to execute.
bundle_context_manager (execution.BundleContextManager): A description of
the stage to execute, and its context.
"""
data_input, data_output = bundle_context_manager.extract_bundle_inputs()

Expand All @@ -519,72 +522,88 @@ def _run_stage(self,
worker_handler_manager.register_process_bundle_descriptor(
bundle_context_manager.process_bundle_descriptor)

# Change cache token across bundle repeats
# We create the bundle manager here, as it can be reused for bundles of the
# same stage, but it may have to be created by-bundle later on.
cache_token_generator = FnApiRunner.get_cache_token_generator(static=False)

self._run_bundle_multiple_times_for_testing(
runner_execution_context,
bundle_context_manager,
data_input,
data_output,
cache_token_generator=cache_token_generator)

bundle_manager = ParallelBundleManager(
bundle_context_manager,
self._progress_frequency,
skip_registration=False,
cache_token_generator=cache_token_generator)

result, splits = bundle_manager.process_bundle(data_input, data_output)
final_result = None

last_result = result
last_sent = data_input
def merge_results(last_result):
""" Merge the latest result with other accumulated results. """
return (
last_result
if final_result is None else beam_fn_api_pb2.InstructionResponse(
process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
monitoring_infos=monitoring_infos.consolidate(
itertools.chain(
final_result.process_bundle.monitoring_infos,
last_result.process_bundle.monitoring_infos))),
error=final_result.error or last_result.error))

while True:
deferred_inputs = {} # type: Dict[str, PartitionableBuffer]

self._collect_written_timers_and_add_to_deferred_inputs(
runner_execution_context, bundle_context_manager, deferred_inputs)
# Queue any SDK-initiated delayed bundle applications.
for delayed_application in last_result.process_bundle.residual_roots:
name = bundle_context_manager.input_for(
delayed_application.application.transform_id,
delayed_application.application.input_id)
if name not in deferred_inputs:
deferred_inputs[name] = ListBuffer(
coder_impl=bundle_context_manager.get_input_coder_impl(name))
deferred_inputs[name].append(delayed_application.application.element)
# Queue any runner-initiated delayed bundle applications.
self._add_residuals_and_channel_splits_to_deferred_inputs(
splits, bundle_context_manager, last_sent, deferred_inputs)

if deferred_inputs:
# The worker will be waiting on these inputs as well.
for other_input in data_input:
if other_input not in deferred_inputs:
deferred_inputs[other_input] = ListBuffer(
coder_impl=bundle_context_manager.get_input_coder_impl(
other_input))
# TODO(robertwb): merge results
# TODO(BEAM-8486): this should be changed to _registered
bundle_manager._skip_registration = True # type: ignore[attr-defined]
last_result, splits = bundle_manager.process_bundle(
deferred_inputs, data_output)
last_sent = deferred_inputs
result = beam_fn_api_pb2.InstructionResponse(
process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
monitoring_infos=monitoring_infos.consolidate(
itertools.chain(
result.process_bundle.monitoring_infos,
last_result.process_bundle.monitoring_infos))),
error=result.error or last_result.error)
else:
last_result, deferred_inputs = self._run_bundle(runner_execution_context,
bundle_context_manager,
data_input,
data_output,
bundle_manager)

final_result = merge_results(last_result)
if not deferred_inputs:
break
else:
data_input = deferred_inputs
bundle_manager._registered = True

# Store the required downstream side inputs into state so it is accessible
# for the worker when it runs bundles that consume this stage's output.
bundle_context_manager.commit_output_views_to_state()
return final_result

def _run_bundle(
self,
runner_execution_context,
bundle_context_manager,
data_input,
data_output,
bundle_manager):
"""Execute a bundle, and return a result object, and deferred inputs."""
self._run_bundle_multiple_times_for_testing(
runner_execution_context, bundle_manager, data_input, data_output)

return result
result, splits = bundle_manager.process_bundle(data_input, data_output)

# Now we collect all the deferred inputs remaining from bundle execution.
# Deferred inputs can be:
# - timers
# - SDK-initiated deferred applications of root elements
# - Runner-initiated deferred applications of root elements
deferred_inputs = {} # type: Dict[str, execution.PartitionableBuffer]

self._collect_written_timers_and_add_to_deferred_inputs(
runner_execution_context, bundle_context_manager, deferred_inputs)

self._add_sdk_delayed_applications_to_deferred_inputs(
bundle_context_manager, result, deferred_inputs)

self._add_residuals_and_channel_splits_to_deferred_inputs(
splits, bundle_context_manager, data_input, deferred_inputs)

# After collecting deferred inputs, we 'pad' the structure with empty
# buffers for other expected inputs.
if deferred_inputs:
# The worker will be waiting on these inputs as well.
for other_input in data_input:
if other_input not in deferred_inputs:
deferred_inputs[other_input] = ListBuffer(
coder_impl=bundle_context_manager.get_input_coder_impl(
other_input))

return result, deferred_inputs

@staticmethod
def get_cache_token_generator(static=True):
Expand Down Expand Up @@ -928,7 +947,6 @@ def execute(part_map):
merged_result.process_bundle.monitoring_infos))),
error=result.error or merged_result.error)
assert merged_result is not None

return merged_result, split_result_list


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import functools
import logging
from builtins import object
from typing import Any
from typing import Container
from typing import DefaultDict
from typing import Dict
Expand Down Expand Up @@ -78,9 +79,6 @@
# SideInputId is identified by a consumer ParDo + tag.
SideInputId = Tuple[str, str]

DataSideInput = Dict[SideInputId,
Tuple[bytes, beam_runner_api_pb2.FunctionSpec]]


class Stage(object):
"""A set of Transforms that can be sent to the worker for processing."""
Expand Down Expand Up @@ -199,7 +197,6 @@ def side_inputs(self):
yield (
transform.inputs[si_tag], (transform.unique_name, si_tag),
payload.side_inputs[si_tag].access_pattern)
return []

def has_as_main_input(self, pcoll):
for transform in self.transforms:
Expand Down Expand Up @@ -648,10 +645,10 @@ def get_all_side_inputs():
all_side_inputs = get_all_side_inputs()

downstream_side_inputs_by_stage = {
} # type: Dict[Stage, DefaultDict[str, Set[SideInputId]]]
} # type: Dict[Stage, DefaultDict[str, Dict[SideInputId, Any]]]

def compute_downstream_side_inputs(stage):
# type: (Stage) -> Dict[str, Dict[Tuple[Stage, str], Any]]
# type: (Stage) -> Dict[str, Dict[SideInputId, Any]]
if stage not in downstream_side_inputs_by_stage:
downstream_side_inputs = collections.defaultdict(dict)
for transform in stage.transforms:
Expand All @@ -666,7 +663,7 @@ def compute_downstream_side_inputs(stage):
downstream_side_inputs.update(
compute_downstream_side_inputs(consumer))
downstream_side_inputs_by_stage[stage] = downstream_side_inputs
return downstream_side_inputs_by_stage[stage]
return dict(downstream_side_inputs_by_stage[stage])

for stage in stages:
stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
Expand Down Expand Up @@ -1059,7 +1056,7 @@ def expand_gbk(stages, pipeline_context):
urn=bundle_processor.DATA_OUTPUT_URN,
payload=grouping_buffer))
],
downstream_side_inputs=frozenset(),
downstream_side_inputs={},
must_follow=stage.must_follow)
yield gbk_write

Expand Down Expand Up @@ -1112,7 +1109,7 @@ def fix_flatten_coders(stages, pipeline_context):
urn=bundle_processor.IDENTITY_DOFN_URN),
environment_id=transform.environment_id)
],
downstream_side_inputs=frozenset(),
downstream_side_inputs={},
must_follow=stage.must_follow)
pcollections[transcoded_pcollection].CopyFrom(pcollections[pcoll_in])
pcollections[transcoded_pcollection].unique_name = (
Expand Down Expand Up @@ -1150,7 +1147,7 @@ def sink_flattens(stages, pipeline_context):
urn=bundle_processor.DATA_OUTPUT_URN,
payload=buffer_id))
],
downstream_side_inputs=frozenset(),
downstream_side_inputs={},
must_follow=stage.must_follow)
flatten_writes.append(flatten_write)
yield flatten_write
Expand Down

0 comments on commit 32e8965

Please sign in to comment.