From 9f2e3a1d23971d4533a4f961871a80e44df9c8a2 Mon Sep 17 00:00:00 2001 From: David Arroyo Cazorla Date: Mon, 31 Oct 2016 18:53:44 +0100 Subject: [PATCH] [CROSSDATA-808] Fix Map serialization (#760) * fix 'MapType' serialization: serialize any kind of key supported by spark appart from string * Merge Fix --- .../common/serializers/RowSerializer.scala | 21 ++++++++++--------- .../serializers/RowSerializerSpec.scala | 7 +++++-- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/common/src/main/scala/com/stratio/crossdata/common/serializers/RowSerializer.scala b/common/src/main/scala/com/stratio/crossdata/common/serializers/RowSerializer.scala index cb8bd7f55..2e8f6cfcc 100644 --- a/common/src/main/scala/com/stratio/crossdata/common/serializers/RowSerializer.scala +++ b/common/src/main/scala/com/stratio/crossdata/common/serializers/RowSerializer.scala @@ -60,10 +60,10 @@ case class RowSerializer(providedSchema: StructType) extends Serializer[Row] { case (ArrayType(ty, _), JArray(arr)) => mutable.WrappedArray make arr.map(extractField(ty, _)).toArray /* Maps will be serialized as sub-objects so keys are constrained to be strings */ - case (MapType(StringType, vt, _), JObject(fields)) => - val (keys, values) = fields.unzip - val unserValues = values map (jval => extractField(vt, jval)) - ArrayBasedMapDataNotDeprecated(keys.toArray, unserValues.toArray) + case (MapType(kt, vt, _), JObject(JField("map", JObject(JField("keys", JArray(mapKeys)) :: JField("values", JArray(mapValues)) :: _) ) :: _)) => + val unserKeys = mapKeys map (jval => extractField(kt, jval)) + val unserValues = mapValues map (jval => extractField(vt, jval)) + ArrayBasedMapDataNotDeprecated(unserKeys.toArray, unserValues.toArray) case (st: StructType, JObject(JField("values",JArray(values))::_)) => deserializeWithSchema(st, values, true) } @@ -106,14 +106,15 @@ case class RowSerializer(providedSchema: StructType) extends Serializer[Row] { case v: ArrayDataNotDeprecated => JArray(v.array.toList.map(v => Extraction.decompose(v))) case v: mutable.WrappedArray[_] => JArray(v.toList.map(v => Extraction.decompose(v))) } - case (MapType(StringType, vt, _), v: MapDataNotDeprecated) => + case (MapType(kt, vt, _), v: MapDataNotDeprecated) => /* Maps will be serialized as sub-objects so keys are constrained to be strings */ - val serKeys = v.keyArray().array.map(v => v.toString) + val serKeys = v.keyArray().array.map(v => serializeField(kt -> v)) val serValues = v.valueArray.array.map(v => serializeField(vt -> v)) - JObject( - (v.keyArray.array zip serValues) map { - case (k: String, v) => JField(k, v) - } toList + JField("map", + JObject( + JField("keys", JArray(serKeys.toList)), + JField("values", JArray(serValues.toList)) + ) ) case (st: StructType, v: Row) => serializeWithSchema(st, v, true) } diff --git a/common/src/test/scala/com/stratio/crossdata/common/serializers/RowSerializerSpec.scala b/common/src/test/scala/com/stratio/crossdata/common/serializers/RowSerializerSpec.scala index 4f205096e..a7666e5dd 100644 --- a/common/src/test/scala/com/stratio/crossdata/common/serializers/RowSerializerSpec.scala +++ b/common/src/test/scala/com/stratio/crossdata/common/serializers/RowSerializerSpec.scala @@ -19,6 +19,7 @@ import com.stratio.crossdata.common.serializers.XDSerializationTest.TestCase import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @@ -47,6 +48,7 @@ class RowSerializerSpec extends XDSerializationTest[Row] with CrossdataCommonSer StructField("arraystring",ArrayType(StringType,true),true), StructField("mapstringint",MapType(StringType,IntegerType,true),true), StructField("mapstringstring",MapType(StringType,StringType,true),true), + StructField("maptimestampinteger",MapType(TimestampType,IntegerType,true),true), StructField("struct",StructType(StructField("field1",IntegerType,true)::StructField("field2",IntegerType,true) ::Nil), true), StructField("arraystruct",ArrayType(StructType(StructField("field1",IntegerType,true)::StructField("field2", IntegerType,true)::Nil),true),true), StructField("structofstruct",StructType(StructField("field1",TimestampType,true)::StructField("field2", IntegerType, true)::StructField("struct1",StructType(StructField("structField1",StringType,true)::StructField("structField2",IntegerType,true)::Nil),true)::Nil),true) @@ -72,6 +74,7 @@ class RowSerializerSpec extends XDSerializationTest[Row] with CrossdataCommonSer WrappedArray make Array("hello", "world"), ArrayBasedMapData(Map("b" -> 2)), ArrayBasedMapData(Map("a" -> "A", "b" -> "B")), + ArrayBasedMapData(Map(java.sql.Timestamp.valueOf("2015-11-30 10:00:00.0") -> 25, java.sql.Timestamp.valueOf("2015-11-30 10:00:00.0") -> 12)), new GenericRowWithSchema(Array(99,98), StructType(StructField("field1", IntegerType) ::StructField("field2", IntegerType)::Nil)), WrappedArray make Array( @@ -102,11 +105,11 @@ class RowSerializerSpec extends XDSerializationTest[Row] with CrossdataCommonSer implicit val formats = json4sJacksonFormats + new RowSerializer(schema) - + override def testCases: Seq[TestCase] = Seq( TestCase("marshall & unmarshall a row with no schema", rowWithNoSchema), TestCase("marshall & unmarshall a row with schema", rowWithSchema) ) -} +} \ No newline at end of file