diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 27133f0a43f2e..15c0ac7361168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.DataSourceScanExec.PUSHED_FILTERS +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.{CreateDataSourceTableUtils, DDLUtils, ExecutedCommandExec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -55,10 +56,10 @@ private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[Logi // The access modifier is used to expose this method to tests. private[sql] def convertStaticPartitions( - sourceAttributes: Seq[Attribute], - providedPartitions: Map[String, Option[String]], - targetAttributes: Seq[Attribute], - targetPartitionSchema: StructType): Seq[NamedExpression] = { + sourceAttributes: Seq[Attribute], + providedPartitions: Map[String, Option[String]], + targetAttributes: Seq[Attribute], + targetPartitionSchema: StructType): Seq[NamedExpression] = { assert(providedPartitions.exists(_._2.isDefined)) @@ -290,7 +291,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } // Based on Public API. - protected def pruneFilterProject( + private def pruneFilterProject( relation: LogicalRelation, projects: Seq[NamedExpression], filterPredicates: Seq[Expression], @@ -318,11 +319,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `PrunedFilteredScan` and `HadoopFsRelation`). // // Note that 2 and 3 shouldn't be used together. - protected def pruneFilterProjectRaw( + private def pruneFilterProjectRaw( relation: LogicalRelation, projects: Seq[NamedExpression], filterPredicates: Seq[Expression], - scanBuilder: (Seq[Attribute], Seq[Expression], Seq[Filter]) => RDD[InternalRow]) = { + scanBuilder: (Seq[Attribute], Seq[Expression], Seq[Filter]) => RDD[InternalRow]): SparkPlan = { val projectSet = AttributeSet(projects.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) @@ -331,8 +332,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. }} - val (unhandledPredicates, pushedFilters) = - selectFilters(relation.relation, candidatePredicates) + val (unhandledPredicates, pushedFilters) = selectFilters(relation.relation, candidatePredicates) // A set of column attributes that are only referenced by pushed down filters. We can eliminate // them from requested columns. @@ -349,11 +349,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val metadata: Map[String, String] = { val pairs = ArrayBuffer.empty[(String, String)] - if (pushedFilters.nonEmpty) { pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) } - pairs.toMap } @@ -500,47 +498,30 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * will be pushed down to the data source. */ protected[sql] def selectFilters( - relation: BaseRelation, - predicates: Seq[Expression]): (Seq[Expression], Seq[Filter]) = { - + relation: BaseRelation, predicates: Seq[Expression]): (Seq[Expression], Seq[Filter]) = { // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are // called `predicate`s, while all data source filters of type `sources.Filter` are simply called // `filter`s. - val translated: Seq[(Expression, Filter)] = - for { - predicate <- predicates - filter <- translateFilter(predicate) - } yield predicate -> filter - // A map from original Catalyst expressions to corresponding translated data source filters. - val translatedMap: Map[Expression, Filter] = translated.toMap - - // Catalyst predicate expressions that cannot be translated to data source filters. - val unrecognizedPredicates = predicates.filterNot(translatedMap.contains) + // If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = predicates.flatMap { p => + translateFilter(p).map(f => p -> f) + }.toMap - // Data source filters that cannot be handled by `relation`. The semantic of a unhandled filter - // at here is that a data source may not be able to apply this filter to every row - // of the underlying dataset. - val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet - - val (unhandled, handled) = translated.partition { - case (predicate, filter) => - unhandledFilters.contains(filter) - } + val pushedFilters: Seq[Filter] = translatedMap.values.toSeq - // Catalyst predicate expressions that can be translated to data source filters, but cannot be - // handled by `relation`. - val (unhandledPredicates, _) = unhandled.unzip + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonconvertiblePredicates = predicates.filterNot(translatedMap.contains) - // Translated data source filters that can be handled by `relation` - val (_, handledFilters) = handled.unzip - - // translated contains all filters that have been converted to the public Filter interface. - // We should always push them to the data source no matter whether the data source can apply - // a filter to every row or not. - val (_, translatedFilters) = translated.unzip + // Data source filters that cannot be handled by `relation`. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (p, f) => + unhandledFilters.contains(f) + }.keys - (unrecognizedPredicates ++ unhandledPredicates, translatedFilters) + (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters) } }