Skip to content

Commit

Permalink
[SPARK-32646][SQL] ORC predicate pushdown should work with case-insen…
Browse files Browse the repository at this point in the history
…sitive analysis

### What changes were proposed in this pull request?

This PR proposes to fix ORC predicate pushdown under case-insensitive analysis case. The field names in pushed down predicates don't need to match in exact letter case with physical field names in ORC files, if we enable case-insensitive analysis.

### Why are the changes needed?

Currently ORC predicate pushdown doesn't work with case-insensitive analysis. A predicate "a < 0" cannot pushdown to ORC file with field name "A" under case-insensitive analysis.

But Parquet predicate pushdown works with this case. We should make ORC predicate pushdown work with case-insensitive analysis too.

### Does this PR introduce _any_ user-facing change?

Yes, after this PR, under case-insensitive analysis, ORC predicate pushdown will work.

### How was this patch tested?

Unit tests.

Closes #29457 from viirya/fix-orc-pushdown.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
viirya authored and cloud-fan committed Aug 21, 2020
1 parent bf221de commit e277ef1
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 60 deletions.
Expand Up @@ -153,11 +153,6 @@ class OrcFileFormat
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(dataSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
}
}

val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
val sqlConf = sparkSession.sessionState.conf
Expand All @@ -169,6 +164,8 @@ class OrcFileFormat
val broadcastedConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown
val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles

(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
Expand All @@ -186,6 +183,15 @@ class OrcFileFormat
if (resultedColPruneInfo.isEmpty) {
Iterator.empty
} else {
// ORC predicate pushdown
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}

val (requestedColIds, canPruneCols) = resultedColPruneInfo.get
val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols,
dataSchema, resultSchema, partitionSchema, conf)
Expand Down
Expand Up @@ -39,6 +39,8 @@ trait OrcFiltersBase {
}
}

case class OrcPrimitiveField(fieldName: String, fieldType: DataType)

/**
* This method returns a map which contains ORC field name and data type. Each key
* represents a column; `dots` are used as separators for nested columns. If any part
Expand All @@ -49,19 +51,21 @@ trait OrcFiltersBase {
*/
protected[sql] def getSearchableTypeMap(
schema: StructType,
caseSensitive: Boolean): Map[String, DataType] = {
caseSensitive: Boolean): Map[String, OrcPrimitiveField] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper

def getPrimitiveFields(
fields: Seq[StructField],
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = {
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, OrcPrimitiveField)] = {
fields.flatMap { f =>
f.dataType match {
case st: StructType =>
getPrimitiveFields(st.fields, parentFieldNames :+ f.name)
case BinaryType => None
case _: AtomicType =>
Some(((parentFieldNames :+ f.name).quoted, f.dataType))
val fieldName = (parentFieldNames :+ f.name).quoted
val orcField = OrcPrimitiveField(fieldName, f.dataType)
Some((fieldName, orcField))
case _ => None
}
}
Expand Down
Expand Up @@ -92,6 +92,20 @@ object OrcUtils extends Logging {
}
}

def readCatalystSchema(
file: Path,
conf: Configuration,
ignoreCorruptFiles: Boolean): Option[StructType] = {
readSchema(file, conf, ignoreCorruptFiles) match {
case Some(schema) =>
Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType])

case None =>
// Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true.
None
}
}

/**
* Reads ORC file schemas in multi-threaded manner, using native version of ORC.
* This is visible for testing.
Expand Down
Expand Up @@ -31,9 +31,10 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils}
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{AtomicType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -52,24 +53,39 @@ case class OrcPartitionReaderFactory(
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType) extends FilePartitionReaderFactory {
partitionSchema: StructType,
filters: Array[Filter]) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize
private val orcFilterPushDown = sqlConf.orcFilterPushDown
private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles

override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
}

private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = {
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}
}

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value

OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down Expand Up @@ -116,6 +132,8 @@ case class OrcPartitionReaderFactory(

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down
Expand Up @@ -48,7 +48,7 @@ case class OrcScan(
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema)
dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
}

override def equals(obj: Any): Boolean = obj match {
Expand Down
Expand Up @@ -56,11 +56,6 @@ case class OrcScanBuilder(

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(schema, filters).foreach { f =>
// The pushed filters will be set in `hadoopConf`. After that, we can simply use the
// changed `hadoopConf` in executors.
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
}
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray
}
Expand Down
Expand Up @@ -81,7 +81,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {

def convertibleFilters(
schema: StructType,
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
filters: Seq[Filter]): Seq[Filter] = {
import org.apache.spark.sql.sources._

Expand Down Expand Up @@ -179,7 +179,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildSearchArgument(
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Builder = {
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -215,7 +215,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildLeafSearchArgument(
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Option[Builder] = {
def getType(attribute: String): PredicateLeaf.Type =
Expand All @@ -228,38 +228,44 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
expression match {
case EqualTo(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().equals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.equals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case LessThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case GreaterThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThan(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case IsNull(name) if dataTypeMap.contains(name) =>
Some(builder.startAnd().isNull(name, getType(name)).end())
Some(builder.startAnd().isNull(dataTypeMap(name).fieldName, getType(name)).end())

case IsNotNull(name) if dataTypeMap.contains(name) =>
Some(builder.startNot().isNull(name, getType(name)).end())
Some(builder.startNot().isNull(dataTypeMap(name).fieldName, getType(name)).end())

case In(name, values) if dataTypeMap.contains(name) =>
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
Some(builder.startAnd().in(name, getType(name),
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType))
Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())

case _ => None
Expand Down
Expand Up @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._

import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
Expand Down Expand Up @@ -586,8 +587,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1)))

val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0"))
// TODO: ORC predicate pushdown should work under case-insensitive analysis.
// assert(actual.count() == 1)
assert(actual.count() == 1)
}
}

Expand All @@ -606,5 +606,71 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
}
}
}

test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") {
import org.apache.spark.sql.sources._

def getOrcFilter(
schema: StructType,
filters: Seq[Filter],
caseSensitive: String): Option[SearchArgument] = {
var orcFilter: Option[SearchArgument] = None
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
orcFilter =
OrcFilters.createFilter(schema, filters)
}
orcFilter
}

def testFilter(
schema: StructType,
filters: Seq[Filter],
expected: SearchArgument): Unit = {
val caseSensitiveFilters = getOrcFilter(schema, filters, "true")
val caseInsensitiveFilters = getOrcFilter(schema, filters, "false")

assert(caseSensitiveFilters.isEmpty)
assert(caseInsensitiveFilters.isDefined)

assert(caseInsensitiveFilters.get.getLeaves().size() > 0)
assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size())
(0 until expected.getLeaves().size()).foreach { index =>
assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index))
}
}

val schema1 = StructType(Seq(StructField("cint", IntegerType)))
testFilter(schema1, Seq(GreaterThan("CINT", 1)),
newBuilder.startNot()
.lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema1, Seq(
And(GreaterThan("CINT", 1), EqualTo("Cint", 2))),
newBuilder.startAnd()
.startNot()
.lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
.equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
.`end`().build())

// Nested column case
val schema2 = StructType(Seq(StructField("a",
StructType(Seq(StructField("cint", IntegerType))))))

testFilter(schema2, Seq(GreaterThan("A.CINT", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(GreaterThan("a.CINT", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(GreaterThan("A.cint", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(
And(GreaterThan("a.CINT", 1), EqualTo("a.Cint", 2))),
newBuilder.startAnd()
.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
.equals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
.`end`().build())
}
}

0 comments on commit e277ef1

Please sign in to comment.