Skip to content

Commit

Permalink
[SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2 partial aggreg…
Browse files Browse the repository at this point in the history
…ate push-down `AVG`

### What changes were proposed in this pull request?
apache#35130 supports partial aggregate push-down `AVG` for DS V2.
The behavior doesn't consistent with `Average` if occurs overflow in ansi mode.
This PR closely follows the implement of `Average` to respect overflow in ansi mode.

### Why are the changes needed?
Make the behavior consistent with `Average` if occurs overflow in ansi mode.

### Does this PR introduce _any_ user-facing change?
'Yes'.
Users could see the exception about overflow throws in ansi mode.

### How was this patch tested?
New tests.

Closes apache#35320 from beliefer/SPARK-37839_followup.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and chenzhx committed Apr 18, 2022
1 parent 614cb93 commit 3034070
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ case class Average(
case _ => DoubleType
}

private lazy val sum = AttributeReference("sum", sumDataType)()
private lazy val count = AttributeReference("count", LongType)()
lazy val sum = AttributeReference("sum", sumDataType)()
lazy val count = AttributeReference("count", LongType)()

override lazy val aggBufferAttributes = sum :: count :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
import org.apache.spark.sql.types.{DataType, LongType, StructType}
import org.apache.spark.sql.util.SchemaUtils._

object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper {
Expand Down Expand Up @@ -129,18 +129,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct)
// Closely follow `Average.evaluateExpression`
avg.dataType match {
case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
case _: DayTimeIntervalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count))
case _ =>
// TODO deal with the overflow issue
Divide(addCastIfNeeded(sum, avg.dataType),
addCastIfNeeded(count, avg.dataType), false)
avg.evaluateExpression transform {
case a: Attribute if a.semanticEquals(avg.sum) =>
addCastIfNeeded(sum, avg.sum.dataType)
case a: Attribute if a.semanticEquals(avg.count) =>
addCastIfNeeded(count, avg.count.dataType)
}
}
}.asInstanceOf[Seq[NamedExpression]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
"""CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate()
conn.prepareStatement(
"""CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate()

conn.prepareStatement(
"CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price NUMERIC(23, 3))")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
"(1, 'bottle', 11111111111111111111.123)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
"(1, 'bottle', 99999999999999999999.123)").executeUpdate()
}
}

Expand Down Expand Up @@ -484,7 +492,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
Seq(Row("test", "people", false), Row("test", "empty_table", false),
Row("test", "employee", false)))
Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false),
Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false)))
}

test("SQL API: create table as select") {
Expand Down Expand Up @@ -1105,4 +1114,37 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
}

test("scan with aggregate push-down: partial push-down AVG with overflow") {
def createDataFrame: DataFrame = spark.read
.option("partitionColumn", "id")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.item")
.agg(avg($"PRICE").as("avg"))

Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val df = createDataFrame
checkAggregateRemoved(df, false)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [SUM(PRICE), COUNT(PRICE)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
if (ansiEnabled) {
val e = intercept[SparkException] {
df.collect()
}
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
e.getCause.getMessage.contains("Overflow in sum of decimals"))
} else {
checkAnswer(df, Seq(Row(null)))
}
}
}
}
}

0 comments on commit 3034070

Please sign in to comment.