From 86162521d71462559cbd5886de491b57020b1518 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 5 Apr 2015 21:45:22 +0800 Subject: [PATCH] Eliminate duplicate filters from pushdown predicates. --- .../apache/spark/sql/jdbc/JDBCRelation.scala | 14 ++- .../sql/sources/DataSourceStrategy.scala | 19 +++- .../apache/spark/sql/sources/interfaces.scala | 1 + .../spark/sql/sources/FilteredScanSuite.scala | 94 ++++++++++++++++++- 4 files changed, 124 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4fa84dc076f7..806f93a59c3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -25,7 +25,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -145,4 +146,15 @@ private[sql] case class JDBCRelation( filters, parts) } + + override def supportPredicate(filter: Expression): Boolean = { + filter match { + case e: expressions.EqualTo => true + case e: expressions.LessThan => true + case e: expressions.GreaterThan => true + case e: expressions.LessThanOrEqual => true + case e: expressions.GreaterThanOrEqual => true + case o => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index e13759b7feb7..7ab039ba2ec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -85,8 +85,11 @@ private[sql] object DataSourceStrategy extends Strategy { scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[Row]) = { val projectSet = AttributeSet(projectList.flatMap(_.references)) - val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = filterPredicates.reduceLeftOption(expressions.And) + + // eliminate the filters which are supported in the BaseRelation + val filteredPredicates = eliminateFilters(filterPredicates, relation) + val filterSet = AttributeSet(filteredPredicates.flatMap(_.references)) + val filterCondition = filteredPredicates.reduceLeftOption(expressions.And) val pushedFilters = filterPredicates.map { _ transform { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. @@ -116,6 +119,18 @@ private[sql] object DataSourceStrategy extends Strategy { } } + // Eliminate the filters supported in the BaseRelation + protected[sql] def eliminateFilters( + filters: Seq[Expression], relation: LogicalRelation): Seq[Expression] = { + val newFilters = relation match { + case LogicalRelation(t: PrunedFilteredScan) => + filters.filterNot(t.supportPredicate(_)) + case _ => + filters + } + newFilters + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, * and convert them. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8f9946a5a801..7a17606e03be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -162,6 +162,7 @@ trait PrunedScan { @DeveloperApi trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] + def supportPredicate(predicate: Expression): Boolean } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 773bd1602d5e..9aa68b534175 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.sources import scala.language.existentials -import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution import org.apache.spark.sql.types._ @@ -41,6 +44,28 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL StructField("b", IntegerType, nullable = false) :: StructField("c", StringType, nullable = false) :: Nil) + override def supportPredicate(filter: Expression): Boolean = { + filter match { + case expressions.EqualTo(l: NamedExpression, _) if l.name == "a" => true + case expressions.LessThan(l: NamedExpression, _) if l.name == "a" => true + case expressions.LessThanOrEqual(l: NamedExpression, _) if l.name == "a" => true + case expressions.GreaterThan(l: NamedExpression, _) if l.name == "a" => true + case expressions.GreaterThanOrEqual(l: NamedExpression, _) if l.name == "a" => true + case expressions.InSet(l: NamedExpression, _) if l.name == "a" => true + case expressions.IsNull(e: NamedExpression) if e.name == "a" => true + case expressions.IsNotNull(e: NamedExpression) if e.name == "a" => true + case expressions.Not(pred) => supportPredicate(pred) + case expressions.And(left, right) => + supportPredicate(left) && supportPredicate(right) + case expressions.Or(left, right) => + supportPredicate(left) && supportPredicate(right) + case e: expressions.StartsWith => true + case e: expressions.EndsWith => true + case e: expressions.Contains => true + case o => false + } + } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) @@ -233,6 +258,73 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + + // Filter should not be duplicate with pushdown predicate + // This query should not have Filter node and have 1 row + testNotFilter("SELECT a, b FROM oneToTenFiltered WHERE a = 1", 1) + // This query have Filter node which filters 10 rows to 1 row + testFilter("SELECT a, b FROM oneToTenFiltered WHERE b = 2", 10, 1) + + def testNotFilter(sqlString: String, count: Int): Unit = { + test(s"Without Filter: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val filterPlan = queryExecution.executedPlan.collect { + case f: execution.Filter => f + } match { + case Seq(f) => fail(s"Shouldn't find Filter\n$queryExecution") + case _ => + } + + val physicalRDDPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"Can't find PhysicalRDD\n$queryExecution") + } + + val rawCount = physicalRDDPlan.execute().count() + + if (rawCount != count) { + fail( + s"Wrong # of results for filter. Got $rawCount, Expected $count\n" + + queryExecution) + } + } + } + + def testFilter(sqlString: String, countRDD: Int, countFilter: Int): Unit = { + test(s"Filter: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val filterPlan = queryExecution.executedPlan.collect { + case f: execution.Filter => f + } match { + case Seq(f) => Some(f) + case _ => fail(s"Can't find Filter\n$queryExecution") + } + + val physicalRDDPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"Can't find PhysicalRDD\n$queryExecution") + } + + val rawCount = filterPlan.get.execute().count() + val rawCountRDD = physicalRDDPlan.execute().count() + + if (rawCount != countFilter) { + fail( + s"Wrong # of results for Filter Plan. Got $rawCount, Expected $countFilter\n" + + queryExecution) + } + if (rawCountRDD != countRDD) { + fail( + s"Wrong # of results for PhysicalRDD Plan. Got $rawCountRDD, Expected $countRDD\n" + + queryExecution) + } + + } + } def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") {