Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ case class LogicalRelation(
override def withMetadataColumns(): LogicalRelation = {
val newMetadata = metadataOutput.filterNot(outputSet.contains)
if (newMetadata.nonEmpty) {
this.copy(output = output ++ newMetadata)
val newRelation = this.copy(output = output ++ newMetadata)
newRelation.copyTagsFrom(this)
newRelation
} else {
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ object SchemaPruning extends Rule[LogicalPlan] {
prunedMetadataSchema: StructType) = {
val finalSchema = prunedBaseRelation.schema.merge(prunedMetadataSchema)
val prunedOutput = getPrunedOutput(outputRelation.output, finalSchema)
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
val prunedRelation = outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
prunedRelation.copyTagsFrom(outputRelation)
prunedRelation
}

// Prune the given output to make it consistent with `requiredSchema`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import java.text.SimpleDateFormat
import org.apache.spark.TestUtils
import org.apache.spark.paths.SparkPath
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1027,4 +1029,30 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession {
}
}
}

test("SPARK-43422: Keep tags during optimization when adding metadata columns") {
withTempPath { path =>
spark.range(end = 10).write.format("parquet").save(path.toString)

// Add the tag to the base Dataframe before selecting a metadata column.
val customTag = TreeNodeTag[Boolean]("customTag")
val baseDf = spark.read.format("parquet").load(path.toString)
val tagsPut = baseDf.queryExecution.analyzed.collect {
case rel: LogicalRelation => rel.setTagValue(customTag, true)
}

assert(tagsPut.nonEmpty)

val dfWithMetadata = baseDf.select("_metadata.row_index")

// Expect the tag in the analyzed and optimized plan after querying a metadata column.
def isTaggedRelation(plan: LogicalPlan): Boolean = plan match {
case rel: LogicalRelation => rel.getTagValue(customTag).getOrElse(false)
case _ => false
}

assert(dfWithMetadata.queryExecution.analyzed.exists(isTaggedRelation))
assert(dfWithMetadata.queryExecution.optimizedPlan.exists(isTaggedRelation))
}
}
}