Skip to content

Commit

Permalink
[SPARK-32906][SQL] Struct field names should not change after normali…
Browse files Browse the repository at this point in the history
…zing floats

### What changes were proposed in this pull request?

This PR intends to fix a minor bug when normalizing floats for struct types;
```
scala> import org.apache.spark.sql.execution.aggregate.HashAggregateExec
scala> val df = Seq(Tuple1(Tuple1(-0.0d)), Tuple1(Tuple1(0.0d))).toDF("k")
scala> val agg = df.distinct()
scala> agg.explain()
== Physical Plan ==
*(2) HashAggregate(keys=[k#40], functions=[])
+- Exchange hashpartitioning(k#40, 200), true, [id=#62]
   +- *(1) HashAggregate(keys=[knownfloatingpointnormalized(if (isnull(k#40)) null else named_struct(col1, knownfloatingpointnormalized(normalizenanandzero(k#40._1)))) AS k#40], functions=[])
      +- *(1) LocalTableScan [k#40]

scala> val aggOutput = agg.queryExecution.sparkPlan.collect { case a: HashAggregateExec => a.output.head }
scala> aggOutput.foreach { attr => println(attr.prettyJson) }
### Final Aggregate ###
[ {
  "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
  "num-children" : 0,
  "name" : "k",
  "dataType" : {
    "type" : "struct",
    "fields" : [ {
      "name" : "_1",
                ^^^
      "type" : "double",
      "nullable" : false,
      "metadata" : { }
    } ]
  },
  "nullable" : true,
  "metadata" : { },
  "exprId" : {
    "product-class" : "org.apache.spark.sql.catalyst.expressions.ExprId",
    "id" : 40,
    "jvmId" : "a824e83f-933e-4b85-a1ff-577b5a0e2366"
  },
  "qualifier" : [ ]
} ]

### Partial Aggregate ###
[ {
  "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
  "num-children" : 0,
  "name" : "k",
  "dataType" : {
    "type" : "struct",
    "fields" : [ {
      "name" : "col1",
                ^^^^
      "type" : "double",
      "nullable" : true,
      "metadata" : { }
    } ]
  },
  "nullable" : true,
  "metadata" : { },
  "exprId" : {
    "product-class" : "org.apache.spark.sql.catalyst.expressions.ExprId",
    "id" : 40,
    "jvmId" : "a824e83f-933e-4b85-a1ff-577b5a0e2366"
  },
  "qualifier" : [ ]
} ]
```

### Why are the changes needed?

bugfix.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added tests.

Closes #29780 from maropu/FixBugInNormalizedFloatingNumbers.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit b49aaa3)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
maropu authored and viirya committed Sep 18, 2020
1 parent 5581a92 commit 2d55de5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))

case _ if expr.dataType.isInstanceOf[StructType] =>
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
normalize(GetStructField(expr, i))
val fields = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => Seq(Literal(name), normalize(GetStructField(expr, i)))
}
val struct = CreateStruct(fields)
val struct = CreateNamedStruct(fields.flatten.toSeq)
KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))

case _ if expr.dataType.isInstanceOf[ArrayType] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,14 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(sql(queryTemplate("FIRST")), Row(1))
checkAnswer(sql(queryTemplate("LAST")), Row(3))
}

test("SPARK-32906: struct field names should not change after normalizing floats") {
val df = Seq(Tuple1(Tuple2(-0.0d, Double.NaN)), Tuple1(Tuple2(0.0d, Double.NaN))).toDF("k")
val aggs = df.distinct().queryExecution.sparkPlan.collect { case a: HashAggregateExec => a }
assert(aggs.length == 2)
assert(aggs.head.output.map(_.dataType.simpleString).head ===
aggs.last.output.map(_.dataType.simpleString).head)
}
}

case class B(c: Option[Double])
Expand Down

0 comments on commit 2d55de5

Please sign in to comment.