Skip to content

Commit

Permalink
[FLINK-2664] [streaming] Allow partitioned state removal
Browse files Browse the repository at this point in the history
Closes #1126
  • Loading branch information
gyfora committed Sep 14, 2015
1 parent ce68cbd commit 8a75025
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 50 deletions.
Expand Up @@ -55,7 +55,9 @@ public interface OperatorState<T> {
/**
* Updates the operator state accessible by {@link #value()} to the given
* value. The next time {@link #value()} is called (for the same state
* partition) the returned state will represent the updated value.
* partition) the returned state will represent the updated value. When a
* partitioned state is updated with null, the state for the current key
* will be removed and the default value is returned on the next access.
*
* @param state
* The new value for the state.
Expand Down
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.streaming.api.state;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -29,7 +30,7 @@

public class EagerStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> {

private StateCheckpointer<S,C> checkpointer;
private StateCheckpointer<S, C> checkpointer;
private final StateHandleProvider<Serializable> provider;

private Map<Serializable, S> fetchedState;
Expand All @@ -43,7 +44,7 @@ public EagerStateStore(StateCheckpointer<S, C> checkpointer, StateHandleProvider
}

@Override
public S getStateForKey(Serializable key) throws Exception {
public S getStateForKey(Serializable key) throws IOException {
return fetchedState.get(key);
}

Expand All @@ -53,7 +54,12 @@ public void setStateForKey(Serializable key, S state) {
}

@Override
public Map<Serializable, S> getPartitionedState() throws Exception {
public void removeStateForKey(Serializable key) {
fetchedState.remove(key);
}

@Override
public Map<Serializable, S> getPartitionedState() throws IOException {
return fetchedState;
}

Expand All @@ -69,11 +75,12 @@ public StateHandle<Serializable> snapshotStates(long checkpointId, long checkpoi
}

@Override
public void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception {

public void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader)
throws Exception {

@SuppressWarnings("unchecked")
Map<Serializable, C> checkpoints = (Map<Serializable, C>) snapshot.getState(userCodeClassLoader);

// we map the values back to the state from the checkpoints
for (Entry<Serializable, C> snapshotEntry : checkpoints.entrySet()) {
fetchedState.put(snapshotEntry.getKey(), (S) checkpointer.restoreState(snapshotEntry.getValue()));
Expand All @@ -89,10 +96,9 @@ public boolean containsKey(Serializable key) {
public void setCheckPointer(StateCheckpointer<S, C> checkpointer) {
this.checkpointer = checkpointer;
}

@Override
public String toString() {
return fetchedState.toString();
}

}
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.streaming.api.state;

import java.io.IOException;
import java.io.Serializable;
import java.util.Map;

Expand All @@ -35,13 +36,15 @@
*/
public interface PartitionedStateStore<S, C extends Serializable> {

S getStateForKey(Serializable key) throws Exception;
S getStateForKey(Serializable key) throws IOException;

void setStateForKey(Serializable key, S state);

void removeStateForKey(Serializable key);

Map<Serializable, S> getPartitionedState() throws Exception;
Map<Serializable, S> getPartitionedState() throws IOException;

StateHandle<Serializable> snapshotStates(long checkpointId, long checkpointTimestamp) throws Exception;
StateHandle<Serializable> snapshotStates(long checkpointId, long checkpointTimestamp) throws IOException;

void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception;

Expand Down
Expand Up @@ -42,8 +42,7 @@
* @param <C>
* Type of the state snapshot.
*/
public class PartitionedStreamOperatorState<IN, S, C extends Serializable> extends
StreamOperatorState<S, C> {
public class PartitionedStreamOperatorState<IN, S, C extends Serializable> extends StreamOperatorState<S, C> {

// KeySelector for getting the state partition key for each input
private final KeySelector<IN, Serializable> keySelector;
Expand Down Expand Up @@ -77,41 +76,50 @@ public S value() throws IOException {
if (currentInput == null) {
throw new IllegalStateException("Need a valid input for accessing the state.");
} else {
Serializable key;
try {
Serializable key = keySelector.getKey(currentInput);
if (stateStore.containsKey(key)) {
return stateStore.getStateForKey(key);
} else {
key = keySelector.getKey(currentInput);
} catch (Exception e) {
throw new RuntimeException("User-defined key selector threw an exception.", e);
}
if (stateStore.containsKey(key)) {
return stateStore.getStateForKey(key);
} else {
try {
return (S) checkpointer.restoreState((C) InstantiationUtil.deserializeObject(
defaultState, cl));
} catch (ClassNotFoundException e) {
throw new RuntimeException("Could not deserialize default state value.", e);
}
} catch (Exception e) {
throw new RuntimeException("User-defined key selector threw an exception.", e);
}
}
}

@Override
public void update(S state) throws IOException {
if (state == null) {
throw new RuntimeException("Cannot set state to null.");
}
if (currentInput == null) {
throw new IllegalStateException("Need a valid input for updating a state.");
} else {
Serializable key;
try {
stateStore.setStateForKey(keySelector.getKey(currentInput), state);
key = keySelector.getKey(currentInput);
} catch (Exception e) {
throw new RuntimeException("User-defined key selector threw an exception.");
}

if (state == null) {
// Remove state if set to null
stateStore.removeStateForKey(key);
} else {
stateStore.setStateForKey(key, state);
}
}
}

@Override
public void setDefaultState(S defaultState) {
try {
this.defaultState = InstantiationUtil.serializeObject(checkpointer.snapshotState(
defaultState, 0, 0));
this.defaultState = InstantiationUtil.serializeObject(checkpointer.snapshotState(defaultState, 0, 0));
} catch (IOException e) {
throw new RuntimeException("Default state must be serializable.");
}
Expand All @@ -122,8 +130,7 @@ public void setCurrentInput(IN input) {
}

@Override
public StateHandle<Serializable> snapshotState(long checkpointId,
long checkpointTimestamp) throws Exception {
public StateHandle<Serializable> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return stateStore.snapshotStates(checkpointId, checkpointTimestamp);
}

Expand Down

0 comments on commit 8a75025

Please sign in to comment.