From 5b8c505aa10327997a70783d867d4daf2baedcc5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 7 Sep 2022 18:45:20 +0800 Subject: [PATCH] [SPARK-40149][SQL] Propagate metadata columns through Project This PR fixes a regression caused by https://github.com/apache/spark/pull/32017 . In https://github.com/apache/spark/pull/32017 , we tried to be more conservative and decided to not propagate metadata columns in certain operators, including `Project`. However, the decision was made only considering SQL API, not DataFrame API. In fact, it's very common to chain `Project` operators in DataFrame, e.g. `df.withColumn(...).withColumn(...)...`, and it's very inconvenient if metadata columns are not propagated through `Project`. This PR makes 2 changes: 1. Project should propagate metadata columns 2. SubqueryAlias should only propagate metadata columns if the child is a leaf node or also a SubqueryAlias The second change is needed to still forbid weird queries like `SELECT m from (SELECT a from t)`, which is the main motivation of https://github.com/apache/spark/pull/32017 . After propagating metadata columns, a problem from https://github.com/apache/spark/pull/31666 is exposed: the natural join metadata columns may confuse the analyzer and lead to wrong analyzed plan. For example, `SELECT t1.value FROM t1 LEFT JOIN t2 USING (key) ORDER BY key`, how shall we resolve `ORDER BY key`? It should be resolved to `t1.key` via the rule `ResolveMissingReferences`, which is in the output of the left join. However, if `Project` can propagate metadata columns, `ORDER BY key` will be resolved to `t2.key`. To solve this problem, this PR only allows qualified access for metadata columns of natural join. This has no breaking change, as people can only do qualified access for natural join metadata columns before, in the `Project` right after `Join`. This actually enables more use cases, as people can now access natural join metadata columns in ORDER BY. I've added a test for it. fix a regression For SQL API, there is no change, as a `SubqueryAlias` always comes with a `Project` or `Aggregate`, so we still don't propagate metadata columns through a SELECT group. For DataFrame API, the behavior becomes more lenient. The only breaking case is an operator that can propagate metadata columns then follows a `SubqueryAlias`, e.g. `df.filter(...).as("t").select("t.metadata_col")`. But this is a weird use case and I don't think we should support it at the first place. new tests Closes #37758 from cloud-fan/metadata. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan (cherry picked from commit 99ae1d9a897909990881f14c5ea70a0d1a0bf456) Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../sql/catalyst/expressions/package.scala | 13 +- .../plans/logical/basicLogicalOperators.scala | 13 +- .../spark/sql/catalyst/util/package.scala | 15 +- .../resources/sql-tests/inputs/using-join.sql | 2 + .../sql-tests/results/using-join.sql.out | 11 + .../sql/connector/DataSourceV2SQLSuite.scala | 218 ----------------- .../sql/connector/MetadataColumnSuite.scala | 219 ++++++++++++++++++ 9 files changed, 263 insertions(+), 238 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 30aaaa5184d3a..8d6261a784753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1043,9 +1043,11 @@ class Analyzer(override val catalogManager: CatalogManager) private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match { case r: DataSourceV2Relation => r.withMetadataColumns() case p: Project => - p.copy( + val newProj = p.copy( projectList = p.metadataOutput ++ p.projectList, child = addMetadataCol(p.child)) + newProj.copyTagsFrom(p) + newProj case _ => plan.withNewChildren(plan.children.map(addMetadataCol)) } } @@ -3480,8 +3482,8 @@ class Analyzer(override val catalogManager: CatalogManager) val project = Project(projectList, Join(left, right, joinType, newCondition, hint)) project.setTagValue( Project.hiddenOutputTag, - hiddenList.map(_.markAsSupportsQualifiedStar()) ++ - project.child.metadataOutput.filter(_.supportsQualifiedStar)) + hiddenList.map(_.markAsQualifiedAccessOnly()) ++ + project.child.metadataOutput.filter(_.qualifiedAccessOnly)) project } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 9db038dbf350b..cd02b03e2d00e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -386,7 +386,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (target.isEmpty) return input.output // If there is a table specified, use hidden input attributes as well - val hiddenOutput = input.metadataOutput.filter(_.supportsQualifiedStar) + val hiddenOutput = input.metadataOutput.filter(_.qualifiedAccessOnly) val expandedAttributes = (hiddenOutput ++ input.output).filter( matchedQualifier(_, target.get, resolver)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 6a4fb099c8b78..7913f396120f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -23,6 +23,7 @@ import com.google.common.collect.Maps import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.types.{StructField, StructType} /** @@ -265,7 +266,7 @@ package object expressions { case (Seq(), _) => val name = nameParts.head val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) - (attributes, nameParts.tail) + (attributes.filterNot(_.qualifiedAccessOnly), nameParts.tail) case _ => matches } } @@ -314,10 +315,12 @@ package object expressions { var i = nameParts.length - 1 while (i >= 0 && candidates.isEmpty) { val name = nameParts(i) - candidates = collectMatches( - name, - nameParts.take(i), - direct.get(name.toLowerCase(Locale.ROOT))) + val attrsToLookup = if (i == 0) { + direct.get(name.toLowerCase(Locale.ROOT)).map(_.filterNot(_.qualifiedAccessOnly)) + } else { + direct.get(name.toLowerCase(Locale.ROOT)) + } + candidates = collectMatches(name, nameParts.take(i), attrsToLookup) if (candidates.nonEmpty) { nestedFields = nameParts.takeRight(nameParts.length - i - 1) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8c005e9980f24..53c95a4ffd3fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -88,7 +88,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) getAllValidConstraints(projectList) override def metadataOutput: Seq[Attribute] = - getTagValue(Project.hiddenOutputTag).getOrElse(Nil) + getTagValue(Project.hiddenOutputTag).getOrElse(child.metadataOutput) override protected def withNewChildInternal(newChild: LogicalPlan): Project = copy(child = newChild) @@ -1307,9 +1307,14 @@ case class SubqueryAlias( } override def metadataOutput: Seq[Attribute] = { - val qualifierList = identifier.qualifier :+ alias - val nonHiddenMetadataOutput = child.metadataOutput.filter(!_.supportsQualifiedStar) - nonHiddenMetadataOutput.map(_.withQualifier(qualifierList)) + // Propagate metadata columns from leaf nodes through a chain of `SubqueryAlias`. + if (child.isInstanceOf[LeafNode] || child.isInstanceOf[SubqueryAlias]) { + val qualifierList = identifier.qualifier :+ alias + val nonHiddenMetadataOutput = child.metadataOutput.filter(!_.qualifiedAccessOnly) + nonHiddenMetadataOutput.map(_.withQualifier(qualifierList)) + } else { + Nil + } } override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 33fe48d44dadb..d1a0aa52f6757 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -206,22 +206,23 @@ package object util extends Logging { implicit class MetadataColumnHelper(attr: Attribute) { /** - * If set, this metadata column is a candidate during qualified star expansions. + * If set, this metadata column can only be accessed with qualifiers, e.g. `qualifiers.col` or + * `qualifiers.*`. If not set, metadata columns cannot be accessed via star. */ - val SUPPORTS_QUALIFIED_STAR = "__supports_qualified_star" + val QUALIFIED_ACCESS_ONLY = "__qualified_access_only" def isMetadataCol: Boolean = attr.metadata.contains(METADATA_COL_ATTR_KEY) && attr.metadata.getBoolean(METADATA_COL_ATTR_KEY) - def supportsQualifiedStar: Boolean = attr.isMetadataCol && - attr.metadata.contains(SUPPORTS_QUALIFIED_STAR) && - attr.metadata.getBoolean(SUPPORTS_QUALIFIED_STAR) + def qualifiedAccessOnly: Boolean = attr.isMetadataCol && + attr.metadata.contains(QUALIFIED_ACCESS_ONLY) && + attr.metadata.getBoolean(QUALIFIED_ACCESS_ONLY) - def markAsSupportsQualifiedStar(): Attribute = attr.withMetadata( + def markAsQualifiedAccessOnly(): Attribute = attr.withMetadata( new MetadataBuilder() .withMetadata(attr.metadata) .putBoolean(METADATA_COL_ATTR_KEY, true) - .putBoolean(SUPPORTS_QUALIFIED_STAR, true) + .putBoolean(QUALIFIED_ACCESS_ONLY, true) .build() ) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/using-join.sql b/sql/core/src/test/resources/sql-tests/inputs/using-join.sql index 336d19f0f2a3d..87390b388764f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/using-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/using-join.sql @@ -19,6 +19,8 @@ SELECT nt1.*, nt2.* FROM nt1 left outer join nt2 using (k); SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k); +SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k) ORDER BY nt2.k; + SELECT k, nt1.k FROM nt1 left outer join nt2 using (k); SELECT k, nt2.k FROM nt1 left outer join nt2 using (k); diff --git a/sql/core/src/test/resources/sql-tests/results/using-join.sql.out b/sql/core/src/test/resources/sql-tests/results/using-join.sql.out index 1d2ae9d96ecad..db9ac1f10bb00 100644 --- a/sql/core/src/test/resources/sql-tests/results/using-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/using-join.sql.out @@ -71,6 +71,17 @@ three NULL two two +-- !query +SELECT nt1.k, nt2.k FROM nt1 left outer join nt2 using (k) ORDER BY nt2.k +-- !query schema +struct +-- !query output +three NULL +one one +one one +two two + + -- !query SELECT k, nt1.k FROM nt1 left outer join nt2 using (k) -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index ba4828dbf0264..a910004277baa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2524,100 +2524,6 @@ class DataSourceV2SQLSuite } } - test("SPARK-31255: Project a metadata column") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1") - val dfQuery = spark.table(t1).select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-31255: Projects data column when metadata column has the same name") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (index bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, index), index)") - sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')") - - val sqlQuery = spark.sql(s"SELECT index, data, _partition FROM $t1") - val dfQuery = spark.table(t1).select("index", "data", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) - } - } - } - - test("SPARK-31255: * expansion does not include metadata columns") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')") - - val sqlQuery = spark.sql(s"SELECT * FROM $t1") - val dfQuery = spark.table(t1) - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(3, "c"), Row(2, "b"), Row(1, "a"))) - } - } - } - - test("SPARK-31255: metadata column should only be produced when necessary") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - - val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0") - val dfQuery = spark.table(t1).filter("index = 0") - - Seq(sqlQuery, dfQuery).foreach { query => - assert(query.schema.fieldNames.toSeq == Seq("id", "data")) - } - } - } - - test("SPARK-34547: metadata columns are resolved last") { - val t1 = s"${catalogAndNamespace}tableOne" - val t2 = "t2" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - withTempView(t2) { - sql(s"CREATE TEMPORARY VIEW $t2 AS SELECT * FROM " + - s"VALUES (1, -1), (2, -2), (3, -3) AS $t2(id, index)") - - val sqlQuery = spark.sql(s"SELECT $t1.id, $t2.id, data, index, $t1.index, $t2.index FROM " + - s"$t1 JOIN $t2 WHERE $t1.id = $t2.id") - val t1Table = spark.table(t1) - val t2Table = spark.table(t2) - val dfQuery = t1Table.join(t2Table, t1Table.col("id") === t2Table.col("id")) - .select(s"$t1.id", s"$t2.id", "data", "index", s"$t1.index", s"$t2.index") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, - Seq( - Row(1, 1, "a", -1, 0, -1), - Row(2, 2, "b", -2, 0, -2), - Row(3, 3, "c", -3, 0, -3) - ) - ) - } - } - } - } - test("SPARK-33505: insert into partitioned table") { val t = "testpart.ns1.ns2.tbl" withTable(t) { @@ -2702,27 +2608,6 @@ class DataSourceV2SQLSuite } } - test("SPARK-34555: Resolve DataFrame metadata column") { - val tbl = s"${catalogAndNamespace}table" - withTable(tbl) { - sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") - val table = spark.table(tbl) - val dfQuery = table.select( - table.col("id"), - table.col("data"), - table.col("index"), - table.col("_partition") - ) - - checkAnswer( - dfQuery, - Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) - ) - } - } - test("SPARK-34561: drop/add columns to a dataset of `DESCRIBE TABLE`") { val tbl = s"${catalogAndNamespace}tbl" withTable(tbl) { @@ -2785,109 +2670,6 @@ class DataSourceV2SQLSuite } } - test("SPARK-34923: do not propagate metadata columns through Project") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - assertThrows[AnalysisException] { - sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)") - } - assertThrows[AnalysisException] { - spark.table(t1).select("id", "data").select("index", "_partition") - } - } - } - - test("SPARK-34923: do not propagate metadata columns through View") { - val t1 = s"${catalogAndNamespace}table" - val view = "view" - - withTable(t1) { - withTempView(view) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - sql(s"CACHE TABLE $view AS SELECT * FROM $t1") - assertThrows[AnalysisException] { - sql(s"SELECT index, _partition FROM $view") - } - } - } - } - - test("SPARK-34923: propagate metadata columns through Filter") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1") - val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-34923: propagate metadata columns through Sort") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id") - val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-34923: propagate metadata columns through RepartitionBy") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql( - s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1") - val tbl = spark.table(t1) - val dfQuery = tbl.repartitionByRange(3, tbl.col("id")) - .select("id", "data", "index", "_partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - - test("SPARK-34923: propagate metadata columns through SubqueryAlias") { - val t1 = s"${catalogAndNamespace}table" - val sbq = "sbq" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") - - val sqlQuery = spark.sql( - s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq") - val dfQuery = spark.table(t1).as(sbq).select( - s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") - - Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) - } - } - } - private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala new file mode 100644 index 0000000000000..95b9c4f72356a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.functions.struct + +class MetadataColumnSuite extends DatasourceV2SQLBase { + import testImplicits._ + + private val tbl = "testcat.t" + + private def prepareTable(): Unit = { + sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") + } + + test("SPARK-31255: Project a metadata column") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl") + val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-31255: Projects data column when metadata column has the same name") { + withTable(tbl) { + sql(s"CREATE TABLE $tbl (index bigint, data string) PARTITIONED BY (bucket(4, index), index)") + sql(s"INSERT INTO $tbl VALUES (3, 'c'), (2, 'b'), (1, 'a')") + + val sqlQuery = sql(s"SELECT index, data, _partition FROM $tbl") + val dfQuery = spark.table(tbl).select("index", "data", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) + } + } + } + + test("SPARK-31255: * expansion does not include metadata columns") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT * FROM $tbl") + val dfQuery = spark.table(tbl) + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + } + + test("SPARK-31255: metadata column should only be produced when necessary") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT * FROM $tbl WHERE index = 0") + val dfQuery = spark.table(tbl).filter("index = 0") + + Seq(sqlQuery, dfQuery).foreach { query => + assert(query.schema.fieldNames.toSeq == Seq("id", "data")) + } + } + } + + test("SPARK-34547: metadata columns are resolved last") { + withTable(tbl) { + prepareTable() + withTempView("v") { + sql(s"CREATE TEMPORARY VIEW v AS SELECT * FROM " + + s"VALUES (1, -1), (2, -2), (3, -3) AS v(id, index)") + + val sqlQuery = sql(s"SELECT $tbl.id, v.id, data, index, $tbl.index, v.index " + + s"FROM $tbl JOIN v WHERE $tbl.id = v.id") + val tableDf = spark.table(tbl) + val viewDf = spark.table("v") + val dfQuery = tableDf.join(viewDf, tableDf.col("id") === viewDf.col("id")) + .select(s"$tbl.id", "v.id", "data", "index", s"$tbl.index", "v.index") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, + Seq( + Row(1, 1, "a", -1, 0, -1), + Row(2, 2, "b", -2, 0, -2), + Row(3, 3, "c", -3, 0, -3) + ) + ) + } + } + } + } + + test("SPARK-34555: Resolve DataFrame metadata column") { + withTable(tbl) { + prepareTable() + val table = spark.table(tbl) + val dfQuery = table.select( + table.col("id"), + table.col("data"), + table.col("index"), + table.col("_partition") + ) + + checkAnswer( + dfQuery, + Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) + ) + } + } + + test("SPARK-34923: propagate metadata columns through Project") { + withTable(tbl) { + prepareTable() + checkAnswer( + spark.table(tbl).select("id", "data").select("index", "_partition"), + Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3")) + ) + } + } + + test("SPARK-34923: do not propagate metadata columns through View") { + val view = "view" + withTable(tbl) { + withTempView(view) { + prepareTable() + sql(s"CACHE TABLE $view AS SELECT * FROM $tbl") + assertThrows[AnalysisException] { + sql(s"SELECT index, _partition FROM $view") + } + } + } + } + + test("SPARK-34923: propagate metadata columns through Filter") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl WHERE id > 1") + val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through Sort") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl ORDER BY id") + val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through RepartitionBy") { + withTable(tbl) { + prepareTable() + val sqlQuery = sql( + s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $tbl") + val dfQuery = spark.table(tbl).repartitionByRange(3, $"id") + .select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through SubqueryAlias if child is leaf node") { + val sbq = "sbq" + withTable(tbl) { + prepareTable() + val sqlQuery = sql( + s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $tbl $sbq") + val dfQuery = spark.table(tbl).as(sbq).select( + s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + + assertThrows[AnalysisException] { + sql(s"SELECT $sbq.index FROM (SELECT id FROM $tbl) $sbq") + } + assertThrows[AnalysisException] { + spark.table(tbl).select($"id").as(sbq).select(s"$sbq.index") + } + } + } + + test("SPARK-40149: select outer join metadata columns with DataFrame API") { + val df1 = Seq(1 -> "a").toDF("k", "v").as("left") + val df2 = Seq(1 -> "b").toDF("k", "v").as("right") + val dfQuery = df1.join(df2, Seq("k"), "outer") + .withColumn("left_all", struct($"left.*")) + .withColumn("right_all", struct($"right.*")) + checkAnswer(dfQuery, Row(1, "a", "b", Row(1, "a"), Row(1, "b"))) + } +}