diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 05f7edaeb5d48..533f7f20b2530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -76,8 +76,8 @@ case class Average( case _ => DoubleType } - private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val count = AttributeReference("count", LongType)() + lazy val sum = AttributeReference("sum", sumDataType)() + lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes = sum :: count :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 24e3a6c91b13d..cdcae15ef4e24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} +import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { @@ -129,18 +129,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) - // Closely follow `Average.evaluateExpression` - avg.dataType match { - case _: YearMonthIntervalType => - If(EqualTo(count, Literal(0L)), - Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) - case _: DayTimeIntervalType => - If(EqualTo(count, Literal(0L)), - Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) - case _ => - // TODO deal with the overflow issue - Divide(addCastIfNeeded(sum, avg.dataType), - addCastIfNeeded(count, avg.dataType), false) + avg.evaluateExpression transform { + case a: Attribute if a.semanticEquals(avg.sum) => + addCastIfNeeded(sum, avg.sum.dataType) + case a: Attribute if a.semanticEquals(avg.count) => + addCastIfNeeded(count, avg.count.dataType) } } }.asInstanceOf[Seq[NamedExpression]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 8553774055665..67a02904660c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -95,6 +95,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() conn.prepareStatement( """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() + + conn.prepareStatement( + "CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price NUMERIC(23, 3))") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 11111111111111111111.123)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 99999999999999999999.123)").executeUpdate() } } @@ -484,7 +492,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false))) + Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), + Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false))) } test("SQL API: create table as select") { @@ -1105,4 +1114,37 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) } + + test("scan with aggregate push-down: partial push-down AVG with overflow") { + def createDataFrame: DataFrame = spark.read + .option("partitionColumn", "id") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.item") + .agg(avg($"PRICE").as("avg")) + + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df = createDataFrame + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + if (ansiEnabled) { + val e = intercept[SparkException] { + df.collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals")) + } else { + checkAnswer(df, Seq(Row(null))) + } + } + } + } }