Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-9340] Populate requirements for Python DoFn properties. #10909

Merged
merged 2 commits into from Feb 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions model/job-management/src/main/proto/beam_expansion_api.proto
Expand Up @@ -57,6 +57,10 @@ message ExpansionResponse {
// and subtransforms.
org.apache.beam.model.pipeline.v1.PTransform transform = 2;

// A set of requirements that must be appended to this pipeline's
// requirements.
repeated string requirements = 3;

// (Optional) An string representation of any error encountered while
// attempting to expand this transform.
string error = 10;
Expand Down
48 changes: 44 additions & 4 deletions model/pipeline/src/main/proto/beam_runner_api.proto
Expand Up @@ -443,25 +443,40 @@ message ParDoPayload {
map<string, SideInput> side_inputs = 3;

// (Optional) A mapping of local state names to state specifications.
// If this is set, the stateful processing requirement should also
// be placed in the pipeline requirements.
map<string, StateSpec> state_specs = 4;

// (Optional) A mapping of local timer names to timer specifications.
// If this is set, the stateful processing requirement should also
// be placed in the pipeline requirements.
map<string, TimerSpec> timer_specs = 5;

// (Optional) A mapping of local timer family names to timer specifications.
// If this is set, the stateful processing requirement should also
// be placed in the pipeline requirements.
map<string, TimerFamilySpec> timer_family_specs = 9;

// Whether the DoFn is splittable
bool splittable = 6;

// (Required if splittable == true) Id of the restriction coder.
string restriction_coder_id = 7;

// (Optional) Only set when this ParDo can request bundle finalization.
// If this is set, the corresponding standard requirement should also
// be placed in the pipeline requirements.
bool requests_finalization = 8;

// (Optional) A mapping of local timer family names to timer specifications.
map<string, TimerFamilySpec> timer_family_specs = 9;

// Whether this stage requires time sorted input
// Whether this stage requires time sorted input.
// If this is set, the corresponding standard requirement should also
// be placed in the pipeline requirements.
bool requires_time_sorted_input = 10;

// Whether this stage requires stable input.
// If this is set, the corresponding standard requirement should also
// be placed in the pipeline requirements.
bool requires_stable_input = 11;
}

// Parameters that a UDF might require.
Expand Down Expand Up @@ -1318,6 +1333,31 @@ message StandardProtocols {
}
}

// These URNs are used to indicate requirements of a pipeline that cannot
// simply be expressed as a component (such as a Coder or PTransform) that the
// runner must understand. In many cases, this indicates a particular field
// of a transform must be inspected and respected (which allows new fields
// to be added in a forwards-compatible way).
message StandardRequirements {
enum Enum {
// This requirement indicates the state_spec and time_spec fields of ParDo
// transform payloads must be inspected.
REQUIRES_STATEFUL_PROCESSING = 0 [(beam_urn) = "beam:requirement:pardo:stateful:v1"];

// This requirement indicates the requests_finalization field of ParDo
// transform payloads must be inspected.
REQUIRES_BUNDLE_FINALIZATION = 1 [(beam_urn) = "beam:requirement:pardo:finalization:v1"];

// This requirement indicates the requires_stable_input field of ParDo
// transform payloads must be inspected.
REQUIRES_STABLE_INPUT = 2 [(beam_urn) = "beam:requirement:pardo:stable_input:v1"];

// This requirement indicates the requires_time_sorted_input field of ParDo
// transform payloads must be inspected.
REQUIRES_TIME_SORTED_INPUT = 3 [(beam_urn) = "beam:requirement:pardo:time_sorted_input:v1"];
}
}

extend google.protobuf.EnumValueOptions {
// An extension to be used for specifying the standard URN of various
// pipeline entities, e.g. transforms, functions, coders etc.
Expand Down
7 changes: 5 additions & 2 deletions sdks/python/apache_beam/pipeline.py
Expand Up @@ -778,7 +778,8 @@ def visit_transform(self, transform_node):
root_transform_id = context.transforms.get_id(self._root_transform())
proto = beam_runner_api_pb2.Pipeline(
root_transform_ids=[root_transform_id],
components=context.to_runner_api())
components=context.to_runner_api(),
requirements=context.requirements())
proto.components.transforms[root_transform_id].unique_name = (
root_transform_id)
if return_context:
Expand All @@ -799,7 +800,9 @@ def from_runner_api(proto, # type: beam_runner_api_pb2.Pipeline
p = Pipeline(runner=runner, options=options)
from apache_beam.runners import pipeline_context
context = pipeline_context.PipelineContext(
proto.components, allow_proto_holders=allow_proto_holders)
proto.components,
allow_proto_holders=allow_proto_holders,
requirements=proto.requirements)
root_transform_id, = proto.root_transform_ids
p.transforms_stack = [context.transforms.get_by_id(root_transform_id)]
# TODO(robertwb): These are only needed to continue construction. Omit?
Expand Down
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/pipeline_test.py
Expand Up @@ -39,6 +39,7 @@
from apache_beam.pipeline import PipelineOptions
from apache_beam.pipeline import PipelineVisitor
from apache_beam.pipeline import PTransformOverride
from apache_beam.portability import common_urns
from apache_beam.pvalue import AsSingleton
from apache_beam.pvalue import TaggedOutput
from apache_beam.runners.dataflow.native_io.iobase import NativeSource
Expand Down Expand Up @@ -825,6 +826,16 @@ def expand(self, p):
self.assertEqual(
p.transforms_stack[0].parts[0].parent, p.transforms_stack[0])

def test_requirements(self):
p = beam.Pipeline()
_ = (
p | beam.Create([])
| beam.ParDo(lambda x, finalize=beam.DoFn.BundleFinalizerParam: None))
proto = p.to_runner_api()
self.assertTrue(
common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn,
proto.requirements)


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/portability/common_urns.py
Expand Up @@ -26,6 +26,7 @@
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardEnvironments
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardProtocols
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardPTransforms
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardRequirements
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardSideInputTypes
from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfo
from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfoSpecs
Expand Down Expand Up @@ -57,3 +58,4 @@
monitoring_info_labels = MonitoringInfo.MonitoringInfoLabels

protocols = StandardProtocols.Enum
requirements = StandardRequirements.Enum
13 changes: 13 additions & 0 deletions sdks/python/apache_beam/runners/common.py
Expand Up @@ -352,6 +352,19 @@ def has_timers(self):
_, all_timer_specs = userstate.get_dofn_specs(self.do_fn)
return bool(all_timer_specs)

def has_bundle_finalization(self):
for sig in (self.start_bundle_method,
self.process_method,
self.finish_bundle_method):
for d in sig.defaults:
try:
if d == DoFn.BundleFinalizerParam:
return True
except Exception: # pylint: disable=broad-except
# Default value might be incomparable.
pass
return False


class DoFnInvoker(object):
"""An abstraction that can be used to execute DoFn methods.
Expand Down
11 changes: 10 additions & 1 deletion sdks/python/apache_beam/runners/pipeline_context.py
Expand Up @@ -28,6 +28,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Union
Expand Down Expand Up @@ -143,7 +144,8 @@ def __init__(self,
iterable_state_read=None, # type: Optional[IterableStateReader]
iterable_state_write=None, # type: Optional[IterableStateWriter]
namespace='ref',
allow_proto_holders=False
allow_proto_holders=False,
requirements=(), # type: Iterable[str]
):
if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor):
proto = beam_runner_api_pb2.Components(
Expand Down Expand Up @@ -187,6 +189,13 @@ def __init__(self,
self.iterable_state_read = iterable_state_read
self.iterable_state_write = iterable_state_write
self.allow_proto_holders = allow_proto_holders
self._requirements = set(requirements)

def add_requirement(self, requirement):
self._requirements.add(requirement)

def requirements(self):
return frozenset(self._requirements)

# If fake coders are requested, return a pickled version of the element type
# rather than an actual coder. The element type is required for some runners,
Expand Down
Expand Up @@ -92,7 +92,8 @@ def with_pipeline(component, pcoll_id=None):
del pipeline_proto.components.transforms[transform_id]
return beam_expansion_api_pb2.ExpansionResponse(
components=pipeline_proto.components,
transform=expanded_transform_proto)
transform=expanded_transform_proto,
requirements=pipeline_proto.requirements)

except Exception: # pylint: disable=broad-except
return beam_expansion_api_pb2.ExpansionResponse(
Expand Down
50 changes: 50 additions & 0 deletions sdks/python/apache_beam/runners/portability/fn_api_runner.py
Expand Up @@ -466,6 +466,13 @@ def _next_uid(self):
self._last_uid += 1
return str(self._last_uid)

@staticmethod
def supported_requirements():
return (
common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn,
common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn,
)

def run_pipeline(self,
pipeline, # type: Pipeline
options # type: pipeline_options.PipelineOptions
Expand Down Expand Up @@ -511,6 +518,8 @@ def run_pipeline(self,

def run_via_runner_api(self, pipeline_proto):
# type: (beam_runner_api_pb2.Pipeline) -> RunnerResult
self._validate_requirements(pipeline_proto)
self._check_requirements(pipeline_proto)
stage_context, stages = self.create_stages(pipeline_proto)
# TODO(pabloem, BEAM-7514): Create a watermark manager (that has access to
# the teststream (if any), and all the stages).
Expand Down Expand Up @@ -561,6 +570,47 @@ def maybe_profile(self):
# Empty context.
yield

def _validate_requirements(self, pipeline_proto):
"""As a test runner, validate requirements were set correctly."""
expected_requirements = set()

def add_requirements(transform_id):
transform = pipeline_proto.components.transforms[transform_id]
if transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if payload.requests_finalization:
expected_requirements.add(
common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn)
if (payload.state_specs or payload.timer_specs or
payload.timer_family_specs):
expected_requirements.add(
common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn)
if payload.requires_stable_input:
expected_requirements.add(
common_urns.requirements.REQUIRES_STABLE_INPUT.urn)
if payload.requires_time_sorted_input:
expected_requirements.add(
common_urns.requirements.REQUIRES_TIME_SORTED_INPUT.urn)
else:
for sub in transform.subtransforms:
add_requirements(sub)

for root in pipeline_proto.root_transform_ids:
add_requirements(root)
if not expected_requirements.issubset(pipeline_proto.requirements):
raise ValueError(
'Missing requirement declaration: %s' %
(expected_requirements - set(pipeline_proto.requirements)))

def _check_requirements(self, pipeline_proto):
"""Check that this runner can satisfy all pipeline requirements."""
supported_requirements = set(self.supported_requirements())
for requirement in pipeline_proto.requirements:
if requirement not in supported_requirements:
raise ValueError(
'Unable to run pipeline with requirement: %s' % requirement)

def create_stages(
self,
pipeline_proto # type: beam_runner_api_pb2.Pipeline
Expand Down
15 changes: 12 additions & 3 deletions sdks/python/apache_beam/transforms/core.py
Expand Up @@ -1292,21 +1292,30 @@ def to_runner_api_parameter(self, context):
"expected instance of ParDo, but got %s" % self.__class__
picked_pardo_fn_data = pickler.dumps(self._pardo_fn_data())
state_specs, timer_specs = userstate.get_dofn_specs(self.fn)
if state_specs or timer_specs:
context.add_requirement(
common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn)
from apache_beam.runners.common import DoFnSignature
is_splittable = DoFnSignature(self.fn).is_splittable_dofn()
restriction_coder = DoFnSignature(self.fn).get_restriction_coder()
if restriction_coder:
sig = DoFnSignature(self.fn)
is_splittable = sig.is_splittable_dofn()
if is_splittable:
restriction_coder = sig.get_restriction_coder()
restriction_coder_id = context.coders.get_id(
restriction_coder) # type: typing.Optional[str]
else:
restriction_coder_id = None
has_bundle_finalization = sig.has_bundle_finalization()
if has_bundle_finalization:
context.add_requirement(
common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn)
return (
common_urns.primitives.PAR_DO.urn,
beam_runner_api_pb2.ParDoPayload(
do_fn=beam_runner_api_pb2.FunctionSpec(
urn=python_urns.PICKLED_DOFN_INFO,
payload=picked_pardo_fn_data),
splittable=is_splittable,
requests_finalization=has_bundle_finalization,
restriction_coder_id=restriction_coder_id,
state_specs={
spec.name: spec.to_runner_api(context)
Expand Down
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/transforms/external.py
Expand Up @@ -330,6 +330,7 @@ def expand(self, pvalueish):
raise RuntimeError(response.error)
self._expanded_components = response.components
self._expanded_transform = response.transform
self._expanded_requirements = response.requirements
result_context = pipeline_context.PipelineContext(response.components)

def fix_output(pcoll, tag):
Expand Down Expand Up @@ -422,6 +423,9 @@ def _normalize(coder_proto):
environment_id=proto.environment_id)
context.transforms.put_proto(id, new_proto)

for requirement in self._expanded_requirements:
context.add_requirement(requirement)

return beam_runner_api_pb2.PTransform(
unique_name=full_label,
spec=self._expanded_transform.spec,
Expand Down