diff --git a/flink-core/src/main/java/org/apache/flink/util/MathUtils.java b/flink-core/src/main/java/org/apache/flink/util/MathUtils.java index 074e8aea66f9c..4c52b6e6f2b3e 100644 --- a/flink-core/src/main/java/org/apache/flink/util/MathUtils.java +++ b/flink-core/src/main/java/org/apache/flink/util/MathUtils.java @@ -22,13 +22,13 @@ * Collection of simple mathematical routines. */ public final class MathUtils { - + /** * Computes the logarithm of the given value to the base of 2, rounded down. It corresponds to the * position of the highest non-zero bit. The position is counted, starting with 0 from the least * significant bit to the most significant bit. For example, log2floor(16) = 4, and * log2floor(10) = 3. - * + * * @param value The value to compute the logarithm for. * @return The logarithm (rounded down) to the base of 2. * @throws ArithmeticException Thrown, if the given value is zero. @@ -40,11 +40,11 @@ public static int log2floor(int value) throws ArithmeticException { return 31 - Integer.numberOfLeadingZeros(value); } - + /** * Computes the logarithm of the given value to the base of 2. This method throws an error, * if the given argument is not a power of 2. - * + * * @param value The value to compute the logarithm for. * @return The logarithm to the base of 2. * @throws ArithmeticException Thrown, if the given value is zero. @@ -59,25 +59,25 @@ public static int log2strict(int value) throws ArithmeticException, IllegalArgum } return 31 - Integer.numberOfLeadingZeros(value); } - + /** * Decrements the given number down to the closest power of two. If the argument is a * power of two, it remains unchanged. - * + * * @param value The value to round down. * @return The closest value that is a power of two and less or equal than the given value. */ public static int roundDownToPowerOf2(int value) { return Integer.highestOneBit(value); } - + /** * Casts the given value to a 32 bit integer, if it can be safely done. If the cast would change the numeric * value, this method raises an exception. *

* This method is a protection in places where one expects to be able to safely case, but where unexpected * situations could make the cast unsafe and would cause hidden problems that are hard to track down. - * + * * @param value The value to be cast to an integer. * @return The given value as an integer. * @see Math#toIntExact(long) @@ -172,8 +172,37 @@ public static int roundUpToPowerOfTwo(int x) { return x + 1; } + /** + * Pseudo-randomly maps a long (64-bit) to an integer (32-bit) using some bit-mixing for better distribution. + * + * @param in the long (64-bit)input. + * @return the bit-mixed int (32-bit) output + */ + public static int longToIntWithBitMixing(long in) { + in = (in ^ (in >>> 30)) * 0xbf58476d1ce4e5b9L; + in = (in ^ (in >>> 27)) * 0x94d049bb133111ebL; + in = in ^ (in >>> 31); + return (int) in; + } + + /** + * Bit-mixing for pseudo-randomization of integers (e.g., to guard against bad hash functions). Implementation is + * from Murmur's 32 bit finalizer. + * + * @param in the input value + * @return the bit-mixed output value + */ + public static int bitMix(int in) { + in ^= in >>> 16; + in *= 0x85ebca6b; + in ^= in >>> 13; + in *= 0xc2b2ae35; + in ^= in >>> 16; + return in; + } + // ============================================================================================ - + /** * Prevent Instantiation through private constructor. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index 2daf89624f301..23c9a49c7943c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -35,6 +35,8 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.heap.async.AbstractHeapMergingState; +import org.apache.flink.runtime.state.heap.async.InternalKeyContext; import org.apache.flink.util.Preconditions; import java.io.Closeable; @@ -51,7 +53,7 @@ * @param Type of the key by which state is keyed. */ public abstract class AbstractKeyedStateBackend - implements KeyedStateBackend, Snapshotable, Closeable { + implements KeyedStateBackend, Snapshotable, Closeable, InternalKeyContext { /** {@link TypeSerializer} for our key. */ protected final TypeSerializer keySerializer; @@ -205,6 +207,7 @@ public int getNumberOfKeyGroups() { /** * @see KeyedStateBackend */ + @Override public KeyGroupRange getKeyGroupRange() { return keyGroupRange; } @@ -293,10 +296,16 @@ public FoldingState createFoldingState(FoldingStateDescriptor> void mergePartitionedStates(final N target, Collection sources, final TypeSerializer namespaceSerializer, final StateDescriptor stateDescriptor) throws Exception { - if (stateDescriptor instanceof ReducingStateDescriptor) { + + State stateRef = getPartitionedState(target, namespaceSerializer, stateDescriptor); + if (stateRef instanceof AbstractHeapMergingState) { + + ((AbstractHeapMergingState) stateRef).mergeNamespaces(target, sources); + } else if (stateDescriptor instanceof ReducingStateDescriptor) { + ReducingStateDescriptor reducingStateDescriptor = (ReducingStateDescriptor) stateDescriptor; + ReducingState state = (ReducingState) stateRef; ReduceFunction reduceFn = reducingStateDescriptor.getReduceFunction(); - ReducingState state = (ReducingState) getPartitionedState(target, namespaceSerializer, stateDescriptor); KvState kvState = (KvState) state; Object result = null; for (N source: sources) { @@ -314,7 +323,8 @@ public FoldingState createFoldingState(FoldingStateDescriptor state = (ListState) getPartitionedState(target, namespaceSerializer, stateDescriptor); + + ListState state = (ListState) stateRef; KvState kvState = (KvState) state; List result = new ArrayList<>(); for (N source: sources) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateTransformationFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateTransformationFunction.java new file mode 100644 index 0000000000000..182b4c8386345 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateTransformationFunction.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.Internal; + +/** + * Interface for a binary function that is used for push-down of state transformation into state backends. The + * function takes as inputs the old state and an element. From those inputs, the function computes the new state. + * + * @param type of the previous state that is the bases for the computation of the new state. + * @param type of the element value that is used to compute the change of state. + */ +@Internal +public interface StateTransformationFunction { + + /** + * Binary function that applies a given value to the given old state to compute the new state. + * + * @param previousState the previous state that is the basis for the transformation. + * @param value the value that the implementation applies to the old state to obtain the new state. + * @return the new state, computed by applying the given value on the given old state. + * @throws Exception if something goes wrong in applying the transformation function. + */ + S apply(S previousState, T value) throws Exception; +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/async/AsyncFsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/async/AsyncFsStateBackend.java new file mode 100644 index 0000000000000..d90ffbd4b7d22 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/async/AsyncFsStateBackend.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.filesystem.async; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.filesystem.FsCheckpointStreamFactory; +import org.apache.flink.runtime.state.heap.async.AsyncHeapKeyedStateBackend; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; + +/** + * The file state backend is a state backend that stores the state of streaming jobs in a file system. + * + *

The state backend has one core directory into which it puts all checkpoint data. Inside that + * directory, it creates a directory per job, inside which each checkpoint gets a directory, with + * files for each state, for example: + * + * {@code hdfs://namenode:port/flink-checkpoints//chk-17/6ba7b810-9dad-11d1-80b4-00c04fd430c8 } + */ +public class AsyncFsStateBackend extends AbstractStateBackend { + + private static final long serialVersionUID = -8191916350224044011L; + + private static final Logger LOG = LoggerFactory.getLogger(AsyncFsStateBackend.class); + + /** By default, state smaller than 1024 bytes will not be written to files, but + * will be stored directly with the metadata */ + public static final int DEFAULT_FILE_STATE_THRESHOLD = 1024; + + /** Maximum size of state that is stored with the metadata, rather than in files */ + private static final int MAX_FILE_STATE_THRESHOLD = 1024 * 1024; + + /** The path to the directory for the checkpoint data, including the file system + * description via scheme and optional authority */ + private final Path basePath; + + /** State below this size will be stored as part of the metadata, rather than in files */ + private final int fileStateThreshold; + + /** + * Creates a new state backend that stores its checkpoint data in the file system and location + * defined by the given URI. + * + *

A file system for the file system scheme in the URI (e.g., 'file://', 'hdfs://', or 'S3://') + * must be accessible via {@link FileSystem#get(URI)}. + * + *

For a state backend targeting HDFS, this means that the URI must either specify the authority + * (host and port), or that the Hadoop configuration that describes that information must be in the + * classpath. + * + * @param checkpointDataUri The URI describing the filesystem (scheme and optionally authority), + * and the path to the checkpoint data directory. + * @throws IOException Thrown, if no file system can be found for the scheme in the URI. + */ + public AsyncFsStateBackend(String checkpointDataUri) throws IOException { + this(new Path(checkpointDataUri)); + } + + /** + * Creates a new state backend that stores its checkpoint data in the file system and location + * defined by the given URI. + * + *

A file system for the file system scheme in the URI (e.g., 'file://', 'hdfs://', or 'S3://') + * must be accessible via {@link FileSystem#get(URI)}. + * + *

For a state backend targeting HDFS, this means that the URI must either specify the authority + * (host and port), or that the Hadoop configuration that describes that information must be in the + * classpath. + * + * @param checkpointDataUri The URI describing the filesystem (scheme and optionally authority), + * and the path to the checkpoint data directory. + * @throws IOException Thrown, if no file system can be found for the scheme in the URI. + */ + public AsyncFsStateBackend(Path checkpointDataUri) throws IOException { + this(checkpointDataUri.toUri()); + } + + /** + * Creates a new state backend that stores its checkpoint data in the file system and location + * defined by the given URI. + * + *

A file system for the file system scheme in the URI (e.g., 'file://', 'hdfs://', or 'S3://') + * must be accessible via {@link FileSystem#get(URI)}. + * + *

For a state backend targeting HDFS, this means that the URI must either specify the authority + * (host and port), or that the Hadoop configuration that describes that information must be in the + * classpath. + * + * @param checkpointDataUri The URI describing the filesystem (scheme and optionally authority), + * and the path to the checkpoint data directory. + * @throws IOException Thrown, if no file system can be found for the scheme in the URI. + */ + public AsyncFsStateBackend(URI checkpointDataUri) throws IOException { + this(checkpointDataUri, DEFAULT_FILE_STATE_THRESHOLD); + } + + /** + * Creates a new state backend that stores its checkpoint data in the file system and location + * defined by the given URI. + * + *

A file system for the file system scheme in the URI (e.g., 'file://', 'hdfs://', or 'S3://') + * must be accessible via {@link FileSystem#get(URI)}. + * + *

For a state backend targeting HDFS, this means that the URI must either specify the authority + * (host and port), or that the Hadoop configuration that describes that information must be in the + * classpath. + * + * @param checkpointDataUri The URI describing the filesystem (scheme and optionally authority), + * and the path to the checkpoint data directory. + * @param fileStateSizeThreshold State up to this size will be stored as part of the metadata, + * rather than in files + * + * @throws IOException Thrown, if no file system can be found for the scheme in the URI. + */ + public AsyncFsStateBackend(URI checkpointDataUri, int fileStateSizeThreshold) throws IOException { + if (fileStateSizeThreshold < 0) { + throw new IllegalArgumentException("The threshold for file state size must be zero or larger."); + } + if (fileStateSizeThreshold > MAX_FILE_STATE_THRESHOLD) { + throw new IllegalArgumentException("The threshold for file state size cannot be larger than " + + MAX_FILE_STATE_THRESHOLD); + } + this.fileStateThreshold = fileStateSizeThreshold; + + this.basePath = validateAndNormalizeUri(checkpointDataUri); + } + + /** + * Gets the base directory where all state-containing files are stored. + * The job specific directory is created inside this directory. + * + * @return The base directory. + */ + public Path getBasePath() { + return basePath; + } + + // ------------------------------------------------------------------------ + // initialization and cleanup + // ------------------------------------------------------------------------ + + @Override + public CheckpointStreamFactory createStreamFactory(JobID jobId, String operatorIdentifier) throws IOException { + return new FsCheckpointStreamFactory(basePath, jobId, fileStateThreshold); + } + + @Override + public AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, + JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry) throws Exception { + return new AsyncHeapKeyedStateBackend<>( + kvStateRegistry, + keySerializer, + env.getUserClassLoader(), + numberOfKeyGroups, + keyGroupRange); + } + + @Override + public String toString() { + return "File State Backend @ " + basePath; + } + + /** + * Checks and normalizes the checkpoint data URI. This method first checks the validity of the + * URI (scheme, path, availability of a matching file system) and then normalizes the URI + * to a path. + * + *

If the URI does not include an authority, but the file system configured for the URI has an + * authority, then the normalized path will include this authority. + * + * @param checkpointDataUri The URI to check and normalize. + * @return A normalized URI as a Path. + * + * @throws IllegalArgumentException Thrown, if the URI misses scheme or path. + * @throws IOException Thrown, if no file system can be found for the URI's scheme. + */ + public static Path validateAndNormalizeUri(URI checkpointDataUri) throws IOException { + final String scheme = checkpointDataUri.getScheme(); + final String path = checkpointDataUri.getPath(); + + // some validity checks + if (scheme == null) { + throw new IllegalArgumentException("The scheme (hdfs://, file://, etc) is null. " + + "Please specify the file system scheme explicitly in the URI."); + } + if (path == null) { + throw new IllegalArgumentException("The path to store the checkpoint data in is null. " + + "Please specify a directory path for the checkpoint data."); + } + if (path.length() == 0 || path.equals("/")) { + throw new IllegalArgumentException("Cannot use the root directory for checkpoints."); + } + + if (!FileSystem.isFlinkSupportedScheme(checkpointDataUri.getScheme())) { + // skip verification checks for non-flink supported filesystem + // this is because the required filesystem classes may not be available to the flink client + return new Path(checkpointDataUri); + } else { + // we do a bit of work to make sure that the URI for the filesystem refers to exactly the same + // (distributed) filesystem on all hosts and includes full host/port information, even if the + // original URI did not include that. We count on the filesystem loading from the configuration + // to fill in the missing data. + + // try to grab the file system for this path/URI + FileSystem filesystem = FileSystem.get(checkpointDataUri); + if (filesystem == null) { + String reason = "Could not find a file system for the given scheme in" + + "the available configurations."; + LOG.warn("Could not verify checkpoint path. This might be caused by a genuine " + + "problem or by the fact that the file system is not accessible from the " + + "client. Reason:{}", reason); + return new Path(checkpointDataUri); + } + + URI fsURI = filesystem.getUri(); + try { + URI baseURI = new URI(fsURI.getScheme(), fsURI.getAuthority(), path, null, null); + return new Path(baseURI); + } catch (URISyntaxException e) { + String reason = String.format( + "Cannot create file system URI for checkpointDataUri %s and filesystem URI %s: " + e.toString(), + checkpointDataUri, + fsURI); + LOG.warn("Could not verify checkpoint path. This might be caused by a genuine " + + "problem or by the fact that the file system is not accessible from the " + + "client. Reason: {}", reason); + return new Path(checkpointDataUri); + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractHeapMergingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractHeapMergingState.java new file mode 100644 index 0000000000000..1b09d9c99dae8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractHeapMergingState.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.state.MergingState; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.StateTransformationFunction; + +import java.util.Collection; + +/** + * Base class for {@link MergingState} that is stored on the heap. + * + * @param The type of the key. + * @param The type of the namespace. + * @param The type of the values in the state. + * @param The type of State + * @param The type of StateDescriptor for the State S + */ +public abstract class AbstractHeapMergingState> + extends AbstractHeapState { + + /** + * The merge transformation function that implements the merge logic. + */ + private final MergeTransformation mergeTransformation; + + /** + * Creates a new key/value state for the given hash map of key/value pairs. + * + * @param stateDesc The state identifier for the state. This contains name + * and can create a default state value. + * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. + */ + protected AbstractHeapMergingState( + SD stateDesc, + StateTable stateTable, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) { + + super(stateDesc, stateTable, keySerializer, namespaceSerializer); + this.mergeTransformation = new MergeTransformation(); + } + + public void mergeNamespaces(N target, Collection sources) throws Exception { + if (sources == null || sources.isEmpty()) { + return; // nothing to do + } + + final StateTable map = stateTable; + + SV merged = null; + + // merge the sources + for (N source : sources) { + + // get and remove the next source per namespace/key + SV sourceState = map.removeAndGetOld(source); + + if (merged != null && sourceState != null) { + merged = mergeState(merged, sourceState); + } else if (merged == null) { + merged = sourceState; + } + } + + // merge into the target, if needed + if (merged != null) { + map.transform(target, merged, mergeTransformation); + } + } + + protected abstract SV mergeState(SV a, SV b) throws Exception; + + final class MergeTransformation implements StateTransformationFunction { + + @Override + public SV apply(SV targetState, SV merged) throws Exception { + if (targetState != null) { + return mergeState(targetState, merged); + } else { + return merged; + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractHeapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractHeapState.java new file mode 100644 index 0000000000000..c93ea6aec8aa1 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractHeapState.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.Preconditions; + +/** + * Base class for partitioned {@link ListState} implementations that are backed by a regular + * heap hash map. The concrete implementations define how the state is checkpointed. + * + * @param The type of the key. + * @param The type of the namespace. + * @param The type of the values in the state. + * @param The type of State + * @param The type of StateDescriptor for the State S + */ +public abstract class AbstractHeapState> + implements KvState, State { + + /** Map containing the actual key/value pairs */ + protected final StateTable stateTable; + + /** This holds the name of the state and can create an initial default value for the state. */ + protected final SD stateDesc; + + /** The current namespace, which the access methods will refer to. */ + protected N currentNamespace; + + protected final TypeSerializer keySerializer; + + protected final TypeSerializer namespaceSerializer; + + /** + * Creates a new key/value state for the given hash map of key/value pairs. + * + * @param stateDesc The state identifier for the state. This contains name + * and can create a default state value. + * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. + */ + protected AbstractHeapState( + SD stateDesc, + StateTable stateTable, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) { + + this.stateDesc = stateDesc; + this.stateTable = Preconditions.checkNotNull(stateTable, "State table must not be null."); + this.keySerializer = keySerializer; + this.namespaceSerializer = namespaceSerializer; + this.currentNamespace = null; + } + + // ------------------------------------------------------------------------ + + + public final void clear() { + stateTable.remove(currentNamespace); + } + + public final void setCurrentNamespace(N namespace) { + this.currentNamespace = Preconditions.checkNotNull(namespace, "Namespace must not be null."); + } + + public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { + Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace"); + + Tuple2 keyAndNamespace = KvStateRequestSerializer.deserializeKeyAndNamespace( + serializedKeyAndNamespace, keySerializer, namespaceSerializer); + + return getSerializedValue(keyAndNamespace.f0, keyAndNamespace.f1); + } + + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkState(namespace != null, "No namespace given."); + Preconditions.checkState(key != null, "No key given."); + + SV result = stateTable.get(key, namespace); + + if (result == null) { + return null; + } + + @SuppressWarnings("unchecked,rawtypes") + TypeSerializer serializer = stateDesc.getSerializer(); + return KvStateRequestSerializer.serializeValue(result, serializer); + } + + /** + * This should only be used for testing. + */ + @VisibleForTesting + public StateTable getStateTable() { + return stateTable; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractStateTableSnapshot.java new file mode 100644 index 0000000000000..8a1d3f360a789 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AbstractStateTableSnapshot.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.util.Preconditions; + +/** + * Abstract class to encapsulate the logic to take snapshots of {@link StateTable} implementations and also defines how + * the snapshot is written during the serialization phase of checkpointing. + */ +@Internal +abstract class AbstractStateTableSnapshot> implements StateTableSnapshot { + + /** + * The {@link StateTable} from which this snapshot was created. + */ + final T owningStateTable; + + /** + * Creates a new {@link AbstractStateTableSnapshot} for and owned by the given table. + * + * @param owningStateTable the {@link StateTable} for which this object represents a snapshot. + */ + AbstractStateTableSnapshot(T owningStateTable) { + this.owningStateTable = Preconditions.checkNotNull(owningStateTable); + } + + /** + * Optional hook to release resources for this snapshot at the end of its lifecycle. + */ + @Override + public void release() { + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AsyncHeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AsyncHeapKeyedStateBackend.java new file mode 100644 index 0000000000000..e19ed0049200e --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/AsyncHeapKeyedStateBackend.java @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.commons.collections.map.HashedMap; +import org.apache.commons.io.IOUtils; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.state.FoldingState; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.migration.MigrationUtil; +import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable; +import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.ArrayListSerializer; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.DoneFuture; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeOffsets; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.KeyedBackendSerializationProxy; +import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.RunnableFuture; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A {@link AbstractKeyedStateBackend} that keeps state on the Java Heap and will serialize state to + * streams provided by a {@link CheckpointStreamFactory} upon + * checkpointing. + * + * @param The key by which state is keyed. + */ +public class AsyncHeapKeyedStateBackend extends AbstractKeyedStateBackend { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncHeapKeyedStateBackend.class); + + /** + * Map of state tables that stores all state of key/value states. We store it centrally so + * that we can easily checkpoint/restore it. + * + *

The actual parameters of StateTable are {@code StateTable>} + * but we can't put them here because different key/value states with different types and + * namespace types share this central list of tables. + */ + private final HashMap> stateTables = new HashMap<>(); + + public AsyncHeapKeyedStateBackend( + TaskKvStateRegistry kvStateRegistry, + TypeSerializer keySerializer, + ClassLoader userCodeClassLoader, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange) { + + super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); + LOG.info("Initializing heap keyed state backend with stream factory."); + } + + // ------------------------------------------------------------------------ + // state backend operations + // ------------------------------------------------------------------------ + + private StateTable tryRegisterStateTable( + TypeSerializer namespaceSerializer, StateDescriptor stateDesc) { + + return tryRegisterStateTable( + stateDesc.getName(), stateDesc.getType(), + namespaceSerializer, stateDesc.getSerializer()); + } + + private StateTable tryRegisterStateTable( + String stateName, + StateDescriptor.Type stateType, + TypeSerializer namespaceSerializer, + TypeSerializer valueSerializer) { + + final RegisteredBackendStateMetaInfo newMetaInfo = + new RegisteredBackendStateMetaInfo<>(stateType, stateName, namespaceSerializer, valueSerializer); + + @SuppressWarnings("unchecked") + StateTable stateTable = (StateTable) stateTables.get(stateName); + + if (stateTable == null) { + stateTable = newStateTable(newMetaInfo); + stateTables.put(stateName, stateTable); + } else { + if (!newMetaInfo.isCompatibleWith(stateTable.getMetaInfo())) { + throw new RuntimeException("Trying to access state using incompatible meta info, was " + + stateTable.getMetaInfo() + " trying access with " + newMetaInfo); + } + stateTable.setMetaInfo(newMetaInfo); + } + return stateTable; + } + + private boolean hasRegisteredState() { + return !stateTables.isEmpty(); + } + + @Override + public ValueState createValueState( + TypeSerializer namespaceSerializer, + ValueStateDescriptor stateDesc) throws Exception { + + StateTable stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); + return new HeapValueState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + @Override + public ListState createListState( + TypeSerializer namespaceSerializer, + ListStateDescriptor stateDesc) throws Exception { + + // the list state does some manual mapping, because the state is typed to the generic + // 'List' interface, but we want to use an implementation typed to ArrayList + // using a more specialized implementation opens up runtime optimizations + + StateTable> stateTable = tryRegisterStateTable( + stateDesc.getName(), + stateDesc.getType(), + namespaceSerializer, + new ArrayListSerializer(stateDesc.getSerializer())); + + return new HeapListState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + @Override + public ReducingState createReducingState( + TypeSerializer namespaceSerializer, + ReducingStateDescriptor stateDesc) throws Exception { + + StateTable stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); + return new HeapReducingState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + @Override + public FoldingState createFoldingState( + TypeSerializer namespaceSerializer, + FoldingStateDescriptor stateDesc) throws Exception { + + StateTable stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); + return new HeapFoldingState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + @Override + @SuppressWarnings("unchecked") + public RunnableFuture snapshot( + final long checkpointId, + final long timestamp, + final CheckpointStreamFactory streamFactory) throws Exception { + + if (!hasRegisteredState()) { + return DoneFuture.nullValue(); + } + + long syncStartTime = System.currentTimeMillis(); + + Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE, + "Too many KV-States: " + stateTables.size() + + ". Currently at most " + Short.MAX_VALUE + " states are supported"); + + List> metaInfoProxyList = new ArrayList<>(stateTables.size()); + + final Map kVStateToId = new HashMap<>(stateTables.size()); + + final Map, StateTableSnapshot> cowStateStableSnapshots = new HashedMap(stateTables.size()); + + for (Map.Entry> kvState : stateTables.entrySet()) { + RegisteredBackendStateMetaInfo metaInfo = kvState.getValue().getMetaInfo(); + KeyedBackendSerializationProxy.StateMetaInfo metaInfoProxy = new KeyedBackendSerializationProxy.StateMetaInfo( + metaInfo.getStateType(), + metaInfo.getName(), + metaInfo.getNamespaceSerializer(), + metaInfo.getStateSerializer()); + + metaInfoProxyList.add(metaInfoProxy); + kVStateToId.put(kvState.getKey(), kVStateToId.size()); + StateTable stateTable = kvState.getValue(); + if (null != stateTable) { + cowStateStableSnapshots.put(stateTable, stateTable.createSnapshot()); + } + } + + final KeyedBackendSerializationProxy serializationProxy = + new KeyedBackendSerializationProxy(keySerializer, metaInfoProxyList); + + //--------------------------------------------------- this becomes the end of sync part + + // implementation of the async IO operation, based on FutureTask + final AbstractAsyncIOCallable ioCallable = + new AbstractAsyncIOCallable() { + + AtomicBoolean open = new AtomicBoolean(false); + + @Override + public CheckpointStreamFactory.CheckpointStateOutputStream openIOHandle() throws Exception { + if (open.compareAndSet(false, true)) { + CheckpointStreamFactory.CheckpointStateOutputStream stream = + streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp); + try { + cancelStreamRegistry.registerClosable(stream); + return stream; + } catch (Exception ex) { + open.set(false); + throw ex; + } + } else { + throw new IOException("Operation already opened."); + } + } + + @Override + public KeyGroupsStateHandle performOperation() throws Exception { + long asyncStartTime = System.currentTimeMillis(); + CheckpointStreamFactory.CheckpointStateOutputStream stream = getIoHandle(); + DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(stream); + serializationProxy.write(outView); + + long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()]; + + for (int keyGroupPos = 0; keyGroupPos < keyGroupRange.getNumberOfKeyGroups(); ++keyGroupPos) { + int keyGroupId = keyGroupRange.getKeyGroupId(keyGroupPos); + keyGroupRangeOffsets[keyGroupPos] = stream.getPos(); + outView.writeInt(keyGroupId); + + for (Map.Entry> kvState : stateTables.entrySet()) { + outView.writeShort(kVStateToId.get(kvState.getKey())); + cowStateStableSnapshots.get(kvState.getValue()).writeMappingsInKeyGroup(outView, keyGroupId); + } + } + + if (open.compareAndSet(true, false)) { + StreamStateHandle streamStateHandle = stream.closeAndGetHandle(); + KeyGroupRangeOffsets offsets = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets); + final KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(offsets, streamStateHandle); + + LOG.info("Heap backend snapshot ({}, asynchronous part) in thread {} took {} ms.", + streamFactory, Thread.currentThread(), (System.currentTimeMillis() - asyncStartTime)); + + return keyGroupsStateHandle; + } else { + throw new IOException("Checkpoint stream already closed."); + } + } + + @Override + public void done(boolean canceled) { + if (open.compareAndSet(true, false)) { + CheckpointStreamFactory.CheckpointStateOutputStream stream = getIoHandle(); + if (null != stream) { + cancelStreamRegistry.unregisterClosable(stream); + IOUtils.closeQuietly(stream); + } + } + for (StateTableSnapshot snapshot : cowStateStableSnapshots.values()) { + snapshot.release(); + } + } + }; + + AsyncStoppableTaskWithCallback task = AsyncStoppableTaskWithCallback.from(ioCallable); + + LOG.info("Heap backend snapshot (" + streamFactory + ", synchronous part) in thread " + + Thread.currentThread() + " took " + (System.currentTimeMillis() - syncStartTime) + " ms."); + + return task; + } + + @SuppressWarnings("deprecation") + @Override + public void restore(Collection restoredState) throws Exception { + LOG.info("Initializing heap keyed state backend from snapshot."); + + if (LOG.isDebugEnabled()) { + LOG.debug("Restoring snapshot from state handles: {}.", restoredState); + } + + if (MigrationUtil.isOldSavepointKeyedState(restoredState)) { + throw new UnsupportedOperationException( + "This async.HeapKeyedStateBackend does not support restore from old savepoints."); + } else { + restorePartitionedState(restoredState); + } + } + + @SuppressWarnings({"unchecked"}) + private void restorePartitionedState(Collection state) throws Exception { + + final Map kvStatesById = new HashMap<>(); + int numRegisteredKvStates = 0; + stateTables.clear(); + + for (KeyGroupsStateHandle keyGroupsHandle : state) { + + if (keyGroupsHandle == null) { + continue; + } + + FSDataInputStream fsDataInputStream = keyGroupsHandle.openInputStream(); + cancelStreamRegistry.registerClosable(fsDataInputStream); + + try { + DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(fsDataInputStream); + + KeyedBackendSerializationProxy serializationProxy = + new KeyedBackendSerializationProxy(userCodeClassLoader); + + serializationProxy.read(inView); + + List> metaInfoList = + serializationProxy.getNamedStateSerializationProxies(); + + for (KeyedBackendSerializationProxy.StateMetaInfo metaInfoSerializationProxy : metaInfoList) { + + StateTable stateTable = stateTables.get(metaInfoSerializationProxy.getStateName()); + + //important: only create a new table we did not already create it previously + if (null == stateTable) { + + RegisteredBackendStateMetaInfo registeredBackendStateMetaInfo = + new RegisteredBackendStateMetaInfo<>(metaInfoSerializationProxy); + + stateTable = newStateTable(registeredBackendStateMetaInfo); + stateTables.put(metaInfoSerializationProxy.getStateName(), stateTable); + kvStatesById.put(numRegisteredKvStates, metaInfoSerializationProxy.getStateName()); + ++numRegisteredKvStates; + } + } + + for (Tuple2 groupOffset : keyGroupsHandle.getGroupRangeOffsets()) { + int keyGroupIndex = groupOffset.f0; + long offset = groupOffset.f1; + fsDataInputStream.seek(offset); + + int writtenKeyGroupIndex = inView.readInt(); + + Preconditions.checkState(writtenKeyGroupIndex == keyGroupIndex, + "Unexpected key-group in restore."); + + for (int i = 0; i < metaInfoList.size(); i++) { + int kvStateId = inView.readShort(); + StateTable stateTable = stateTables.get(kvStatesById.get(kvStateId)); + + // Hardcoding 2 as version will lead to the right method for the + // serialization format. Due to th backport, we should keep this fix and do + // not allow restore from a different format. + StateTableByKeyGroupReader keyGroupReader = + StateTableByKeyGroupReaders.readerForVersion( + stateTable, + 2); + + keyGroupReader.readMappingsInKeyGroup(inView, keyGroupIndex); + } + } + } finally { + cancelStreamRegistry.unregisterClosable(fsDataInputStream); + IOUtils.closeQuietly(fsDataInputStream); + } + } + } + + @Override + public String toString() { + return "HeapKeyedStateBackend"; + } + + /** + * Returns the total number of state entries across all keys/namespaces. + */ + @VisibleForTesting + @SuppressWarnings("unchecked") + public int numStateEntries() { + int sum = 0; + for (StateTable stateTable : stateTables.values()) { + sum += stateTable.size(); + } + return sum; + } + + /** + * Returns the total number of state entries across all keys for the given namespace. + */ + @VisibleForTesting + public int numStateEntries(Object namespace) { + int sum = 0; + for (StateTable stateTable : stateTables.values()) { + sum += stateTable.sizeOfNamespace(namespace); + } + return sum; + } + + private StateTable newStateTable(RegisteredBackendStateMetaInfo newMetaInfo) { + return new CopyOnWriteStateTable<>(this, newMetaInfo); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTable.java new file mode 100644 index 0000000000000..6c9c14c527389 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTable.java @@ -0,0 +1,1066 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo; +import org.apache.flink.runtime.state.StateTransformationFunction; +import org.apache.flink.util.MathUtils; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.ConcurrentModificationException; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.TreeSet; + +/** + * Implementation of Flink's in-memory state tables with copy-on-write support. This map does not support null values + * for key or namespace. + *

+ * {@link CopyOnWriteStateTable} sacrifices some peak performance and memory efficiency for features like incremental + * rehashing and asynchronous snapshots through copy-on-write. Copy-on-write tries to minimize the amount of copying by + * maintaining version meta data for both, the map structure and the state objects. However, we must often proactively + * copy state objects when we hand them to the user. + *

+ * As for any state backend, user should not keep references on state objects that they obtained from state backends + * outside the scope of the user function calls. + *

+ * Some brief maintenance notes: + *

+ * 1) Flattening the underlying data structure from nested maps (namespace) -> (key) -> (state) to one flat map + * (key, namespace) -> (state) brings certain performance trade-offs. In theory, the flat map has one less level of + * indirection compared to the nested map. However, the nested map naturally de-duplicates namespace objects for which + * #equals() is true. This leads to potentially a lot of redundant namespace objects for the flattened version. Those, + * in turn, can again introduce more cache misses because we need to follow the namespace object on all operations to + * ensure entry identities. Obviously, copy-on-write can also add memory overhead. So does the meta data to track + * copy-on-write requirement (state and entry versions on {@link StateTableEntry}). + *

+ * 2) A flat map structure is a lot easier when it comes to tracking copy-on-write of the map structure. + *

+ * 3) Nested structure had the (never used) advantage that we can easily drop and iterate whole namespaces. This could + * give locality advantages for certain access pattern, e.g. iterating a namespace. + *

+ * 4) Serialization format is changed from namespace-prefix compressed (as naturally provided from the old nested + * structure) to making all entries self contained as (key, namespace, state). + *

+ * 5) We got rid of having multiple nested tables, one for each key-group. Instead, we partition state into key-groups + * on-the-fly, during the asynchronous part of a snapshot. + *

+ * 6) Currently, a state table can only grow, but never shrinks on low load. We could easily add this if required. + *

+ * 7) Heap based state backends like this can easily cause a lot of GC activity. Besides using G1 as garbage collector, + * we should provide an additional state backend that operates on off-heap memory. This would sacrifice peak performance + * (due to de/serialization of objects) for a lower, but more constant throughput and potentially huge simplifications + * w.r.t. copy-on-write. + *

+ * 8) We could try a hybrid of a serialized and object based backends, where key and namespace of the entries are both + * serialized in one byte-array. + *

+ * 9) We could consider smaller types (e.g. short) for the version counting and think about some reset strategy before + * overflows, when there is no snapshot running. However, this would have to touch all entries in the map. + *

+ * This class was initially based on the {@link java.util.HashMap} implementation of the Android JDK, but is now heavily + * customized towards the use case of table for state entries. + * + * IMPORTANT: the contracts for this class rely on the user not holding any references to objects returned by this map + * beyond the life cycle of per-element operations. Or phrased differently, all get-update-put operations on a mapping + * should be within one call of processElement. Otherwise, the user must take care of taking deep copies, e.g. for + * caching purposes. + * + * @param type of key. + * @param type of namespace. + * @param type of value. + */ +public class CopyOnWriteStateTable extends StateTable implements Iterable> { + + /** + * The logger. + */ + private static final Logger LOG = LoggerFactory.getLogger(AsyncHeapKeyedStateBackend.class); + + /** + * Min capacity (other than zero) for a {@link CopyOnWriteStateTable}. Must be a power of two + * greater than 1 (and less than 1 << 30). + */ + private static final int MINIMUM_CAPACITY = 4; + + /** + * Max capacity for a {@link CopyOnWriteStateTable}. Must be a power of two >= MINIMUM_CAPACITY. + */ + private static final int MAXIMUM_CAPACITY = 1 << 30; + + /** + * Minimum number of entries that one step of incremental rehashing migrates from the old to the new sub-table. + */ + private static final int MIN_TRANSFERRED_PER_INCREMENTAL_REHASH = 4; + + /** + * An empty table shared by all zero-capacity maps (typically from default + * constructor). It is never written to, and replaced on first put. Its size + * is set to half the minimum, so that the first resize will create a + * minimum-sized table. + */ + private static final StateTableEntry[] EMPTY_TABLE = new StateTableEntry[MINIMUM_CAPACITY >>> 1]; + + /** + * Empty entry that we use to bootstrap our StateEntryIterator. + */ + private static final StateTableEntry ITERATOR_BOOTSTRAP_ENTRY = new StateTableEntry<>(); + + /** + * Maintains an ordered set of version ids that are still in use by unreleased snapshots. + */ + private final TreeSet snapshotVersions; + + /** + * This is the primary entry array (hash directory) of the state table. If no incremental rehash is ongoing, this + * is the only used table. + **/ + private StateTableEntry[] primaryTable; + + /** + * We maintain a secondary entry array while performing an incremental rehash. The purpose is to slowly migrate + * entries from the primary table to this resized table array. When all entries are migrated, this becomes the new + * primary table. + */ + private StateTableEntry[] incrementalRehashTable; + + /** + * The current number of mappings in the primary table. + */ + private int primaryTableSize; + + /** + * The current number of mappings in the rehash table. + */ + private int incrementalRehashTableSize; + + /** + * The next index for a step of incremental rehashing in the primary table. + */ + private int rehashIndex; + + /** + * The current version of this map. Used for copy-on-write mechanics. + */ + private int stateTableVersion; + + /** + * The highest version of this map that is still required by any unreleased snapshot. + */ + private int highestRequiredSnapshotVersion; + + /** + * The last namespace that was actually inserted. This is a small optimization to reduce duplicate namespace objects. + */ + private N lastNamespace; + + /** + * The {@link CopyOnWriteStateTable} is rehashed when its size exceeds this threshold. + * The value of this field is generally .75 * capacity, except when + * the capacity is zero, as described in the EMPTY_TABLE declaration + * above. + */ + private int threshold; + + /** + * Incremented by "structural modifications" to allow (best effort) + * detection of concurrent modification. + */ + private int modCount; + + /** + * Constructs a new {@code StateTable} with default capacity of 1024. + * + * @param keyContext the key context. + * @param metaInfo the meta information, including the type serializer for state copy-on-write. + */ + CopyOnWriteStateTable(InternalKeyContext keyContext, RegisteredBackendStateMetaInfo metaInfo) { + this(keyContext, metaInfo, 1024); + } + + /** + * Constructs a new {@code StateTable} instance with the specified capacity. + * + * @param keyContext the key context. + * @param metaInfo the meta information, including the type serializer for state copy-on-write. + * @param capacity the initial capacity of this hash map. + * @throws IllegalArgumentException when the capacity is less than zero. + */ + @SuppressWarnings("unchecked") + private CopyOnWriteStateTable(InternalKeyContext keyContext, RegisteredBackendStateMetaInfo metaInfo, int capacity) { + super(keyContext, metaInfo); + + // initialized tables to EMPTY_TABLE. + this.primaryTable = (StateTableEntry[]) EMPTY_TABLE; + this.incrementalRehashTable = (StateTableEntry[]) EMPTY_TABLE; + + // initialize sizes to 0. + this.primaryTableSize = 0; + this.incrementalRehashTableSize = 0; + + this.rehashIndex = 0; + this.stateTableVersion = 0; + this.highestRequiredSnapshotVersion = 0; + this.snapshotVersions = new TreeSet<>(); + + if (capacity < 0) { + throw new IllegalArgumentException("Capacity: " + capacity); + } + + if (capacity == 0) { + threshold = -1; + return; + } + + if (capacity < MINIMUM_CAPACITY) { + capacity = MINIMUM_CAPACITY; + } else if (capacity > MAXIMUM_CAPACITY) { + capacity = MAXIMUM_CAPACITY; + } else { + capacity = MathUtils.roundUpToPowerOfTwo(capacity); + } + primaryTable = makeTable(capacity); + } + + // Public API from AbstractStateTable ------------------------------------------------------------------------------ + + /** + * Returns the total number of entries in this {@link CopyOnWriteStateTable}. This is the sum of both sub-tables. + * + * @return the number of entries in this {@link CopyOnWriteStateTable}. + */ + @Override + public int size() { + return primaryTableSize + incrementalRehashTableSize; + } + + @Override + public S get(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final int requiredVersion = highestRequiredSnapshotVersion; + final StateTableEntry[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry e = tab[index]; e != null; e = e.next) { + final K eKey = e.key; + final N eNamespace = e.namespace; + if ((e.hash == hash && key.equals(eKey) && namespace.equals(eNamespace))) { + + // copy-on-write check for state + if (e.stateVersion < requiredVersion) { + // copy-on-write check for entry + if (e.entryVersion < requiredVersion) { + e = handleChainedEntryCopyOnWrite(tab, hash & (tab.length - 1), e); + } + e.stateVersion = stateTableVersion; + e.state = getStateSerializer().copy(e.state); + } + + return e.state; + } + } + + return null; + } + + @Override + public void put(K key, int keyGroup, N namespace, S state) { + put(key, namespace, state); + } + + @Override + public S get(N namespace) { + return get(keyContext.getCurrentKey(), namespace); + } + + @Override + public boolean containsKey(N namespace) { + return containsKey(keyContext.getCurrentKey(), namespace); + } + + @Override + public void put(N namespace, S state) { + put(keyContext.getCurrentKey(), namespace, state); + } + + @Override + public S putAndGetOld(N namespace, S state) { + return putAndGetOld(keyContext.getCurrentKey(), namespace, state); + } + + @Override + public void remove(N namespace) { + remove(keyContext.getCurrentKey(), namespace); + } + + @Override + public S removeAndGetOld(N namespace) { + return removeAndGetOld(keyContext.getCurrentKey(), namespace); + } + + @Override + public void transform(N namespace, T value, StateTransformationFunction transformation) throws Exception { + transform(keyContext.getCurrentKey(), namespace, value, transformation); + } + + // Private implementation details of the API methods --------------------------------------------------------------- + + /** + * Returns whether this table contains the specified key/namespace composite key. + * + * @param key the key in the composite key to search for. Not null. + * @param namespace the namespace in the composite key to search for. Not null. + * @return {@code true} if this map contains the specified key/namespace composite key, + * {@code false} otherwise. + */ + boolean containsKey(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final StateTableEntry[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry e = tab[index]; e != null; e = e.next) { + final K eKey = e.key; + final N eNamespace = e.namespace; + + if ((e.hash == hash && key.equals(eKey) && namespace.equals(eNamespace))) { + return true; + } + } + return false; + } + + /** + * Maps the specified key/namespace composite key to the specified value. This method should be preferred + * over {@link #putAndGetOld(Object, Object, Object)} (Object, Object)} when the caller is not interested + * in the old value, because this can potentially reduce copy-on-write activity. + * + * @param key the key. Not null. + * @param namespace the namespace. Not null. + * @param value the value. Can be null. + */ + void put(K key, N namespace, S value) { + final StateTableEntry e = putEntry(key, namespace); + + e.state = value; + e.stateVersion = stateTableVersion; + } + + /** + * Maps the specified key/namespace composite key to the specified value. Returns the previous state that was + * registered under the composite key. + * + * @param key the key. Not null. + * @param namespace the namespace. Not null. + * @param value the value. Can be null. + * @return the value of any previous mapping with the specified key or + * {@code null} if there was no such mapping. + */ + S putAndGetOld(K key, N namespace, S value) { + + final StateTableEntry e = putEntry(key, namespace); + + // copy-on-write check for state + S oldState = (e.stateVersion < highestRequiredSnapshotVersion) ? + getStateSerializer().copy(e.state) : + e.state; + + e.state = value; + e.stateVersion = stateTableVersion; + + return oldState; + } + + /** + * Removes the mapping with the specified key/namespace composite key from this map. This method should be preferred + * over {@link #removeAndGetOld(Object, Object)} when the caller is not interested in the old value, because this + * can potentially reduce copy-on-write activity. + * + * @param key the key of the mapping to remove. Not null. + * @param namespace the namespace of the mapping to remove. Not null. + */ + void remove(K key, N namespace) { + removeEntry(key, namespace); + } + + /** + * Removes the mapping with the specified key/namespace composite key from this map, returning the state that was + * found under the entry. + * + * @param key the key of the mapping to remove. Not null. + * @param namespace the namespace of the mapping to remove. Not null. + * @return the value of the removed mapping or {@code null} if no mapping + * for the specified key was found. + */ + S removeAndGetOld(K key, N namespace) { + + final StateTableEntry e = removeEntry(key, namespace); + + return e != null ? + // copy-on-write check for state + (e.stateVersion < highestRequiredSnapshotVersion ? + getStateSerializer().copy(e.state) : + e.state) : + null; + } + + /** + * @param key the key of the mapping to remove. Not null. + * @param namespace the namespace of the mapping to remove. Not null. + * @param value the value that is the second input for the transformation. + * @param transformation the transformation function to apply on the old state and the given value. + * @param type of the value that is the second input to the {@link StateTransformationFunction}. + * @throws Exception exception that happen on applying the function. + * @see #transform(Object, Object, StateTransformationFunction). + */ + void transform( + K key, + N namespace, + T value, + StateTransformationFunction transformation) throws Exception { + + final StateTableEntry entry = putEntry(key, namespace); + + // copy-on-write check for state + entry.state = transformation.apply( + (entry.stateVersion < highestRequiredSnapshotVersion) ? + getStateSerializer().copy(entry.state) : + entry.state, + value); + entry.stateVersion = stateTableVersion; + } + + /** + * Helper method that is the basis for operations that add mappings. + */ + private StateTableEntry putEntry(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final StateTableEntry[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry e = tab[index]; e != null; e = e.next) { + if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) { + + // copy-on-write check for entry + if (e.entryVersion < highestRequiredSnapshotVersion) { + e = handleChainedEntryCopyOnWrite(tab, index, e); + } + + return e; + } + } + + ++modCount; + if (size() > threshold) { + doubleCapacity(); + } + + return addNewStateTableEntry(tab, key, namespace, hash); + } + + /** + * Helper method that is the basis for operations that remove mappings. + */ + private StateTableEntry removeEntry(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final StateTableEntry[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry e = tab[index], prev = null; e != null; prev = e, e = e.next) { + if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) { + if (prev == null) { + tab[index] = e.next; + } else { + // copy-on-write check for entry + if (prev.entryVersion < highestRequiredSnapshotVersion) { + prev = handleChainedEntryCopyOnWrite(tab, index, prev); + } + prev.next = e.next; + } + ++modCount; + if (tab == primaryTable) { + --primaryTableSize; + } else { + --incrementalRehashTableSize; + } + return e; + } + } + return null; + } + + private void checkKeyNamespacePreconditions(K key, N namespace) { + Preconditions.checkNotNull(key, "No key set. This method should not be called outside of a keyed context."); + Preconditions.checkNotNull(namespace, "Provided namespace is null."); + } + + // Meta data setter / getter and toString -------------------------------------------------------------------------- + + @Override + public TypeSerializer getStateSerializer() { + return metaInfo.getStateSerializer(); + } + + @Override + public TypeSerializer getNamespaceSerializer() { + return metaInfo.getNamespaceSerializer(); + } + + @Override + public RegisteredBackendStateMetaInfo getMetaInfo() { + return metaInfo; + } + + @Override + public void setMetaInfo(RegisteredBackendStateMetaInfo metaInfo) { + this.metaInfo = metaInfo; + } + + // Iteration ------------------------------------------------------------------------------------------------------ + + @Override + public Iterator> iterator() { + return new StateEntryIterator(); + } + + // Private utility functions for StateTable management ------------------------------------------------------------- + + /** + * @see #releaseSnapshot(CopyOnWriteStateTableSnapshot) + */ + @VisibleForTesting + void releaseSnapshot(int snapshotVersion) { + // we guard against concurrent modifications of highestRequiredSnapshotVersion between snapshot and release. + // Only stale reads of from the result of #releaseSnapshot calls are ok. + synchronized (snapshotVersions) { + Preconditions.checkState(snapshotVersions.remove(snapshotVersion), "Attempt to release unknown snapshot version"); + highestRequiredSnapshotVersion = snapshotVersions.isEmpty() ? 0 : snapshotVersions.last(); + } + } + + /** + * Creates (combined) copy of the table arrays for a snapshot. This method must be called by the same Thread that + * does modifications to the {@link CopyOnWriteStateTable}. + */ + @VisibleForTesting + @SuppressWarnings("unchecked") + StateTableEntry[] snapshotTableArrays() { + + // we guard against concurrent modifications of highestRequiredSnapshotVersion between snapshot and release. + // Only stale reads of from the result of #releaseSnapshot calls are ok. This is why we must call this method + // from the same thread that does all the modifications to the table. + synchronized (snapshotVersions) { + + // increase the table version for copy-on-write and register the snapshot + if (++stateTableVersion < 0) { + // this is just a safety net against overflows, but should never happen in practice (i.e., only after 2^31 snapshots) + throw new IllegalStateException("Version count overflow in CopyOnWriteStateTable. Enforcing restart."); + } + + highestRequiredSnapshotVersion = stateTableVersion; + snapshotVersions.add(highestRequiredSnapshotVersion); + } + + StateTableEntry[] table = primaryTable; + if (isRehashing()) { + // consider both tables for the snapshot, the rehash index tells us which part of the two tables we need + final int localRehashIndex = rehashIndex; + final int localCopyLength = table.length - localRehashIndex; + StateTableEntry[] copy = new StateTableEntry[localRehashIndex + table.length]; + // for the primary table, take every index >= rhIdx. + System.arraycopy(table, localRehashIndex, copy, 0, localCopyLength); + + // for the new table, we are sure that two regions contain all the entries: + // [0, rhIdx[ AND [table.length / 2, table.length / 2 + rhIdx[ + table = incrementalRehashTable; + System.arraycopy(table, 0, copy, localCopyLength, localRehashIndex); + System.arraycopy(table, table.length >>> 1, copy, localCopyLength + localRehashIndex, localRehashIndex); + + return copy; + } else { + // we only need to copy the primary table + return Arrays.copyOf(table, table.length); + } + } + + /** + * Allocate a table of the given capacity and set the threshold accordingly. + * + * @param newCapacity must be a power of two + */ + private StateTableEntry[] makeTable(int newCapacity) { + + if (MAXIMUM_CAPACITY == newCapacity) { + LOG.warn("Maximum capacity of 2^30 in StateTable reached. Cannot increase hash table size. This can lead " + + "to more collisions and lower performance. Please consider scaling-out your job or using a " + + "different keyed state backend implementation!"); + } + + threshold = (newCapacity >> 1) + (newCapacity >> 2); // 3/4 capacity + @SuppressWarnings("unchecked") StateTableEntry[] newTable + = (StateTableEntry[]) new StateTableEntry[newCapacity]; + return newTable; + } + + /** + * Creates and inserts a new {@link StateTableEntry}. + */ + private StateTableEntry addNewStateTableEntry( + StateTableEntry[] table, + K key, + N namespace, + int hash) { + + // small optimization that aims to avoid holding references on duplicate namespace objects + if (namespace.equals(lastNamespace)) { + namespace = lastNamespace; + } else { + lastNamespace = namespace; + } + + int index = hash & (table.length - 1); + StateTableEntry newEntry = new StateTableEntry<>( + key, + namespace, + null, + hash, + table[index], + stateTableVersion, + stateTableVersion); + table[index] = newEntry; + + if (table == primaryTable) { + ++primaryTableSize; + } else { + ++incrementalRehashTableSize; + } + return newEntry; + } + + /** + * Select the sub-table which is responsible for entries with the given hash code. + * + * @param hashCode the hash code which we use to decide about the table that is responsible. + * @return the index of the sub-table that is responsible for the entry with the given hash code. + */ + private StateTableEntry[] selectActiveTable(int hashCode) { + return (hashCode & (primaryTable.length - 1)) >= rehashIndex ? primaryTable : incrementalRehashTable; + } + + /** + * Doubles the capacity of the hash table. Existing entries are placed in + * the correct bucket on the enlarged table. If the current capacity is, + * MAXIMUM_CAPACITY, this method is a no-op. Returns the table, which + * will be new unless we were already at MAXIMUM_CAPACITY. + */ + private void doubleCapacity() { + + // There can only be one rehash in flight. From the amount of incremental rehash steps we take, this should always hold. + Preconditions.checkState(!isRehashing(), "There is already a rehash in progress."); + + StateTableEntry[] oldTable = primaryTable; + + int oldCapacity = oldTable.length; + + if (oldCapacity == MAXIMUM_CAPACITY) { + return; + } + + incrementalRehashTable = makeTable(oldCapacity * 2); + } + + /** + * Returns true, if an incremental rehash is in progress. + */ + @VisibleForTesting + boolean isRehashing() { + // if we rehash, the secondary table is not empty + return EMPTY_TABLE != incrementalRehashTable; + } + + /** + * Computes the hash for the composite of key and namespace and performs some steps of incremental rehash if + * incremental rehashing is in progress. + */ + private int computeHashForOperationAndDoIncrementalRehash(K key, N namespace) { + + checkKeyNamespacePreconditions(key, namespace); + + if (isRehashing()) { + incrementalRehash(); + } + + return compositeHash(key, namespace); + } + + /** + * Runs a number of steps for incremental rehashing. + */ + @SuppressWarnings("unchecked") + private void incrementalRehash() { + + StateTableEntry[] oldTable = primaryTable; + StateTableEntry[] newTable = incrementalRehashTable; + + int oldCapacity = oldTable.length; + int newMask = newTable.length - 1; + int requiredVersion = highestRequiredSnapshotVersion; + int rhIdx = rehashIndex; + int transferred = 0; + + // we migrate a certain minimum amount of entries from the old to the new table + while (transferred < MIN_TRANSFERRED_PER_INCREMENTAL_REHASH) { + + StateTableEntry e = oldTable[rhIdx]; + + while (e != null) { + // copy-on-write check for entry + if (e.entryVersion < requiredVersion) { + e = new StateTableEntry<>(e, stateTableVersion); + } + StateTableEntry n = e.next; + int pos = e.hash & newMask; + e.next = newTable[pos]; + newTable[pos] = e; + e = n; + ++transferred; + } + + oldTable[rhIdx] = null; + if (++rhIdx == oldCapacity) { + //here, the rehash is complete and we release resources and reset fields + primaryTable = newTable; + incrementalRehashTable = (StateTableEntry[]) EMPTY_TABLE; + primaryTableSize += incrementalRehashTableSize; + incrementalRehashTableSize = 0; + rehashIndex = 0; + return; + } + } + + // sync our local bookkeeping the with official bookkeeping fields + primaryTableSize -= transferred; + incrementalRehashTableSize += transferred; + rehashIndex = rhIdx; + } + + /** + * Perform copy-on-write for entry chains. We iterate the (hopefully and probably) still cached chain, replace + * all links up to the 'untilEntry', which we actually wanted to modify. + */ + private StateTableEntry handleChainedEntryCopyOnWrite( + StateTableEntry[] tab, + int tableIdx, + StateTableEntry untilEntry) { + + final int required = highestRequiredSnapshotVersion; + + StateTableEntry current = tab[tableIdx]; + StateTableEntry copy; + + if (current.entryVersion < required) { + copy = new StateTableEntry<>(current, stateTableVersion); + tab[tableIdx] = copy; + } else { + // nothing to do, just advance copy to current + copy = current; + } + + // we iterate the chain up to 'until entry' + while (current != untilEntry) { + + //advance current + current = current.next; + + if (current.entryVersion < required) { + // copy and advance the current's copy + copy.next = new StateTableEntry<>(current, stateTableVersion); + copy = copy.next; + } else { + // nothing to do, just advance copy to current + copy = current; + } + } + + return copy; + } + + @SuppressWarnings("unchecked") + private static StateTableEntry getBootstrapEntry() { + return (StateTableEntry) ITERATOR_BOOTSTRAP_ENTRY; + } + + /** + * Helper function that creates and scrambles a composite hash for key and namespace. + */ + private static int compositeHash(Object key, Object namespace) { + // create composite key through XOR, then apply some bit-mixing for better distribution of skewed keys. + return MathUtils.bitMix(key.hashCode() ^ namespace.hashCode()); + } + + // Snapshotting ---------------------------------------------------------------------------------------------------- + + int getStateTableVersion() { + return stateTableVersion; + } + + /** + * Creates a snapshot of this {@link CopyOnWriteStateTable}, to be written in checkpointing. The snapshot integrity + * is protected through copy-on-write from the {@link CopyOnWriteStateTable}. Users should call + * {@link #releaseSnapshot(CopyOnWriteStateTableSnapshot)} after using the returned object. + * + * @return a snapshot from this {@link CopyOnWriteStateTable}, for checkpointing. + */ + @Override + public CopyOnWriteStateTableSnapshot createSnapshot() { + return new CopyOnWriteStateTableSnapshot<>(this); + } + + /** + * Releases a snapshot for this {@link CopyOnWriteStateTable}. This method should be called once a snapshot is no more needed, + * so that the {@link CopyOnWriteStateTable} can stop considering this snapshot for copy-on-write, thus avoiding unnecessary + * object creation. + * + * @param snapshotToRelease the snapshot to release, which was previously created by this state table. + */ + void releaseSnapshot(CopyOnWriteStateTableSnapshot snapshotToRelease) { + + Preconditions.checkArgument(snapshotToRelease.isOwner(this), + "Cannot release snapshot which is owned by a different state table."); + + releaseSnapshot(snapshotToRelease.getSnapshotVersion()); + } + + // StateTableEntry ------------------------------------------------------------------------------------------------- + + /** + * One entry in the {@link CopyOnWriteStateTable}. This is a triplet of key, namespace, and state. Thereby, key and + * namespace together serve as a composite key for the state. This class also contains some management meta data for + * copy-on-write, a pointer to link other {@link StateTableEntry}s to a list, and cached hash code. + * + * @param type of key. + * @param type of namespace. + * @param type of state. + */ + static class StateTableEntry implements StateEntry { + + /** + * The key. Assumed to be immutable and not null. + */ + final K key; + + /** + * The namespace. Assumed to be immutable and not null. + */ + final N namespace; + + /** + * The state. This is not final to allow exchanging the object for copy-on-write. Can be null. + */ + S state; + + /** + * Link to another {@link StateTableEntry}. This is used to resolve collisions in the + * {@link CopyOnWriteStateTable} through chaining. + */ + StateTableEntry next; + + /** + * The version of this {@link StateTableEntry}. This is meta data for copy-on-write of the table structure. + */ + int entryVersion; + + /** + * The version of the state object in this entry. This is meta data for copy-on-write of the state object itself. + */ + int stateVersion; + + /** + * The computed secondary hash for the composite of key and namespace. + */ + final int hash; + + StateTableEntry() { + this(null, null, null, 0, null, 0, 0); + } + + StateTableEntry(StateTableEntry other, int entryVersion) { + this(other.key, other.namespace, other.state, other.hash, other.next, entryVersion, other.stateVersion); + } + + StateTableEntry( + K key, + N namespace, + S state, + int hash, + StateTableEntry next, + int entryVersion, + int stateVersion) { + this.key = key; + this.namespace = namespace; + this.hash = hash; + this.next = next; + this.entryVersion = entryVersion; + this.state = state; + this.stateVersion = stateVersion; + } + + public final void setState(S value, int mapVersion) { + // naturally, we can update the state version every time we replace the old state with a different object + if (value != state) { + this.state = value; + this.stateVersion = mapVersion; + } + } + + @Override + public K getKey() { + return key; + } + + @Override + public N getNamespace() { + return namespace; + } + + @Override + public S getState() { + return state; + } + + @Override + public final boolean equals(Object o) { + if (!(o instanceof CopyOnWriteStateTable.StateTableEntry)) { + return false; + } + + StateEntry e = (StateEntry) o; + return e.getKey().equals(key) + && e.getNamespace().equals(namespace) + && Objects.equals(e.getState(), state); + } + + @Override + public final int hashCode() { + return (key.hashCode() ^ namespace.hashCode()) ^ Objects.hashCode(state); + } + + @Override + public final String toString() { + return "(" + key + "|" + namespace + ")=" + state; + } + } + + // For testing ---------------------------------------------------------------------------------------------------- + + @Override + public int sizeOfNamespace(Object namespace) { + int count = 0; + for (StateEntry entry : this) { + if (null != entry && namespace.equals(entry.getNamespace())) { + ++count; + } + } + return count; + } + + + // StateEntryIterator --------------------------------------------------------------------------------------------- + + /** + * Iterator over the entries in a {@link CopyOnWriteStateTable}. + */ + class StateEntryIterator implements Iterator> { + private StateTableEntry[] activeTable; + private int nextTablePosition; + private StateTableEntry nextEntry; + private int expectedModCount = modCount; + + StateEntryIterator() { + this.activeTable = primaryTable; + this.nextTablePosition = 0; + this.expectedModCount = modCount; + this.nextEntry = getBootstrapEntry(); + advanceIterator(); + } + + private StateTableEntry advanceIterator() { + + StateTableEntry entryToReturn = nextEntry; + StateTableEntry next = entryToReturn.next; + + // consider both sub-tables tables to cover the case of rehash + while (next == null) { + + StateTableEntry[] tab = activeTable; + + while (nextTablePosition < tab.length) { + next = tab[nextTablePosition++]; + + if (next != null) { + nextEntry = next; + return entryToReturn; + } + } + + if (activeTable == incrementalRehashTable) { + break; + } + + activeTable = incrementalRehashTable; + nextTablePosition = 0; + } + + nextEntry = next; + return entryToReturn; + } + + @Override + public boolean hasNext() { + return nextEntry != null; + } + + @Override + public StateTableEntry next() { + if (modCount != expectedModCount) { + throw new ConcurrentModificationException(); + } + + if (nextEntry == null) { + throw new NoSuchElementException(); + } + + return advanceIterator(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Read-only iterator"); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTableSnapshot.java new file mode 100644 index 0000000000000..db3b1973759f1 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTableSnapshot.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; + +import java.io.IOException; + +/** + * This class represents the snapshot of a {@link CopyOnWriteStateTable} and has a role in operator state checkpointing. Besides + * holding the {@link CopyOnWriteStateTable}s internal entries at the time of the snapshot, this class is also responsible for + * preparing and writing the state in the process of checkpointing. + *

+ * IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics + * through the {@link CopyOnWriteStateTable} that created the snapshot object, but all objects in this snapshot must be considered + * as READ-ONLY!. The reason is that the objects held by this class may or may not be deep copies of original objects + * that may still used in the {@link CopyOnWriteStateTable}. This depends for each entry on whether or not it was subject to + * copy-on-write operations by the {@link CopyOnWriteStateTable}. Phrased differently: the {@link CopyOnWriteStateTable} provides + * copy-on-write isolation for this snapshot, but this snapshot does not isolate modifications from the + * {@link CopyOnWriteStateTable}! + * + * @param type of key + * @param type of namespace + * @param type of state + */ +@Internal +public class CopyOnWriteStateTableSnapshot + extends AbstractStateTableSnapshot> { + + /** + * Version of the {@link CopyOnWriteStateTable} when this snapshot was created. This can be used to release the snapshot. + */ + private final int snapshotVersion; + + /** + * The number of entries in the {@link CopyOnWriteStateTable} at the time of creating this snapshot. + */ + private final int stateTableSize; + + /** + * The state table entries, as by the time this snapshot was created. Objects in this array may or may not be deep + * copies of the current entries in the {@link CopyOnWriteStateTable} that created this snapshot. This depends for each entry + * on whether or not it was subject to copy-on-write operations by the {@link CopyOnWriteStateTable}. + */ + private final CopyOnWriteStateTable.StateTableEntry[] snapshotData; + + /** + * Offsets for the individual key-groups. This is lazily created when the snapshot is grouped by key-group during + * the process of writing this snapshot to an output as part of checkpointing. + */ + private int[] keyGroupOffsets; + + /** + * Creates a new {@link CopyOnWriteStateTableSnapshot}. + * + * @param owningStateTable the {@link CopyOnWriteStateTable} for which this object represents a snapshot. + */ + CopyOnWriteStateTableSnapshot(CopyOnWriteStateTable owningStateTable) { + + super(owningStateTable); + this.snapshotData = owningStateTable.snapshotTableArrays(); + this.snapshotVersion = owningStateTable.getStateTableVersion(); + this.stateTableSize = owningStateTable.size(); + this.keyGroupOffsets = null; + } + + /** + * Returns the internal version of the {@link CopyOnWriteStateTable} when this snapshot was created. This value must be used to + * tell the {@link CopyOnWriteStateTable} when to release this snapshot. + */ + int getSnapshotVersion() { + return snapshotVersion; + } + + /** + * Partitions the snapshot data by key-group. The algorithm first builds a histogram for the distribution of keys + * into key-groups. Then, the histogram is accumulated to obtain the boundaries of each key-group in an array. + * Last, we use the accumulated counts as write position pointers for the key-group's bins when reordering the + * entries by key-group. This operation is lazily performed before the first writing of a key-group. + *

+ * As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the + * cuckoo cycles in cuckoo hashing. This can trade some performance for a smaller memory footprint. + */ + @SuppressWarnings("unchecked") + private void partitionEntriesByKeyGroup() { + + // We only have to perform this step once before the first key-group is written + if (null != keyGroupOffsets) { + return; + } + + final KeyGroupRange keyGroupRange = owningStateTable.keyContext.getKeyGroupRange(); + final int totalKeyGroups = owningStateTable.keyContext.getNumberOfKeyGroups(); + final int baseKgIdx = keyGroupRange.getStartKeyGroup(); + final int[] histogram = new int[keyGroupRange.getNumberOfKeyGroups() + 1]; + + CopyOnWriteStateTable.StateTableEntry[] unfold = new CopyOnWriteStateTable.StateTableEntry[stateTableSize]; + + // 1) In this step we i) 'unfold' the linked list of entries to a flat array and ii) build a histogram for key-groups + int unfoldIndex = 0; + for (CopyOnWriteStateTable.StateTableEntry entry : snapshotData) { + while (null != entry) { + int effectiveKgIdx = + KeyGroupRangeAssignment.computeKeyGroupForKeyHash(entry.key.hashCode(), totalKeyGroups) - baseKgIdx + 1; + ++histogram[effectiveKgIdx]; + unfold[unfoldIndex++] = entry; + entry = entry.next; + } + } + + // 2) We accumulate the histogram bins to obtain key-group ranges in the final array + for (int i = 1; i < histogram.length; ++i) { + histogram[i] += histogram[i - 1]; + } + + // 3) We repartition the entries by key-group, using the histogram values as write indexes + for (CopyOnWriteStateTable.StateTableEntry t : unfold) { + int effectiveKgIdx = + KeyGroupRangeAssignment.computeKeyGroupForKeyHash(t.key.hashCode(), totalKeyGroups) - baseKgIdx; + snapshotData[histogram[effectiveKgIdx]++] = t; + } + + // 4) As byproduct, we also created the key-group offsets + this.keyGroupOffsets = histogram; + } + + @Override + public void release() { + owningStateTable.releaseSnapshot(this); + } + + @Override + public void writeMappingsInKeyGroup(DataOutputView dov, int keyGroupId) throws IOException { + + if (null == keyGroupOffsets) { + partitionEntriesByKeyGroup(); + } + + final CopyOnWriteStateTable.StateTableEntry[] groupedOut = snapshotData; + KeyGroupRange keyGroupRange = owningStateTable.keyContext.getKeyGroupRange(); + int keyGroupOffsetIdx = keyGroupId - keyGroupRange.getStartKeyGroup() - 1; + int startOffset = keyGroupOffsetIdx < 0 ? 0 : keyGroupOffsets[keyGroupOffsetIdx]; + int endOffset = keyGroupOffsets[keyGroupOffsetIdx + 1]; + + TypeSerializer keySerializer = owningStateTable.keyContext.getKeySerializer(); + TypeSerializer namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer(); + TypeSerializer stateSerializer = owningStateTable.metaInfo.getStateSerializer(); + + // write number of mappings in key-group + dov.writeInt(endOffset - startOffset); + + // write mappings + for (int i = startOffset; i < endOffset; ++i) { + CopyOnWriteStateTable.StateTableEntry toWrite = groupedOut[i]; + groupedOut[i] = null; // free asap for GC + namespaceSerializer.serialize(toWrite.namespace, dov); + keySerializer.serialize(toWrite.key, dov); + stateSerializer.serialize(toWrite.state, dov); + } + } + + /** + * Returns true iff the given state table is the owner of this snapshot object. + */ + boolean isOwner(CopyOnWriteStateTable stateTable) { + return stateTable == owningStateTable; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapFoldingState.java new file mode 100644 index 0000000000000..ad955c308c330 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapFoldingState.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.functions.FoldFunction; +import org.apache.flink.api.common.state.FoldingState; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.StateTransformationFunction; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; + +/** + * Heap-backed partitioned {@link FoldingState} that is + * snapshotted into files. + * + * @param The type of the key. + * @param The type of the namespace. + * @param The type of the values that can be folded into the state. + * @param The type of the value in the folding state. + */ +public class HeapFoldingState + extends AbstractHeapState, FoldingStateDescriptor> + implements FoldingState { + + /** The function used to fold the state */ + private final FoldTransformation foldTransformation; + + /** + * Creates a new key/value state for the given hash map of key/value pairs. + * + * @param stateDesc The state identifier for the state. This contains name + * and can create a default state value. + * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. + */ + public HeapFoldingState( + FoldingStateDescriptor stateDesc, + StateTable stateTable, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) { + super(stateDesc, stateTable, keySerializer, namespaceSerializer); + this.foldTransformation = new FoldTransformation<>(stateDesc); + } + + // ------------------------------------------------------------------------ + // state access + // ------------------------------------------------------------------------ + + public ACC get() { + return stateTable.get(currentNamespace); + } + + public void add(T value) throws IOException { + + if (value == null) { + clear(); + return; + } + + try { + stateTable.transform(currentNamespace, value, foldTransformation); + } catch (Exception e) { + throw new IOException("Could not add value to folding state.", e); + } + } + + static final class FoldTransformation implements StateTransformationFunction { + + private final FoldingStateDescriptor stateDescriptor; + private final FoldFunction foldFunction; + + FoldTransformation(FoldingStateDescriptor stateDesc) { + this.stateDescriptor = Preconditions.checkNotNull(stateDesc); + this.foldFunction = Preconditions.checkNotNull(stateDesc.getFoldFunction()); + } + + @Override + public ACC apply(ACC previousState, T value) throws Exception { + return foldFunction.fold((previousState != null) ? previousState : stateDescriptor.getDefaultValue(), value); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapListState.java new file mode 100644 index 0000000000000..ab5fff51316d0 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapListState.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.util.Preconditions; + +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; + +/** + * Heap-backed partitioned {@link ListState} that is snapshotted + * into files. + * + * @param The type of the key. + * @param The type of the namespace. + * @param The type of the value. + */ +public class HeapListState + extends AbstractHeapMergingState, ArrayList, ListState, ListStateDescriptor> + implements ListState { + + /** + * Creates a new key/value state for the given hash map of key/value pairs. + * + * @param stateDesc The state identifier for the state. This contains name + * and can create a default state value. + * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. + */ + public HeapListState( + ListStateDescriptor stateDesc, + StateTable> stateTable, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) { + super(stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + // ------------------------------------------------------------------------ + // state access + // ------------------------------------------------------------------------ + + @Override + public Iterable get() { + return stateTable.get(currentNamespace); + } + + @Override + public void add(V value) { + final N namespace = currentNamespace; + + if (value == null) { + clear(); + return; + } + + final StateTable> map = stateTable; + ArrayList list = map.get(namespace); + + if (list == null) { + list = new ArrayList<>(); + map.put(namespace, list); + } + list.add(value); + } + + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkState(namespace != null, "No namespace given."); + Preconditions.checkState(key != null, "No key given."); + + ArrayList result = stateTable.get(key, namespace); + + if (result == null) { + return null; + } + + TypeSerializer serializer = stateDesc.getSerializer(); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper view = new DataOutputViewStreamWrapper(baos); + + // write the same as RocksDB writes lists, with one ',' separator + for (int i = 0; i < result.size(); i++) { + serializer.serialize(result.get(i), view); + if (i < result.size() -1) { + view.writeByte(','); + } + } + view.flush(); + + return baos.toByteArray(); + } + + // ------------------------------------------------------------------------ + // state merging + // ------------------------------------------------------------------------ + + @Override + protected ArrayList mergeState(ArrayList a, ArrayList b) { + a.addAll(b); + return a; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapReducingState.java new file mode 100644 index 0000000000000..b6eed74611a18 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapReducingState.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.StateTransformationFunction; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; + +/** + * Heap-backed partitioned {@link ReducingState} that is + * snapshotted into files. + * + * @param The type of the key. + * @param The type of the namespace. + * @param The type of the value. + */ +public class HeapReducingState + extends AbstractHeapMergingState, ReducingStateDescriptor> + implements ReducingState { + + private final ReduceTransformation reduceTransformation; + + /** + * Creates a new key/value state for the given hash map of key/value pairs. + * + * @param stateDesc The state identifier for the state. This contains name + * and can create a default state value. + * @param stateTable The state table to use in this kev/value state. May contain initial state. + */ + public HeapReducingState( + ReducingStateDescriptor stateDesc, + StateTable stateTable, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) { + + super(stateDesc, stateTable, keySerializer, namespaceSerializer); + this.reduceTransformation = new ReduceTransformation<>(stateDesc.getReduceFunction()); + } + + // ------------------------------------------------------------------------ + // state access + // ------------------------------------------------------------------------ + + @Override + public V get() { + return stateTable.get(currentNamespace); + } + + @Override + public void add(V value) throws IOException { + + if (value == null) { + clear(); + return; + } + + try { + stateTable.transform(currentNamespace, value, reduceTransformation); + } catch (Exception e) { + throw new IOException("Exception while applying ReduceFunction in reducing state", e); + } + } + + // ------------------------------------------------------------------------ + // state merging + // ------------------------------------------------------------------------ + + @Override + protected V mergeState(V a, V b) throws Exception { + return reduceTransformation.apply(a, b); + } + + static final class ReduceTransformation implements StateTransformationFunction { + + private final ReduceFunction reduceFunction; + + ReduceTransformation(ReduceFunction reduceFunction) { + this.reduceFunction = Preconditions.checkNotNull(reduceFunction); + } + + @Override + public V apply(V previousState, V value) throws Exception { + return previousState != null ? reduceFunction.reduce(previousState, value) : value; + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapValueState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapValueState.java new file mode 100644 index 0000000000000..436c20e676b53 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/HeapValueState.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +/** + * Heap-backed partitioned {@link ValueState} that is snapshotted + * into files. + * + * @param The type of the key. + * @param The type of the namespace. + * @param The type of the value. + */ +public class HeapValueState + extends AbstractHeapState, ValueStateDescriptor> + implements ValueState { + + /** + * Creates a new key/value state for the given hash map of key/value pairs. + * + * @param stateDesc The state identifier for the state. This contains name + * and can create a default state value. + * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. + */ + public HeapValueState( + ValueStateDescriptor stateDesc, + StateTable stateTable, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) { + super(stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + @Override + public V value() { + final V result = stateTable.get(currentNamespace); + + if (result == null) { + return stateDesc.getDefaultValue(); + } + + return result; + } + + @Override + public void update(V value) { + + if (value == null) { + clear(); + return; + } + + stateTable.put(currentNamespace, value); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/InternalKeyContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/InternalKeyContext.java new file mode 100644 index 0000000000000..bf988ee3f519c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/InternalKeyContext.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.KeyGroupRange; + +/** + * This interface is the current context of a keyed state. It provides information about the currently selected key in + * the context, the corresponding key-group, and other key and key-grouping related information. + *

+ * The typical use case for this interface is providing a view on the current-key selection aspects of + * {@link org.apache.flink.runtime.state.KeyedStateBackend}. + */ +@Internal +public interface InternalKeyContext { + + /** + * Used by states to access the current key. + */ + K getCurrentKey(); + + /** + * Returns the key-group to which the current key belongs. + */ + int getCurrentKeyGroupIndex(); + + /** + * Returns the number of key-groups aka max parallelism. + */ + int getNumberOfKeyGroups(); + + /** + * Returns the key groups for this backend. + */ + KeyGroupRange getKeyGroupRange(); + + /** + * {@link TypeSerializer} for the state backend key type. + */ + TypeSerializer getKeySerializer(); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateEntry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateEntry.java new file mode 100644 index 0000000000000..d32e82538463c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateEntry.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +/** + * Interface of entries in a state table. Entries are triple of key, namespace, and state. + * + * @param type of key. + * @param type of namespace. + * @param type of state. + */ +public interface StateEntry { + + /** + * Returns the key of this entry. + */ + K getKey(); + + /** + * Returns the namespace of this entry. + */ + N getNamespace(); + + /** + * Returns the state of this entry. + */ + S getState(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTable.java new file mode 100644 index 0000000000000..c1db7e8f8b80b --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTable.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo; +import org.apache.flink.runtime.state.StateTransformationFunction; +import org.apache.flink.util.Preconditions; + +/** + * Base class for state tables. Accesses to state are typically scoped by the currently active key, as provided + * through the {@link InternalKeyContext}. + * + * @param type of key + * @param type of namespace + * @param type of state + */ +public abstract class StateTable { + + /** + * The key context view on the backend. This provides information, such as the currently active key. + */ + protected final InternalKeyContext keyContext; + + /** + * Combined meta information such as name and serializers for this state + */ + protected RegisteredBackendStateMetaInfo metaInfo; + + /** + * + * @param keyContext the key context provides the key scope for all put/get/delete operations. + * @param metaInfo the meta information, including the type serializer for state copy-on-write. + */ + public StateTable(InternalKeyContext keyContext, RegisteredBackendStateMetaInfo metaInfo) { + this.keyContext = Preconditions.checkNotNull(keyContext); + this.metaInfo = Preconditions.checkNotNull(metaInfo); + } + + // Main interface methods of StateTable ------------------------------------------------------- + + /** + * Returns whether this {@link NestedMapsStateTable} is empty. + * + * @return {@code true} if this {@link NestedMapsStateTable} has no elements, {@code false} + * otherwise. + * @see #size() + */ + public boolean isEmpty() { + return size() == 0; + } + + /** + * Returns the total number of entries in this {@link NestedMapsStateTable}. This is the sum of both sub-tables. + * + * @return the number of entries in this {@link NestedMapsStateTable}. + */ + public abstract int size(); + + /** + * Returns the state of the mapping for the composite of active key and given namespace. + * + * @param namespace the namespace. Not null. + * @return the states of the mapping with the specified key/namespace composite key, or {@code null} + * if no mapping for the specified key is found. + */ + public abstract S get(N namespace); + + /** + * Returns whether this table contains a mapping for the composite of active key and given namespace. + * + * @param namespace the namespace in the composite key to search for. Not null. + * @return {@code true} if this map contains the specified key/namespace composite key, + * {@code false} otherwise. + */ + public abstract boolean containsKey(N namespace); + + /** + * Maps the composite of active key and given namespace to the specified state. This method should be preferred + * over {@link #putAndGetOld(N, S)} (Namespace, State)} when the caller is not interested in the old state. + * + * @param namespace the namespace. Not null. + * @param state the state. Can be null. + */ + public abstract void put(N namespace, S state); + + /** + * Maps the composite of active key and given namespace to the specified state. Returns the previous state that + * was registered under the composite key. + * + * @param namespace the namespace. Not null. + * @param state the state. Can be null. + * @return the state of any previous mapping with the specified key or + * {@code null} if there was no such mapping. + */ + public abstract S putAndGetOld(N namespace, S state); + + /** + * Removes the mapping for the composite of active key and given namespace. This method should be preferred + * over {@link #removeAndGetOld(N)} when the caller is not interested in the old state. + * + * @param namespace the namespace of the mapping to remove. Not null. + */ + public abstract void remove(N namespace); + + /** + * Removes the mapping for the composite of active key and given namespace, returning the state that was + * found under the entry. + * + * @param namespace the namespace of the mapping to remove. Not null. + * @return the state of the removed mapping or {@code null} if no mapping + * for the specified key was found. + */ + public abstract S removeAndGetOld(N namespace); + + /** + * Applies the given {@link StateTransformationFunction} to the state (1st input argument), using the given value as + * second input argument. The result of {@link StateTransformationFunction#apply(Object, Object)} is then stored as + * the new state. This function is basically an optimization for get-update-put pattern. + * + * @param namespace the namespace. Not null. + * @param value the value to use in transforming the state. Can be null. + * @param transformation the transformation function. + * @throws Exception if some exception happens in the transformation function. + */ + public abstract void transform( + N namespace, + T value, + StateTransformationFunction transformation) throws Exception; + + // For queryable state ------------------------------------------------------------------------ + + /** + * Returns the state for the composite of active key and given namespace. This is typically used by + * queryable state. + * + * @param key the key. Not null. + * @param namespace the namespace. Not null. + * @return the state of the mapping with the specified key/namespace composite key, or {@code null} + * if no mapping for the specified key is found. + */ + public abstract S get(K key, N namespace); + + // Meta data setter / getter and toString ----------------------------------------------------- + + public TypeSerializer getStateSerializer() { + return metaInfo.getStateSerializer(); + } + + public TypeSerializer getNamespaceSerializer() { + return metaInfo.getNamespaceSerializer(); + } + + public RegisteredBackendStateMetaInfo getMetaInfo() { + return metaInfo; + } + + public void setMetaInfo(RegisteredBackendStateMetaInfo metaInfo) { + this.metaInfo = metaInfo; + } + + // Snapshot / Restore ------------------------------------------------------------------------- + + abstract StateTableSnapshot createSnapshot(); + + public abstract void put(K key, int keyGroup, N namespace, S state); + + // For testing -------------------------------------------------------------------------------- + + @VisibleForTesting + public abstract int sizeOfNamespace(Object namespace); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableByKeyGroupReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableByKeyGroupReader.java new file mode 100644 index 0000000000000..41f0abd067db9 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableByKeyGroupReader.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.core.memory.DataInputView; + +import java.io.IOException; + +/** + * Interface for state de-serialization into {@link org.apache.flink.runtime.state.heap.StateTable}s by key-group. + */ +interface StateTableByKeyGroupReader { + + /** + * Read the data for the specified key-group from the input. + * + * @param div the input + * @param keyGroupId the key-group to write + * @throws IOException on write related problems + */ + void readMappingsInKeyGroup(DataInputView div, int keyGroupId) throws IOException; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableByKeyGroupReaders.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableByKeyGroupReaders.java new file mode 100644 index 0000000000000..2b5f15a70ad36 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableByKeyGroupReaders.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; + +import java.io.IOException; + +/** + * This class provides a static factory method to create different implementations of {@link StateTableByKeyGroupReader} + * depending on the provided serialization format version. + *

+ * The implementations are also located here as inner classes. + */ +class StateTableByKeyGroupReaders { + + /** + * Creates a new StateTableByKeyGroupReader that inserts de-serialized mappings into the given table, using the + * de-serialization algorithm that matches the given version. + * + * @param table the {@link org.apache.flink.runtime.state.heap.StateTable} into which de-serialized mappings are inserted. + * @param version version for the de-serialization algorithm. + * @param type of key. + * @param type of namespace. + * @param type of state. + * @return the appropriate reader. + */ + static StateTableByKeyGroupReader readerForVersion(StateTable table, int version) { + switch (version) { + case 1: + return new StateTableByKeyGroupReaderV1<>(table); + case 2: + return new StateTableByKeyGroupReaderV2<>(table); + default: + throw new IllegalArgumentException("Unknown version: " + version); + } + } + + static abstract class AbstractStateTableByKeyGroupReader + implements StateTableByKeyGroupReader { + + protected final StateTable stateTable; + + AbstractStateTableByKeyGroupReader(StateTable stateTable) { + this.stateTable = stateTable; + } + + @Override + public abstract void readMappingsInKeyGroup(DataInputView div, int keyGroupId) throws IOException; + + protected TypeSerializer getKeySerializer() { + return stateTable.keyContext.getKeySerializer(); + } + + protected TypeSerializer getNamespaceSerializer() { + return stateTable.getNamespaceSerializer(); + } + + protected TypeSerializer getStateSerializer() { + return stateTable.getStateSerializer(); + } + } + + static final class StateTableByKeyGroupReaderV1 + extends AbstractStateTableByKeyGroupReader { + + StateTableByKeyGroupReaderV1(StateTable stateTable) { + super(stateTable); + } + + @Override + public void readMappingsInKeyGroup(DataInputView inView, int keyGroupId) throws IOException { + + if (inView.readByte() == 0) { + return; + } + + final TypeSerializer keySerializer = getKeySerializer(); + final TypeSerializer namespaceSerializer = getNamespaceSerializer(); + final TypeSerializer stateSerializer = getStateSerializer(); + + // V1 uses kind of namespace compressing format + int numNamespaces = inView.readInt(); + for (int k = 0; k < numNamespaces; k++) { + N namespace = namespaceSerializer.deserialize(inView); + int numEntries = inView.readInt(); + for (int l = 0; l < numEntries; l++) { + K key = keySerializer.deserialize(inView); + S state = stateSerializer.deserialize(inView); + stateTable.put(key, keyGroupId, namespace, state); + } + } + } + } + + private static final class StateTableByKeyGroupReaderV2 + extends AbstractStateTableByKeyGroupReader { + + StateTableByKeyGroupReaderV2(StateTable stateTable) { + super(stateTable); + } + + @Override + public void readMappingsInKeyGroup(DataInputView inView, int keyGroupId) throws IOException { + + final TypeSerializer keySerializer = getKeySerializer(); + final TypeSerializer namespaceSerializer = getNamespaceSerializer(); + final TypeSerializer stateSerializer = getStateSerializer(); + + int numKeys = inView.readInt(); + for (int i = 0; i < numKeys; ++i) { + N namespace = namespaceSerializer.deserialize(inView); + K key = keySerializer.deserialize(inView); + S state = stateSerializer.deserialize(inView); + stateTable.put(key, keyGroupId, namespace, state); + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableSnapshot.java new file mode 100644 index 0000000000000..184cd59c44453 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/async/StateTableSnapshot.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; + +/** + * Interface for the snapshots of a {@link org.apache.flink.runtime.state.heap.StateTable}. Offers a way to serialize the snapshot (by key-group). All + * snapshots should be released after usage. + */ +interface StateTableSnapshot { + + /** + * Writes the data for the specified key-group to the output. + * + * @param dov the output + * @param keyGroupId the key-group to write + * @throws IOException on write related problems + */ + void writeMappingsInKeyGroup(DataOutputView dov, int keyGroupId) throws IOException; + + /** + * Release the snapshot. All snapshots should be released when they are no longer used because some implementation + * can only release resources after a release. + */ + void release(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/async/AsyncMemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/async/AsyncMemoryStateBackend.java new file mode 100644 index 0000000000000..54a208a9d7e71 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/async/AsyncMemoryStateBackend.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.memory.async; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.heap.async.AsyncHeapKeyedStateBackend; +import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; + +import java.io.IOException; + +/** + * A {@link AbstractStateBackend} that stores all its data and checkpoints in memory and has no + * capabilities to spill to disk. Checkpoints are serialized and the serialized data is + * transferred + */ +public class AsyncMemoryStateBackend extends AbstractStateBackend { + + private static final long serialVersionUID = 4109305377809414635L; + + /** The default maximal size that the snapshotted memory state may have (5 MiBytes) */ + private static final int DEFAULT_MAX_STATE_SIZE = 5 * 1024 * 1024; + + /** The maximal size that the snapshotted memory state may have */ + private final int maxStateSize; + + /** + * Creates a new memory state backend that accepts states whose serialized forms are + * up to the default state size (5 MB). + */ + public AsyncMemoryStateBackend() { + this(DEFAULT_MAX_STATE_SIZE); + } + + /** + * Creates a new memory state backend that accepts states whose serialized forms are + * up to the given number of bytes. + * + * @param maxStateSize The maximal size of the serialized state + */ + public AsyncMemoryStateBackend(int maxStateSize) { + this.maxStateSize = maxStateSize; + } + + @Override + public String toString() { + return "MemoryStateBackend (data in heap memory / checkpoints to JobManager)"; + } + + @Override + public CheckpointStreamFactory createStreamFactory( + JobID jobId, String operatorIdentifier) throws IOException { + return new MemCheckpointStreamFactory(maxStateSize); + } + + @Override + public AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry) throws IOException { + + return new AsyncHeapKeyedStateBackend<>( + kvStateRegistry, + keySerializer, + env.getUserClassLoader(), + numberOfKeyGroups, + keyGroupRange); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncFileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncFileStateBackendTest.java new file mode 100644 index 0000000000000..255bd46fc613d --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncFileStateBackendTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.commons.io.FileUtils; +import org.apache.flink.api.common.JobID; +import org.apache.flink.core.fs.Path; +import org.apache.flink.core.testutils.CommonTestUtils; +import org.apache.flink.runtime.state.filesystem.FileStateHandle; +import org.apache.flink.runtime.state.filesystem.FsStateBackend; +import org.apache.flink.runtime.state.filesystem.async.AsyncFsStateBackend; +import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.util.Random; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class AsyncFileStateBackendTest extends StateBackendTestBase { + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Override + protected AsyncFsStateBackend getStateBackend() throws Exception { + File checkpointPath = tempFolder.newFolder(); + return new AsyncFsStateBackend(localFileUri(checkpointPath)); + } + + // disable these because the verification does not work for this state backend + @Override + @Test + public void testValueStateRestoreWithWrongSerializers() {} + + @Override + @Test + public void testListStateRestoreWithWrongSerializers() {} + + @Override + @Test + public void testReducingStateRestoreWithWrongSerializers() {} + + @Test + public void testStateOutputStream() throws IOException { + File basePath = tempFolder.newFolder().getAbsoluteFile(); + + try { + // the state backend has a very low in-mem state threshold (15 bytes) + FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(basePath.toURI(), 15)); + JobID jobId = new JobID(); + + // we know how FsCheckpointStreamFactory is implemented so we know where it + // will store checkpoints + File checkpointPath = new File(basePath.getAbsolutePath(), jobId.toString()); + + CheckpointStreamFactory streamFactory = backend.createStreamFactory(jobId, "test_op"); + + byte[] state1 = new byte[1274673]; + byte[] state2 = new byte[1]; + byte[] state3 = new byte[0]; + byte[] state4 = new byte[177]; + + Random rnd = new Random(); + rnd.nextBytes(state1); + rnd.nextBytes(state2); + rnd.nextBytes(state3); + rnd.nextBytes(state4); + + long checkpointId = 97231523452L; + + CheckpointStreamFactory.CheckpointStateOutputStream stream1 = + streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + + CheckpointStreamFactory.CheckpointStateOutputStream stream2 = + streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + + CheckpointStreamFactory.CheckpointStateOutputStream stream3 = + streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + + stream1.write(state1); + stream2.write(state2); + stream3.write(state3); + + FileStateHandle handle1 = (FileStateHandle) stream1.closeAndGetHandle(); + ByteStreamStateHandle handle2 = (ByteStreamStateHandle) stream2.closeAndGetHandle(); + ByteStreamStateHandle handle3 = (ByteStreamStateHandle) stream3.closeAndGetHandle(); + + // use with try-with-resources + StreamStateHandle handle4; + try (CheckpointStreamFactory.CheckpointStateOutputStream stream4 = + streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis())) { + stream4.write(state4); + handle4 = stream4.closeAndGetHandle(); + } + + // close before accessing handle + CheckpointStreamFactory.CheckpointStateOutputStream stream5 = + streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + stream5.write(state4); + stream5.close(); + try { + stream5.closeAndGetHandle(); + fail(); + } catch (IOException e) { + // uh-huh + } + + validateBytesInStream(handle1.openInputStream(), state1); + handle1.discardState(); + assertFalse(isDirectoryEmpty(basePath)); + ensureLocalFileDeleted(handle1.getFilePath()); + + validateBytesInStream(handle2.openInputStream(), state2); + handle2.discardState(); + + // nothing was written to the stream, so it will return nothing + assertNull(handle3); + + validateBytesInStream(handle4.openInputStream(), state4); + handle4.discardState(); + assertTrue(isDirectoryEmpty(checkpointPath)); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + + private static void ensureLocalFileDeleted(Path path) { + URI uri = path.toUri(); + if ("file".equals(uri.getScheme())) { + File file = new File(uri.getPath()); + assertFalse("file not properly deleted", file.exists()); + } + else { + throw new IllegalArgumentException("not a local path"); + } + } + + private static void deleteDirectorySilently(File dir) { + try { + FileUtils.deleteDirectory(dir); + } + catch (IOException ignored) {} + } + + private static boolean isDirectoryEmpty(File directory) { + if (!directory.exists()) { + return true; + } + String[] nested = directory.list(); + return nested == null || nested.length == 0; + } + + private static String localFileUri(File path) { + return path.toURI().toString(); + } + + private static void validateBytesInStream(InputStream is, byte[] data) throws IOException { + try { + byte[] holder = new byte[data.length]; + + int pos = 0; + int read; + while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) { + pos += read; + } + + assertEquals("not enough data", holder.length, pos); + assertEquals("too much data", -1, is.read()); + assertArrayEquals("wrong data", data, holder); + } finally { + is.close(); + } + } + + @Test + public void testConcurrentMapIfQueryable() throws Exception { + //unsupported + } + +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncMemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncMemoryStateBackendTest.java new file mode 100644 index 0000000000000..b1a323bc9e5db --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncMemoryStateBackendTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.runtime.state.heap.async.AsyncHeapKeyedStateBackend; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.state.memory.async.AsyncMemoryStateBackend; +import org.junit.Test; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.HashMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Tests for the {@link MemoryStateBackend}. + */ +public class AsyncMemoryStateBackendTest extends StateBackendTestBase { + + @Override + protected AsyncMemoryStateBackend getStateBackend() throws Exception { + return new AsyncMemoryStateBackend(); + } + + // disable these because the verification does not work for this state backend + @Override + @Test + public void testValueStateRestoreWithWrongSerializers() {} + + @Override + @Test + public void testListStateRestoreWithWrongSerializers() {} + + @Override + @Test + public void testReducingStateRestoreWithWrongSerializers() {} + + @Test + @SuppressWarnings("unchecked, deprecation") + public void testNumStateEntries() throws Exception { + KeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); + + ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", String.class, null); + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); + + AsyncHeapKeyedStateBackend heapBackend = (AsyncHeapKeyedStateBackend) backend; + + assertEquals(0, heapBackend.numStateEntries()); + + ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + + backend.setCurrentKey(0); + state.update("hello"); + state.update("ciao"); + + assertEquals(1, heapBackend.numStateEntries()); + + backend.setCurrentKey(42); + state.update("foo"); + + assertEquals(2, heapBackend.numStateEntries()); + + backend.setCurrentKey(0); + state.clear(); + + assertEquals(1, heapBackend.numStateEntries()); + + backend.setCurrentKey(42); + state.clear(); + + assertEquals(0, heapBackend.numStateEntries()); + + backend.dispose(); + } + + @Test + public void testOversizedState() { + try { + MemoryStateBackend backend = new MemoryStateBackend(10); + CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op"); + + HashMap state = new HashMap<>(); + state.put("hey there", 2); + state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77); + + try { + CheckpointStreamFactory.CheckpointStateOutputStream outStream = + streamFactory.createCheckpointStateOutputStream(12, 459); + + ObjectOutputStream oos = new ObjectOutputStream(outStream); + oos.writeObject(state); + + oos.flush(); + + outStream.closeAndGetHandle(); + + fail("this should cause an exception"); + } + catch (IOException e) { + // now darling, isn't that exactly what we wanted? + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testStateStream() { + try { + MemoryStateBackend backend = new MemoryStateBackend(); + CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op"); + + HashMap state = new HashMap<>(); + state.put("hey there", 2); + state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77); + + CheckpointStreamFactory.CheckpointStateOutputStream os = streamFactory.createCheckpointStateOutputStream(1, 2); + ObjectOutputStream oos = new ObjectOutputStream(os); + oos.writeObject(state); + oos.flush(); + StreamStateHandle handle = os.closeAndGetHandle(); + + assertNotNull(handle); + + try (ObjectInputStream ois = new ObjectInputStream(handle.openInputStream())) { + assertEquals(state, ois.readObject()); + assertTrue(ois.available() <= 0); + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testOversizedStateStream() { + try { + MemoryStateBackend backend = new MemoryStateBackend(10); + CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op"); + + HashMap state = new HashMap<>(); + state.put("hey there", 2); + state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77); + + CheckpointStreamFactory.CheckpointStateOutputStream os = streamFactory.createCheckpointStateOutputStream(1, 2); + ObjectOutputStream oos = new ObjectOutputStream(os); + + try { + oos.writeObject(state); + oos.flush(); + os.closeAndGetHandle(); + fail("this should cause an exception"); + } + catch (IOException e) { + // oh boy! what an exception! + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testConcurrentMapIfQueryable() throws Exception { + //unsupported + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java index c267afca5ae12..b196e718b5f25 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java @@ -61,7 +61,7 @@ public void testListStateRestoreWithWrongSerializers() {} public void testReducingStateRestoreWithWrongSerializers() {} @Test - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked, deprecation") public void testNumStateEntries() throws Exception { KeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index e821bcfed68ee..61de1e38fc478 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -39,6 +39,7 @@ import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.core.testutils.CheckedThread; +import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; @@ -48,9 +49,13 @@ import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.heap.AbstractHeapState; import org.apache.flink.runtime.state.heap.StateTable; +import org.apache.flink.runtime.state.heap.async.AsyncHeapKeyedStateBackend; +import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory; import org.apache.flink.types.IntValue; import org.apache.flink.util.FutureUtil; +import org.apache.flink.util.IOUtils; import org.apache.flink.util.TestLogger; +import org.junit.Assert; import org.junit.Test; import java.io.IOException; @@ -60,6 +65,7 @@ import java.util.Timer; import java.util.TimerTask; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; import java.util.concurrent.RunnableFuture; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -1432,6 +1438,150 @@ public void testEmptyStateCheckpointing() { } } + @Test + public void testAsyncSnapshot() throws Exception { + OneShotLatch waiter = new OneShotLatch(); + BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); + streamFactory.setWaiterLatch(waiter); + + AbstractKeyedStateBackend backend = null; + KeyGroupsStateHandle stateHandle = null; + + try { + backend = createKeyedBackend(IntSerializer.INSTANCE); + + if (!(backend instanceof AsyncHeapKeyedStateBackend)) { + return; + } + + ValueState valueState = backend.createValueState( + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>("test", IntSerializer.INSTANCE)); + + ((KvState)valueState).setCurrentNamespace(VoidNamespace.INSTANCE); + + for (int i = 0; i < 10; ++i) { + backend.setCurrentKey(i); + valueState.update(i); + } + + RunnableFuture snapshot = + backend.snapshot(0L, 0L, streamFactory); + Thread runner = new Thread(snapshot); + runner.start(); + for (int i = 0; i < 20; ++i) { + backend.setCurrentKey(i); + valueState.update(i + 1); + if (10 == i) { + waiter.await(); + } + } + + runner.join(); + stateHandle = snapshot.get(); + + // test isolation + for (int i = 0; i < 20; ++i) { + backend.setCurrentKey(i); + Assert.assertEquals(i + 1, (int) valueState.value()); + } + + } finally { + if (null != backend) { + IOUtils.closeQuietly(backend); + backend.dispose(); + } + } + + Assert.assertNotNull(stateHandle); + + backend = createKeyedBackend(IntSerializer.INSTANCE); + try { + backend.restore(Collections.singleton(stateHandle)); + ValueState valueState = backend.createValueState( + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>("test", IntSerializer.INSTANCE)); + + ((KvState)valueState).setCurrentNamespace(VoidNamespace.INSTANCE); + + for (int i = 0; i < 10; ++i) { + backend.setCurrentKey(i); + Assert.assertEquals(i, (int) valueState.value()); + } + + backend.setCurrentKey(11); + Assert.assertEquals(null, valueState.value()); + } finally { + if (null != backend) { + IOUtils.closeQuietly(backend); + backend.dispose(); + } + } + } + + @Test + public void testAsyncSnapshotCancellation() throws Exception { + OneShotLatch blocker = new OneShotLatch(); + OneShotLatch waiter = new OneShotLatch(); + BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); + streamFactory.setWaiterLatch(waiter); + streamFactory.setBlockerLatch(blocker); + streamFactory.setAfterNumberInvocations(100); + + AbstractKeyedStateBackend backend = null; + try { + backend = createKeyedBackend(IntSerializer.INSTANCE); + + if (!(backend instanceof AsyncHeapKeyedStateBackend)) { + return; + } + + ValueState valueState = backend.createValueState( + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>("test", IntSerializer.INSTANCE)); + + ((KvState)valueState).setCurrentNamespace(VoidNamespace.INSTANCE); + + for (int i = 0; i < 10; ++i) { + backend.setCurrentKey(i); + valueState.update(i); + } + + RunnableFuture snapshot = + backend.snapshot(0L, 0L, streamFactory); + + Thread runner = new Thread(snapshot); + runner.start(); + + // wait until the code reached some stream read + waiter.await(); + + // close the backend to see if the close is propagated to the stream + backend.close(); + + //unblock the stream so that it can run into the IOException + blocker.trigger(); + + //dispose the backend + backend.dispose(); + + runner.join(); + + try { + snapshot.get(); + fail("Close was not propagated."); + } catch (ExecutionException ex) { + //ignore + } + + } finally { + if (null != backend) { + IOUtils.closeQuietly(backend); + backend.dispose(); + } + } + } + private static class AppendingReduce implements ReduceFunction { @Override public String reduce(String value1, String value2) throws Exception { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTableTest.java new file mode 100644 index 0000000000000..fb36d67071558 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/CopyOnWriteStateTableTest.java @@ -0,0 +1,486 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.ArrayListSerializer; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo; +import org.apache.flink.runtime.state.StateTransformationFunction; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +public class CopyOnWriteStateTableTest { + + /** + * Testing the basic map operations. + */ + @Test + public void testPutGetRemoveContainsTransform() throws Exception { + RegisteredBackendStateMetaInfo> metaInfo = + new RegisteredBackendStateMetaInfo<>( + StateDescriptor.Type.UNKNOWN, + "test", + IntSerializer.INSTANCE, + new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects. + + final MockInternalKeyContext keyContext = new MockInternalKeyContext<>(IntSerializer.INSTANCE); + + final CopyOnWriteStateTable> stateTable = + new CopyOnWriteStateTable<>(keyContext, metaInfo); + + ArrayList state_1_1 = new ArrayList<>(); + state_1_1.add(41); + ArrayList state_2_1 = new ArrayList<>(); + state_2_1.add(42); + ArrayList state_1_2 = new ArrayList<>(); + state_1_2.add(43); + + Assert.assertNull(stateTable.putAndGetOld(1, 1, state_1_1)); + Assert.assertEquals(state_1_1, stateTable.get(1, 1)); + Assert.assertEquals(1, stateTable.size()); + + Assert.assertNull(stateTable.putAndGetOld(2, 1, state_2_1)); + Assert.assertEquals(state_2_1, stateTable.get(2, 1)); + Assert.assertEquals(2, stateTable.size()); + + Assert.assertNull(stateTable.putAndGetOld(1, 2, state_1_2)); + Assert.assertEquals(state_1_2, stateTable.get(1, 2)); + Assert.assertEquals(3, stateTable.size()); + + Assert.assertTrue(stateTable.containsKey(2, 1)); + Assert.assertFalse(stateTable.containsKey(3, 1)); + Assert.assertFalse(stateTable.containsKey(2, 3)); + stateTable.put(2, 1, null); + Assert.assertTrue(stateTable.containsKey(2, 1)); + Assert.assertEquals(3, stateTable.size()); + Assert.assertNull(stateTable.get(2, 1)); + stateTable.put(2, 1, state_2_1); + Assert.assertEquals(3, stateTable.size()); + + Assert.assertEquals(state_2_1, stateTable.removeAndGetOld(2, 1)); + Assert.assertFalse(stateTable.containsKey(2, 1)); + Assert.assertEquals(2, stateTable.size()); + + stateTable.remove(1, 2); + Assert.assertFalse(stateTable.containsKey(1, 2)); + Assert.assertEquals(1, stateTable.size()); + + Assert.assertNull(stateTable.removeAndGetOld(4, 2)); + Assert.assertEquals(1, stateTable.size()); + + StateTransformationFunction, Integer> function = + new StateTransformationFunction, Integer>() { + @Override + public ArrayList apply(ArrayList previousState, Integer value) throws Exception { + previousState.add(value); + return previousState; + } + }; + + final int value = 4711; + stateTable.transform(1, 1, value, function); + state_1_1 = function.apply(state_1_1, value); + Assert.assertEquals(state_1_1, stateTable.get(1, 1)); + } + + /** + * This test triggers incremental rehash and tests for corruptions. + */ + @Test + public void testIncrementalRehash() { + RegisteredBackendStateMetaInfo> metaInfo = + new RegisteredBackendStateMetaInfo<>( + StateDescriptor.Type.UNKNOWN, + "test", + IntSerializer.INSTANCE, + new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects. + + final MockInternalKeyContext keyContext = new MockInternalKeyContext<>(IntSerializer.INSTANCE); + + final CopyOnWriteStateTable> stateTable = + new CopyOnWriteStateTable<>(keyContext, metaInfo); + + int insert = 0; + int remove = 0; + while (!stateTable.isRehashing()) { + stateTable.put(insert++, 0, new ArrayList()); + if (insert % 8 == 0) { + stateTable.remove(remove++, 0); + } + } + Assert.assertEquals(insert - remove, stateTable.size()); + while (stateTable.isRehashing()) { + stateTable.put(insert++, 0, new ArrayList()); + if (insert % 8 == 0) { + stateTable.remove(remove++, 0); + } + } + Assert.assertEquals(insert - remove, stateTable.size()); + + for (int i = 0; i < insert; ++i) { + if (i < remove) { + Assert.assertFalse(stateTable.containsKey(i, 0)); + } else { + Assert.assertTrue(stateTable.containsKey(i, 0)); + } + } + } + + /** + * This test does some random modifications to a state table and a reference (hash map). Then draws snapshots, + * performs more modifications and checks snapshot integrity. + */ + @Test + public void testRandomModificationsAndCopyOnWriteIsolation() throws Exception { + + final RegisteredBackendStateMetaInfo> metaInfo = + new RegisteredBackendStateMetaInfo<>( + StateDescriptor.Type.UNKNOWN, + "test", + IntSerializer.INSTANCE, + new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects. + + final MockInternalKeyContext keyContext = new MockInternalKeyContext<>(IntSerializer.INSTANCE); + + final CopyOnWriteStateTable> stateTable = + new CopyOnWriteStateTable<>(keyContext, metaInfo); + + final HashMap, ArrayList> referenceMap = new HashMap<>(); + + final Random random = new Random(42); + + // holds snapshots from the map under test + CopyOnWriteStateTable.StateTableEntry>[] snapshot = null; + int snapshotSize = 0; + + // holds a reference snapshot from our reference map that we compare against + Tuple3>[] reference = null; + + int val = 0; + + + int snapshotCounter = 0; + int referencedSnapshotId = 0; + + final StateTransformationFunction, Integer> transformationFunction = + new StateTransformationFunction, Integer>() { + @Override + public ArrayList apply(ArrayList previousState, Integer value) throws Exception { + if (previousState == null) { + previousState = new ArrayList<>(); + } + previousState.add(value); + // we give back the original, attempting to spot errors in to copy-on-write + return previousState; + } + }; + + // the main loop for modifications + for (int i = 0; i < 10_000_000; ++i) { + + int key = random.nextInt(20); + int namespace = random.nextInt(4); + Tuple2 compositeKey = new Tuple2<>(key, namespace); + + int op = random.nextInt(7); + + ArrayList state = null; + ArrayList referenceState = null; + + switch (op) { + case 0: + case 1: { + state = stateTable.get(key, namespace); + referenceState = referenceMap.get(compositeKey); + if (null == state) { + state = new ArrayList<>(); + stateTable.put(key, namespace, state); + referenceState = new ArrayList<>(); + referenceMap.put(compositeKey, referenceState); + } + break; + } + case 2: { + stateTable.put(key, namespace, new ArrayList()); + referenceMap.put(compositeKey, new ArrayList()); + break; + } + case 3: { + state = stateTable.putAndGetOld(key, namespace, new ArrayList()); + referenceState = referenceMap.put(compositeKey, new ArrayList()); + break; + } + case 4: { + stateTable.remove(key, namespace); + referenceMap.remove(compositeKey); + break; + } + case 5: { + state = stateTable.removeAndGetOld(key, namespace); + referenceState = referenceMap.remove(compositeKey); + break; + } + case 6: { + final int updateValue = random.nextInt(1000); + stateTable.transform(key, namespace, updateValue, transformationFunction); + referenceMap.put(compositeKey, transformationFunction.apply( + referenceMap.remove(compositeKey), updateValue)); + break; + } + default: { + Assert.fail("Unknown op-code " + op); + } + } + + Assert.assertEquals(referenceMap.size(), stateTable.size()); + + if (state != null) { + // mutate the states a bit... + if (random.nextBoolean() && !state.isEmpty()) { + state.remove(state.size() - 1); + referenceState.remove(referenceState.size() - 1); + } else { + state.add(val); + referenceState.add(val); + ++val; + } + } + + Assert.assertEquals(referenceState, state); + + // snapshot triggering / comparison / release + if (i > 0 && i % 500 == 0) { + + if (snapshot != null) { + // check our referenced snapshot + deepCheck(reference, convert(snapshot, snapshotSize)); + + if (i % 1_000 == 0) { + // draw and release some other snapshot while holding on the old snapshot + ++snapshotCounter; + stateTable.snapshotTableArrays(); + stateTable.releaseSnapshot(snapshotCounter); + } + + //release the snapshot after some time + if (i % 5_000 == 0) { + snapshot = null; + reference = null; + snapshotSize = 0; + stateTable.releaseSnapshot(referencedSnapshotId); + } + + } else { + // if there is no more referenced snapshot, we create one + ++snapshotCounter; + referencedSnapshotId = snapshotCounter; + snapshot = stateTable.snapshotTableArrays(); + snapshotSize = stateTable.size(); + reference = manualDeepDump(referenceMap); + } + } + } + } + + /** + * This tests for the copy-on-write contracts, e.g. ensures that no copy-on-write is active after all snapshots are + * released. + */ + @Test + public void testCopyOnWriteContracts() { + RegisteredBackendStateMetaInfo> metaInfo = + new RegisteredBackendStateMetaInfo<>( + StateDescriptor.Type.UNKNOWN, + "test", + IntSerializer.INSTANCE, + new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects. + + final MockInternalKeyContext keyContext = new MockInternalKeyContext<>(IntSerializer.INSTANCE); + + final CopyOnWriteStateTable> stateTable = + new CopyOnWriteStateTable<>(keyContext, metaInfo); + + ArrayList originalState1 = new ArrayList<>(1); + ArrayList originalState2 = new ArrayList<>(1); + ArrayList originalState3 = new ArrayList<>(1); + ArrayList originalState4 = new ArrayList<>(1); + ArrayList originalState5 = new ArrayList<>(1); + + originalState1.add(1); + originalState2.add(2); + originalState3.add(3); + originalState4.add(4); + originalState5.add(5); + + stateTable.put(1, 1, originalState1); + stateTable.put(2, 1, originalState2); + stateTable.put(4, 1, originalState4); + stateTable.put(5, 1, originalState5); + + // no snapshot taken, we get the original back + Assert.assertTrue(stateTable.get(1, 1) == originalState1); + CopyOnWriteStateTableSnapshot> snapshot1 = stateTable.createSnapshot(); + // after snapshot1 is taken, we get a copy... + final ArrayList copyState = stateTable.get(1, 1); + Assert.assertFalse(copyState == originalState1); + // ...and the copy is equal + Assert.assertEquals(originalState1, copyState); + + // we make an insert AFTER snapshot1 + stateTable.put(3, 1, originalState3); + + // on repeated lookups, we get the same copy because no further snapshot was taken + Assert.assertTrue(copyState == stateTable.get(1, 1)); + + // we take snapshot2 + CopyOnWriteStateTableSnapshot> snapshot2 = stateTable.createSnapshot(); + // after the second snapshot, copy-on-write is active again for old entries + Assert.assertFalse(copyState == stateTable.get(1, 1)); + // and equality still holds + Assert.assertEquals(copyState, stateTable.get(1, 1)); + + // after releasing snapshot2 + stateTable.releaseSnapshot(snapshot2); + // we still get the original of the untouched late insert (after snapshot1) + Assert.assertTrue(originalState3 == stateTable.get(3, 1)); + // but copy-on-write is still active for older inserts (before snapshot1) + Assert.assertFalse(originalState4 == stateTable.get(4, 1)); + + // after releasing snapshot1 + stateTable.releaseSnapshot(snapshot1); + // no copy-on-write is active + Assert.assertTrue(originalState5 == stateTable.get(5, 1)); + } + + @SuppressWarnings("unchecked") + private static Tuple3[] convert(CopyOnWriteStateTable.StateTableEntry[] snapshot, int mapSize) { + + Tuple3[] result = new Tuple3[mapSize]; + int pos = 0; + for (CopyOnWriteStateTable.StateTableEntry entry : snapshot) { + while (null != entry) { + result[pos++] = new Tuple3<>(entry.getKey(), entry.getNamespace(), entry.getState()); + entry = entry.next; + } + } + Assert.assertEquals(mapSize, pos); + return result; + } + + @SuppressWarnings("unchecked") + private Tuple3>[] manualDeepDump( + HashMap, + ArrayList> map) { + + Tuple3>[] result = new Tuple3[map.size()]; + int pos = 0; + for (Map.Entry, ArrayList> entry : map.entrySet()) { + Integer key = entry.getKey().f0; + Integer namespace = entry.getKey().f1; + result[pos++] = new Tuple3<>(key, namespace, new ArrayList<>(entry.getValue())); + } + return result; + } + + private void deepCheck( + Tuple3>[] a, + Tuple3>[] b) { + + if (a == b) { + return; + } + + Assert.assertEquals(a.length, b.length); + + Comparator>> comparator = + new Comparator>>() { + + @Override + public int compare(Tuple3> o1, Tuple3> o2) { + int namespaceDiff = o1.f1 - o2.f1; + return namespaceDiff != 0 ? namespaceDiff : o1.f0 - o2.f0; + } + }; + + Arrays.sort(a, comparator); + Arrays.sort(b, comparator); + + for (int i = 0; i < a.length; ++i) { + Tuple3> av = a[i]; + Tuple3> bv = b[i]; + + Assert.assertEquals(av.f0, bv.f0); + Assert.assertEquals(av.f1, bv.f1); + Assert.assertEquals(av.f2, bv.f2); + } + } + + static class MockInternalKeyContext implements InternalKeyContext { + + private T key; + private final TypeSerializer serializer; + private final KeyGroupRange keyGroupRange; + + public MockInternalKeyContext(TypeSerializer serializer) { + this.serializer = serializer; + this.keyGroupRange = new KeyGroupRange(0, 0); + } + + public void setKey(T key) { + this.key = key; + } + + @Override + public T getCurrentKey() { + return key; + } + + @Override + public int getCurrentKeyGroupIndex() { + return 0; + } + + @Override + public int getNumberOfKeyGroups() { + return 1; + } + + @Override + public KeyGroupRange getKeyGroupRange() { + return keyGroupRange; + } + + @Override + public TypeSerializer getKeySerializer() { + return serializer; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapListStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapListStateTest.java new file mode 100644 index 0000000000000..a7c2d150697f0 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapListStateTest.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests for the simple Java heap objects implementation of the {@link ListState}. + */ +@SuppressWarnings("unchecked") +public class HeapListStateTest extends HeapStateBackendTestBase { + + @Test + public void testAddAndGet() throws Exception { + + final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); + stateDescr.initializeSerializerUnlessSet(new ExecutionConfig()); + + final AsyncHeapKeyedStateBackend keyedBackend = createKeyedBackend(); + + try { + ListState state = + keyedBackend.createListState(VoidNamespaceSerializer.INSTANCE, stateDescr); + + AbstractHeapMergingState mergingState = + (AbstractHeapMergingState) state; + + mergingState.setCurrentNamespace(VoidNamespace.INSTANCE); + + keyedBackend.setCurrentKey("abc"); + assertNull(state.get()); + + keyedBackend.setCurrentKey("def"); + assertNull(state.get()); + state.add(17L); + state.add(11L); + assertEquals(asList(17L, 11L), state.get()); + + keyedBackend.setCurrentKey("abc"); + assertNull(state.get()); + + keyedBackend.setCurrentKey("g"); + assertNull(state.get()); + state.add(1L); + state.add(2L); + + keyedBackend.setCurrentKey("def"); + assertEquals(asList(17L, 11L), state.get()); + state.clear(); + assertNull(state.get()); + + keyedBackend.setCurrentKey("g"); + state.add(3L); + state.add(2L); + state.add(1L); + + keyedBackend.setCurrentKey("def"); + assertNull(state.get()); + + keyedBackend.setCurrentKey("g"); + assertEquals(asList(1L, 2L, 3L, 2L, 1L), state.get()); + state.clear(); + + // make sure all lists / maps are cleared + + StateTable> stateTable = + ((HeapListState) state).getStateTable(); + + assertTrue(mergingState.getStateTable().isEmpty()); + } + finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + @Test + public void testMerging() throws Exception { + + final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); + stateDescr.initializeSerializerUnlessSet(new ExecutionConfig()); + + final Integer namespace1 = 1; + final Integer namespace2 = 2; + final Integer namespace3 = 3; + + final Set expectedResult = new HashSet<>(asList(11L, 22L, 33L, 44L, 55L)); + + final AsyncHeapKeyedStateBackend keyedBackend = createKeyedBackend(); + + try { + ListState state = keyedBackend.createListState(IntSerializer.INSTANCE, stateDescr); + + AbstractHeapMergingState mergingState = + (AbstractHeapMergingState) state; + + // populate the different namespaces + // - abc spreads the values over three namespaces + // - def spreads teh values over two namespaces (one empty) + // - ghi is empty + // - jkl has all elements already in the target namespace + // - mno has all elements already in one source namespace + + keyedBackend.setCurrentKey("abc"); + mergingState.setCurrentNamespace(namespace1); + state.add(33L); + state.add(55L); + + mergingState.setCurrentNamespace(namespace2); + state.add(22L); + state.add(11L); + + mergingState.setCurrentNamespace(namespace3); + state.add(44L); + + keyedBackend.setCurrentKey("def"); + mergingState.setCurrentNamespace(namespace1); + state.add(11L); + state.add(44L); + + mergingState.setCurrentNamespace(namespace3); + state.add(22L); + state.add(55L); + state.add(33L); + + keyedBackend.setCurrentKey("jkl"); + mergingState.setCurrentNamespace(namespace1); + state.add(11L); + state.add(22L); + state.add(33L); + state.add(44L); + state.add(55L); + + keyedBackend.setCurrentKey("mno"); + mergingState.setCurrentNamespace(namespace3); + state.add(11L); + state.add(22L); + state.add(33L); + state.add(44L); + state.add(55L); + + keyedBackend.setCurrentKey("abc"); + //TODO + mergingState.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + mergingState.setCurrentNamespace(namespace1); + validateResult(state.get(), expectedResult); + + keyedBackend.setCurrentKey("def"); + mergingState.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + mergingState.setCurrentNamespace(namespace1); + validateResult(state.get(), expectedResult); + + keyedBackend.setCurrentKey("ghi"); + mergingState.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + mergingState.setCurrentNamespace(namespace1); + assertNull(state.get()); + + keyedBackend.setCurrentKey("jkl"); + mergingState.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + mergingState.setCurrentNamespace(namespace1); + validateResult(state.get(), expectedResult); + + keyedBackend.setCurrentKey("mno"); + mergingState.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + mergingState.setCurrentNamespace(namespace1); + validateResult(state.get(), expectedResult); + + // make sure all lists / maps are cleared + + keyedBackend.setCurrentKey("abc"); + mergingState.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("def"); + mergingState.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("ghi"); + mergingState.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("jkl"); + mergingState.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("mno"); + mergingState.setCurrentNamespace(namespace1); + state.clear(); + + assertTrue(mergingState.getStateTable().isEmpty()); + } + finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + private static void validateResult(Iterable values, Set expected) { + int num = 0; + for (T v : values) { + num++; + assertTrue(expected.contains(v)); + } + + assertEquals(expected.size(), num); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapReducingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapReducingStateTest.java new file mode 100644 index 0000000000000..5da0fef602f04 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapReducingStateTest.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.junit.Test; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests for the simple Java heap objects implementation of the {@link ReducingState}. + */ +@SuppressWarnings("unchecked") +public class HeapReducingStateTest extends HeapStateBackendTestBase { + + @Test + public void testAddAndGet() throws Exception { + + final ReducingStateDescriptor stateDescr = + new ReducingStateDescriptor<>("my-state", new AddingFunction(), Long.class); + stateDescr.initializeSerializerUnlessSet(new ExecutionConfig()); + + final AsyncHeapKeyedStateBackend keyedBackend = createKeyedBackend(); + + try { + ReducingState reducingState = + keyedBackend.createReducingState(VoidNamespaceSerializer.INSTANCE, stateDescr); + + AbstractHeapMergingState state = + (AbstractHeapMergingState) reducingState; + + state.setCurrentNamespace(VoidNamespace.INSTANCE); + + keyedBackend.setCurrentKey("abc"); + assertNull(reducingState.get()); + + keyedBackend.setCurrentKey("def"); + assertNull(reducingState.get()); + reducingState.add(17L); + reducingState.add(11L); + assertEquals(28L, reducingState.get().longValue()); + + keyedBackend.setCurrentKey("abc"); + assertNull(reducingState.get()); + + keyedBackend.setCurrentKey("g"); + assertNull(reducingState.get()); + reducingState.add(1L); + reducingState.add(2L); + + keyedBackend.setCurrentKey("def"); + assertEquals(28L, reducingState.get().longValue()); + state.clear(); + assertNull(reducingState.get()); + + keyedBackend.setCurrentKey("g"); + reducingState.add(3L); + reducingState.add(2L); + reducingState.add(1L); + + keyedBackend.setCurrentKey("def"); + assertNull(reducingState.get()); + + keyedBackend.setCurrentKey("g"); + assertEquals(9L, reducingState.get().longValue()); + state.clear(); + + // make sure all lists / maps are cleared + assertTrue(state.getStateTable().isEmpty()); + } + finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + @Test + public void testMerging() throws Exception { + + final ReducingStateDescriptor stateDescr = new ReducingStateDescriptor<>( + "my-state", new AddingFunction(), Long.class); + stateDescr.initializeSerializerUnlessSet(new ExecutionConfig()); + + final Integer namespace1 = 1; + final Integer namespace2 = 2; + final Integer namespace3 = 3; + + final Long expectedResult = 165L; + + final AsyncHeapKeyedStateBackend keyedBackend = createKeyedBackend(); + + try { + final ReducingState reducingState = + keyedBackend.createReducingState(IntSerializer.INSTANCE, stateDescr); + + AbstractHeapMergingState state = + (AbstractHeapMergingState) reducingState; + + // populate the different namespaces + // - abc spreads the values over three namespaces + // - def spreads teh values over two namespaces (one empty) + // - ghi is empty + // - jkl has all elements already in the target namespace + // - mno has all elements already in one source namespace + + keyedBackend.setCurrentKey("abc"); + state.setCurrentNamespace(namespace1); + reducingState.add(33L); + reducingState.add(55L); + + state.setCurrentNamespace(namespace2); + reducingState.add(22L); + reducingState.add(11L); + + state.setCurrentNamespace(namespace3); + reducingState.add(44L); + + keyedBackend.setCurrentKey("def"); + state.setCurrentNamespace(namespace1); + reducingState.add(11L); + reducingState.add(44L); + + state.setCurrentNamespace(namespace3); + reducingState.add(22L); + reducingState.add(55L); + reducingState.add(33L); + + keyedBackend.setCurrentKey("jkl"); + state.setCurrentNamespace(namespace1); + reducingState.add(11L); + reducingState.add(22L); + reducingState.add(33L); + reducingState.add(44L); + reducingState.add(55L); + + keyedBackend.setCurrentKey("mno"); + state.setCurrentNamespace(namespace3); + reducingState.add(11L); + reducingState.add(22L); + reducingState.add(33L); + reducingState.add(44L); + reducingState.add(55L); + + keyedBackend.setCurrentKey("abc"); + state.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + state.setCurrentNamespace(namespace1); + assertEquals(expectedResult, reducingState.get()); + + keyedBackend.setCurrentKey("def"); + state.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + state.setCurrentNamespace(namespace1); + assertEquals(expectedResult, reducingState.get()); + + keyedBackend.setCurrentKey("ghi"); + state.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + state.setCurrentNamespace(namespace1); + assertNull(reducingState.get()); + + keyedBackend.setCurrentKey("jkl"); + state.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + state.setCurrentNamespace(namespace1); + assertEquals(expectedResult, reducingState.get()); + + keyedBackend.setCurrentKey("mno"); + state.mergeNamespaces(namespace1, asList(namespace2, namespace3)); + state.setCurrentNamespace(namespace1); + assertEquals(expectedResult, reducingState.get()); + + // make sure all lists / maps are cleared + + keyedBackend.setCurrentKey("abc"); + state.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("def"); + state.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("ghi"); + state.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("jkl"); + state.setCurrentNamespace(namespace1); + state.clear(); + + keyedBackend.setCurrentKey("mno"); + state.setCurrentNamespace(namespace1); + state.clear(); + + assertTrue(state.getStateTable().isEmpty()); + } + finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + // ------------------------------------------------------------------------ + // test functions + // ------------------------------------------------------------------------ + + @SuppressWarnings("serial") + private static class AddingFunction implements ReduceFunction { + + @Override + public Long reduce(Long a, Long b) { + return a + b; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapStateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapStateBackendTestBase.java new file mode 100644 index 0000000000000..0bb3775993abf --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/async/HeapStateBackendTestBase.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap.async; + +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.KeyGroupRange; + +import static org.mockito.Mockito.mock; + +public abstract class HeapStateBackendTestBase { + + public AsyncHeapKeyedStateBackend createKeyedBackend() throws Exception { + return new AsyncHeapKeyedStateBackend<>( + mock(TaskKvStateRegistry.class), + StringSerializer.INSTANCE, + HeapStateBackendTestBase.class.getClassLoader(), + 16, + new KeyGroupRange(0, 15)); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/BlockerCheckpointStreamFactory.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/BlockerCheckpointStreamFactory.java new file mode 100644 index 0000000000000..291f3ed024104 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/BlockerCheckpointStreamFactory.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.util; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; + +import java.io.IOException; + +/** + * {@link CheckpointStreamFactory} for tests that allows for testing cancellation in async IO + */ +@VisibleForTesting +@Internal +public class BlockerCheckpointStreamFactory implements CheckpointStreamFactory { + + private final int maxSize; + private volatile int afterNumberInvocations; + private volatile OneShotLatch blocker; + private volatile OneShotLatch waiter; + + MemCheckpointStreamFactory.MemoryCheckpointOutputStream lastCreatedStream; + + public MemCheckpointStreamFactory.MemoryCheckpointOutputStream getLastCreatedStream() { + return lastCreatedStream; + } + + public BlockerCheckpointStreamFactory(int maxSize) { + this.maxSize = maxSize; + } + + public void setAfterNumberInvocations(int afterNumberInvocations) { + this.afterNumberInvocations = afterNumberInvocations; + } + + public void setBlockerLatch(OneShotLatch latch) { + this.blocker = latch; + } + + public void setWaiterLatch(OneShotLatch latch) { + this.waiter = latch; + } + + @Override + public MemCheckpointStreamFactory.MemoryCheckpointOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp) throws Exception { + this.lastCreatedStream = new MemCheckpointStreamFactory.MemoryCheckpointOutputStream(maxSize) { + + private int afterNInvocations = afterNumberInvocations; + private final OneShotLatch streamBlocker = blocker; + private final OneShotLatch streamWaiter = waiter; + + @Override + public void write(int b) throws IOException { + + if (null != waiter) { + waiter.trigger(); + } + + if (afterNInvocations > 0) { + --afterNInvocations; + } + + if (0 == afterNInvocations && null != streamBlocker) { + try { + streamBlocker.await(); + } catch (InterruptedException ignored) { + } + } + try { + super.write(b); + } catch (IOException ex) { + if (null != streamWaiter) { + streamWaiter.trigger(); + } + throw ex; + } + + if (0 == afterNInvocations && null != streamWaiter) { + streamWaiter.trigger(); + } + } + + @Override + public void close() { + super.close(); + if (null != streamWaiter) { + streamWaiter.trigger(); + } + } + }; + + return lastCreatedStream; + } + + @Override + public void close() throws Exception { + + } +} \ No newline at end of file diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/windows/TimeWindow.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/windows/TimeWindow.java index 0d5d09130a425..a1adda13f3fa2 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/windows/TimeWindow.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/windows/TimeWindow.java @@ -23,6 +23,7 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.streaming.api.windowing.assigners.MergingWindowAssigner; +import org.apache.flink.util.MathUtils; import java.io.IOException; import java.util.ArrayList; @@ -35,7 +36,7 @@ /** * A {@link Window} that represents a time interval from {@code start} (inclusive) to - * {@code start + size} (exclusive). + * {@code end} (exclusive). */ @PublicEvolving public class TimeWindow extends Window { @@ -48,14 +49,35 @@ public TimeWindow(long start, long end) { this.end = end; } + /** + * Gets the starting timestamp of the window. This is the first timestamp that belongs + * to this window. + * + * @return The starting timestamp of this window. + */ public long getStart() { return start; } + /** + * Gets the end timestamp of this window. The end timestamp is exclusive, meaning it + * is the first timestamp that does not belong to this window any more. + * + * @return The exclusive end timestamp of this window. + */ public long getEnd() { return end; } + /** + * Gets the largest timestamp that still belongs to this window. + * + *

This timestamp is identical to {@code getEnd() - 1}. + * + * @return The largest timestamp that still belongs to this window. + * + * @see #getEnd() + */ @Override public long maxTimestamp() { return end - 1; @@ -77,17 +99,15 @@ public boolean equals(Object o) { @Override public int hashCode() { - int result = (int) (start ^ (start >>> 32)); - result = 31 * result + (int) (end ^ (end >>> 32)); - return result; + return MathUtils.longToIntWithBitMixing(start + end); } @Override public String toString() { return "TimeWindow{" + - "start=" + start + - ", end=" + end + - '}'; + "start=" + start + + ", end=" + end + + '}'; } /** @@ -104,6 +124,13 @@ public TimeWindow cover(TimeWindow other) { return new TimeWindow(Math.min(start, other.start), Math.max(end, other.end)); } + // ------------------------------------------------------------------------ + // Serializer + // ------------------------------------------------------------------------ + + /** + * The serializer used to write the TimeWindow type. + */ public static class Serializer extends TypeSerializer { private static final long serialVersionUID = 1L; @@ -152,9 +179,7 @@ public TimeWindow deserialize(DataInputView source) throws IOException { @Override public TimeWindow deserialize(TimeWindow reuse, DataInputView source) throws IOException { - long start = source.readLong(); - long end = source.readLong(); - return new TimeWindow(start, end); + return deserialize(source); } @Override @@ -179,6 +204,10 @@ public int hashCode() { } } + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + /** * Merge overlapping {@link TimeWindow}s. For use by merging * {@link org.apache.flink.streaming.api.windowing.assigners.WindowAssigner WindowAssigners}. diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java index ee417acc4c0c0..b9028c8a84ca5 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java @@ -32,7 +32,9 @@ import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.filesystem.FsStateBackend; +import org.apache.flink.runtime.state.filesystem.async.AsyncFsStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.state.memory.async.AsyncMemoryStateBackend; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -91,7 +93,7 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog } enum StateBackendEnum { - MEM, FILE, ROCKSDB_FULLY_ASYNC + MEM, FILE, ROCKSDB_FULLY_ASYNC, MEM_ASYNC, FILE_ASYNC } @BeforeClass @@ -115,12 +117,18 @@ public static void stopTestCluster() { @Before public void initStateBackend() throws IOException { switch (stateBackendEnum) { + case MEM_ASYNC: + this.stateBackend = new AsyncMemoryStateBackend(MAX_MEM_STATE_SIZE); + break; + case FILE_ASYNC: { + this.stateBackend = new AsyncFsStateBackend(tempFolder.newFolder().toURI()); + break; + } case MEM: this.stateBackend = new MemoryStateBackend(MAX_MEM_STATE_SIZE); break; case FILE: { - String backups = tempFolder.newFolder().getAbsolutePath(); - this.stateBackend = new FsStateBackend("file://" + backups); + this.stateBackend = new FsStateBackend(tempFolder.newFolder().toURI()); break; } case ROCKSDB_FULLY_ASYNC: { diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AsyncFileBackendEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AsyncFileBackendEventTimeWindowCheckpointingITCase.java new file mode 100644 index 0000000000000..a5bf10c619d64 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AsyncFileBackendEventTimeWindowCheckpointingITCase.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing; + +public class AsyncFileBackendEventTimeWindowCheckpointingITCase extends AbstractEventTimeWindowCheckpointingITCase { + + public AsyncFileBackendEventTimeWindowCheckpointingITCase() { + super(StateBackendEnum.FILE_ASYNC); + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AsyncMemBackendEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AsyncMemBackendEventTimeWindowCheckpointingITCase.java new file mode 100644 index 0000000000000..ef9ad37ebcb03 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AsyncMemBackendEventTimeWindowCheckpointingITCase.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing; + +public class AsyncMemBackendEventTimeWindowCheckpointingITCase extends AbstractEventTimeWindowCheckpointingITCase { + + public AsyncMemBackendEventTimeWindowCheckpointingITCase() { + super(StateBackendEnum.MEM_ASYNC); + } +}