From 80932de281b51af3f55d0036cac60d7196a60c69 Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Sat, 24 Aug 2019 15:26:28 -0700 Subject: [PATCH] fixes --- sdks/python/apache_beam/coders/coder_impl.py | 8 +++- sdks/python/apache_beam/coders/coders.py | 25 ++++++++++-- sdks/python/apache_beam/io/iobase.py | 8 +++- .../apache_beam/options/pipeline_options.py | 5 ++- sdks/python/apache_beam/pipeline.py | 25 +++++++++--- .../runners/direct/bundle_factory.py | 1 + .../runners/direct/evaluation_context.py | 4 +- .../runners/direct/transform_evaluator.py | 10 +++-- .../runners/portability/fn_api_runner.py | 9 +++-- .../portability/fn_api_runner_transforms.py | 20 ++++++++++ sdks/python/apache_beam/runners/runner.py | 39 ++++++++++++++++--- .../runners/worker/bundle_processor.py | 17 ++++---- .../python/apache_beam/transforms/external.py | 4 ++ .../apache_beam/transforms/ptransform.py | 7 +++- sdks/python/apache_beam/utils/urns.py | 5 ++- 15 files changed, 149 insertions(+), 38 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 9a7522125dfa2..ade8df6ae56ff 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -750,8 +750,12 @@ class SequenceCoderImpl(StreamCoderImpl): # Default buffer size of 64kB of handling iterables of unknown length. _DEFAULT_BUFFER_SIZE = 64 * 1024 - def __init__(self, elem_coder, - read_state=None, write_state=None, write_state_threshold=0): + def __init__(self, + elem_coder, # type: Coder + read_state=None, # type: Optional[Callable[[bytes, Coder], Iterable]] + write_state=None, # type: Optional[Callable[[Iterable, Coder], bytes]] + write_state_threshold=0 # type: int + ): self._elem_coder = elem_coder self._read_state = read_state self._write_state = write_state diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 930f1451c6031..2adb014ddeff8 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -82,8 +82,12 @@ CoderT = TypeVar('CoderT', bound='Coder') ProtoCoderT = TypeVar('ProtoCoderT', bound='ProtoCoder') -ParameterType = Union['message.Message', bytes, None] -ConstructorFn = Callable[[ParameterType, List['Coder'], 'PipelineContext'], Any] +ParameterType = Union[Type['message.Message'], Type[bytes], None] +ConstructorFn = Callable[ + [Union['message.Message', bytes], + List['Coder'], + 'PipelineContext'], + Any] def serialize_coder(coder): @@ -268,11 +272,26 @@ def __hash__(self): _known_urns = {} # type: Dict[str, Tuple[ParameterType, ConstructorFn]] @classmethod + @typing.overload def register_urn(cls, urn, # type: str parameter_type, # type: ParameterType - fn=None # type: Optional[ConstructorFn] ): + # type: (...) -> ConstructorFn + pass + + @classmethod + @typing.overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: ParameterType + fn # type: ConstructorFn + ): + # type: (...) -> None + pass + + @classmethod + def register_urn(cls, urn, parameter_type, fn=None): """Registers a urn with a constructor. For example, if 'beam:fn:foo' had parameter type FooPayload, one could diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index d8da27e41ee7e..c303d641e47c8 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -99,6 +99,10 @@ class SourceBase(HasDisplayData, urns.RunnerApiFn, Generic[T]): """ urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE) + def is_bounded(self): + # type: () -> bool + raise NotImplementedError + class BoundedSource(SourceBase[T]): """A source that reads a finite amount of input records. @@ -861,7 +865,7 @@ class Read(ptransform.PTransform[pvalue.PBeginType, OutT]): """A transform that reads a PCollection.""" def __init__(self, source): - # type: (BoundedSource) -> None + # type: (SourceBase) -> None """Initializes a Read transform. Args: @@ -926,6 +930,7 @@ def display_data(self): 'source_dd': self.source} def to_runner_api_parameter(self, context): + # type: (PipelineContext) -> Tuple[str, ptransform.ParameterType] return (common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload( source=self.source.to_runner_api(context), @@ -935,6 +940,7 @@ def to_runner_api_parameter(self, context): @staticmethod def from_runner_api_parameter(parameter, context): + # type: (beam_runner_api_pb2.ReadPayload, PipelineContext) -> Read return Read(SourceBase.from_runner_api(parameter.source, context)) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 3b1abf88db23a..f15fadf7c52bc 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -27,6 +27,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Optional from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.options.value_provider import StaticValueProvider @@ -157,7 +158,9 @@ def _add_argparse_args(cls, parser): By default the options classes will use command line arguments to initialize the options. """ - def __init__(self, flags=None, **kwargs): + def __init__(self, + flags=None, # type: Optional[List[str]] + **kwargs): """Initialize an options class. The initializer will traverse all subclasses, add all their argparse diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index a987cade0bd80..d881075a43a3d 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -57,6 +57,7 @@ from builtins import zip from typing import Dict from typing import List +from typing import Optional from typing import Union from future.utils import with_metaclass @@ -83,6 +84,7 @@ if typing.TYPE_CHECKING: from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners import PipelineResult from apache_beam.runners.pipeline_context import PipelineContext __all__ = ['Pipeline', 'PTransformOverride'] @@ -103,7 +105,11 @@ class Pipeline(object): (e.g. ``input | "label" >> my_tranform``). """ - def __init__(self, runner=None, options=None, argv=None): + def __init__(self, + runner=None, # type: Optional[PipelineRunner] + options=None, # type: Optional[PipelineOptions] + argv=None # type: Optional[List[str]] + ): """Initialize a pipeline object. Args: @@ -405,6 +411,7 @@ def replace_all(self, replacements): self._check_replacement(override) def run(self, test_runner_api=True): + # type: (bool) -> PipelineResult """Runs the pipeline. Returns whatever our runner returns after running.""" # When possible, invoke a round trip through the runner API. @@ -435,6 +442,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.run().wait_until_finish() def visit(self, visitor): + # type: (PipelineVisitor) -> None """Visits depth-first every node of a pipeline's DAG. Runner-internal implementation detail; no backwards-compatibility guarantees @@ -677,8 +685,12 @@ def visit_transform(self, transform_node): return proto @staticmethod - def from_runner_api(proto, runner, options, return_context=False, - allow_proto_holders=False): + def from_runner_api(proto, # type: beam_runner_api_pb2.Pipeline + runner, # type: PipelineRunner + options, # type: PipelineOptions + return_context=False, + allow_proto_holders=False + ): # type: (...) -> Pipeline """For internal use only; no backwards-compatibility guarantees.""" p = Pipeline(runner=runner, options=options) @@ -753,7 +765,7 @@ def __init__(self, parent, transform, # type: ptransform.PTransform full_label, - inputs + inputs # type: Iterable[pvalue.PCollection] ): self.parent = parent self.transform = transform @@ -765,7 +777,7 @@ def __init__(self, self.full_label = full_label self.inputs = inputs or () self.side_inputs = () if transform is None else tuple(transform.side_inputs) - self.outputs = {} # type: Dict[Union[str, int, None], pvalue.PValue] + self.outputs = {} # type: Dict[Union[str, int, None], Union[pvalue.PValue, pvalue.DoOutputsTuple]] self.parts = [] # type: List[AppliedPTransform] def __repr__(self): @@ -870,6 +882,7 @@ def visit(self, visitor, pipeline, visited): visitor.visit_value(v, self) def named_inputs(self): + # type: () -> Dict[str, pvalue.PCollection] # TODO(BEAM-1833): Push names up into the sdk construction. main_inputs = {str(ix): input for ix, input in enumerate(self.inputs) @@ -879,6 +892,7 @@ def named_inputs(self): return dict(main_inputs, **side_inputs) def named_outputs(self): + # type: () -> Dict[str, pvalue.PCollection] return {str(tag): output for tag, output in self.outputs.items() if isinstance(output, pvalue.PCollection)} @@ -912,6 +926,7 @@ def transform_to_runner_api(transform, context): @staticmethod def from_runner_api(proto, context): + # type: (beam_runner_api_pb2.PTransform, PipelineContext) -> AppliedPTransform def is_side_input(tag): # As per named_inputs() above. return tag.startswith('side') diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py index 990a097127af4..5a10bcfd70319 100644 --- a/sdks/python/apache_beam/runners/direct/bundle_factory.py +++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py @@ -45,6 +45,7 @@ def create_bundle(self, output_pcollection): return _Bundle(output_pcollection, self._stacked) def create_empty_committed_bundle(self, output_pcollection): + # type: (pvalue.PCollection) -> _Bundle bundle = self.create_bundle(output_pcollection) bundle.commit(None) return bundle diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index c2784965cab84..d10c57baa43a4 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -40,7 +40,7 @@ if typing.TYPE_CHECKING: from apache_beam.pipeline import AppliedPTransform from apache_beam.pvalue import AsSideInput, PCollection - from apache_beam.runners.direct.bundle_factory import BundleFactory + from apache_beam.runners.direct.bundle_factory import BundleFactory, _Bundle from apache_beam.utils.timestamp import Timestamp class _ExecutionContext(object): @@ -375,10 +375,12 @@ def get_execution_context(self, applied_ptransform): self._transform_keyed_states[applied_ptransform]) def create_bundle(self, output_pcollection): + # type: (pvalue.PCollection) -> _Bundle """Create an uncommitted bundle for the specified PCollection.""" return self._bundle_factory.create_bundle(output_pcollection) def create_empty_committed_bundle(self, output_pcollection): + # type: (pvalue.PCollection) -> _Bundle """Create empty bundle useful for triggering evaluation.""" return self._bundle_factory.create_empty_committed_bundle( output_pcollection) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index be989273e361e..33aeb64b49a25 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -567,9 +567,13 @@ def __missing__(self, key): class _ParDoEvaluator(_TransformEvaluator): """TransformEvaluator for ParDo transform.""" - def __init__(self, evaluation_context, applied_ptransform, - input_committed_bundle, side_inputs, - perform_dofn_pickle_test=True): + def __init__(self, + evaluation_context, # type: EvaluationContext + applied_ptransform, # type: AppliedPTransform + input_committed_bundle, + side_inputs, + perform_dofn_pickle_test=True + ): super(_ParDoEvaluator, self).__init__( evaluation_context, applied_ptransform, input_committed_bundle, side_inputs) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index 842519aef0605..7a2c0583bc4c2 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -85,10 +85,11 @@ from google.protobuf import message from apache_beam.runners.portability import fn_api_runner -ConstructorFn = Callable[[Union['message.Message', bytes], - 'FnApiRunner.StateServicer', - Optional['fn_api_runner.ExtendedProvisionInfo']], - Any] +ConstructorFn = Callable[ + [Union['message.Message', bytes], + 'FnApiRunner.StateServicer', + Optional['fn_api_runner.ExtendedProvisionInfo']], + Any] # This module is experimental. No backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py index 4f3e2f92c8a30..2f2919445bbde 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py @@ -24,6 +24,10 @@ import functools import logging from builtins import object +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Tuple from past.builtins import unicode @@ -458,6 +462,7 @@ def create_and_optimize_stages(pipeline_proto, phases, known_runner_urns, use_state_iterables=False): + # type: (...) -> Tuple[TransformContext, List[Stage]] """Create a set of stages given a pipeline proto, and set of optimizations. Args: @@ -513,6 +518,7 @@ def optimize_pipeline( def annotate_downstream_side_inputs(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] """Annotate each stage with fusion-prohibiting information. Each stage is annotated with the (transitive) set of pcollections that @@ -560,6 +566,7 @@ def compute_downstream_side_inputs(stage): def annotate_stateful_dofns_as_roots(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] for stage in stages: for transform in stage.transforms: if transform.spec.urn == common_urns.primitives.PAR_DO.urn: @@ -571,6 +578,7 @@ def annotate_stateful_dofns_as_roots(stages, pipeline_context): def fix_side_input_pcoll_coders(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] """Length prefix side input PCollection coders. """ for stage in stages: @@ -580,6 +588,7 @@ def fix_side_input_pcoll_coders(stages, pipeline_context): def lift_combiners(stages, context): + # type: (List[Stage], TransformContext) -> Iterator[Stage] """Expands CombinePerKey into pre- and post-grouping stages. ... -> CombinePerKey -> ... @@ -709,6 +718,7 @@ def make_stage(base_stage, transform): def expand_sdf(stages, context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Transforms splitable DoFns into pair+split+read.""" for stage in stages: assert len(stage.transforms) == 1 @@ -850,6 +860,7 @@ def make_stage(base_stage, transform_id, extra_must_follow=()): def expand_gbk(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Transforms each GBK into a write followed by a read. """ for stage in stages: @@ -899,6 +910,7 @@ def expand_gbk(stages, pipeline_context): def fix_flatten_coders(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Ensures that the inputs of Flatten have the same coders as the output. """ pcollections = pipeline_context.components.pcollections @@ -939,6 +951,7 @@ def fix_flatten_coders(stages, pipeline_context): def sink_flattens(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Sink flattens and remove them from the graph. A flatten that cannot be sunk/fused away becomes multiple writes (to the @@ -1070,6 +1083,7 @@ def fuse(producer, consumer): def read_to_impulse(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Translates Read operations into Impulse operations.""" for stage in stages: # First map Reads, if any, to Impulse + triggered read op. @@ -1107,6 +1121,7 @@ def read_to_impulse(stages, pipeline_context): def impulse_to_input(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Translates Impulse operations into GRPC reads.""" for stage in stages: for transform in list(stage.transforms): @@ -1123,6 +1138,7 @@ def impulse_to_input(stages, pipeline_context): def extract_impulse_stages(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Splits fused Impulse operations into their own stage.""" for stage in stages: for transform in list(stage.transforms): @@ -1140,6 +1156,7 @@ def extract_impulse_stages(stages, pipeline_context): def remove_data_plane_ops(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] for stage in stages: for transform in list(stage.transforms): if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, @@ -1151,6 +1168,7 @@ def remove_data_plane_ops(stages, pipeline_context): def inject_timer_pcollections(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Create PCollections for fired timers and to-be-set timers. At execution time, fired timers and timers-to-set are represented as @@ -1224,6 +1242,7 @@ def inject_timer_pcollections(stages, pipeline_context): def sort_stages(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> List[Stage] """Order stages suitable for sequential execution. """ all_stages = set(stages) @@ -1244,6 +1263,7 @@ def process(stage): def window_pcollection_coders(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] """Wrap all PCollection coders as windowed value coders. This is required as some SDK workers require windowed coders for their diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index 88ad0f37a583f..d4cb87aea25ba 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -27,6 +27,15 @@ import tempfile import typing from builtins import object +from typing import Optional + +if typing.TYPE_CHECKING: + from apache_beam import pvalue + from apache_beam import PTransform + from apache_beam.options.pipeline_options import PipelineOptions + from apache_beam.pipeline import AppliedPTransform + from apache_beam.pipeline import Pipeline + from apache_beam.pipeline import PipelineVisitor __all__ = ['PipelineRunner', 'PipelineState', 'PipelineResult'] @@ -106,7 +115,11 @@ class PipelineRunner(object): materialized values in order to reduce footprint. """ - def run(self, transform, options=None): + def run(self, + transform, # type: PTransform + options=None # type: Optional[PipelineOptions] + ): + # type: (...) -> PipelineResult """Run the given transform or callable with this runner. Blocks until the pipeline is complete. See also `PipelineRunner.run_async`. @@ -115,7 +128,11 @@ def run(self, transform, options=None): result.wait_until_finish() return result - def run_async(self, transform, options=None): + def run_async(self, + transform, # type: PTransform + options=None # type: Optional[PipelineOptions] + ): + # type: (...) -> PipelineResult """Run the given transform or callable with this runner. May return immediately, executing the pipeline in the background. @@ -134,7 +151,10 @@ def run_async(self, transform, options=None): transform(PBegin(p)) return p.run() - def run_pipeline(self, pipeline, options): + def run_pipeline(self, + pipeline, # type: Pipeline + options # type: PipelineOptions + ): """Execute the entire pipeline or the sub-DAG reachable from a node. Runners should override this method. @@ -147,6 +167,7 @@ def run_pipeline(self, pipeline, options): class RunVisitor(PipelineVisitor): def __init__(self, runner): + # type: (PipelineRunner) -> None self.runner = runner def visit_transform(self, transform_node): @@ -158,7 +179,11 @@ def visit_transform(self, transform_node): pipeline.visit(RunVisitor(self)) - def apply(self, transform, input, options): + def apply(self, + transform, # type: PTransform + input, # type: pvalue.PCollection + options # type: PipelineOptions + ): """Runner callback for a pipeline.apply call. Args: @@ -181,7 +206,10 @@ def apply_PTransform(self, transform, input, options): # The base case of apply is to call the transform's expand. return transform.expand(input) - def run_transform(self, transform_node, options): + def run_transform(self, + transform_node, # type: AppliedPTransform + options # type: PipelineOptions + ): """Runner callback for a pipeline.run call. Args: @@ -295,6 +323,7 @@ def key(self, pobj): return self.to_cache_key(pobj.real_producer, pobj.tag) +# FIXME: replace with PipelineState(str, enum.Enum) class PipelineState(object): """State of the Pipeline, as returned by :attr:`PipelineResult.state`. diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index ba6be2ad14175..ab2cdc035fb9d 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -73,17 +73,14 @@ from apache_beam.runners.worker import data_plane # This module is experimental. No backwards-compatibility guarantees. -ParameterType = Union['message.Message', bytes, None] +ParameterType = Union[Type['message.Message'], Type[bytes], None] ConstructorFn = Callable[ - [ - 'BeamTransformFactory', - ParameterType, - beam_runner_api_pb2.PTransform, - 'PipelineContext', - Dict[str, operations.Operation] - ], - operations.Operation -] + ['BeamTransformFactory', + Union['message.Message', bytes], + beam_runner_api_pb2.PTransform, + 'PipelineContext', + Dict[str, operations.Operation]], + operations.Operation] DATA_INPUT_URN = 'beam:source:runner:0.1' DATA_OUTPUT_URN = 'beam:sink:runner:0.1' diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index a7e9fb504be9c..61988c2f4ab77 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -25,6 +25,7 @@ import contextlib import copy import threading +from typing import Dict from apache_beam import pvalue from apache_beam.portability import common_urns @@ -64,6 +65,8 @@ def __init__(self, urn, payload, endpoint): self._payload = payload self._endpoint = endpoint self._namespace = self._fresh_namespace() + self._inputs = {} # type: Dict[str, pvalue.PCollection] + self._output = {} # type: Dict[str, pvalue.PCollection] def default_label(self): return '%s(%s)' % (self.__class__.__name__, self._urn) @@ -86,6 +89,7 @@ def _fresh_namespace(cls): return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter) def expand(self, pvalueish): + # type: (pvalue.PCollection) -> pvalue.PCollection if isinstance(pvalueish, pvalue.PBegin): self._inputs = {} elif isinstance(pvalueish, (list, tuple)): diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 7e53428a117dc..432d46b3bc388 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -94,8 +94,10 @@ class and wrapper class that allows lambda functions to be used as ] PTransformT = TypeVar('PTransformT', bound='PTransform') -ParameterType = Union['message.Message', bytes, None] -ConstructorFn = Callable[[ParameterType, 'PipelineContext'], Any] +ParameterType = Union[Type['message.Message'], Type[bytes], None] +ConstructorFn = Callable[ + [Union['message.Message', bytes], 'PipelineContext'], + Any] T = TypeVar('T') PValueT = TypeVar('PValueT', bound=pvalue.PValue) @@ -692,6 +694,7 @@ def from_runner_api(cls, proto, context): raise def to_runner_api_parameter(self, unused_context): + # type: (PipelineContext) -> Tuple[str, ParameterType] # The payload here is just to ease debugging. return (python_urns.GENERIC_COMPOSITE_TRANSFORM, getattr(self, '_fn_api_payload', str(self))) diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index b7a8468750c21..2a68387d24cd2 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -42,7 +42,10 @@ from apache_beam.runners.pipeline_context import PipelineContext ParameterType = Union[Type[message.Message], Type[bytes], None] -ConstructorFn = Callable[[ParameterType, 'PipelineContext'], Any] +ConstructorFn = Callable[ + [Union['message.Message', bytes], + 'PipelineContext'], + Any] class RunnerApiFn(object):