Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from apache_beam.runners.portability.fn_api_runner.translations import only_element
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import unique_name
from apache_beam.runners.portability.fn_api_runner.watermark_manager import WatermarkManager
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import core
from apache_beam.transforms import trigger
Expand All @@ -72,6 +73,7 @@
from apache_beam.runners.portability.fn_api_runner.fn_runner import DataOutput
from apache_beam.runners.portability.fn_api_runner.fn_runner import OutputTimers
from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
from apache_beam.runners.portability.fn_api_runner.translations import TimerFamilyId
from apache_beam.transforms.window import BoundedWindow

ENCODED_IMPULSE_VALUE = WindowedValueCoder(
Expand Down Expand Up @@ -563,20 +565,19 @@ class FnApiRunnerExecutionContext(object):
``beam.PCollection``.
"""
def __init__(self,
stages, # type: List[translations.Stage]
worker_handler_manager, # type: worker_handlers.WorkerHandlerManager
pipeline_components, # type: beam_runner_api_pb2.Components
safe_coders, # type: Dict[str, str]
data_channel_coders, # type: Dict[str, str]
):
# type: (...) -> None

stages, # type: List[translations.Stage]
worker_handler_manager, # type: worker_handlers.WorkerHandlerManager
pipeline_components, # type: beam_runner_api_pb2.Components
safe_coders: translations.SafeCoderMapping,
data_channel_coders: Dict[str, str],
) -> None:
"""
:param worker_handler_manager: This class manages the set of worker
handlers, and the communication with state / control APIs.
:param pipeline_components: (beam_runner_api_pb2.Components): TODO
:param safe_coders:
:param data_channel_coders:
:param pipeline_components: (beam_runner_api_pb2.Components)
:param safe_coders: A map from Coder ID to Safe Coder ID.
:param data_channel_coders: A map from PCollection ID to the ID of the Coder
for that PCollection.
"""
self.stages = stages
self.side_input_descriptors_by_stage = (
Expand All @@ -588,6 +589,12 @@ def __init__(self,
self.safe_coders = safe_coders
self.data_channel_coders = data_channel_coders

self.input_transform_to_buffer_id = {
t.unique_name: t.spec.payload
for s in stages for t in s.transforms
if t.spec.urn == bundle_processor.DATA_INPUT_URN
}
self.watermark_manager = WatermarkManager(stages)
self.pipeline_context = pipeline_context.PipelineContext(
self.pipeline_components,
iterable_state_write=self._iterable_state_write)
Expand Down Expand Up @@ -816,7 +823,7 @@ def _build_process_bundle_descriptor(self):
timer_api_service_descriptor=self.data_api_service_descriptor())

def extract_bundle_inputs_and_outputs(self):
# type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], bytes]]
# type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]

"""Returns maps of transform names to PCollection identifiers.

Expand Down Expand Up @@ -956,8 +963,8 @@ def get_buffer(self, buffer_id, transform_id):
raise NotImplementedError(buffer_id)
return self.execution_context.pcoll_buffers[buffer_id]

def input_for(self, transform_id, input_id):
# type: (str, str) -> str
def input_for(self, transform_id: str, input_id: str) -> str:
"""Returns the name of the transform producing the given PCollection."""
input_pcoll = self.process_bundle_descriptor.transforms[
transform_id].inputs[input_id]
for read_id, proto in self.process_bundle_descriptor.transforms.items():
Expand Down
Loading