From cf333b0b0c318569a1704ca71121c37dcd12bd3d Mon Sep 17 00:00:00 2001 From: kl0u Date: Wed, 29 Mar 2017 18:21:02 +0200 Subject: [PATCH] [FLINK-6215] Make the StatefulSequenceSource scalable. So far this source was computing all the elements to be emitted and stored them in memory. This could lead to out-of-memory problems for large deployments. Now we do split the range of elements into partitions that can be re-shuffled upon rescaling and we just store the next offset and the end of each one of them upon checkpointing. --- .../api/checkpoint/ListCheckpointed.java | 2 +- .../source/StatefulSequenceSource.java | 116 +++++++++++------- .../functions/StatefulSequenceSourceTest.java | 9 +- 3 files changed, 81 insertions(+), 46 deletions(-) diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java index 84a9700cf8f3e..e3dfd6b94090a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java @@ -67,7 +67,7 @@ * +----+ +----+ +----+ +----+ +----+ * - * Recovering the checkpoint with parallelism = 5 yields the following state assignment: + * Recovering the checkpoint with parallelism = 2 yields the following state assignment: *
  *      func_1          func_2
  * +----+----+----+   +----+----+
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
index bdb12f39c3dcd..d9784e4001f7c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
@@ -20,25 +20,30 @@
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.util.Preconditions;
 
-import java.util.ArrayDeque;
-import java.util.Deque;
+import java.util.HashMap;
+import java.util.Map;
 
 /**
  * A stateful streaming source that emits each number from a given interval exactly once,
  * possibly in parallel.
  *
- * 

For the source to be re-scalable, the first time the job is run, we precompute all the elements - * that each of the tasks should emit and upon checkpointing, each element constitutes its own - * partition. When rescaling, these partitions will be randomly re-assigned to the new tasks. + *

For the source to be re-scalable, the range of elements to be emitted is initially (at the first execution) + * split into {@code min(maxParallelism, totalNumberOfElements)} partitions, and for each one, we + * store the {@code nextOffset}, i.e. the next element to be emitted, and its {@code end}. Upon rescaling, these + * partitions can be reshuffled among the new tasks, and these will resume emitting from where their predecessors + * left off. * - *

This strategy guarantees that each element will be emitted exactly-once, but elements will not - * necessarily be emitted in ascending order, even for the same tasks. + *

Although each element will be emitted exactly-once, elements will not necessarily be emitted in ascending order, + * even for the same task. */ @PublicEvolving public class StatefulSequenceSource extends RichParallelSourceFunction implements CheckpointedFunction { @@ -50,9 +55,8 @@ public class StatefulSequenceSource extends RichParallelSourceFunction imp private volatile boolean isRunning = true; - private transient Deque valuesToEmit; - - private transient ListState checkpointedState; + private transient Map endToNextOffsetMapping; + private transient ListState> checkpointedState; /** * Creates a source that emits all numbers from the given interval exactly once. @@ -61,6 +65,7 @@ public class StatefulSequenceSource extends RichParallelSourceFunction imp * @param end End of the range of numbers to emit. */ public StatefulSequenceSource(long start, long end) { + Preconditions.checkArgument(start <= end); this.start = start; this.end = end; } @@ -68,45 +73,81 @@ public StatefulSequenceSource(long start, long end) { @Override public void initializeState(FunctionInitializationContext context) throws Exception { - Preconditions.checkState(this.checkpointedState == null, + Preconditions.checkState(checkpointedState == null, "The " + getClass().getSimpleName() + " has already been initialized."); this.checkpointedState = context.getOperatorStateStore().getOperatorState( new ListStateDescriptor<>( - "stateful-sequence-source-state", - LongSerializer.INSTANCE + "stateful-sequence-source-state", + new TupleSerializer<>( + (Class>) (Class) Tuple2.class, + new TypeSerializer[] { LongSerializer.INSTANCE, LongSerializer.INSTANCE } + ) ) ); - this.valuesToEmit = new ArrayDeque<>(); + this.endToNextOffsetMapping = new HashMap<>(); if (context.isRestored()) { - // upon restoring - - for (Long v : this.checkpointedState.get()) { - this.valuesToEmit.add(v); + for (Tuple2 partitionInfo: checkpointedState.get()) { + Long prev = endToNextOffsetMapping.put(partitionInfo.f0, partitionInfo.f1); + Preconditions.checkState(prev == null, + getClass().getSimpleName() + " : Duplicate entry when restoring."); } } else { - // the first time the job is executed - - final int stepSize = getRuntimeContext().getNumberOfParallelSubtasks(); final int taskIdx = getRuntimeContext().getIndexOfThisSubtask(); - final long congruence = start + taskIdx; + final int parallelTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + + final long totalElements = Math.abs(end - start + 1L); + final int maxParallelism = getRuntimeContext().getMaxNumberOfParallelSubtasks(); + final int totalPartitions = totalElements < Integer.MAX_VALUE ? Math.min(maxParallelism, (int) totalElements) : maxParallelism; - long totalNoOfElements = Math.abs(end - start + 1); - final int baseSize = safeDivide(totalNoOfElements, stepSize); - final int toCollect = (totalNoOfElements % stepSize > taskIdx) ? baseSize + 1 : baseSize; + Tuple2 localPartitionRange = getLocalRange(totalPartitions, parallelTasks, taskIdx); + int localStartIdx = localPartitionRange.f0; + int localEndIdx = localStartIdx + localPartitionRange.f1; - for (long collected = 0; collected < toCollect; collected++) { - this.valuesToEmit.add(collected * stepSize + congruence); + for (int partIdx = localStartIdx; partIdx < localEndIdx; partIdx++) { + Tuple2 limits = getPartitionLimits(totalElements, totalPartitions, partIdx); + endToNextOffsetMapping.put(limits.f1, limits.f0); } } } + private Tuple2 getLocalRange(int totalPartitions, int parallelTasks, int taskIdx) { + int minPartitionSliceSize = totalPartitions / parallelTasks; + int remainingPartitions = totalPartitions - minPartitionSliceSize * parallelTasks; + + int localRangeStartIdx = taskIdx * minPartitionSliceSize + Math.min(taskIdx, remainingPartitions); + int localRangeSize = taskIdx < remainingPartitions ? minPartitionSliceSize + 1 : minPartitionSliceSize; + + return new Tuple2<>(localRangeStartIdx, localRangeSize); + } + + private Tuple2 getPartitionLimits(long totalElements, int totalPartitions, long partitionIdx) { + long minElementPartitionSize = totalElements / totalPartitions; + long remainingElements = totalElements - minElementPartitionSize * totalPartitions; + long startOffset = start; + + for (int idx = 0; idx < partitionIdx; idx++) { + long partitionSize = idx < remainingElements ? minElementPartitionSize + 1L : minElementPartitionSize; + startOffset += partitionSize; + } + + long partitionSize = partitionIdx < remainingElements ? minElementPartitionSize + 1L : minElementPartitionSize; + return new Tuple2<>(startOffset, startOffset + partitionSize); + } + @Override public void run(SourceContext ctx) throws Exception { - while (isRunning && !this.valuesToEmit.isEmpty()) { - synchronized (ctx.getCheckpointLock()) { - ctx.collect(this.valuesToEmit.poll()); + for (Map.Entry partition: endToNextOffsetMapping.entrySet()) { + long endOffset = partition.getKey(); + long currentOffset = partition.getValue(); + + while (isRunning && currentOffset < endOffset) { + synchronized (ctx.getCheckpointLock()) { + long toSend = currentOffset; + endToNextOffsetMapping.put(endOffset, ++currentOffset); + ctx.collect(toSend); + } } } } @@ -118,19 +159,12 @@ public void cancel() { @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { - Preconditions.checkState(this.checkpointedState != null, + Preconditions.checkState(checkpointedState != null, "The " + getClass().getSimpleName() + " state has not been properly initialized."); - this.checkpointedState.clear(); - for (Long v : this.valuesToEmit) { - this.checkpointedState.add(v); + checkpointedState.clear(); + for (Map.Entry entry : endToNextOffsetMapping.entrySet()) { + checkpointedState.add(new Tuple2<>(entry.getKey(), entry.getValue())); } } - - private static int safeDivide(long left, long right) { - Preconditions.checkArgument(right > 0); - Preconditions.checkArgument(left >= 0); - Preconditions.checkArgument(left <= Integer.MAX_VALUE * right); - return (int) (left / right); - } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java index 9030e9dcbd191..83fae82cf7f97 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java @@ -39,7 +39,8 @@ public class StatefulSequenceSourceTest { @Test public void testCheckpointRestore() throws Exception { - final int initElement = 0; + final int maxParallelism = 5; + final int initElement = -101; final int maxElement = 100; final Set expectedOutput = new HashSet<>(); @@ -57,14 +58,14 @@ public void testCheckpointRestore() throws Exception { StreamSource src1 = new StreamSource<>(source1); final AbstractStreamOperatorTestHarness testHarness1 = - new AbstractStreamOperatorTestHarness<>(src1, 2, 2, 0); + new AbstractStreamOperatorTestHarness<>(src1, maxParallelism, 2, 0); testHarness1.open(); final StatefulSequenceSource source2 = new StatefulSequenceSource(initElement, maxElement); StreamSource src2 = new StreamSource<>(source2); final AbstractStreamOperatorTestHarness testHarness2 = - new AbstractStreamOperatorTestHarness<>(src2, 2, 2, 1); + new AbstractStreamOperatorTestHarness<>(src2, maxParallelism, 2, 1); testHarness2.open(); final Throwable[] error = new Throwable[3]; @@ -117,7 +118,7 @@ public void run() { StreamSource src3 = new StreamSource<>(source3); final AbstractStreamOperatorTestHarness testHarness3 = - new AbstractStreamOperatorTestHarness<>(src3, 2, 1, 0); + new AbstractStreamOperatorTestHarness<>(src3, maxParallelism, 1, 0); testHarness3.setup(); testHarness3.initializeState(snapshot); testHarness3.open();