Skip to content

Commit

Permalink
[FLINK-2713] [streaming] Set state restore to lazy to avoid StateChec…
Browse files Browse the repository at this point in the history
…kpointer issues and reduce checkpoint overhead

Closes #1154
  • Loading branch information
gyfora committed Sep 21, 2015
1 parent b9663c4 commit 63d9800
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 37 deletions.
Expand Up @@ -55,6 +55,8 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
private IN currentInput; private IN currentInput;


private ClassLoader cl; private ClassLoader cl;
private boolean restored = true;
private StateHandle<Serializable> checkpoint = null;


public PartitionedStreamOperatorState(StateCheckpointer<S, C> checkpointer, public PartitionedStreamOperatorState(StateCheckpointer<S, C> checkpointer,
StateHandleProvider<C> provider, KeySelector<IN, Serializable> keySelector, ClassLoader cl) { StateHandleProvider<C> provider, KeySelector<IN, Serializable> keySelector, ClassLoader cl) {
Expand All @@ -76,6 +78,10 @@ public S value() throws IOException {
if (currentInput == null) { if (currentInput == null) {
throw new IllegalStateException("Need a valid input for accessing the state."); throw new IllegalStateException("Need a valid input for accessing the state.");
} else { } else {
if (!restored) {
// If the state is not restored yet, restore now
restoreWithCheckpointer();
}
Serializable key; Serializable key;
try { try {
key = keySelector.getKey(currentInput); key = keySelector.getKey(currentInput);
Expand All @@ -100,6 +106,10 @@ public void update(S state) throws IOException {
if (currentInput == null) { if (currentInput == null) {
throw new IllegalStateException("Need a valid input for updating a state."); throw new IllegalStateException("Need a valid input for updating a state.");
} else { } else {
if (!restored) {
// If the state is not restored yet, restore now
restoreWithCheckpointer();
}
Serializable key; Serializable key;
try { try {
key = keySelector.getKey(currentInput); key = keySelector.getKey(currentInput);
Expand Down Expand Up @@ -131,18 +141,38 @@ public void setCurrentInput(IN input) {


@Override @Override
public StateHandle<Serializable> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { public StateHandle<Serializable> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return stateStore.snapshotStates(checkpointId, checkpointTimestamp); // If the state is restored we take a snapshot, otherwise return the last checkpoint
return restored ? stateStore.snapshotStates(checkpointId, checkpointTimestamp) : provider
.createStateHandle(checkpoint.getState(cl));
} }

@Override @Override
public void restoreState(StateHandle<Serializable> snapshots, ClassLoader userCodeClassLoader) throws Exception { public void restoreState(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception {
stateStore.restoreStates(snapshots, userCodeClassLoader); // We store the snapshot for lazy restore
checkpoint = snapshot;
restored = false;
}

private void restoreWithCheckpointer() throws IOException {
try {
stateStore.restoreStates(checkpoint, cl);
} catch (Exception e) {
throw new IOException(e);
}
restored = true;
checkpoint = null;
} }


@Override @Override
public Map<Serializable, S> getPartitionedState() throws Exception { public Map<Serializable, S> getPartitionedState() throws Exception {
return stateStore.getPartitionedState(); return stateStore.getPartitionedState();
} }

@Override
public void setCheckpointer(StateCheckpointer<S, C> checkpointer) {
super.setCheckpointer(checkpointer);
stateStore.setCheckPointer(checkpointer);
}


@Override @Override
public String toString() { public String toString() {
Expand Down
Expand Up @@ -44,7 +44,10 @@ public class StreamOperatorState<S, C extends Serializable> implements OperatorS


private S state; private S state;
protected StateCheckpointer<S, C> checkpointer; protected StateCheckpointer<S, C> checkpointer;
private final StateHandleProvider<Serializable> provider; protected final StateHandleProvider<Serializable> provider;

private boolean restored = true;
private Serializable checkpoint = null;


@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public StreamOperatorState(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) { public StreamOperatorState(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) {
Expand All @@ -59,6 +62,10 @@ public StreamOperatorState(StateHandleProvider<C> provider) {


@Override @Override
public S value() throws IOException { public S value() throws IOException {
if (!restored) {
// If the state is not restore it yet, restore at this point
restoreWithCheckpointer();
}
return state; return state;
} }


Expand All @@ -67,6 +74,11 @@ public void update(S state) throws IOException {
if (state == null) { if (state == null) {
throw new RuntimeException("Cannot set state to null."); throw new RuntimeException("Cannot set state to null.");
} }
if (!restored) {
// If the value is updated before the restore it is overwritten
restored = true;
checkpoint = false;
}
this.state = state; this.state = state;
} }


Expand All @@ -90,14 +102,22 @@ protected StateHandleProvider<Serializable> getStateHandleProvider() {


public StateHandle<Serializable> snapshotState(long checkpointId, long checkpointTimestamp) public StateHandle<Serializable> snapshotState(long checkpointId, long checkpointTimestamp)
throws Exception { throws Exception {
return provider.createStateHandle(checkpointer.snapshotState(value(), checkpointId, // If the state is restored we take a snapshot, otherwise return the last checkpoint
checkpointTimestamp)); return provider.createStateHandle(restored ? checkpointer.snapshotState(value(), checkpointId,

checkpointTimestamp) : checkpoint);
} }


@SuppressWarnings("unchecked")
public void restoreState(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception { public void restoreState(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception {
update(checkpointer.restoreState((C) snapshot.getState(userCodeClassLoader))); // We set the checkpoint for lazy restore
checkpoint = snapshot.getState(userCodeClassLoader);
restored = false;
}

@SuppressWarnings("unchecked")
private void restoreWithCheckpointer() throws IOException {
update(checkpointer.restoreState((C) checkpoint));
restored = true;
checkpoint = null;
} }


public Map<Serializable, S> getPartitionedState() throws Exception { public Map<Serializable, S> getPartitionedState() throws Exception {
Expand Down
Expand Up @@ -83,11 +83,14 @@ public void simpleStateTest() throws Exception {
assertEquals("12345", context.getOperatorState("concat", "", false).value()); assertEquals("12345", context.getOperatorState("concat", "", false).value());
assertEquals((Integer) 5, ((StatefulMapper) map.getUserFunction()).checkpointedCounter); assertEquals((Integer) 5, ((StatefulMapper) map.getUserFunction()).checkpointedCounter);


byte[] serializedState = InstantiationUtil.serializeObject(map.getStateSnapshotFromFunction(1, 1)); byte[] serializedState0 = InstantiationUtil.serializeObject(map.getStateSnapshotFromFunction(1, 1));
// Restore state but snapshot again before calling the value
byte[] serializedState = InstantiationUtil.serializeObject(createOperatorWithContext(out,
new ModKey(2), serializedState0).getStateSnapshotFromFunction(1, 1));


StreamMap<Integer, String> restoredMap = createOperatorWithContext(out, new ModKey(2), serializedState); StreamMap<Integer, String> restoredMap = createOperatorWithContext(out, new ModKey(2), serializedState);
StreamingRuntimeContext restoredContext = restoredMap.getRuntimeContext(); StreamingRuntimeContext restoredContext = restoredMap.getRuntimeContext();

assertEquals((Integer) 5, restoredContext.getOperatorState("counter", 0, false).value()); assertEquals((Integer) 5, restoredContext.getOperatorState("counter", 0, false).value());
assertEquals(ImmutableMap.of(0, new MutableInt(2), 1, new MutableInt(3)), context.getOperatorStates().get("groupCounter").getPartitionedState()); assertEquals(ImmutableMap.of(0, new MutableInt(2), 1, new MutableInt(3)), context.getOperatorStates().get("groupCounter").getPartitionedState());
assertEquals("12345", restoredContext.getOperatorState("concat", "", false).value()); assertEquals("12345", restoredContext.getOperatorState("concat", "", false).value());
Expand Down Expand Up @@ -227,7 +230,7 @@ public String map(Integer value) throws Exception {


@Override @Override
public void open(Configuration conf) throws IOException { public void open(Configuration conf) throws IOException {
counter = getRuntimeContext().getOperatorState("counter", 0, false); counter = getRuntimeContext().getOperatorState("counter", 0, false, intCheckpointer);
groupCounter = getRuntimeContext().getOperatorState("groupCounter", new MutableInt(0), true); groupCounter = getRuntimeContext().getOperatorState("groupCounter", new MutableInt(0), true);
concat = getRuntimeContext().getOperatorState("concat", "", false); concat = getRuntimeContext().getOperatorState("concat", "", false);
try { try {
Expand Down Expand Up @@ -279,19 +282,7 @@ public String map(Integer value) throws Exception {


@Override @Override
public void open(Configuration conf) throws IOException { public void open(Configuration conf) throws IOException {
groupCounter = getRuntimeContext().getOperatorState("groupCounter", 0, true, groupCounter = getRuntimeContext().getOperatorState("groupCounter", 0, true, intCheckpointer);
new StateCheckpointer<Integer, String>() {

@Override
public String snapshotState(Integer state, long checkpointId, long checkpointTimestamp) {
return state.toString();
}

@Override
public Integer restoreState(String stateSnapshot) {
return Integer.parseInt(stateSnapshot);
}
});
} }


@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Expand All @@ -308,6 +299,21 @@ public void close() throws Exception {
} }


} }

public static StateCheckpointer<Integer, String> intCheckpointer = new StateCheckpointer<Integer, String>() {

private static final long serialVersionUID = 1L;

@Override
public String snapshotState(Integer state, long checkpointId, long checkpointTimestamp) {
return state.toString();
}

@Override
public Integer restoreState(String stateSnapshot) {
return Integer.parseInt(stateSnapshot);
}
};


public static class PStateKeyRemovalTestMapper extends RichMapFunction<Integer, String> { public static class PStateKeyRemovalTestMapper extends RichMapFunction<Integer, String> {


Expand Down
Expand Up @@ -20,7 +20,6 @@


import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;


import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
Expand All @@ -30,18 +29,14 @@


import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.OperatorState; import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.StateCheckpointer;
import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.test.util.ForkableFlinkMiniCluster;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;


/** /**
* A simple test that runs a streaming topology with checkpointing enabled. * A simple test that runs a streaming topology with checkpointing enabled.
Expand Down Expand Up @@ -184,22 +179,47 @@ private static class CounterSink extends RichSinkFunction<Tuple2<Integer, Long>>


private static Map<Integer, Long> allCounts = new ConcurrentHashMap<Integer, Long>(); private static Map<Integer, Long> allCounts = new ConcurrentHashMap<Integer, Long>();


private OperatorState<Long> counts; private OperatorState<NonSerializableLong> counts;


@Override @Override
public void open(Configuration parameters) throws IOException { public void open(Configuration parameters) throws IOException {
counts = getRuntimeContext().getOperatorState("count", 0L, true); counts = getRuntimeContext().getOperatorState("count", NonSerializableLong.of(0L), true,
new StateCheckpointer<NonSerializableLong, String>() {

@Override
public String snapshotState(NonSerializableLong state, long id, long ts) {
return state.value.toString();
}

@Override
public NonSerializableLong restoreState(String stateSnapshot) {
return NonSerializableLong.of(Long.parseLong(stateSnapshot));
}

});
} }


@Override @Override
public void invoke(Tuple2<Integer, Long> value) throws Exception { public void invoke(Tuple2<Integer, Long> value) throws Exception {
long currentCount = counts.value() + 1; long currentCount = counts.value().value + 1;
counts.update(currentCount); counts.update(NonSerializableLong.of(currentCount));
allCounts.put(value.f0, currentCount); allCounts.put(value.f0, currentCount);


} }
} }


private static class NonSerializableLong {
public Long value;

private NonSerializableLong(long value) {
this.value = value;
}

public static NonSerializableLong of(long value) {
return new NonSerializableLong(value);
}
}

private static class IdentityKeySelector<T> implements KeySelector<T, T> { private static class IdentityKeySelector<T> implements KeySelector<T, T> {


@Override @Override
Expand Down
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.api.common.functions.RichFilterFunction; import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.OperatorState; import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.StateCheckpointer;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStream;
Expand Down Expand Up @@ -308,7 +309,20 @@ public PrefixCount map(String value) throws IOException {


@Override @Override
public void open(Configuration conf) throws IOException { public void open(Configuration conf) throws IOException {
this.count = getRuntimeContext().getOperatorState("count", 0L, false); this.count = getRuntimeContext().getOperatorState("count", 0L, false,
new StateCheckpointer<Long, String>() {

@Override
public String snapshotState(Long state, long id, long ts) {
return state.toString();
}

@Override
public Long restoreState(String stateSnapshot) {
return Long.parseLong(stateSnapshot);
}

});
} }


@Override @Override
Expand Down

0 comments on commit 63d9800

Please sign in to comment.