Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-36351][SQL] Refactor filter push down in file source v2 #33650

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3eec42e
[SPARK-36351][SQL] Separate partition filters and data filters in Pus…
huaxingao Aug 5, 2021
c3e07ac
fix build error
huaxingao Aug 5, 2021
f95b4a0
separate partition filters and data filters in pushFilters
huaxingao Aug 7, 2021
feea946
Separate partition filter and data filter in PushDownUtil
huaxingao Aug 7, 2021
e9d598f
fix scala 2.13 build error
huaxingao Aug 7, 2021
a6ae1c5
fix test failure for text file format
huaxingao Aug 9, 2021
ff8a9d4
separate partition filters and data filter in pushFilters
huaxingao Aug 13, 2021
8f06107
separate partition filters and data filters in PushDownUtils
huaxingao Aug 18, 2021
c541315
make FileScanBuilder NOT implement SupportsPushDownFilters
huaxingao Aug 18, 2021
eaebb4c
remove unused import
huaxingao Aug 18, 2021
f61caa0
fix AvroScanBuilder
huaxingao Aug 19, 2021
e04428b
address comments
huaxingao Aug 19, 2021
ab6187c
address comments
huaxingao Aug 19, 2021
5b41c61
change FileScanBuilder.translateDataFilter to protected
huaxingao Aug 19, 2021
3b9e2c6
inline translateDataFilter()
huaxingao Aug 19, 2021
73eea33
inline translate filter
huaxingao Aug 19, 2021
095a7b4
add comment to pushFilters
huaxingao Aug 20, 2021
68ace26
add org.apache.spark.sql.internal.connector.SupportsPushDownCatalystF…
huaxingao Aug 31, 2021
f3b4d22
split partition and data filters in file source
huaxingao Sep 2, 2021
4700c08
follow SupportsPushDownFilters
huaxingao Sep 2, 2021
3085bdf
address comments
huaxingao Sep 2, 2021
da9fe2f
address comments
huaxingao Sep 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -62,10 +62,6 @@ case class AvroScan(
pushedFilters)
}

override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)

override def equals(obj: Any): Boolean = obj match {
case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options &&
equivalentFilters(pushedFilters, a.pushedFilters)
Expand Down
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.v2.avro

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.StructFilters
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
Expand All @@ -31,7 +31,7 @@ class AvroScanBuilder (
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters {
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {

override def build(): Scan = {
AvroScan(
Expand All @@ -41,17 +41,16 @@ class AvroScanBuilder (
readDataSchema(),
readPartitionSchema(),
options,
pushedFilters())
pushedDataFilters,
partitionFilters,
dataFilters)
}

private var _pushedFilters: Array[Filter] = Array.empty

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.avroFilterPushDown) {
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
_pushedFilters = StructFilters.pushedFilters(filters, dataSchema)
StructFilters.pushedFilters(dataFilters, dataSchema)
} else {
Array.empty[Filter]
}
filters
}

override def pushedFilters(): Array[Filter] = _pushedFilters
}
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.internal.connector

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sources.Filter

/**
* A mix-in interface for {@link FileScanBuilder}. File sources can implement this interface to
* push down filters to the file source. The pushed down filters will be separated into partition
* filters and data filters. Partition filters are used for partition pruning and data filters are
* used to reduce the size of the data to be read.
*/
trait SupportsPushDownCatalystFilters {

/**
* Pushes down catalyst Expression filters (which will be separated into partition filters and
* data filters), and returns data filters that need to be evaluated after scanning.
*/
def pushFilters(filters: Seq[Expression]): Seq[Expression]

/**
* Returns the data filters that are pushed to the data source via
* {@link #pushFilters(Expression[])}.
*/
def pushedFilters: Array[Filter]
}
Expand Up @@ -28,6 +28,7 @@ import org.json4s.jackson.Serialization
import org.apache.spark.SparkUpgradeException
import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
Expand All @@ -39,7 +40,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils


object DataSourceUtils {
object DataSourceUtils extends PredicateHelper {
/**
* The key to use for storing partitionBy columns as options.
*/
Expand Down Expand Up @@ -242,4 +243,22 @@ object DataSourceUtils {
options
}
}

def getPartitionFiltersAndDataFilters(
partitionSchema: StructType,
normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val partitionColumns = normalizedFilters.flatMap { expr =>
expr.collect {
case attr: AttributeReference if partitionSchema.names.contains(attr.name) =>
attr
}
}
val partitionSet = AttributeSet(partitionColumns)
val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
f.references.subsetOf(partitionSet)
)
val extraPartitionFilter =
dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
(ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters)
}
}
Expand Up @@ -17,52 +17,24 @@

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogStatistics
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan}
import org.apache.spark.sql.types.StructType

/**
* Prune the partitions of file source based table using partition filters. Currently, this rule
* is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]]
* with [[FileScan]].
* is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]].
*
* For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding
* statistics will be updated. And the partition filters will be kept in the filters of returned
* logical plan.
*
* For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to
* its underlying [[FileScan]]. And the partition filters will be removed in the filters of
* returned logical plan.
*/
private[sql] object PruneFileSourcePartitions
extends Rule[LogicalPlan] with PredicateHelper {

private def getPartitionKeyFiltersAndDataFilters(
sparkSession: SparkSession,
relation: LeafNode,
partitionSchema: StructType,
filters: Seq[Expression],
output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = {
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output)
val partitionColumns =
relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
f.references.subsetOf(partitionSet)
)
val extraPartitionFilter =
dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))

(ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters)
}

private def rebuildPhysicalOperation(
projects: Seq[NamedExpression],
filters: Seq[Expression],
Expand Down Expand Up @@ -91,12 +63,14 @@ private[sql] object PruneFileSourcePartitions
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
fsRelation.sparkSession, logicalRelation, partitionSchema, filters,
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
logicalRelation.output)
val (partitionKeyFilters, _) = DataSourceUtils
.getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)

if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters)
val prunedFsRelation =
fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession)
// Change table stats based on the sizeInBytes of pruned files
Expand All @@ -117,23 +91,5 @@ private[sql] object PruneFileSourcePartitions
} else {
op
}

case op @ PhysicalOperation(projects, filters,
v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output))
if filters.nonEmpty =>
val (partitionKeyFilters, dataFilters) =
getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation,
scan.readPartitionSchema, filters, output)
// The dataFilters are pushed down only once
if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) {
val prunedV2Relation =
v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters))
// The pushed down partition filters don't need to be reevaluated.
val afterScanFilters =
ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty)
rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation)
} else {
op
}
}
}
Expand Up @@ -71,12 +71,6 @@ trait FileScan extends Scan
*/
def dataFilters: Seq[Expression]

/**
* Create a new `FileScan` instance from the current one
* with different `partitionFilters` and `dataFilters`
*/
def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan

/**
* If a file with `path` is unsplittable, return the unsplittable reason,
* otherwise return `None`.
Expand Down
Expand Up @@ -16,19 +16,30 @@
*/
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.SparkSession
import scala.collection.mutable

import org.apache.spark.sql.{sources, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils}
import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

abstract class FileScanBuilder(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns {
dataSchema: StructType)
extends ScanBuilder
with SupportsPushDownRequiredColumns
with SupportsPushDownCatalystFilters {
private val partitionSchema = fileIndex.partitionSchema
private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
protected val supportsNestedSchemaPruning = false
protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields)
protected var partitionFilters = Seq.empty[Expression]
protected var dataFilters = Seq.empty[Expression]
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
protected var pushedDataFilters = Array.empty[Filter]

override def pruneColumns(requiredSchema: StructType): Unit = {
// [SPARK-30107] While `requiredSchema` might have pruned nested columns,
Expand All @@ -48,7 +59,7 @@ abstract class FileScanBuilder(
StructType(fields)
}

protected def readPartitionSchema(): StructType = {
def readPartitionSchema(): StructType = {
val requiredNameSet = createRequiredNameSet()
val fields = partitionSchema.fields.filter { field =>
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
Expand All @@ -57,6 +68,31 @@ abstract class FileScanBuilder(
StructType(fields)
}

override def pushFilters(filters: Seq[Expression]): Seq[Expression] = {
val (partitionFilters, dataFilters) =
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filters)
this.partitionFilters = partitionFilters
this.dataFilters = dataFilters
val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter]
for (filterExpr <- dataFilters) {
val translated = DataSourceStrategy.translateFilter(filterExpr, true)
if (translated.nonEmpty) {
translatedFilters += translated.get
}
}
pushedDataFilters = pushDataFilters(translatedFilters.toArray)
dataFilters
}

override def pushedFilters: Array[Filter] = pushedDataFilters

/*
* Push down data filters to the file source, so the data filters can be evaluated there to
* reduce the size of the data to be read. By default, data filters are not pushed down.
* File source needs to implement this method to push down data filters.
*/
protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter]
huaxingao marked this conversation as resolved.
Show resolved Hide resolved

private def createRequiredNameSet(): Set[String] =
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet

Expand Down
Expand Up @@ -25,9 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -71,6 +69,9 @@ object PushDownUtils extends PredicateHelper {
}
(r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq)

case f: FileScanBuilder =>
val postScanFilters = f.pushFilters(filters)
(f.pushedFilters, postScanFilters)
case _ => (Nil, filters)
}
}
Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan}
import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -84,10 +84,6 @@ case class CSVScan(
dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters)
}

override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)

override def equals(obj: Any): Boolean = obj match {
case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options &&
equivalentFilters(pushedFilters, c.pushedFilters)
Expand Down