Skip to content

Commit

Permalink
[streaming] Allow multiple operator states + stateful function test r…
Browse files Browse the repository at this point in the history
…ework
  • Loading branch information
gyfora committed Jun 25, 2015
1 parent a7e2458 commit ef11e63
Show file tree
Hide file tree
Showing 23 changed files with 369 additions and 299 deletions.
Expand Up @@ -166,17 +166,21 @@ public interface RuntimeContext {
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------


/** /**
* Returns the {@link OperatorState} of this operator instance, which can be * Returns the {@link OperatorState} with the given name of the underlying
* used to store and update user state in a fault tolerant fashion. The * operator instance, which can be used to store and update user state in a
* state will be initialized by the provided default value, and the * fault tolerant fashion. The state will be initialized by the provided
* {@link StateCheckpointer} will be used to draw the state snapshots. * default value, and the {@link StateCheckpointer} will be used to draw the
* state snapshots.
* *
* <p> * <p>
* When storing a {@link Serializable} state the user can omit the * When storing a {@link Serializable} state the user can omit the
* {@link StateCheckpointer} in which case the full state will be written as * {@link StateCheckpointer} in which case the full state will be written as
* the snapshot. * the snapshot.
* </p> * </p>
* *
* @param name
* Identifier for the state allowing that more operator states
* can be used by the same operator.
* @param defaultState * @param defaultState
* Default value for the operator state. This will be returned * Default value for the operator state. This will be returned
* the first time {@link OperatorState#getState()} (for every * the first time {@link OperatorState#getState()} (for every
Expand All @@ -185,26 +189,30 @@ public interface RuntimeContext {
* @param checkpointer * @param checkpointer
* The {@link StateCheckpointer} that will be used to draw * The {@link StateCheckpointer} that will be used to draw
* snapshots from the user state. * snapshots from the user state.
* @return The {@link OperatorState} for this instance. * @return The {@link OperatorState} for the underlying operator.
*/ */
<S,C extends Serializable> OperatorState<S> getOperatorState(S defaultState, StateCheckpointer<S,C> checkpointer); <S,C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState, StateCheckpointer<S,C> checkpointer);


/** /**
* Returns the {@link OperatorState} of this operator instance, which can be * Returns the {@link OperatorState} with the given name of the underlying
* used to store and update user state in a fault tolerant fashion. The * operator instance, which can be used to store and update user state in a
* state will be initialized by the provided default value. * fault tolerant fashion. The state will be initialized by the provided
* default value.
* *
* <p> * <p>
* When storing a non-{@link Serializable} state the user needs to specify a * When storing a non-{@link Serializable} state the user needs to specify a
* {@link StateCheckpointer} for drawing snapshots. * {@link StateCheckpointer} for drawing snapshots.
* </p> * </p>
* *
* @param name
* Identifier for the state allowing that more operator states can be
* used by the same operator.
* @param defaultState * @param defaultState
* Default value for the operator state. This will be returned * Default value for the operator state. This will be returned
* the first time {@link OperatorState#getState()} (for every * the first time {@link OperatorState#getState()} (for every
* state partition) is called before * state partition) is called before
* {@link OperatorState#updateState(Object)}. * {@link OperatorState#updateState(Object)}.
* @return The {@link OperatorState} for this instance. * @return The {@link OperatorState} for the underlying operator.
*/ */
<S extends Serializable> OperatorState<S> getOperatorState(S defaultState); <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState);
} }
Expand Up @@ -174,12 +174,12 @@ private <V, A extends Serializable> Accumulator<V, A> getAccumulator(String name
} }


@Override @Override
public <S, C extends Serializable> OperatorState<S> getOperatorState(S defaultState, StateCheckpointer<S, C> checkpointer) { public <S, C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState, StateCheckpointer<S, C> checkpointer) {
throw new UnsupportedOperationException("Operator state is only accessible for streaming operators."); throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
} }


@Override @Override
public <S extends Serializable> OperatorState<S> getOperatorState(S defaultState) { public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState) {
throw new UnsupportedOperationException("Operator state is only accessible for streaming operators."); throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
} }
} }
Expand Up @@ -21,7 +21,7 @@
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapFunction;


/** /**
* Base class for all streaming operator states. It can represent both * Base interface for all streaming operator states. It can represent both
* partitioned (when state partitioning is defined in the program) or * partitioned (when state partitioning is defined in the program) or
* non-partitioned user states. * non-partitioned user states.
* *
Expand All @@ -30,7 +30,7 @@
* transformation call the operator represents, for instance inside * transformation call the operator represents, for instance inside
* {@link MapFunction#map()} and invalid in * {@link MapFunction#map()} and invalid in
* {@link #open(org.apache.flink.configuration.Configuration)} or * {@link #open(org.apache.flink.configuration.Configuration)} or
* {@link #close()}. * {@link #close()} methods.
* *
* @param <T> * @param <T>
* Type of the operator state * Type of the operator state
Expand All @@ -42,13 +42,7 @@ public interface OperatorState<T> {
* partitioned the returned state is the same for all inputs. If state * partitioned the returned state is the same for all inputs. If state
* partitioning is applied the state returned depends on the current * partitioning is applied the state returned depends on the current
* operator input, as the operator maintains an independent state for each * operator input, as the operator maintains an independent state for each
* partitions. * partition.
*
* <p>
* {@link #getState()} returns <code>null</code> if there is no state stored
* in the operator. This is the expected behaviour before initializing the
* state with {@link #updateState(T)}.
* </p>
* *
* @return The operator state corresponding to the current input. * @return The operator state corresponding to the current input.
*/ */
Expand All @@ -60,7 +54,7 @@ public interface OperatorState<T> {
* partition) the returned state will represent the updated value. * partition) the returned state will represent the updated value.
* *
* @param state * @param state
* The updated state. * The new state.
*/ */
void updateState(T state); void updateState(T state);


Expand Down
Expand Up @@ -21,6 +21,8 @@
import java.io.Serializable; import java.io.Serializable;
import java.util.Map; import java.util.Map;


import org.apache.flink.api.common.state.StateCheckpointer;

/** /**
* Interface for storing and accessing partitioned state. The interface is * Interface for storing and accessing partitioned state. The interface is
* designed in a way that allows implementations for lazily state access. * designed in a way that allows implementations for lazily state access.
Expand All @@ -43,5 +45,7 @@ public interface PartitionedStateStore<S, C extends Serializable> {
void restoreStates(Map<Serializable, StateHandle<C>> snapshots) throws Exception; void restoreStates(Map<Serializable, StateHandle<C>> snapshots) throws Exception;


boolean containsKey(Serializable key); boolean containsKey(Serializable key);

void setCheckPointer(StateCheckpointer<S, C> checkpointer);


} }
Expand Up @@ -17,8 +17,6 @@


package org.apache.flink.streaming.api.datastream; package org.apache.flink.streaming.api.datastream;


import org.apache.flink.streaming.api.collector.selector.OutputSelector;

/** /**
* The iterative data stream represents the start of an iteration in a * The iterative data stream represents the start of an iteration in a
* {@link DataStream}. * {@link DataStream}.
Expand Down
Expand Up @@ -79,4 +79,9 @@ public boolean isInputCopyingDisabled() {
public void disableInputCopy() { public void disableInputCopy() {
this.inputCopyDisabled = true; this.inputCopyDisabled = true;
} }

@Override
public StreamingRuntimeContext getRuntimeContext(){
return runtimeContext;
}
} }
Expand Up @@ -19,7 +19,9 @@
package org.apache.flink.streaming.api.operators; package org.apache.flink.streaming.api.operators;


import java.io.Serializable; import java.io.Serializable;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;


import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.functions.util.FunctionUtils;
Expand Down Expand Up @@ -67,24 +69,32 @@ public void close() throws Exception {
FunctionUtils.closeFunction(userFunction); FunctionUtils.closeFunction(userFunction);
} }


@SuppressWarnings("unchecked") @SuppressWarnings({ "unchecked", "rawtypes" })
public void restoreInitialState(Serializable state) throws Exception { public void restoreInitialState(Serializable state) throws Exception {


Map<Serializable, StateHandle<Serializable>> snapshots = (Map<Serializable, StateHandle<Serializable>>) state; Map<String, Map<Serializable, StateHandle<Serializable>>> snapshots = (Map<String, Map<Serializable, StateHandle<Serializable>>>) state;


StreamOperatorState<?, Serializable> operatorState = (StreamOperatorState<?, Serializable>) runtimeContext Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();
.getOperatorState();

for (Entry<String, Map<Serializable, StateHandle<Serializable>>> snapshot : snapshots.entrySet()) {
operatorState.restoreState(snapshots); StreamOperatorState restoredState = runtimeContext.createRawState();
restoredState.restoreState(snapshot.getValue());
operatorStates.put(snapshot.getKey(), restoredState);
}


} }


public Serializable getStateSnapshotFromFunction(long checkpointId, long timestamp) @SuppressWarnings({ "rawtypes", "unchecked" })
throws Exception { public Serializable getStateSnapshotFromFunction(long checkpointId, long timestamp) throws Exception {


StreamOperatorState<?,?> operatorState = (StreamOperatorState<?,?>) runtimeContext.getOperatorState(); Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();

Map<String, Map<Serializable, StateHandle<Serializable>>> snapshots = new HashMap<String, Map<Serializable, StateHandle<Serializable>>>();
return (Serializable) operatorState.snapshotState(checkpointId, timestamp);
for (Entry<String, StreamOperatorState> state : operatorStates.entrySet()) {
snapshots.put(state.getKey(), state.getValue().snapshotState(checkpointId, timestamp));
}

return (Serializable) snapshots;
} }


public void confirmCheckpointCompleted(long checkpointId, long timestamp, public void confirmCheckpointCompleted(long checkpointId, long timestamp,
Expand Down
Expand Up @@ -48,6 +48,8 @@ public interface StreamOperator<OUT> extends Serializable {
* This method is called after no more elements for can arrive for processing. * This method is called after no more elements for can arrive for processing.
*/ */
public void close() throws Exception; public void close() throws Exception;

public StreamingRuntimeContext getRuntimeContext();


/** /**
* An operator can return true here to disable copying of its input elements. This overrides * An operator can return true here to disable copying of its input elements. This overrides
Expand Down
Expand Up @@ -30,10 +30,10 @@


public class EagerStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> { public class EagerStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> {


protected StateCheckpointer<S, C> checkpointer; private StateCheckpointer<S, C> checkpointer;
protected StateHandleProvider<C> provider; private final StateHandleProvider<C> provider;


Map<Serializable, S> fetchedState; private Map<Serializable, S> fetchedState;


public EagerStateStore(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) { public EagerStateStore(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) {
this.checkpointer = checkpointer; this.checkpointer = checkpointer;
Expand Down Expand Up @@ -83,4 +83,9 @@ public boolean containsKey(Serializable key) {
return fetchedState.containsKey(key); return fetchedState.containsKey(key);
} }


@Override
public void setCheckPointer(StateCheckpointer<S, C> checkpointer) {
this.checkpointer = checkpointer;
}

} }
Expand Up @@ -47,10 +47,10 @@
public class LazyStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> { public class LazyStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> {


protected StateCheckpointer<S, C> checkpointer; protected StateCheckpointer<S, C> checkpointer;
protected StateHandleProvider<C> provider; protected final StateHandleProvider<C> provider;


Map<Serializable, StateHandle<C>> unfetchedState; private final Map<Serializable, StateHandle<C>> unfetchedState;
Map<Serializable, S> fetchedState; private final Map<Serializable, S> fetchedState;


public LazyStateStore(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) { public LazyStateStore(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) {
this.checkpointer = checkpointer; this.checkpointer = checkpointer;
Expand Down Expand Up @@ -114,4 +114,9 @@ public boolean containsKey(Serializable key) {
return fetchedState.containsKey(key) || unfetchedState.containsKey(key); return fetchedState.containsKey(key) || unfetchedState.containsKey(key);
} }


@Override
public void setCheckPointer(StateCheckpointer<S, C> checkpointer) {
this.checkpointer = checkpointer;
}

} }
Expand Up @@ -45,9 +45,9 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
StreamOperatorState<S, C> { StreamOperatorState<S, C> {


// KeySelector for getting the state partition key for each input // KeySelector for getting the state partition key for each input
private KeySelector<IN, Serializable> keySelector; private final KeySelector<IN, Serializable> keySelector;


private PartitionedStateStore<S, C> stateStore; private final PartitionedStateStore<S, C> stateStore;


private S defaultState; private S defaultState;


Expand Down
Expand Up @@ -44,7 +44,7 @@ public class StreamOperatorState<S, C extends Serializable> implements OperatorS


private S state; private S state;
private StateCheckpointer<S, C> checkpointer; private StateCheckpointer<S, C> checkpointer;
private StateHandleProvider<C> provider; private final StateHandleProvider<C> provider;


public StreamOperatorState(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) { public StreamOperatorState(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) {
this.checkpointer = checkpointer; this.checkpointer = checkpointer;
Expand All @@ -66,8 +66,11 @@ public void updateState(S state) {
this.state = state; this.state = state;
} }


public void setDefaultState(S defaultState){ public void setDefaultState(S defaultState) {
updateState(defaultState); // reconsider this as it might cause issues when setting the state to null
if (getState() == null) {
updateState(defaultState);
}
} }


public StateCheckpointer<S, C> getCheckpointer() { public StateCheckpointer<S, C> getCheckpointer() {
Expand Down
Expand Up @@ -99,6 +99,7 @@ public void invoke() throws Exception {


StreamRecord<IN> nextRecord; StreamRecord<IN> nextRecord;
while (isRunning && (nextRecord = readNext()) != null) { while (isRunning && (nextRecord = readNext()) != null) {
headContext.setNextInput(nextRecord);
streamOperator.processElement(nextRecord.getObject()); streamOperator.processElement(nextRecord.getObject());
} }


Expand Down
Expand Up @@ -226,16 +226,18 @@ public void clearWriters() {
} }


private static class OperatorCollector<T> implements Output<T> { private static class OperatorCollector<T> implements Output<T> {
protected OneInputStreamOperator operator;


public OperatorCollector(OneInputStreamOperator<?, T> operator) { protected OneInputStreamOperator<Object, T> operator;

public OperatorCollector(OneInputStreamOperator<Object, T> operator) {
this.operator = operator; this.operator = operator;
} }


@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void collect(T record) { public void collect(T record) {
try { try {
operator.getRuntimeContext().setNextInput(record);
operator.processElement(record); operator.processElement(record);
} catch (Exception e) { } catch (Exception e) {
if (LOG.isErrorEnabled()) { if (LOG.isErrorEnabled()) {
Expand Down

0 comments on commit ef11e63

Please sign in to comment.