Skip to content

Commit

Permalink
Replace ThrowingConsumer with FnDataReceiver
Browse files Browse the repository at this point in the history
  • Loading branch information
tgroh committed Dec 12, 2017
1 parent f58dab3 commit d799fe0
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.sdk.coders.Coder;
Expand Down Expand Up @@ -83,11 +84,11 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
BeamFnDataClient beamFnDataClient,
BeamFnStateClient beamFnStateClient,
String pTransformId,
RunnerApi.PTransform pTransform,
PTransform pTransform,
Supplier<String> processBundleInstructionId,
Map<String, RunnerApi.PCollection> pCollections,
Map<String, PCollection> pCollections,
Map<String, RunnerApi.Coder> coders,
Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
Consumer<ThrowingRunnable> addFinishFunction) throws IOException {

Expand All @@ -98,7 +99,7 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
RunnerApi.Coder coderSpec =
coders.get(
pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers =
Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers =
(Collection) pCollectionIdsToConsumers.get(
getOnlyElement(pTransform.getOutputsMap().values()));

Expand Down Expand Up @@ -132,7 +133,7 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
RunnerApi.Coder coderSpec,
Map<String, RunnerApi.Coder> coders,
BeamFnDataClient beamFnDataClientFactory,
Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers)
Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers)
throws IOException {
this.apiServiceDescriptor =
BeamFnApi.RemoteGrpcPort.parseFrom(functionSpec.getPayload()).getApiServiceDescriptor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.WindowedValue;
Expand Down Expand Up @@ -76,11 +78,11 @@ public BeamFnDataWriteRunner<InputT> createRunnerForPTransform(
BeamFnDataClient beamFnDataClient,
BeamFnStateClient beamFnStateClient,
String pTransformId,
RunnerApi.PTransform pTransform,
PTransform pTransform,
Supplier<String> processBundleInstructionId,
Map<String, RunnerApi.PCollection> pCollections,
Map<String, PCollection> pCollections,
Map<String, RunnerApi.Coder> coders,
Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
Expand All @@ -100,8 +102,8 @@ public BeamFnDataWriteRunner<InputT> createRunnerForPTransform(
addStartFunction.accept(runner::registerForOutput);
pCollectionIdsToConsumers.put(
getOnlyElement(pTransform.getInputsMap().values()),
(ThrowingConsumer)
(ThrowingConsumer<WindowedValue<InputT>>) runner::consume);
(FnDataReceiver)
(FnDataReceiver<WindowedValue<InputT>>) runner::consume);
addFinishFunction.accept(runner::close);
return runner;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.Source.Reader;
import org.apache.beam.sdk.options.PipelineOptions;
Expand Down Expand Up @@ -67,15 +70,15 @@ public BoundedSourceRunner<InputT, OutputT> createRunnerForPTransform(
BeamFnDataClient beamFnDataClient,
BeamFnStateClient beamFnStateClient,
String pTransformId,
RunnerApi.PTransform pTransform,
PTransform pTransform,
Supplier<String> processBundleInstructionId,
Map<String, RunnerApi.PCollection> pCollections,
Map<String, RunnerApi.Coder> coders,
Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
Map<String, PCollection> pCollections,
Map<String, Coder> coders,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
Consumer<ThrowingRunnable> addFinishFunction) {

ImmutableList.Builder<ThrowingConsumer<WindowedValue<?>>> consumers = ImmutableList.builder();
ImmutableList.Builder<FnDataReceiver<WindowedValue<?>>> consumers = ImmutableList.builder();
for (String pCollectionId : pTransform.getOutputsMap().values()) {
consumers.addAll(pCollectionIdsToConsumers.get(pCollectionId));
}
Expand All @@ -89,8 +92,8 @@ public BoundedSourceRunner<InputT, OutputT> createRunnerForPTransform(
// TODO: Remove and replace with source being sent across gRPC port
addStartFunction.accept(runner::start);

ThrowingConsumer runReadLoop =
(ThrowingConsumer<WindowedValue<InputT>>) runner::runReadLoop;
FnDataReceiver runReadLoop =
(FnDataReceiver<WindowedValue<InputT>>) runner::runReadLoop;
for (String pCollectionId : pTransform.getInputsMap().values()) {
pCollectionIdsToConsumers.put(
pCollectionId,
Expand All @@ -103,12 +106,12 @@ public BoundedSourceRunner<InputT, OutputT> createRunnerForPTransform(

private final PipelineOptions pipelineOptions;
private final RunnerApi.FunctionSpec definition;
private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers;
private final Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers;

BoundedSourceRunner(
PipelineOptions pipelineOptions,
RunnerApi.FunctionSpec definition,
Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers) {
Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers) {
this.pipelineOptions = pipelineOptions;
this.definition = definition;
this.consumers = consumers;
Expand Down Expand Up @@ -151,7 +154,7 @@ public void runReadLoop(WindowedValue<InputT> value) throws Exception {
// TODO: Should this use the input window as the window for all the outputs?
WindowedValue<OutputT> nextValue = WindowedValue.timestampedValueInGlobalWindow(
reader.getCurrent(), reader.getCurrentTimestamp());
for (ThrowingConsumer<WindowedValue<OutputT>> consumer : consumers) {
for (FnDataReceiver<WindowedValue<OutputT>> consumer : consumers) {
consumer.accept(nextValue);
}
} while (reader.advance());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,20 @@
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.harness.state.BagUserState;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
Expand Down Expand Up @@ -123,23 +125,23 @@ public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
BeamFnDataClient beamFnDataClient,
BeamFnStateClient beamFnStateClient,
String pTransformId,
RunnerApi.PTransform pTransform,
PTransform pTransform,
Supplier<String> processBundleInstructionId,
Map<String, RunnerApi.PCollection> pCollections,
Map<String, PCollection> pCollections,
Map<String, RunnerApi.Coder> coders,
Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
Consumer<ThrowingRunnable> addFinishFunction) {

// For every output PCollection, create a map from output name to Consumer
ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>>
ImmutableMap.Builder<String, Collection<FnDataReceiver<WindowedValue<?>>>>
outputMapBuilder = ImmutableMap.builder();
for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
outputMapBuilder.put(
entry.getKey(),
pCollectionIdsToConsumers.get(entry.getValue()));
}
ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap =
ImmutableMap<String, Collection<FnDataReceiver<WindowedValue<?>>>> outputMap =
outputMapBuilder.build();

// Get the DoFnInfo from the serialized blob.
Expand All @@ -158,16 +160,16 @@ public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
doFnInfo.getOutputMap());

ImmutableMultimap.Builder<TupleTag<?>,
ThrowingConsumer<WindowedValue<?>>> tagToOutputMapBuilder =
FnDataReceiver<WindowedValue<?>>> tagToOutputMapBuilder =
ImmutableMultimap.builder();
for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) {
@SuppressWarnings({"unchecked", "rawtypes"})
Collection<ThrowingConsumer<WindowedValue<?>>> consumers =
Collection<FnDataReceiver<WindowedValue<?>>> consumers =
outputMap.get(Long.toString(entry.getKey()));
tagToOutputMapBuilder.putAll(entry.getValue(), consumers);
}

ImmutableMultimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> tagToOutputMap =
ImmutableMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToOutputMap =
tagToOutputMapBuilder.build();

@SuppressWarnings({"unchecked", "rawtypes"})
Expand All @@ -180,7 +182,7 @@ public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
WindowedValue.getFullCoder(
doFnInfo.getInputCoder(),
doFnInfo.getWindowingStrategy().getWindowFn().windowCoder()),
(Collection<ThrowingConsumer<WindowedValue<OutputT>>>) (Collection)
(Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection)
tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())),
tagToOutputMap,
doFnInfo.getWindowingStrategy());
Expand All @@ -190,7 +192,7 @@ public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
for (String pcollectionId : pTransform.getInputsMap().values()) {
pCollectionIdsToConsumers.put(
pcollectionId,
(ThrowingConsumer) (ThrowingConsumer<WindowedValue<InputT>>) runner::processElement);
(FnDataReceiver) (FnDataReceiver<WindowedValue<InputT>>) runner::processElement);
}
addFinishFunction.accept(runner::finishBundle);
return runner;
Expand All @@ -205,8 +207,8 @@ public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
private final Supplier<String> processBundleInstructionId;
private final DoFn<InputT, OutputT> doFn;
private final WindowedValueCoder<InputT> inputCoder;
private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers;
private final Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap;
private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
private final Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap;
private final WindowingStrategy windowingStrategy;
private final DoFnSignature doFnSignature;
private final DoFnInvoker<InputT, OutputT> doFnInvoker;
Expand Down Expand Up @@ -243,8 +245,8 @@ public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
Supplier<String> processBundleInstructionId,
DoFn<InputT, OutputT> doFn,
WindowedValueCoder<InputT> inputCoder,
Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers,
Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap,
Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap,
WindowingStrategy windowingStrategy) {
this.pipelineOptions = pipelineOptions;
this.beamFnStateClient = beamFnStateClient;
Expand Down Expand Up @@ -316,11 +318,11 @@ public void finishBundle() {
* Outputs the given element to the specified set of consumers wrapping any exceptions.
*/
private <T> void outputTo(
Collection<ThrowingConsumer<WindowedValue<T>>> consumers,
Collection<FnDataReceiver<WindowedValue<T>>> consumers,
WindowedValue<T> output) {
Iterator<ThrowingConsumer<WindowedValue<T>>> consumerIterator;
Iterator<FnDataReceiver<WindowedValue<T>>> consumerIterator;
try {
for (ThrowingConsumer<WindowedValue<T>> consumer : consumers) {
for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
consumer.accept(output);
}
} catch (Throwable t) {
Expand Down Expand Up @@ -492,7 +494,7 @@ public void outputWithTimestamp(OutputT output, Instant timestamp) {

@Override
public <T> void output(TupleTag<T> tag, T output) {
Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
if (consumers == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
Expand All @@ -506,7 +508,7 @@ public <T> void output(TupleTag<T> tag, T output) {

@Override
public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
if (consumers == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
Expand Down Expand Up @@ -622,7 +624,7 @@ public void output(OutputT output, Instant timestamp, BoundedWindow window) {

@Override
public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
if (consumers == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
Expand Down

0 comments on commit d799fe0

Please sign in to comment.