Skip to content

Commit

Permalink
[BEAM-3450] Log when the runner is not properly setting the coder.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb committed Oct 18, 2018
1 parent 22f59de commit 8d7ad50
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,17 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
.setPrimitiveTransformReference(pTransformId)
.setName(getOnlyElement(pTransform.getOutputsMap().keySet()))
.build();
RunnerApi.Coder coderSpec =
coders.get(
pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
RunnerApi.Coder coderSpec;
if (RemoteGrpcPortRead.fromPTransform(pTransform).getPort().getCoderId().isEmpty()) {
LOG.error(
"Missing required coder_id on grpc_port for %s; using deprecated fallback.",
pTransformId);
coderSpec =
coders.get(
pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
} else {
coderSpec = null;
}
Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers =
(Collection)
pCollectionIdsToConsumers.get(getOnlyElement(pTransform.getOutputsMap().values()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import org.apache.beam.sdk.fn.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.WindowedValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Registers as a consumer with the Beam Fn Data Api. Consumes elements and encodes them for
Expand All @@ -57,6 +59,8 @@
*/
public class BeamFnDataWriteRunner<InputT> {

private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataWriteRunner.class);

/** A registrar which provides a factory to handle writing to the Fn Api Data Plane. */
@AutoService(PTransformRunnerFactory.Registrar.class)
public static class Registrar implements PTransformRunnerFactory.Registrar {
Expand Down Expand Up @@ -91,9 +95,17 @@ public BeamFnDataWriteRunner<InputT> createRunnerForPTransform(
.setPrimitiveTransformReference(pTransformId)
.setName(getOnlyElement(pTransform.getInputsMap().keySet()))
.build();
RunnerApi.Coder coderSpec =
coders.get(
pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId());
RunnerApi.Coder coderSpec;
if (RemoteGrpcPortWrite.fromPTransform(pTransform).getPort().getCoderId().isEmpty()) {
LOG.error(
"Missing required coder_id on grpc_port for %s; using deprecated fallback.",
pTransformId);
coderSpec =
coders.get(
pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId());
} else {
coderSpec = null;
}
BeamFnDataWriteRunner<InputT> runner =
new BeamFnDataWriteRunner<>(
pTransform, processBundleInstructionId, target, coderSpec, coders, beamFnDataClient);
Expand Down
24 changes: 18 additions & 6 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,15 +582,21 @@ def create(factory, transform_id, transform_proto, grpc_port, consumers):
target = beam_fn_api_pb2.Target(
primitive_transform_reference=transform_id,
name=only_element(list(transform_proto.outputs.keys())))
if grpc_port.coder_id:
output_coder = factory.get_coder(grpc_port.coder_id)
else:
logging.error(
'Missing required coder_id on grpc_port for %s; '
'using deprecated fallback.',
transform_id)
output_coder = factory.get_only_output_coder(transform_proto)
return DataInputOperation(
transform_proto.unique_name,
transform_proto.unique_name,
consumers,
factory.counter_factory,
factory.state_sampler,
factory.get_coder(grpc_port.coder_id)
if grpc_port.coder_id
else factory.get_only_output_coder(transform_proto),
output_coder,
input_target=target,
data_channel=factory.data_channel_factory.create_data_channel(grpc_port))

Expand All @@ -601,15 +607,21 @@ def create(factory, transform_id, transform_proto, grpc_port, consumers):
target = beam_fn_api_pb2.Target(
primitive_transform_reference=transform_id,
name=only_element(list(transform_proto.inputs.keys())))
if grpc_port.coder_id:
output_coder = factory.get_coder(grpc_port.coder_id)
else:
logging.error(
'Missing required coder_id on grpc_port for %s; '
'using deprecated fallback.',
transform_id)
output_coder = factory.get_only_input_coder(transform_proto)
return DataOutputOperation(
transform_proto.unique_name,
transform_proto.unique_name,
consumers,
factory.counter_factory,
factory.state_sampler,
factory.get_coder(grpc_port.coder_id)
if grpc_port.coder_id
else factory.get_only_input_coder(transform_proto),
output_coder,
target=target,
data_channel=factory.data_channel_factory.create_data_channel(grpc_port))

Expand Down

0 comments on commit 8d7ad50

Please sign in to comment.