Skip to content

Commit

Permalink
[FLINK-5163] Port the StatefulSequenceSource to the new state abstrac…
Browse files Browse the repository at this point in the history
…tions.
  • Loading branch information
kl0u committed Dec 9, 2016
1 parent fd5d87a commit 5401b6c
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
Expand All @@ -18,25 +18,42 @@
package org.apache.flink.streaming.api.functions.source;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.util.Preconditions;

import java.util.ArrayDeque;
import java.util.Deque;

/**
* A stateful streaming source that emits each number from a given interval exactly once,
* possibly in parallel.
* <p>
* For the source to be re-scalable, the first time the job is run, we precompute all the elements
* that each of the tasks should emit and upon checkpointing, each element constitutes its own
* partition. When rescaling, these partitions will be randomly re-assigned to the new tasks.
* <p>
* This strategy guarantees that each element will be emitted exactly-once, but elements will not
* necessarily be emitted in ascending order, even for the same tasks.
*/
@PublicEvolving
public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements Checkpointed<Long> {
public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements CheckpointedFunction {

private static final long serialVersionUID = 1L;

private final long start;
private final long end;

private long collected;

private volatile boolean isRunning = true;

private transient Deque<Long> valuesToEmit;

private transient ListState<Long> checkpointedState;

/**
* Creates a source that emits all numbers from the given interval exactly once.
*
Expand All @@ -49,24 +66,47 @@ public StatefulSequenceSource(long start, long end) {
}

@Override
public void run(SourceContext<Long> ctx) throws Exception {
final Object checkpointLock = ctx.getCheckpointLock();
public void initializeState(FunctionInitializationContext context) throws Exception {

Preconditions.checkState(this.checkpointedState == null,
"The " + getClass().getSimpleName() + " has already been initialized.");

this.checkpointedState = context.getOperatorStateStore().getOperatorState(
new ListStateDescriptor<>(
"stateful-sequence-source-state",
LongSerializer.INSTANCE
)
);

RuntimeContext context = getRuntimeContext();
this.valuesToEmit = new ArrayDeque<>();
if (context.isRestored()) {
// upon restoring

for (Long v : this.checkpointedState.get()) {
this.valuesToEmit.add(v);
}
} else {
// the first time the job is executed

final long stepSize = context.getNumberOfParallelSubtasks();
final long congruence = start + context.getIndexOfThisSubtask();
final int stepSize = getRuntimeContext().getNumberOfParallelSubtasks();
final int taskIdx = getRuntimeContext().getIndexOfThisSubtask();
final long congruence = start + taskIdx;

final long toCollect =
((end - start + 1) % stepSize > (congruence - start)) ?
((end - start + 1) / stepSize + 1) :
((end - start + 1) / stepSize);

long totalNoOfElements = Math.abs(end - start + 1);
final int baseSize = safeDivide(totalNoOfElements, stepSize);
final int toCollect = (totalNoOfElements % stepSize > taskIdx) ? baseSize + 1 : baseSize;

while (isRunning && collected < toCollect) {
synchronized (checkpointLock) {
ctx.collect(collected * stepSize + congruence);
collected++;
for (long collected = 0; collected < toCollect; collected++) {
this.valuesToEmit.add(collected * stepSize + congruence);
}
}
}

@Override
public void run(SourceContext<Long> ctx) throws Exception {
while (isRunning && !this.valuesToEmit.isEmpty()) {
synchronized (ctx.getCheckpointLock()) {
ctx.collect(this.valuesToEmit.poll());
}
}
}
Expand All @@ -77,12 +117,20 @@ public void cancel() {
}

@Override
public Long snapshotState(long checkpointId, long checkpointTimestamp) {
return collected;
public void snapshotState(FunctionSnapshotContext context) throws Exception {
Preconditions.checkState(this.checkpointedState != null,
"The " + getClass().getSimpleName() + " state has not been properly initialized.");

this.checkpointedState.clear();
for (Long v : this.valuesToEmit) {
this.checkpointedState.add(v);
}
}

@Override
public void restoreState(Long state) {
collected = state;
private static int safeDivide(long left, long right) {
Preconditions.checkArgument(right > 0);
Preconditions.checkArgument(left >= 0);
Preconditions.checkArgument(left <= Integer.MAX_VALUE * right);
return (int) (left / right);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,4 @@ public void fromCollectionTest() throws Exception {
Arrays.asList(1, 2, 3))));
assertEquals(expectedList, actualList);
}

@Test
public void generateSequenceTest() throws Exception {
List<Long> expectedList = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L);
List<Long> actualList = SourceFunctionUtil.runSourceFunction(new StatefulSequenceSource(1,
7));
assertEquals(expectedList, actualList);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,238 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.functions;

import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public class StatefulSequenceSourceTest {

@Test
public void testCheckpointRestore() throws Exception {
final int initElement = 0;
final int maxElement = 100;

final Set<Long> expectedOutput = new HashSet<>();
for (long i = initElement; i <= maxElement; i++) {
expectedOutput.add(i);
}

final ConcurrentHashMap<String, List<Long>> outputCollector = new ConcurrentHashMap<>();
final OneShotLatch latchToTrigger1 = new OneShotLatch();
final OneShotLatch latchToWait1 = new OneShotLatch();
final OneShotLatch latchToTrigger2 = new OneShotLatch();
final OneShotLatch latchToWait2 = new OneShotLatch();

final StatefulSequenceSource source1 = new StatefulSequenceSource(initElement, maxElement);
StreamSource<Long, StatefulSequenceSource> src1 = new StreamSource<>(source1);

final AbstractStreamOperatorTestHarness<Long> testHarness1 =
new AbstractStreamOperatorTestHarness<>(src1, 2, 2, 0);
testHarness1.open();

final StatefulSequenceSource source2 = new StatefulSequenceSource(initElement, maxElement);
StreamSource<Long, StatefulSequenceSource> src2 = new StreamSource<>(source2);

final AbstractStreamOperatorTestHarness<Long> testHarness2 =
new AbstractStreamOperatorTestHarness<>(src2, 2, 2, 1);
testHarness2.open();

final Throwable[] error = new Throwable[3];

// run the source asynchronously
Thread runner1 = new Thread() {
@Override
public void run() {
try {
source1.run(new BlockingSourceContext("1", latchToTrigger1, latchToWait1, outputCollector, 21));
}
catch (Throwable t) {
t.printStackTrace();
error[0] = t;
}
}
};

// run the source asynchronously
Thread runner2 = new Thread() {
@Override
public void run() {
try {
source2.run(new BlockingSourceContext("2", latchToTrigger2, latchToWait2, outputCollector, 32));
}
catch (Throwable t) {
t.printStackTrace();
error[1] = t;
}
}
};

runner1.start();
runner2.start();

if (!latchToTrigger1.isTriggered()) {
latchToTrigger1.await();
}

if (!latchToTrigger2.isTriggered()) {
latchToTrigger2.await();
}

OperatorStateHandles snapshot = AbstractStreamOperatorTestHarness.repackageState(
testHarness1.snapshot(0L, 0L),
testHarness2.snapshot(0L, 0L)
);

final StatefulSequenceSource source3 = new StatefulSequenceSource(initElement, maxElement);
StreamSource<Long, StatefulSequenceSource> src3 = new StreamSource<>(source3);

final AbstractStreamOperatorTestHarness<Long> testHarness3 =
new AbstractStreamOperatorTestHarness<>(src3, 2, 1, 0);
testHarness3.setup();
testHarness3.initializeState(snapshot);
testHarness3.open();

final OneShotLatch latchToTrigger3 = new OneShotLatch();
final OneShotLatch latchToWait3 = new OneShotLatch();
latchToWait3.trigger();

// run the source asynchronously
Thread runner3 = new Thread() {
@Override
public void run() {
try {
source3.run(new BlockingSourceContext("3", latchToTrigger3, latchToWait3, outputCollector, 3));
}
catch (Throwable t) {
t.printStackTrace();
error[2] = t;
}
}
};
runner3.start();
runner3.join();

Assert.assertEquals(3, outputCollector.size()); // we have 3 tasks.

// test for at-most-once
Set<Long> dedupRes = new HashSet<>(Math.abs(maxElement - initElement) + 1);
for (Map.Entry<String, List<Long>> elementsPerTask: outputCollector.entrySet()) {
String key = elementsPerTask.getKey();
List<Long> elements = outputCollector.get(key);

// this tests the correctness of the latches in the test
Assert.assertTrue(elements.size() > 0);

for (Long elem : elements) {
if (!dedupRes.add(elem)) {
Assert.fail("Duplicate entry: " + elem);
}

if (!expectedOutput.contains(elem)) {
Assert.fail("Unexpected element: " + elem);
}
}
}

// test for exactly-once
Assert.assertEquals(Math.abs(initElement - maxElement) + 1, dedupRes.size());

latchToWait1.trigger();
latchToWait2.trigger();
}

private static class BlockingSourceContext implements SourceFunction.SourceContext<Long> {

private final String name;

private final Object lock;
private final OneShotLatch latchToTrigger;
private final OneShotLatch latchToWait;
private final ConcurrentHashMap<String, List<Long>> collector;

private final int threshold;
private int counter = 0;

private final List<Long> localOutput;

public BlockingSourceContext(String name, OneShotLatch latchToTrigger, OneShotLatch latchToWait,
ConcurrentHashMap<String, List<Long>> output, int elemToFire) {
this.name = name;
this.lock = new Object();
this.latchToTrigger = latchToTrigger;
this.latchToWait = latchToWait;
this.collector = output;
this.threshold = elemToFire;

this.localOutput = new ArrayList<>();
List<Long> prev = collector.put(name, localOutput);
if (prev != null) {
Assert.fail();
}
}

@Override
public void collectWithTimestamp(Long element, long timestamp) {
collect(element);
}

@Override
public void collect(Long element) {
localOutput.add(element);
if (++counter == threshold) {
latchToTrigger.trigger();
try {
if (!latchToWait.isTriggered()) {
latchToWait.await();
}
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}


@Override
public void emitWatermark(Watermark mark) {
}

@Override
public Object getCheckpointLock() {
return lock;
}

@Override
public void close() {
}
}
}

0 comments on commit 5401b6c

Please sign in to comment.