diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index bb840e69d99a3..5516c736e94ed 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -36,7 +36,8 @@ case class AvroScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -54,6 +55,9 @@ case class AvroScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + override def withDataFilters(dataFilters: Seq[Expression]): FileScan = + this.copy(dataFilters = dataFilters) + override def equals(obj: Any): Boolean = obj match { case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 7fd154ccac445..e43662f37ccbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -28,20 +28,25 @@ import org.apache.spark.sql.types.StructType private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { - private def getPartitionKeyFilters( + private def getPartitionKeyFiltersAndDataFilters( sparkSession: SparkSession, relation: LeafNode, partitionSchema: StructType, filters: Seq[Expression], - output: Seq[AttributeReference]): ExpressionSet = { + output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { val normalizedFilters = DataSourceStrategy.normalizeExprs( filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) val partitionColumns = relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) - ExpressionSet(normalizedFilters.filter { f => + val partitionKeyFilters = ExpressionSet(normalizedFilters.filter { f => f.references.subsetOf(partitionSet) }) + + val dataFilters = + normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty) + + (partitionKeyFilters, dataFilters) } private def rebuildPhysicalOperation( @@ -72,7 +77,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val partitionKeyFilters = getPartitionKeyFilters( + val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) @@ -92,11 +97,13 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { case op @ PhysicalOperation(projects, filters, v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) if filters.nonEmpty && scan.readDataSchema.nonEmpty => - val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession, - v2Relation, scan.readPartitionSchema, filters, output) - if (partitionKeyFilters.nonEmpty) { + val (partitionKeyFilters, dataFilters) = + getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation, + scan.readPartitionSchema, filters, output) + if (partitionKeyFilters.nonEmpty || dataFilters.nonEmpty) { val prunedV2Relation = - v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq)) + v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq) + .withDataFilters(dataFilters)) // The pushed down partition filters don't need to be reevaluated. val afterScanFilters = ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index a22e1ccfe4515..01458d63cd463 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -65,6 +65,16 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin */ def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan + /** + * Returns the filters that can be use for file listing + */ + def dataFilters: Seq[Expression] + + /** + * Create a new `FileScan` instance from the current one with different `dataFilters`. + */ + def withDataFilters(dataFilters: Seq[Expression]): FileScan + /** * If a file with `path` is unsplittable, return the unsplittable reason, * otherwise return `None`. @@ -79,7 +89,8 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin override def equals(obj: Any): Boolean = obj match { case f: FileScan => fileIndex == f.fileIndex && readSchema == f.readSchema - ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) + ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) && + ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters) case _ => false } @@ -92,6 +103,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin val metadata: Map[String, String] = Map( "ReadSchema" -> readDataSchema.catalogString, "PartitionFilters" -> seqToString(partitionFilters), + "DataFilters" -> seqToString(dataFilters), "Location" -> locationDesc) val metadataStr = metadata.toSeq.sorted.map { case (key, value) => @@ -103,7 +115,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin } protected def partitions: Seq[FilePartition] = { - val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty) + val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) val partitionAttributes = fileIndex.partitionSchema.toAttributes val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 78b04aa811e09..cc22036bb7a45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -38,7 +38,8 @@ case class CSVScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { private lazy val parsedOptions: CSVOptions = new CSVOptions( @@ -92,6 +93,9 @@ case class CSVScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + override def withDataFilters(dataFilters: Seq[Expression]): FileScan = + this.copy(dataFilters = dataFilters) + override def equals(obj: Any): Boolean = obj match { case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 153b402476c40..42b4c6a63b6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -39,7 +39,8 @@ case class JsonScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { private val parsedOptions = new JSONOptionsInRead( @@ -91,6 +92,9 @@ case class JsonScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + override def withDataFilters(dataFilters: Seq[Expression]): FileScan = + this.copy(dataFilters = dataFilters) + override def equals(obj: Any): Boolean = obj match { case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index f0595cb6d09c3..9f582c62624e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -38,7 +38,8 @@ case class OrcScan( readPartitionSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], - partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -66,4 +67,7 @@ case class OrcScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + + override def withDataFilters(dataFilters: Seq[Expression]): FileScan = + this.copy(dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 44179e2e42a4c..d2db3813db2df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -41,7 +41,8 @@ case class ParquetScan( readPartitionSchema: StructType, pushedFilters: Array[Filter], options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -94,4 +95,7 @@ case class ParquetScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + + override def withDataFilters(dataFilters: Seq[Expression]): FileScan = + this.copy(dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index cf6595e5c126c..bb0d480867b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -36,7 +36,8 @@ case class TextScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { private val optionsAsScala = options.asScala.toMap @@ -73,6 +74,9 @@ case class TextScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + override def withDataFilters(dataFilters: Seq[Expression]): FileScan = + this.copy(dataFilters = dataFilters) + override def equals(obj: Any): Boolean = obj match { case t: TextScan => super.equals(t) && options == t.options