Skip to content

Commit

Permalink
[SPARK-25557][SQL] Nested column predicate pushdown for ORC
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

We added nested column predicate pushdown for Parquet in #27728. This patch extends the feature support to ORC.

### Why are the changes needed?

Extending the feature to ORC for feature parity. Better performance for handling nested predicate pushdown.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit tests.

Closes #28761 from viirya/SPARK-25557.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
viirya authored and dongjoon-hyun committed Aug 7, 2020
1 parent 6c3d0a4 commit 7b6e1d5
Show file tree
Hide file tree
Showing 11 changed files with 460 additions and 310 deletions.
Expand Up @@ -2108,9 +2108,9 @@ object SQLConf {
.doc("A comma-separated list of data source short names or fully qualified data source " +
"implementation class names for which Spark tries to push down predicates for nested " +
"columns and/or names containing `dots` to data sources. This configuration is only " +
"effective with file-based data source in DSv1. Currently, Parquet implements " +
"both optimizations while ORC only supports predicates for names containing `dots`. The " +
"other data sources don't support this feature yet. So the default value is 'parquet,orc'.")
"effective with file-based data sources in DSv1. Currently, Parquet and ORC implement " +
"both optimizations. The other data sources don't support this feature yet. So the " +
"default value is 'parquet,orc'.")
.version("3.0.0")
.stringConf
.createWithDefault("parquet,orc")
Expand Down
Expand Up @@ -668,6 +668,8 @@ abstract class PushableColumnBase {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
def helper(e: Expression): Option[Seq[String]] = e match {
case a: Attribute =>
// Attribute that contains dot "." in name is supported only when
// nested predicate pushdown is enabled.
if (nestedPredicatePushdownEnabled || !a.name.contains(".")) {
Some(Seq(a.name))
} else {
Expand Down
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.execution.datasources.orc

import java.util.Locale

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.{And, Filter}
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType}
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField, StructType}

/**
* Methods that can be shared when upgrading the built-in Hive.
Expand All @@ -37,12 +40,45 @@ trait OrcFiltersBase {
}

/**
* Return true if this is a searchable type in ORC.
* Both CharType and VarcharType are cleaned at AstBuilder.
* This method returns a map which contains ORC field name and data type. Each key
* represents a column; `dots` are used as separators for nested columns. If any part
* of the names contains `dots`, it is quoted to avoid confusion. See
* `org.apache.spark.sql.connector.catalog.quoted` for implementation details.
*
* BinaryType, UserDefinedType, ArrayType and MapType are ignored.
*/
protected[sql] def isSearchableType(dataType: DataType) = dataType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
protected[sql] def getSearchableTypeMap(
schema: StructType,
caseSensitive: Boolean): Map[String, DataType] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper

def getPrimitiveFields(
fields: Seq[StructField],
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = {
fields.flatMap { f =>
f.dataType match {
case st: StructType =>
getPrimitiveFields(st.fields, parentFieldNames :+ f.name)
case BinaryType => None
case _: AtomicType =>
Some(((parentFieldNames :+ f.name).quoted, f.dataType))
case _ => None
}
}
}

val primitiveFields = getPrimitiveFields(schema.fields)
if (caseSensitive) {
primitiveFields.toMap
} else {
// Don't consider ambiguity here, i.e. more than one field are matched in case insensitive
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
// See: SPARK-25175.
val dedupPrimitiveFields = primitiveFields
.groupBy(_._1.toLowerCase(Locale.ROOT))
.filter(_._2.size == 1)
.mapValues(_.head._2)
CaseInsensitiveMap(dedupPrimitiveFields)
}
}
}
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -60,10 +61,8 @@ case class OrcScanBuilder(
// changed `hadoopConf` in executors.
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
}
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray
}
filters
}
Expand Down
Expand Up @@ -22,8 +22,10 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

/**
* A helper trait that provides convenient facilities for file-based data source testing.
Expand Down Expand Up @@ -103,4 +105,40 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
df: DataFrame, path: File): Unit = {
df.write.mode(SaveMode.Overwrite).format(dataSourceName).save(path.getCanonicalPath)
}

/**
* Takes single level `inputDF` dataframe to generate multi-level nested
* dataframes as new test data. It tests both non-nested and nested dataframes
* which are written and read back with specified datasource.
*/
protected def withNestedDataFrame(inputDF: DataFrame): Seq[(DataFrame, String, Any => Any)] = {
assert(inputDF.schema.fields.length == 1)
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
val df = inputDF.toDF("temp")
Seq(
(
df.withColumnRenamed("temp", "a"),
"a", // zero nesting
(x: Any) => x),
(
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
"a.b", // one level nesting
(x: Any) => Row(x)),
(
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
"a.b.c", // two level nesting
(x: Any) => Row(Row(x))
),
(
df.withColumnRenamed("temp", "a.b"),
"`a.b`", // zero nesting with column name containing `dots`
(x: Any) => x
),
(
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
)
}
}
Expand Up @@ -143,4 +143,26 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor
FileUtils.copyURLToFile(url, file)
spark.read.orc(file.getAbsolutePath)
}

/**
* Takes a sequence of products `data` to generate multi-level nested
* dataframes as new test data. It tests both non-nested and nested dataframes
* which are written and read back with Orc datasource.
*
* This is different from [[withOrcDataFrame]] which does not
* test nested cases.
*/
protected def withNestedOrcDataFrame[T <: Product: ClassTag: TypeTag](data: Seq[T])
(runTest: (DataFrame, String, Any => Any) => Unit): Unit =
withNestedOrcDataFrame(spark.createDataFrame(data))(runTest)

protected def withNestedOrcDataFrame(inputDF: DataFrame)
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
withTempPath { file =>
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
readFile(file.getCanonicalPath, true) { df => runTest(df, colName, resultFun) }
}
}
}
}
Expand Up @@ -122,34 +122,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared

private def withNestedParquetDataFrame(inputDF: DataFrame)
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
assert(inputDF.schema.fields.length == 1)
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
val df = inputDF.toDF("temp")
Seq(
(
df.withColumnRenamed("temp", "a"),
"a", // zero nesting
(x: Any) => x),
(
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
"a.b", // one level nesting
(x: Any) => Row(x)),
(
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
"a.b.c", // two level nesting
(x: Any) => Row(Row(x))
),
(
df.withColumnRenamed("temp", "a.b"),
"`a.b`", // zero nesting with column name containing `dots`
(x: Any) => x
),
(
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
).foreach { case (newDF, colName, resultFun) =>
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
withTempPath { file =>
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
readParquetFile(file.getCanonicalPath) { df => runTest(df, colName, resultFun) }
Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.apache.orc.storage.serde2.io.HiveDecimalWritable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -68,11 +68,9 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* Create ORC filter as a SearchArgument instance.
*/
def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
// Combines all convertible filters using `And` to produce a single conjunction
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters))
val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters))
conjunctionOptional.map { conjunction =>
// Then tries to build a single ORC `SearchArgument` for the conjunction predicate.
// The input predicate is fully convertible. There should not be any empty result in the
Expand Down Expand Up @@ -228,40 +226,38 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()`
// call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be
// wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters
// in order to distinguish predicate pushdown for nested columns.
expression match {
case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
case EqualTo(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().equals(name, getType(name), castedValue).end())

case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())

case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
case LessThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())

case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())

case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
case GreaterThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())

case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThan(name, getType(name), castedValue).end())

case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
case IsNull(name) if dataTypeMap.contains(name) =>
Some(builder.startAnd().isNull(name, getType(name)).end())

case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
case IsNotNull(name) if dataTypeMap.contains(name) =>
Some(builder.startNot().isNull(name, getType(name)).end())

case In(name, values) if isSearchableType(dataTypeMap(name)) =>
case In(name, values) if dataTypeMap.contains(name) =>
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
Some(builder.startAnd().in(name, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
Expand Down

0 comments on commit 7b6e1d5

Please sign in to comment.