From d6ed4b7cec98b34ff609df3fdfcd009c1f01c50a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 21 Oct 2018 08:57:18 +0800 Subject: [PATCH 1/3] Datatype of serializers should be accessible. --- .../apache/spark/sql/catalyst/encoders/RowEncoder.scala | 6 +++--- .../spark/sql/catalyst/encoders/RowEncoderSuite.scala | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3340789398f9c..45bbb1451236a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -171,7 +171,7 @@ object RowEncoder { if (inputObject.nullable) { If(IsNull(inputObject), - Literal.create(null, inputType), + Literal.create(null, nonNullOutput.dataType), nonNullOutput) } else { nonNullOutput @@ -187,7 +187,7 @@ object RowEncoder { val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), - Literal.create(null, field.dataType), + Literal.create(null, fieldValue.dataType), fieldValue ) } else { @@ -198,7 +198,7 @@ object RowEncoder { if (inputObject.nullable) { If(IsNull(inputObject), - Literal.create(null, inputType), + Literal.create(null, nonNullOutput.dataType), nonNullOutput) } else { nonNullOutput diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 8d89f9c6c41d4..06c28733f82dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -273,6 +273,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + test("SPARK-25791: Datatype of serializers should be accessible") { + val udtSQLType = new StructType().add("a", IntegerType) + val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", "serializedPyClass") + val schema = new StructType().add("pythonUDT", pythonUDT, true) + val encoder = RowEncoder(schema) + encoder.serializer.foreach(s => println(s.dataType)) + } + for { elementType <- Seq(IntegerType, StringType) containsNull <- Seq(true, false) From 6d41fe00bfb62dea91632072a3e7abcb143c9183 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 21 Oct 2018 09:55:26 +0800 Subject: [PATCH 2/3] Fix scala style. --- .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 06c28733f82dc..a70a0fd0e1392 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -278,7 +278,9 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", "serializedPyClass") val schema = new StructType().add("pythonUDT", pythonUDT, true) val encoder = RowEncoder(schema) + // scalastyle:off println encoder.serializer.foreach(s => println(s.dataType)) + // scalastyle:on println } for { From 61c8b2f3488232d6ec6d134e21c4c5887d489065 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 23 Oct 2018 17:48:03 +0800 Subject: [PATCH 3/3] Address comments. --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 2 ++ .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 45bbb1451236a..ae89f98b19025 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -187,6 +187,8 @@ object RowEncoder { val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), + // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`. + // We should use `fieldValue.dataType` here. Literal.create(null, fieldValue.dataType), fieldValue ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index a70a0fd0e1392..235732134d4b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -278,9 +278,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", "serializedPyClass") val schema = new StructType().add("pythonUDT", pythonUDT, true) val encoder = RowEncoder(schema) - // scalastyle:off println - encoder.serializer.foreach(s => println(s.dataType)) - // scalastyle:on println + assert(encoder.serializer(0).dataType == pythonUDT.sqlType) } for {