diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index ad67171a01c87..3c5756bd22af4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -3840,6 +3840,69 @@ public String fold(String acc, Integer value) throws Exception { } } + @Test + public void testMapStateGetKeys() throws Exception { + final int namespace1ElementsNum = 1000; + final int namespace2ElementsNum = 1000; + String fieldName = "get-keys-test"; + AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); + try { + final String ns1 = "ns1"; + MapState keyedState1 = backend.getPartitionedState( + ns1, + StringSerializer.INSTANCE, + new MapStateDescriptor<>(fieldName, StringSerializer.INSTANCE, IntSerializer.INSTANCE) + ); + + for (int key = 0; key < namespace1ElementsNum; key++) { + backend.setCurrentKey(key); + keyedState1.put("he", key * 2); + keyedState1.put("ho", key * 2); + } + + final String ns2 = "ns2"; + MapState keyedState2 = backend.getPartitionedState( + ns2, + StringSerializer.INSTANCE, + new MapStateDescriptor<>(fieldName, StringSerializer.INSTANCE, IntSerializer.INSTANCE) + ); + + for (int key = namespace1ElementsNum; key < namespace1ElementsNum + namespace2ElementsNum; key++) { + backend.setCurrentKey(key); + keyedState2.put("he", key * 2); + keyedState2.put("ho", key * 2); + } + + // valid for namespace1 + try (Stream keysStream = backend.getKeys(fieldName, ns1).sorted()) { + PrimitiveIterator.OfInt actualIterator = keysStream.mapToInt(value -> value.intValue()).iterator(); + + for (int expectedKey = 0; expectedKey < namespace1ElementsNum; expectedKey++) { + assertTrue(actualIterator.hasNext()); + assertEquals(expectedKey, actualIterator.nextInt()); + } + + assertFalse(actualIterator.hasNext()); + } + + // valid for namespace2 + try (Stream keysStream = backend.getKeys(fieldName, ns2).sorted()) { + PrimitiveIterator.OfInt actualIterator = keysStream.mapToInt(value -> value.intValue()).iterator(); + + for (int expectedKey = namespace1ElementsNum; expectedKey < namespace1ElementsNum + namespace2ElementsNum; expectedKey++) { + assertTrue(actualIterator.hasNext()); + assertEquals(expectedKey, actualIterator.nextInt()); + } + + assertFalse(actualIterator.hasNext()); + } + } + finally { + IOUtils.closeQuietly(backend); + backend.dispose(); + } + } + @Test public void testCheckConcurrencyProblemWhenPerformingCheckpointAsync() throws Exception { diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index f2430ae19df4c..4c4014ce16c71 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -1652,6 +1652,7 @@ static class RocksIteratorForKeysWrapper implements Iterator, AutoCloseabl private final byte[] namespaceBytes; private final boolean ambiguousKeyPossible; private K nextKey; + private K preKey; RocksIteratorForKeysWrapper( RocksIteratorWrapper iterator, @@ -1666,6 +1667,7 @@ static class RocksIteratorForKeysWrapper implements Iterator, AutoCloseabl this.keyGroupPrefixBytes = Preconditions.checkNotNull(keyGroupPrefixBytes); this.namespaceBytes = Preconditions.checkNotNull(namespaceBytes); this.nextKey = null; + this.preKey = null; this.ambiguousKeyPossible = ambiguousKeyPossible; } @@ -1675,15 +1677,22 @@ public boolean hasNext() { while (nextKey == null && iterator.isValid()) { byte[] key = iterator.key(); - if (isMatchingNameSpace(key)) { - ByteArrayInputStreamWithPos inputStream = - new ByteArrayInputStreamWithPos(key, keyGroupPrefixBytes, key.length - keyGroupPrefixBytes); - DataInputViewStreamWrapper dataInput = new DataInputViewStreamWrapper(inputStream); - K value = RocksDBKeySerializationUtils.readKey( - keySerializer, - inputStream, - dataInput, - ambiguousKeyPossible); + + ByteArrayInputStreamWithPos inputStream = + new ByteArrayInputStreamWithPos(key, keyGroupPrefixBytes, key.length - keyGroupPrefixBytes); + + DataInputViewStreamWrapper dataInput = new DataInputViewStreamWrapper(inputStream); + + K value = RocksDBKeySerializationUtils.readKey( + keySerializer, + inputStream, + dataInput, + ambiguousKeyPossible); + + int namespaceByteStartPos = inputStream.getPosition(); + + if (isMatchingNameSpace(key, namespaceByteStartPos) && !Objects.equals(preKey, value)) { + preKey = value; nextKey = value; } iterator.next(); @@ -1705,12 +1714,12 @@ public K next() { return tmpKey; } - private boolean isMatchingNameSpace(@Nonnull byte[] key) { + private boolean isMatchingNameSpace(@Nonnull byte[] key, int beginPos) { final int namespaceBytesLength = namespaceBytes.length; - final int basicLength = namespaceBytesLength + keyGroupPrefixBytes; + final int basicLength = namespaceBytesLength + beginPos; if (key.length >= basicLength) { - for (int i = 1; i <= namespaceBytesLength; ++i) { - if (key[key.length - i] != namespaceBytes[namespaceBytesLength - i]) { + for (int i = 0; i < namespaceBytesLength; ++i) { + if (key[beginPos + i] != namespaceBytes[i]) { return false; } }