-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
expansion_service.py
139 lines (129 loc) · 6.04 KB
/
expansion_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A PipelineExpansion service.
"""
# pytype: skip-file
import copy
import traceback
from apache_beam import pipeline as beam_pipeline
from apache_beam.options import pipeline_options
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_expansion_api_pb2
from apache_beam.portability.api import beam_expansion_api_pb2_grpc
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import artifact_service
from apache_beam.runners.portability.artifact_service import BeamFilesystemHandler
from apache_beam.transforms import environments
from apache_beam.transforms import external
from apache_beam.transforms import ptransform
class ExpansionServiceServicer(
beam_expansion_api_pb2_grpc.ExpansionServiceServicer):
def __init__(self, options=None, loopback_address=None):
self._options = options or beam_pipeline.PipelineOptions(
flags=[],
environment_type=python_urns.EMBEDDED_PYTHON,
sdk_location='container')
default_environment = (environments.Environment.from_options(self._options))
if loopback_address:
loopback_environment = environments.Environment.from_options(
beam_pipeline.PipelineOptions(
environment_type=common_urns.environments.EXTERNAL.urn,
environment_config=loopback_address))
default_environment = environments.AnyOfEnvironment(
[default_environment, loopback_environment])
self._default_environment = default_environment
def Expand(self, request, context=None):
try:
options = copy.deepcopy(self._options)
request_options = pipeline_options.PipelineOptions.from_runner_api(
request.pipeline_options)
# TODO(https://github.com/apache/beam/issues/20090): Figure out the
# correct subset of options to apply to expansion.
if request_options.view_as(
pipeline_options.StreamingOptions).update_compatibility_version:
options.view_as(
pipeline_options.StreamingOptions
).update_compatibility_version = request_options.view_as(
pipeline_options.StreamingOptions).update_compatibility_version
pipeline = beam_pipeline.Pipeline(options=options)
def with_pipeline(component, pcoll_id=None):
component.pipeline = pipeline
if pcoll_id:
component.producer, component.tag = producers[pcoll_id]
# We need the lookup to resolve back to this id.
context.pcollections._obj_to_id[component] = pcoll_id
return component
context = pipeline_context.PipelineContext(
request.components,
default_environment=self._default_environment,
namespace=request.namespace,
requirements=request.requirements)
producers = {
pcoll_id: (context.transforms.get_by_id(t_id), pcoll_tag)
for t_id,
t_proto in request.components.transforms.items() for pcoll_tag,
pcoll_id in t_proto.outputs.items()
}
transform = with_pipeline(
ptransform.PTransform.from_runner_api(request.transform, context))
if len(request.output_coder_requests) == 1:
output_coder = {
k: context.element_type_from_coder_id(v)
for k,
v in request.output_coder_requests.items()
}
transform = transform.with_output_types(list(output_coder.values())[0])
elif len(request.output_coder_requests) > 1:
raise ValueError(
'type annotation for multiple outputs is not allowed yet: %s' %
request.output_coder_requests)
inputs = transform._pvaluish_from_dict({
tag:
with_pipeline(context.pcollections.get_by_id(pcoll_id), pcoll_id)
for tag,
pcoll_id in request.transform.inputs.items()
})
if not inputs:
inputs = pipeline
with external.ExternalTransform.outer_namespace(request.namespace):
result = pipeline.apply(
transform, inputs, request.transform.unique_name)
expanded_transform = pipeline._root_transform().parts[-1]
# TODO(BEAM-1833): Use named outputs internally.
if isinstance(result, dict):
expanded_transform.outputs = result
pipeline_proto = pipeline.to_runner_api(context=context)
# TODO(BEAM-1833): Use named inputs internally.
expanded_transform_id = context.transforms.get_id(expanded_transform)
expanded_transform_proto = pipeline_proto.components.transforms.pop(
expanded_transform_id)
expanded_transform_proto.inputs.clear()
expanded_transform_proto.inputs.update(request.transform.inputs)
for transform_id in pipeline_proto.root_transform_ids:
del pipeline_proto.components.transforms[transform_id]
return beam_expansion_api_pb2.ExpansionResponse(
components=pipeline_proto.components,
transform=expanded_transform_proto,
requirements=pipeline_proto.requirements)
except Exception: # pylint: disable=broad-except
return beam_expansion_api_pb2.ExpansionResponse(
error=traceback.format_exc())
def artifact_service(self):
"""Returns a service to retrieve artifacts for use in a job."""
return artifact_service.ArtifactRetrievalService(
BeamFilesystemHandler(None).file_reader)