From c1589bd5e4ce8b838c819aaa3645fa69a7233fe7 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 17 Dec 2015 00:46:50 +0800 Subject: [PATCH 1/9] Checks Dataset nullability during resolution --- .../catalyst/encoders/ExpressionEncoder.scala | 28 ++++++++++++----- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 1 + .../scala/org/apache/spark/sql/Dataset.scala | 13 +++++++- .../org/apache/spark/sql/DatasetSuite.scala | 30 +++++++++++++++++++ 5 files changed, 65 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7a4401cf5810e..fcd0ba0d14d9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -20,17 +20,18 @@ package org.apache.spark.sql.catalyst.encoders import java.util.concurrent.ConcurrentMap import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} +import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.spark.util.Utils -import org.apache.spark.sql.{AnalysisException, Encoder} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} -import org.apache.spark.sql.types.{StructField, ObjectType, StructType} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.util.Utils /** * A factory for constructing encoders that convert objects and primitives to and from the @@ -284,6 +285,7 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) + checkNullability(schema) val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), @@ -304,6 +306,18 @@ case class ExpressionEncoder[T]( }) } + private def checkNullability(output: Seq[Attribute]): Unit = { + val logicalPlanSchema = StructType.fromAttributes(output) + + // Checks for schema nullability + if (this.schema != logicalPlanSchema && this.schema.sameType(logicalPlanSchema)) { + throw new AnalysisException( + s"""Dataset nullability doesn't conform to the underlying logical plan: + >${sideBySide(logicalPlanSchema.treeString, this.schema.treeString).mkString("\n")} + """.stripMargin('>')) + } + } + /** * Returns a copy of this encoder where the expressions used to construct an object from an input * row have been bound to the ordinals of the given schema. Note that you need to first call diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 7293d5d4472af..ed8b40b96ec2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal, $dataType]" + override def toString: String = s"input[$ordinal, $dataType]$nullabilitySuffix" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a9c12127d367..921711a075a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.types._ * */ abstract class Expression extends TreeNode[Expression] { + protected def nullabilitySuffix: String = if (nullable) "" else "" /** * Returns true when an expression is a candidate for static evaluation before the query is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d201d65238523..d0748b1211f1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.rdd.RDD @@ -64,7 +65,7 @@ import org.apache.spark.util.Utils class Dataset[T] private[sql]( @transient override val sqlContext: SQLContext, @transient override val queryExecution: QueryExecution, - tEncoder: Encoder[T]) extends Queryable with Serializable { + tEncoder: Encoder[T]) extends Queryable with Serializable with Logging { /** * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is @@ -83,6 +84,16 @@ class Dataset[T] private[sql]( */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + logTrace( + s""" + |# unresolvedTEncoder.fromRowExpression + |${unresolvedTEncoder.fromRowExpression.treeString} + |# resolvedTEncoder.fromRowExpression + |${resolvedTEncoder.fromRowExpression.treeString} + |# boundTEncoder.fromRowExpression + |${boundTEncoder.fromRowExpression.treeString} + """.stripMargin) + private implicit def classTag = resolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index de012a9a56454..71fd0dc7af2c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,6 +24,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { @@ -515,12 +516,41 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) } + + test("check nullability") { + val rowRDD = sqlContext.sparkContext.parallelize(Seq( + Row(Row("hello", 1: Integer)), + Row(Row("world", null)) + )) + + val schema = StructType(Seq( + StructField("f", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = false) + )) + + val message = intercept[AnalysisException] { + sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct].collect() + }.message + + assert(message.contains( + """Dataset nullability doesn't conform to the underlying logical plan: + | root root + |! |-- f: struct (nullable = false) |-- f: struct (nullable = true) + | | |-- a: string (nullable = true) | |-- a: string (nullable = true) + | | |-- b: integer (nullable = false) | |-- b: integer (nullable = false) + """.stripMargin + )) + } } case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) +case class NestedStruct(f: ClassData) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. From 5b62827d4498fe43da475141782eb2b4b2deeb2e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 21 Dec 2015 02:14:07 +0800 Subject: [PATCH 2/9] Runtime nullability check for NewInstance --- .../spark/sql/catalyst/ScalaReflection.scala | 8 +++- .../sql/catalyst/expressions/objects.scala | 40 +++++++++++++++++++ .../src/test/resources/log4j.properties | 4 +- sql/core/src/test/resources/log4j.properties | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 30 +++++++++++++- 5 files changed, 80 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index cc9e6af1818f2..f2544d9e00051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -336,10 +336,16 @@ object ScalaReflection extends ScalaReflection { Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - constructorFor( + val constructor = constructorFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + + if (!nullable) { + AssertNotNull(constructor, t.toString, fieldName, fieldType.toString) + } else { + constructor + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 492cc9bf4146c..d2998ff8f794b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -624,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """ } } + +/** + * Asserts that input values of a non-nullable child expression are not null. + * + * Note that there are cases where `child.nullable == true`, while we still needs to add this + * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable + * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all + * non-null `s`, `s.i` can't be null. + */ +case class AssertNotNull( + child: Expression, parentType: String, fieldName: String, fieldType: String) + extends UnaryExpression { + + override def dataType: DataType = child.dataType + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childGen = child.gen(ctx) + + ev.isNull = "false" + ev.value = childGen.value + + s""" + ${childGen.code} + + if (${childGen.isNull}) { + throw new RuntimeException( + "Null value appeared in non-nullable field $fieldType.$fieldName of type $fieldType. " + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + ); + } + """ + } +} diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index eb3b1999eb996..c1e256836cf81 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -16,9 +16,9 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +log4j.rootCategory=TRACE, file log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true +log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 12fb128149d32..5b672885a22ff 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file core/target/unit-tests.log -log4j.rootLogger=DEBUG, CA, FA +log4j.rootLogger=TRACE, CA, FA #Console Appender log4j.appender.CA=org.apache.log4j.ConsoleAppender @@ -33,7 +33,7 @@ log4j.appender.FA.layout=org.apache.log4j.PatternLayout log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n # Set the logger level of File Appender to WARN -log4j.appender.FA.Threshold = INFO +log4j.appender.FA.Threshold = TRACE # Some packages are noisy for no good reason. log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 71fd0dc7af2c6..90a5284affb57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -517,7 +517,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) } - test("check nullability") { + test("analysis time nullability check") { val rowRDD = sqlContext.sparkContext.parallelize(Seq( Row(Row("hello", 1: Integer)), Row(Row("world", null)) @@ -543,6 +543,34 @@ class DatasetSuite extends QueryTest with SharedSQLContext { """.stripMargin )) } + + test("runtime nullability check") { + val schema = StructType(Seq( + StructField("f", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = true) + )) + + def buildDataset(rows: Row*): Dataset[NestedStruct] = { + val rowRDD = sqlContext.sparkContext.parallelize(rows) + sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + } + + checkAnswer( + buildDataset(Row(Row("hello", 1))), + NestedStruct(ClassData("hello", 1)) + ) + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null))) + + val message = intercept[RuntimeException] { + buildDataset(Row(Row("hello", null))).collect() + }.getMessage + + assert(message.contains("Null value appeared in non-nullable field")) + } } case class ClassData(a: String, b: Int) From eb3d00572d19f313fecd26ce1afdb446bc69898e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 21 Dec 2015 02:18:09 +0800 Subject: [PATCH 3/9] Fixes typo in exception message --- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index d2998ff8f794b..d40cd96905732 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -655,7 +655,7 @@ case class AssertNotNull( if (${childGen.isNull}) { throw new RuntimeException( - "Null value appeared in non-nullable field $fieldType.$fieldName of type $fieldType. " + + "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + "please try to use scala.Option[_] or other nullable types " + "(e.g. java.lang.Integer instead of int/scala.Int)." diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 90a5284affb57..fde7958b1433c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -569,7 +569,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { buildDataset(Row(Row("hello", null))).collect() }.getMessage - assert(message.contains("Null value appeared in non-nullable field")) + assert(message.contains( + "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." + )) } } From 482e09608e84c1bc88a10c413c2b88b5cc3425ff Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 21 Dec 2015 14:31:06 +0800 Subject: [PATCH 4/9] Reverts log4j.properties changes --- sql/catalyst/src/test/resources/log4j.properties | 4 ++-- sql/core/src/test/resources/log4j.properties | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index c1e256836cf81..eb3b1999eb996 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -16,9 +16,9 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=TRACE, file +log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=false +log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 5b672885a22ff..12fb128149d32 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file core/target/unit-tests.log -log4j.rootLogger=TRACE, CA, FA +log4j.rootLogger=DEBUG, CA, FA #Console Appender log4j.appender.CA=org.apache.log4j.ConsoleAppender @@ -33,7 +33,7 @@ log4j.appender.FA.layout=org.apache.log4j.PatternLayout log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n # Set the logger level of File Appender to WARN -log4j.appender.FA.Threshold = TRACE +log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false From f4cb448b4538fbd0e9dce1182bcbf17ed4a95218 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 21 Dec 2015 15:53:34 +0800 Subject: [PATCH 5/9] Narrows down the scope of this PR --- .../catalyst/encoders/ExpressionEncoder.scala | 13 --------- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 1 - .../org/apache/spark/sql/DatasetSuite.scala | 27 ------------------- 4 files changed, 1 insertion(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index fcd0ba0d14d9d..72c156b2c739f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -285,7 +285,6 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - checkNullability(schema) val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), @@ -306,18 +305,6 @@ case class ExpressionEncoder[T]( }) } - private def checkNullability(output: Seq[Attribute]): Unit = { - val logicalPlanSchema = StructType.fromAttributes(output) - - // Checks for schema nullability - if (this.schema != logicalPlanSchema && this.schema.sameType(logicalPlanSchema)) { - throw new AnalysisException( - s"""Dataset nullability doesn't conform to the underlying logical plan: - >${sideBySide(logicalPlanSchema.treeString, this.schema.treeString).mkString("\n")} - """.stripMargin('>')) - } - } - /** * Returns a copy of this encoder where the expressions used to construct an object from an input * row have been bound to the ordinals of the given schema. Note that you need to first call diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ed8b40b96ec2a..7293d5d4472af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal, $dataType]$nullabilitySuffix" + override def toString: String = s"input[$ordinal, $dataType]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 921711a075a7a..6a9c12127d367 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -49,7 +49,6 @@ import org.apache.spark.sql.types._ * */ abstract class Expression extends TreeNode[Expression] { - protected def nullabilitySuffix: String = if (nullable) "" else "" /** * Returns true when an expression is a candidate for static evaluation before the query is diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index fde7958b1433c..3337996309d4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -517,33 +517,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) } - test("analysis time nullability check") { - val rowRDD = sqlContext.sparkContext.parallelize(Seq( - Row(Row("hello", 1: Integer)), - Row(Row("world", null)) - )) - - val schema = StructType(Seq( - StructField("f", StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - )), nullable = false) - )) - - val message = intercept[AnalysisException] { - sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct].collect() - }.message - - assert(message.contains( - """Dataset nullability doesn't conform to the underlying logical plan: - | root root - |! |-- f: struct (nullable = false) |-- f: struct (nullable = true) - | | |-- a: string (nullable = true) | |-- a: string (nullable = true) - | | |-- b: integer (nullable = false) | |-- b: integer (nullable = false) - """.stripMargin - )) - } - test("runtime nullability check") { val schema = StructType(Seq( StructField("f", StructType(Seq( From bff92d563135d12f160f04d55a28ccaa1ad348aa Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 21 Dec 2015 21:02:39 +0800 Subject: [PATCH 6/9] Supports Java bean --- .../sql/catalyst/JavaTypeInference.scala | 9 +- .../apache/spark/sql/JavaDatasetSuite.java | 126 +++++++++++++++++- 2 files changed, 133 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index f566d1b3caebf..a1500cbc305d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -288,7 +288,14 @@ object JavaTypeInference { val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType - p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + val (_, nullable) = inferDataType(fieldType) + val constructor = constructorFor(fieldType, Some(addToPath(fieldName))) + val setter = if (nullable) { + constructor + } else { + AssertNotNull(constructor, other.getName, fieldName, fieldType.toString) + } + p.getWriteMethod.getName -> setter }.toMap val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0dbaeb81c7ec9..9f8db39e33d7e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,8 @@ import java.sql.Timestamp; import java.util.*; +import com.google.common.base.Objects; +import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -39,7 +41,6 @@ import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; @@ -741,4 +742,127 @@ public void testJavaBeanEncoder2() { context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } + + public class SmallBean implements Serializable { + private String a; + + private int b; + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public String getA() { + return a; + } + + public void setA(String a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SmallBean smallBean = (SmallBean) o; + return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a); + } + + @Override + public int hashCode() { + return Objects.hashCode(a, b); + } + } + + public class NestedSmallBean implements Serializable { + private SmallBean f; + + public SmallBean getF() { + return f; + } + + public void setF(SmallBean f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NestedSmallBean that = (NestedSmallBean) o; + return Objects.equal(f, that.f); + } + + @Override + public int hashCode() { + return Objects.hashCode(f); + } + } + + @Rule + public transient ExpectedException nullabilityCheck = ExpectedException.none(); + + @Test + public void testRuntimeNullabilityCheck() { + OuterScopes.addOuterScope(this); + + StructType schema = new StructType() + .add("f", new StructType() + .add("a", StringType, true) + .add("b", IntegerType, true), true); + + // Shouldn't throw runtime exception since it passes nullability check. + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", 1 + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + SmallBean smallBean = new SmallBean(); + smallBean.setA("hello"); + smallBean.setB(1); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + nestedSmallBean.setF(smallBean); + + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + { + Row row = new GenericRow(new Object[] { null }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + nullabilityCheck.expect(RuntimeException.class); + nullabilityCheck.expectMessage( + "Null value appeared in non-nullable field " + + "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", null + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + ds.collect(); + } + } } From b66de757bee4618963b265a6b1219e1470a5c10f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 22 Dec 2015 00:27:14 +0800 Subject: [PATCH 7/9] Fixes test failures --- .../encoders/EncoderResolutionSuite.scala | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 815a03f7c1a89..764ffdc0947c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -36,12 +36,16 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[StringLongClass] val cls = classOf[StringLongClass] + { val attrs = Seq('a.string, 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, - toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + Seq( + toExternalString('a.string), + AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") + ), false, ObjectType(cls)) compareExpressions(fromRowExpr, expected) @@ -52,7 +56,10 @@ class EncoderResolutionSuite extends PlanTest { val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression val expected = NewInstance( cls, - toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + Seq( + toExternalString('a.int.cast(StringType)), + AssertNotNull('b.long, cls.getName, "b", "Long") + ), false, ObjectType(cls)) compareExpressions(fromRowExpr, expected) @@ -69,7 +76,7 @@ class EncoderResolutionSuite extends PlanTest { val expected: Expression = NewInstance( cls, Seq( - 'a.int.cast(LongType), + AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"), If( 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), @@ -78,7 +85,9 @@ class EncoderResolutionSuite extends PlanTest { Seq( toExternalString( GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), - GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), + AssertNotNull( + GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), + innerCls.getName, "b", "Long")), false, ObjectType(innerCls)) )), @@ -102,7 +111,9 @@ class EncoderResolutionSuite extends PlanTest { cls, Seq( toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), - GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), + AssertNotNull( + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), + cls.getName, "b", "Long")), false, ObjectType(cls)), 'b.int.cast(LongType)), From cf3fe16d569e684862361431f71d01ef4c246617 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 22 Dec 2015 09:15:01 +0800 Subject: [PATCH 8/9] Reverts unrelated changes --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 15 +++++++-------- .../main/scala/org/apache/spark/sql/Dataset.scala | 10 ---------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 72c156b2c739f..7a4401cf5810e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -20,18 +20,17 @@ package org.apache.spark.sql.catalyst.encoders import java.util.concurrent.ConcurrentMap import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.util.Utils +import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} -import org.apache.spark.sql.{AnalysisException, Encoder} -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitives to and from the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d0748b1211f1c..a763a951440cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -84,16 +84,6 @@ class Dataset[T] private[sql]( */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) - logTrace( - s""" - |# unresolvedTEncoder.fromRowExpression - |${unresolvedTEncoder.fromRowExpression.treeString} - |# resolvedTEncoder.fromRowExpression - |${resolvedTEncoder.fromRowExpression.treeString} - |# boundTEncoder.fromRowExpression - |${boundTEncoder.fromRowExpression.treeString} - """.stripMargin) - private implicit def classTag = resolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = From 759c20d173a5a4f0d508e81a8f803d9bee3e2564 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 22 Dec 2015 09:56:39 +0800 Subject: [PATCH 9/9] Fixes compilation error introduced while rebasing --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f2544d9e00051..becd019caeca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -326,7 +326,7 @@ object ScalaReflection extends ScalaReflection { val cls = getClassFromType(tpe) val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => - val dataType = schemaFor(fieldType).dataType + val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name.