Skip to content

Commit

Permalink
[FLINK-7835][cep] Fix duplicate() in NFASerializer.
Browse files Browse the repository at this point in the history
  • Loading branch information
kl0u committed Oct 13, 2017
1 parent 57333c6 commit ff9cefb
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 182 deletions.
Expand Up @@ -29,6 +29,7 @@
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;

/**
* Type serializer which keeps track of the serialized objects so that each object is only
Expand All @@ -53,7 +54,7 @@ public final class NonDuplicatingTypeSerializer<T> extends TypeSerializer<T> {
private transient IdentityHashMap<T, Integer> identityMap;

// here we store the already deserialized objects
private transient ArrayList<T> elementList;
private transient List<T> elementList;

public NonDuplicatingTypeSerializer(final TypeSerializer<T> typeSerializer) {
this.typeSerializer = typeSerializer;
Expand Down Expand Up @@ -82,7 +83,7 @@ public boolean isImmutableType() {

@Override
public TypeSerializer<T> duplicate() {
return new NonDuplicatingTypeSerializer<>(typeSerializer);
return new NonDuplicatingTypeSerializer<>(typeSerializer.duplicate());
}

@Override
Expand Down
Expand Up @@ -28,9 +28,7 @@
import org.apache.flink.api.common.typeutils.base.EnumSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream;
import org.apache.flink.cep.NonDuplicatingTypeSerializer;
import org.apache.flink.cep.nfa.compiler.NFACompiler;
import org.apache.flink.cep.nfa.compiler.NFAStateNameHandler;
Expand All @@ -48,7 +46,6 @@
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OptionalDataException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -489,7 +486,6 @@ private boolean isSelfIgnore(final StateTransition<T> edge) {
}
}


/**
* Computes the next computation states based on the given computation state, the current event,
* its timestamp and the internal state machine. The algorithm is:
Expand Down Expand Up @@ -793,53 +789,6 @@ Map<String, List<T>> extractCurrentMatches(final ComputationState<T> computation
return result;
}

////////////////////// Fault-Tolerance //////////////////////

private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException {
ois.defaultReadObject();

int numberComputationStates = ois.readInt();

computationStates = new LinkedList<>();

final List<ComputationState<T>> readComputationStates = new ArrayList<>(numberComputationStates);

for (int i = 0; i < numberComputationStates; i++) {
ComputationState<T> computationState = readComputationState(ois);
readComputationStates.add(computationState);
}

this.computationStates.addAll(readComputationStates);
nonDuplicatingTypeSerializer.clearReferences();
}

@SuppressWarnings("unchecked")
private ComputationState<T> readComputationState(ObjectInputStream ois) throws IOException, ClassNotFoundException {
final State<T> state = (State<T>) ois.readObject();
State<T> previousState;
try {
previousState = (State<T>) ois.readObject();
} catch (OptionalDataException e) {
previousState = null;
}

final long timestamp = ois.readLong();
final DeweyNumber version = (DeweyNumber) ois.readObject();
final long startTimestamp = ois.readLong();

final boolean hasEvent = ois.readBoolean();
final T event;

if (hasEvent) {
DataInputViewStreamWrapper input = new DataInputViewStreamWrapper(ois);
event = nonDuplicatingTypeSerializer.deserialize(input);
} else {
event = null;
}

return ComputationState.createState(this, state, previousState, event, 0, timestamp, version, startTimestamp);
}

////////////////////// New Serialization //////////////////////

/**
Expand Down Expand Up @@ -893,8 +842,8 @@ public boolean isImmutableType() {
}

@Override
public TypeSerializer<NFA<T>> duplicate() {
return this;
public NFASerializer<T> duplicate() {
return new NFASerializer<>(eventSerializer.duplicate());
}

@Override
Expand All @@ -906,21 +855,13 @@ public NFA<T> createInstance() {
public NFA<T> copy(NFA<T> from) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);

serialize(from, new DataOutputViewStreamWrapper(oos));

oos.close();
serialize(from, new DataOutputViewStreamWrapper(baos));
baos.close();

byte[] data = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais);

@SuppressWarnings("unchecked")
NFA<T> copy = deserialize(new DataInputViewStreamWrapper(ois));
ois.close();
NFA<T> copy = deserialize(new DataInputViewStreamWrapper(bais));
bais.close();
return copy;
} catch (IOException e) {
Expand Down Expand Up @@ -1236,91 +1177,4 @@ private IterativeCondition<T> deserializeCondition(DataInputView in) throws IOEx
return null;
}
}

////////////////// Old Serialization //////////////////////

/**
* A {@link TypeSerializer} for {@link NFA} that uses Java Serialization.
*/
public static class Serializer<T> extends TypeSerializerSingleton<NFA<T>> {

private static final long serialVersionUID = 1L;

@Override
public boolean isImmutableType() {
return false;
}

@Override
public NFA<T> createInstance() {
return null;
}

@Override
public NFA<T> copy(NFA<T> from) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);

oos.writeObject(from);

oos.close();
baos.close();

byte[] data = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais);

@SuppressWarnings("unchecked")
NFA<T> copy = (NFA<T>) ois.readObject();
ois.close();
bais.close();
return copy;
} catch (IOException | ClassNotFoundException e) {
throw new RuntimeException("Could not copy NFA.", e);
}
}

@Override
public NFA<T> copy(NFA<T> from, NFA<T> reuse) {
return copy(from);
}

@Override
public int getLength() {
return 0;
}

@Override
public void serialize(NFA<T> record, DataOutputView target) throws IOException {
throw new UnsupportedOperationException("This is the deprecated serialization strategy.");
}

@Override
public NFA<T> deserialize(DataInputView source) throws IOException {
try (ObjectInputStream ois = new ObjectInputStream(new DataInputViewStream(source))) {
return (NFA<T>) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException("Could not deserialize NFA.", e);
}
}

@Override
public NFA<T> deserialize(NFA<T> reuse, DataInputView source) throws IOException {
return deserialize(source);
}

@Override
public void copy(DataInputView source, DataOutputView target) throws IOException {
int size = source.readInt();
target.writeInt(size);
target.write(source, size);
}

@Override
public boolean canEqual(Object obj) {
return obj instanceof Serializer;
}
}
}
Expand Up @@ -38,8 +38,6 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -829,40 +827,44 @@ public SharedBufferSerializer(
this.versionSerializer = versionSerializer;
}

public TypeSerializer<DeweyNumber> getVersionSerializer() {
return versionSerializer;
}

public TypeSerializer<K> getKeySerializer() {
return keySerializer;
}

public TypeSerializer<V> getValueSerializer() {
return valueSerializer;
}

@Override
public boolean isImmutableType() {
return false;
}

@Override
public TypeSerializer<SharedBuffer<K, V>> duplicate() {
return new SharedBufferSerializer<>(keySerializer, valueSerializer);
public SharedBufferSerializer<K, V> duplicate() {
return new SharedBufferSerializer<>(keySerializer.duplicate(), valueSerializer.duplicate());
}

@Override
public SharedBuffer<K, V> createInstance() {
return new SharedBuffer<>(new NonDuplicatingTypeSerializer<V>(valueSerializer));
return new SharedBuffer<>(new NonDuplicatingTypeSerializer<>(valueSerializer.duplicate()));
}

@Override
public SharedBuffer<K, V> copy(SharedBuffer from) {
public SharedBuffer<K, V> copy(SharedBuffer<K, V> from) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);

serialize(from, new DataOutputViewStreamWrapper(oos));

oos.close();
serialize(from, new DataOutputViewStreamWrapper(baos));
baos.close();

byte[] data = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais);

@SuppressWarnings("unchecked")
SharedBuffer<K, V> copy = deserialize(new DataInputViewStreamWrapper(ois));
ois.close();
SharedBuffer<K, V> copy = deserialize(new DataInputViewStreamWrapper(bais));
bais.close();

return copy;
Expand All @@ -872,7 +874,7 @@ public SharedBuffer<K, V> copy(SharedBuffer from) {
}

@Override
public SharedBuffer<K, V> copy(SharedBuffer from, SharedBuffer reuse) {
public SharedBuffer<K, V> copy(SharedBuffer<K, V> from, SharedBuffer<K, V> reuse) {
return copy(from);
}

Expand All @@ -882,7 +884,7 @@ public int getLength() {
}

@Override
public void serialize(SharedBuffer record, DataOutputView target) throws IOException {
public void serialize(SharedBuffer<K, V> record, DataOutputView target) throws IOException {
Map<K, SharedBufferPage<K, V>> pages = record.pages;
Map<SharedBufferEntry<K, V>, Integer> entryIDs = new HashMap<>();

Expand Down Expand Up @@ -955,7 +957,7 @@ public void serialize(SharedBuffer record, DataOutputView target) throws IOExcep
}

@Override
public SharedBuffer deserialize(DataInputView source) throws IOException {
public SharedBuffer<K, V> deserialize(DataInputView source) throws IOException {
List<SharedBufferEntry<K, V>> entryList = new ArrayList<>();
Map<K, SharedBufferPage<K, V>> pages = new HashMap<>();

Expand Down Expand Up @@ -1013,11 +1015,11 @@ public SharedBuffer deserialize(DataInputView source) throws IOException {
// here we put the old NonDuplicating serializer because this needs to create a copy
// of the buffer, as created by the NFA. There, for compatibility reasons, we have left
// the old serializer.
return new SharedBuffer(new NonDuplicatingTypeSerializer(valueSerializer), pages);
return new SharedBuffer<>(new NonDuplicatingTypeSerializer<>(valueSerializer), pages);
}

@Override
public SharedBuffer deserialize(SharedBuffer reuse, DataInputView source) throws IOException {
public SharedBuffer<K, V> deserialize(SharedBuffer<K, V> reuse, DataInputView source) throws IOException {
return deserialize(source);
}

Expand Down Expand Up @@ -1068,11 +1070,19 @@ public void copy(DataInputView source, DataOutputView target) throws IOException

@Override
public boolean equals(Object obj) {
return obj == this ||
(obj != null && obj.getClass().equals(getClass()) &&
keySerializer.equals(((SharedBufferSerializer<?, ?>) obj).keySerializer) &&
valueSerializer.equals(((SharedBufferSerializer<?, ?>) obj).valueSerializer) &&
versionSerializer.equals(((SharedBufferSerializer<?, ?>) obj).versionSerializer));
if (obj == this) {
return true;
}

if (obj == null || !Objects.equals(obj.getClass(), getClass())) {
return false;
}

SharedBufferSerializer other = (SharedBufferSerializer) obj;
return
Objects.equals(keySerializer, other.getKeySerializer()) &&
Objects.equals(valueSerializer, other.getValueSerializer()) &&
Objects.equals(versionSerializer, other.getVersionSerializer());
}

@Override
Expand Down
Expand Up @@ -310,14 +310,13 @@ public boolean filter(Event value) throws Exception {
NFA.NFASerializer<Event> copySerializer = new NFA.NFASerializer<>(Event.createTypeSerializer());
ByteArrayInputStream in = new ByteArrayInputStream(baos.toByteArray());
ByteArrayOutputStream out = new ByteArrayOutputStream();
copySerializer.copy(new DataInputViewStreamWrapper(in), new DataOutputViewStreamWrapper(out));
copySerializer.duplicate().copy(new DataInputViewStreamWrapper(in), new DataOutputViewStreamWrapper(out));
in.close();
out.close();

// deserialize
ByteArrayInputStream bais = new ByteArrayInputStream(out.toByteArray());
NFA.NFASerializer<Event> deserializer = new NFA.NFASerializer<>(Event.createTypeSerializer());
NFA<Event> copy = deserializer.deserialize(new DataInputViewStreamWrapper(bais));
NFA<Event> copy = serializer.duplicate().deserialize(new DataInputViewStreamWrapper(bais));
bais.close();

assertEquals(nfa, copy);
Expand Down
Expand Up @@ -160,7 +160,7 @@ public void testSharedBufferSerialization() throws IOException, ClassNotFoundExc
serializer.serialize(sharedBuffer, new DataOutputViewStreamWrapper(baos));

ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
SharedBuffer<String, Event> copy = serializer.deserialize(new DataInputViewStreamWrapper(bais));
SharedBuffer<String, Event> copy = serializer.duplicate().deserialize(new DataInputViewStreamWrapper(bais));

assertEquals(sharedBuffer, copy);
}
Expand Down

0 comments on commit ff9cefb

Please sign in to comment.