Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ object SchemaPruning extends SQLConfHelper {
* 1. The schema field ordering at original schema is still preserved in pruned schema.
* 2. The top-level fields are not pruned here.
*/
def pruneDataSchema(
dataSchema: StructType,
def pruneSchema(
schema: StructType,
requestedRootFields: Seq[RootField]): StructType = {
val resolver = conf.resolver
// Merge the requested root fields into a single schema. Note the ordering of the fields
Expand All @@ -44,10 +44,10 @@ object SchemaPruning extends SQLConfHelper {
.map { root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
val mergedDataSchema =
StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
StructType(schema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
sortLeftFieldsByRight(mergedDataSchema, schema).asInstanceOf[StructType]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ case class AttributeReference(
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
}

override def withDataType(newType: DataType): Attribute = {
override def withDataType(newType: DataType): AttributeReference = {
AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {

if (conf.serializerNestedSchemaPruningEnabled && rootFields.nonEmpty) {
// Prunes nested fields in serializers.
val prunedSchema = SchemaPruning.pruneDataSchema(
val prunedSchema = SchemaPruning.pruneSchema(
StructType.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields)
val nestedPrunedSerializer = prunedSerializer.zipWithIndex.map { case (serializer, idx) =>
pruneSerializer(serializer, prunedSchema(idx).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
val prunedSchema = SchemaPruning.pruneSchema(schema, requestedRootFields)
assert(prunedSchema === expectedSchema)
}

Expand Down Expand Up @@ -140,7 +140,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
assert(field.metadata.getString("foo") == "bar")

val schema = StructType(Seq(field))
val prunedSchema = SchemaPruning.pruneDataSchema(schema, rootFields)
val prunedSchema = SchemaPruning.pruneSchema(schema, rootFields)
assert(prunedSchema.head.metadata.getString("foo") == "bar")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,10 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
val outputSchema = readDataColumns.toStructType
logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}")

val metadataStructOpt = requiredAttributes.collectFirst {
val metadataStructOpt = l.output.collectFirst {
case MetadataAttribute(attr) => attr
}

// TODO (yaohua): should be able to prune the metadata struct only containing what needed
val metadataColumns = metadataStructOpt.map { metadataStruct =>
metadataStruct.dataType.asInstanceOf[StructType].fields.map { field =>
MetadataAttribute(field.name, field.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,58 +31,68 @@ import org.apache.spark.sql.util.SchemaUtils._
* By "physical column", we mean a column as defined in the data source format like Parquet format
* or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL
* column, and a nested Parquet column corresponds to a [[StructField]].
*
* Also prunes the unnecessary metadata columns if any for all file formats.
*/
object SchemaPruning extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._

override def apply(plan: LogicalPlan): LogicalPlan =
if (conf.nestedSchemaPruningEnabled) {
apply0(plan)
} else {
plan
}

private def apply0(plan: LogicalPlan): LogicalPlan =
plan transformDown {
case op @ PhysicalOperation(projects, filters,
l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _))
if canPruneRelation(hadoopFsRelation) =>

prunePhysicalColumns(l.output, projects, filters, hadoopFsRelation.dataSchema,
prunedDataSchema => {
l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) =>
prunePhysicalColumns(l, projects, filters, hadoopFsRelation,
(prunedDataSchema, prunedMetadataSchema) => {
val prunedHadoopRelation =
hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession)
buildPrunedRelation(l, prunedHadoopRelation)
buildPrunedRelation(l, prunedHadoopRelation, prunedMetadataSchema)
}).getOrElse(op)
}

/**
* This method returns optional logical plan. `None` is returned if no nested field is required or
* all nested fields are required.
*
* This method will prune both the data schema and the metadata schema
*/
private def prunePhysicalColumns(
output: Seq[AttributeReference],
relation: LogicalRelation,
projects: Seq[NamedExpression],
filters: Seq[Expression],
dataSchema: StructType,
leafNodeBuilder: StructType => LeafNode): Option[LogicalPlan] = {
hadoopFsRelation: HadoopFsRelation,
leafNodeBuilder: (StructType, StructType) => LeafNode): Option[LogicalPlan] = {

val (normalizedProjects, normalizedFilters) =
normalizeAttributeRefNames(output, projects, filters)
normalizeAttributeRefNames(relation.output, projects, filters)
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)

// If requestedRootFields includes a nested field, continue. Otherwise,
// return op
if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields)

// If the data schema is different from the pruned data schema, continue. Otherwise,
// return op. We effect this comparison by counting the number of "leaf" fields in
// each schemata, assuming the fields in prunedDataSchema are a subset of the fields
// in dataSchema.
if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
val prunedRelation = leafNodeBuilder(prunedDataSchema)
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
val prunedDataSchema = if (canPruneDataSchema(hadoopFsRelation)) {
pruneSchema(hadoopFsRelation.dataSchema, requestedRootFields)
} else {
hadoopFsRelation.dataSchema
}

val metadataSchema =
relation.output.collect { case MetadataAttribute(attr) => attr }.toStructType
val prunedMetadataSchema = if (metadataSchema.nonEmpty) {
pruneSchema(metadataSchema, requestedRootFields)
} else {
metadataSchema
}

// If the data schema is different from the pruned data schema
// OR
// the metadata schema is different from the pruned metadata schema, continue.
// Otherwise, return None.
if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) ||
countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) {
val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema)
val projectionOverSchema =
ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema))
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
prunedRelation, projectionOverSchema))
} else {
Expand All @@ -96,9 +106,10 @@ object SchemaPruning extends Rule[LogicalPlan] {
/**
* Checks to see if the given relation can be pruned. Currently we support Parquet and ORC v1.
*/
private def canPruneRelation(fsRelation: HadoopFsRelation) =
fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
fsRelation.fileFormat.isInstanceOf[OrcFileFormat]
private def canPruneDataSchema(fsRelation: HadoopFsRelation): Boolean =
conf.nestedSchemaPruningEnabled && (
fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
fsRelation.fileFormat.isInstanceOf[OrcFileFormat])

/**
* Normalizes the names of the attribute references in the given projects and filters to reflect
Expand Down Expand Up @@ -162,29 +173,25 @@ object SchemaPruning extends Rule[LogicalPlan] {
*/
private def buildPrunedRelation(
outputRelation: LogicalRelation,
prunedBaseRelation: HadoopFsRelation) = {
val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema)
// also add the metadata output if any
// TODO: should be able to prune the metadata schema
val metaOutput = outputRelation.output.collect {
case MetadataAttribute(attr) => attr
}
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput ++ metaOutput)
prunedBaseRelation: HadoopFsRelation,
prunedMetadataSchema: StructType) = {
val finalSchema = prunedBaseRelation.schema.merge(prunedMetadataSchema)
val prunedOutput = getPrunedOutput(outputRelation.output, finalSchema)
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
}

// Prune the given output to make it consistent with `requiredSchema`.
private def getPrunedOutput(
output: Seq[AttributeReference],
requiredSchema: StructType): Seq[AttributeReference] = {
// We need to replace the expression ids of the pruned relation output attributes
// with the expression ids of the original relation output attributes so that
// references to the original relation's output are not broken
val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
// We need to update the data type of the output attributes to use the pruned ones.
// so that references to the original relation's output are not broken
val nameAttributeMap = output.map(att => (att.name, att)).toMap
requiredSchema
.toAttributes
.map {
case att if outputIdMap.contains(att.name) =>
att.withExprId(outputIdMap(att.name))
case att if nameAttributeMap.contains(att.name) =>
nameAttributeMap(att.name).withDataType(att.dataType)
case att => att
}
}
Expand All @@ -203,6 +210,4 @@ object SchemaPruning extends Rule[LogicalPlan] {
case _ => 1
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ object PushDownUtils extends PredicateHelper {
case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled =>
val rootFields = SchemaPruning.identifyRootFields(projects, filters)
val prunedSchema = if (rootFields.nonEmpty) {
SchemaPruning.pruneDataSchema(relation.schema, rootFields)
SchemaPruning.pruneSchema(relation.schema, rootFields)
} else {
new StructType()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Timestamp
import java.text.SimpleDateFormat

import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -384,4 +385,51 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession {
}
}
}

metadataColumnsTest("prune metadata schema in projects", schema) { (df, f0, f1) =>
val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_NAME)
val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
case p: FileSourceScanExec => p.metadataColumns
}.get
assert(fileSourceScanMetaCols.size == 1)
assert(fileSourceScanMetaCols.head.name == "file_name")

checkAnswer(
prunedDF,
Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_NAME)),
Row("lily", 31, 54321L, f1(METADATA_FILE_NAME)))
)
}

metadataColumnsTest("prune metadata schema in filters", schema) { (df, f0, f1) =>
val prunedDF = df.select("name", "age", "info.id")
.where(col(METADATA_FILE_PATH).contains("data/f0"))

val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
case p: FileSourceScanExec => p.metadataColumns
}.get
assert(fileSourceScanMetaCols.size == 1)
assert(fileSourceScanMetaCols.head.name == "file_path")

checkAnswer(
prunedDF,
Seq(Row("jack", 24, 12345L))
)
}

metadataColumnsTest("prune metadata schema in projects and filters", schema) { (df, f0, f1) =>
val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_SIZE)
.where(col(METADATA_FILE_PATH).contains("data/f0"))

val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
case p: FileSourceScanExec => p.metadataColumns
}.get
assert(fileSourceScanMetaCols.size == 2)
assert(fileSourceScanMetaCols.map(_.name).toSet == Set("file_size", "file_path"))

checkAnswer(
prunedDF,
Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_SIZE)))
)
}
}