Skip to content

Commit

Permalink
ByteArrayDecimalType support push down to the data sources
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Jun 13, 2018
1 parent e76b012 commit 9606670
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
Expand Up @@ -62,6 +62,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) {
(n: String, v: Any) => FilterApi.eq(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++
v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull)
}

private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
Expand All @@ -88,6 +93,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) {
(n: String, v: Any) => FilterApi.notEq(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++
v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull)
}

private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
Expand All @@ -111,6 +121,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) {
(n: String, v: Any) => FilterApi.lt(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) =>
(n: String, v: Any) => FilterApi.lt(
binaryColumn(n),
Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++
v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull)
}

private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
Expand All @@ -134,6 +149,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) {
(n: String, v: Any) => FilterApi.ltEq(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) =>
(n: String, v: Any) => FilterApi.ltEq(
binaryColumn(n),
Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++
v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull)
}

private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
Expand All @@ -157,6 +177,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) {
(n: String, v: Any) => FilterApi.gt(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) =>
(n: String, v: Any) => FilterApi.gt(
binaryColumn(n),
Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++
v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull)
}

private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
Expand All @@ -180,6 +205,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) {
(n: String, v: Any) => FilterApi.gtEq(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) =>
(n: String, v: Any) => FilterApi.gtEq(
binaryColumn(n),
Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++
v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull)
}

/**
Expand Down
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.sql.execution.datasources.parquet

import java.math.BigInteger
import java.nio.charset.StandardCharsets
import java.sql.Date

import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -359,6 +359,41 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}

test("filter pushdown - decimal(ByteArrayDecimalType)") {
val one = new java.math.BigDecimal(1)
val two = new java.math.BigDecimal(2)
val three = new java.math.BigDecimal(3)
val four = new java.math.BigDecimal(4)

val data = Seq(one, two, three, four)

withParquetDataFrame(data.map(Tuple1(_))) { implicit df =>
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i)))

checkFilterPredicate('_1 === one, classOf[Eq[_]], one)
checkFilterPredicate('_1 <=> one, classOf[Eq[_]], one)
checkFilterPredicate('_1 =!= one, classOf[NotEq[_]],
(2 to 4).map(i => Row.apply(new java.math.BigDecimal(i))))

checkFilterPredicate('_1 < two, classOf[Lt[_]], one)
checkFilterPredicate('_1 > three, classOf[Gt[_]], four)
checkFilterPredicate('_1 <= one, classOf[LtEq[_]], one)
checkFilterPredicate('_1 >= four, classOf[GtEq[_]], four)

checkFilterPredicate(Literal(one) === '_1, classOf[Eq[_]], one)
checkFilterPredicate(Literal(one) <=> '_1, classOf[Eq[_]], one)
checkFilterPredicate(Literal(two) > '_1, classOf[Lt[_]], one)
checkFilterPredicate(Literal(three) < '_1, classOf[Gt[_]], four)
checkFilterPredicate(Literal(one) >= '_1, classOf[LtEq[_]], one)
checkFilterPredicate(Literal(four) <= '_1, classOf[GtEq[_]], four)

checkFilterPredicate(!('_1 < four), classOf[GtEq[_]], four)
checkFilterPredicate('_1 < two || '_1 > three, classOf[Operators.Or],
Seq(Row(one), Row(four)))
}
}

test("SPARK-6554: don't push down predicates which reference partition columns") {
import testImplicits._

Expand Down

0 comments on commit 9606670

Please sign in to comment.