diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index faee9ce56a0a1..a42fa9ba6bc85 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -110,6 +110,17 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( this } + /** + * Check if a key exists at the provided position using object equality rather than + * cooperative equality. Otherwise, hash sets will mishandle values for which `==` + * and `equals` return different results, like 0.0/-0.0 and NaN/NaN. + * + * See: https://issues.apache.org/jira/browse/SPARK-45599 + */ + @annotation.nowarn("cat=other-non-cooperative-equals") + private def keyExistsAtPos(k: T, pos: Int) = + _data(pos) equals k + /** * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. @@ -130,8 +141,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( _bitset.set(pos) _size += 1 return pos | NONEXISTENCE_MASK - } else if (_data(pos) == k) { - // Found an existing key. + } else if (keyExistsAtPos(k, pos)) { return pos } else { // quadratic probing with values increase by 1, 2, 3, ... @@ -165,7 +175,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( while (true) { if (!_bitset.get(pos)) { return INVALID_POS - } else if (k == _data(pos)) { + } else if (keyExistsAtPos(k, pos)) { return pos } else { // quadratic probing with values increase by 1, 2, 3, ... diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 155c855c8723e..ae9fb54bddfbe 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -249,4 +249,34 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { map(null) = null assert(map.get(null) === Some(null)) } + + test("SPARK-45599: 0.0 and -0.0 should count distinctly; NaNs should count together") { + // Exactly these elements provided in roughly this order trigger a condition where lookups of + // 0.0 and -0.0 in the bitset happen to collide, causing their counts to be merged incorrectly + // and inconsistently if `==` is used to check for key equality. + val spark45599Repro = Seq( + Double.NaN, + 2.0, + 168.0, + Double.NaN, + Double.NaN, + -0.0, + 153.0, + 0.0 + ) + + val map1 = new OpenHashMap[Double, Int]() + spark45599Repro.foreach(map1.changeValue(_, 1, {_ + 1})) + assert(map1(0.0) == 1) + assert(map1(-0.0) == 1) + assert(map1(Double.NaN) == 3) + + val map2 = new OpenHashMap[Double, Int]() + // Simply changing the order in which the elements are added to the map should not change the + // counts for 0.0 and -0.0. + spark45599Repro.reverse.foreach(map2.changeValue(_, 1, {_ + 1})) + assert(map2(0.0) == 1) + assert(map2(-0.0) == 1) + assert(map2(Double.NaN) == 3) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 89a308556d5df..0bc8aa067f57a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -269,4 +269,43 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(pos1 == pos2) } } + + test("SPARK-45599: 0.0 and -0.0 are equal but not the same") { + // Therefore, 0.0 and -0.0 should get separate entries in the hash set. + // + // Exactly these elements provided in roughly this order will trigger the following scenario: + // When probing the bitset in `getPos(-0.0)`, the loop will happen upon the entry for 0.0. + // In the old logic pre-SPARK-45599, the loop will find that the bit is set and, because + // -0.0 == 0.0, it will think that's the position of -0.0. But in reality this is the position + // of 0.0. So -0.0 and 0.0 will be stored at different positions, but `getPos()` will return + // the same position for them. This can cause users of OpenHashSet, like OpenHashMap, to + // return the wrong value for a key based on whether or not this bitset lookup collision + // happens. + val spark45599Repro = Seq( + Double.NaN, + 2.0, + 168.0, + Double.NaN, + Double.NaN, + -0.0, + 153.0, + 0.0 + ) + val set = new OpenHashSet[Double]() + spark45599Repro.foreach(set.add) + assert(set.size == 6) + val zeroPos = set.getPos(0.0) + val negZeroPos = set.getPos(-0.0) + assert(zeroPos != negZeroPos) + } + + test("SPARK-45599: NaN and NaN are the same but not equal") { + // Any mathematical comparison to NaN will return false, but when we place it in + // a hash set we want the lookup to work like a "normal" value. + val set = new OpenHashSet[Double]() + set.add(Double.NaN) + set.add(Double.NaN) + assert(set.contains(Double.NaN)) + assert(set.size == 1) + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out index f4fd42d6adea3..727025d6fc83b 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out @@ -747,3 +747,17 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) -- !query analysis Project [array_prepend(array(cast(null as string)), cast(null as string)) AS array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING))#x] +- OneRowRelation + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query analysis +Project [array_union(array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double)), array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double))) AS array_union(array(0.0, 0.0, NaN), array(0.0, 0.0, NaN))#x] ++- OneRowRelation + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query analysis +Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out index 83d0ff3f2edf7..fece926834b7c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out @@ -699,3 +699,10 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query analysis +Project [0 AS 0#x, 0.0 AS 0.0#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out index c26bb210b0fff..216d978546b26 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out @@ -747,3 +747,17 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) -- !query analysis Project [array_prepend(array(cast(null as string)), cast(null as string)) AS array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING))#x] +- OneRowRelation + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query analysis +Project [array_union(array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double)), array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double))) AS array_union(array(0.0, 0.0, NaN), array(0.0, 0.0, NaN))#x] ++- OneRowRelation + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query analysis +Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out index 88517449760d9..f14c87fc6ae3e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out @@ -1171,3 +1171,22 @@ Aggregate [c#x], [(c#x * 2) AS d#x] +- Project [if ((a#x < 0)) 0 else a#x AS b#x] +- SubqueryAlias t1 +- LocalRelation [a#x] + + +-- !query +SELECT col1, count(*) AS cnt +FROM VALUES + (0.0), + (-0.0), + (double('NaN')), + (double('NaN')), + (double('Infinity')), + (double('Infinity')), + (-double('Infinity')), + (-double('Infinity')) +GROUP BY col1 +ORDER BY col1 +-- !query analysis +Sort [col1#x ASC NULLS FIRST], true ++- Aggregate [col1#x], [col1#x, count(1) AS cnt#xL] + +- LocalRelation [col1#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out index 83d0ff3f2edf7..fece926834b7c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out @@ -699,3 +699,10 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query analysis +Project [0 AS 0#x, 0.0 AS 0.0#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 52a0906ea7392..865dc8bac4ea5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -177,3 +177,7 @@ select array_prepend(CAST(null AS ARRAY), CAST(null as String)); select array_prepend(array(), 1); select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- SPARK-45599: Confirm 0.0, -0.0, and NaN are handled appropriately. +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))); +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index ea1e2f323151a..7d6116ac1e3f1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -259,3 +259,18 @@ FROM ( GROUP BY b ) t3 GROUP BY c; + +-- SPARK-45599: Check that "weird" doubles group and sort as desired. +SELECT col1, count(*) AS cnt +FROM VALUES + (0.0), + (-0.0), + (double('NaN')), + (double('NaN')), + (double('Infinity')), + (double('Infinity')), + (-double('Infinity')), + (-double('Infinity')) +GROUP BY col1 +ORDER BY col1 +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 9f0eefc16a8cd..e1e4a370bffdc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -118,3 +118,6 @@ select +X'1'; select -date '1999-01-01'; select -timestamp '1999-01-01'; select -x'2379ACFe'; + +-- normalize -0 and -0.0 +select -0, -0.0; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 49e18411ffa37..6a07d659e39b5 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -907,3 +907,19 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out index 6e2c8a65206ef..9c3a0cce023b7 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out @@ -777,3 +777,11 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query schema +struct<0:int,0.0:decimal(1,1)> +-- !query output +0 0.0 diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index e568f5fa7796d..d33fc62f0d9a1 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -788,3 +788,19 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index e9addb9631536..0735752947222 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1102,3 +1102,25 @@ struct -- !query output 0 2 + + +-- !query +SELECT col1, count(*) AS cnt +FROM VALUES + (0.0), + (-0.0), + (double('NaN')), + (double('NaN')), + (double('Infinity')), + (double('Infinity')), + (-double('Infinity')), + (-double('Infinity')) +GROUP BY col1 +ORDER BY col1 +-- !query schema +struct +-- !query output +-Infinity 2 +0.0 2 +Infinity 2 +NaN 2 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 6e2c8a65206ef..9c3a0cce023b7 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -777,3 +777,11 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query schema +struct<0:int,0.0:decimal(1,1)> +-- !query output +0 0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5c7cf874bd793..ab665d8943b33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1089,6 +1089,39 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("SPARK-45599: Neither 0.0 nor -0.0 should be dropped when computing percentile") { + // To reproduce the bug described in SPARK-45599, we need exactly these rows in roughly + // this order in a DataFrame with exactly 1 partition. + // scalastyle:off line.size.limit + // See: https://issues.apache.org/jira/browse/SPARK-45599?focusedCommentId=17806954&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-17806954 + // scalastyle:on line.size.limit + val spark45599Repro: DataFrame = Seq( + 0.0, + 2.0, + 153.0, + 168.0, + 3252411229536261.0, + 7.205759403792794e+16, + 1.7976931348623157e+308, + 0.25, + Double.NaN, + Double.NaN, + -0.0, + -128.0, + Double.NaN, + Double.NaN + ).toDF("val").coalesce(1) + + checkAnswer( + spark45599Repro.agg( + percentile(col("val"), lit(0.1)) + ), + // With the buggy implementation of OpenHashSet, this returns `0.050000000000000044` + // instead of `-0.0`. + List(Row(-0.0)) + ) + } + test("any_value") { checkAnswer( courseSales.groupBy("course").agg(