Skip to content

Commit

Permalink
Hide broadcast state / remove from public API
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanRRichter authored and aljoscha committed Jan 13, 2017
1 parent 1020ba2 commit 6a86e9d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 35 deletions.
Expand Up @@ -56,33 +56,6 @@ public interface OperatorStateStore {
*/ */
<T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception; <T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception;


/**
* Creates (or restores) a list state. Each state is registered under a unique name.
* The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
*
* On restore, all items in the list are broadcasted to all parallel operator instances.
*
* @param stateDescriptor The descriptor for this state, providing a name and serializer.
* @param <S> The generic type of the state
*
* @return A list for all state partitions.
* @throws Exception
*/
<S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;

/**
* Creates a state of the given name that uses Java serialization to persist the state. On restore, all items
* in the list are broadcasted to all parallel operator instances.
*
* <p>This is a simple convenience method. For more flexibility on how state serialization
* should happen, use the {@link #getBroadcastOperatorState(ListStateDescriptor)} method.
*
* @param stateName The name of state to create
* @return A list state using Java serialization to serialize state objects.
* @throws Exception
*/
<T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception;

/** /**
* Returns a set with the names of all currently registered states. * Returns a set with the names of all currently registered states.
* @return set of names for all registered states. * @return set of names for all registered states.
Expand Down
Expand Up @@ -91,12 +91,10 @@ public <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor)
} }


@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override
public <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception { public <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception {
return (ListState<T>) getBroadcastOperatorState(new ListStateDescriptor<>(stateName, javaSerializer)); return (ListState<T>) getBroadcastOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
} }


@Override
public <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception { public <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception {
return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.BROADCAST); return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.BROADCAST);
} }
Expand Down
Expand Up @@ -45,8 +45,9 @@ static Environment createMockEnvironment() {
return env; return env;
} }


private OperatorStateBackend createNewOperatorStateBackend() throws Exception { private DefaultOperatorStateBackend createNewOperatorStateBackend() throws Exception {
return abstractStateBackend.createOperatorStateBackend( //TODO this is temporarily casted to test already functionality that we do not yet expose through public API
return (DefaultOperatorStateBackend) abstractStateBackend.createOperatorStateBackend(
createMockEnvironment(), createMockEnvironment(),
"test-operator"); "test-operator");
} }
Expand All @@ -60,7 +61,7 @@ public void testCreateNew() throws Exception {


@Test @Test
public void testRegisterStates() throws Exception { public void testRegisterStates() throws Exception {
OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend(); DefaultOperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
Expand Down Expand Up @@ -143,7 +144,7 @@ public void testRegisterStates() throws Exception {


@Test @Test
public void testSnapshotRestore() throws Exception { public void testSnapshotRestore() throws Exception {
OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend(); DefaultOperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
Expand Down Expand Up @@ -171,7 +172,8 @@ public void testSnapshotRestore() throws Exception {
operatorStateBackend.close(); operatorStateBackend.close();
operatorStateBackend.dispose(); operatorStateBackend.dispose();


operatorStateBackend = abstractStateBackend.createOperatorStateBackend( //TODO this is temporarily casted to test already functionality that we do not yet expose through public API
operatorStateBackend = (DefaultOperatorStateBackend) abstractStateBackend.createOperatorStateBackend(
createMockEnvironment(), createMockEnvironment(),
"testOperator"); "testOperator");


Expand Down
Expand Up @@ -34,6 +34,7 @@
import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings; import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.messages.JobManagerMessages;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
Expand Down Expand Up @@ -969,8 +970,10 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception {
public void initializeState(FunctionInitializationContext context) throws Exception { public void initializeState(FunctionInitializationContext context) throws Exception {


if (broadcast) { if (broadcast) {
//TODO this is temporarily casted to test already functionality that we do not yet expose through public API
DefaultOperatorStateBackend operatorStateStore = (DefaultOperatorStateBackend) context.getOperatorStateStore();
this.counterPartitions = this.counterPartitions =
context.getOperatorStateStore().getBroadcastSerializableListState("counter_partitions"); operatorStateStore.getBroadcastSerializableListState("counter_partitions");
} else { } else {
this.counterPartitions = this.counterPartitions =
context.getOperatorStateStore().getSerializableListState("counter_partitions"); context.getOperatorStateStore().getSerializableListState("counter_partitions");
Expand Down

0 comments on commit 6a86e9d

Please sign in to comment.