From 1bf4f32924df6a3a06623dfcd2e06b6749c6ebad Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 17 Sep 2020 13:20:36 +0900 Subject: [PATCH] Fix --- .../sql/catalyst/optimizer/NormalizeFloatingNumbers.scala | 6 +++--- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 10f846cf910f9..bfc36ec477a73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -129,10 +129,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { Coalesce(children.map(normalize)) case _ if expr.dataType.isInstanceOf[StructType] => - val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => - normalize(GetStructField(expr, i)) + val fields = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { + case (name, i) => Seq(Literal(name), normalize(GetStructField(expr, i))) } - val struct = CreateStruct(fields) + val struct = CreateNamedStruct(fields.flatten.toSeq) KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct)) case _ if expr.dataType.isInstanceOf[ArrayType] => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e954e2bf1c46d..353444b664412 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1043,6 +1043,14 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql(queryTemplate("FIRST")), Row(1)) checkAnswer(sql(queryTemplate("LAST")), Row(3)) } + + test("SPARK-32906: struct field names should not change after normalizing floats") { + val df = Seq(Tuple1(Tuple2(-0.0d, Double.NaN)), Tuple1(Tuple2(0.0d, Double.NaN))).toDF("k") + val aggs = df.distinct().queryExecution.sparkPlan.collect { case a: HashAggregateExec => a } + assert(aggs.length == 2) + assert(aggs.head.output.map(_.dataType.simpleString).head === + aggs.last.output.map(_.dataType.simpleString).head) + } } case class B(c: Option[Double])