From acec27f2e1c73c7f3c8322f89ba188ba65a1d4b3 Mon Sep 17 00:00:00 2001 From: twalthr Date: Tue, 6 Dec 2016 15:07:45 +0100 Subject: [PATCH 1/2] [FLINK-5011] [types] TraversableSerializer does not perform a deep copy of the elements it is traversing --- .../typeutils/TraversableSerializer.scala | 4 +- .../ScalaSpecialTypesSerializerTest.scala | 18 ++++- .../runtime/TraversableSerializerTest.scala | 78 ++++++++++--------- 3 files changed, 60 insertions(+), 40 deletions(-) diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/TraversableSerializer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/TraversableSerializer.scala index 7d14dc14e622a..d1b908591f2ae 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/TraversableSerializer.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/TraversableSerializer.scala @@ -58,14 +58,14 @@ abstract class TraversableSerializer[T <: TraversableOnce[E], E]( cbf().result() } - override def isImmutableType: Boolean = true + override def isImmutableType: Boolean = false override def getLength: Int = -1 override def copy(from: T): T = { val builder = cbf() builder.sizeHint(from.size) - from foreach { e => builder += e } + from foreach { e => builder += elementSerializer.copy(e) } builder.result() } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala index 155160c08d988..1ff2b0745dc28 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala @@ -22,7 +22,6 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeutils.{SerializerTestInstance, TypeSerializer} import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer import org.apache.flink.api.scala._ -import org.apache.flink.api.scala.typeutils.EnumValueTypeInfo import org.junit.Assert._ import org.junit.{Assert, Test} @@ -92,6 +91,23 @@ class ScalaSpecialTypesSerializerTest { runTests(testData) } + @Test + def testStringArray(): Unit = { + val testData = Array(Array("Foo", "Bar"), Array("Hello")) + runTests(testData) + } + + @Test + def testIntArray(): Unit = { + val testData = Array(Array(1,3,3,7), Array(4,7)) + runTests(testData) + } + + @Test + def testArrayWithCaseClass(): Unit = { + val testData = Array(Array((1, "String"), (2, "Foo")), Array((4, "String"), (3, "Foo"))) + runTests(testData) + } private final def runTests[T : TypeInformation](instances: Array[T]) { try { diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala index 65648b6928a40..4987f72186dee 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala @@ -18,16 +18,15 @@ package org.apache.flink.api.scala.runtime import org.apache.flink.api.common.ExecutionConfig -import org.apache.flink.api.common.functions.InvalidTypesException -import org.junit.Assert._ - -import org.apache.flink.api.common.typeutils.{TypeSerializer, SerializerTestInstance} import org.apache.flink.api.common.typeinfo.TypeInformation -import org.junit.{Ignore, Assert, Test} - +import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.typeutils.TraversableSerializer +import org.junit.Assert._ +import org.junit.{Assert, Ignore, Test} -import scala.collection.immutable.{BitSet, SortedSet, LinearSeq} +import scala.collection.immutable.{BitSet, LinearSeq, SortedSet} import scala.collection.{SortedMap, mutable} class TraversableSerializerTest { @@ -57,6 +56,7 @@ class TraversableSerializerTest { } @Test + @Ignore def testSortedMap(): Unit = { // SortedSet is not supported right now. val testData = Array(SortedMap("Hello" -> 1, "World" -> 2), SortedMap("Foo" -> 42)) @@ -70,6 +70,7 @@ class TraversableSerializerTest { } @Test + @Ignore def testSortedSet(): Unit = { // SortedSet is not supported right now. val testData = Array(SortedSet(1,2,3), SortedSet(2,3)) @@ -88,24 +89,6 @@ class TraversableSerializerTest { runTests(testData) } - @Test - def testStringArray(): Unit = { - val testData = Array(Array("Foo", "Bar"), Array("Hello")) - runTests(testData) - } - - @Test - def testIntArray(): Unit = { - val testData = Array(Array(1,3,3,7), Array(4,7)) - runTests(testData) - } - - @Test - def testArrayWithCaseClass(): Unit = { - val testData = Array(Array((1, "String"), (2, "Foo")), Array((4, "String"), (3, "Foo"))) - runTests(testData) - } - @Test def testWithCaseClass(): Unit = { val testData = Array(Seq((1, "String"), (2, "Foo")), Seq((4, "String"), (3, "Foo"))) @@ -124,31 +107,27 @@ class TraversableSerializerTest { // have a typeClass of Object, and therefore not deserialize the elements correctly. // It does work when used in a Job, though. Because the Objects get cast to // the correct type in the user function. - val testData = Array(Seq(1,1L,1d,true,"Hello"), Seq(2,2L,2d,false,"Ciao")) + val testData = Array(Seq(1, 1L, 1d, true, "Hello"), Seq(2, 2L, 2d, false, "Ciao")) runTests(testData) } - - private final def runTests[T : TypeInformation](instances: Array[T]) { try { val typeInfo = implicitly[TypeInformation[T]] val serializer = typeInfo.createSerializer(new ExecutionConfig) val typeClass = typeInfo.getTypeClass - val test = - new ScalaSpecialTypesSerializerTestInstance[T](serializer, typeClass, -1, instances) + val test = new TraversableSerializerTestInstance[T](serializer, typeClass, -1, instances) test.testAll() } catch { - case e: Exception => { + case e: Exception => System.err.println(e.getMessage) e.printStackTrace() Assert.fail(e.getMessage) - } } } } -class Pojo(val name: String, val count: Int) { +class Pojo(var name: String, var count: Int) { def this() = this("", -1) override def equals(other: Any): Boolean = { @@ -159,12 +138,38 @@ class Pojo(val name: String, val count: Int) { } } -class ScalaCollectionSerializerTestInstance[T]( +class TraversableSerializerTestInstance[T]( serializer: TypeSerializer[T], typeClass: Class[T], length: Int, testData: Array[T]) - extends SerializerTestInstance[T](serializer, typeClass, length, testData: _*) { + extends ScalaSpecialTypesSerializerTestInstance[T](serializer, typeClass, length, testData) { + + @Test + override def testAll(): Unit = { + super.testAll() + testTraversableDeepCopy() + } + + @Test + def testTraversableDeepCopy(): Unit = { + val serializer = getSerializer + val elementSerializer = serializer.asInstanceOf[TraversableSerializer[_, _]].elementSerializer + val data = getTestData + + // check for deep copy if type is immutable and not serialized with Kryo + // elements of traversable should not have reference equality + if (!elementSerializer.isImmutableType && !elementSerializer.isInstanceOf[KryoSerializer[_]]) { + data.foreach { datum => + val original = datum.asInstanceOf[Traversable[_]].toIterable + val copy = serializer.copy(datum).asInstanceOf[Traversable[_]].toIterable + copy.zip(original).foreach { case (c: AnyRef, o: AnyRef) => + assertTrue("Copy of mutable element has reference equality.", c ne o) + case _ => // ok + } + } + } + } @Test override def testInstantiate(): Unit = { @@ -179,11 +184,10 @@ class ScalaCollectionSerializerTestInstance[T]( // assertEquals("Type of the instantiated object is wrong.", tpe, instance.getClass) } catch { - case e: Exception => { + case e: Exception => System.err.println(e.getMessage) e.printStackTrace() fail("Exception in test: " + e.getMessage) - } } } From b61adcd6a8151d4a1ee75445e0cb3d0130f1bff2 Mon Sep 17 00:00:00 2001 From: twalthr Date: Thu, 15 Dec 2016 15:49:29 +0100 Subject: [PATCH 2/2] Move unrelated tests --- .../ScalaSpecialTypesSerializerTest.scala | 13 +++++++++++++ .../runtime/TraversableSerializerTest.scala | 18 ++---------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala index 1ff2b0745dc28..555359f89abf6 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/ScalaSpecialTypesSerializerTest.scala @@ -25,6 +25,7 @@ import org.apache.flink.api.scala._ import org.junit.Assert._ import org.junit.{Assert, Test} +import scala.collection.{SortedMap, SortedSet} import scala.util.{Failure, Success} class ScalaSpecialTypesSerializerTest { @@ -109,6 +110,18 @@ class ScalaSpecialTypesSerializerTest { runTests(testData) } + @Test + def testSortedMap(): Unit = { + val testData = Array(SortedMap("Hello" -> 1, "World" -> 2), SortedMap("Foo" -> 42)) + runTests(testData) + } + + @Test + def testSortedSet(): Unit = { + val testData = Array(SortedSet(1,2,3), SortedSet(2,3)) + runTests(testData) + } + private final def runTests[T : TypeInformation](instances: Array[T]) { try { val typeInfo = implicitly[TypeInformation[T]] diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala index 4987f72186dee..e177e7c4ac42f 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TraversableSerializerTest.scala @@ -31,6 +31,8 @@ import scala.collection.{SortedMap, mutable} class TraversableSerializerTest { + // Note: SortedMap and SortedSet are serialized with Kryo + @Test def testSeq(): Unit = { val testData = Array(Seq(1,2,3), Seq(2,3)) @@ -55,28 +57,12 @@ class TraversableSerializerTest { runTests(testData) } - @Test - @Ignore - def testSortedMap(): Unit = { - // SortedSet is not supported right now. - val testData = Array(SortedMap("Hello" -> 1, "World" -> 2), SortedMap("Foo" -> 42)) - runTests(testData) - } - @Test def testSet(): Unit = { val testData = Array(Set(1,2,3,3), Set(2,3)) runTests(testData) } - @Test - @Ignore - def testSortedSet(): Unit = { - // SortedSet is not supported right now. - val testData = Array(SortedSet(1,2,3), SortedSet(2,3)) - runTests(testData) - } - @Test def testBitSet(): Unit = { val testData = Array(BitSet(1,2,3,4), BitSet(2,3,2))