Skip to content

Commit

Permalink
SchemaPruning should adhere spark.sql.caseSensitive config
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeep-katta committed Apr 15, 2021
1 parent a153efa commit f6e4b6b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.types._

object SchemaPruning {
object SchemaPruning extends SQLConfHelper {
/**
* Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
* and given requested field are "a", the field "b" is pruned in the returned schema.
Expand All @@ -28,6 +29,7 @@ object SchemaPruning {
def pruneDataSchema(
dataSchema: StructType,
requestedRootFields: Seq[RootField]): StructType = {
val resolver = conf.resolver
// Merge the requested root fields into a single schema. Note the ordering of the fields
// in the resulting schema may differ from their ordering in the logical relation's
// original schema
Expand All @@ -36,7 +38,8 @@ object SchemaPruning {
.reduceLeft(_ merge _)
val dataSchemaFieldNames = dataSchema.fieldNames.toSet
val mergedDataSchema =
StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name))
))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
Expand All @@ -61,12 +64,16 @@ object SchemaPruning {
sortLeftFieldsByRight(leftValueType, rightValueType),
containsNull)
case (leftStruct: StructType, rightStruct: StructType) =>
val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
val resolver = conf.resolver
val filteredRightFieldNames = rightStruct.fieldNames
.filter(name => leftStruct.fieldNames.exists(resolver(_, name)))
val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
val leftFieldType = leftStruct(fieldName).dataType
val rightFieldType = rightStruct(fieldName).dataType
val resolvedLeftStruct = leftStruct.filter(p => resolver(p.name, fieldName)).head
val leftFieldType = resolvedLeftStruct.dataType
val resolvedRightStruct = rightStruct.filter(p => resolver(p.name, fieldName)).head
val rightFieldType = resolvedRightStruct.dataType
val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable)
StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable)
}
StructType(sortedLeftFields)
case _ => left
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,20 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
import org.apache.spark.sql.types._

class SchemaPruningSuite extends SparkFunSuite {
class SchemaPruningSuite extends SparkFunSuite with SQLHelper {

def getRootFields(requestedFields: StructField*): Seq[RootField] = {
requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
}

test("prune schema by the requested fields") {
def testPrunedSchema(
schema: StructType,
Expand All @@ -32,7 +43,6 @@ class SchemaPruningSuite extends SparkFunSuite {
val expectedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
assert(expectedSchema == StructType(requestedFields))
}

testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("a", IntegerType))
testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("b", IntegerType))

Expand All @@ -59,4 +69,28 @@ class SchemaPruningSuite extends SparkFunSuite {
StructType.fromDDL("e int, f string")))
testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
}

test("SPARK-35096: test case insensitivity of pruned schema ") {
Seq(true, false).foreach(isCaseSensitive => {
withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) {
if (isCaseSensitive) {
val requestedFields = getRootFields(StructField("id", IntegerType))
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"), requestedFields)
assert(prunedSchema == StructType(Seq.empty))
} else {
// Schema is case insensitive
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"),
getRootFields(StructField("id", IntegerType)))
assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil))
// Root fields are insensitive
val prunedSchema_1 = SchemaPruning.pruneDataSchema(
StructType.fromDDL("id int, name String"),
getRootFields(StructField("ID", IntegerType)))
assert(prunedSchema_1 == StructType(StructField("id", IntegerType) :: Nil))
}
}
})
}
}

0 comments on commit f6e4b6b

Please sign in to comment.