Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-4150] Remove fallback case for coder not specified within RemoteGrpcPort. #10755

Merged
merged 2 commits into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ public Node apply(MutableNetwork<Node, Edge> input) {
ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews =
ImmutableMap.builder();
ImmutableMap.Builder<String, NameContext> pcollectionIdToNameContexts = ImmutableMap.builder();
ImmutableMap.Builder<InstructionOutputNode, String> instructionOutputNodeToCoderIdBuilder =
ImmutableMap.builder();

// For each instruction output node:
// 1. Generate new Coder and register it with SDKComponents and ProcessBundleDescriptor.
Expand All @@ -225,6 +227,7 @@ public Node apply(MutableNetwork<Node, Edge> input) {
InstructionOutput instructionOutput = node.getInstructionOutput();

String coderId = "generatedCoder" + idGenerator.getId();
instructionOutputNodeToCoderIdBuilder.put(node, coderId);
try (ByteString.Output output = ByteString.newOutput()) {
try {
Coder<?> javaCoder =
Expand Down Expand Up @@ -274,6 +277,8 @@ public Node apply(MutableNetwork<Node, Edge> input) {
instructionOutput.getName()));
}
processBundleDescriptor.putAllCoders(sdkComponents.toComponents().getCodersMap());
Map<InstructionOutputNode, String> instructionOutputNodeToCoderIdMap =
instructionOutputNodeToCoderIdBuilder.build();

for (ParallelInstructionNode node :
Iterables.filter(input.nodes(), ParallelInstructionNode.class)) {
Expand Down Expand Up @@ -408,22 +413,33 @@ public Node apply(MutableNetwork<Node, Edge> input) {
Set<Node> predecessors = input.predecessors(node);
Set<Node> successors = input.successors(node);
if (predecessors.isEmpty() && !successors.isEmpty()) {
Node instructionOutputNode = Iterables.getOnlyElement(successors);
pTransform.putOutputs(
"generatedOutput" + idGenerator.getId(),
nodesToPCollections.get(Iterables.getOnlyElement(successors)));
nodesToPCollections.get(instructionOutputNode));
pTransform.setSpec(
RunnerApi.FunctionSpec.newBuilder()
.setUrn(DATA_INPUT_URN)
.setPayload(node.getRemoteGrpcPort().toByteString())
.setPayload(
node.getRemoteGrpcPort()
.toBuilder()
.setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode))
.build()
.toByteString())
.build());
} else if (!predecessors.isEmpty() && successors.isEmpty()) {
Node instructionOutputNode = Iterables.getOnlyElement(predecessors);
pTransform.putInputs(
"generatedInput" + idGenerator.getId(),
nodesToPCollections.get(Iterables.getOnlyElement(predecessors)));
"generatedInput" + idGenerator.getId(), nodesToPCollections.get(instructionOutputNode));
pTransform.setSpec(
RunnerApi.FunctionSpec.newBuilder()
.setUrn(DATA_OUTPUT_URN)
.setPayload(node.getRemoteGrpcPort().toByteString())
.setPayload(
node.getRemoteGrpcPort()
.toBuilder()
.setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode))
.build()
.toByteString())
.build());
} else {
throw new IllegalStateException(
Expand Down
49 changes: 14 additions & 35 deletions sdks/go/pkg/beam/core/runtime/exec/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ func UnmarshalPlan(desc *fnpb.ProcessBundleDescriptor) (*Plan, error) {
}

u := &DataSource{UID: b.idgen.New()}
u.Coder, err = b.coders.Coder(cid) // Expected to be windowed coder
if err != nil {
return nil, err
}
if !coder.IsW(u.Coder) {
return nil, errors.Errorf("unwindowed coder %v on DataSource %v: %v", cid, id, u.Coder)
}

for key, pid := range transform.GetOutputs() {
u.SID = StreamID{PtransformID: id, Port: port}
Expand All @@ -73,22 +80,6 @@ func UnmarshalPlan(desc *fnpb.ProcessBundleDescriptor) (*Plan, error) {
if err != nil {
return nil, err
}

if cid == "" {
c, wc, err := b.makeCoderForPCollection(pid)
if err != nil {
return nil, err
}
u.Coder = coder.NewW(c, wc)
} else {
u.Coder, err = b.coders.Coder(cid) // Expected to be windowed coder
if err != nil {
return nil, err
}
if !coder.IsW(u.Coder) {
return nil, errors.Errorf("unwindowed coder %v on DataSource %v: %v", cid, id, u.Coder)
}
}
}

b.units = append(b.units, u)
Expand Down Expand Up @@ -500,25 +491,13 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
}

sink := &DataSink{UID: b.idgen.New()}

for _, pid := range transform.GetInputs() {
sink.SID = StreamID{PtransformID: id.to, Port: port}

if cid == "" {
c, wc, err := b.makeCoderForPCollection(pid)
if err != nil {
return nil, err
}
sink.Coder = coder.NewW(c, wc)
} else {
sink.Coder, err = b.coders.Coder(cid) // Expected to be windowed coder
if err != nil {
return nil, err
}
if !coder.IsW(sink.Coder) {
return nil, errors.Errorf("unwindowed coder %v on DataSink %v: %v", cid, id, sink.Coder)
}
}
sink.SID = StreamID{PtransformID: id.to, Port: port}
sink.Coder, err = b.coders.Coder(cid) // Expected to be windowed coder
if err != nil {
return nil, err
}
if !coder.IsW(sink.Coder) {
return nil, errors.Errorf("unwindowed coder %v on DataSink %v: %v", cid, id, sink.Coder)
}
u = sink

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,6 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
BundleSplitListener splitListener)
throws IOException {

RunnerApi.Coder coderSpec;
if (RemoteGrpcPortRead.fromPTransform(pTransform).getPort().getCoderId().isEmpty()) {
LOG.error(
"Missing required coder_id on grpc_port for %s; using deprecated fallback.",
pTransformId);
coderSpec =
coders.get(
pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
} else {
coderSpec = null;
}
FnDataReceiver<WindowedValue<OutputT>> consumer =
(FnDataReceiver<WindowedValue<OutputT>>)
(FnDataReceiver)
Expand All @@ -120,7 +109,6 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
pTransformId,
pTransform,
processBundleInstructionId,
coderSpec,
coders,
beamFnDataClient,
consumer);
Expand Down Expand Up @@ -148,7 +136,6 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
String pTransformId,
RunnerApi.PTransform grpcReadNode,
Supplier<String> processBundleInstructionIdSupplier,
RunnerApi.Coder coderSpec,
Map<String, RunnerApi.Coder> coders,
BeamFnDataClient beamFnDataClient,
FnDataReceiver<WindowedValue<OutputT>> consumer)
Expand All @@ -162,17 +149,9 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(

RehydratedComponents components =
RehydratedComponents.forComponents(Components.newBuilder().putAllCoders(coders).build());
@SuppressWarnings("unchecked")
Coder<WindowedValue<OutputT>> coder;
if (!port.getCoderId().isEmpty()) {
coder =
(Coder<WindowedValue<OutputT>>)
CoderTranslation.fromProto(coders.get(port.getCoderId()), components);
} else {
// TODO: Remove this path once it is no longer used
coder = (Coder<WindowedValue<OutputT>>) CoderTranslation.fromProto(coderSpec, components);
}
this.coder = coder;
this.coder =
(Coder<WindowedValue<OutputT>>)
CoderTranslation.fromProto(coders.get(port.getCoderId()), components);
}

public void registerInputLocation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,10 @@ public BeamFnDataWriteRunner<InputT> createRunnerForPTransform(
Consumer<ThrowingRunnable> tearDownFunctions,
BundleSplitListener splitListener)
throws IOException {
RunnerApi.Coder coderSpec;
if (RemoteGrpcPortWrite.fromPTransform(pTransform).getPort().getCoderId().isEmpty()) {
LOG.error(
"Missing required coder_id on grpc_port for %s; using deprecated fallback.",
pTransformId);
coderSpec =
coders.get(
pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId());
} else {
coderSpec = null;
}

BeamFnDataWriteRunner<InputT> runner =
new BeamFnDataWriteRunner<>(
pTransformId,
pTransform,
processBundleInstructionId,
coderSpec,
coders,
beamFnDataClient);
pTransformId, pTransform, processBundleInstructionId, coders, beamFnDataClient);
startFunctionRegistry.register(pTransformId, runner::registerForOutput);
pCollectionConsumerRegistry.register(
getOnlyElement(pTransform.getInputsMap().values()),
Expand All @@ -133,7 +118,6 @@ public BeamFnDataWriteRunner<InputT> createRunnerForPTransform(
String pTransformId,
RunnerApi.PTransform remoteWriteNode,
Supplier<String> processBundleInstructionIdSupplier,
RunnerApi.Coder coderSpec,
Map<String, RunnerApi.Coder> coders,
BeamFnDataClient beamFnDataClientFactory)
throws IOException {
Expand All @@ -145,17 +129,9 @@ public BeamFnDataWriteRunner<InputT> createRunnerForPTransform(

RehydratedComponents components =
RehydratedComponents.forComponents(Components.newBuilder().putAllCoders(coders).build());
@SuppressWarnings("unchecked")
Coder<WindowedValue<InputT>> coder;
if (!port.getCoderId().isEmpty()) {
coder =
(Coder<WindowedValue<InputT>>)
CoderTranslation.fromProto(coders.get(port.getCoderId()), components);
} else {
// TODO: remove this path once it is no longer used
coder = (Coder<WindowedValue<InputT>>) CoderTranslation.fromProto(coderSpec, components);
}
this.coder = coder;
this.coder =
(Coder<WindowedValue<InputT>>)
CoderTranslation.fromProto(coders.get(port.getCoderId()), components);
}

public void registerForOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ public void testReuseForMultipleBundles() throws Exception {
INPUT_TRANSFORM_ID,
RemoteGrpcPortRead.readFromPort(PORT_SPEC, "localOutput").toPTransform(),
bundleId::get,
CODER_SPEC,
COMPONENTS.getCodersMap(),
mockBeamFnDataClient,
consumers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ public void testReuseForMultipleBundles() throws Exception {
TRANSFORM_ID,
RemoteGrpcPortWrite.writeToPort("myWrite", PORT_SPEC).toPTransform(),
bundleId::get,
WIRE_CODER_SPEC,
COMPONENTS.getCodersMap(),
mockBeamFnDataClient);

Expand Down
18 changes: 2 additions & 16 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,14 +1153,7 @@ def create_source_runner(
output_consumers[:] = [TimerConsumer(tag, do_op)]
break

if grpc_port.coder_id:
output_coder = factory.get_coder(grpc_port.coder_id)
else:
_LOGGER.info(
'Missing required coder_id on grpc_port for %s; '
'using deprecated fallback.',
transform_id)
output_coder = factory.get_only_output_coder(transform_proto)
output_coder = factory.get_coder(grpc_port.coder_id)
return DataInputOperation(
common.NameContext(transform_proto.unique_name, transform_id),
transform_proto.unique_name,
Expand All @@ -1182,14 +1175,7 @@ def create_sink_runner(
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> DataOutputOperation
if grpc_port.coder_id:
output_coder = factory.get_coder(grpc_port.coder_id)
else:
_LOGGER.info(
'Missing required coder_id on grpc_port for %s; '
'using deprecated fallback.',
transform_id)
output_coder = factory.get_only_input_coder(transform_proto)
output_coder = factory.get_coder(grpc_port.coder_id)
return DataOutputOperation(
common.NameContext(transform_proto.unique_name, transform_id),
transform_proto.unique_name,
Expand Down