Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-36646][SQL] Push down group by partition column for aggregate #34445

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -81,19 +81,22 @@ object AggregatePushDownUtils {
}
}

if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) {
if (dataFilters.nonEmpty) {
// Parquet/ORC footer has max/min/count for columns
// e.g. SELECT COUNT(col1) FROM t
// but footer doesn't have max/min/count for a column if max/min/count
// are combined with filter or group by
// e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
// SELECT COUNT(col1) FROM t GROUP BY col2
// However, if the filter is on partition column, max/min/count can still be pushed down
// Todo: add support if groupby column is partition col
// (https://issues.apache.org/jira/browse/SPARK-36646)
return None
}

aggregation.groupByColumns.foreach { col =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe also add some comments here - it's not that easy to understand and can help the maintenance of this code.

if (col.fieldNames.length != 1 || !isPartitionCol(col)) return None
finalSchema = finalSchema.add(getStructFieldForCol(col))
}

aggregation.aggregateExpressions.foreach {
case max: Max =>
if (!processMinOrMax(max)) return None
Expand Down Expand Up @@ -138,4 +141,19 @@ object AggregatePushDownUtils {
converter.convert(aggregatesAsRow, columnVectors.toArray)
new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
}

/**
* Return the schema for aggregates only (exclude group by columns)
*/
def getSchemaWithoutGroupingExpression(
sunchao marked this conversation as resolved.
Show resolved Hide resolved
aggregation: Aggregation,
aggSchema: StructType): StructType = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe swap the order of aggSchema and aggregation here, as we're modifying the schema here with the info from aggregation.

val groupByColNums = aggregation.groupByColumns.length
huaxingao marked this conversation as resolved.
Show resolved Hide resolved
if (groupByColNums > 0) {
new StructType(aggSchema.fields.drop(groupByColNums))
} else {
aggSchema
}
}

}
Expand Up @@ -35,11 +35,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.SchemaMergeUtils
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, SchemaMergeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, Utils}

Expand Down Expand Up @@ -396,9 +397,8 @@ object OrcUtils extends Logging {
dataSchema: StructType,
partitionSchema: StructType,
aggregation: Aggregation,
aggSchema: StructType): InternalRow = {
require(aggregation.groupByColumns.length == 0,
s"aggregate $aggregation with group-by column shouldn't be pushed down")
aggSchema: StructType,
partitionValues: InternalRow): InternalRow = {
var columnsStatistics: OrcColumnStatistics = null
try {
columnsStatistics = OrcFooterReader.readStatistics(reader)
Expand Down Expand Up @@ -457,17 +457,22 @@ object OrcUtils extends Logging {
}
}

// if there are group by columns, we will build result row first,
// and then append group by columns values (partition columns values) to the result row.
val schemaWithoutGroupby =
huaxingao marked this conversation as resolved.
Show resolved Hide resolved
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggregation, aggSchema)

val aggORCValues: Seq[WritableComparable[_]] =
aggregation.aggregateExpressions.zipWithIndex.map {
case (max: Max, index) =>
val columnName = max.column.fieldNames.head
val statistics = getColumnStatistics(columnName)
val dataType = aggSchema(index).dataType
val dataType = schemaWithoutGroupby(index).dataType
getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
case (min: Min, index) =>
val columnName = min.column.fieldNames.head
val statistics = getColumnStatistics(columnName)
val dataType = aggSchema.apply(index).dataType
val dataType = schemaWithoutGroupby.apply(index).dataType
getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
case (count: Count, _) =>
val columnName = count.column.fieldNames.head
Expand All @@ -490,7 +495,13 @@ object OrcUtils extends Logging {
s"createAggInternalRowFromFooter should not take $x as the aggregate expression")
}

val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray)
orcValuesDeserializer.deserializeFromValues(aggORCValues)
val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupby,
(0 until schemaWithoutGroupby.length).toArray)
val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues)
if (aggregation.groupByColumns.length > 0) {
new JoinedRow(partitionValues, resultRow)
} else {
resultRow
}
}
}
Expand Up @@ -31,8 +31,9 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils
import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -157,17 +158,22 @@ object ParquetUtils {
partitionSchema: StructType,
aggregation: Aggregation,
aggSchema: StructType,
datetimeRebaseMode: LegacyBehaviorPolicy.Value,
isCaseSensitive: Boolean): InternalRow = {
partitionValues: InternalRow,
datetimeRebaseMode: LegacyBehaviorPolicy.Value): InternalRow = {
val (primitiveTypes, values) = getPushedDownAggResult(
footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive)
footer, filePath, dataSchema, partitionSchema, aggregation)

val builder = Types.buildMessage
primitiveTypes.foreach(t => builder.addField(t))
val parquetSchema = builder.named("root")

// if there are group by columns, we will build result row first,
// and then append group by columns values (partition columns values) to the result row.
val schemaWithoutGroupby =
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggregation, aggSchema)

val schemaConverter = new ParquetToSparkSchemaConverter
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema,
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, schemaWithoutGroupby,
None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
primitiveTypeNames.zipWithIndex.foreach {
Expand Down Expand Up @@ -195,7 +201,12 @@ object ParquetUtils {
case (_, i) =>
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
}
converter.currentRecord

if (aggregation.groupByColumns.length > 0) {
huaxingao marked this conversation as resolved.
Show resolved Hide resolved
new JoinedRow(partitionValues, converter.currentRecord)
huaxingao marked this conversation as resolved.
Show resolved Hide resolved
} else {
converter.currentRecord
}
}

/**
Expand All @@ -211,16 +222,14 @@ object ParquetUtils {
filePath: String,
dataSchema: StructType,
partitionSchema: StructType,
aggregation: Aggregation,
isCaseSensitive: Boolean)
aggregation: Aggregation)
: (Array[PrimitiveType], Array[Any]) = {
val footerFileMetaData = footer.getFileMetaData
val fields = footerFileMetaData.getSchema.getFields
val blocks = footer.getBlocks
val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
val valuesBuilder = mutable.ArrayBuilder.make[Any]

assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down")
aggregation.aggregateExpressions.foreach { agg =>
var value: Any = None
var rowCount = 0L
Expand Down Expand Up @@ -250,8 +259,7 @@ object ParquetUtils {
schemaName = "count(" + count.column.fieldNames.head + ")"
rowCount += block.getRowCount
var isPartitionCol = false
if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive))
.toSet.contains(count.column.fieldNames.head)) {
if (partitionSchema.fields.map(_.name).toSet.contains(count.column.fieldNames.head)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need check case sensitivity now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to me no need to check case sensitivity because I have normalized aggregates and group by columns in V2ScanRelationPushDown.

isPartitionCol = true
}
isCount = true
Expand Down
Expand Up @@ -86,7 +86,7 @@ case class OrcPartitionReaderFactory(
val filePath = new Path(new URI(file.filePath))

if (aggregation.nonEmpty) {
return buildReaderWithAggregates(filePath, conf)
return buildReaderWithAggregates(file, conf)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems filePath can be created after the if block:

if (aggregation.nonEmpty) {
  return buildReaderWithAggregates(file, conf)
}

val filePath = new Path(new URI(file.filePath))


val resultedColPruneInfo =
Expand Down Expand Up @@ -130,7 +130,7 @@ case class OrcPartitionReaderFactory(
val filePath = new Path(new URI(file.filePath))

if (aggregation.nonEmpty) {
return buildColumnarReaderWithAggregates(filePath, conf)
return buildColumnarReaderWithAggregates(file, conf)
}
Comment on lines 129 to 131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.


val resultedColPruneInfo =
Expand Down Expand Up @@ -183,14 +183,16 @@ case class OrcPartitionReaderFactory(
* Build reader with aggregate push down.
*/
private def buildReaderWithAggregates(
filePath: Path,
file: PartitionedFile,
conf: Configuration): PartitionReader[InternalRow] = {
val filePath = new Path(new URI(file.filePath))
new PartitionReader[InternalRow] {
private var hasNext = true
private lazy val row: InternalRow = {
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
OrcUtils.createAggInternalRowFromFooter(
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema)
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get,
readDataSchema, file.partitionValues)
}
}

Expand All @@ -209,15 +211,16 @@ case class OrcPartitionReaderFactory(
* Build columnar reader with aggregate push down.
*/
private def buildColumnarReaderWithAggregates(
filePath: Path,
file: PartitionedFile,
conf: Configuration): PartitionReader[ColumnarBatch] = {
val filePath = new Path(new URI(file.filePath))
new PartitionReader[ColumnarBatch] {
private var hasNext = true
private lazy val batch: ColumnarBatch = {
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
val row = OrcUtils.createAggInternalRowFromFooter(
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get,
readDataSchema)
readDataSchema, file.partitionValues)
AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false)
}
}
Expand Down
Expand Up @@ -136,8 +136,8 @@ case class ParquetPartitionReaderFactory(
val footer = getFooter(file)
if (footer != null && footer.getBlocks.size > 0) {
ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema,
partitionSchema, aggregation.get, readDataSchema,
getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
partitionSchema, aggregation.get, readDataSchema, file.partitionValues,
getDatetimeRebaseMode(footer.getFileMetaData))
} else {
null
}
Expand Down Expand Up @@ -179,8 +179,8 @@ case class ParquetPartitionReaderFactory(
val footer = getFooter(file)
if (footer != null && footer.getBlocks.size > 0) {
val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath,
dataSchema, partitionSchema, aggregation.get, readDataSchema,
getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
dataSchema, partitionSchema, aggregation.get, readDataSchema, file.partitionValues,
getDatetimeRebaseMode(footer.getFileMetaData))
AggregatePushDownUtils.convertAggregatesRowToBatch(
row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined)
} else {
Expand Down
Expand Up @@ -261,6 +261,63 @@ trait FileSourceAggregatePushDownSuite
}
}

test("aggregate with partition group by can be pushed down") {
sunchao marked this conversation as resolved.
Show resolved Hide resolved
withTempPath { dir =>
spark.range(10).selectExpr("id", "id % 3 as p")
.write.partitionBy("p").format(format).save(dir.getCanonicalPath)
withTempView("tmp") {
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp");
Seq("false", "true").foreach { enableVectorizedReader =>
withSQLConf(aggPushDownEnabledKey -> "true",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can you test both aggPushDownEnabledKey as true and false and see if the results are the same?

vectorizedReaderEnabledKey -> enableVectorizedReader) {
val df = sql("SELECT count(*), count(id), p, max(id), p, count(p), max(id)," +
" min(id), p FROM tmp group by p")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregation: [COUNT(*), COUNT(id), MAX(id), COUNT(p), MIN(id)], " +
"PushedFilters: [], PushedGroupBy: [p]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(3, 3, 1, 7, 1, 3, 7, 1, 1), Row(3, 3, 2, 8, 2, 3, 8, 2, 2),
Row(4, 4, 0, 9, 0, 4, 9, 0, 0)))
}
}
}
}
}

test("aggregate with multi partition group by columns can be pushed down") {
withTempPath { dir =>
Seq((10, 1, 2), (2, 1, 2), (3, 2, 1), (4, 2, 1), (5, 2, 1), (6, 2, 1),
(1, 1, 2), (4, 1, 2), (3, 2, 2), (-4, 2, 2), (6, 2, 2))
.toDF("value", "p1", "p2")
.write
.partitionBy("p1", "p2")
.format(format)
.save(dir.getCanonicalPath)
withTempView("tmp") {
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp");
Seq("false", "true").foreach { enableVectorizedReader =>
withSQLConf(aggPushDownEnabledKey -> "true",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too. We should make sure aggPushDownEnabledKey won't change results.

vectorizedReaderEnabledKey -> enableVectorizedReader) {
val df = sql("SELECT count(*), count(value), max(value), min(value), p1, p2 FROM tmp" +
" GROUP BY p1, p2")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregation: [COUNT(*), COUNT(value), MAX(value), MIN(value)]," +
" PushedFilters: [], PushedGroupBy: [p1, p2]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(4, 4, 10, 1, 1, 2), Row(4, 4, 6, 3, 2, 1),
Row(3, 3, 6, -4, 2, 2)))
}
}
}
}
}

test("push down only if all the aggregates can be pushed down") {
val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
(9, "mno", 7), (2, null, 7))
Expand Down