Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46442][SQL] DS V2 supports push down PERCENTILE_CONT and PERCENTILE_DISC #44397

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -21,6 +21,7 @@

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.SortValue;
import org.apache.spark.sql.internal.connector.ExpressionWithToString;

/**
Expand All @@ -41,7 +42,9 @@
* <li><pre>REGR_R2(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>REGR_SLOPE(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>REGR_SXY(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>MODE(input1[, inverse])</pre> Since 4.0.0</li>
* <li><pre>MODE() WITHIN (ORDER BY input1 [ASC|DESC])</pre> Since 4.0.0</li>
* <li><pre>PERCENTILE_CONT(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> Since 4.0.0</li>
* <li><pre>PERCENTILE_DISC(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> Since 4.0.0</li>
* </ol>
*
* @since 3.3.0
Expand All @@ -51,11 +54,21 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement
private final String name;
private final boolean isDistinct;
private final Expression[] children;
private final SortValue[] orderingWithinGroups;

public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
this.name = name;
this.isDistinct = isDistinct;
this.children = children;
this.orderingWithinGroups = new SortValue[]{};
}

public GeneralAggregateFunc(
String name, boolean isDistinct, Expression[] children, SortValue[] orderingWithinGroups) {
this.name = name;
this.isDistinct = isDistinct;
this.children = children;
this.orderingWithinGroups = orderingWithinGroups;
}

public String name() { return name; }
Expand All @@ -64,6 +77,8 @@ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] childr
@Override
public Expression[] children() { return children; }

public SortValue[] orderingWithinGroups() { return orderingWithinGroups; }

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand All @@ -73,14 +88,16 @@ public boolean equals(Object o) {

if (isDistinct != that.isDistinct) return false;
if (!name.equals(that.name)) return false;
return Arrays.equals(children, that.children);
if (!Arrays.equals(children, that.children)) return false;
return Arrays.equals(orderingWithinGroups, that.orderingWithinGroups);
}

@Override
public int hashCode() {
int result = name.hashCode();
result = 31 * result + (isDistinct ? 1 : 0);
result = 31 * result + Arrays.hashCode(children);
result = 31 * result + Arrays.hashCode(orderingWithinGroups);
return result;
}
}
Expand Up @@ -144,8 +144,16 @@ yield visitBinaryArithmetic(
return visitAggregateFunction("AVG", avg.isDistinct(),
expressionsToStringArray(avg.children()));
} else if (expr instanceof GeneralAggregateFunc f) {
return visitAggregateFunction(f.name(), f.isDistinct(),
expressionsToStringArray(f.children()));
if (f.orderingWithinGroups().length == 0) {
return visitAggregateFunction(f.name(), f.isDistinct(),
expressionsToStringArray(f.children()));
} else {
return visitInverseDistributionFunction(
f.name(),
f.isDistinct(),
expressionsToStringArray(f.children()),
expressionsToStringArray(f.orderingWithinGroups()));
}
} else if (expr instanceof UserDefinedScalarFunc f) {
return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
expressionsToStringArray(f.children()));
Expand Down Expand Up @@ -271,6 +279,15 @@ protected String visitAggregateFunction(
}
}

protected String visitInverseDistributionFunction(
String funcName, boolean isDistinct, String[] inputs, String[] orderingWithinGroups) {
assert(isDistinct == false);
String withinGroup =
joinArrayToString(orderingWithinGroups, ", ", "WITHIN GROUP (ORDER BY ", ")");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we translate ASC/DESC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer visitSortOrder.

String functionCall = joinArrayToString(inputs, ", ", funcName + "(", ")");
return functionCall + " " + withinGroup;
}

protected String visitUserDefinedScalarFunction(
String funcName, String canonicalName, String[] inputs) {
throw new UnsupportedOperationException(
Expand Down
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableExpression
Expand Down Expand Up @@ -347,8 +347,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right)))
// Translate Mode if it is deterministic or reverse is defined.
case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) =>
Some(new GeneralAggregateFunc("MODE", isDistinct,
Array(expr, LiteralValue(reverse, BooleanType))))
Some(new GeneralAggregateFunc(
"MODE", isDistinct, Array.empty, Array(generateSortValue(expr, !reverse))))
case aggregate.Percentile(
PushableExpression(left), PushableExpression(right), LongLiteral(1L), _, _, reverse) =>
Some(new GeneralAggregateFunc("PERCENTILE_CONT", isDistinct,
Array(right), Array(generateSortValue(left, reverse))))
case aggregate.PercentileDisc(
PushableExpression(left), PushableExpression(right), reverse, _, _, _) =>
Some(new GeneralAggregateFunc("PERCENTILE_DISC", isDistinct,
Array(right), Array(generateSortValue(left, reverse))))
// TODO supports other aggregate functions
case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
val translatedExprs = children.flatMap(PushableExpression.unapply(_))
Expand Down Expand Up @@ -380,6 +388,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
None
}
}

private def generateSortValue(expr: V2Expression, reverse: Boolean): SortValue = if (reverse) {
SortValue(expr, SortDirection.DESCENDING, NullOrdering.NULLS_LAST)
} else {
SortValue(expr, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)
}
}

object ColumnOrField {
Expand Down
15 changes: 2 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Expand Up @@ -42,7 +42,7 @@ private[sql] object H2Dialect extends JdbcDialect {

private val distinctUnsupportedAggregateFunctions =
Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY",
"MODE")
"MODE", "PERCENTILE_CONT", "PERCENTILE_DISC")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this change? H2 dialect deals with these two functions in visitInverseDistributionFunction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because H2 dialect overrides the visitInverseDistributionFunction and check with isSupportedFunction.

     override def visitInverseDistributionFunction(
         funcName: String,
         isDistinct: Boolean,
         inputs: Array[String],
         orderingWithinGroups: Array[String]): String = {
       if (isSupportedFunction(funcName)) {
         super.visitInverseDistributionFunction(
           dialectFunctionName(funcName), isDistinct, inputs, orderingWithinGroups)
       } else {
         throw new UnsupportedOperationException(
           s"${this.getClass.getSimpleName} does not support " +
             s"inverse distribution function: $funcName")
       }
     }


private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions
Expand Down Expand Up @@ -270,18 +270,7 @@ private[sql] object H2Dialect extends JdbcDialect {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT")
} else {
funcName match {
case "MODE" =>
// Support Mode only if it is deterministic or reverse is defined.
assert(inputs.length == 2)
if (inputs.last == "true") {
s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})"
} else {
s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)"
}
case _ =>
super.visitAggregateFunction(funcName, isDistinct, inputs)
}
super.visitAggregateFunction(funcName, isDistinct, inputs)
}

override def visitExtract(field: String, source: String): String = {
Expand Down
Expand Up @@ -336,7 +336,22 @@ abstract class JdbcDialect extends Serializable with Logging {
super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs)
} else {
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support aggregate function: $funcName");
s"${this.getClass.getSimpleName} does not support aggregate function: $funcName")
}
}

override def visitInverseDistributionFunction(
funcName: String,
isDistinct: Boolean,
inputs: Array[String],
orderingWithinGroups: Array[String]): String = {
if (isSupportedFunction(funcName)) {
super.visitInverseDistributionFunction(
dialectFunctionName(funcName), isDistinct, inputs, orderingWithinGroups)
} else {
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support " +
s"inverse distribution function: $funcName")
}
}

Expand Down
Expand Up @@ -2435,7 +2435,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df1)
checkPushedInfo(df1,
"""
|PushedAggregates: [MODE(SALARY, true)],
|PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
Expand Down Expand Up @@ -2465,7 +2465,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df3)
checkPushedInfo(df3,
"""
|PushedAggregates: [MODE(SALARY, true)],
|PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
Expand All @@ -2481,13 +2481,69 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df4)
checkPushedInfo(df4,
"""
|PushedAggregates: [MODE(SALARY, false)],
|PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00)))
}

test("scan with aggregate push-down: PERCENTILE & PERCENTILE_DISC with filter and group by") {
val df1 = sql(
"""
|SELECT
| dept,
| PERCENTILE(salary, 0.5)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df1)
checkAggregateRemoved(df1)
checkPushedInfo(df1,
"""
|PushedAggregates: [PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df1, Seq(Row(1, 9500.00), Row(2, 11000.00), Row(6, 12000.00)))

val df2 = sql(
"""
|SELECT
| dept,
| PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY),
| PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df2)
checkAggregateRemoved(df2)
checkPushedInfo(df2,
"""
|PushedAggregates: [PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST),
|PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df2,
Seq(Row(1, 9300.0, 9700.0), Row(2, 10600.0, 11400.0), Row(6, 12000.0, 12000.0)))

val df3 = sql(
"""
|SELECT
| dept,
| PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY),
| PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df3)
checkAggregateRemoved(df3)
checkPushedInfo(df3,
"""
|PushedAggregates: [PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST),
|PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df3,
Seq(Row(1, 9000.0, 10000.0), Row(2, 10000.0, 12000.0), Row(6, 12000.0, 12000.0)))
}

test("scan with aggregate push-down: aggregate over alias push down") {
val cols = Seq("a", "b", "c", "d", "e")
val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*)
Expand Down