Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1040,13 +1040,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
child = addMetadataCol(p.child, requiredAttrIds))
newProj.copyTagsFrom(p)
newProj
case u: Union if u.metadataOutput.exists(a => requiredAttrIds.contains(a.exprId)) =>
u.withNewChildren(u.children.map { child =>
// The children of a Union will have the same attributes with different expression IDs
val exprIdMap = u.metadataOutput.map(_.exprId)
.zip(child.metadataOutput.map(_.exprId)).toMap
addMetadataCol(child, requiredAttrIds.map(a => exprIdMap.getOrElse(a, a)))
})
case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, requiredAttrIds)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,55 +450,23 @@ case class Union(
AttributeSet.fromAttributeSets(children.map(_.outputSet)).size
}

/**
* Merges a sequence of attributes to have a common datatype and updates the
* nullability to be consistent with the attributes being merged.
*/
private def mergeAttributes(attributes: Seq[Attribute]): Attribute = {
val firstAttr = attributes.head
val nullable = attributes.exists(_.nullable)
val newDt = attributes.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)(
firstAttr.exprId, firstAttr.qualifier)
}
}

override def output: Seq[Attribute] = children.map(_.output).transpose.map(mergeAttributes)

override def metadataOutput: Seq[Attribute] = {
val childrenMetadataOutput = children.map(_.metadataOutput)
// This follows similar code in `CheckAnalysis` to check if the output of a Union is correct,
// but just silently doesn't return an output instead of throwing an error. It also ensures
// that the attribute and data type names are the same.
val refDataTypes = childrenMetadataOutput.head.map(_.dataType)
val refAttrNames = childrenMetadataOutput.head.map(_.name)
childrenMetadataOutput.tail.foreach { childMetadataOutput =>
// We can only propagate the metadata output correctly if every child has the same
// number of columns
if (childMetadataOutput.length != refDataTypes.length) return Nil
// Check if the data types match by name and type
val childDataTypes = childMetadataOutput.map(_.dataType)
childDataTypes.zip(refDataTypes).foreach { case (dt1, dt2) =>
if (!DataType.equalsStructurally(dt1, dt2, true) ||
!DataType.equalsStructurallyByName(dt1, dt2, conf.resolver)) {
return Nil
}
}
// Check that the names of the attributes match
val childAttrNames = childMetadataOutput.map(_.name)
childAttrNames.zip(refAttrNames).foreach { case (attrName1, attrName2) =>
if (!conf.resolver(attrName1, attrName2)) {
return Nil
}
// updating nullability to make all the children consistent
override def output: Seq[Attribute] = {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)(
firstAttr.exprId, firstAttr.qualifier)
}
}
// If the metadata output matches, merge the attributes and return them
childrenMetadataOutput.transpose.map(mergeAttributes)
}

override def metadataOutput: Seq[Attribute] = Nil

override lazy val resolved: Boolean = {
// allChildrenCompatible needs to be evaluated after childrenResolved
def allChildrenCompatible: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ abstract class InMemoryBaseTable(

// purposely exposes a metadata column that conflicts with a data column in some tests
override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn)
private lazy val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name)
private val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name)

private val allowUnsupportedTransforms =
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,13 @@

package org.apache.spark.sql.connector

import java.io.{File, FilenameFilter}

import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog, InMemoryTable, MetadataColumn, Table}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType}

class MetadataColumnSuite extends DatasourceV2SQLBase {
import testImplicits._

override protected def beforeEach(): Unit = {
super.beforeEach()
spark.conf.set("spark.sql.catalog.testCatalog", classOf[MetadataTestCatalog].getName)
spark.conf.set("spark.sql.catalog.typeMismatch", classOf[MetadataTypeMismatchCatalog].getName)
spark.conf.set(
"spark.sql.catalog.nameMismatch", classOf[MetadataAttrNameMismatchCatalog].getName)
spark.conf.set(
"spark.sql.catalog.fieldNameMismatch", classOf[MetadataFieldNameMismatchCatalog].getName)
}

private val tbl = "testcat.t"

private def prepareTable(): Unit = {
Expand Down Expand Up @@ -264,272 +249,4 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
}
}
}

test("SPARK-41498: Metadata column is propagated through union") {
withTable(tbl) {
prepareTable()
val df = spark.table(tbl)
val dfQuery = df.union(df).select("id", "data", "index", "_partition")
val expectedAnswer = Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))
checkAnswer(dfQuery, expectedAnswer ++ expectedAnswer)
}
}

test("SPARK-41498: Nested metadata column is propagated through union") {
withTempDir { dir =>
spark.range(start = 0, end = 10, step = 1, numPartitions = 1)
.write.mode("overwrite").save(dir.getAbsolutePath)
val df = spark.read.load(dir.getAbsolutePath)
val dfQuery = df.union(df).select("_metadata.file_path")

val filePath = dir.listFiles(new FilenameFilter {
override def accept(dir: File, name: String): Boolean = name.endsWith(".parquet")
}).map(_.getAbsolutePath)
assert(filePath.length == 1)
val expectedAnswer = (1 to 20).map(_ => Row("file:" ++ filePath.head))
checkAnswer(dfQuery, expectedAnswer)
}
}

test("SPARK-41498: Metadata column is not propagated when children of Union " +
"have metadata output of different size") {
withTable(tbl) {
prepareTable()
withTempDir { dir =>
spark.range(start = 10, end = 20).selectExpr("bigint(id) as id", "string(id) as data")
.write.mode("overwrite").save(dir.getAbsolutePath)
val df1 = spark.table(tbl)
val df2 = spark.read.load(dir.getAbsolutePath)

// Make sure one df contains a metadata column and the other does not
assert(!df1.queryExecution.analyzed.metadataOutput.exists(_.name == "_metadata"))
assert(df2.queryExecution.analyzed.metadataOutput.exists(_.name == "_metadata"))

assert(df1.union(df2).queryExecution.analyzed.metadataOutput.isEmpty)
}
}
}

test("SPARK-41498: Metadata column is not propagated when children of Union " +
"have a type mismatch in a metadata column") {
val tbl = "testCatalog.t"
val typeMismatchTbl = "typeMismatch.t"
withTable(tbl, typeMismatchTbl) {
spark.range(10).write.saveAsTable(tbl)
val df = spark.table(tbl)
spark.range(10).write.saveAsTable(typeMismatchTbl)
val typeMismatchDf = spark.table(typeMismatchTbl)
assert(df.union(typeMismatchDf).queryExecution.analyzed.metadataOutput.isEmpty)
}
}

test("SPARK-41498: Metadata column is not propagated when children of Union " +
"have an attribute name mismatch in a metadata column") {
val tbl = "testCatalog.t"
val nameMismatchTbl = "nameMismatch.t"
withTable(tbl, nameMismatchTbl) {
spark.range(10).write.saveAsTable(tbl)
val df = spark.table(tbl)
spark.range(10).write.saveAsTable(nameMismatchTbl)
val nameMismatchDf = spark.table(nameMismatchTbl)
assert(df.union(nameMismatchDf).queryExecution.analyzed.metadataOutput.isEmpty)
}
}

test("SPARK-41498: Metadata column is not propagated when children of Union " +
"have a field name mismatch in a metadata column") {
val tbl = "testCatalog.t"
val fieldNameMismatchTbl = "fieldNameMismatch.t"
withTable(tbl, fieldNameMismatchTbl) {
spark.range(10).write.saveAsTable(tbl)
val df = spark.table(tbl)
spark.range(10).write.saveAsTable(fieldNameMismatchTbl)
val fieldNameMismatchDf = spark.table(fieldNameMismatchTbl)
assert(df.union(fieldNameMismatchDf).queryExecution.analyzed.metadataOutput.isEmpty)
}
}

test("SPARK-41538: Metadata column should be appended at the end of project") {
val tableName = "table_1"
val viewName = "view_1"
withTable(tableName) {
withView(viewName) {
sql(s"CREATE TABLE $tableName (a ARRAY<STRING>, s STRUCT<id: STRING>) USING parquet")
val id = "id1"
sql(s"INSERT INTO $tableName values(ARRAY('a'), named_struct('id', '$id'))")
sql(
s"""
|CREATE VIEW $viewName (id)
|AS WITH source AS (
| SELECT * FROM $tableName
|),
|renamed AS (
| SELECT s.id FROM source
|)
|SELECT id FROM renamed
|""".stripMargin)
val query =
s"""
|with foo AS (
| SELECT '$id' as id
|),
|bar AS (
| SELECT '$id' as id
|)
|SELECT
| 1
|FROM foo
|FULL OUTER JOIN bar USING(id)
|FULL OUTER JOIN $viewName USING(id)
|WHERE foo.id IS NOT NULL
|""".stripMargin
checkAnswer(sql(query), Row(1))
}
}
}

test("SPARK-42331: Fix metadata col can not been resolved") {
withTable(tbl) {
prepareTable()

checkAnswer(
spark.table(tbl).where("index = 0").select("index"),
Seq(Row(0), Row(0), Row(0)))
checkAnswer(
spark.table(tbl).where("index = 0").select("_partition"),
Seq(Row("3/1"), Row("0/2"), Row("1/3")))
}
}
}

class MetadataTestTable(
name: String,
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String])
extends InMemoryTable(name, schema, partitioning, properties) {

override val metadataColumns: Array[MetadataColumn] =
Array(
new MetadataColumn {
override def name: String = "_metadata"
override def dataType: DataType = StructType(StructField("index", IntegerType) :: Nil)
override def comment: String = ""
}
)
}

class TypeMismatchTable(
name: String,
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String])
extends InMemoryTable(name, schema, partitioning, properties) {

override val metadataColumns: Array[MetadataColumn] =
Array(
new MetadataColumn {
override def name: String = "_metadata"
override def dataType: DataType = StructType(StructField("index", StringType) :: Nil)
override def comment: String =
"Used to create a type mismatch with the metadata col in `MetadataTestTable`"
}
)
}

class AttrNameMismatchTable(
name: String,
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String])
extends InMemoryTable(name, schema, partitioning, properties) {
override val metadataColumns: Array[MetadataColumn] =
Array(
new MetadataColumn {
override def name: String = "wrongName"
override def dataType: DataType = StructType(StructField("index", IntegerType) :: Nil)
override def comment: String =
"Used to create a name mismatch with the metadata col in `MetadataTestTable`"
})
}

class FieldNameMismatchTable(
name: String,
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String])
extends InMemoryTable(name, schema, partitioning, properties) {
override val metadataColumns: Array[MetadataColumn] =
Array(
new MetadataColumn {
override def name: String = "_metadata"
override def dataType: DataType = StructType(StructField("wrongName", IntegerType) :: Nil)
override def comment: String =
"Used to create a name mismatch with the struct field in the metadata col of " +
"`MetadataTestTable`"
})
}

class MetadataTestCatalog extends InMemoryCatalog {
override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: java.util.Map[String, String]): Table = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tableName = s"$name.${ident.quoted}"
val tbl = new MetadataTestTable(tableName, schema, partitions, properties)
tables.put(ident, tbl)
namespaces.putIfAbsent(ident.namespace.toList, Map())
tbl
}
}

class MetadataTypeMismatchCatalog extends InMemoryCatalog {
override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: java.util.Map[String, String]): Table = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tableName = s"$name.${ident.quoted}"
val tbl = new TypeMismatchTable(tableName, schema, partitions, properties)
tables.put(ident, tbl)
namespaces.putIfAbsent(ident.namespace.toList, Map())
tbl
}
}

class MetadataAttrNameMismatchCatalog extends InMemoryCatalog {
override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: java.util.Map[String, String]): Table = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tableName = s"$name.${ident.quoted}"
val tbl = new AttrNameMismatchTable(tableName, schema, partitions, properties)
tables.put(ident, tbl)
namespaces.putIfAbsent(ident.namespace.toList, Map())
tbl
}
}

class MetadataFieldNameMismatchCatalog extends InMemoryCatalog {
override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: java.util.Map[String, String]): Table = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tableName = s"$name.${ident.quoted}"
val tbl = new FieldNameMismatchTable(tableName, schema, partitions, properties)
tables.put(ident, tbl)
namespaces.putIfAbsent(ident.namespace.toList, Map())
tbl
}
}