diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 6213267c41c64..abbf2d44431cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object SchemaPruning { + + val sqlConf = SQLConf.get /** * Filters the schema by the requested fields. For example, if the schema is struct, * and given requested field are "a", the field "b" is pruned in the returned schema. @@ -28,6 +31,7 @@ object SchemaPruning { def pruneDataSchema( dataSchema: StructType, requestedRootFields: Seq[RootField]): StructType = { + val resolver = sqlConf.resolver // Merge the requested root fields into a single schema. Note the ordering of the fields // in the resulting schema may differ from their ordering in the logical relation's // original schema @@ -36,7 +40,7 @@ object SchemaPruning { .reduceLeft(_ merge _) val dataSchemaFieldNames = dataSchema.fieldNames.toSet val mergedDataSchema = - StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) + StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name)))) // Sort the fields of mergedDataSchema according to their order in dataSchema, // recursively. This makes mergedDataSchema a pruned schema of dataSchema sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType] @@ -61,12 +65,15 @@ object SchemaPruning { sortLeftFieldsByRight(leftValueType, rightValueType), containsNull) case (leftStruct: StructType, rightStruct: StructType) => - val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains) + val resolver = sqlConf.resolver + val filteredRightFieldNames = rightStruct.fieldNames + .filter(name => leftStruct.fieldNames.exists(resolver(_, name))) val sortedLeftFields = filteredRightFieldNames.map { fieldName => - val leftFieldType = leftStruct(fieldName).dataType + val resolvedLeftStruct = leftStruct.find(p => resolver(p.name, fieldName)).get + val leftFieldType = resolvedLeftStruct.dataType val rightFieldType = rightStruct(fieldName).dataType val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) - StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable) + StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable) } StructType(sortedLeftFields) case _ => left diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala index c04f59ebb1b1b..7895f4d5ef400 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala @@ -18,9 +18,20 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE import org.apache.spark.sql.types._ -class SchemaPruningSuite extends SparkFunSuite { +class SchemaPruningSuite extends SparkFunSuite with SQLHelper { + + def getRootFields(requestedFields: StructField*): Seq[RootField] = { + requestedFields.map { f => + // `derivedFromAtt` doesn't affect the result of pruned schema. + SchemaPruning.RootField(field = f, derivedFromAtt = true) + } + } + test("prune schema by the requested fields") { def testPrunedSchema( schema: StructType, @@ -59,4 +70,34 @@ class SchemaPruningSuite extends SparkFunSuite { StructType.fromDDL("e int, f string"))) testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap) } + + test("SPARK-35096: test case insensitivity of pruned schema") { + Seq(true, false).foreach(isCaseSensitive => { + withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) { + if (isCaseSensitive) { + // Schema is case-sensitive + val requestedFields = getRootFields(StructField("id", IntegerType)) + val prunedSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("ID int, name String"), requestedFields) + assert(prunedSchema == StructType(Seq.empty)) + // Root fields are case-sensitive + val rootFieldsSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("id int, name String"), + getRootFields(StructField("ID", IntegerType))) + assert(rootFieldsSchema == StructType(StructType(Seq.empty))) + } else { + // Schema is case-insensitive + val prunedSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("ID int, name String"), + getRootFields(StructField("id", IntegerType))) + assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil)) + // Root fields are case-insensitive + val rootFieldsSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("id int, name String"), + getRootFields(StructField("ID", IntegerType))) + assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil)) + } + } + }) + } }