Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13101][SQL] nullability of array type element should not fail analysis of encoder #11035

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ object JavaTypeInference {
val setter = if (nullable) {
constructor
} else {
AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
AssertNotNull(constructor, Seq("currently no type path record in java"))
}
p.getWriteMethod.getName -> setter
}.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t

// TODO: add runtime null check for primitive array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this does still silently corrupt values.

scala> Seq(("a", Seq(null, new Integer(1)))).toDS().as[(String, Array[Int])].collect()
res5: Array[(String, Array[Int])] = Array((a,Array(0, 1)))

Since this isn't a regression its probably okay as long as we open another JIRA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
Expand Down Expand Up @@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val arrayData =
Invoke(
MapObjects(
p => constructorFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
ObjectType(classOf[Array[Any]]))

val mapFunction: Expression => Expression = p => {
val converter = constructorFor(elementType, Some(p), newTypePath)
if (nullable) {
converter
} else {
AssertNotNull(converter, newTypePath)
}
}

val array = Invoke(
MapObjects(mapFunction, getPath, dataType),
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
arrayData :: Nil)
array :: Nil)

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
Expand Down Expand Up @@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
newTypePath)

if (!nullable) {
AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
AssertNotNull(constructor, newTypePath)
} else {
constructor
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
case _ => Cast(child, dataType)
case _ => Cast(child, dataType.asNullable)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ object MapObjects {
* to handle collection elements.
* @param inputData An expression that when evaluted returns a collection object.
*/
case class MapObjects(
case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
inputData: Expression) extends Expression {
Expand Down Expand Up @@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
* non-null `s`, `s.i` can't be null.
*/
case class AssertNotNull(
child: Expression, parentType: String, fieldName: String, fieldType: String)
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
extends UnaryExpression {

override def dataType: DataType = child.dataType
Expand All @@ -651,19 +650,22 @@ case class AssertNotNull(
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val childGen = child.gen(ctx)

val errMsg = "Null value appeared in non-nullable field:" +
walkedTypePath.mkString("\n", "\n", "\n") +
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
val idx = ctx.references.length
ctx.references += errMsg

ev.isNull = "false"
ev.value = childGen.value

s"""
${childGen.code}

if (${childGen.isNull}) {
throw new RuntimeException(
"Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
);
throw new RuntimeException((String) references[$idx]);
}
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class StringLongClass(a: String, b: Long)

Expand All @@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
case class ComplexClass(a: Long, b: StringLongClass)

class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")

test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]
val cls = classOf[StringLongClass]


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous tests are very verbose and hard to maintain, every time we change some unrelated analysis behaviours, these tests need to be updated. Actually for these tests, we only care about if analysis passes, runtime execution successes. So I simplified these tests to only ensure we don't fail analysis and execution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

{
val attrs = Seq('a.string, 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
Seq(
toExternalString('a.string),
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
),
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))

{
val attrs = Seq('a.int, 'b.long)
val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
val expected = NewInstance(
cls,
Seq(
toExternalString('a.int.cast(StringType)),
AssertNotNull('b.long, cls.getName, "b", "Long")
),
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
}

test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val innerCls = classOf[StringLongClass]
val cls = classOf[ComplexClass]

val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
Seq(
AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
If(
'b.struct('a.int, 'b.long).isNull,
Literal.create(null, ObjectType(innerCls)),
NewInstance(
innerCls,
Seq(
toExternalString(
GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
AssertNotNull(
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
innerCls.getName, "b", "Long")),
ObjectType(innerCls),
propagateNull = false)
)),
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
}

test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val cls = classOf[StringLongClass]

val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
classOf[Tuple2[_, _]],
Seq(
NewInstance(
cls,
Seq(
toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
AssertNotNull(
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
cls.getName, "b", "Long")),
ObjectType(cls),
propagateNull = false),
'b.int.cast(LongType)),
ObjectType(classOf[Tuple2[_, _]]),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
}

test("nullability of array type element should not fail analysis") {
val encoder = ExpressionEncoder[Seq[Int]]
val attrs = 'a.array(IntegerType) :: Nil

// It should pass analysis
val bound = encoder.resolve(attrs, null).bind(attrs)

// If no null values appear, it should works fine
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))

// If there is null value, it should throw runtime exception
val e = intercept[RuntimeException] {
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}

test("the real number of fields doesn't match encoder schema: tuple encoder") {
Expand Down Expand Up @@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
}
}

private def toExternalString(e: Expression): Expression = {
Invoke(e, "toString", ObjectType(classOf[String]), Nil)
}

test("throw exception if real type is not compatible with encoder schema") {
val msg1 = intercept[AnalysisException] {
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,7 @@ public void testRuntimeNullabilityCheck() {
}

nullabilityCheck.expect(RuntimeException.class);
nullabilityCheck.expectMessage(
"Null value appeared in non-nullable field " +
"test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
nullabilityCheck.expectMessage("Null value appeared in non-nullable field");

{
Row row = new GenericRow(new Object[] {
Expand Down
13 changes: 5 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1, 1, 1)
}


test("SPARK-12404: Datatype Helper Serializablity") {
val ds = sparkContext.parallelize((
new Timestamp(0),
new Date(0),
java.math.BigDecimal.valueOf(1),
scala.math.BigDecimal(1)) :: Nil).toDS()
new Timestamp(0),
new Date(0),
java.math.BigDecimal.valueOf(1),
scala.math.BigDecimal(1)) :: Nil).toDS()

ds.collect()
}
Expand Down Expand Up @@ -553,9 +552,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
buildDataset(Row(Row("hello", null))).collect()
}.getMessage

assert(message.contains(
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
))
assert(message.contains("Null value appeared in non-nullable field"))
}

test("SPARK-12478: top level null field") {
Expand Down