From 67b06e7e57f3ad3d50c72bd6a0c521eea494bdb4 Mon Sep 17 00:00:00 2001 From: Andrey Zagrebin Date: Mon, 11 Jun 2018 19:34:47 +0200 Subject: [PATCH 1/6] [FLINK-9513] Implement TTL state wrappers factory and serializer for value with TTL --- .../common/typeutils/CompositeSerializer.java | 204 +++++++++++++++++ .../state/AbstractKeyedStateBackend.java | 18 +- .../runtime/state/KeyedStateFactory.java | 41 ++++ .../runtime/state/ttl/TtlStateFactory.java | 207 ++++++++++++++++++ 4 files changed, 454 insertions(+), 16 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java new file mode 100644 index 0000000000000..15ffff242b524 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java @@ -0,0 +1,204 @@ +package org.apache.flink.api.common.typeutils; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Base class for composite serializers. + * + *

This class serializes a list of objects + * + * @param type of custom serialized value + */ +@SuppressWarnings("unchecked") +public abstract class CompositeSerializer extends TypeSerializer { + private final List originalSerializers; + + protected CompositeSerializer(List originalSerializers) { + Preconditions.checkNotNull(originalSerializers); + this.originalSerializers = originalSerializers; + } + + protected abstract T composeValue(List values); + + protected abstract List decomposeValue(T v); + + protected abstract CompositeSerializer createSerializerInstance(List originalSerializers); + + private T composeValueInternal(List values) { + Preconditions.checkArgument(values.size() == originalSerializers.size()); + return composeValue(values); + } + + private List decomposeValueInternal(T v) { + List values = decomposeValue(v); + Preconditions.checkArgument(values.size() == originalSerializers.size()); + return values; + } + + private CompositeSerializer createSerializerInstanceInternal(List originalSerializers) { + Preconditions.checkArgument(originalSerializers.size() == originalSerializers.size()); + return createSerializerInstance(originalSerializers); + } + + @Override + public CompositeSerializer duplicate() { + return createSerializerInstanceInternal(originalSerializers.stream() + .map(TypeSerializer::duplicate) + .collect(Collectors.toList())); + } + + @Override + public boolean isImmutableType() { + return originalSerializers.stream().allMatch(TypeSerializer::isImmutableType); + } + + @Override + public T createInstance() { + return composeValueInternal(originalSerializers.stream() + .map(TypeSerializer::createInstance) + .collect(Collectors.toList())); + } + + @Override + public T copy(T from) { + List originalValues = decomposeValueInternal(from); + return composeValueInternal( + IntStream.range(0, originalSerializers.size()) + .mapToObj(i -> originalSerializers.get(i).copy(originalValues.get(i))) + .collect(Collectors.toList())); + } + + @Override + public T copy(T from, T reuse) { + List originalFromValues = decomposeValueInternal(from); + List originalReuseValues = decomposeValueInternal(reuse); + return composeValueInternal( + IntStream.range(0, originalSerializers.size()) + .mapToObj(i -> originalSerializers.get(i).copy(originalFromValues.get(i), originalReuseValues.get(i))) + .collect(Collectors.toList())); + } + + @Override + public int getLength() { + return originalSerializers.stream().allMatch(s -> s.getLength() >= 0) ? + originalSerializers.stream().mapToInt(TypeSerializer::getLength).sum() : -1; + } + + @Override + public void serialize(T record, DataOutputView target) throws IOException { + List originalValues = decomposeValueInternal(record); + for (int i = 0; i < originalSerializers.size(); i++) { + originalSerializers.get(i).serialize(originalValues.get(i), target); + } + } + + @Override + public T deserialize(DataInputView source) throws IOException { + List originalValues = new ArrayList(); + for (TypeSerializer typeSerializer : originalSerializers) { + originalValues.add(typeSerializer.deserialize(source)); + } + return composeValueInternal(originalValues); + } + + @Override + public T deserialize(T reuse, DataInputView source) throws IOException { + List originalValues = new ArrayList(); + List originalReuseValues = decomposeValueInternal(reuse); + for (int i = 0; i < originalSerializers.size(); i++) { + originalValues.add(originalSerializers.get(i).deserialize(originalReuseValues.get(i), source)); + } + return composeValueInternal(originalValues); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + for (TypeSerializer typeSerializer : originalSerializers) { + typeSerializer.copy(source, target); + } + } + + @Override + public int hashCode() { + return originalSerializers.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof CompositeSerializer) { + CompositeSerializer other = (CompositeSerializer) obj; + return other.canEqual(this) && originalSerializers.equals(other.originalSerializers); + } else { + return false; + } + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CompositeSerializer; + } + + @Override + public TypeSerializerConfigSnapshot snapshotConfiguration() { + return new CompositeTypeSerializerConfigSnapshot(originalSerializers.toArray(new TypeSerializer[]{ })) { + @Override + public int getVersion() { + return 0; + } + }; + } + + @SuppressWarnings("unchecked") + @Override + public CompatibilityResult ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) { + if (configSnapshot instanceof CompositeTypeSerializerConfigSnapshot) { + List, TypeSerializerConfigSnapshot>> previousSerializersAndConfigs = + ((CompositeTypeSerializerConfigSnapshot) configSnapshot).getNestedSerializersAndConfigs(); + + if (previousSerializersAndConfigs.size() == originalSerializers.size()) { + + List convertSerializers = new ArrayList<>(); + boolean requiresMigration = false; + CompatibilityResult compatResult; + int i = 0; + for (Tuple2, TypeSerializerConfigSnapshot> f : previousSerializersAndConfigs) { + compatResult = CompatibilityUtil.resolveCompatibilityResult( + f.f0, + UnloadableDummyTypeSerializer.class, + f.f1, + originalSerializers.get(i)); + + if (compatResult.isRequiresMigration()) { + requiresMigration = true; + + if (compatResult.getConvertDeserializer() != null) { + convertSerializers.add(new TypeDeserializerAdapter<>(compatResult.getConvertDeserializer())); + } else { + return CompatibilityResult.requiresMigration(); + } + } + + i++; + } + + if (!requiresMigration) { + return CompatibilityResult.compatible(); + } else { + return CompatibilityResult.requiresMigration( + createSerializerInstanceInternal(convertSerializers)); + } + } + } + + return CompatibilityResult.requiresMigration(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index 1690240a171d8..210b845b8b485 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -48,7 +48,8 @@ public abstract class AbstractKeyedStateBackend implements KeyedStateBackend, Snapshotable, Collection>, Closeable, - CheckpointListener { + CheckpointListener, + KeyedStateFactory{ /** {@link TypeSerializer} for our key. */ protected final TypeSerializer keySerializer; @@ -133,21 +134,6 @@ public void dispose() { keyValueStatesByName.clear(); } - /** - * Creates and returns a new {@link State}. - * - * @param namespaceSerializer TypeSerializer for the state namespace. - * @param stateDesc The {@code StateDescriptor} that contains the name of the state. - * - * @param The type of the namespace. - * @param The type of the stored state value. - * @param The type of the public API state. - * @param The type of internal state. - */ - public abstract IS createState( - TypeSerializer namespaceSerializer, - StateDescriptor stateDesc) throws Exception; - /** * @see KeyedStateBackend */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java new file mode 100644 index 0000000000000..c8944fd2710c4 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +/** This factory produces concrete state objects in backends. */ +public interface KeyedStateFactory { + /** + * Creates and returns a new {@link State}. + * + * @param namespaceSerializer TypeSerializer for the state namespace. + * @param stateDesc The {@code StateDescriptor} that contains the name of the state. + * + * @param The type of the namespace. + * @param The type of the stored state value. + * @param The type of the public API state. + * @param The type of internal state. + */ + IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java new file mode 100644 index 0000000000000..fcdb9a1eaf83c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl; + +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.CompositeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.KeyedStateFactory; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * This state factory wraps state objects, produced by backends, with TTL logic. + */ +public class TtlStateFactory { + public static IS createStateAndWrapWithTtlIfEnabled( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc, + KeyedStateFactory originalStateFactory, + TtlConfig ttlConfig, + TtlTimeProvider timeProvider) throws Exception { + return ttlConfig.getTtlUpdateType() == TtlUpdateType.Disabled ? + originalStateFactory.createState(namespaceSerializer, stateDesc) : + new TtlStateFactory(originalStateFactory, ttlConfig, timeProvider) + .createState(namespaceSerializer, stateDesc); + } + + private final Map, StateFactory> stateFactories; + + private final KeyedStateFactory originalStateFactory; + private final TtlConfig ttlConfig; + private final TtlTimeProvider timeProvider; + + private TtlStateFactory(KeyedStateFactory originalStateFactory, TtlConfig ttlConfig, TtlTimeProvider timeProvider) { + this.originalStateFactory = originalStateFactory; + this.ttlConfig = ttlConfig; + this.timeProvider = timeProvider; + this.stateFactories = createStateFactories(); + } + + private Map, StateFactory> createStateFactories() { + return Stream.of( + Tuple2.of(ValueStateDescriptor.class, (StateFactory) this::createValueState), + Tuple2.of(ListStateDescriptor.class, (StateFactory) this::createListState), + Tuple2.of(MapStateDescriptor.class, (StateFactory) this::createMapState), + Tuple2.of(ReducingStateDescriptor.class, (StateFactory) this::createReducingState), + Tuple2.of(AggregatingStateDescriptor.class, (StateFactory) this::createAggregatingState), + Tuple2.of(FoldingStateDescriptor.class, (StateFactory) this::createFoldingState) + ).collect(Collectors.toMap(t -> t.f0, t -> t.f1)); + } + + private interface StateFactory { + IS create( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception; + } + + private IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + StateFactory stateFactory = stateFactories.get(stateDesc.getClass()); + if (stateFactory == null) { + String message = String.format("State %s is not supported by %s", + stateDesc.getClass(), TtlStateFactory.class); + throw new FlinkRuntimeException(message); + } + return stateFactory.create(namespaceSerializer, stateDesc); + } + + @SuppressWarnings("unchecked") + private IS createValueState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + SV defVal = stateDesc.getDefaultValue(); + TtlValue ttlDefVal = defVal == null ? null : new TtlValue<>(defVal, Long.MAX_VALUE); + ValueStateDescriptor> ttlDescriptor = new ValueStateDescriptor<>( + stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()), ttlDefVal); + return (IS) new TtlValueState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private IS createListState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + ListStateDescriptor listStateDesc = (ListStateDescriptor) stateDesc; + ListStateDescriptor> ttlDescriptor = new ListStateDescriptor<>( + stateDesc.getName(), new TtlSerializer<>(listStateDesc.getElementSerializer())); + return (IS) new TtlListState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, listStateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private IS createMapState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + MapStateDescriptor mapStateDesc = (MapStateDescriptor) stateDesc; + MapStateDescriptor> ttlDescriptor = new MapStateDescriptor<>( + stateDesc.getName(), + mapStateDesc.getKeySerializer(), + new TtlSerializer<>(mapStateDesc.getValueSerializer())); + return (IS) new TtlMapState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, mapStateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private IS createReducingState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + ReducingStateDescriptor reducingStateDesc = (ReducingStateDescriptor) stateDesc; + ReducingStateDescriptor> ttlDescriptor = new ReducingStateDescriptor<>( + stateDesc.getName(), + new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider), + new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlReducingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private IS createAggregatingState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + AggregatingStateDescriptor aggregatingStateDescriptor = + (AggregatingStateDescriptor) stateDesc; + TtlAggregateFunction ttlAggregateFunction = new TtlAggregateFunction<>( + aggregatingStateDescriptor.getAggregateFunction(), ttlConfig, timeProvider); + AggregatingStateDescriptor, OUT> ttlDescriptor = new AggregatingStateDescriptor<>( + stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlAggregatingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction); + } + + @SuppressWarnings("unchecked") + private IS createFoldingState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + FoldingStateDescriptor foldingStateDescriptor = (FoldingStateDescriptor) stateDesc; + SV initAcc = stateDesc.getDefaultValue(); + TtlValue ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc, Long.MAX_VALUE); + FoldingStateDescriptor> ttlDescriptor = new FoldingStateDescriptor<>( + stateDesc.getName(), + ttlInitAcc, + new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider), + new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlFoldingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + private static class TtlSerializer extends CompositeSerializer> { + TtlSerializer(TypeSerializer userValueSerializer) { + super(Arrays.asList(userValueSerializer, new LongSerializer())); + } + + @Override + @SuppressWarnings("unchecked") + protected TtlValue composeValue(List values) { + return new TtlValue<>((T) values.get(0), (Long) values.get(1)); + } + + @Override + protected List decomposeValue(TtlValue v) { + return Arrays.asList(v.getUserValue(), v.getExpirationTimestamp()); + } + + @Override + @SuppressWarnings("unchecked") + protected CompositeSerializer> createSerializerInstance(List typeSerializers) { + return new TtlSerializer<>(typeSerializers.get(0)); + } + } +} From 7faacbf184a5f6a4faf057f2830d9f269bfc05ab Mon Sep 17 00:00:00 2001 From: Andrey Zagrebin Date: Fri, 22 Jun 2018 13:47:26 +0200 Subject: [PATCH 2/6] prefer loops over streams in CompositeSerializer and add tests for it --- .../common/typeutils/CompositeSerializer.java | 244 +++++++++++------- .../typeutils/CompositeSerializerTest.java | 208 +++++++++++++++ .../state/AbstractKeyedStateBackend.java | 3 +- .../runtime/state/KeyedStateBackend.java | 6 +- .../runtime/state/KeyedStateFactory.java | 5 +- .../runtime/state/ttl/TtlStateFactory.java | 79 +++--- 6 files changed, 405 insertions(+), 140 deletions(-) create mode 100644 flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java index 15ffff242b524..2cb1e4435e632 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.flink.api.common.typeutils; import org.apache.flink.api.java.tuple.Tuple2; @@ -5,138 +23,178 @@ import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.util.Preconditions; +import javax.annotation.Nonnull; + import java.io.IOException; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import java.util.Objects; /** * Base class for composite serializers. * - *

This class serializes a list of objects + *

This class serializes a composite type using array of its field serializers. + * Fields are indexed the same way as their serializers. * * @param type of custom serialized value */ -@SuppressWarnings("unchecked") public abstract class CompositeSerializer extends TypeSerializer { - private final List originalSerializers; + private static final long serialVersionUID = 1L; - protected CompositeSerializer(List originalSerializers) { - Preconditions.checkNotNull(originalSerializers); - this.originalSerializers = originalSerializers; + protected final TypeSerializer[] fieldSerializers; + final boolean isImmutableTargetType; + private final int length; + + @SuppressWarnings("unchecked") + protected CompositeSerializer(boolean isImmutableTargetType, TypeSerializer ... fieldSerializers) { + Preconditions.checkNotNull(fieldSerializers); + Preconditions.checkArgument(Arrays.stream(fieldSerializers).allMatch(Objects::nonNull)); + this.isImmutableTargetType = isImmutableTargetType; + this.fieldSerializers = (TypeSerializer[]) fieldSerializers; + this.length = calcLength(); } - protected abstract T composeValue(List values); + /** Create new instance from its fields. */ + public abstract T createInstance(@Nonnull Object ... values); - protected abstract List decomposeValue(T v); + /** Modify field of existing instance. Supported only by mutable types. */ + protected abstract void setField(@Nonnull T value, int index, Object fieldValue); - protected abstract CompositeSerializer createSerializerInstance(List originalSerializers); + /** Get field of existing instance. */ + protected abstract Object getField(@Nonnull T value, int index); - private T composeValueInternal(List values) { - Preconditions.checkArgument(values.size() == originalSerializers.size()); - return composeValue(values); - } - - private List decomposeValueInternal(T v) { - List values = decomposeValue(v); - Preconditions.checkArgument(values.size() == originalSerializers.size()); - return values; - } - - private CompositeSerializer createSerializerInstanceInternal(List originalSerializers) { - Preconditions.checkArgument(originalSerializers.size() == originalSerializers.size()); - return createSerializerInstance(originalSerializers); - } + /** Factory for concrete serializer. */ + protected abstract CompositeSerializer createSerializerInstance(TypeSerializer ... originalSerializers); @Override public CompositeSerializer duplicate() { - return createSerializerInstanceInternal(originalSerializers.stream() - .map(TypeSerializer::duplicate) - .collect(Collectors.toList())); + TypeSerializer[] duplicatedSerializers = new TypeSerializer[fieldSerializers.length]; + boolean stateful = false; + for (int index = 0; index < fieldSerializers.length; index++) { + duplicatedSerializers[index] = fieldSerializers[index].duplicate(); + if (fieldSerializers[index] != duplicatedSerializers[index]) { + stateful = true; + } + } + return stateful ? createSerializerInstance(duplicatedSerializers) : this; } @Override public boolean isImmutableType() { - return originalSerializers.stream().allMatch(TypeSerializer::isImmutableType); + for (TypeSerializer fieldSerializer : fieldSerializers) { + if (!fieldSerializer.isImmutableType()) { + return false; + } + } + return isImmutableTargetType; } @Override public T createInstance() { - return composeValueInternal(originalSerializers.stream() - .map(TypeSerializer::createInstance) - .collect(Collectors.toList())); + Object[] fields = new Object[fieldSerializers.length]; + for (int index = 0; index < fieldSerializers.length; index++) { + fields[index] = fieldSerializers[index].createInstance(); + } + return createInstance(fields); } @Override public T copy(T from) { - List originalValues = decomposeValueInternal(from); - return composeValueInternal( - IntStream.range(0, originalSerializers.size()) - .mapToObj(i -> originalSerializers.get(i).copy(originalValues.get(i))) - .collect(Collectors.toList())); + Preconditions.checkNotNull(from); + Object[] fields = new Object[fieldSerializers.length]; + for (int index = 0; index < fieldSerializers.length; index++) { + fields[index] = fieldSerializers[index].copy(getField(from, index)); + } + return createInstance(fields); } @Override public T copy(T from, T reuse) { - List originalFromValues = decomposeValueInternal(from); - List originalReuseValues = decomposeValueInternal(reuse); - return composeValueInternal( - IntStream.range(0, originalSerializers.size()) - .mapToObj(i -> originalSerializers.get(i).copy(originalFromValues.get(i), originalReuseValues.get(i))) - .collect(Collectors.toList())); + Preconditions.checkNotNull(from); + Preconditions.checkNotNull(reuse); + Object[] fields = new Object[fieldSerializers.length]; + for (int index = 0; index < fieldSerializers.length; index++) { + fields[index] = fieldSerializers[index].copy(getField(from, index), getField(reuse, index)); + } + return fromFields(fields, reuse); } @Override public int getLength() { - return originalSerializers.stream().allMatch(s -> s.getLength() >= 0) ? - originalSerializers.stream().mapToInt(TypeSerializer::getLength).sum() : -1; + return length; + } + + private int calcLength() { + int totalLength = 0; + for (TypeSerializer fieldSerializer : fieldSerializers) { + if (fieldSerializer.getLength() < 0) { + return -1; + } + totalLength += fieldSerializer.getLength(); + } + return totalLength; } @Override public void serialize(T record, DataOutputView target) throws IOException { - List originalValues = decomposeValueInternal(record); - for (int i = 0; i < originalSerializers.size(); i++) { - originalSerializers.get(i).serialize(originalValues.get(i), target); + Preconditions.checkNotNull(record); + Preconditions.checkNotNull(target); + for (int index = 0; index < fieldSerializers.length; index++) { + fieldSerializers[index].serialize(getField(record, index), target); } } @Override public T deserialize(DataInputView source) throws IOException { - List originalValues = new ArrayList(); - for (TypeSerializer typeSerializer : originalSerializers) { - originalValues.add(typeSerializer.deserialize(source)); + Preconditions.checkNotNull(source); + Object[] fields = new Object[fieldSerializers.length]; + for (int i = 0; i < fieldSerializers.length; i++) { + fields[i] = fieldSerializers[i].deserialize(source); } - return composeValueInternal(originalValues); + return createInstance(fields); } @Override public T deserialize(T reuse, DataInputView source) throws IOException { - List originalValues = new ArrayList(); - List originalReuseValues = decomposeValueInternal(reuse); - for (int i = 0; i < originalSerializers.size(); i++) { - originalValues.add(originalSerializers.get(i).deserialize(originalReuseValues.get(i), source)); + Preconditions.checkNotNull(reuse); + Preconditions.checkNotNull(source); + Object[] fields = new Object[fieldSerializers.length]; + for (int index = 0; index < fieldSerializers.length; index++) { + fields[index] = fieldSerializers[index].deserialize(getField(reuse, index), source); + } + return fromFields(fields, reuse); + } + + private T fromFields(Object[] fields, T reuse) { + if (isImmutableTargetType) { + return createInstance(fields); + } else { + for (int index = 0; index < fields.length; index++) { + setField(reuse, index, fields[index]); + } + return reuse; } - return composeValueInternal(originalValues); } @Override public void copy(DataInputView source, DataOutputView target) throws IOException { - for (TypeSerializer typeSerializer : originalSerializers) { + Preconditions.checkNotNull(source); + Preconditions.checkNotNull(target); + for (TypeSerializer typeSerializer : fieldSerializers) { typeSerializer.copy(source, target); } } @Override public int hashCode() { - return originalSerializers.hashCode(); + return Arrays.hashCode(fieldSerializers); } @Override public boolean equals(Object obj) { if (obj instanceof CompositeSerializer) { CompositeSerializer other = (CompositeSerializer) obj; - return other.canEqual(this) && originalSerializers.equals(other.originalSerializers); + return other.canEqual(this) && Arrays.equals(fieldSerializers, other.fieldSerializers); } else { return false; } @@ -149,7 +207,7 @@ public boolean canEqual(Object obj) { @Override public TypeSerializerConfigSnapshot snapshotConfiguration() { - return new CompositeTypeSerializerConfigSnapshot(originalSerializers.toArray(new TypeSerializer[]{ })) { + return new CompositeTypeSerializerConfigSnapshot(fieldSerializers) { @Override public int getVersion() { return 0; @@ -157,48 +215,44 @@ public int getVersion() { }; } - @SuppressWarnings("unchecked") @Override public CompatibilityResult ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) { if (configSnapshot instanceof CompositeTypeSerializerConfigSnapshot) { List, TypeSerializerConfigSnapshot>> previousSerializersAndConfigs = ((CompositeTypeSerializerConfigSnapshot) configSnapshot).getNestedSerializersAndConfigs(); + if (previousSerializersAndConfigs.size() == fieldSerializers.length) { + return ensureFieldCompatibility(previousSerializersAndConfigs); + } + } + return CompatibilityResult.requiresMigration(); + } - if (previousSerializersAndConfigs.size() == originalSerializers.size()) { - - List convertSerializers = new ArrayList<>(); - boolean requiresMigration = false; - CompatibilityResult compatResult; - int i = 0; - for (Tuple2, TypeSerializerConfigSnapshot> f : previousSerializersAndConfigs) { - compatResult = CompatibilityUtil.resolveCompatibilityResult( - f.f0, - UnloadableDummyTypeSerializer.class, - f.f1, - originalSerializers.get(i)); - - if (compatResult.isRequiresMigration()) { - requiresMigration = true; - - if (compatResult.getConvertDeserializer() != null) { - convertSerializers.add(new TypeDeserializerAdapter<>(compatResult.getConvertDeserializer())); - } else { - return CompatibilityResult.requiresMigration(); - } - } - - i++; - } - - if (!requiresMigration) { - return CompatibilityResult.compatible(); + @SuppressWarnings("unchecked") + private CompatibilityResult ensureFieldCompatibility( + List, TypeSerializerConfigSnapshot>> previousSerializersAndConfigs) { + TypeSerializer[] convertSerializers = new TypeSerializer[fieldSerializers.length]; + boolean requiresMigration = false; + for (int index = 0; index < previousSerializersAndConfigs.size(); index++) { + CompatibilityResult compatResult = + resolveFieldCompatibility(previousSerializersAndConfigs, index); + if (compatResult.isRequiresMigration()) { + requiresMigration = true; + if (compatResult.getConvertDeserializer() != null) { + convertSerializers[index] = new TypeDeserializerAdapter<>(compatResult.getConvertDeserializer()); } else { - return CompatibilityResult.requiresMigration( - createSerializerInstanceInternal(convertSerializers)); + return CompatibilityResult.requiresMigration(); } } } + return requiresMigration ? + CompatibilityResult.requiresMigration(createSerializerInstance(convertSerializers)) : + CompatibilityResult.compatible(); + } - return CompatibilityResult.requiresMigration(); + private CompatibilityResult resolveFieldCompatibility( + List, TypeSerializerConfigSnapshot>> previousSerializersAndConfigs, int index) { + return CompatibilityUtil.resolveCompatibilityResult( + previousSerializersAndConfigs.get(index).f0, UnloadableDummyTypeSerializer.class, + previousSerializersAndConfigs.get(index).f1, fieldSerializers[index]); } } diff --git a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java new file mode 100644 index 0000000000000..054218a5d3ab4 --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.typeutils; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TypeExtractor; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nonnull; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.function.IntFunction; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertEquals; + +/** Test suite for {@link CompositeSerializer}. */ +public class CompositeSerializerTest { + private static final ExecutionConfig execConf = new ExecutionConfig(); + + private static final List, Object[]>> TEST_FIELD_SERIALIZERS = Arrays.asList( + Tuple2.of(BooleanSerializer.INSTANCE, new Object[] { true, false }), + Tuple2.of(LongSerializer.INSTANCE, new Object[] { 1L, 23L }), + Tuple2.of(StringSerializer.INSTANCE, new Object[] { "teststr1", "teststr2" }), + Tuple2.of(TypeInformation.of(Pojo.class).createSerializer(execConf), + new Object[] { new Pojo(3, new String[] { "123", "456" }), new Pojo(6, new String[] { }) }) + ); + + @Test + public void testSingleFieldSerializer() { + TEST_FIELD_SERIALIZERS.forEach(t -> { + @SuppressWarnings("unchecked") + TypeSerializer[] fieldSerializers = new TypeSerializer[] { t.f0 }; + List[] instances = Arrays.stream(t.f1) + .map(Arrays::asList) + .toArray((IntFunction[]>) List[]::new); + runTests(t.f0.getLength(), fieldSerializers, instances); + }); + } + + @Test + public void testPairFieldSerializer() { + TEST_FIELD_SERIALIZERS.forEach(t1 -> + TEST_FIELD_SERIALIZERS.forEach(t2 -> { + @SuppressWarnings("unchecked") + TypeSerializer[] fieldSerializers = new TypeSerializer[] { t1.f0, t2.f0 }; + List[] instances = IntStream.range(0, t1.f1.length) + .mapToObj(i -> Arrays.asList(t1.f1[i], t2.f1[i])) + .toArray((IntFunction[]>) List[]::new); + runTests(getLength(fieldSerializers), fieldSerializers, instances); + })); + } + + @Test + public void testAllFieldSerializer() { + @SuppressWarnings("unchecked") + TypeSerializer[] fieldSerializers = TEST_FIELD_SERIALIZERS.stream() + .map(t -> (TypeSerializer) t.f0) + .toArray((IntFunction[]>) TypeSerializer[]::new); + List[] instances = IntStream.range(0, TEST_FIELD_SERIALIZERS.get(0).f1.length) + .mapToObj(CompositeSerializerTest::getTestCase) + .toArray((IntFunction[]>) List[]::new); + runTests(getLength(fieldSerializers), fieldSerializers, instances); + } + + // needs to be Arrays.ArrayList for all tests + private static List getTestCase(int index) { + return Arrays.asList(TEST_FIELD_SERIALIZERS.stream() + .map(t -> t.f1[index]) + .toArray(Object[]::new)); + } + + private static int getLength(TypeSerializer[] fieldSerializers) { + return Arrays.stream(fieldSerializers).allMatch(fs -> fs.getLength() > 0) ? + Arrays.stream(fieldSerializers).mapToInt(TypeSerializer::getLength).sum() : -1; + } + + @SuppressWarnings("unchecked") + private void runTests( + int length, + TypeSerializer[] fieldSerializers, + List ... instances) { + try { + for (boolean immutability : Arrays.asList(true, false)) { + TypeSerializer> serializer = new TestListCompositeSerializer(immutability, fieldSerializers); + CompositeSerializerTestInstance test = new CompositeSerializerTestInstance(serializer, length, instances); + test.testAll(); + } + } + catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + Assert.fail(e.getMessage()); + } + } + + private static class Pojo { + public int f1; + public String[] f2; + + private Pojo(int f1, String[] f2) { + this.f1 = f1; + this.f2 = f2; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Pojo pojo = (Pojo) o; + return f1 == pojo.f1 && + Arrays.equals(f2, pojo.f2); + } + + @Override + public int hashCode() { + + int result = Objects.hash(f1); + result = 31 * result + Arrays.hashCode(f2); + return result; + } + + @Override + public String toString() { + return "Pojo{" + + "f1=" + f1 + + ", f2=" + Arrays.toString(f2) + + '}'; + } + } + + private static class TestListCompositeSerializer extends CompositeSerializer> { + TestListCompositeSerializer(boolean isImmutableTargetType, TypeSerializer... fieldSerializers) { + super(isImmutableTargetType, fieldSerializers); + } + + @Override + public List createInstance(@Nonnull Object... values) { + return Arrays.asList(values); + } + + @Override + protected void setField(@Nonnull List value, int index, Object fieldValue) { + if (isImmutableTargetType) { + throw new UnsupportedOperationException("Type is immutable"); + } else { + value.set(index, fieldValue); + } + } + + @Override + protected Object getField(@Nonnull List value, int index) { + return value.get(index); + } + + @Override + protected CompositeSerializer> createSerializerInstance(TypeSerializer... originalSerializers) { + return new TestListCompositeSerializer(isImmutableTargetType, originalSerializers); + } + } + + private static class CompositeSerializerTestInstance extends SerializerTestInstance> { + @SuppressWarnings("unchecked") + CompositeSerializerTestInstance( + TypeSerializer> serializer, + int length, + List ... testData) { + super(serializer, getCls(testData[0]), length, testData); + } + + private static Class> getCls(List instance) { + return TypeExtractor.getForObject(instance).getTypeClass(); + } + + protected void deepEquals(String message, List should, List is) { + assertEquals(message, should, is); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index 210b845b8b485..8ce25b6fa0f8d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -48,8 +48,7 @@ public abstract class AbstractKeyedStateBackend implements KeyedStateBackend, Snapshotable, Collection>, Closeable, - CheckpointListener, - KeyedStateFactory{ + CheckpointListener { /** {@link TypeSerializer} for our key. */ protected final TypeSerializer keySerializer; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index 3326a814609ff..ad75a1f86c0df 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -31,7 +31,7 @@ * * @param The key by which state is keyed. */ -public interface KeyedStateBackend extends InternalKeyContext, Disposable { +public interface KeyedStateBackend extends InternalKeyContext, KeyedStateFactory, Disposable { /** * Sets the current key that is used for partitioned state. @@ -70,7 +70,7 @@ void applyToAllKeys( * * @param namespaceSerializer The serializer used for the namespace type of the state * @param stateDescriptor The identifier for the state. This contains name and can create a default state value. - * + * * @param The type of the namespace. * @param The type of the state. * @@ -84,7 +84,7 @@ S getOrCreateKeyedState( /** * Creates or retrieves a partitioned state backed by this state backend. - * + * * TODO: NOTE: This method does a lot of work caching / retrieving states just to update the namespace. * This method should be removed for the sake of namespaces being lazily fetched from the keyed * state backend, or being set on the state directly. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java index c8944fd2710c4..dd251bd9fc99a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java @@ -21,11 +21,12 @@ import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.internal.InternalKvState; -/** This factory produces concrete state objects in backends. */ +/** This factory produces concrete internal state objects. */ public interface KeyedStateFactory { /** - * Creates and returns a new {@link State}. + * Creates and returns a new {@link InternalKvState}. * * @param namespaceSerializer TypeSerializer for the state namespace. * @param stateDesc The {@code StateDescriptor} that contains the name of the state. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java index fcdb9a1eaf83c..1ad5f26a497aa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java @@ -32,9 +32,10 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.state.KeyedStateFactory; import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nonnull; -import java.util.Arrays; -import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -42,6 +43,7 @@ /** * This state factory wraps state objects, produced by backends, with TTL logic. */ +@SuppressWarnings("unchecked") public class TtlStateFactory { public static IS createStateAndWrapWithTtlIfEnabled( TypeSerializer namespaceSerializer, @@ -49,13 +51,18 @@ public static IS createStateAndWrapWithTt KeyedStateFactory originalStateFactory, TtlConfig ttlConfig, TtlTimeProvider timeProvider) throws Exception { - return ttlConfig.getTtlUpdateType() == TtlUpdateType.Disabled ? + Preconditions.checkNotNull(namespaceSerializer); + Preconditions.checkNotNull(stateDesc); + Preconditions.checkNotNull(originalStateFactory); + Preconditions.checkNotNull(ttlConfig); + Preconditions.checkNotNull(timeProvider); + return ttlConfig.getTtlUpdateType() == TtlConfig.TtlUpdateType.Disabled ? originalStateFactory.createState(namespaceSerializer, stateDesc) : new TtlStateFactory(originalStateFactory, ttlConfig, timeProvider) .createState(namespaceSerializer, stateDesc); } - private final Map, StateFactory> stateFactories; + private final Map, KeyedStateFactory> stateFactories; private final KeyedStateFactory originalStateFactory; private final TtlConfig ttlConfig; @@ -68,49 +75,40 @@ private TtlStateFactory(KeyedStateFactory originalStateFactory, TtlConfig ttlCon this.stateFactories = createStateFactories(); } - private Map, StateFactory> createStateFactories() { + @SuppressWarnings("deprecation") + private Map, KeyedStateFactory> createStateFactories() { return Stream.of( - Tuple2.of(ValueStateDescriptor.class, (StateFactory) this::createValueState), - Tuple2.of(ListStateDescriptor.class, (StateFactory) this::createListState), - Tuple2.of(MapStateDescriptor.class, (StateFactory) this::createMapState), - Tuple2.of(ReducingStateDescriptor.class, (StateFactory) this::createReducingState), - Tuple2.of(AggregatingStateDescriptor.class, (StateFactory) this::createAggregatingState), - Tuple2.of(FoldingStateDescriptor.class, (StateFactory) this::createFoldingState) + Tuple2.of(ValueStateDescriptor.class, (KeyedStateFactory) this::createValueState), + Tuple2.of(ListStateDescriptor.class, (KeyedStateFactory) this::createListState), + Tuple2.of(MapStateDescriptor.class, (KeyedStateFactory) this::createMapState), + Tuple2.of(ReducingStateDescriptor.class, (KeyedStateFactory) this::createReducingState), + Tuple2.of(AggregatingStateDescriptor.class, (KeyedStateFactory) this::createAggregatingState), + Tuple2.of(FoldingStateDescriptor.class, (KeyedStateFactory) this::createFoldingState) ).collect(Collectors.toMap(t -> t.f0, t -> t.f1)); } - private interface StateFactory { - IS create( - TypeSerializer namespaceSerializer, - StateDescriptor stateDesc) throws Exception; - } - private IS createState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { - StateFactory stateFactory = stateFactories.get(stateDesc.getClass()); + KeyedStateFactory stateFactory = stateFactories.get(stateDesc.getClass()); if (stateFactory == null) { String message = String.format("State %s is not supported by %s", stateDesc.getClass(), TtlStateFactory.class); throw new FlinkRuntimeException(message); } - return stateFactory.create(namespaceSerializer, stateDesc); + return stateFactory.createState(namespaceSerializer, stateDesc); } - @SuppressWarnings("unchecked") private IS createValueState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { - SV defVal = stateDesc.getDefaultValue(); - TtlValue ttlDefVal = defVal == null ? null : new TtlValue<>(defVal, Long.MAX_VALUE); ValueStateDescriptor> ttlDescriptor = new ValueStateDescriptor<>( - stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()), ttlDefVal); + stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer())); return (IS) new TtlValueState<>( originalStateFactory.createState(namespaceSerializer, ttlDescriptor), ttlConfig, timeProvider, stateDesc.getSerializer()); } - @SuppressWarnings("unchecked") private IS createListState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -122,7 +120,6 @@ private IS createListState( ttlConfig, timeProvider, listStateDesc.getSerializer()); } - @SuppressWarnings("unchecked") private IS createMapState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -136,7 +133,6 @@ private IS createMapState( ttlConfig, timeProvider, mapStateDesc.getSerializer()); } - @SuppressWarnings("unchecked") private IS createReducingState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -150,7 +146,6 @@ private IS createReducingState( ttlConfig, timeProvider, stateDesc.getSerializer()); } - @SuppressWarnings("unchecked") private IS createAggregatingState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -165,7 +160,7 @@ private IS createAggregatingStat ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction); } - @SuppressWarnings("unchecked") + @SuppressWarnings("deprecation") private IS createFoldingState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -175,33 +170,41 @@ private IS createFoldingState( FoldingStateDescriptor> ttlDescriptor = new FoldingStateDescriptor<>( stateDesc.getName(), ttlInitAcc, - new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider), + new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider, initAcc), new TtlSerializer<>(stateDesc.getSerializer())); return (IS) new TtlFoldingState<>( originalStateFactory.createState(namespaceSerializer, ttlDescriptor), ttlConfig, timeProvider, stateDesc.getSerializer()); } + /** Serializer for user state value with TTL. */ private static class TtlSerializer extends CompositeSerializer> { + TtlSerializer(TypeSerializer userValueSerializer) { - super(Arrays.asList(userValueSerializer, new LongSerializer())); + super(true, userValueSerializer, LongSerializer.INSTANCE); + } + + @Override + public TtlValue createInstance(@Nonnull Object ... values) { + Preconditions.checkArgument(values.length == 2); + return new TtlValue<>((T) values[0], (long) values[1]); } @Override - @SuppressWarnings("unchecked") - protected TtlValue composeValue(List values) { - return new TtlValue<>((T) values.get(0), (Long) values.get(1)); + protected void setField(@Nonnull TtlValue v, int index, Object fieldValue) { + throw new UnsupportedOperationException("TtlValue is immutable"); } @Override - protected List decomposeValue(TtlValue v) { - return Arrays.asList(v.getUserValue(), v.getExpirationTimestamp()); + protected Object getField(@Nonnull TtlValue v, int index) { + return index == 0 ? v.getUserValue() : v.getLastAccessTimestamp(); } @Override - @SuppressWarnings("unchecked") - protected CompositeSerializer> createSerializerInstance(List typeSerializers) { - return new TtlSerializer<>(typeSerializers.get(0)); + protected CompositeSerializer> createSerializerInstance(TypeSerializer ... originalSerializers) { + Preconditions.checkNotNull(originalSerializers); + Preconditions.checkArgument(originalSerializers.length == 2); + return new TtlSerializer<>((TypeSerializer) originalSerializers[0]); } } } From f8097350c1a604c3ede1ecb64e4595c8620155ae Mon Sep 17 00:00:00 2001 From: Andrey Zagrebin Date: Tue, 3 Jul 2018 12:07:21 +0200 Subject: [PATCH 3/6] Precompute immutability flag, statefulness and hashcode in CompositeSerializer constructor. Include TtlStateFactory in TTL tests. --- .../common/typeutils/CompositeSerializer.java | 67 ++++++++++------ .../typeutils/CompositeSerializerTest.java | 4 +- .../runtime/state/ttl/TtlStateFactory.java | 10 ++- .../state/ttl/TtlAggregatingStateTest.java | 43 +++-------- .../state/ttl/TtlFoldingStateTest.java | 33 ++------ .../runtime/state/ttl/TtlListStateTest.java | 55 +++----------- .../state/ttl/TtlMapStatePerElementTest.java | 7 +- .../runtime/state/ttl/TtlMapStateTest.java | 7 +- .../state/ttl/TtlMapStateTestBase.java | 33 ++++++++ .../state/ttl/TtlReducingStateTest.java | 41 +++------- .../runtime/state/ttl/TtlStateTestBase.java | 17 +++++ .../runtime/state/ttl/TtlValueStateTest.java | 25 ++---- .../mock/MockInternalAggregatingState.java | 60 +++++++++++++++ .../ttl/mock/MockInternalFoldingState.java | 55 ++++++++++++++ .../ttl/{ => mock}/MockInternalKvState.java | 3 +- .../state/ttl/mock/MockInternalListState.java | 76 +++++++++++++++++++ .../ttl/{ => mock}/MockInternalMapState.java | 17 ++++- .../{ => mock}/MockInternalMergingState.java | 3 +- .../ttl/mock/MockInternalReducingState.java | 59 ++++++++++++++ .../ttl/mock/MockInternalValueState.java | 46 +++++++++++ .../state/ttl/mock/MockKeyedStateFactory.java | 64 ++++++++++++++++ 21 files changed, 524 insertions(+), 201 deletions(-) create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTestBase.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalAggregatingState.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalFoldingState.java rename flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/{ => mock}/MockInternalKvState.java (96%) create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalListState.java rename flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/{ => mock}/MockInternalMapState.java (77%) rename flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/{ => mock}/MockInternalMergingState.java (93%) create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalReducingState.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalValueState.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateFactory.java diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java index 2cb1e4435e632..49f9ef121b94a 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java @@ -29,6 +29,7 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.stream.IntStream; /** * Base class for composite serializers. @@ -41,17 +42,36 @@ public abstract class CompositeSerializer extends TypeSerializer { private static final long serialVersionUID = 1L; + /** Serializers for fields which constitute T. */ protected final TypeSerializer[] fieldSerializers; - final boolean isImmutableTargetType; + + /** Whether T is an immutable type. */ + final boolean immutableTargetType; + + /** Byte length of target object in serialized form. */ private final int length; + /** Whether any field serializer is stateful. */ + private final boolean stateful; + + private final int hashCode; + @SuppressWarnings("unchecked") - protected CompositeSerializer(boolean isImmutableTargetType, TypeSerializer ... fieldSerializers) { + protected CompositeSerializer(boolean immutableTargetType, TypeSerializer ... fieldSerializers) { Preconditions.checkNotNull(fieldSerializers); Preconditions.checkArgument(Arrays.stream(fieldSerializers).allMatch(Objects::nonNull)); - this.isImmutableTargetType = isImmutableTargetType; + this.immutableTargetType = immutableTargetType && + Arrays.stream(fieldSerializers).allMatch(TypeSerializer::isImmutableType); this.fieldSerializers = (TypeSerializer[]) fieldSerializers; this.length = calcLength(); + this.stateful = isStateful(); + this.hashCode = Arrays.hashCode(fieldSerializers); + } + + private boolean isStateful() { + TypeSerializer[] duplicatedSerializers = duplicateFieldSerializers(); + return IntStream.range(0, fieldSerializers.length) + .anyMatch(i -> fieldSerializers[i] != duplicatedSerializers[i]); } /** Create new instance from its fields. */ @@ -68,25 +88,20 @@ protected CompositeSerializer(boolean isImmutableTargetType, TypeSerializer . @Override public CompositeSerializer duplicate() { + return stateful ? createSerializerInstance(duplicateFieldSerializers()) : this; + } + + private TypeSerializer[] duplicateFieldSerializers() { TypeSerializer[] duplicatedSerializers = new TypeSerializer[fieldSerializers.length]; - boolean stateful = false; for (int index = 0; index < fieldSerializers.length; index++) { duplicatedSerializers[index] = fieldSerializers[index].duplicate(); - if (fieldSerializers[index] != duplicatedSerializers[index]) { - stateful = true; - } } - return stateful ? createSerializerInstance(duplicatedSerializers) : this; + return duplicatedSerializers; } @Override public boolean isImmutableType() { - for (TypeSerializer fieldSerializer : fieldSerializers) { - if (!fieldSerializer.isImmutableType()) { - return false; - } - } - return isImmutableTargetType; + return immutableTargetType; } @Override @@ -101,6 +116,9 @@ public T createInstance() { @Override public T copy(T from) { Preconditions.checkNotNull(from); + if (isImmutableType()) { + return from; + } Object[] fields = new Object[fieldSerializers.length]; for (int index = 0; index < fieldSerializers.length; index++) { fields[index] = fieldSerializers[index].copy(getField(from, index)); @@ -112,11 +130,14 @@ public T copy(T from) { public T copy(T from, T reuse) { Preconditions.checkNotNull(from); Preconditions.checkNotNull(reuse); + if (isImmutableType()) { + return from; + } Object[] fields = new Object[fieldSerializers.length]; for (int index = 0; index < fieldSerializers.length; index++) { fields[index] = fieldSerializers[index].copy(getField(from, index), getField(reuse, index)); } - return fromFields(fields, reuse); + return createInstanceWithReuse(fields, reuse); } @Override @@ -162,18 +183,14 @@ public T deserialize(T reuse, DataInputView source) throws IOException { for (int index = 0; index < fieldSerializers.length; index++) { fields[index] = fieldSerializers[index].deserialize(getField(reuse, index), source); } - return fromFields(fields, reuse); + return immutableTargetType ? createInstance(fields) : createInstanceWithReuse(fields, reuse); } - private T fromFields(Object[] fields, T reuse) { - if (isImmutableTargetType) { - return createInstance(fields); - } else { - for (int index = 0; index < fields.length; index++) { - setField(reuse, index, fields[index]); - } - return reuse; + private T createInstanceWithReuse(Object[] fields, T reuse) { + for (int index = 0; index < fields.length; index++) { + setField(reuse, index, fields[index]); } + return reuse; } @Override @@ -187,7 +204,7 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public int hashCode() { - return Arrays.hashCode(fieldSerializers); + return hashCode; } @Override diff --git a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java index 054218a5d3ab4..7d8ac20c82a16 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java @@ -170,7 +170,7 @@ public List createInstance(@Nonnull Object... values) { @Override protected void setField(@Nonnull List value, int index, Object fieldValue) { - if (isImmutableTargetType) { + if (immutableTargetType) { throw new UnsupportedOperationException("Type is immutable"); } else { value.set(index, fieldValue); @@ -184,7 +184,7 @@ protected Object getField(@Nonnull List value, int index) { @Override protected CompositeSerializer> createSerializerInstance(TypeSerializer... originalSerializers) { - return new TestListCompositeSerializer(isImmutableTargetType, originalSerializers); + return new TestListCompositeSerializer(immutableTargetType, originalSerializers); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java index 1ad5f26a497aa..f37bc529df78e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java @@ -43,7 +43,6 @@ /** * This state factory wraps state objects, produced by backends, with TTL logic. */ -@SuppressWarnings("unchecked") public class TtlStateFactory { public static IS createStateAndWrapWithTtlIfEnabled( TypeSerializer namespaceSerializer, @@ -99,6 +98,7 @@ private IS createState( return stateFactory.createState(namespaceSerializer, stateDesc); } + @SuppressWarnings("unchecked") private IS createValueState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -109,6 +109,7 @@ private IS createValueState( ttlConfig, timeProvider, stateDesc.getSerializer()); } + @SuppressWarnings("unchecked") private IS createListState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -120,6 +121,7 @@ private IS createListState( ttlConfig, timeProvider, listStateDesc.getSerializer()); } + @SuppressWarnings("unchecked") private IS createMapState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -133,6 +135,7 @@ private IS createMapState( ttlConfig, timeProvider, mapStateDesc.getSerializer()); } + @SuppressWarnings("unchecked") private IS createReducingState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -146,6 +149,7 @@ private IS createReducingState( ttlConfig, timeProvider, stateDesc.getSerializer()); } + @SuppressWarnings("unchecked") private IS createAggregatingState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -160,7 +164,7 @@ private IS createAggregatingStat ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction); } - @SuppressWarnings("deprecation") + @SuppressWarnings({"deprecation", "unchecked"}) private IS createFoldingState( TypeSerializer namespaceSerializer, StateDescriptor stateDesc) throws Exception { @@ -184,6 +188,7 @@ private static class TtlSerializer extends CompositeSerializer> { super(true, userValueSerializer, LongSerializer.INSTANCE); } + @SuppressWarnings("unchecked") @Override public TtlValue createInstance(@Nonnull Object ... values) { Preconditions.checkArgument(values.length == 2); @@ -200,6 +205,7 @@ protected Object getField(@Nonnull TtlValue v, int index) { return index == 0 ? v.getUserValue() : v.getLastAccessTimestamp(); } + @SuppressWarnings("unchecked") @Override protected CompositeSerializer> createSerializerInstance(TypeSerializer ... originalSerializers) { Preconditions.checkNotNull(originalSerializers); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAggregatingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAggregatingStateTest.java index 477f0576ab03a..5d9c682601648 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAggregatingStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAggregatingStateTest.java @@ -19,8 +19,9 @@ package org.apache.flink.runtime.state.ttl; import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.runtime.state.internal.InternalAggregatingState; import java.util.HashSet; import java.util.List; @@ -31,15 +32,6 @@ public class TtlAggregatingStateTest extends TtlMergingStateBase.TtlIntegerMergingStateBase, Integer, String> { private static final long DEFAULT_ACCUMULATOR = 3L; - @Override - TtlAggregatingState createState() { - TtlAggregateFunction ttlAggregateFunction = - new TtlAggregateFunction<>(AGGREGATE, ttlConfig, timeProvider); - return new TtlAggregatingState<>( - new MockInternalTtlAggregatingState<>(ttlAggregateFunction), - ttlConfig, timeProvider, null, ttlAggregateFunction); - } - @Override void initTestValues() { updater = v -> ttlState.add(v); @@ -55,6 +47,13 @@ void initTestValues() { getUpdateExpired = "9"; } + @Override + TtlAggregatingState createState() { + AggregatingStateDescriptor aggregatingStateDes = + new AggregatingStateDescriptor<>("TtlTestAggregatingState", AGGREGATE, LongSerializer.INSTANCE); + return (TtlAggregatingState) wrapMockState(aggregatingStateDes); + } + @Override String getMergeResult( List> unexpiredUpdatesToMerge, @@ -66,30 +65,6 @@ String getMergeResult( namespaces.size() * (int) DEFAULT_ACCUMULATOR); } - private static class MockInternalTtlAggregatingState - extends MockInternalMergingState implements InternalAggregatingState { - private final AggregateFunction aggregateFunction; - - private MockInternalTtlAggregatingState(AggregateFunction aggregateFunction) { - this.aggregateFunction = aggregateFunction; - } - - @Override - public OUT get() { - return aggregateFunction.getResult(getInternal()); - } - - @Override - public void add(IN value) { - updateInternal(aggregateFunction.add(value, getInternal())); - } - - @Override - ACC mergeState(ACC acc, ACC nAcc) { - return aggregateFunction.merge(acc, nAcc); - } - } - private static final AggregateFunction AGGREGATE = new AggregateFunction() { @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlFoldingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlFoldingStateTest.java index 01d5ee1d81dcd..8dac8ca4328ef 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlFoldingStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlFoldingStateTest.java @@ -19,17 +19,12 @@ package org.apache.flink.runtime.state.ttl; import org.apache.flink.api.common.functions.FoldFunction; -import org.apache.flink.runtime.state.internal.InternalFoldingState; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.typeutils.base.StringSerializer; /** Test suite for {@link TtlFoldingState}. */ +@SuppressWarnings("deprecation") public class TtlFoldingStateTest extends TtlStateTestBase, Long, String> { - @Override - TtlFoldingState createState() { - FoldFunction> ttlFoldFunction = new TtlFoldFunction<>(FOLD, ttlConfig, timeProvider, "1"); - return new TtlFoldingState<>( - new MockInternalFoldingState<>(ttlFoldFunction), ttlConfig, timeProvider, null); - } - @Override void initTestValues() { updater = v -> ttlState.add(v); @@ -45,23 +40,11 @@ void initTestValues() { getUpdateExpired = "7"; } - private static class MockInternalFoldingState - extends MockInternalKvState implements InternalFoldingState { - private final FoldFunction foldFunction; - - private MockInternalFoldingState(FoldFunction foldFunction) { - this.foldFunction = foldFunction; - } - - @Override - public ACC get() { - return getInternal(); - } - - @Override - public void add(T value) throws Exception { - updateInternal(foldFunction.fold(get(), value)); - } + @Override + TtlFoldingState createState() { + FoldingStateDescriptor foldingStateDesc = + new FoldingStateDescriptor<>("TtlTestFoldingState", "1", FOLD, StringSerializer.INSTANCE); + return (TtlFoldingState) wrapMockState(foldingStateDesc); } private static final FoldFunction FOLD = (acc, val) -> { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlListStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlListStateTest.java index 5128ff21a4373..893f9aeff062c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlListStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlListStateTest.java @@ -18,8 +18,9 @@ package org.apache.flink.runtime.state.ttl; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.runtime.state.internal.InternalListState; import java.util.ArrayList; import java.util.Arrays; @@ -32,11 +33,6 @@ /** Test suite for {@link TtlListState}. */ public class TtlListStateTest extends TtlMergingStateBase, List, Iterable> { - @Override - TtlListState createState() { - return new TtlListState<>(new MockInternalListState<>(), ttlConfig, timeProvider, null); - } - @Override void initTestValues() { updater = v -> ttlState.addAll(v); @@ -54,6 +50,13 @@ void initTestValues() { getUpdateExpired = updateExpired; } + @Override + TtlListState createState() { + ListStateDescriptor listStateDesc = + new ListStateDescriptor<>("TtlTestListState", IntSerializer.INSTANCE); + return (TtlListState) wrapMockState(listStateDesc); + } + @Override List generateRandomUpdate() { int size = RANDOM.nextInt(5); @@ -69,44 +72,4 @@ Iterable getMergeResult( return result; } - private static class MockInternalListState - extends MockInternalMergingState, Iterable> - implements InternalListState { - - MockInternalListState() { - super(ArrayList::new); - } - - @Override - public void update(List elements) { - updateInternal(elements); - } - - @Override - public void addAll(List elements) { - getInternal().addAll(elements); - } - - @Override - List mergeState(List acc, List nAcc) { - acc = new ArrayList<>(acc); - acc.addAll(nAcc); - return acc; - } - - @Override - public Iterable get() { - return getInternal(); - } - - @Override - public void add(T element) { - getInternal().add(element); - } - - @Override - public void clear() { - getInternal().clear(); - } - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStatePerElementTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStatePerElementTest.java index ac9b03858abd3..d6949e7c1c15a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStatePerElementTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStatePerElementTest.java @@ -19,17 +19,12 @@ package org.apache.flink.runtime.state.ttl; /** Test suite for per element methods of {@link TtlMapState}. */ -public class TtlMapStatePerElementTest extends TtlStateTestBase, String, String> { +public class TtlMapStatePerElementTest extends TtlMapStateTestBase { private static final int TEST_KEY = 1; private static final String TEST_VAL1 = "test value1"; private static final String TEST_VAL2 = "test value2"; private static final String TEST_VAL3 = "test value3"; - @Override - TtlMapState createState() { - return new TtlMapState<>(new MockInternalMapState<>(), ttlConfig, timeProvider, null); - } - @Override void initTestValues() { updater = v -> ttlState.put(TEST_KEY, v); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTest.java index 535962b759c5c..bac2f41ea55ec 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTest.java @@ -29,12 +29,7 @@ /** Test suite for collection methods of {@link TtlMapState}. */ public class TtlMapStateTest extends - TtlStateTestBase, Map, Set>> { - - @Override - TtlMapState createState() { - return new TtlMapState<>(new MockInternalMapState<>(), ttlConfig, timeProvider, null); - } + TtlMapStateTestBase, Set>> { @Override void initTestValues() { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTestBase.java new file mode 100644 index 0000000000000..dab319470f45b --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlMapStateTestBase.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl; + +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; + +abstract class TtlMapStateTestBase + extends TtlStateTestBase, UV, GV> { + @Override + TtlMapState createState() { + MapStateDescriptor mapStateDesc = + new MapStateDescriptor<>("TtlTestMapState", IntSerializer.INSTANCE, StringSerializer.INSTANCE); + return (TtlMapState) wrapMockState(mapStateDesc); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlReducingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlReducingStateTest.java index 44043a1faa7e6..bc5f67fc69997 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlReducingStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlReducingStateTest.java @@ -19,21 +19,15 @@ package org.apache.flink.runtime.state.ttl; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.runtime.state.internal.InternalReducingState; import java.util.List; /** Test suite for {@link TtlReducingState}. */ public class TtlReducingStateTest extends TtlMergingStateBase.TtlIntegerMergingStateBase, Integer, Integer> { - @Override - TtlReducingState createState() { - ReduceFunction> ttlReduceFunction = new TtlReduceFunction<>(REDUCE, ttlConfig, timeProvider); - return new TtlReducingState<>( - new MockInternalReducingState<>(ttlReduceFunction), ttlConfig, timeProvider, null); - } - @Override void initTestValues() { updater = v -> ttlState.add(v); @@ -49,6 +43,13 @@ void initTestValues() { getUpdateExpired = 6; } + @Override + TtlReducingState createState() { + ReducingStateDescriptor aggregatingStateDes = + new ReducingStateDescriptor<>("TtlTestReducingState", REDUCE, IntSerializer.INSTANCE); + return (TtlReducingState) wrapMockState(aggregatingStateDes); + } + @Override Integer getMergeResult( List> unexpiredUpdatesToMerge, @@ -56,30 +57,6 @@ Integer getMergeResult( return getIntegerMergeResult(unexpiredUpdatesToMerge, finalUpdatesToMerge); } - private static class MockInternalReducingState - extends MockInternalMergingState implements InternalReducingState { - private final ReduceFunction reduceFunction; - - private MockInternalReducingState(ReduceFunction reduceFunction) { - this.reduceFunction = reduceFunction; - } - - @Override - public T get() { - return getInternal(); - } - - @Override - public void add(T value) throws Exception { - updateInternal(reduceFunction.reduce(get(), value)); - } - - @Override - T mergeState(T t, T nAcc) throws Exception { - return reduceFunction.reduce(t, nAcc); - } - } - private static final ReduceFunction REDUCE = (v1, v2) -> { if (v1 == null && v2 == null) { return null; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java index a1bd72a578c52..13eac3258f534 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java @@ -18,8 +18,14 @@ package org.apache.flink.runtime.state.ttl; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.time.Time; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.runtime.state.KeyedStateFactory; import org.apache.flink.runtime.state.internal.InternalKvState; +import org.apache.flink.runtime.state.ttl.mock.MockKeyedStateFactory; +import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.util.function.SupplierWithException; import org.apache.flink.util.function.ThrowingConsumer; @@ -29,6 +35,7 @@ abstract class TtlStateTestBase, UV, GV> { private static final long TTL = 100; + private static final KeyedStateFactory MOCK_ORIGINAL_STATE_FACTORY = new MockKeyedStateFactory(); S ttlState; MockTimeProvider timeProvider; @@ -69,6 +76,16 @@ private void initTest(TtlConfig.TtlUpdateType updateType, TtlConfig.TtlStateVisi abstract S createState(); + IS wrapMockState(StateDescriptor stateDesc) { + try { + return TtlStateFactory.createStateAndWrapWithTtlIfEnabled( + StringSerializer.INSTANCE, stateDesc, + MOCK_ORIGINAL_STATE_FACTORY, ttlConfig, timeProvider); + } catch (Exception e) { + throw new FlinkRuntimeException("Unexpected exception wrapping mock state", e); + } + } + abstract void initTestValues(); @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlValueStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlValueStateTest.java index f98c981e20aa9..8d9a4b4140207 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlValueStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlValueStateTest.java @@ -18,7 +18,8 @@ package org.apache.flink.runtime.state.ttl; -import org.apache.flink.runtime.state.internal.InternalValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.StringSerializer; /** Test suite for {@link TtlValueState}. */ public class TtlValueStateTest extends TtlStateTestBase, String, String> { @@ -26,11 +27,6 @@ public class TtlValueStateTest extends TtlStateTestBase createState() { - return new TtlValueState<>(new MockInternalValueState<>(), ttlConfig, timeProvider, null); - } - @Override void initTestValues() { updater = v -> ttlState.update(v); @@ -46,17 +42,10 @@ void initTestValues() { getUpdateExpired = TEST_VAL3; } - private static class MockInternalValueState - extends MockInternalKvState implements InternalValueState { - - @Override - public T value() { - return getInternal(); - } - - @Override - public void update(T value) { - updateInternal(value); - } + @Override + TtlValueState createState() { + ValueStateDescriptor valueStateDesc = + new ValueStateDescriptor<>("TtlValueTestState", StringSerializer.INSTANCE); + return (TtlValueState) wrapMockState(valueStateDesc); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalAggregatingState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalAggregatingState.java new file mode 100644 index 0000000000000..ce8bc836777f4 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalAggregatingState.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl.mock; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.internal.InternalAggregatingState; + +/** In memory mock internal aggregating state. */ +class MockInternalAggregatingState + extends MockInternalMergingState implements InternalAggregatingState { + private final AggregateFunction aggregateFunction; + + private MockInternalAggregatingState(AggregateFunction aggregateFunction) { + this.aggregateFunction = aggregateFunction; + } + + @Override + public OUT get() { + return aggregateFunction.getResult(getInternal()); + } + + @Override + public void add(IN value) { + updateInternal(aggregateFunction.add(value, getInternal())); + } + + @Override + ACC mergeState(ACC acc, ACC nAcc) { + return aggregateFunction.merge(acc, nAcc); + } + + @SuppressWarnings({"unchecked", "unused"}) + static IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) { + AggregatingStateDescriptor aggregatingStateDesc = + (AggregatingStateDescriptor) stateDesc; + return (IS) new MockInternalAggregatingState<>(aggregatingStateDesc.getAggregateFunction()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalFoldingState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalFoldingState.java new file mode 100644 index 0000000000000..259ce331fba9d --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalFoldingState.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl.mock; + +import org.apache.flink.api.common.functions.FoldFunction; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.internal.InternalFoldingState; + +/** In memory mock internal folding state. */ +@SuppressWarnings("deprecation") +class MockInternalFoldingState + extends MockInternalKvState implements InternalFoldingState { + private final FoldFunction foldFunction; + + private MockInternalFoldingState(FoldFunction foldFunction) { + this.foldFunction = foldFunction; + } + + @Override + public ACC get() { + return getInternal(); + } + + @Override + public void add(T value) throws Exception { + updateInternal(foldFunction.fold(get(), value)); + } + + @SuppressWarnings({"unchecked", "unused"}) + static IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) { + FoldingStateDescriptor foldingStateDesc = (FoldingStateDescriptor) stateDesc; + return (IS) new MockInternalFoldingState<>(foldingStateDesc.getFoldFunction()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalKvState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalKvState.java similarity index 96% rename from flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalKvState.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalKvState.java index 878d88814308b..439ca7f2747fb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalKvState.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalKvState.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.state.ttl; +package org.apache.flink.runtime.state.ttl.mock; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.state.internal.InternalKvState; @@ -25,6 +25,7 @@ import java.util.Map; import java.util.function.Supplier; +/** In memory mock internal state base class. */ class MockInternalKvState implements InternalKvState { private Map namespacedValues = new HashMap<>(); private T defaultNamespaceValue; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalListState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalListState.java new file mode 100644 index 0000000000000..132c7e94cfdee --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalListState.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl.mock; + +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.internal.InternalListState; + +import java.util.ArrayList; +import java.util.List; + +/** In memory mock internal list state. */ +class MockInternalListState + extends MockInternalMergingState, Iterable> + implements InternalListState { + + private MockInternalListState() { + super(ArrayList::new); + } + + @Override + public void update(List elements) { + updateInternal(elements); + } + + @Override + public void addAll(List elements) { + getInternal().addAll(elements); + } + + @Override + List mergeState(List acc, List nAcc) { + acc = new ArrayList<>(acc); + acc.addAll(nAcc); + return acc; + } + + @Override + public Iterable get() { + return getInternal(); + } + + @Override + public void add(T element) { + getInternal().add(element); + } + + @Override + public void clear() { + getInternal().clear(); + } + + @SuppressWarnings({"unchecked", "unused"}) + static IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) { + return (IS) new MockInternalListState<>(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalMapState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalMapState.java similarity index 77% rename from flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalMapState.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalMapState.java index c548017c3da45..386ef9772e530 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalMapState.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalMapState.java @@ -16,19 +16,23 @@ * limitations under the License. */ -package org.apache.flink.runtime.state.ttl; +package org.apache.flink.runtime.state.ttl.mock; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.state.internal.InternalMapState; import java.util.HashMap; import java.util.Iterator; import java.util.Map; -class MockInternalMapState +/** In memory mock internal map state. */ +public class MockInternalMapState extends MockInternalKvState> implements InternalMapState { - MockInternalMapState() { + private MockInternalMapState() { super(HashMap::new); } @@ -85,4 +89,11 @@ public Iterable values() { public Iterator> iterator() { return entries().iterator(); } + + @SuppressWarnings({"unchecked", "unused"}) + static IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) { + return (IS) new MockInternalMapState<>(); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalMergingState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalMergingState.java similarity index 93% rename from flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalMergingState.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalMergingState.java index 582faf3252a64..c25eb33c86ff3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/MockInternalMergingState.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalMergingState.java @@ -16,13 +16,14 @@ * limitations under the License. */ -package org.apache.flink.runtime.state.ttl; +package org.apache.flink.runtime.state.ttl.mock; import org.apache.flink.runtime.state.internal.InternalMergingState; import java.util.Collection; import java.util.function.Supplier; +/** In memory mock internal merging state base class. */ abstract class MockInternalMergingState extends MockInternalKvState implements InternalMergingState { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalReducingState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalReducingState.java new file mode 100644 index 0000000000000..09efea65ffd10 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalReducingState.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl.mock; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.internal.InternalReducingState; + +/** In memory mock internal reducing state. */ +class MockInternalReducingState + extends MockInternalMergingState implements InternalReducingState { + private final ReduceFunction reduceFunction; + + private MockInternalReducingState(ReduceFunction reduceFunction) { + this.reduceFunction = reduceFunction; + } + + @Override + public T get() { + return getInternal(); + } + + @Override + public void add(T value) throws Exception { + updateInternal(reduceFunction.reduce(get(), value)); + } + + @Override + T mergeState(T t, T nAcc) throws Exception { + return reduceFunction.reduce(t, nAcc); + } + + @SuppressWarnings({"unchecked", "unused"}) + static IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) { + ReducingStateDescriptor reducingStateDesc = (ReducingStateDescriptor) stateDesc; + return (IS) new MockInternalReducingState<>(reducingStateDesc.getReduceFunction()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalValueState.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalValueState.java new file mode 100644 index 0000000000000..9dceafaef2c7d --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockInternalValueState.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl.mock; + +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.internal.InternalValueState; + +/** In memory mock internal value state. */ +class MockInternalValueState + extends MockInternalKvState implements InternalValueState { + + @Override + public T value() { + return getInternal(); + } + + @Override + public void update(T value) { + updateInternal(value); + } + + @SuppressWarnings({"unchecked", "unused"}) + static IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) { + return (IS) new MockInternalValueState<>(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateFactory.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateFactory.java new file mode 100644 index 0000000000000..d8433529e0715 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl.mock; + +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.KeyedStateFactory; +import org.apache.flink.runtime.state.ttl.TtlStateFactory; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** State factory which produces in memory mock state objects. */ +public class MockKeyedStateFactory implements KeyedStateFactory { + @SuppressWarnings("deprecation") + private static final Map, KeyedStateFactory> STATE_FACTORIES = + Stream.of( + Tuple2.of(ValueStateDescriptor.class, (KeyedStateFactory) MockInternalValueState::createState), + Tuple2.of(ListStateDescriptor.class, (KeyedStateFactory) MockInternalListState::createState), + Tuple2.of(MapStateDescriptor.class, (KeyedStateFactory) MockInternalMapState::createState), + Tuple2.of(ReducingStateDescriptor.class, (KeyedStateFactory) MockInternalReducingState::createState), + Tuple2.of(AggregatingStateDescriptor.class, (KeyedStateFactory) MockInternalAggregatingState::createState), + Tuple2.of(FoldingStateDescriptor.class, (KeyedStateFactory) MockInternalFoldingState::createState) + ).collect(Collectors.toMap(t -> t.f0, t -> t.f1)); + + @Override + public IS createState( + TypeSerializer namespaceSerializer, + StateDescriptor stateDesc) throws Exception { + KeyedStateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass()); + if (stateFactory == null) { + String message = String.format("State %s is not supported by %s", + stateDesc.getClass(), TtlStateFactory.class); + throw new FlinkRuntimeException(message); + } + return stateFactory.createState(namespaceSerializer, stateDesc); + } +} From 855016211e9e93c173103cbd1d616ad220e904ea Mon Sep 17 00:00:00 2001 From: Andrey Zagrebin Date: Tue, 3 Jul 2018 17:24:25 +0200 Subject: [PATCH 4/6] Precompute CompositeSerializer parameters in one loop --- .../common/typeutils/CompositeSerializer.java | 124 ++++++++++++------ .../typeutils/CompositeSerializerTest.java | 11 +- .../runtime/state/ttl/TtlStateFactory.java | 10 +- 3 files changed, 97 insertions(+), 48 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java index 49f9ef121b94a..5844c7682f9bc 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java @@ -26,10 +26,9 @@ import javax.annotation.Nonnull; import java.io.IOException; +import java.io.Serializable; import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.IntStream; /** * Base class for composite serializers. @@ -45,33 +44,21 @@ public abstract class CompositeSerializer extends TypeSerializer { /** Serializers for fields which constitute T. */ protected final TypeSerializer[] fieldSerializers; - /** Whether T is an immutable type. */ - final boolean immutableTargetType; - - /** Byte length of target object in serialized form. */ - private final int length; - - /** Whether any field serializer is stateful. */ - private final boolean stateful; - - private final int hashCode; + final PrecomputedParameters precomputed; + /** Can be used for user facing constructor. */ @SuppressWarnings("unchecked") protected CompositeSerializer(boolean immutableTargetType, TypeSerializer ... fieldSerializers) { - Preconditions.checkNotNull(fieldSerializers); - Preconditions.checkArgument(Arrays.stream(fieldSerializers).allMatch(Objects::nonNull)); - this.immutableTargetType = immutableTargetType && - Arrays.stream(fieldSerializers).allMatch(TypeSerializer::isImmutableType); - this.fieldSerializers = (TypeSerializer[]) fieldSerializers; - this.length = calcLength(); - this.stateful = isStateful(); - this.hashCode = Arrays.hashCode(fieldSerializers); + this( + new PrecomputedParameters(immutableTargetType, (TypeSerializer[]) fieldSerializers), + fieldSerializers); } - private boolean isStateful() { - TypeSerializer[] duplicatedSerializers = duplicateFieldSerializers(); - return IntStream.range(0, fieldSerializers.length) - .anyMatch(i -> fieldSerializers[i] != duplicatedSerializers[i]); + /** Can be used in createSerializerInstance for internal operations. */ + @SuppressWarnings("unchecked") + protected CompositeSerializer(PrecomputedParameters precomputed, TypeSerializer ... fieldSerializers) { + this.fieldSerializers = (TypeSerializer[]) fieldSerializers; + this.precomputed = new PrecomputedParameters(precomputed, this.fieldSerializers); } /** Create new instance from its fields. */ @@ -84,24 +71,28 @@ private boolean isStateful() { protected abstract Object getField(@Nonnull T value, int index); /** Factory for concrete serializer. */ - protected abstract CompositeSerializer createSerializerInstance(TypeSerializer ... originalSerializers); + protected abstract CompositeSerializer createSerializerInstance( + PrecomputedParameters precomputed, + TypeSerializer ... originalSerializers); @Override public CompositeSerializer duplicate() { - return stateful ? createSerializerInstance(duplicateFieldSerializers()) : this; + return precomputed.stateful ? + createSerializerInstance(precomputed, duplicateFieldSerializers(fieldSerializers)) : this; } - private TypeSerializer[] duplicateFieldSerializers() { + private static TypeSerializer[] duplicateFieldSerializers(TypeSerializer[] fieldSerializers) { TypeSerializer[] duplicatedSerializers = new TypeSerializer[fieldSerializers.length]; for (int index = 0; index < fieldSerializers.length; index++) { duplicatedSerializers[index] = fieldSerializers[index].duplicate(); + assert duplicatedSerializers[index] != null; } return duplicatedSerializers; } @Override public boolean isImmutableType() { - return immutableTargetType; + return precomputed.immutable; } @Override @@ -142,18 +133,7 @@ public T copy(T from, T reuse) { @Override public int getLength() { - return length; - } - - private int calcLength() { - int totalLength = 0; - for (TypeSerializer fieldSerializer : fieldSerializers) { - if (fieldSerializer.getLength() < 0) { - return -1; - } - totalLength += fieldSerializer.getLength(); - } - return totalLength; + return precomputed.length; } @Override @@ -183,7 +163,7 @@ public T deserialize(T reuse, DataInputView source) throws IOException { for (int index = 0; index < fieldSerializers.length; index++) { fields[index] = fieldSerializers[index].deserialize(getField(reuse, index), source); } - return immutableTargetType ? createInstance(fields) : createInstanceWithReuse(fields, reuse); + return precomputed.immutable ? createInstance(fields) : createInstanceWithReuse(fields, reuse); } private T createInstanceWithReuse(Object[] fields, T reuse) { @@ -204,7 +184,7 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public int hashCode() { - return hashCode; + return precomputed.hashCode; } @Override @@ -261,8 +241,10 @@ private CompatibilityResult ensureFieldCompatibility( } } } + PrecomputedParameters precomputed = + new PrecomputedParameters(this.precomputed.immutableTargetType, convertSerializers); return requiresMigration ? - CompatibilityResult.requiresMigration(createSerializerInstance(convertSerializers)) : + CompatibilityResult.requiresMigration(createSerializerInstance(precomputed, convertSerializers)) : CompatibilityResult.compatible(); } @@ -272,4 +254,60 @@ private CompatibilityResult resolveFieldCompatibility( previousSerializersAndConfigs.get(index).f0, UnloadableDummyTypeSerializer.class, previousSerializersAndConfigs.get(index).f1, fieldSerializers[index]); } + + /** This class holds composite serializer parameters which can be precomputed in advanced for better performance. */ + protected static class PrecomputedParameters implements Serializable { + /** Whether target type is immutable. */ + final boolean immutableTargetType; + + /** Whether target type and its fields are immutable. */ + final boolean immutable; + + /** Byte length of target object in serialized form. */ + private final int length; + + /** Whether any field serializer is stateful. */ + final boolean stateful; + + final int hashCode; + + PrecomputedParameters( + boolean immutableTargetType, + TypeSerializer[] fieldSerializers) { + Preconditions.checkNotNull(fieldSerializers); + int totalLength = 0; + boolean fieldsImmutable = true; + boolean stateful = false; + int hashCode = 1; + for (TypeSerializer fieldSerializer : fieldSerializers) { + Preconditions.checkNotNull(fieldSerializer); + if (fieldSerializer != fieldSerializer.duplicate()) { + stateful = true; + } + if (!fieldSerializer.isImmutableType()) { + fieldsImmutable = false; + } + if (fieldSerializer.getLength() < 0) { + totalLength = -1; + } + totalLength = totalLength >= 0 ? totalLength + fieldSerializer.getLength() : totalLength; + hashCode = 31 * hashCode + (fieldSerializer.hashCode()); + } + + this.immutableTargetType = immutableTargetType; + this.immutable = immutableTargetType && fieldsImmutable; + this.length = totalLength; + this.stateful = stateful; + this.hashCode = hashCode; + } + + /** This constructor recomputes only hash code. */ + PrecomputedParameters(PrecomputedParameters other, TypeSerializer[] fieldSerializers) { + this.immutableTargetType = other.immutableTargetType; + this.immutable = other.immutable; + this.length = other.length; + this.stateful = other.stateful; + this.hashCode = Arrays.hashCode(fieldSerializers); + } + } } diff --git a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java index 7d8ac20c82a16..fc5c241c20566 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/CompositeSerializerTest.java @@ -163,6 +163,10 @@ private static class TestListCompositeSerializer extends CompositeSerializer... fieldSerializers) { + super(precomputed, fieldSerializers); + } + @Override public List createInstance(@Nonnull Object... values) { return Arrays.asList(values); @@ -170,7 +174,7 @@ public List createInstance(@Nonnull Object... values) { @Override protected void setField(@Nonnull List value, int index, Object fieldValue) { - if (immutableTargetType) { + if (precomputed.immutable) { throw new UnsupportedOperationException("Type is immutable"); } else { value.set(index, fieldValue); @@ -183,8 +187,9 @@ protected Object getField(@Nonnull List value, int index) { } @Override - protected CompositeSerializer> createSerializerInstance(TypeSerializer... originalSerializers) { - return new TestListCompositeSerializer(immutableTargetType, originalSerializers); + protected CompositeSerializer> createSerializerInstance( + PrecomputedParameters precomputed, TypeSerializer... originalSerializers) { + return new TestListCompositeSerializer(precomputed, originalSerializers); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java index f37bc529df78e..2f7593a93eeb2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java @@ -188,6 +188,10 @@ private static class TtlSerializer extends CompositeSerializer> { super(true, userValueSerializer, LongSerializer.INSTANCE); } + TtlSerializer(PrecomputedParameters precomputed, TypeSerializer ... fieldSerializers) { + super(precomputed, fieldSerializers); + } + @SuppressWarnings("unchecked") @Override public TtlValue createInstance(@Nonnull Object ... values) { @@ -207,10 +211,12 @@ protected Object getField(@Nonnull TtlValue v, int index) { @SuppressWarnings("unchecked") @Override - protected CompositeSerializer> createSerializerInstance(TypeSerializer ... originalSerializers) { + protected CompositeSerializer> createSerializerInstance( + PrecomputedParameters precomputed, + TypeSerializer ... originalSerializers) { Preconditions.checkNotNull(originalSerializers); Preconditions.checkArgument(originalSerializers.length == 2); - return new TtlSerializer<>((TypeSerializer) originalSerializers[0]); + return new TtlSerializer<>(precomputed, (TypeSerializer) originalSerializers[0]); } } } From e39cd95b4bb3552c0e4104351ad75b138e65b7e5 Mon Sep 17 00:00:00 2001 From: Andrey Zagrebin Date: Wed, 4 Jul 2018 08:42:43 +0200 Subject: [PATCH 5/6] do not precompute hash code in CompositeSerializer --- .../common/typeutils/CompositeSerializer.java | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java index 5844c7682f9bc..66e5bdd1bbedd 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java @@ -58,7 +58,7 @@ protected CompositeSerializer(boolean immutableTargetType, TypeSerializer ... @SuppressWarnings("unchecked") protected CompositeSerializer(PrecomputedParameters precomputed, TypeSerializer ... fieldSerializers) { this.fieldSerializers = (TypeSerializer[]) fieldSerializers; - this.precomputed = new PrecomputedParameters(precomputed, this.fieldSerializers); + this.precomputed = precomputed; } /** Create new instance from its fields. */ @@ -184,7 +184,7 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public int hashCode() { - return precomputed.hashCode; + return Arrays.hashCode(fieldSerializers); } @Override @@ -257,6 +257,8 @@ private CompatibilityResult resolveFieldCompatibility( /** This class holds composite serializer parameters which can be precomputed in advanced for better performance. */ protected static class PrecomputedParameters implements Serializable { + private static final long serialVersionUID = 1L; + /** Whether target type is immutable. */ final boolean immutableTargetType; @@ -269,8 +271,6 @@ protected static class PrecomputedParameters implements Serializable { /** Whether any field serializer is stateful. */ final boolean stateful; - final int hashCode; - PrecomputedParameters( boolean immutableTargetType, TypeSerializer[] fieldSerializers) { @@ -278,7 +278,6 @@ protected static class PrecomputedParameters implements Serializable { int totalLength = 0; boolean fieldsImmutable = true; boolean stateful = false; - int hashCode = 1; for (TypeSerializer fieldSerializer : fieldSerializers) { Preconditions.checkNotNull(fieldSerializer); if (fieldSerializer != fieldSerializer.duplicate()) { @@ -291,23 +290,12 @@ protected static class PrecomputedParameters implements Serializable { totalLength = -1; } totalLength = totalLength >= 0 ? totalLength + fieldSerializer.getLength() : totalLength; - hashCode = 31 * hashCode + (fieldSerializer.hashCode()); } this.immutableTargetType = immutableTargetType; this.immutable = immutableTargetType && fieldsImmutable; this.length = totalLength; this.stateful = stateful; - this.hashCode = hashCode; - } - - /** This constructor recomputes only hash code. */ - PrecomputedParameters(PrecomputedParameters other, TypeSerializer[] fieldSerializers) { - this.immutableTargetType = other.immutableTargetType; - this.immutable = other.immutable; - this.length = other.length; - this.stateful = other.stateful; - this.hashCode = Arrays.hashCode(fieldSerializers); } } } From 9e7429ac72a34491e9fd4b4498b780bd324469d5 Mon Sep 17 00:00:00 2001 From: Andrey Zagrebin Date: Wed, 4 Jul 2018 08:56:09 +0200 Subject: [PATCH 6/6] add target type immutability computation to hashCode and equals of CompositeSerializer --- .../flink/api/common/typeutils/CompositeSerializer.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java index 66e5bdd1bbedd..b603626e83f60 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeSerializer.java @@ -184,14 +184,16 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public int hashCode() { - return Arrays.hashCode(fieldSerializers); + return 31 * Boolean.hashCode(precomputed.immutableTargetType) + Arrays.hashCode(fieldSerializers); } @Override public boolean equals(Object obj) { if (obj instanceof CompositeSerializer) { CompositeSerializer other = (CompositeSerializer) obj; - return other.canEqual(this) && Arrays.equals(fieldSerializers, other.fieldSerializers); + return other.canEqual(this) && + Arrays.equals(fieldSerializers, other.fieldSerializers) && + precomputed.immutableTargetType == other.precomputed.immutableTargetType; } else { return false; }