Skip to content

Commit

Permalink
[SPARK-46442][SQL] DS V2 supports push down PERCENTILE_CONT and PERCE…
Browse files Browse the repository at this point in the history
…NTILE_DISC

### What changes were proposed in this pull request?
This PR will translate the aggregate function `PERCENTILE_CONT` and `PERCENTILE_DISC` for pushdown.

- This PR adds `Expression[] orderingWithinGroups` into `GeneralAggregateFunc`, so as DS V2 pushdown framework could compile the `WITHIN GROUP (ORDER BY ...)` easily.

- This PR also split `visitInverseDistributionFunction` from `visitAggregateFunction`, so as DS V2 pushdown framework could generate the syntax `WITHIN GROUP (ORDER BY ...)` easily.

- This PR also fix a bug that `JdbcUtils` can't treat the precision and scale of decimal returned from JDBC.

### Why are the changes needed?
DS V2 supports push down `PERCENTILE_CONT` and `PERCENTILE_DISC`.

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New test cases.

### Was this patch authored or co-authored using generative AI tooling?
'No'.

Closes #44397 from beliefer/SPARK-46442.

Lead-authored-by: Jiaan Geng <beliefer@163.com>
Co-authored-by: beliefer <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and cloud-fan committed Jan 10, 2024
1 parent d2f5724 commit 85b504d
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 24 deletions.
Expand Up @@ -21,6 +21,7 @@

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.SortValue;
import org.apache.spark.sql.internal.connector.ExpressionWithToString;

/**
Expand All @@ -41,7 +42,9 @@
* <li><pre>REGR_R2(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>REGR_SLOPE(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>REGR_SXY(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>MODE(input1[, inverse])</pre> Since 4.0.0</li>
* <li><pre>MODE() WITHIN (ORDER BY input1 [ASC|DESC])</pre> Since 4.0.0</li>
* <li><pre>PERCENTILE_CONT(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> Since 4.0.0</li>
* <li><pre>PERCENTILE_DISC(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> Since 4.0.0</li>
* </ol>
*
* @since 3.3.0
Expand All @@ -51,11 +54,21 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement
private final String name;
private final boolean isDistinct;
private final Expression[] children;
private final SortValue[] orderingWithinGroups;

public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
this.name = name;
this.isDistinct = isDistinct;
this.children = children;
this.orderingWithinGroups = new SortValue[]{};
}

public GeneralAggregateFunc(
String name, boolean isDistinct, Expression[] children, SortValue[] orderingWithinGroups) {
this.name = name;
this.isDistinct = isDistinct;
this.children = children;
this.orderingWithinGroups = orderingWithinGroups;
}

public String name() { return name; }
Expand All @@ -64,6 +77,8 @@ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] childr
@Override
public Expression[] children() { return children; }

public SortValue[] orderingWithinGroups() { return orderingWithinGroups; }

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand All @@ -73,14 +88,16 @@ public boolean equals(Object o) {

if (isDistinct != that.isDistinct) return false;
if (!name.equals(that.name)) return false;
return Arrays.equals(children, that.children);
if (!Arrays.equals(children, that.children)) return false;
return Arrays.equals(orderingWithinGroups, that.orderingWithinGroups);
}

@Override
public int hashCode() {
int result = name.hashCode();
result = 31 * result + (isDistinct ? 1 : 0);
result = 31 * result + Arrays.hashCode(children);
result = 31 * result + Arrays.hashCode(orderingWithinGroups);
return result;
}
}
Expand Up @@ -146,8 +146,16 @@ yield visitBinaryArithmetic(
return visitAggregateFunction("AVG", avg.isDistinct(),
expressionsToStringArray(avg.children()));
} else if (expr instanceof GeneralAggregateFunc f) {
return visitAggregateFunction(f.name(), f.isDistinct(),
expressionsToStringArray(f.children()));
if (f.orderingWithinGroups().length == 0) {
return visitAggregateFunction(f.name(), f.isDistinct(),
expressionsToStringArray(f.children()));
} else {
return visitInverseDistributionFunction(
f.name(),
f.isDistinct(),
expressionsToStringArray(f.children()),
expressionsToStringArray(f.orderingWithinGroups()));
}
} else if (expr instanceof UserDefinedScalarFunc f) {
return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
expressionsToStringArray(f.children()));
Expand Down Expand Up @@ -273,6 +281,15 @@ protected String visitAggregateFunction(
}
}

protected String visitInverseDistributionFunction(
String funcName, boolean isDistinct, String[] inputs, String[] orderingWithinGroups) {
assert(isDistinct == false);
String withinGroup =
joinArrayToString(orderingWithinGroups, ", ", "WITHIN GROUP (ORDER BY ", ")");
String functionCall = joinArrayToString(inputs, ", ", funcName + "(", ")");
return functionCall + " " + withinGroup;
}

protected String visitUserDefinedScalarFunction(
String funcName, String canonicalName, String[] inputs) {
throw new SparkUnsupportedOperationException(
Expand Down
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableExpression
Expand Down Expand Up @@ -347,8 +347,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right)))
// Translate Mode if it is deterministic or reverse is defined.
case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) =>
Some(new GeneralAggregateFunc("MODE", isDistinct,
Array(expr, LiteralValue(reverse, BooleanType))))
Some(new GeneralAggregateFunc(
"MODE", isDistinct, Array.empty, Array(generateSortValue(expr, !reverse))))
case aggregate.Percentile(
PushableExpression(left), PushableExpression(right), LongLiteral(1L), _, _, reverse) =>
Some(new GeneralAggregateFunc("PERCENTILE_CONT", isDistinct,
Array(right), Array(generateSortValue(left, reverse))))
case aggregate.PercentileDisc(
PushableExpression(left), PushableExpression(right), reverse, _, _, _) =>
Some(new GeneralAggregateFunc("PERCENTILE_DISC", isDistinct,
Array(right), Array(generateSortValue(left, reverse))))
// TODO supports other aggregate functions
case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
val translatedExprs = children.flatMap(PushableExpression.unapply(_))
Expand Down Expand Up @@ -380,6 +388,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
None
}
}

private def generateSortValue(expr: V2Expression, reverse: Boolean): SortValue = if (reverse) {
SortValue(expr, SortDirection.DESCENDING, NullOrdering.NULLS_LAST)
} else {
SortValue(expr, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)
}
}

object ColumnOrField {
Expand Down
15 changes: 2 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Expand Up @@ -43,7 +43,7 @@ private[sql] object H2Dialect extends JdbcDialect {

private val distinctUnsupportedAggregateFunctions =
Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY",
"MODE")
"MODE", "PERCENTILE_CONT", "PERCENTILE_DISC")

private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions
Expand Down Expand Up @@ -271,18 +271,7 @@ private[sql] object H2Dialect extends JdbcDialect {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT")
} else {
funcName match {
case "MODE" =>
// Support Mode only if it is deterministic or reverse is defined.
assert(inputs.length == 2)
if (inputs.last == "true") {
s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})"
} else {
s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)"
}
case _ =>
super.visitAggregateFunction(funcName, isDistinct, inputs)
}
super.visitAggregateFunction(funcName, isDistinct, inputs)
}

override def visitExtract(field: String, source: String): String = {
Expand Down
Expand Up @@ -336,7 +336,22 @@ abstract class JdbcDialect extends Serializable with Logging {
super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs)
} else {
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support aggregate function: $funcName");
s"${this.getClass.getSimpleName} does not support aggregate function: $funcName")
}
}

override def visitInverseDistributionFunction(
funcName: String,
isDistinct: Boolean,
inputs: Array[String],
orderingWithinGroups: Array[String]): String = {
if (isSupportedFunction(funcName)) {
super.visitInverseDistributionFunction(
dialectFunctionName(funcName), isDistinct, inputs, orderingWithinGroups)
} else {
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support " +
s"inverse distribution function: $funcName")
}
}

Expand Down
Expand Up @@ -2435,7 +2435,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df1)
checkPushedInfo(df1,
"""
|PushedAggregates: [MODE(SALARY, true)],
|PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
Expand Down Expand Up @@ -2465,7 +2465,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df3)
checkPushedInfo(df3,
"""
|PushedAggregates: [MODE(SALARY, true)],
|PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
Expand All @@ -2481,13 +2481,69 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df4)
checkPushedInfo(df4,
"""
|PushedAggregates: [MODE(SALARY, false)],
|PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00)))
}

test("scan with aggregate push-down: PERCENTILE & PERCENTILE_DISC with filter and group by") {
val df1 = sql(
"""
|SELECT
| dept,
| PERCENTILE(salary, 0.5)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df1)
checkAggregateRemoved(df1)
checkPushedInfo(df1,
"""
|PushedAggregates: [PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df1, Seq(Row(1, 9500.00), Row(2, 11000.00), Row(6, 12000.00)))

val df2 = sql(
"""
|SELECT
| dept,
| PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY),
| PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df2)
checkAggregateRemoved(df2)
checkPushedInfo(df2,
"""
|PushedAggregates: [PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST),
|PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df2,
Seq(Row(1, 9300.0, 9700.0), Row(2, 10600.0, 11400.0), Row(6, 12000.0, 12000.0)))

val df3 = sql(
"""
|SELECT
| dept,
| PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY),
| PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df3)
checkAggregateRemoved(df3)
checkPushedInfo(df3,
"""
|PushedAggregates: [PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST),
|PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df3,
Seq(Row(1, 9000.0, 10000.0), Row(2, 10000.0, 12000.0), Row(6, 12000.0, 12000.0)))
}

test("scan with aggregate push-down: aggregate over alias push down") {
val cols = Seq("a", "b", "c", "d", "e")
val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*)
Expand Down

0 comments on commit 85b504d

Please sign in to comment.