Skip to content

Commit

Permalink
NormalizeFloatingNumbers should work on null struct.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jul 1, 2020
1 parent 20cd47e commit 41a318e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
Expand Down Expand Up @@ -123,7 +123,8 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
normalize(GetStructField(expr, i))
}
CreateStruct(fields)
val struct = CreateStruct(fields)
If(IsNull(expr), Literal(null, struct.dataType), struct)

case _ if expr.dataType.isInstanceOf[ArrayType] =>
val ArrayType(et, containsNull) = expr.dataType
Expand Down
Expand Up @@ -1028,4 +1028,16 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df, Row("abellina", 2) :: Row("mithunr", 1) :: Nil)
}
}

test("SPARK-32136: NormalizeFloatingNumbers should work on null struct") {
val df = Seq(
A(None),
A(Some(B(None))),
A(Some(B(Some(1.0))))).toDF
val groupBy = df.groupBy("b").agg(count("*"))
checkAnswer(groupBy, Row(null, 1) :: Row(Row(null), 1) :: Row(Row(1.0), 1) :: Nil)
}
}

case class B(c: Option[Double])
case class A(b: Option[B])

0 comments on commit 41a318e

Please sign in to comment.