Skip to content

Commit

Permalink
add interface SupportsPushDownAggregates
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao committed May 15, 2021
1 parent 743bc8a commit 346485e
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 293 deletions.
Expand Up @@ -113,22 +113,4 @@ default CustomMetric[] supportedCustomMetrics() {
CustomMetric[] NO_METRICS = {};
return NO_METRICS;
}

/**
* Pushes down Aggregation to scan.
* The Aggregation can be pushed down only if all the Aggregate Functions can
* be pushed down.
*/
default void pushAggregation(Aggregation aggregation) {
throw new UnsupportedOperationException(description() +
": Push down Aggregation is not supported");
}

/*
* Returns the aggregation that is pushed to the Scan
*/
default Aggregation pushedAggregation() {
throw new UnsupportedOperationException(description() +
": pushedAggregation is not supported");
}
}
Expand Up @@ -28,5 +28,11 @@
*/
@Evolving
public interface ScanBuilder {
enum orders { FILTER, AGGREGATE, COLUMNS };

// Orders of operators push down. Spark will push down filters first, then aggregates, and finally
// column pruning (if applicable).
static orders[] PUSH_DOWN_ORDERS = {orders.FILTER, orders.AGGREGATE, orders.COLUMNS};

Scan build();
}
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.Aggregation;
import org.apache.spark.sql.types.StructType;

/**
* A mix-in interface for {@link ScanBuilder}. Data source can implement this interface to
* push down aggregates to the data source.
*
* @since 3.2.0
*/
@Evolving
public interface SupportsPushDownAggregates extends ScanBuilder {

/**
* Pushes down Aggregation to datasource.
* The Aggregation can be pushed down only if all the Aggregate Functions can
* be pushed down.
*/
void pushAggregation(Aggregation aggregation);

/**
* Returns the aggregation that are pushed to the data source via
* {@link #pushAggregation(Aggregation aggregation)}.
*/
Aggregation pushedAggregation();

/**
* Returns the schema of the pushed down aggregates
*/
StructType getPushDownAggSchema();

/**
* Indicate if the data source only supports global aggregated push down
*/
boolean supportsGlobalAggregatePushDownOnly();

/**
* Indicate if the data source supports push down aggregates along with filters
*/
boolean supportsPushDownAggregateWithFilter();
}
Expand Up @@ -20,21 +20,19 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql.types.DataType

// Aggregate Functions in SQL statement.
// e.g. SELECT COUNT(EmployeeID), AVG(salary), deptID FROM dept GROUP BY deptID
// aggregateExpressions are (COUNT(EmployeeID), AVG(salary)), groupByColumns are (deptID)
case class Aggregation(aggregateExpressions: Seq[AggregateFunc],
// e.g. SELECT COUNT(EmployeeID), Max(salary), deptID FROM dept GROUP BY deptID
// aggregateExpressions are (COUNT(EmployeeID), Max(salary)), groupByColumns are (deptID)
case class Aggregation(aggregateExpressions: Seq[Seq[AggregateFunc]],
groupByColumns: Seq[String])

abstract class AggregateFunc

// Avg and Sum are only supported by JDBC agg pushdown, not supported by parquet agg pushdown yet
case class Avg(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
case class Min(column: String, dataType: DataType) extends AggregateFunc
case class Max(column: String, dataType: DataType) extends AggregateFunc
case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
case class Count(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc

object Aggregation {
// Returns an empty Aggregate
def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])
def empty: Aggregation = Aggregation(Seq.empty[Seq[AggregateFunc]], Seq.empty[String])
}
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.PartialAggregatePushDown
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes}
Expand All @@ -38,7 +37,7 @@ class SparkOptimizer(

override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
// TODO: move SchemaPruning into catalyst
SchemaPruning :: PartialAggregatePushDown :: V2ScanRelationPushDown :: V2Writes ::
SchemaPruning :: V2ScanRelationPushDown :: V2Writes ::
PruneFileSourcePartitions :: Nil

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Expand Down
Expand Up @@ -679,19 +679,19 @@ object DataSourceStrategy

protected[sql] def translateAggregate(
aggregates: AggregateExpression,
pushableColumn: PushableColumnBase): Option[AggregateFunc] = {
pushableColumn: PushableColumnBase): Option[Seq[AggregateFunc]] = {
aggregates.aggregateFunction match {
case min@aggregate.Min(pushableColumn(name)) =>
Some(Min(name, min.dataType))
Some(Seq(Min(name, min.dataType)))
case max@aggregate.Max(pushableColumn(name)) =>
Some(Max(name, max.dataType))
Some(Seq(Max(name, max.dataType)))
case count: aggregate.Count =>
val columnName = count.children.head match {
// SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table
case Literal(_, _) => "1"
case pushableColumn(name) => name
}
Some(Count(columnName, count.dataType, aggregates.isDistinct))
Some(Seq(Count(columnName, count.dataType, aggregates.isDistinct)))
case _ => None
}
}
Expand Down

This file was deleted.

Expand Up @@ -357,23 +357,23 @@ object ParquetUtils {
blocks.forEach { block =>
val blockMetaData = block.getColumns()
aggregation.aggregateExpressions(i) match {
case Max(col, _) =>
case Seq(Max(col, _)) =>
index = dataSchema.fieldNames.toList.indexOf(col)
val currentMax = getCurrentBlockMaxOrMin(footer, blockMetaData, index, true)
if (currentMax != None &&
(value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) {
value = currentMax
}

case Min(col, _) =>
case Seq(Min(col, _)) =>
index = dataSchema.fieldNames.toList.indexOf(col)
val currentMin = getCurrentBlockMaxOrMin(footer, blockMetaData, index, false)
if (currentMin != None &&
(value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) {
value = currentMin
}

case Count(col, _, _) =>
case Seq(Count(col, _, _)) =>
index = dataSchema.fieldNames.toList.indexOf(col)
rowCount += block.getRowCount
if (!col.equals("1")) { // "1" is for count(*)
Expand Down
Expand Up @@ -20,11 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.sources.Aggregation
import org.apache.spark.sql.types.StructType

object PushDownUtils extends PredicateHelper {
Expand Down Expand Up @@ -70,6 +72,37 @@ object PushDownUtils extends PredicateHelper {
}
}

/**
* Pushes down aggregates to the data source reader
*
* @return pushed aggregation.
*/
def pushAggregates(
scanBuilder: ScanBuilder,
aggregates: Seq[AggregateExpression],
groupBy: Seq[Expression]): Aggregation = {

def columnAsString(e: Expression): String = e match {
case AttributeReference(name, _, _, _) => name
case _ => ""
}

scanBuilder match {
case r: SupportsPushDownAggregates =>
val translatedAggregates = aggregates.map(DataSourceStrategy
.translateAggregate(_, PushableColumn(false)))
val translatedGroupBys = groupBy.map(columnAsString)

if (translatedAggregates.exists(_.isEmpty) || translatedGroupBys.exists(_.isEmpty)) {
Aggregation.empty
} else {
r.pushAggregation(Aggregation(translatedAggregates.flatten, translatedGroupBys))
r.pushedAggregation
}
case _ => Aggregation.empty
}
}

/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
Expand Down

0 comments on commit 346485e

Please sign in to comment.