diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 310626197a763..f437e0025de7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -62,6 +62,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++ + v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull) } private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -88,6 +93,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++ + v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull) } private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -111,6 +121,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { (n: String, v: Any) => FilterApi.lt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) => + (n: String, v: Any) => FilterApi.lt( + binaryColumn(n), + Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++ + v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull) } private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -134,6 +149,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { (n: String, v: Any) => FilterApi.ltEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) => + (n: String, v: Any) => FilterApi.ltEq( + binaryColumn(n), + Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++ + v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull) } private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -157,6 +177,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { (n: String, v: Any) => FilterApi.gt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) => + (n: String, v: Any) => FilterApi.gt( + binaryColumn(n), + Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++ + v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull) } private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -180,6 +205,11 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { (n: String, v: Any) => FilterApi.gtEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case decimal: DecimalType if DecimalType.isByteArrayDecimalType(decimal) => + (n: String, v: Any) => FilterApi.gtEq( + binaryColumn(n), + Option(v).map(d => Binary.fromReusedByteArray(new Array[Byte](8) ++ + v.asInstanceOf[java.math.BigDecimal].unscaledValue().toByteArray)).orNull) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 90da7eb8c4fb5..d2333abd95c82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.math.BigInteger import java.nio.charset.StandardCharsets import java.sql.Date import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} - import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -359,6 +359,41 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - decimal(ByteArrayDecimalType)") { + val one = new java.math.BigDecimal(1) + val two = new java.math.BigDecimal(2) + val three = new java.math.BigDecimal(3) + val four = new java.math.BigDecimal(4) + + val data = Seq(one, two, three, four) + + withParquetDataFrame(data.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) + + checkFilterPredicate('_1 === one, classOf[Eq[_]], one) + checkFilterPredicate('_1 <=> one, classOf[Eq[_]], one) + checkFilterPredicate('_1 =!= one, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(new java.math.BigDecimal(i)))) + + checkFilterPredicate('_1 < two, classOf[Lt[_]], one) + checkFilterPredicate('_1 > three, classOf[Gt[_]], four) + checkFilterPredicate('_1 <= one, classOf[LtEq[_]], one) + checkFilterPredicate('_1 >= four, classOf[GtEq[_]], four) + + checkFilterPredicate(Literal(one) === '_1, classOf[Eq[_]], one) + checkFilterPredicate(Literal(one) <=> '_1, classOf[Eq[_]], one) + checkFilterPredicate(Literal(two) > '_1, classOf[Lt[_]], one) + checkFilterPredicate(Literal(three) < '_1, classOf[Gt[_]], four) + checkFilterPredicate(Literal(one) >= '_1, classOf[LtEq[_]], one) + checkFilterPredicate(Literal(four) <= '_1, classOf[GtEq[_]], four) + + checkFilterPredicate(!('_1 < four), classOf[GtEq[_]], four) + checkFilterPredicate('_1 < two || '_1 > three, classOf[Operators.Or], + Seq(Row(one), Row(four))) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._