Skip to content

Commit

Permalink
StringStartsWith support push down
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Jun 23, 2018
1 parent 15747cf commit 5b52ace
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import java.sql.Date
import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.io.api.Binary
import org.apache.parquet.schema.PrimitiveComparator

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Some utility function to convert Spark data source filters to Parquet filters.
Expand Down Expand Up @@ -270,6 +272,29 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) {
case sources.Not(pred) =>
createFilter(schema, pred).map(FilterApi.not)

case sources.StringStartsWith(name, prefix) if canMakeFilterOn(name) =>
Option(prefix).map { v =>
FilterApi.userDefined(binaryColumn(name),
new UserDefinedPredicate[Binary] with Serializable {
private val strToBinary = Binary.fromReusedByteArray(v.getBytes)
private val size = strToBinary.length

override def canDrop(statistics: Statistics[Binary]): Boolean = {
val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR
val max = statistics.getMax
val min = statistics.getMin
comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 ||
comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0
}

override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false

override def keep(value: Binary): Boolean =
UTF8String.fromBytes(value.getBytes).startsWith(UTF8String.fromString(v))
}
)
}

case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
assert(df.where("col > 0").count() === 2)
}
}

test("filter pushdown - StringStartsWith") {
withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df =>
Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix =>
checkFilterPredicate(
'_1.startsWith(prefix).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
"2str2")
}

Seq("2S", "null", "2str22").foreach { prefix =>
checkFilterPredicate(
'_1.startsWith(prefix).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq.empty[Row])
}

assertResult(None) {
parquetFilters.createFilter(
df.schema,
sources.StringStartsWith("_1", null))
}
}
}
}

class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
Expand Down

0 comments on commit 5b52ace

Please sign in to comment.