From c03ae069d738c6aa526cc1a1216d079bc8b5ec3e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 13 Jan 2020 12:19:51 +0300 Subject: [PATCH] Put literal filters in front of others --- .../spark/sql/catalyst/csv/CSVFilters.scala | 17 ++++++++++++----- .../sql/catalyst/csv/CSVFiltersSuite.scala | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVFilters.scala index 28a03ab17f0dd..7222f9f050348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVFilters.scala @@ -74,9 +74,8 @@ class CSVFilters( val len = readSchema.fields.length val groupedPredicates = Array.fill[BasePredicate](len)(null) if (SQLConf.get.csvFilterPushDown) { - val groupedExprs = Array.fill(len)(Seq.empty[Expression]) + val groupedFilters = Array.fill(len)(Seq.empty[sources.Filter]) for (filter <- filters) { - val expr = CSVFilters.filterToExpression(filter, toRef) val refs = filter.references val index = if (refs.isEmpty) { // For example, AlwaysTrue and AlwaysFalse doesn't have any references @@ -89,11 +88,19 @@ class CSVFilters( // Accordingly, fieldIndex() returns a valid index always. refs.map(readSchema.fieldIndex).max } - groupedExprs(index) ++= expr + groupedFilters(index) :+= filter + } + if (len > 0 && !groupedFilters(0).isEmpty) { + // We assume that filters w/o refs like AlwaysTrue and AlwaysFalse + // can be evaluated faster that others. We put them in front of others. + val (literals, others) = groupedFilters(0).partition(_.references.isEmpty) + groupedFilters(0) = literals ++ others } for (i <- 0 until len) { - if (!groupedExprs(i).isEmpty) { - val reducedExpr = groupedExprs(i).reduce(And) + if (!groupedFilters(i).isEmpty) { + val reducedExpr = groupedFilters(i) + .flatMap(CSVFilters.filterToExpression(_, toRef)) + .reduce(And) groupedPredicates(i) = Predicate.create(reducedExpr) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVFiltersSuite.scala index 956c3e3c9e068..9268877964398 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVFiltersSuite.scala @@ -117,7 +117,7 @@ class CSVFiltersSuite extends SparkFunSuite { check(filters = Seq(AlwaysTrue), row = InternalRow(1), pos = 0, skip = false) check(filters = Seq(AlwaysFalse), row = InternalRow(1), pos = 0, skip = true) check( - filters = Seq(sources.LessThan("d", 10), sources.AlwaysFalse), + filters = Seq(sources.EqualTo("i", 1), sources.LessThan("d", 10), sources.AlwaysFalse), requiredSchema = "i INTEGER, d DOUBLE", row = InternalRow(1, 3.14), pos = 0,