Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-36646][SQL] Push down group by partition column for aggregate #34445

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.execution.RowToColumnConverter
Expand Down Expand Up @@ -81,19 +81,37 @@ object AggregatePushDownUtils {
}
}

if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) {
if (dataFilters.nonEmpty) {
// Parquet/ORC footer has max/min/count for columns
// e.g. SELECT COUNT(col1) FROM t
// but footer doesn't have max/min/count for a column if max/min/count
// are combined with filter or group by
// e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
// SELECT COUNT(col1) FROM t GROUP BY col2
// However, if the filter is on partition column, max/min/count can still be pushed down
// Todo: add support if groupby column is partition col
// (https://issues.apache.org/jira/browse/SPARK-36646)
return None
}

if (aggregation.groupByColumns.nonEmpty &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add some comments explaining the reasoning why we have this check and only support the case when group by columns is the same as partition columns. What if the number of group by columns is smaller than that of partition columns?

partitionNames.size != aggregation.groupByColumns.length) {
// If there are group by columns, we only push down if the group by columns are the same as
// the partition columns. In theory, if group by columns are a subset of partition columns,
// we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3,
// SELECT MAX(c) FROM t GROUP BY p1, p2 should still be able to push down. However, the
// partial aggregation pushed down to data source needs to be
// SELECT p1, p2, p3, MAX(c) FROM t GROUP BY p1, p2, p3, and Spark layer
// needs to have a final aggregation such as SELECT MAX(c) FROM t GROUP BY p1, p2, then the
// pushed down query schema is different from the query schema at Spark. We will keep
// aggregate push down simple and don't handle this complicate case for now.
return None
}
aggregation.groupByColumns.foreach { col =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe also add some comments here - it's not that easy to understand and can help the maintenance of this code.

// don't push down if the group by columns are not the same as the partition columns (orders
// doesn't matter because reorder can be done at data source layer)
if (col.fieldNames.length != 1 || !isPartitionCol(col)) return None
finalSchema = finalSchema.add(getStructFieldForCol(col))
}

aggregation.aggregateExpressions.foreach {
case max: Max =>
if (!processMinOrMax(max)) return None
Expand Down Expand Up @@ -138,4 +156,44 @@ object AggregatePushDownUtils {
converter.convert(aggregatesAsRow, columnVectors.toArray)
new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
}

/**
* Return the schema for aggregates only (exclude group by columns)
*/
def getSchemaWithoutGroupingExpression(
sunchao marked this conversation as resolved.
Show resolved Hide resolved
aggSchema: StructType,
aggregation: Aggregation): StructType = {
val numOfGroupByColumns = aggregation.groupByColumns.length
if (numOfGroupByColumns > 0) {
new StructType(aggSchema.fields.drop(numOfGroupByColumns))
} else {
aggSchema
}
}

/**
* Reorder partition cols if they are not in the same order as group by columns
*/
def reOrderPartitionCol(
partitionSchema: StructType,
aggregation: Aggregation,
partitionValues: InternalRow): InternalRow = {
val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head)
huaxingao marked this conversation as resolved.
Show resolved Hide resolved
assert(groupByColNames.length == partitionSchema.length &&
groupByColNames.length == partitionValues.numFields, "The number of group by columns " +
s"${groupByColNames.length} should be the same as partition schema length " +
s"${partitionSchema.length} and the number of fields ${partitionValues.numFields} " +
s"in partitionValues")
var reorderedPartColValues = Array.empty[Any]
if (!partitionSchema.names.sameElements(groupByColNames)) {
groupByColNames.foreach { col =>
val index = partitionSchema.names.indexOf(col)
val v = partitionValues.asInstanceOf[GenericInternalRow].values(index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious: is this always guaranteed to be GenericInternalRow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me that the partitionValues comes from PartitionPath, which always contains GenericInternalRow.

reorderedPartColValues = reorderedPartColValues :+ v
}
new GenericInternalRow(reorderedPartColValues)
} else {
partitionValues
}
}
}
Expand Up @@ -35,11 +35,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.SchemaMergeUtils
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, SchemaMergeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, Utils}

Expand Down Expand Up @@ -396,9 +397,8 @@ object OrcUtils extends Logging {
dataSchema: StructType,
partitionSchema: StructType,
aggregation: Aggregation,
aggSchema: StructType): InternalRow = {
require(aggregation.groupByColumns.length == 0,
s"aggregate $aggregation with group-by column shouldn't be pushed down")
aggSchema: StructType,
partitionValues: InternalRow): InternalRow = {
var columnsStatistics: OrcColumnStatistics = null
try {
columnsStatistics = OrcFooterReader.readStatistics(reader)
Expand Down Expand Up @@ -457,17 +457,22 @@ object OrcUtils extends Logging {
}
}

// if there are group by columns, we will build result row first,
// and then append group by columns values (partition columns values) to the result row.
val schemaWithoutGroupBy =
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggSchema, aggregation)

val aggORCValues: Seq[WritableComparable[_]] =
aggregation.aggregateExpressions.zipWithIndex.map {
case (max: Max, index) =>
val columnName = max.column.fieldNames.head
val statistics = getColumnStatistics(columnName)
val dataType = aggSchema(index).dataType
val dataType = schemaWithoutGroupBy(index).dataType
getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
case (min: Min, index) =>
val columnName = min.column.fieldNames.head
val statistics = getColumnStatistics(columnName)
val dataType = aggSchema.apply(index).dataType
val dataType = schemaWithoutGroupBy.apply(index).dataType
getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
case (count: Count, _) =>
val columnName = count.column.fieldNames.head
Expand All @@ -490,7 +495,15 @@ object OrcUtils extends Logging {
s"createAggInternalRowFromFooter should not take $x as the aggregate expression")
}

val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray)
orcValuesDeserializer.deserializeFromValues(aggORCValues)
val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy,
(0 until schemaWithoutGroupBy.length).toArray)
val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues)
if (aggregation.groupByColumns.nonEmpty) {
val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
partitionSchema, aggregation, partitionValues)
new JoinedRow(reOrderedPartitionValues, resultRow)
} else {
resultRow
}
}
}
Expand Up @@ -31,8 +31,9 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils
import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -157,17 +158,22 @@ object ParquetUtils {
partitionSchema: StructType,
aggregation: Aggregation,
aggSchema: StructType,
datetimeRebaseMode: LegacyBehaviorPolicy.Value,
isCaseSensitive: Boolean): InternalRow = {
partitionValues: InternalRow,
datetimeRebaseMode: LegacyBehaviorPolicy.Value): InternalRow = {
val (primitiveTypes, values) = getPushedDownAggResult(
footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive)
footer, filePath, dataSchema, partitionSchema, aggregation)

val builder = Types.buildMessage
primitiveTypes.foreach(t => builder.addField(t))
val parquetSchema = builder.named("root")

// if there are group by columns, we will build result row first,
// and then append group by columns values (partition columns values) to the result row.
val schemaWithoutGroupBy =
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggSchema, aggregation)

val schemaConverter = new ParquetToSparkSchemaConverter
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema,
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, schemaWithoutGroupBy,
None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
primitiveTypeNames.zipWithIndex.foreach {
Expand Down Expand Up @@ -195,7 +201,14 @@ object ParquetUtils {
case (_, i) =>
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
}
converter.currentRecord

if (aggregation.groupByColumns.nonEmpty) {
val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
partitionSchema, aggregation, partitionValues)
new JoinedRow(reorderedPartitionValues, converter.currentRecord)
} else {
converter.currentRecord
}
}

/**
Expand All @@ -211,16 +224,14 @@ object ParquetUtils {
filePath: String,
dataSchema: StructType,
partitionSchema: StructType,
aggregation: Aggregation,
isCaseSensitive: Boolean)
aggregation: Aggregation)
: (Array[PrimitiveType], Array[Any]) = {
val footerFileMetaData = footer.getFileMetaData
val fields = footerFileMetaData.getSchema.getFields
val blocks = footer.getBlocks
val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
val valuesBuilder = mutable.ArrayBuilder.make[Any]

assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down")
aggregation.aggregateExpressions.foreach { agg =>
var value: Any = None
var rowCount = 0L
Expand Down Expand Up @@ -250,8 +261,7 @@ object ParquetUtils {
schemaName = "count(" + count.column.fieldNames.head + ")"
rowCount += block.getRowCount
var isPartitionCol = false
if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive))
.toSet.contains(count.column.fieldNames.head)) {
if (partitionSchema.fields.map(_.name).toSet.contains(count.column.fieldNames.head)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need check case sensitivity now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to me no need to check case sensitivity because I have normalized aggregates and group by columns in V2ScanRelationPushDown.

isPartitionCol = true
}
isCount = true
Expand Down
Expand Up @@ -83,11 +83,10 @@ case class OrcPartitionReaderFactory(

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value
val filePath = new Path(new URI(file.filePath))

if (aggregation.nonEmpty) {
return buildReaderWithAggregates(filePath, conf)
return buildReaderWithAggregates(file, conf)
}
val filePath = new Path(new URI(file.filePath))

val resultedColPruneInfo =
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
Expand Down Expand Up @@ -127,11 +126,10 @@ case class OrcPartitionReaderFactory(

override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
val conf = broadcastedConf.value.value
val filePath = new Path(new URI(file.filePath))

if (aggregation.nonEmpty) {
return buildColumnarReaderWithAggregates(filePath, conf)
return buildColumnarReaderWithAggregates(file, conf)
}
Comment on lines 129 to 131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

val filePath = new Path(new URI(file.filePath))

val resultedColPruneInfo =
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
Expand Down Expand Up @@ -183,14 +181,16 @@ case class OrcPartitionReaderFactory(
* Build reader with aggregate push down.
*/
private def buildReaderWithAggregates(
filePath: Path,
file: PartitionedFile,
conf: Configuration): PartitionReader[InternalRow] = {
val filePath = new Path(new URI(file.filePath))
new PartitionReader[InternalRow] {
private var hasNext = true
private lazy val row: InternalRow = {
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
OrcUtils.createAggInternalRowFromFooter(
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema)
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get,
readDataSchema, file.partitionValues)
}
}

Expand All @@ -209,15 +209,16 @@ case class OrcPartitionReaderFactory(
* Build columnar reader with aggregate push down.
*/
private def buildColumnarReaderWithAggregates(
filePath: Path,
file: PartitionedFile,
conf: Configuration): PartitionReader[ColumnarBatch] = {
val filePath = new Path(new URI(file.filePath))
new PartitionReader[ColumnarBatch] {
private var hasNext = true
private lazy val batch: ColumnarBatch = {
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
val row = OrcUtils.createAggInternalRowFromFooter(
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get,
readDataSchema)
readDataSchema, file.partitionValues)
AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false)
}
}
Expand Down
Expand Up @@ -134,10 +134,11 @@ case class ParquetPartitionReaderFactory(
private var hasNext = true
private lazy val row: InternalRow = {
val footer = getFooter(file)

if (footer != null && footer.getBlocks.size > 0) {
ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema,
partitionSchema, aggregation.get, readDataSchema,
getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
partitionSchema, aggregation.get, readDataSchema, file.partitionValues,
getDatetimeRebaseMode(footer.getFileMetaData))
} else {
null
}
Expand Down Expand Up @@ -179,8 +180,8 @@ case class ParquetPartitionReaderFactory(
val footer = getFooter(file)
if (footer != null && footer.getBlocks.size > 0) {
val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath,
dataSchema, partitionSchema, aggregation.get, readDataSchema,
getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
dataSchema, partitionSchema, aggregation.get, readDataSchema, file.partitionValues,
getDatetimeRebaseMode(footer.getFileMetaData))
AggregatePushDownUtils.convertAggregatesRowToBatch(
row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined)
} else {
Expand Down