From a4d2731eb75032f958e96323a075eb8bc7d11c73 Mon Sep 17 00:00:00 2001 From: Shiti Date: Thu, 25 Jun 2015 16:06:10 +0530 Subject: [PATCH] [FLINK-2230]handling null values for TupleSerializer --- .../typeutils/runtime/TupleSerializer.java | 111 +++++++++++++----- .../runtime/TupleSerializerTest.java | 10 ++ .../runtime/TupleSerializerTestInstance.java | 14 +-- .../scala/typeutils/CaseClassSerializer.scala | 60 +++++++++- .../scala/runtime/TupleSerializerTest.scala | 6 + .../runtime/TupleSerializerTestInstance.scala | 8 +- 6 files changed, 160 insertions(+), 49 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializer.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializer.java index 231486d13130d..6808550779c76 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializer.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializer.java @@ -18,19 +18,18 @@ package org.apache.flink.api.java.typeutils.runtime; -import java.io.IOException; - import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.types.NullFieldException; +import java.io.IOException; +import java.util.BitSet; public final class TupleSerializer extends TupleSerializerBase { private static final long serialVersionUID = 1L; - + public TupleSerializer(Class tupleClass, TypeSerializer[] fieldSerializers) { super(tupleClass, fieldSerializers); } @@ -38,9 +37,10 @@ public TupleSerializer(Class tupleClass, TypeSerializer[] fieldSerializers @Override public TupleSerializer duplicate() { boolean stateful = false; - TypeSerializer[] duplicateFieldSerializers = new TypeSerializer[fieldSerializers.length]; + int fieldCount = fieldSerializers.length; + TypeSerializer[] duplicateFieldSerializers = new TypeSerializer[fieldCount]; - for (int i = 0; i < fieldSerializers.length; i++) { + for (int i = 0; i < fieldCount; i++) { duplicateFieldSerializers[i] = fieldSerializers[i].duplicate(); if (duplicateFieldSerializers[i] != fieldSerializers[i]) { // at least one of them is stateful @@ -59,11 +59,11 @@ public TupleSerializer duplicate() { public T createInstance() { try { T t = tupleClass.newInstance(); - + for (int i = 0; i < arity; i++) { t.setField(fieldSerializers[i].createInstance(), i); } - + return t; } catch (Exception e) { @@ -88,57 +88,106 @@ public T createInstance(Object[] fields) { } } - @Override + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + int bitSetSize = (arity / 8) + 1; + byte[] buffer = new byte[bitSetSize]; + source.read(buffer); + BitSet bitIndicator = BitSet.valueOf(buffer); + + target.write(bitIndicator.toByteArray()); + + for (int i = 0; i < arity; i++) { + if (bitIndicator.get(i)) { + fieldSerializers[i].copy(source, target); + } + } + } + + @Override public T copy(T from) { T target = instantiateRaw(); for (int i = 0; i < arity; i++) { - Object copy = fieldSerializers[i].copy(from.getField(i)); - target.setField(copy, i); - } + Object field = from.getField(i); + if (field != null) { + Object copy = fieldSerializers[i].copy(field); + target.setField(copy, i); + } else { + target.setField(null, i); + } + } return target; } - + @Override public T copy(T from, T reuse) { for (int i = 0; i < arity; i++) { - Object copy = fieldSerializers[i].copy(from.getField(i), reuse.getField(i)); - reuse.setField(copy, i); + Object field = from.getField(i); + if (field != null) { + Object copy = fieldSerializers[i].copy(field); + reuse.setField(copy, i); + } else { + reuse.setField(null, i); + } } - return reuse; } @Override public void serialize(T value, DataOutputView target) throws IOException { - for (int i = 0; i < arity; i++) { - Object o = value.getField(i); - try { - fieldSerializers[i].serialize(o, target); - } catch (NullPointerException npex) { - throw new NullFieldException(i); - } - } - } + BitSet bitIndicator = new BitSet(arity); + for (int i = 0; i < arity; i++) { + Object o = value.getField(i); + bitIndicator.set(i,o != null); + } + target.write(bitIndicator.toByteArray()); + + for (int i = 0; i < arity; i++) { + Object o = value.getField(i); + if (o != null) { + fieldSerializers[i].serialize(o, target); + } + } + } @Override public T deserialize(DataInputView source) throws IOException { T tuple = instantiateRaw(); + + int bitSetSize = (arity/8)+1; + byte[] buffer = new byte[bitSetSize]; + source.read(buffer); + BitSet bitIndicator = BitSet.valueOf(buffer); + for (int i = 0; i < arity; i++) { - Object field = fieldSerializers[i].deserialize(source); - tuple.setField(field, i); + if(!bitIndicator.get(i)){ + tuple.setField(null,i); + } else { + Object field = fieldSerializers[i].deserialize(source); + tuple.setField(field, i); + } } return tuple; } - + @Override public T deserialize(T reuse, DataInputView source) throws IOException { + int bitSetSize = (arity/8)+1; + byte[] buffer = new byte[bitSetSize]; + source.read(buffer); + BitSet bitIndicator = BitSet.valueOf(buffer); + for (int i = 0; i < arity; i++) { - Object field = fieldSerializers[i].deserialize(reuse.getField(i), source); - reuse.setField(field, i); + if(!bitIndicator.get(i)){ + reuse.setField(null,i); + } else { + Object field = fieldSerializers[i].deserialize(reuse.getField(i), source); + reuse.setField(field, i); + } } return reuse; } - + private T instantiateRaw() { try { return tupleClass.newInstance(); diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java index 96f8306ca8fc5..d5662ab76d649 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java @@ -207,6 +207,16 @@ public void testTuple5CustomObjects() { runTests(testTuples); } + @Test + public void testTupleWithNull() { + @SuppressWarnings("unchecked") + Tuple2[] testTuples = new Tuple2[] { + new Tuple2(0,"a"), new Tuple2(1,"b"), new Tuple2(-1,null) + }; + + runTests(testTuples); + } + private void runTests(T... instances) { try { TupleTypeInfo tupleTypeInfo = (TupleTypeInfo) TypeExtractor.getForObject(instances[0]); diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTestInstance.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTestInstance.java index a196984d4e513..02791a8462bf8 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTestInstance.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTestInstance.java @@ -19,16 +19,16 @@ package org.apache.flink.api.java.typeutils.runtime; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -import java.util.Arrays; - import org.apache.flink.api.common.typeutils.SerializerTestInstance; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple; import org.junit.Assert; +import java.util.Arrays; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + public class TupleSerializerTestInstance extends SerializerTestInstance { public TupleSerializerTestInstance(TypeSerializer serializer, Class typeClass, int length, T[] testData) { @@ -41,8 +41,8 @@ protected void deepEquals(String message, T shouldTuple, T isTuple) { for (int i = 0; i < shouldTuple.getArity(); i++) { Object should = shouldTuple.getField(i); Object is = isTuple.getField(i); - - if (should.getClass().isArray()) { + + if (should != null && should.getClass().isArray()) { if (should instanceof boolean[]) { Assert.assertTrue(message, Arrays.equals((boolean[]) should, (boolean[]) is)); } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassSerializer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassSerializer.scala index 2a76c379fa425..1683530d0538f 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassSerializer.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassSerializer.scala @@ -17,10 +17,11 @@ */ package org.apache.flink.api.scala.typeutils -import org.apache.commons.lang.SerializationUtils +import java.util.BitSet + import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase -import org.apache.flink.core.memory.{DataOutputView, DataInputView} +import org.apache.flink.core.memory.{DataInputView, DataOutputView} /** * Serializer for Case Classes. Creation and access is different from @@ -76,17 +77,37 @@ abstract class CaseClassSerializer[T <: Product]( initArray() var i = 0 while (i < arity) { - fields(i) = fieldSerializers(i).copy(from.productElement(i).asInstanceOf[AnyRef]) + val fieldValue: AnyRef = from.productElement(i).asInstanceOf[AnyRef] + if(fieldValue != null){ + fields(i) = fieldSerializers(i).copy(fieldValue) + } else { + fields(i) = null + } i += 1 } createInstance(fields) } def serialize(value: T, target: DataOutputView) { - var i = 0 + val bitIndicator: BitSet = new BitSet(arity) + var i: Int = 0 + while (i < arity) { + { + val element: Any = value.productElement(i) + bitIndicator.set(i, element != null) + } + i += 1 + } + + target.write(bitIndicator.toByteArray) + + i = 0 while (i < arity) { val serializer = fieldSerializers(i).asInstanceOf[TypeSerializer[Any]] - serializer.serialize(value.productElement(i), target) + val element: Any = value.productElement(i) + if (element != null) { + serializer.serialize(element, target) + } i += 1 } } @@ -97,9 +118,19 @@ abstract class CaseClassSerializer[T <: Product]( def deserialize(source: DataInputView): T = { initArray() + + val bitSetSize: Int = (arity / 8) + 1 + val buffer: Array[Byte] = new Array[Byte](bitSetSize) + source.read(buffer) + val bitIndicator: BitSet = BitSet.valueOf(buffer) + var i = 0 while (i < arity) { - fields(i) = fieldSerializers(i).deserialize(source) + if(bitIndicator.get(i)){ + fields(i) = fieldSerializers(i).deserialize(source) + } else { + fields(i) = null + } i += 1 } createInstance(fields) @@ -110,4 +141,21 @@ abstract class CaseClassSerializer[T <: Product]( fields = new Array[AnyRef](arity) } } + + override def copy(source: DataInputView, target: DataOutputView): Unit = { + val bitSetSize: Int = (arity / 8) + 1 + val buffer: Array[Byte] = new Array[Byte](bitSetSize) + source.read(buffer) + val bitIndicator: BitSet = BitSet.valueOf(buffer) + + target.write(bitIndicator.toByteArray) + + var i: Int = 0 + while (i < arity) { + if (bitIndicator.get(i)) { + fieldSerializers(i).copy(source, target) + } + i += 1 + } + } } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTest.scala index c436d6236d1f8..08540dc842ab7 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTest.scala @@ -205,6 +205,12 @@ class TupleSerializerTest { runTests(testTuples) } + @Test + def testTupleWithNull(): Unit = { + val testTuples = Array((0,"a"), (1,"b"), (-1,null)) + runTests(testTuples) + } + private final def runTests[T <: Product : TypeInformation](instances: Array[T]) { try { // Register the custom Kryo Serializer diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTestInstance.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTestInstance.scala index 7a425900c96c9..1fc44e203b419 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTestInstance.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/runtime/TupleSerializerTestInstance.scala @@ -17,11 +17,9 @@ */ package org.apache.flink.api.scala.runtime +import org.apache.flink.api.common.typeutils.{SerializerTestInstance, TypeSerializer} import org.junit.Assert._ -import org.apache.flink.api.common.typeutils.SerializerTestInstance -import org.apache.flink.api.common.typeutils.TypeSerializer -import org.junit.Assert -import org.junit.Test +import org.junit.{Assert, Test} class TupleSerializerTestInstance[T <: Product] ( @@ -57,7 +55,7 @@ class TupleSerializerTestInstance[T <: Product] ( for (i <- 0 until shouldTuple.productArity) { val should = shouldTuple.productElement(i) val is = isTuple.productElement(i) - if (should.getClass.isArray) { + if (should != null && should.getClass.isArray) { should match { case booleans: Array[Boolean] => Assert.assertTrue(message, booleans.sameElements(is.asInstanceOf[Array[Boolean]]))