Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35096][SQL] SchemaPruning should adhere spark.sql.caseSensitive config #32194

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.types._

object SchemaPruning {
object SchemaPruning extends SQLConfHelper {
/**
* Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
* and given requested field are "a", the field "b" is pruned in the returned schema.
Expand All @@ -28,6 +29,7 @@ object SchemaPruning {
def pruneDataSchema(
dataSchema: StructType,
requestedRootFields: Seq[RootField]): StructType = {
val resolver = conf.resolver
// Merge the requested root fields into a single schema. Note the ordering of the fields
// in the resulting schema may differ from their ordering in the logical relation's
// original schema
Expand All @@ -36,7 +38,7 @@ object SchemaPruning {
.reduceLeft(_ merge _)
val dataSchemaFieldNames = dataSchema.fieldNames.toSet
val mergedDataSchema =
StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name))))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
Expand All @@ -61,12 +63,15 @@ object SchemaPruning {
sortLeftFieldsByRight(leftValueType, rightValueType),
containsNull)
case (leftStruct: StructType, rightStruct: StructType) =>
val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
val resolver = conf.resolver
val filteredRightFieldNames = rightStruct.fieldNames
.filter(name => leftStruct.fieldNames.exists(resolver(_, name)))
val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
val leftFieldType = leftStruct(fieldName).dataType
val resolvedLeftStruct = leftStruct.find(p => resolver(p.name, fieldName)).get
val leftFieldType = resolvedLeftStruct.dataType
val rightFieldType = rightStruct(fieldName).dataType
val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable)
StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable)
}
StructType(sortedLeftFields)
case _ => left
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,20 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
import org.apache.spark.sql.types._

class SchemaPruningSuite extends SparkFunSuite {
class SchemaPruningSuite extends SparkFunSuite with SQLHelper {

def getRootFields(requestedFields: StructField*): Seq[RootField] = {
requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
}

test("prune schema by the requested fields") {
def testPrunedSchema(
schema: StructType,
Expand Down Expand Up @@ -59,4 +70,34 @@ class SchemaPruningSuite extends SparkFunSuite {
StructType.fromDDL("e int, f string")))
testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
}

test("SPARK-35096: test case insensitivity of pruned schema") {
Seq(true, false).foreach(isCaseSensitive => {
withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) {
if (isCaseSensitive) {
// Schema is case-sensitive
val requestedFields = getRootFields(StructField("id", IntegerType))
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"), requestedFields)
assert(prunedSchema == StructType(Seq.empty))
// Root fields are case-sensitive
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("id int, name String"),
getRootFields(StructField("ID", IntegerType)))
assert(rootFieldsSchema == StructType(StructType(Seq.empty)))
} else {
// Schema is case-insensitive
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"),
getRootFields(StructField("id", IntegerType)))
assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil))
// Root fields are case-insensitive
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("id int, name String"),
getRootFields(StructField("ID", IntegerType)))
assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil))
}
}
})
}
}