Skip to content

Commit

Permalink
[tests] Reinforce StateCheckpoinedITCase to make sure actual checkpoi…
Browse files Browse the repository at this point in the history
…nting has happened before a failure.
  • Loading branch information
StephanEwen committed Aug 16, 2015
1 parent beed1d4 commit 57caed8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,35 @@
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.test.util.ForkableFlinkMiniCluster;
import org.apache.flink.util.Collector;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
* A simple test that runs a streaming topology with checkpointing enabled.
*
* The test triggers a failure after a while and verifies that, after completion, the
* state defined with either the {@link OperatorState} or the {@link Checkpointed}
* interface reflects the "exactly once" semantics.
*
* The test throttles the input until at least two checkpoints are completed, to make sure that
* the recovery does not fall back to "square one" (which would naturally lead to correct
* results without testing the checkpointing).
*/
@SuppressWarnings("serial")
public class StateCheckpoinedITCase extends StreamFaultToleranceTestBase {
Expand All @@ -63,18 +61,25 @@ public class StateCheckpoinedITCase extends StreamFaultToleranceTestBase {
* Runs the following program:
*
* <pre>
* [ (source)->(filter)->(map) ] -> [ (map) ] -> [ (groupBy/reduce)->(sink) ]
* [ (source)->(filter)] -> [ (map) -> (map) ] -> [ (groupBy/reduce)->(sink) ]
* </pre>
*/
@Override
public void testProgram(StreamExecutionEnvironment env) {
assertTrue("Broken test setup", NUM_STRINGS % 40 == 0);

final long failurePosMin = (long) (0.4 * NUM_STRINGS / PARALLELISM);
final long failurePosMax = (long) (0.7 * NUM_STRINGS / PARALLELISM);

final long failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;

DataStream<String> stream = env.addSource(new StringGeneratingSourceFunction(NUM_STRINGS));

stream
// -------------- first vertex, chained to the source ----------------
.filter(new StringRichFilterFunction())
// first vertex, chained to the source
// this filter throttles the flow until at least one checkpoint
// is complete, to make sure this program does not run without
.filter(new StringRichFilterFunction(failurePos))

// -------------- seconds vertex - one-to-one connected ----------------
.map(new StringPrefixCountRichMapFunction())
Expand All @@ -83,12 +88,16 @@ public void testProgram(StreamExecutionEnvironment env) {

// -------------- third vertex - reducer and the sink ----------------
.partitionByHash("prefix")
.flatMap(new OnceFailingAggregator(NUM_STRINGS))
.flatMap(new OnceFailingAggregator(failurePos))
.addSink(new ValidatingSink());
}

@Override
public void postSubmit() {

assertTrue("Test inconclusive: failure occurred before first checkpoint",
OnceFailingAggregator.wasCheckpointedBeforeFailure);

long filterSum = 0;
for (long l : StringRichFilterFunction.counts) {
filterSum += l;
Expand Down Expand Up @@ -189,15 +198,27 @@ public void restoreState(Integer state) {
}
}

private static class StringRichFilterFunction extends RichFilterFunction<String> implements Checkpointed<Long> {
private static class StringRichFilterFunction extends RichFilterFunction<String>
implements Checkpointed<Long>, CheckpointNotifier
{

static final long[] counts = new long[PARALLELISM];

private final long failurePos;
private long count;
private int numTimesCheckpointed;

private StringRichFilterFunction(long failurePos) {
this.failurePos = failurePos;
}


@Override
public boolean filter(String value) {
public boolean filter(String value) throws Exception {
count++;
if (count < failurePos && numTimesCheckpointed < 2) {
Thread.sleep(1);
}
return value.length() < 100; // should be always true
}

Expand All @@ -215,6 +236,11 @@ public Long snapshotState(long checkpointId, long checkpointTimestamp) {
public void restoreState(Long state) {
count = state;
}

@Override
public void notifyCheckpointComplete(long checkpointId) {
numTimesCheckpointed++;
}
}

private static class StringPrefixCountRichMapFunction extends RichMapFunction<String, PrefixCount>
Expand Down Expand Up @@ -271,35 +297,34 @@ public void close() throws IOException {
}

private static class OnceFailingAggregator extends RichFlatMapFunction<PrefixCount, PrefixCount>
implements Checkpointed<HashMap<String, PrefixCount>> {
implements Checkpointed<HashMap<String, PrefixCount>>, CheckpointNotifier {

static boolean wasCheckpointedBeforeFailure = false;

private static volatile boolean hasFailed = false;

private final HashMap<String, PrefixCount> aggregationMap = new HashMap<String, PrefixCount>();

private final long numElements;

private long failurePos;
private long count;

private boolean wasCheckpointed;


OnceFailingAggregator(long numElements) {
this.numElements = numElements;
OnceFailingAggregator(long failurePos) {
this.failurePos = failurePos;
}

@Override
public void open(Configuration parameters) {
long failurePosMin = (long) (0.4 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());
long failurePosMax = (long) (0.7 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());

failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
count = 0;
}

@Override
public void flatMap(PrefixCount value, Collector<PrefixCount> out) throws Exception {
count++;
if (!hasFailed && count >= failurePos) {
if (!hasFailed && count >= failurePos && getRuntimeContext().getIndexOfThisSubtask() == 1) {
wasCheckpointedBeforeFailure = wasCheckpointed;
hasFailed = true;
throw new Exception("Test Failure");
}
Expand All @@ -324,6 +349,11 @@ public HashMap<String, PrefixCount> snapshotState(long checkpointId, long checkp
public void restoreState(HashMap<String, PrefixCount> state) {
aggregationMap.putAll(state);
}

@Override
public void notifyCheckpointComplete(long checkpointId) {
this.wasCheckpointed = true;
}
}

private static class ValidatingSink extends RichSinkFunction<PrefixCount>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,22 @@

package org.apache.flink.test.checkpointing;


import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.test.util.ForkableFlinkMiniCluster;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.IOException;
import java.io.Serializable;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;


/**
* Test base for fault tolerant streaming programs
*/
@SuppressWarnings("serial")
public abstract class StreamFaultToleranceTestBase {

protected static final int NUM_TASK_MANAGERS = 2;
Expand Down Expand Up @@ -127,6 +111,7 @@ public void runCheckpointedProgram() {
// Frequently used utilities
// --------------------------------------------------------------------------------------------

@SuppressWarnings("serial")
public static class PrefixCount implements Serializable {

public String prefix;
Expand All @@ -146,5 +131,4 @@ public String toString() {
return prefix + " / " + value;
}
}

}

0 comments on commit 57caed8

Please sign in to comment.