Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-2914] Add portable merging window support to Python. #12995

Merged
merged 14 commits into from
Feb 11, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ def test_callbacks_with_exception(self):
def test_register_finalizations(self):
raise unittest.SkipTest("BEAM-11021")

def test_custom_merging_window(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may need to add this to spark_runner_test.py as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

raise unittest.SkipTest("BEAM-11004")

# Inherits all other tests.


Expand Down
241 changes: 229 additions & 12 deletions sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import collections
import copy
import itertools
import uuid
import weakref
from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
Expand Down Expand Up @@ -55,6 +57,7 @@
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import unique_name
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import core
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.window import GlobalWindow
Expand All @@ -69,7 +72,6 @@
from apache_beam.runners.portability.fn_api_runner.fn_runner import DataOutput
from apache_beam.runners.portability.fn_api_runner.fn_runner import OutputTimers
from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
from apache_beam.transforms import core
from apache_beam.transforms.window import BoundedWindow

ENCODED_IMPULSE_VALUE = WindowedValueCoder(
Expand Down Expand Up @@ -338,6 +340,222 @@ def from_runner_api_parameter(window_coder_id, context):
context.coders[window_coder_id.decode('utf-8')])


class GenericMergingWindowFn(window.WindowFn):

URN = 'internal-generic-merging'

TO_SDK_TRANSFORM = 'read'
FROM_SDK_TRANSFORM = 'write'

_HANDLES = {} # type: Dict[str, GenericMergingWindowFn]

def __init__(self, execution_context, windowing_strategy_proto):
# type: (FnApiRunnerExecutionContext, beam_runner_api_pb2.WindowingStrategy) -> None
self._worker_handler = None # type: Optional[worker_handlers.WorkerHandler]
self._handle_id = handle_id = uuid.uuid4().hex
self._HANDLES[handle_id] = self
# ExecutionContexts are expensive, we don't want to keep them in the
# static dictionary forever. Instead we hold a weakref and pop self
# out of the dict once this context goes away.
self._execution_context_ref_obj = weakref.ref(
execution_context, lambda _: self._HANDLES.pop(handle_id, None))
self._windowing_strategy_proto = windowing_strategy_proto
self._counter = 0
# Lazily created in make_process_bundle_descriptor()
self._process_bundle_descriptor = None
self._bundle_processor_id = None # type: Optional[str]
self.windowed_input_coder_impl = None # type: Optional[CoderImpl]
self.windowed_output_coder_impl = None # type: Optional[CoderImpl]

def _execution_context_ref(self):
# type: () -> FnApiRunnerExecutionContext
result = self._execution_context_ref_obj()
assert result is not None
return result

def payload(self):
# type: () -> bytes
return self._handle_id.encode('utf-8')

@staticmethod
@window.urns.RunnerApiFn.register_urn(URN, bytes)
def from_runner_api_parameter(handle_id, unused_context):
# type: (bytes, Any) -> GenericMergingWindowFn
return GenericMergingWindowFn._HANDLES[handle_id.decode('utf-8')]

def assign(self, assign_context):
# type: (window.WindowFn.AssignContext) -> Iterable[window.BoundedWindow]
raise NotImplementedError()

def merge(self, merge_context):
# type: (window.WindowFn.MergeContext) -> None
worker_handler = self.worker_handle()

assert self.windowed_input_coder_impl is not None
assert self.windowed_output_coder_impl is not None
process_bundle_id = self.uid('process')
to_worker = worker_handler.data_conn.output_stream(
process_bundle_id, self.TO_SDK_TRANSFORM)
to_worker.write(
self.windowed_input_coder_impl.encode_nested(
window.GlobalWindows.windowed_value((b'', merge_context.windows))))
to_worker.close()

process_bundle_req = beam_fn_api_pb2.InstructionRequest(
instruction_id=process_bundle_id,
process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
process_bundle_descriptor_id=self._bundle_processor_id))
result_future = worker_handler.control_conn.push(process_bundle_req)
for output in worker_handler.data_conn.input_elements(
process_bundle_id, [self.FROM_SDK_TRANSFORM],
abort_callback=lambda: bool(result_future.is_done() and result_future.
get().error)):
if isinstance(output, beam_fn_api_pb2.Elements.Data):
windowed_result = self.windowed_output_coder_impl.decode_nested(
output.data)
for merge_result, originals in windowed_result.value[1][1]:
merge_context.merge(originals, merge_result)
else:
raise RuntimeError("Unexpected data: %s" % output)

result = result_future.get()
if result.error:
raise RuntimeError(result.error)
# The result was "returned" via the merge callbacks on merge_context above.

def get_window_coder(self):
# type: () -> coders.Coder
return self._execution_context_ref().pipeline_context.coders[
self._windowing_strategy_proto.window_coder_id]

def worker_handle(self):
# type: () -> worker_handlers.WorkerHandler
if self._worker_handler is None:
worker_handler_manager = self._execution_context_ref(
).worker_handler_manager
self._worker_handler = worker_handler_manager.get_worker_handlers(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._worker_handler -> self._worker_handle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Fixed.

self._windowing_strategy_proto.environment_id, 1)[0]
process_bundle_decriptor = self.make_process_bundle_descriptor(
self._worker_handler.data_api_service_descriptor(),
self._worker_handler.state_api_service_descriptor())
worker_handler_manager.register_process_bundle_descriptor(
process_bundle_decriptor)
return self._worker_handler

def make_process_bundle_descriptor(
self, data_api_service_descriptor, state_api_service_descriptor):
# type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor

"""Creates a ProcessBundleDescriptor for invoking the WindowFn's
merge operation.
"""
def make_channel_payload(coder_id):
# type: (str) -> bytes
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
if data_api_service_descriptor:
data_spec.api_service_descriptor.url = (data_api_service_descriptor.url)
return data_spec.SerializeToString()

pipeline_context = self._execution_context_ref().pipeline_context
global_windowing_strategy_id = self.uid('global_windowing_strategy')
global_windowing_strategy_proto = core.Windowing(
window.GlobalWindows()).to_runner_api(pipeline_context)
coders = dict(pipeline_context.coders.get_id_to_proto_map())

def make_coder(urn, *components):
# type: (str, str) -> str
coder_proto = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=urn),
component_coder_ids=components)
coder_id = self.uid('coder')
coders[coder_id] = coder_proto
pipeline_context.coders.put_proto(coder_id, coder_proto)
return coder_id

bytes_coder_id = make_coder(common_urns.coders.BYTES.urn)
window_coder_id = self._windowing_strategy_proto.window_coder_id
global_window_coder_id = make_coder(common_urns.coders.GLOBAL_WINDOW.urn)
iter_window_coder_id = make_coder(
common_urns.coders.ITERABLE.urn, window_coder_id)
input_coder_id = make_coder(
common_urns.coders.KV.urn, bytes_coder_id, iter_window_coder_id)
output_coder_id = make_coder(
common_urns.coders.KV.urn,
bytes_coder_id,
make_coder(
common_urns.coders.KV.urn,
iter_window_coder_id,
make_coder(
common_urns.coders.ITERABLE.urn,
make_coder(
common_urns.coders.KV.urn,
window_coder_id,
iter_window_coder_id))))
windowed_input_coder_id = make_coder(
common_urns.coders.WINDOWED_VALUE.urn,
input_coder_id,
global_window_coder_id)
windowed_output_coder_id = make_coder(
common_urns.coders.WINDOWED_VALUE.urn,
output_coder_id,
global_window_coder_id)

self.windowed_input_coder_impl = pipeline_context.coders[
windowed_input_coder_id].get_impl()
self.windowed_output_coder_impl = pipeline_context.coders[
windowed_output_coder_id].get_impl()

self._bundle_processor_id = self.uid('merge_windows')
return beam_fn_api_pb2.ProcessBundleDescriptor(
id=self._bundle_processor_id,
transforms={
self.TO_SDK_TRANSFORM: beam_runner_api_pb2.PTransform(
unique_name='MergeWindows/Read',
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_INPUT_URN,
payload=make_channel_payload(windowed_input_coder_id)),
outputs={'input': 'input'}),
'Merge': beam_runner_api_pb2.PTransform(
unique_name='MergeWindows/Merge',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.MERGE_WINDOWS.urn,
payload=self._windowing_strategy_proto.window_fn.
SerializeToString()),
inputs={'input': 'input'},
outputs={'output': 'output'}),
self.FROM_SDK_TRANSFORM: beam_runner_api_pb2.PTransform(
unique_name='MergeWindows/Write',
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_OUTPUT_URN,
payload=make_channel_payload(windowed_output_coder_id)),
inputs={'output': 'output'}),
},
pcollections={
'input': beam_runner_api_pb2.PCollection(
unique_name='input',
windowing_strategy_id=global_windowing_strategy_id,
coder_id=input_coder_id),
'output': beam_runner_api_pb2.PCollection(
unique_name='output',
windowing_strategy_id=global_windowing_strategy_id,
coder_id=output_coder_id),
},
coders=coders,
windowing_strategies={
global_windowing_strategy_id: global_windowing_strategy_proto,
},
environments=dict(
self._execution_context_ref().pipeline_components.environments.
items()),
state_api_service_descriptor=state_api_service_descriptor,
timer_api_service_descriptor=data_api_service_descriptor)

def uid(self, name=''):
# type: (str) -> str
self._counter += 1
return '%s_%s_%s' % (self._handle_id, name, self._counter)


class FnApiRunnerExecutionContext(object):
"""
:var pcoll_buffers: (dict): Mapping of
Expand Down Expand Up @@ -443,23 +661,22 @@ def _make_safe_windowing_strategy(self, id):
windowing_strategy_proto = self.pipeline_components.windowing_strategies[id]
if windowing_strategy_proto.window_fn.urn in SAFE_WINDOW_FNS:
return id
elif (windowing_strategy_proto.merge_status ==
beam_runner_api_pb2.MergeStatus.NON_MERGING) or True:
else:
safe_id = id + '_safe'
while safe_id in self.pipeline_components.windowing_strategies:
safe_id += '_'
safe_proto = copy.copy(windowing_strategy_proto)
safe_proto.window_fn.urn = GenericNonMergingWindowFn.URN
safe_proto.window_fn.payload = (
windowing_strategy_proto.window_coder_id.encode('utf-8'))
if (windowing_strategy_proto.merge_status ==
beam_runner_api_pb2.MergeStatus.NON_MERGING):
safe_proto.window_fn.urn = GenericNonMergingWindowFn.URN
safe_proto.window_fn.payload = (
windowing_strategy_proto.window_coder_id.encode('utf-8'))
else:
window_fn = GenericMergingWindowFn(self, windowing_strategy_proto)
safe_proto.window_fn.urn = GenericMergingWindowFn.URN
safe_proto.window_fn.payload = window_fn.payload()
self.pipeline_context.windowing_strategies.put_proto(safe_id, safe_proto)
return safe_id
elif windowing_strategy_proto.window_fn.urn == python_urns.PICKLED_WINDOWFN:
return id
else:
raise NotImplementedError(
'[BEAM-10119] Unknown merging WindowFn: %s' %
windowing_strategy_proto)

@property
def state_servicer(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import print_function

import collections
import gc
import logging
import os
import random
Expand All @@ -46,6 +47,7 @@
from tenacity import stop_after_attempt

import apache_beam as beam
from apache_beam.coders import coders
from apache_beam.coders.coders import StrUtf8Coder
from apache_beam.io import restriction_trackers
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
Expand Down Expand Up @@ -780,6 +782,21 @@ def test_windowing(self):
| beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))

def test_custom_merging_window(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([1, 2, 100, 101, 102])
| beam.Map(lambda t: window.TimestampedValue(('k', t), t))
| beam.WindowInto(CustomMergingWindowFn())
| beam.GroupByKey()
| beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
assert_that(
res, equal_to([('k', [1]), ('k', [101]), ('k', [2, 100, 102])]))
gc.collect()
from apache_beam.runners.portability.fn_api_runner.execution import GenericMergingWindowFn
self.assertEqual(GenericMergingWindowFn._HANDLES, {})

@unittest.skip('BEAM-9119: test is flaky')
def test_large_elements(self):
with self.create_pipeline() as p:
Expand Down Expand Up @@ -2002,6 +2019,26 @@ def test_gbk_many_values(self):
assert_that(r, equal_to([VALUES_PER_ELEMENT * NUM_OF_ELEMENTS]))


# TODO(robertwb): Why does pickling break when this is inlined?
class CustomMergingWindowFn(window.WindowFn):
def assign(self, assign_context):
return [
window.IntervalWindow(
assign_context.timestamp, assign_context.timestamp + 1)
]

def merge(self, merge_context):
evens = [w for w in merge_context.windows if w.start % 2 == 0]
if evens:
merge_context.merge(
evens,
window.IntervalWindow(
min(w.start for w in evens), max(w.end for w in evens)))

def get_window_coder(self):
return coders.IntervalWindowCoder()


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def test_flattened_side_input(self):
super(SparkRunnerTest,
self).test_flattened_side_input(with_transcoding=False)

def test_custom_merging_window(self):
raise unittest.SkipTest("BEAM-11004")

# Inherits all other tests from PortableRunnerTest.


Expand Down