Skip to content

Commit

Permalink
Put literal filters in front of others
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxGekk committed Jan 13, 2020
1 parent 4a25815 commit c03ae06
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ class CSVFilters(
val len = readSchema.fields.length
val groupedPredicates = Array.fill[BasePredicate](len)(null)
if (SQLConf.get.csvFilterPushDown) {
val groupedExprs = Array.fill(len)(Seq.empty[Expression])
val groupedFilters = Array.fill(len)(Seq.empty[sources.Filter])
for (filter <- filters) {
val expr = CSVFilters.filterToExpression(filter, toRef)
val refs = filter.references
val index = if (refs.isEmpty) {
// For example, AlwaysTrue and AlwaysFalse doesn't have any references
Expand All @@ -89,11 +88,19 @@ class CSVFilters(
// Accordingly, fieldIndex() returns a valid index always.
refs.map(readSchema.fieldIndex).max
}
groupedExprs(index) ++= expr
groupedFilters(index) :+= filter
}
if (len > 0 && !groupedFilters(0).isEmpty) {
// We assume that filters w/o refs like AlwaysTrue and AlwaysFalse
// can be evaluated faster that others. We put them in front of others.
val (literals, others) = groupedFilters(0).partition(_.references.isEmpty)
groupedFilters(0) = literals ++ others
}
for (i <- 0 until len) {
if (!groupedExprs(i).isEmpty) {
val reducedExpr = groupedExprs(i).reduce(And)
if (!groupedFilters(i).isEmpty) {
val reducedExpr = groupedFilters(i)
.flatMap(CSVFilters.filterToExpression(_, toRef))
.reduce(And)
groupedPredicates(i) = Predicate.create(reducedExpr)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class CSVFiltersSuite extends SparkFunSuite {
check(filters = Seq(AlwaysTrue), row = InternalRow(1), pos = 0, skip = false)
check(filters = Seq(AlwaysFalse), row = InternalRow(1), pos = 0, skip = true)
check(
filters = Seq(sources.LessThan("d", 10), sources.AlwaysFalse),
filters = Seq(sources.EqualTo("i", 1), sources.LessThan("d", 10), sources.AlwaysFalse),
requiredSchema = "i INTEGER, d DOUBLE",
row = InternalRow(1, 3.14),
pos = 0,
Expand Down

0 comments on commit c03ae06

Please sign in to comment.