Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion runners/samza/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ task validatesRunner(type: Test) {
excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher'
excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle'
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,10 @@ public interface SamzaPipelineOptions extends PipelineOptions {
Boolean getStateDurable();

void setStateDurable(Boolean stateDurable);

@Description("The maximum number of event-time timers buffered in memory for a transform.")
@Default.Integer(50000)
int getTimerBufferSize();

void setTimerBufferSize(int timerBufferSize);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.metrics.MetricResults;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.samza.application.StreamApplication;
import org.apache.samza.job.ApplicationStatus;
import org.apache.samza.runtime.ApplicationRunner;
Expand Down Expand Up @@ -98,7 +99,8 @@ private StateInfo getStateInfo() {
case UnsuccessfulFinish:
LOG.error(status.getThrowable().getMessage(), status.getThrowable());
return new StateInfo(
State.FAILED, new Pipeline.PipelineExecutionException(status.getThrowable()));
State.FAILED,
new Pipeline.PipelineExecutionException(getUserCodeException(status.getThrowable())));
default:
return new StateInfo(State.UNKNOWN);
}
Expand All @@ -117,4 +119,21 @@ private StateInfo(State state, Pipeline.PipelineExecutionException error) {
this.error = error;
}
}

/**
* Some of the Beam unit tests relying on the exception message to do assertion. This function
* will find the original UserCodeException so the message will be exposed directly.
*/
private static Throwable getUserCodeException(Throwable throwable) {
Throwable t = throwable;
while (t != null) {
if (t instanceof UserCodeException) {
return t;
}

t = t.getCause();
}

return throwable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.beam.runners.samza.SamzaExecutionContext;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
Expand All @@ -56,7 +55,6 @@
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterators;
import org.apache.samza.config.Config;
import org.apache.samza.operators.TimerRegistry;
import org.apache.samza.storage.kv.KeyValueStore;
import org.apache.samza.task.TaskContext;
import org.joda.time.Instant;
import org.slf4j.Logger;
Expand Down Expand Up @@ -141,12 +139,11 @@ public DoFnOp(
public void open(
Config config,
TaskContext context,
TimerRegistry<TimerKey<Void>> timerRegistry,
TimerRegistry<KeyedTimerData<Void>> timerRegistry,
OpEmitter<OutT> emitter) {
this.inputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
this.sideInputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
this.pushbackWatermarkHold = BoundedWindow.TIMESTAMP_MAX_VALUE;
this.timerInternalsFactory = new SamzaTimerInternalsFactory(keyCoder, timerRegistry);

final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final SamzaPipelineOptions pipelineOptions =
Expand All @@ -156,7 +153,18 @@ public void open(
.as(SamzaPipelineOptions.class);

final SamzaStoreStateInternals.Factory<?> nonKeyedStateInternalsFactory =
createStateInternalFactory(null, context, pipelineOptions, signature, mainOutputTag);
SamzaStoreStateInternals.createStateInternalFactory(
null, context, pipelineOptions, signature, mainOutputTag);

this.timerInternalsFactory =
SamzaTimerInternalsFactory.createTimerInternalFactory(
keyCoder,
(TimerRegistry) timerRegistry,
getTimerStateId(signature),
nonKeyedStateInternalsFactory,
windowingStrategy,
pipelineOptions);

this.sideInputHandler =
new SideInputHandler(sideInputs, nonKeyedStateInternalsFactory.stateInternalsForKey(null));

Expand Down Expand Up @@ -208,34 +216,12 @@ public void open(
doFnInvoker.invokeSetup();
}

static SamzaStoreStateInternals.Factory createStateInternalFactory(
Coder<?> keyCoder,
TaskContext context,
SamzaPipelineOptions pipelineOptions,
DoFnSignature signature,
TupleTag<?> mainOutputTag) {
final int batchGetSize = pipelineOptions.getStoreBatchGetSize();
final Map<String, KeyValueStore<byte[], byte[]>> stores =
new HashMap<>(SamzaStoreStateInternals.getBeamStore(context));

final Coder stateKeyCoder;
if (keyCoder != null) {
signature
.stateDeclarations()
.keySet()
.forEach(
stateId ->
stores.put(stateId, (KeyValueStore<byte[], byte[]>) context.getStore(stateId)));
stateKeyCoder = keyCoder;
} else {
stateKeyCoder = VoidCoder.of();
private String getTimerStateId(DoFnSignature signature) {
final StringBuilder builder = new StringBuilder("timer");
if (signature.usesTimers()) {
signature.timerDeclarations().keySet().forEach(key -> builder.append(key));
}
return new SamzaStoreStateInternals.Factory<>(
// TODO: ??? what to do with empty output?
mainOutputTag == null ? "null" : mainOutputTag.getId(),
stores,
stateKeyCoder,
batchGetSize);
return builder.toString();
}

@Override
Expand Down Expand Up @@ -322,6 +308,8 @@ public void processTimer(KeyedTimerData<Void> keyedTimerData) {
pushbackFnRunner.startBundle();
fireTimer(keyedTimerData);
pushbackFnRunner.finishBundle();

this.timerInternalsFactory.removeProcessingTimer((KeyedTimerData) keyedTimerData);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
public class GroupByKeyOp<K, InputT, OutputT>
implements Op<KeyedWorkItem<K, InputT>, KV<K, OutputT>, K> {
private static final Logger LOG = LoggerFactory.getLogger(GroupByKeyOp.class);
private static final String TIMER_STATE_ID = "timer";

private final TupleTag<KV<K, OutputT>> mainOutputTag;
private final KeyedWorkItemCoder<K, InputT> inputCoder;
Expand Down Expand Up @@ -97,25 +98,37 @@ public GroupByKeyOp(
public void open(
Config config,
TaskContext context,
TimerRegistry<TimerKey<K>> timerRegistry,
TimerRegistry<KeyedTimerData<K>> timerRegistry,
OpEmitter<KV<K, OutputT>> emitter) {
this.pipelineOptions =
Base64Serializer.deserializeUnchecked(
config.get("beamPipelineOptions"), SerializablePipelineOptions.class)
.get()
.as(SamzaPipelineOptions.class);

final SamzaStoreStateInternals.Factory<?> nonKeyedStateInternalsFactory =
SamzaStoreStateInternals.createStateInternalFactory(
null, context, pipelineOptions, null, mainOutputTag);

final DoFnRunners.OutputManager outputManager = outputManagerFactory.create(emitter);

this.stateInternalsFactory =
new SamzaStoreStateInternals.Factory<>(
mainOutputTag.getId(),
SamzaStoreStateInternals.getBeamStore(context),
Collections.singletonMap(
SamzaStoreStateInternals.BEAM_STORE,
SamzaStoreStateInternals.getBeamStore(context)),
keyCoder,
pipelineOptions.getStoreBatchGetSize());

this.timerInternalsFactory =
new SamzaTimerInternalsFactory<>(inputCoder.getKeyCoder(), timerRegistry);
SamzaTimerInternalsFactory.createTimerInternalFactory(
keyCoder,
timerRegistry,
TIMER_STATE_ID,
nonKeyedStateInternalsFactory,
windowingStrategy,
pipelineOptions);

final DoFn<KeyedWorkItem<K, InputT>, KV<K, OutputT>> doFn =
GroupAlsoByWindowViaWindowSetNewDoFn.create(
Expand Down Expand Up @@ -192,6 +205,8 @@ public void processTimer(KeyedTimerData<K> keyedTimerData) {
fnRunner.startBundle();
fireTimer(keyedTimerData.getKey(), keyedTimerData.getTimerData());
fnRunner.finishBundle();

timerInternalsFactory.removeProcessingTimer(keyedTimerData);
}

private void fireTimer(K key, TimerData timer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,24 @@
*/
package org.apache.beam.runners.samza.runtime;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.TimerInternals.TimerData;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.StructuredCoder;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.joda.time.Instant;

/**
* {@link TimerInternals.TimerData} with key, used by {@link SamzaTimerInternalsFactory}. Implements
Expand All @@ -29,7 +45,7 @@ public class KeyedTimerData<K> implements Comparable<KeyedTimerData<K>> {
private final K key;
private final TimerInternals.TimerData timerData;

public KeyedTimerData(byte[] keyBytes, K key, TimerInternals.TimerData timerData) {
public KeyedTimerData(byte[] keyBytes, K key, TimerData timerData) {
this.keyBytes = keyBytes;
this.key = key;
this.timerData = timerData;
Expand Down Expand Up @@ -102,4 +118,73 @@ public int hashCode() {
result = 31 * result + timerData.hashCode();
return result;
}

/**
* Coder for {@link KeyedTimerData}. Note we don't use the {@link
* org.apache.beam.runners.core.TimerInternals.TimerDataCoder} here directly since we want to
* en/decode timestamp first so the timers will be sorted in the state.
*/
public static class KeyedTimerDataCoder<K> extends StructuredCoder<KeyedTimerData<K>> {
private static final StringUtf8Coder STRING_CODER = StringUtf8Coder.of();
private static final InstantCoder INSTANT_CODER = InstantCoder.of();

private final Coder<K> keyCoder;
private final Coder<? extends BoundedWindow> windowCoder;

KeyedTimerDataCoder(Coder<K> keyCoder, Coder<? extends BoundedWindow> windowCoder) {
this.keyCoder = keyCoder;
this.windowCoder = windowCoder;
}

@Override
public void encode(KeyedTimerData<K> value, OutputStream outStream)
throws CoderException, IOException {

final TimerData timer = value.getTimerData();
// encode the timestamp first
INSTANT_CODER.encode(timer.getTimestamp(), outStream);
STRING_CODER.encode(timer.getTimerId(), outStream);
STRING_CODER.encode(timer.getNamespace().stringKey(), outStream);
STRING_CODER.encode(timer.getDomain().name(), outStream);

if (keyCoder != null) {
keyCoder.encode(value.key, outStream);
}
}

@Override
public KeyedTimerData<K> decode(InputStream inStream) throws CoderException, IOException {
// decode the timestamp first
final Instant timestamp = INSTANT_CODER.decode(inStream);
final String timerId = STRING_CODER.decode(inStream);
final StateNamespace namespace =
StateNamespaces.fromString(STRING_CODER.decode(inStream), windowCoder);
final TimeDomain domain = TimeDomain.valueOf(STRING_CODER.decode(inStream));
final TimerData timer = TimerData.of(timerId, namespace, timestamp, domain);

byte[] keyBytes = null;
K key = null;
if (keyCoder != null) {
key = keyCoder.decode(inStream);

final ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
keyCoder.encode(key, baos);
} catch (IOException e) {
throw new RuntimeException("Could not encode key: " + key, e);
}
keyBytes = baos.toByteArray();
}

return new KeyedTimerData(keyBytes, key, timer);
}

@Override
public List<? extends Coder<?>> getCoderArguments() {
return Arrays.asList(keyCoder, windowCoder);
}

@Override
public void verifyDeterministic() throws NonDeterministicException {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public interface Op<InT, OutT, K> extends Serializable {
default void open(
Config config,
TaskContext taskContext,
TimerRegistry<TimerKey<K>> timerRegistry,
TimerRegistry<KeyedTimerData<K>> timerRegistry,
OpEmitter<OutT> emitter) {}

void processElement(WindowedValue<InT> inputElement, OpEmitter<OutT> emitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.beam.runners.core.TimerInternals.TimerData;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.samza.config.Config;
Expand All @@ -39,7 +37,7 @@
public class OpAdapter<InT, OutT, K>
implements FlatMapFunction<OpMessage<InT>, OpMessage<OutT>>,
WatermarkFunction<OpMessage<OutT>>,
TimerFunction<TimerKey<K>, OpMessage<OutT>>,
TimerFunction<KeyedTimerData<K>, OpMessage<OutT>>,
Serializable {
private static final Logger LOG = LoggerFactory.getLogger(OpAdapter.class);

Expand Down Expand Up @@ -68,7 +66,7 @@ public final void init(Config config, TaskContext context) {
}

@Override
public final void registerTimer(TimerRegistry<TimerKey<K>> timerRegistry) {
public final void registerTimer(TimerRegistry<KeyedTimerData<K>> timerRegistry) {
assert taskContext != null;

op.open(config, taskContext, timerRegistry, emitter);
Expand Down Expand Up @@ -126,19 +124,10 @@ public Long getOutputWatermark() {
}

@Override
public Collection<OpMessage<OutT>> onTimer(TimerKey<K> timerKey, long time) {
public Collection<OpMessage<OutT>> onTimer(KeyedTimerData<K> keyedTimerData, long time) {
assert outputList.isEmpty();

try {
final TimerData timerData =
TimerData.of(
timerKey.getTimerId(),
timerKey.getStateNamespace(),
new Instant(time),
TimeDomain.PROCESSING_TIME);
final KeyedTimerData<K> keyedTimerData =
new KeyedTimerData<>(timerKey.getKeyBytes(), timerKey.getKey(), timerData);

op.processTimer(keyedTimerData);
} catch (Exception e) {
LOG.error("Op {} threw an exception during processing timer", this.getClass().getName(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public static <InT, FnOutT> DoFnRunner<InT, FnOutT> create(
final StateInternals stateInternals;
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final SamzaStoreStateInternals.Factory<?> stateInternalsFactory =
DoFnOp.createStateInternalFactory(
SamzaStoreStateInternals.createStateInternalFactory(
keyCoder, taskContext, pipelineOptions, signature, mainOutputTag);

final SamzaExecutionContext executionContext =
Expand Down
Loading