Skip to content

Commit

Permalink
[FLINK-21167] Make StateTable snapshots iterable
Browse files Browse the repository at this point in the history
In order to implement an iterator required by a binary unified savepoint we need a way to iterate a snapshot.
  • Loading branch information
dawidwys committed Jan 28, 2021
1 parent 8ae563a commit 351d6e9
Show file tree
Hide file tree
Showing 13 changed files with 582 additions and 39 deletions.
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.state;

import org.apache.flink.annotation.Internal;

import java.util.Iterator;

/**
* A {@link StateSnapshot} that can return an iterator over all contained {@link StateEntry
* StateEntries}.
*/
@Internal
public interface IterableStateSnapshot<K, N, S> extends StateSnapshot {
Iterator<StateEntry<K, N, S>> getIterator(int keyGroup);
}
Expand Up @@ -36,6 +36,15 @@ public interface StateEntry<K, N, S> {
/** Returns the state of this entry. */
S getState();

default StateEntry<K, N, S> filterOrTransform(StateSnapshotTransformer<S> transformer) {
S newState = transformer.filterOrTransform(getState());
if (newState != null) {
return new SimpleStateEntry<>(getKey(), getNamespace(), newState);
} else {
return null;
}
}

class SimpleStateEntry<K, N, S> implements StateEntry<K, N, S> {
private final K key;
private final N namespace;
Expand Down
Expand Up @@ -21,6 +21,8 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.IterableStateSnapshot;
import org.apache.flink.runtime.state.StateEntry;
import org.apache.flink.runtime.state.StateSnapshot;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
Expand All @@ -30,14 +32,15 @@
import javax.annotation.Nullable;

import java.io.IOException;
import java.util.Iterator;

/**
* Abstract base class for snapshots of a {@link StateTable}. Offers a way to serialize the snapshot
* (by key-group). All snapshots should be released after usage.
*/
@Internal
abstract class AbstractStateTableSnapshot<K, N, S>
implements StateSnapshot, StateSnapshot.StateKeyGroupWriter {
implements IterableStateSnapshot<K, N, S>, StateSnapshot.StateKeyGroupWriter {

/** The {@link StateTable} from which this snapshot was created. */
protected final StateTable<K, N, S> owningStateTable;
Expand Down Expand Up @@ -88,6 +91,17 @@ public StateKeyGroupWriter getKeyGroupWriter() {
return this;
}

@Override
public Iterator<StateEntry<K, N, S>> getIterator(int keyGroupId) {
StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> stateMapSnapshot =
getStateMapSnapshotForKeyGroup(keyGroupId);
return stateMapSnapshot.getIterator(
localKeySerializer,
localNamespaceSerializer,
localStateSerializer,
stateSnapshotTransformer);
}

/**
* Implementation note: we currently chose the same format between {@link NestedMapsStateTable}
* and {@link CopyOnWriteStateTable}.
Expand Down
Expand Up @@ -110,6 +110,19 @@ int getSnapshotVersion() {
return snapshotVersion;
}

@Override
public SnapshotIterator<K, N, S> getIterator(
@Nonnull TypeSerializer<K> keySerializer,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull TypeSerializer<S> stateSerializer,
@Nullable final StateSnapshotTransformer<S> stateSnapshotTransformer) {

return stateSnapshotTransformer == null
? new NonTransformSnapshotIterator<>(numberOfEntriesInSnapshotData, snapshotData)
: new TransformedSnapshotIterator<>(
numberOfEntriesInSnapshotData, snapshotData, stateSnapshotTransformer);
}

@Override
public void writeState(
TypeSerializer<K> keySerializer,
Expand All @@ -119,13 +132,11 @@ public void writeState(
@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer)
throws IOException {
SnapshotIterator<K, N, S> snapshotIterator =
stateSnapshotTransformer == null
? new NonTransformSnapshotIterator<>(
numberOfEntriesInSnapshotData, snapshotData)
: new TransformedSnapshotIterator<>(
numberOfEntriesInSnapshotData,
snapshotData,
stateSnapshotTransformer);
getIterator(
keySerializer,
namespaceSerializer,
stateSerializer,
stateSnapshotTransformer);

int size = snapshotIterator.size();
dov.writeInt(size);
Expand Down
Expand Up @@ -20,14 +20,19 @@

import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.StateEntry;
import org.apache.flink.runtime.state.StateSnapshotTransformer;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Spliterators;
import java.util.stream.StreamSupport;

/**
* This class represents the snapshot of a {@link NestedStateMap}.
Expand All @@ -48,6 +53,24 @@ public NestedStateMapSnapshot(NestedStateMap<K, N, S> owningStateMap) {
super(owningStateMap);
}

@Override
public Iterator<StateEntry<K, N, S>> getIterator(
@Nonnull TypeSerializer<K> keySerializer,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull TypeSerializer<S> stateSerializer,
@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) {
if (stateSnapshotTransformer == null) {
return owningStateMap.iterator();
} else {
return StreamSupport.stream(
Spliterators.spliteratorUnknownSize(owningStateMap.iterator(), 0),
false)
.map(entry -> entry.filterOrTransform(stateSnapshotTransformer))
.filter(Objects::nonNull)
.iterator();
}
}

@Override
public void writeState(
TypeSerializer<K> keySerializer,
Expand Down
Expand Up @@ -20,13 +20,15 @@

import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.StateEntry;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.util.Preconditions;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import java.io.IOException;
import java.util.Iterator;

/**
* Base class for snapshots of a {@link StateMap}.
Expand All @@ -52,6 +54,12 @@ public boolean isOwner(T stateMap) {
/** Release the snapshot. */
public void release() {}

public abstract Iterator<StateEntry<K, N, S>> getIterator(
@Nonnull TypeSerializer<K> keySerializer,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull TypeSerializer<S> stateSerializer,
@Nullable final StateSnapshotTransformer<S> stateSnapshotTransformer);

/**
* Writes the state in this snapshot to output. The state need to be transformed with the given
* transformer if the transformer is non-null.
Expand Down
Expand Up @@ -19,10 +19,13 @@
package org.apache.flink.runtime.state.heap;

import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.ListSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.runtime.state.ArrayListSerializer;
import org.apache.flink.runtime.state.StateEntry;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateTransformationFunction;
import org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor;
import org.apache.flink.util.TestLogger;
Expand All @@ -31,13 +34,22 @@
import org.junit.Assert;
import org.junit.Test;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;

import static org.apache.flink.runtime.state.testutils.StateEntryMatcher.entry;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertThat;

/** Test for {@link CopyOnWriteStateMap}. */
public class CopyOnWriteStateMapTest extends TestLogger {

Expand Down Expand Up @@ -398,6 +410,82 @@ public void testCopyOnWriteContracts() {
Assert.assertSame(originalState5, stateMap.get(5, 1));
}

@Test
public void testIteratingOverSnapshot() {
ListSerializer<Integer> stateSerializer = new ListSerializer<>(IntSerializer.INSTANCE);
final CopyOnWriteStateMap<Integer, Integer, List<Integer>> stateMap =
new CopyOnWriteStateMap<>(stateSerializer);

List<Integer> originalState1 = new ArrayList<>(1);
List<Integer> originalState2 = new ArrayList<>(1);
List<Integer> originalState3 = new ArrayList<>(1);
List<Integer> originalState4 = new ArrayList<>(1);
List<Integer> originalState5 = new ArrayList<>(1);

originalState1.add(1);
originalState2.add(2);
originalState3.add(3);
originalState4.add(4);
originalState5.add(5);

stateMap.put(1, 1, originalState1);
stateMap.put(2, 1, originalState2);
stateMap.put(3, 1, originalState3);
stateMap.put(4, 1, originalState4);
stateMap.put(5, 1, originalState5);

CopyOnWriteStateMapSnapshot<Integer, Integer, List<Integer>> snapshot =
stateMap.stateSnapshot();

Iterator<StateEntry<Integer, Integer, List<Integer>>> iterator =
snapshot.getIterator(
IntSerializer.INSTANCE, IntSerializer.INSTANCE, stateSerializer, null);
assertThat(
() -> iterator,
containsInAnyOrder(
entry(1, 1, originalState1),
entry(2, 1, originalState2),
entry(3, 1, originalState3),
entry(4, 1, originalState4),
entry(5, 1, originalState5)));
}

@Test
public void testIteratingOverSnapshotWithTransform() {
final CopyOnWriteStateMap<Integer, Integer, Long> stateMap =
new CopyOnWriteStateMap<>(LongSerializer.INSTANCE);

stateMap.put(1, 1, 10L);
stateMap.put(2, 1, 11L);
stateMap.put(3, 1, 12L);
stateMap.put(4, 1, 13L);
stateMap.put(5, 1, 14L);

StateMapSnapshot<Integer, Integer, Long, ? extends StateMap<Integer, Integer, Long>>
snapshot = stateMap.stateSnapshot();

Iterator<StateEntry<Integer, Integer, Long>> iterator =
snapshot.getIterator(
IntSerializer.INSTANCE,
IntSerializer.INSTANCE,
LongSerializer.INSTANCE,
new StateSnapshotTransformer<Long>() {
@Nullable
@Override
public Long filterOrTransform(@Nullable Long value) {
if (value == 12L) {
return null;
} else {
return value + 2L;
}
}
});
assertThat(
() -> iterator,
containsInAnyOrder(
entry(1, 1, 12L), entry(2, 1, 13L), entry(4, 1, 15L), entry(5, 1, 16L)));
}

/** This tests that snapshot can be released correctly. */
@Test
public void testSnapshotRelease() {
Expand All @@ -410,16 +498,15 @@ public void testSnapshotRelease() {

CopyOnWriteStateMapSnapshot<Integer, Integer, Integer> snapshot = stateMap.stateSnapshot();
Assert.assertFalse(snapshot.isReleased());
Assert.assertThat(
stateMap.getSnapshotVersions(), Matchers.contains(snapshot.getSnapshotVersion()));
assertThat(stateMap.getSnapshotVersions(), contains(snapshot.getSnapshotVersion()));

snapshot.release();
Assert.assertTrue(snapshot.isReleased());
Assert.assertThat(stateMap.getSnapshotVersions(), Matchers.empty());
assertThat(stateMap.getSnapshotVersions(), Matchers.empty());

// verify that snapshot will release itself only once
snapshot.release();
Assert.assertThat(stateMap.getSnapshotVersions(), Matchers.empty());
assertThat(stateMap.getSnapshotVersions(), Matchers.empty());
}

@SuppressWarnings("unchecked")
Expand Down

0 comments on commit 351d6e9

Please sign in to comment.