Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
* RunnerApi.CombinePayload} protos.
*/
public class CombineTranslation {
private static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1";
public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1";

public static CombinePayload toProto(
AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents sdkComponents)
Expand Down Expand Up @@ -86,7 +86,7 @@ private static <K, InputT, AccumT> Coder<AccumT> extractAccumulatorCoder(
.getAccumulatorCoder();
}

private static SdkFunctionSpec toProto(GlobalCombineFn<?, ?, ?> combineFn) {
public static SdkFunctionSpec toProto(GlobalCombineFn<?, ?, ?> combineFn) {
return SdkFunctionSpec.newBuilder()
// TODO: Set Java SDK Environment URN
.setSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ public static String urnForTransform(PTransform<?, ?> transform) {
*/
public interface TransformPayloadTranslator<T extends PTransform<?, ?>> {
String getUrn(T transform);
FunctionSpec translate(AppliedPTransform<?, ?, T> application, SdkComponents components);
FunctionSpec translate(AppliedPTransform<?, ?, T> application, SdkComponents components)
throws IOException;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
package org.apache.beam.runners.core.construction;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
Expand All @@ -46,9 +48,12 @@
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput.Builder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.StateSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.TimerSpec;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.PTransform;
Expand Down Expand Up @@ -107,7 +112,8 @@ public String getUrn(ParDo.MultiOutput<?, ?> transform) {

@Override
public FunctionSpec translate(
AppliedPTransform<?, ?, MultiOutput<?, ?>> transform, SdkComponents components) {
AppliedPTransform<?, ?, MultiOutput<?, ?>> transform, SdkComponents components)
throws IOException {
ParDoPayload payload = toProto(transform.getTransform(), components);
return RunnerApi.FunctionSpec.newBuilder()
.setUrn(PAR_DO_TRANSFORM_URN)
Expand All @@ -128,8 +134,10 @@ public static class Registrar implements TransformPayloadTranslatorRegistrar {
}
}

public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components) {
DoFnSignature signature = DoFnSignatures.getSignature(parDo.getFn().getClass());
public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components)
throws IOException {
DoFn<?, ?> doFn = parDo.getFn();
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
Map<String, StateDeclaration> states = signature.stateDeclarations();
Map<String, TimerDeclaration> timers = signature.timerDeclarations();
List<Parameter> parameters = signature.processElement().extraParameters();
Expand All @@ -146,16 +154,62 @@ public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents
}
}
for (Map.Entry<String, StateDeclaration> state : states.entrySet()) {
StateSpec spec = toProto(state.getValue());
RunnerApi.StateSpec spec =
toProto(getStateSpecOrCrash(state.getValue(), doFn), components);
builder.putStateSpecs(state.getKey(), spec);
}
for (Map.Entry<String, TimerDeclaration> timer : timers.entrySet()) {
TimerSpec spec = toProto(timer.getValue());
RunnerApi.TimerSpec spec =
toProto(getTimerSpecOrCrash(timer.getValue(), doFn));
builder.putTimerSpecs(timer.getKey(), spec);
}
return builder.build();
}

private static StateSpec<?> getStateSpecOrCrash(
StateDeclaration stateDeclaration, DoFn<?, ?> target) {
try {
Object fieldValue = stateDeclaration.field().get(target);
checkState(fieldValue instanceof StateSpec,
"Malformed %s class %s: state declaration field %s does not have type %s.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
stateDeclaration.field().getName(),
StateSpec.class);

return (StateSpec<?>) stateDeclaration.field().get(target);
} catch (IllegalAccessException exc) {
throw new RuntimeException(
String.format(
"Malformed %s class %s: state declaration field %s is not accessible.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
stateDeclaration.field().getName()));
}
}

private static TimerSpec getTimerSpecOrCrash(
TimerDeclaration timerDeclaration, DoFn<?, ?> target) {
try {
Object fieldValue = timerDeclaration.field().get(target);
checkState(fieldValue instanceof TimerSpec,
"Malformed %s class %s: timer declaration field %s does not have type %s.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
timerDeclaration.field().getName(),
TimerSpec.class);

return (TimerSpec) timerDeclaration.field().get(target);
} catch (IllegalAccessException exc) {
throw new RuntimeException(
String.format(
"Malformed %s class %s: timer declaration field %s is not accessible.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
timerDeclaration.field().getName()));
}
}

public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException {
return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn();
}
Expand All @@ -179,14 +233,149 @@ public static RunnerApi.PCollection getMainInput(
return components.getPcollectionsOrThrow(ptransform.getInputsOrThrow(mainInputId));
}

// TODO: Implement
private static StateSpec toProto(StateDeclaration state) {
throw new UnsupportedOperationException("Not yet supported");
@VisibleForTesting
static RunnerApi.StateSpec toProto(StateSpec<?> stateSpec, final SdkComponents components)
throws IOException {
final RunnerApi.StateSpec.Builder builder = RunnerApi.StateSpec.newBuilder();

return stateSpec.match(
new StateSpec.Cases<RunnerApi.StateSpec>() {
@Override
public RunnerApi.StateSpec dispatchValue(Coder<?> valueCoder) {
return builder
.setValueSpec(
RunnerApi.ValueStateSpec.newBuilder()
.setCoderId(registerCoderOrThrow(components, valueCoder)))
.build();
}

@Override
public RunnerApi.StateSpec dispatchBag(Coder<?> elementCoder) {
return builder
.setBagSpec(
RunnerApi.BagStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.build();
}

@Override
public RunnerApi.StateSpec dispatchCombining(
Combine.CombineFn<?, ?, ?> combineFn, Coder<?> accumCoder) {
return builder
.setCombiningSpec(
RunnerApi.CombiningStateSpec.newBuilder()
.setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder))
.setCombineFn(CombineTranslation.toProto(combineFn)))
.build();
}

@Override
public RunnerApi.StateSpec dispatchMap(Coder<?> keyCoder, Coder<?> valueCoder) {
return builder
.setMapSpec(
RunnerApi.MapStateSpec.newBuilder()
.setKeyCoderId(registerCoderOrThrow(components, keyCoder))
.setValueCoderId(registerCoderOrThrow(components, valueCoder)))
.build();
}

@Override
public RunnerApi.StateSpec dispatchSet(Coder<?> elementCoder) {
return builder
.setSetSpec(
RunnerApi.SetStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.build();
}
});
}

@VisibleForTesting
static StateSpec<?> fromProto(RunnerApi.StateSpec stateSpec, RunnerApi.Components components)
throws IOException {
switch (stateSpec.getSpecCase()) {
case VALUE_SPEC:
return StateSpecs.value(
CoderTranslation.fromProto(
components.getCodersMap().get(stateSpec.getValueSpec().getCoderId()), components));
case BAG_SPEC:
return StateSpecs.bag(
CoderTranslation.fromProto(
components.getCodersMap().get(stateSpec.getBagSpec().getElementCoderId()),
components));
case COMBINING_SPEC:
FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn().getSpec();

if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) {
throw new UnsupportedOperationException(
String.format(
"Cannot create %s from non-Java %s: %s",
StateSpec.class.getSimpleName(),
Combine.CombineFn.class.getSimpleName(),
combineFnSpec.getUrn()));
}

Combine.CombineFn<?, ?, ?> combineFn =
(Combine.CombineFn<?, ?, ?>)
SerializableUtils.deserializeFromByteArray(
combineFnSpec.getParameter().unpack(BytesValue.class).toByteArray(),
Combine.CombineFn.class.getSimpleName());

// Rawtype coder cast because it is required to be a valid accumulator coder
// for the CombineFn, by construction
return StateSpecs.combining(
(Coder)
CoderTranslation.fromProto(
components
.getCodersMap()
.get(stateSpec.getCombiningSpec().getAccumulatorCoderId()),
components),
combineFn);

case MAP_SPEC:
return StateSpecs.map(
CoderTranslation.fromProto(
components.getCodersOrThrow(stateSpec.getMapSpec().getKeyCoderId()), components),
CoderTranslation.fromProto(
components.getCodersOrThrow(stateSpec.getMapSpec().getValueCoderId()), components));

case SET_SPEC:
return StateSpecs.set(
CoderTranslation.fromProto(
components.getCodersMap().get(stateSpec.getSetSpec().getElementCoderId()),
components));

case SPEC_NOT_SET:
default:
throw new IllegalArgumentException(
String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec));

}
}

private static String registerCoderOrThrow(SdkComponents components, Coder coder) {
try {
return components.registerCoder(coder);
} catch (IOException exc) {
throw new RuntimeException("Failure to register coder", exc);
}
}

// TODO: Implement
private static TimerSpec toProto(TimerDeclaration timer) {
throw new UnsupportedOperationException("Not yet supported");
private static RunnerApi.TimerSpec toProto(TimerSpec timer) {
return RunnerApi.TimerSpec.newBuilder().setTimeDomain(toProto(timer.getTimeDomain())).build();
}

private static RunnerApi.TimeDomain toProto(TimeDomain timeDomain) {
switch(timeDomain) {
case EVENT_TIME:
return RunnerApi.TimeDomain.EVENT_TIME;
case PROCESSING_TIME:
return RunnerApi.TimeDomain.PROCESSING_TIME;
case SYNCHRONIZED_PROCESSING_TIME:
return RunnerApi.TimeDomain.SYNCHRONIZED_PROCESSING_TIME;
default:
throw new IllegalArgumentException("Unknown time domain");
}
}

@AutoValue
Expand Down
Loading