Skip to content

Commit

Permalink
[SPARK-34876][SQL] Fill defaultResult of non-nullable aggregates
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Filled the `defaultResult` field on non-nullable aggregates

### Why are the changes needed?

The `defaultResult` defaults to `None` and in some situations (like correlated scalar subqueries) it is used for the value of the aggregation.

The UT result before the fix:
```
-- !query
SELECT t1a,
   (SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2,
   (SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2,
   (SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) approx_count_distinct_t2,
   (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
   (SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2,
    (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
FROM t1
-- !query schema
struct<t1a:string,count_t2:bigint,count_if_t2:bigint,approx_count_distinct_t2:bigint,collect_list_t2:array<bigint>,collect_set_t2:array<bigint>,collect_set_t2:string>
-- !query output
val1a	0	0	NULL	NULL	NULL	NULL
val1a	0	0	NULL	NULL	NULL	NULL
val1a	0	0	NULL	NULL	NULL	NULL
val1a	0	0	NULL	NULL	NULL	NULL
val1b	6	6	3	[19,119,319,19,19,19]	[19,119,319]	0000000100000000000000060000000100000004000000005D8D6AB90000000000000000000000000000000400000000000000010000000000000001
val1c	2	2	2	[219,19]	[219,19]	0000000100000000000000020000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000001
val1d	0	0	NULL	NULL	NULL	NULL
val1d	0	0	NULL	NULL	NULL	NULL
val1d	0	0	NULL	NULL	NULL	NULL
val1e	1	1	1	[19]	[19]	0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e	1	1	1	[19]	[19]	0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e	1	1	1	[19]	[19]	0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
```

### Does this PR introduce _any_ user-facing change?

Bugfix

### How was this patch tested?

UT

Closes #31973 from tanelk/SPARK-34876_non_nullable_agg_subquery.

Authored-by: Tanel Kiis <tanel.kiis@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
tanelk authored and HyukjinKwon committed Mar 29, 2021
1 parent 4fceef0 commit 4b9e94c
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 4 deletions.
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
Expand Down Expand Up @@ -142,6 +142,9 @@ case class CountMinSketchAgg(

override def dataType: DataType = BinaryType

override def defaultResult: Option[Literal] =
Option(Literal.create(eval(createAggregationBuffer()), dataType))

override def children: Seq[Expression] =
Seq(child, epsExpression, confidenceExpression, seedExpression)

Expand Down
Expand Up @@ -90,6 +90,8 @@ case class HyperLogLogPlusPlus(

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

override def defaultResult: Option[Literal] = Option(Literal.create(0L, dataType))

val hllppHelper = new HyperLogLogPlusPlusHelper(relativeSD)

/** Allocate enough words to store all registers. */
Expand Down
Expand Up @@ -47,6 +47,8 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
// actual order of input rows.
override lazy val deterministic: Boolean = false

override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType))

protected def convertToBufferElement(value: Any): Any

override def update(buffer: T, input: InternalRow): T = {
Expand Down
Expand Up @@ -203,8 +203,7 @@ abstract class AggregateFunction extends Expression {
def inputAggBufferAttributes: Seq[AttributeReference]

/**
* Result of the aggregate function when the input is empty. This is currently only used for the
* proper rewriting of distinct aggregate functions.
* Result of the aggregate function when the input is empty.
*/
def defaultResult: Option[Literal] = None

Expand Down
Expand Up @@ -128,3 +128,13 @@ WHERE NOT EXISTS (SELECT (SELECT max(t2b)
ON t2a = t1a
WHERE t2c = t3c)
AND t3a = t1a);

-- SPARK-34876: Non-nullable aggregates should not return NULL in a correlated subquery
SELECT t1a,
(SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2,
(SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2,
(SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) approx_count_distinct_t2,
(SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
(SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2,
(SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
FROM t1;
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 11
-- Number of queries: 12


-- !query
Expand Down Expand Up @@ -196,3 +196,29 @@ val1d NULL
val1e 10
val1e 10
val1e 10


-- !query
SELECT t1a,
(SELECT count(t2d) FROM t2 WHERE t2a = t1a) count_t2,
(SELECT count_if(t2d > 0) FROM t2 WHERE t2a = t1a) count_if_t2,
(SELECT approx_count_distinct(t2d) FROM t2 WHERE t2a = t1a) approx_count_distinct_t2,
(SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
(SELECT collect_set(t2d) FROM t2 WHERE t2a = t1a) collect_set_t2,
(SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
FROM t1
-- !query schema
struct<t1a:string,count_t2:bigint,count_if_t2:bigint,approx_count_distinct_t2:bigint,collect_list_t2:array<bigint>,collect_set_t2:array<bigint>,collect_set_t2:string>
-- !query output
val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
val1a 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
val1b 6 6 3 [19,119,319,19,19,19] [19,119,319] 0000000100000000000000060000000100000004000000005D8D6AB90000000000000000000000000000000400000000000000010000000000000001
val1c 2 2 2 [219,19] [219,19] 0000000100000000000000020000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000001
val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB90000000000000000000000000000000000000000000000000000000000000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000

0 comments on commit 4b9e94c

Please sign in to comment.