From b5470ae294b81107443dec81648b847e0b58aca5 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 23 Feb 2021 11:18:47 +0900 Subject: [PATCH 01/60] [MINOR][DOCS] Replace http to https when possible in PySpark documentation ### What changes were proposed in this pull request? This PR proposes: - Change http to https for better security - Change http://apache-spark-developers-list.1001551.n3.nabble.com/ to official mailing list link (https://mail-archives.apache.org/mod_mbox/spark-dev/) ### Why are the changes needed? For better security, and to use official link. ### Does this PR introduce _any_ user-facing change? Yes, It exposes more secure and correct links to the PySpark end users in PySpark documentation. ### How was this patch tested? I manually checked if each link works Closes #31616 from HyukjinKwon/minor-https. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- python/docs/source/development/contributing.rst | 6 +++--- python/docs/source/getting_started/index.rst | 4 ++-- python/docs/source/migration_guide/index.rst | 8 ++++---- python/docs/source/user_guide/arrow_pandas.rst | 2 +- python/docs/source/user_guide/index.rst | 12 ++++++------ 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/docs/source/development/contributing.rst b/python/docs/source/development/contributing.rst index 8100bcbafbaee..4f0f9ae998f34 100644 --- a/python/docs/source/development/contributing.rst +++ b/python/docs/source/development/contributing.rst @@ -21,14 +21,14 @@ Contributing to PySpark There are many types of contribution, for example, helping other users, testing releases, reviewing changes, documentation contribution, bug reporting, JIRA maintenance, code changes, etc. -These are documented at `the general guidelines `_. +These are documented at `the general guidelines `_. This page focuses on PySpark and includes additional details specifically for PySpark. Contributing by Testing Releases -------------------------------- -Before the official release, PySpark release candidates are shared in the `dev@spark.apache.org `_ mailing list to vote on. +Before the official release, PySpark release candidates are shared in the `dev@spark.apache.org `_ mailing list to vote on. This release candidates can be easily installed via pip. For example, in case of Spark 3.0.0 RC1, you can install as below: .. code-block:: bash @@ -71,7 +71,7 @@ under ``python/docs/source/reference``. Otherwise, they would not be documented Preparing to Contribute Code Changes ------------------------------------ -Before starting to work on codes in PySpark, it is recommended to read `the general guidelines `_. +Before starting to work on codes in PySpark, it is recommended to read `the general guidelines `_. There are a couple of additional notes to keep in mind when contributing to codes in PySpark: * Be Pythonic. diff --git a/python/docs/source/getting_started/index.rst b/python/docs/source/getting_started/index.rst index 38b9c935fc623..f6d7a92ced03f 100644 --- a/python/docs/source/getting_started/index.rst +++ b/python/docs/source/getting_started/index.rst @@ -22,8 +22,8 @@ Getting Started This page summarizes the basic steps required to setup and get started with PySpark. There are more guides shared with other languages such as -`Quick Start `_ in Programming Guides -at `the Spark documentation `_. +`Quick Start `_ in Programming Guides +at `the Spark documentation `_. .. toctree:: :maxdepth: 2 diff --git a/python/docs/source/migration_guide/index.rst b/python/docs/source/migration_guide/index.rst index 88e768dc464df..d309d44780d1d 100644 --- a/python/docs/source/migration_guide/index.rst +++ b/python/docs/source/migration_guide/index.rst @@ -36,8 +36,8 @@ This page describes the migration guide specific to PySpark. Many items of other migration guides can also be applied when migrating PySpark to higher versions because PySpark internally shares other components. Please also refer other migration guides: -- `Migration Guide: Spark Core `_ -- `Migration Guide: SQL, Datasets and DataFrame `_ -- `Migration Guide: Structured Streaming `_ -- `Migration Guide: MLlib (Machine Learning) `_ +- `Migration Guide: Spark Core `_ +- `Migration Guide: SQL, Datasets and DataFrame `_ +- `Migration Guide: Structured Streaming `_ +- `Migration Guide: MLlib (Machine Learning) `_ diff --git a/python/docs/source/user_guide/arrow_pandas.rst b/python/docs/source/user_guide/arrow_pandas.rst index 91d8155523391..12b772f62abe2 100644 --- a/python/docs/source/user_guide/arrow_pandas.rst +++ b/python/docs/source/user_guide/arrow_pandas.rst @@ -408,5 +408,5 @@ This will instruct PyArrow >= 0.15.0 to use the legacy IPC format with the older is in Spark 2.3.x and 2.4.x. Not setting this environment variable will lead to a similar error as described in `SPARK-29367 `_ when running ``pandas_udf``\s or :meth:`DataFrame.toPandas` with Arrow enabled. More information about the Arrow IPC change can -be read on the Arrow 0.15.0 release `blog `_. +be read on the Arrow 0.15.0 release `blog `_. diff --git a/python/docs/source/user_guide/index.rst b/python/docs/source/user_guide/index.rst index 704156b11d985..3897ab2ea9086 100644 --- a/python/docs/source/user_guide/index.rst +++ b/python/docs/source/user_guide/index.rst @@ -30,11 +30,11 @@ This page is the guide for PySpark users which contains PySpark specific topics. There are more guides shared with other languages in Programming Guides -at `the Spark documentation `_. +at `the Spark documentation `_. -- `RDD Programming Guide `_ -- `Spark SQL, DataFrames and Datasets Guide `_ -- `Structured Streaming Programming Guide `_ -- `Spark Streaming Programming Guide `_ -- `Machine Learning Library (MLlib) Guide `_ +- `RDD Programming Guide `_ +- `Spark SQL, DataFrames and Datasets Guide `_ +- `Structured Streaming Programming Guide `_ +- `Spark Streaming Programming Guide `_ +- `Machine Learning Library (MLlib) Guide `_ From 612d52315b8476dd588d75ce3001dee5786db747 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 23 Feb 2021 11:22:02 +0900 Subject: [PATCH 02/60] [SPARK-34500][DOCS][EXAMPLES] Replace symbol literals with $"" in examples and documents ### What changes were proposed in this pull request? This PR replaces all the occurrences of symbol literals (`'name`) with string interpolation (`$"name"`) in examples and documents. ### Why are the changes needed? Symbol literals are used to represent columns in Spark SQL but the Scala community seems to remove `Symbol` completely. As we discussed in #31569, first we should replacing symbol literals with `$"name"` in user facing examples and documents. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Build docs. Closes #31615 from sarutak/replace-symbol-literals-in-doc-and-examples. Authored-by: Kousuke Saruta Signed-off-by: HyukjinKwon --- docs/sql-data-sources-avro.md | 4 ++-- .../apache/spark/examples/sql/SimpleTypedAggregator.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- .../scala/org/apache/spark/sql/expressions/Window.scala | 8 ++++---- .../org/apache/spark/sql/expressions/WindowSpec.scala | 8 ++++---- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index da2a90e3ae027..928b3d021a172 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -107,9 +107,9 @@ val df = spark // 2. Filter by column `favorite_color`; // 3. Encode the column `name` in Avro format. val output = df - .select(from_avro('value, jsonFormatSchema) as 'user) + .select(from_avro($"value", jsonFormatSchema) as $"user") .where("user.favorite_color == \"red\"") - .select(to_avro($"user.name") as 'value) + .select(to_avro($"user.name") as $"value") val query = output .writeStream diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala index 5510f0019353b..5d11fb2fc96e5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala @@ -31,7 +31,7 @@ object SimpleTypedAggregator { .getOrCreate() import spark.implicits._ - val ds = spark.range(20).select(('id % 3).as("key"), 'id).as[(Long, Long)] + val ds = spark.range(20).select(($"id" % 3).as("key"), $"id").as[(Long, Long)] println("input data:") ds.show() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6557ff3b2fc45..db008821584dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1108,7 +1108,7 @@ class Column(val expr: Expression) extends Logging { * Gives the column an alias. * {{{ * // Renames colA to colB in select output. - * df.select($"colA".as('colB)) + * df.select($"colA".as("colB")) * }}} * * If the current column has metadata associated with it, this metadata will be propagated diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7ccf6dc872206..fd02d0b131587 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2307,7 +2307,7 @@ class Dataset[T] private[sql]( * case class Book(title: String, words: String) * val ds: Dataset[Book] * - * val allWords = ds.select('title, explode(split('words, " ")).as("word")) + * val allWords = ds.select($"title", explode(split($"words", " ")).as("word")) * * val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title")) * }}} @@ -2346,7 +2346,7 @@ class Dataset[T] private[sql]( * `functions.explode()`: * * {{{ - * ds.select(explode(split('words, " ")).as("word")) + * ds.select(explode(split($"words", " ")).as("word")) * }}} * * or `flatMap()`: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index d13baaedbaeff..93bf738a53daf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -136,8 +136,8 @@ object Window { * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) - * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() + * Window.partitionBy($"category").orderBy($"id").rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum($"id") over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -188,8 +188,8 @@ object Window { * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) - * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() + * Window.partitionBy($"category").orderBy($"id").rangeBetween(Window.currentRow, 1) + * df.withColumn("sum", sum($"id") over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 58227f075f2c7..09a945f162a98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -100,8 +100,8 @@ class WindowSpec private[sql]( * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) - * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() + * Window.partitionBy($"category").orderBy($"id").rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum($"id") over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -168,8 +168,8 @@ class WindowSpec private[sql]( * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) - * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() + * Window.partitionBy($"category").orderBy($"id").rangeBetween(Window.currentRow, 1) + * df.withColumn("sum", sum($"id") over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| From be675a052c38a36ce5e33ba56bdc69cc8972b3e8 Mon Sep 17 00:00:00 2001 From: Linhong Liu Date: Tue, 23 Feb 2021 15:51:02 +0800 Subject: [PATCH 03/60] [SPARK-34490][SQL] Analysis should fail if the view refers a dropped table ### What changes were proposed in this pull request? When resolving a view, we use the captured view name in `AnalysisContext` to distinguish whether a relation name is a view or a table. But if the resolution failed, other rules (e.g. `ResolveTables`) will try to resolve the relation again but without `AnalysisContext`. So, in this case, the resolution may be incorrect. For example, if the view refers to a dropped table while a view with the same name exists, the dropped table will be resolved as a view rather than an unresolved exception. ### Why are the changes needed? bugfix ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? newly added test cases Closes #31606 from linhongliu-db/fix-temp-view-master. Lead-authored-by: Linhong Liu Co-authored-by: Linhong Liu <67896261+linhongliu-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 35 ++++++++++++------- .../analysis/TableLookupCacheSuite.scala | 13 +++++-- .../sql/execution/SQLViewTestSuite.scala | 20 +++++++++++ 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 29341aecc1842..38259c234c262 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -871,24 +871,24 @@ class Analyzer(override val catalogManager: CatalogManager) object ResolveTempViews extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case u @ UnresolvedRelation(ident, _, isStreaming) => - lookupTempView(ident, isStreaming).getOrElse(u) + lookupTempView(ident, isStreaming, performCheck = true).getOrElse(u) case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _, _) => - lookupTempView(ident) + lookupTempView(ident, performCheck = true) .map(view => i.copy(table = view)) .getOrElse(i) case c @ CacheTable(UnresolvedRelation(ident, _, false), _, _, _) => - lookupTempView(ident) + lookupTempView(ident, performCheck = true) .map(view => c.copy(table = view)) .getOrElse(c) case c @ UncacheTable(UnresolvedRelation(ident, _, false), _, _) => - lookupTempView(ident) + lookupTempView(ident, performCheck = true) .map(view => c.copy(table = view, isTempView = true)) .getOrElse(c) // TODO (SPARK-27484): handle streaming write commands when we have them. case write: V2WriteCommand => write.table match { case UnresolvedRelation(ident, _, false) => - lookupTempView(ident).map(EliminateSubqueryAliases(_)).map { + lookupTempView(ident, performCheck = true).map(EliminateSubqueryAliases(_)).map { case r: DataSourceV2Relation => write.withNewTable(r) case _ => throw QueryCompilationErrors.writeIntoTempViewNotAllowedError(ident.quoted) }.getOrElse(write) @@ -921,7 +921,9 @@ class Analyzer(override val catalogManager: CatalogManager) } def lookupTempView( - identifier: Seq[String], isStreaming: Boolean = false): Option[LogicalPlan] = { + identifier: Seq[String], + isStreaming: Boolean = false, + performCheck: Boolean = false): Option[LogicalPlan] = { // Permanent View can't refer to temp views, no need to lookup at all. if (isResolvingView && !referredTempViewNames.contains(identifier)) return None @@ -934,7 +936,7 @@ class Analyzer(override val catalogManager: CatalogManager) if (isStreaming && tmpView.nonEmpty && !tmpView.get.isStreaming) { throw QueryCompilationErrors.readNonStreamingTempViewError(identifier.quoted) } - tmpView.map(ResolveRelations.resolveViews) + tmpView.map(ResolveRelations.resolveViews(_, performCheck)) } } @@ -1098,7 +1100,7 @@ class Analyzer(override val catalogManager: CatalogManager) // look at `AnalysisContext.catalogAndNamespace` when resolving relations with single-part name. // If `AnalysisContext.catalogAndNamespace` is non-empty, analyzer will expand single-part names // with it, instead of current catalog and namespace. - def resolveViews(plan: LogicalPlan): LogicalPlan = plan match { + def resolveViews(plan: LogicalPlan, performCheck: Boolean = false): LogicalPlan = plan match { // The view's child should be a logical plan parsed from the `desc.viewText`, the variable // `viewText` should be defined, or else we throw an error on the generation of the View // operator. @@ -1115,9 +1117,18 @@ class Analyzer(override val catalogManager: CatalogManager) executeSameContext(child) } } + // Fail the analysis eagerly because outside AnalysisContext, the unresolved operators + // inside a view maybe resolved incorrectly. + // But for commands like `DropViewCommand`, resolving view is unnecessary even though + // there is unresolved node. So use the `performCheck` flag to skip the analysis check + // for these commands. + // TODO(SPARK-34504): avoid unnecessary view resolving and remove the `performCheck` flag + if (performCheck) { + checkAnalysis(newChild) + } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => - p.copy(child = resolveViews(view)) + p.copy(child = resolveViews(view, performCheck)) case _ => plan } @@ -1137,14 +1148,14 @@ class Analyzer(override val catalogManager: CatalogManager) case c @ CacheTable(u @ UnresolvedRelation(_, _, false), _, _, _) => lookupRelation(u.multipartIdentifier, u.options, false) - .map(resolveViews) + .map(resolveViews(_, performCheck = true)) .map(EliminateSubqueryAliases(_)) .map(relation => c.copy(table = relation)) .getOrElse(c) case c @ UncacheTable(u @ UnresolvedRelation(_, _, false), _, _) => lookupRelation(u.multipartIdentifier, u.options, false) - .map(resolveViews) + .map(resolveViews(_, performCheck = true)) .map(EliminateSubqueryAliases(_)) .map(relation => c.copy(table = relation)) .getOrElse(c) @@ -1170,7 +1181,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u: UnresolvedRelation => lookupRelation(u.multipartIdentifier, u.options, u.isStreaming) - .map(resolveViews).getOrElse(u) + .map(resolveViews(_, performCheck = true)).getOrElse(u) case u @ UnresolvedTable(identifier, cmd, relationTypeMismatchHint) => lookupTableOrView(identifier).map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala index 3e9a8b71a8fb6..ec9480514ba2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.io.File +import scala.collection.JavaConverters._ + import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -27,8 +29,8 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, ExternalCatalog, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.connector.InMemoryTableCatalog -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, V1Table} +import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table} import org.apache.spark.sql.types._ class TableLookupCacheSuite extends AnalysisTest with Matchers { @@ -46,7 +48,12 @@ class TableLookupCacheSuite extends AnalysisTest with Matchers { ignoreIfExists = false) val v2Catalog = new InMemoryTableCatalog { override def loadTable(ident: Identifier): Table = { - V1Table(externalCatalog.getTable("default", ident.name)) + val catalogTable = externalCatalog.getTable("default", ident.name) + new InMemoryTable( + catalogTable.identifier.table, + catalogTable.schema, + Array.empty, + Map.empty[String, String].asJava) } override def name: String = CatalogManager.SESSION_CATALOG_NAME } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala index 68e1a682562ac..84a20bb16ad86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala @@ -258,6 +258,26 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { checkViewOutput(viewName, Seq(Row(2))) } } + + test("SPARK-34490 - query should fail if the view refers a dropped table") { + withTable("t") { + Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t") + val viewName = createView("testView", "SELECT * FROM t") + withView(viewName) { + // Always create a temp view in this case, not use `createView` on purpose + sql("CREATE TEMP VIEW t AS SELECT 1 AS c1") + withTempView("t") { + checkViewOutput(viewName, Seq(Row(2), Row(3), Row(1))) + // Manually drop table `t` to see if the query will fail + sql("DROP TABLE IF EXISTS default.t") + val e = intercept[AnalysisException] { + sql(s"SELECT * FROM $viewName").collect() + }.getMessage + assert(e.contains("Table or view not found: t")) + } + } + } + } } class LocalTempViewTestSuite extends SQLViewTestSuite with SharedSparkSession { From 8f994cbb4a18558c2e81516ef1e339d9c8fa0d41 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Tue, 23 Feb 2021 12:04:31 +0000 Subject: [PATCH 04/60] [SPARK-34475][SQL] Rename logical nodes of v2 `ALTER` commands ### What changes were proposed in this pull request? In the PR, I propose to rename logical nodes of v2 commands in the form: ` + ` like: - AlterTableAddPartition -> AddPartition - AlterTableSetLocation -> SetTableLocation ### Why are the changes needed? 1. For simplicity and readability of logical plans 2. For consistency with other logical nodes. For example, the logical node `RenameTable` for `ALTER TABLE .. RENAME TO` was added before `AlterTableRenamePartition`. ### Does this PR introduce _any_ user-facing change? Should not since this is non-public APIs. ### How was this patch tested? 1. Check scala style: `./dev/scalastyle` 2. Affected test suites: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *AlterTableRenamePartitionSuite" ``` Closes #31596 from MaxGekk/rename-alter-table-logic-nodes. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 +-- .../analysis/ResolvePartitionSpec.scala | 8 ++-- .../analysis/ResolveTableProperties.scala | 6 +-- .../sql/catalyst/parser/AstBuilder.scala | 44 ++++++++--------- .../catalyst/plans/logical/v2Commands.scala | 24 +++++----- .../sql/catalyst/parser/DDLParserSuite.scala | 48 +++++++++---------- .../analysis/ResolveSessionCatalog.scala | 24 +++++----- .../datasources/v2/DataSourceV2Strategy.scala | 20 ++++---- .../spark/sql/internal/CatalogImpl.scala | 4 +- .../sql/connector/DataSourceV2SQLSuite.scala | 6 +-- .../AlterTableAddPartitionParserSuite.scala | 6 +-- .../AlterTableDropPartitionParserSuite.scala | 10 ++-- ...terTableRecoverPartitionsParserSuite.scala | 10 ++-- ...AlterTableRenamePartitionParserSuite.scala | 6 +-- .../command/PlanResolutionSuite.scala | 31 ++++++------ 15 files changed, 128 insertions(+), 125 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 29bd2c256df8b..1d44e6ba298e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -565,13 +565,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // no validation needed for set and remove property } - case AlterTableAddPartition(r: ResolvedTable, parts, _) => + case AddPartitions(r: ResolvedTable, parts, _) => checkAlterTablePartition(r.table, parts) - case AlterTableDropPartition(r: ResolvedTable, parts, _, _) => + case DropPartitions(r: ResolvedTable, parts, _, _) => checkAlterTablePartition(r.table, parts) - case AlterTableRenamePartition(r: ResolvedTable, from, _) => + case RenamePartitions(r: ResolvedTable, from, _) => checkAlterTablePartition(r.table, Seq(from)) case showPartitions: ShowPartitions => checkShowPartitions(showPartitions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 2307152b17375..72298b285f2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, AlterTableDropPartition, AlterTableRenamePartition, LogicalPlan, ShowPartitions} +import org.apache.spark.sql.catalyst.plans.logical.{AddPartitions, DropPartitions, LogicalPlan, RenamePartitions, ShowPartitions} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement @@ -33,7 +33,7 @@ import org.apache.spark.sql.util.PartitioningUtils.{normalizePartitionSpec, requ object ResolvePartitionSpec extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case r @ AlterTableAddPartition( + case r @ AddPartitions( ResolvedTable(_, _, table: SupportsPartitionManagement, _), partSpecs, _) => val partitionSchema = table.partitionSchema() r.copy(parts = resolvePartitionSpecs( @@ -42,7 +42,7 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { partitionSchema, requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames))) - case r @ AlterTableDropPartition( + case r @ DropPartitions( ResolvedTable(_, _, table: SupportsPartitionManagement, _), partSpecs, _, _) => val partitionSchema = table.partitionSchema() r.copy(parts = resolvePartitionSpecs( @@ -51,7 +51,7 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { partitionSchema, requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames))) - case r @ AlterTableRenamePartition( + case r @ RenamePartitions( ResolvedTable(_, _, table: SupportsPartitionManagement, _), from, to) => val partitionSchema = table.partitionSchema() val Seq(resolvedFrom, resolvedTo) = resolvePartitionSpecs( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala index 2fe6e20614524..12b1502dd7c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala @@ -20,17 +20,17 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.plans.logical.{AlterTableUnsetProperties, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnsetTableProperties} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.connector.catalog.TableCatalog /** - * A rule for resolving AlterTableUnsetProperties to handle non-existent properties. + * A rule for resolving [[UnsetTableProperties]] to handle non-existent properties. */ object ResolveTableProperties extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case a @ AlterTableUnsetProperties(r: ResolvedTable, props, ifExists) if !ifExists => + case a @ UnsetTableProperties(r: ResolvedTable, props, ifExists) if !ifExists => val tblProperties = r.table.properties.asScala props.foreach { p => if (!tblProperties.contains(p) && p != TableCatalog.PROP_COMMENT) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8bb73702365b0..595a3a5ba5332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2767,7 +2767,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create an [[AlterNamespaceSetProperties]] logical plan. + * Create an [[SetNamespaceProperties]] logical plan. * * For example: * {{{ @@ -2778,14 +2778,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg override def visitSetNamespaceProperties(ctx: SetNamespacePropertiesContext): LogicalPlan = { withOrigin(ctx) { val properties = cleanNamespaceProperties(visitPropertyKeyValues(ctx.tablePropertyList), ctx) - AlterNamespaceSetProperties( + SetNamespaceProperties( UnresolvedNamespace(visitMultipartIdentifier(ctx.multipartIdentifier)), properties) } } /** - * Create an [[AlterNamespaceSetLocation]] logical plan. + * Create an [[SetNamespaceLocation]] logical plan. * * For example: * {{{ @@ -2794,7 +2794,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitSetNamespaceLocation(ctx: SetNamespaceLocationContext): LogicalPlan = { withOrigin(ctx) { - AlterNamespaceSetLocation( + SetNamespaceLocation( UnresolvedNamespace(visitMultipartIdentifier(ctx.multipartIdentifier)), visitLocationSpec(ctx.locationSpec)) } @@ -3477,7 +3477,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Parse [[AlterViewSetProperties]] or [[AlterTableSetProperties]] commands. + * Parse [[SetViewProperties]] or [[SetTableProperties]] commands. * * For example: * {{{ @@ -3490,7 +3490,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val properties = visitPropertyKeyValues(ctx.tablePropertyList) val cleanedTableProperties = cleanTableProperties(ctx, properties) if (ctx.VIEW != null) { - AlterViewSetProperties( + SetViewProperties( createUnresolvedView( ctx.multipartIdentifier, commandName = "ALTER VIEW ... SET TBLPROPERTIES", @@ -3498,7 +3498,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg relationTypeMismatchHint = alterViewTypeMismatchHint), cleanedTableProperties) } else { - AlterTableSetProperties( + SetTableProperties( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... SET TBLPROPERTIES", @@ -3508,7 +3508,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Parse [[AlterViewUnsetProperties]] or [[AlterTableUnsetProperties]] commands. + * Parse [[UnsetViewProperties]] or [[UnsetTableProperties]] commands. * * For example: * {{{ @@ -3523,7 +3523,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val ifExists = ctx.EXISTS != null if (ctx.VIEW != null) { - AlterViewUnsetProperties( + UnsetViewProperties( createUnresolvedView( ctx.multipartIdentifier, commandName = "ALTER VIEW ... UNSET TBLPROPERTIES", @@ -3532,7 +3532,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg cleanedProperties, ifExists) } else { - AlterTableUnsetProperties( + UnsetTableProperties( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... UNSET TBLPROPERTIES", @@ -3543,7 +3543,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create an [[AlterTableSetLocation]] command. + * Create an [[SetTableLocation]] command. * * For example: * {{{ @@ -3551,7 +3551,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { - AlterTableSetLocation( + SetTableLocation( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... SET LOCATION ...", @@ -3810,7 +3810,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create an [[AlterTableRecoverPartitions]] + * Create an [[RecoverPartitions]] * * For example: * {{{ @@ -3819,7 +3819,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitRecoverPartitions( ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) { - AlterTableRecoverPartitions( + RecoverPartitions( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... RECOVER PARTITIONS", @@ -3827,7 +3827,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create an [[AlterTableAddPartition]]. + * Create an [[AddPartitions]]. * * For example: * {{{ @@ -3849,7 +3849,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val location = Option(splCtx.locationSpec).map(visitLocationSpec) UnresolvedPartitionSpec(spec, location) } - AlterTableAddPartition( + AddPartitions( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... ADD PARTITION ...", @@ -3859,7 +3859,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create an [[AlterTableRenamePartition]] + * Create an [[RenamePartitions]] * * For example: * {{{ @@ -3868,7 +3868,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitRenameTablePartition( ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableRenamePartition( + RenamePartitions( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... RENAME TO PARTITION", @@ -3878,7 +3878,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create an [[AlterTableDropPartition]] + * Create an [[DropPartitions]] * * For example: * {{{ @@ -3897,7 +3897,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } val partSpecs = ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec) .map(spec => UnresolvedPartitionSpec(spec)) - AlterTableDropPartition( + DropPartitions( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... DROP PARTITION ...", @@ -3908,7 +3908,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create an [[AlterTableSerDeProperties]] + * Create an [[SetTableSerDeProperties]] * * For example: * {{{ @@ -3918,7 +3918,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) { - AlterTableSerDeProperties( + SetTableSerDeProperties( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index f9341714881fa..8797b107f945a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -305,7 +305,7 @@ case class DescribeNamespace( * The logical plan of the ALTER (DATABASE|SCHEMA|NAMESPACE) ... SET (DBPROPERTIES|PROPERTIES) * command. */ -case class AlterNamespaceSetProperties( +case class SetNamespaceProperties( namespace: LogicalPlan, properties: Map[String, String]) extends Command { override def children: Seq[LogicalPlan] = Seq(namespace) @@ -314,7 +314,7 @@ case class AlterNamespaceSetProperties( /** * The logical plan of the ALTER (DATABASE|SCHEMA|NAMESPACE) ... SET LOCATION command. */ -case class AlterNamespaceSetLocation( +case class SetNamespaceLocation( namespace: LogicalPlan, location: String) extends Command { override def children: Seq[LogicalPlan] = Seq(namespace) @@ -676,7 +676,7 @@ case class AnalyzeColumn( * PARTITION spec1 [LOCATION 'loc1'][, PARTITION spec2 [LOCATION 'loc2'], ...]; * }}} */ -case class AlterTableAddPartition( +case class AddPartitions( child: LogicalPlan, parts: Seq[PartitionSpec], ifNotExists: Boolean) extends Command { @@ -698,7 +698,7 @@ case class AlterTableAddPartition( * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; * }}} */ -case class AlterTableDropPartition( +case class DropPartitions( child: LogicalPlan, parts: Seq[PartitionSpec], ifExists: Boolean, @@ -712,7 +712,7 @@ case class AlterTableDropPartition( /** * The logical plan of the ALTER TABLE ... RENAME TO PARTITION command. */ -case class AlterTableRenamePartition( +case class RenamePartitions( child: LogicalPlan, from: PartitionSpec, to: PartitionSpec) extends Command { @@ -727,7 +727,7 @@ case class AlterTableRenamePartition( /** * The logical plan of the ALTER TABLE ... RECOVER PARTITIONS command. */ -case class AlterTableRecoverPartitions(child: LogicalPlan) extends Command { +case class RecoverPartitions(child: LogicalPlan) extends Command { override def children: Seq[LogicalPlan] = child :: Nil } @@ -819,7 +819,7 @@ case class AlterViewAs( /** * The logical plan of the ALTER VIEW ... SET TBLPROPERTIES command. */ -case class AlterViewSetProperties( +case class SetViewProperties( child: LogicalPlan, properties: Map[String, String]) extends Command { override def children: Seq[LogicalPlan] = child :: Nil @@ -828,7 +828,7 @@ case class AlterViewSetProperties( /** * The logical plan of the ALTER VIEW ... UNSET TBLPROPERTIES command. */ -case class AlterViewUnsetProperties( +case class UnsetViewProperties( child: LogicalPlan, propertyKeys: Seq[String], ifExists: Boolean) extends Command { @@ -838,7 +838,7 @@ case class AlterViewUnsetProperties( /** * The logical plan of the ALTER TABLE ... SET [SERDE|SERDEPROPERTIES] command. */ -case class AlterTableSerDeProperties( +case class SetTableSerDeProperties( child: LogicalPlan, serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], @@ -876,7 +876,7 @@ case class UncacheTable( /** * The logical plan of the ALTER TABLE ... SET LOCATION command. */ -case class AlterTableSetLocation( +case class SetTableLocation( table: LogicalPlan, partitionSpec: Option[TablePartitionSpec], location: String) extends Command { @@ -886,7 +886,7 @@ case class AlterTableSetLocation( /** * The logical plan of the ALTER TABLE ... SET TBLPROPERTIES command. */ -case class AlterTableSetProperties( +case class SetTableProperties( table: LogicalPlan, properties: Map[String, String]) extends Command { override def children: Seq[LogicalPlan] = table :: Nil @@ -895,7 +895,7 @@ case class AlterTableSetProperties( /** * The logical plan of the ALTER TABLE ... UNSET TBLPROPERTIES command. */ -case class AlterTableUnsetProperties( +case class UnsetTableProperties( table: LogicalPlan, propertyKeys: Seq[String], ifExists: Boolean) extends Command { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index acf2bfeeaabd2..cb9dda8260a50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -741,16 +741,16 @@ class DDLParserSuite extends AnalysisTest { val hint = Some("Please use ALTER TABLE instead.") comparePlans(parsePlan(sql1_view), - AlterViewSetProperties( + SetViewProperties( UnresolvedView(Seq("table_name"), "ALTER VIEW ... SET TBLPROPERTIES", false, hint), Map("test" -> "test", "comment" -> "new_comment"))) comparePlans(parsePlan(sql2_view), - AlterViewUnsetProperties( + UnsetViewProperties( UnresolvedView(Seq("table_name"), "ALTER VIEW ... UNSET TBLPROPERTIES", false, hint), Seq("comment", "test"), ifExists = false)) comparePlans(parsePlan(sql3_view), - AlterViewUnsetProperties( + UnsetViewProperties( UnresolvedView(Seq("table_name"), "ALTER VIEW ... UNSET TBLPROPERTIES", false, hint), Seq("comment", "test"), ifExists = true)) @@ -767,18 +767,18 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan(sql1_table), - AlterTableSetProperties( + SetTableProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... SET TBLPROPERTIES", hint), Map("test" -> "test", "comment" -> "new_comment"))) comparePlans( parsePlan(sql2_table), - AlterTableUnsetProperties( + UnsetTableProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... UNSET TBLPROPERTIES", hint), Seq("comment", "test"), ifExists = false)) comparePlans( parsePlan(sql3_table), - AlterTableUnsetProperties( + UnsetTableProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... UNSET TBLPROPERTIES", hint), Seq("comment", "test"), ifExists = true)) @@ -876,14 +876,14 @@ class DDLParserSuite extends AnalysisTest { val hint = Some("Please use ALTER VIEW instead.") comparePlans( parsePlan("ALTER TABLE a.b.c SET LOCATION 'new location'"), - AlterTableSetLocation( + SetTableLocation( UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... SET LOCATION ...", hint), None, "new location")) comparePlans( parsePlan("ALTER TABLE a.b.c PARTITION(ds='2017-06-10') SET LOCATION 'new location'"), - AlterTableSetLocation( + SetTableLocation( UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... SET LOCATION ...", hint), Some(Map("ds" -> "2017-06-10")), "new location")) @@ -1765,49 +1765,49 @@ class DDLParserSuite extends AnalysisTest { test("set namespace properties") { comparePlans( parsePlan("ALTER DATABASE a.b.c SET PROPERTIES ('a'='a', 'b'='b', 'c'='c')"), - AlterNamespaceSetProperties( + SetNamespaceProperties( UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a", "b" -> "b", "c" -> "c"))) comparePlans( parsePlan("ALTER SCHEMA a.b.c SET PROPERTIES ('a'='a')"), - AlterNamespaceSetProperties( + SetNamespaceProperties( UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a"))) comparePlans( parsePlan("ALTER NAMESPACE a.b.c SET PROPERTIES ('b'='b')"), - AlterNamespaceSetProperties( + SetNamespaceProperties( UnresolvedNamespace(Seq("a", "b", "c")), Map("b" -> "b"))) comparePlans( parsePlan("ALTER DATABASE a.b.c SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')"), - AlterNamespaceSetProperties( + SetNamespaceProperties( UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a", "b" -> "b", "c" -> "c"))) comparePlans( parsePlan("ALTER SCHEMA a.b.c SET DBPROPERTIES ('a'='a')"), - AlterNamespaceSetProperties( + SetNamespaceProperties( UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a"))) comparePlans( parsePlan("ALTER NAMESPACE a.b.c SET DBPROPERTIES ('b'='b')"), - AlterNamespaceSetProperties( + SetNamespaceProperties( UnresolvedNamespace(Seq("a", "b", "c")), Map("b" -> "b"))) } test("set namespace location") { comparePlans( parsePlan("ALTER DATABASE a.b.c SET LOCATION '/home/user/db'"), - AlterNamespaceSetLocation( + SetNamespaceLocation( UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) comparePlans( parsePlan("ALTER SCHEMA a.b.c SET LOCATION '/home/user/db'"), - AlterNamespaceSetLocation( + SetNamespaceLocation( UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) comparePlans( parsePlan("ALTER NAMESPACE a.b.c SET LOCATION '/home/user/db'"), - AlterNamespaceSetLocation( + SetNamespaceLocation( UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) } @@ -2064,7 +2064,7 @@ class DDLParserSuite extends AnalysisTest { val sql1 = "ALTER TABLE table_name SET SERDE 'org.apache.class'" val hint = Some("Please use ALTER VIEW instead.") val parsed1 = parsePlan(sql1) - val expected1 = AlterTableSerDeProperties( + val expected1 = SetTableSerDeProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", hint), Some("org.apache.class"), None, @@ -2077,7 +2077,7 @@ class DDLParserSuite extends AnalysisTest { |WITH SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed2 = parsePlan(sql2) - val expected2 = AlterTableSerDeProperties( + val expected2 = SetTableSerDeProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", hint), Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), @@ -2090,7 +2090,7 @@ class DDLParserSuite extends AnalysisTest { |SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed3 = parsePlan(sql3) - val expected3 = AlterTableSerDeProperties( + val expected3 = SetTableSerDeProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", hint), None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), @@ -2104,7 +2104,7 @@ class DDLParserSuite extends AnalysisTest { |WITH SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed4 = parsePlan(sql4) - val expected4 = AlterTableSerDeProperties( + val expected4 = SetTableSerDeProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", hint), Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), @@ -2117,7 +2117,7 @@ class DDLParserSuite extends AnalysisTest { |SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed5 = parsePlan(sql5) - val expected5 = AlterTableSerDeProperties( + val expected5 = SetTableSerDeProperties( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", hint), None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), @@ -2130,7 +2130,7 @@ class DDLParserSuite extends AnalysisTest { |WITH SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed6 = parsePlan(sql6) - val expected6 = AlterTableSerDeProperties( + val expected6 = SetTableSerDeProperties( UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", hint), Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), @@ -2143,7 +2143,7 @@ class DDLParserSuite extends AnalysisTest { |SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed7 = parsePlan(sql7) - val expected7 = AlterTableSerDeProperties( + val expected7 = SetTableSerDeProperties( UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]", hint), None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 4d974b4515bcb..7a8f4dd39080e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -167,25 +167,25 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) createAlterTable(nameParts, catalog, tbl, changes) } - case AlterTableSetProperties(ResolvedV1TableIdentifier(ident), props) => + case SetTableProperties(ResolvedV1TableIdentifier(ident), props) => AlterTableSetPropertiesCommand(ident.asTableIdentifier, props, isView = false) - case AlterTableUnsetProperties(ResolvedV1TableIdentifier(ident), keys, ifExists) => + case UnsetTableProperties(ResolvedV1TableIdentifier(ident), keys, ifExists) => AlterTableUnsetPropertiesCommand(ident.asTableIdentifier, keys, ifExists, isView = false) - case AlterViewSetProperties(ResolvedView(ident, _), props) => + case SetViewProperties(ResolvedView(ident, _), props) => AlterTableSetPropertiesCommand(ident.asTableIdentifier, props, isView = true) - case AlterViewUnsetProperties(ResolvedView(ident, _), keys, ifExists) => + case UnsetViewProperties(ResolvedView(ident, _), keys, ifExists) => AlterTableUnsetPropertiesCommand(ident.asTableIdentifier, keys, ifExists, isView = true) case d @ DescribeNamespace(DatabaseInSessionCatalog(db), _) => DescribeDatabaseCommand(db, d.extended) - case AlterNamespaceSetProperties(DatabaseInSessionCatalog(db), properties) => + case SetNamespaceProperties(DatabaseInSessionCatalog(db), properties) => AlterDatabasePropertiesCommand(db, properties) - case AlterNamespaceSetLocation(DatabaseInSessionCatalog(db), location) => + case SetNamespaceLocation(DatabaseInSessionCatalog(db), location) => AlterDatabaseSetLocationCommand(db, location) case s @ ShowNamespaces(ResolvedNamespace(cata, _), _, output) if isSessionCatalog(cata) => @@ -417,24 +417,24 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } ShowColumnsCommand(db, v1TableName, output) - case AlterTableRecoverPartitions(ResolvedV1TableIdentifier(ident)) => + case RecoverPartitions(ResolvedV1TableIdentifier(ident)) => AlterTableRecoverPartitionsCommand( ident.asTableIdentifier, "ALTER TABLE RECOVER PARTITIONS") - case AlterTableAddPartition(ResolvedV1TableIdentifier(ident), partSpecsAndLocs, ifNotExists) => + case AddPartitions(ResolvedV1TableIdentifier(ident), partSpecsAndLocs, ifNotExists) => AlterTableAddPartitionCommand( ident.asTableIdentifier, partSpecsAndLocs.asUnresolvedPartitionSpecs.map(spec => (spec.spec, spec.location)), ifNotExists) - case AlterTableRenamePartition( + case RenamePartitions( ResolvedV1TableIdentifier(ident), UnresolvedPartitionSpec(from, _), UnresolvedPartitionSpec(to, _)) => AlterTableRenamePartitionCommand(ident.asTableIdentifier, from, to) - case AlterTableDropPartition( + case DropPartitions( ResolvedV1TableIdentifier(ident), specs, ifExists, purge) => AlterTableDropPartitionCommand( ident.asTableIdentifier, @@ -443,7 +443,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) purge, retainData = false) - case AlterTableSerDeProperties( + case SetTableSerDeProperties( ResolvedV1TableIdentifier(ident), serdeClassName, serdeProperties, @@ -454,7 +454,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) serdeProperties, partitionSpec) - case AlterTableSetLocation(ResolvedV1TableIdentifier(ident), partitionSpec, location) => + case SetTableLocation(ResolvedV1TableIdentifier(ident), partitionSpec, location) => AlterTableSetLocationCommand(ident.asTableIdentifier, partitionSpec, location) case AlterViewAs(ResolvedView(ident, _), originalText, query) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 976c7df841dd9..c633442d7b2aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -306,10 +306,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat invalidateTableCache(r), session.sharedState.cacheManager.cacheQuery) :: Nil - case AlterNamespaceSetProperties(ResolvedNamespace(catalog, ns), properties) => + case SetNamespaceProperties(ResolvedNamespace(catalog, ns), properties) => AlterNamespaceSetPropertiesExec(catalog.asNamespaceCatalog, ns, properties) :: Nil - case AlterNamespaceSetLocation(ResolvedNamespace(catalog, ns), location) => + case SetNamespaceLocation(ResolvedNamespace(catalog, ns), location) => AlterNamespaceSetPropertiesExec( catalog.asNamespaceCatalog, ns, @@ -352,7 +352,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case AnalyzeTable(_: ResolvedTable, _, _) | AnalyzeColumn(_: ResolvedTable, _, _) => throw new AnalysisException("ANALYZE TABLE is not supported for v2 tables.") - case AlterTableAddPartition( + case AddPartitions( r @ ResolvedTable(_, _, table: SupportsPartitionManagement, _), parts, ignoreIfExists) => AddPartitionExec( table, @@ -360,7 +360,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat ignoreIfExists, recacheTable(r)) :: Nil - case AlterTableDropPartition( + case DropPartitions( r @ ResolvedTable(_, _, table: SupportsPartitionManagement, _), parts, ignoreIfNotExists, @@ -372,7 +372,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat purge, recacheTable(r)) :: Nil - case AlterTableRenamePartition( + case RenamePartitions( r @ ResolvedTable(_, _, table: SupportsPartitionManagement, _), from, to) => RenamePartitionExec( table, @@ -380,11 +380,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat Seq(to).asResolvedPartitionSpecs.head, recacheTable(r)) :: Nil - case AlterTableRecoverPartitions(_: ResolvedTable) => + case RecoverPartitions(_: ResolvedTable) => throw new AnalysisException( "ALTER TABLE ... RECOVER PARTITIONS is not supported for v2 tables.") - case AlterTableSerDeProperties(_: ResolvedTable, _, _, _) => + case SetTableSerDeProperties(_: ResolvedTable, _, _, _) => throw new AnalysisException( "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES] is not supported for v2 tables.") @@ -421,20 +421,20 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case r: UncacheTable => UncacheTableExec(r.table, cascade = !r.isTempView) :: Nil - case AlterTableSetLocation(table: ResolvedTable, partitionSpec, location) => + case SetTableLocation(table: ResolvedTable, partitionSpec, location) => if (partitionSpec.nonEmpty) { throw QueryCompilationErrors.alterV2TableSetLocationWithPartitionNotSupportedError } val changes = Seq(TableChange.setProperty(TableCatalog.PROP_LOCATION, location)) AlterTableExec(table.catalog, table.identifier, changes) :: Nil - case AlterTableSetProperties(table: ResolvedTable, props) => + case SetTableProperties(table: ResolvedTable, props) => val changes = props.map { case (key, value) => TableChange.setProperty(key, value) }.toSeq AlterTableExec(table.catalog, table.identifier, changes) :: Nil - case AlterTableUnsetProperties(table: ResolvedTable, keys, _) => + case UnsetTableProperties(table: ResolvedTable, keys, _) => val changes = keys.map(key => TableChange.removeProperty(key)) AlterTableExec(table.catalog, table.identifier, changes) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 884a389f94969..96ac3a6f7e5cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdenti import org.apache.spark.sql.catalyst.analysis.UnresolvedTable import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.plans.logical.{AlterTableRecoverPartitions, LocalRelation, LogicalPlan, SubqueryAlias, View} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, RecoverPartitions, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.types.StructType @@ -448,7 +448,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { override def recoverPartitions(tableName: String): Unit = { val multiPartIdent = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) sparkSession.sessionState.executePlan( - AlterTableRecoverPartitions( + RecoverPartitions( UnresolvedTable(multiPartIdent, "recoverPartitions()", None))).toRdd } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d0441ac28631a..533428f9504b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1222,7 +1222,7 @@ class DataSourceV2SQLSuite } } - test("AlterNamespaceSetProperties using v2 catalog") { + test("ALTER NAMESPACE .. SET PROPERTIES using v2 catalog") { withNamespace("testcat.ns1.ns2") { sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " + "'test namespace' LOCATION '/tmp/ns_test' WITH PROPERTIES ('a'='a','b'='b','c'='c')") @@ -1238,7 +1238,7 @@ class DataSourceV2SQLSuite } } - test("AlterNamespaceSetProperties: reserved properties") { + test("ALTER NAMESPACE .. SET PROPERTIES reserved properties") { import SupportsNamespaces._ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) { CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => @@ -1269,7 +1269,7 @@ class DataSourceV2SQLSuite } } - test("AlterNamespaceSetLocation using v2 catalog") { + test("ALTER NAMESPACE .. SET LOCATION using v2 catalog") { withNamespace("testcat.ns1.ns2") { sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " + "'test namespace' LOCATION '/tmp/ns_test_1'") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionParserSuite.scala index 1ec0f45f66118..1694c73b10f2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionParserSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedPartitionSpec, UnresolvedTable} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan -import org.apache.spark.sql.catalyst.plans.logical.AlterTableAddPartition +import org.apache.spark.sql.catalyst.plans.logical.AddPartitions import org.apache.spark.sql.test.SharedSparkSession class AlterTableAddPartitionParserSuite extends AnalysisTest with SharedSparkSession { @@ -29,7 +29,7 @@ class AlterTableAddPartitionParserSuite extends AnalysisTest with SharedSparkSes |(dt='2008-08-08', country='us') LOCATION 'location1' PARTITION |(dt='2009-09-09', country='uk')""".stripMargin val parsed = parsePlan(sql) - val expected = AlterTableAddPartition( + val expected = AddPartitions( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ...", @@ -44,7 +44,7 @@ class AlterTableAddPartitionParserSuite extends AnalysisTest with SharedSparkSes test("add partition") { val sql = "ALTER TABLE a.b.c ADD PARTITION (dt='2008-08-08') LOCATION 'loc'" val parsed = parsePlan(sql) - val expected = AlterTableAddPartition( + val expected = AddPartitions( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ...", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala index b48ca16a6bb45..4c60c80f4e054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedPartitionSpec, UnresolvedTable} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AlterTableDropPartition +import org.apache.spark.sql.catalyst.plans.logical.DropPartitions import org.apache.spark.sql.test.SharedSparkSession class AlterTableDropPartitionParserSuite extends AnalysisTest with SharedSparkSession { @@ -29,7 +29,7 @@ class AlterTableDropPartitionParserSuite extends AnalysisTest with SharedSparkSe |ALTER TABLE table_name DROP PARTITION |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') """.stripMargin - val expected = AlterTableDropPartition( + val expected = DropPartitions( UnresolvedTable( Seq("table_name"), "ALTER TABLE ... DROP PARTITION ...", @@ -49,7 +49,7 @@ class AlterTableDropPartitionParserSuite extends AnalysisTest with SharedSparkSe |PARTITION (dt='2008-08-08', country='us'), |PARTITION (dt='2009-09-09', country='uk') """.stripMargin - val expected = AlterTableDropPartition( + val expected = DropPartitions( UnresolvedTable( Seq("table_name"), "ALTER TABLE ... DROP PARTITION ...", @@ -64,7 +64,7 @@ class AlterTableDropPartitionParserSuite extends AnalysisTest with SharedSparkSe test("drop partition in a table with multi-part identifier") { val sql = "ALTER TABLE a.b.c DROP IF EXISTS PARTITION (ds='2017-06-10')" - val expected = AlterTableDropPartition( + val expected = DropPartitions( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... DROP PARTITION ...", @@ -78,7 +78,7 @@ class AlterTableDropPartitionParserSuite extends AnalysisTest with SharedSparkSe test("drop partition with PURGE") { val sql = "ALTER TABLE table_name DROP PARTITION (p=1) PURGE" - val expected = AlterTableDropPartition( + val expected = DropPartitions( UnresolvedTable( Seq("table_name"), "ALTER TABLE ... DROP PARTITION ...", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala index 04251b665c05e..ebc1bd3468837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTable} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AlterTableRecoverPartitions +import org.apache.spark.sql.catalyst.plans.logical.RecoverPartitions import org.apache.spark.sql.test.SharedSparkSession class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSparkSession { @@ -35,7 +35,7 @@ class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSpa test("recover partitions of a table") { comparePlans( parsePlan("ALTER TABLE tbl RECOVER PARTITIONS"), - AlterTableRecoverPartitions( + RecoverPartitions( UnresolvedTable( Seq("tbl"), "ALTER TABLE ... RECOVER PARTITIONS", @@ -45,7 +45,7 @@ class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSpa test("recover partitions of a table in a database") { comparePlans( parsePlan("alter table db.tbl recover partitions"), - AlterTableRecoverPartitions( + RecoverPartitions( UnresolvedTable( Seq("db", "tbl"), "ALTER TABLE ... RECOVER PARTITIONS", @@ -55,7 +55,7 @@ class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSpa test("recover partitions of a table spark_catalog") { comparePlans( parsePlan("alter table spark_catalog.db.TBL recover partitions"), - AlterTableRecoverPartitions( + RecoverPartitions( UnresolvedTable( Seq("spark_catalog", "db", "TBL"), "ALTER TABLE ... RECOVER PARTITIONS", @@ -65,7 +65,7 @@ class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSpa test("recover partitions of a table in nested namespaces") { comparePlans( parsePlan("Alter Table ns1.ns2.ns3.ns4.ns5.ns6.ns7.ns8.t Recover Partitions"), - AlterTableRecoverPartitions( + RecoverPartitions( UnresolvedTable( Seq("ns1", "ns2", "ns3", "ns4", "ns5", "ns6", "ns7", "ns8", "t"), "ALTER TABLE ... RECOVER PARTITIONS", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala index 5f2856f071df7..4148798d6cdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedPartitionSpec, UnresolvedTable} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan -import org.apache.spark.sql.catalyst.plans.logical.AlterTableRenamePartition +import org.apache.spark.sql.catalyst.plans.logical.RenamePartitions import org.apache.spark.sql.test.SharedSparkSession class AlterTableRenamePartitionParserSuite extends AnalysisTest with SharedSparkSession { @@ -29,7 +29,7 @@ class AlterTableRenamePartitionParserSuite extends AnalysisTest with SharedSpark |RENAME TO PARTITION (ds='2018-06-10') """.stripMargin val parsed = parsePlan(sql) - val expected = AlterTableRenamePartition( + val expected = RenamePartitions( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... RENAME TO PARTITION", @@ -45,7 +45,7 @@ class AlterTableRenamePartitionParserSuite extends AnalysisTest with SharedSpark |RENAME TO PARTITION (dt='2008-09-09', country='uk') """.stripMargin val parsed = parsePlan(sql) - val expected = AlterTableRenamePartition( + val expected = RenamePartitions( UnresolvedTable( Seq("table_name"), "ALTER TABLE ... RENAME TO PARTITION", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 848f15c0a6a9e..17c44bc9ac768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunc import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, StringLiteral} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AlterTableSetLocation, AlterTableSetProperties, AlterTableUnsetProperties, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, ShowTableProperties, SubqueryAlias, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, TableChange, V1Table} @@ -780,23 +780,23 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed3, expected3) } else { parsed1 match { - case AlterTableSetProperties(_: ResolvedTable, properties) => + case SetTableProperties(_: ResolvedTable, properties) => assert(properties == Map(("test", "test"), ("comment", "new_comment"))) - case _ => fail("expect AlterTableSetProperties") + case _ => fail(s"expect ${SetTableProperties.getClass.getName}") } parsed2 match { - case AlterTableUnsetProperties(_: ResolvedTable, propertyKeys, ifExists) => + case UnsetTableProperties(_: ResolvedTable, propertyKeys, ifExists) => assert(propertyKeys == Seq("comment", "test")) assert(!ifExists) - case _ => fail("expect AlterTableUnsetProperties") + case _ => fail(s"expect ${UnsetTableProperties.getClass.getName}") } parsed3 match { - case AlterTableUnsetProperties(_: ResolvedTable, propertyKeys, ifExists) => + case UnsetTableProperties(_: ResolvedTable, propertyKeys, ifExists) => assert(propertyKeys == Seq("comment", "test")) assert(ifExists) - case _ => fail("expect AlterTableUnsetProperties") + case _ => fail(s"expect ${UnsetTableProperties.getClass.getName}") } } } @@ -808,12 +808,14 @@ class PlanResolutionSuite extends AnalysisTest { // For non-existing tables, we convert it to v2 command with `UnresolvedV2Table` parsed4 match { - case AlterTableSetProperties(_: UnresolvedTable, _) => // OK - case _ => fail("Expect AlterTableSetProperties, but got:\n" + parsed4.treeString) + case SetTableProperties(_: UnresolvedTable, _) => // OK + case _ => + fail(s"Expect ${SetTableProperties.getClass.getName}, but got:\n" + parsed4.treeString) } parsed5 match { - case AlterTableUnsetProperties(_: UnresolvedTable, _, _) => // OK - case _ => fail("Expect AlterTableUnsetProperties, but got:\n" + parsed5.treeString) + case UnsetTableProperties(_: UnresolvedTable, _, _) => // OK + case _ => + fail(s"Expect ${UnsetTableProperties.getClass.getName}, but got:\n" + parsed5.treeString) } } @@ -835,7 +837,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed, expected) } else { parsed match { - case AlterTableSetProperties(_: ResolvedTable, changes) => + case SetTableProperties(_: ResolvedTable, changes) => assert(changes == Map(("a", "1"), ("b", "0.1"), ("c", "true"))) case _ => fail("Expect AlterTable, but got:\n" + parsed.treeString) } @@ -856,9 +858,10 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed, expected) } else { parsed match { - case AlterTableSetLocation(_: ResolvedTable, _, location) => + case SetTableLocation(_: ResolvedTable, _, location) => assert(location === "new location") - case _ => fail("Expect AlterTableSetLocation, but got:\n" + parsed.treeString) + case _ => + fail(s"Expect ${SetTableLocation.getClass.getName}, but got:\n" + parsed.treeString) } } } From 429f8af9b683935151c2379bc80b27162cd1c8bf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Feb 2021 02:38:22 +0800 Subject: [PATCH 05/60] Revert "[SPARK-34380][SQL] Support ifExists for ALTER TABLE ... UNSET TBLPROPERTIES for v2 command" This reverts commit 9a566f83a0e126742473574476c6381f58394aed. --- .../sql/catalyst/analysis/Analyzer.scala | 1 - .../analysis/ResolveTableProperties.scala | 43 ------------------- .../datasources/v2/DataSourceV2Strategy.scala | 1 + .../spark/sql/connector/AlterTableTests.scala | 30 ------------- .../command/PlanResolutionSuite.scala | 9 ---- 5 files changed, 1 insertion(+), 83 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 38259c234c262..182f456afa9e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -278,7 +278,6 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveRandomSeed :: ResolveBinaryArithmetic :: ResolveUnion :: - ResolveTableProperties :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Apply Char Padding", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala deleted file mode 100644 index 12b1502dd7c38..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableProperties.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnsetTableProperties} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper -import org.apache.spark.sql.connector.catalog.TableCatalog - -/** - * A rule for resolving [[UnsetTableProperties]] to handle non-existent properties. - */ -object ResolveTableProperties extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case a @ UnsetTableProperties(r: ResolvedTable, props, ifExists) if !ifExists => - val tblProperties = r.table.properties.asScala - props.foreach { p => - if (!tblProperties.contains(p) && p != TableCatalog.PROP_COMMENT) { - throw new AnalysisException( - s"Attempted to unset non-existent property '$p' in table '${r.identifier.quoted}'") - } - } - a - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c633442d7b2aa..3eed7160b6dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -434,6 +434,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat }.toSeq AlterTableExec(table.catalog, table.identifier, changes) :: Nil + // TODO: v2 `UNSET TBLPROPERTIES` should respect the ifExists flag. case UnsetTableProperties(table: ResolvedTable, keys, _) => val changes = keys.map(key => TableChange.removeProperty(key)) AlterTableExec(table.catalog, table.identifier, changes) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 3fc23db9e7000..afc51f45c54ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -1141,36 +1141,6 @@ trait AlterTableTests extends SharedSparkSession { } } - test("SPARK-34380: unset nonexistent table property") { - val t = s"${catalogAndNamespace}table_name" - withTable(t) { - sql(s"CREATE TABLE $t (id int) USING $v2Format TBLPROPERTIES('test' = '34')") - - val tableName = fullTableName(t) - val table = getTableMetadata(tableName) - - assert(table.name === tableName) - assert(table.properties === - withDefaultOwnership(Map("provider" -> v2Format, "test" -> "34")).asJava) - - val exc = intercept[AnalysisException] { - sql(s"ALTER TABLE $t UNSET TBLPROPERTIES ('unknown')") - } - assert(exc.getMessage.contains("Attempted to unset non-existent property 'unknown'")) - - // Reserved property "comment" should be allowed regardless. - sql(s"ALTER TABLE $t UNSET TBLPROPERTIES ('comment')") - - // The following becomes a no-op because "IF EXISTS" is set. - sql(s"ALTER TABLE $t UNSET TBLPROPERTIES IF EXISTS ('unknown')") - - val updated = getTableMetadata(tableName) - assert(updated.name === tableName) - assert(updated.properties === - withDefaultOwnership(Map("provider" -> v2Format, "test" -> "34")).asJava) - } - } - test("AlterTable: replace columns") { val t = s"${catalogAndNamespace}table_name" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 17c44bc9ac768..1b090369f2a23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.command import java.net.URI import java.util.{Collections, Locale} -import scala.collection.JavaConverters._ - import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock @@ -54,7 +52,6 @@ class PlanResolutionSuite extends AnalysisTest { val t = mock(classOf[Table]) when(t.schema()).thenReturn(new StructType().add("i", "int").add("s", "string")) when(t.partitioning()).thenReturn(Array.empty[Transform]) - when(t.properties()).thenReturn(Map("test" ->"test", "comment" -> "new_comment").asJava) t } @@ -70,7 +67,6 @@ class PlanResolutionSuite extends AnalysisTest { when(t.schema).thenReturn(new StructType().add("i", "int").add("s", "string")) when(t.tableType).thenReturn(CatalogTableType.MANAGED) when(t.provider).thenReturn(Some(v1Format)) - when(t.properties).thenReturn(Map("test" ->"test", "comment" -> "new_comment")) V1Table(t) } @@ -756,15 +752,10 @@ class PlanResolutionSuite extends AnalysisTest { "'comment' = 'new_comment')" val sql2 = s"ALTER TABLE $tblName UNSET TBLPROPERTIES ('comment', 'test')" val sql3 = s"ALTER TABLE $tblName UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - val sql4 = s"ALTER TABLE $tblName UNSET TBLPROPERTIES ('unknown')" val parsed1 = parseAndResolve(sql1) val parsed2 = parseAndResolve(sql2) val parsed3 = parseAndResolve(sql3) - val e = intercept[AnalysisException] { - parseAndResolve(sql4) - } - e.getMessage.contains("Attempted to unset non-existent property 'unknown'") if (useV1Command) { val tableIdent = TableIdentifier(tblName, Some("default")) From 443139b601ca87cb0e0c9c2f906d4d4c1e624e35 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 23 Feb 2021 12:18:43 -0800 Subject: [PATCH 06/60] [SPARK-34502][SQL] Remove unused parameters in join methods ### What changes were proposed in this pull request? Remove unused parameters in `CoalesceBucketsInJoin`, `UnsafeCartesianRDD` and `ShuffledHashJoinExec`. ### Why are the changes needed? Clean up ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #31617 from huaxingao/join-minor. Authored-by: Huaxin Gao Signed-off-by: Liang-Chi Hsieh --- .../execution/bucketing/CoalesceBucketsInJoin.scala | 8 +++----- .../sql/execution/joins/CartesianProductExec.scala | 2 -- .../sql/execution/joins/ShuffledHashJoinExec.scala | 10 ++++------ 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala index a4e5be01b45a2..d50c7fd283a5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala @@ -50,7 +50,6 @@ object CoalesceBucketsInJoin extends Rule[SparkPlan] { private def updateNumCoalescedBuckets( join: BaseJoinExec, numLeftBuckets: Int, - numRightBucket: Int, numCoalescedBuckets: Int): BaseJoinExec = { if (numCoalescedBuckets != numLeftBuckets) { val leftCoalescedChild = @@ -72,7 +71,6 @@ object CoalesceBucketsInJoin extends Rule[SparkPlan] { private def isCoalesceSHJStreamSide( join: ShuffledHashJoinExec, numLeftBuckets: Int, - numRightBucket: Int, numCoalescedBuckets: Int): Boolean = { if (numCoalescedBuckets == numLeftBuckets) { join.buildSide != BuildRight @@ -93,12 +91,12 @@ object CoalesceBucketsInJoin extends Rule[SparkPlan] { val numCoalescedBuckets = math.min(numLeftBuckets, numRightBuckets) join match { case j: SortMergeJoinExec => - updateNumCoalescedBuckets(j, numLeftBuckets, numRightBuckets, numCoalescedBuckets) + updateNumCoalescedBuckets(j, numLeftBuckets, numCoalescedBuckets) case j: ShuffledHashJoinExec // Only coalesce the buckets for shuffled hash join stream side, // to avoid OOM for build side. - if isCoalesceSHJStreamSide(j, numLeftBuckets, numRightBuckets, numCoalescedBuckets) => - updateNumCoalescedBuckets(j, numLeftBuckets, numRightBuckets, numCoalescedBuckets) + if isCoalesceSHJStreamSide(j, numLeftBuckets, numCoalescedBuckets) => + updateNumCoalescedBuckets(j, numLeftBuckets, numCoalescedBuckets) case other => other } case other => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index a71bf94c45034..b6386d0d11b4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -35,7 +35,6 @@ import org.apache.spark.util.CompletionIterator class UnsafeCartesianRDD( left : RDD[UnsafeRow], right : RDD[UnsafeRow], - numFieldsOfRight: Int, inMemoryBufferThreshold: Int, spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { @@ -81,7 +80,6 @@ case class CartesianProductExec( val pair = new UnsafeCartesianRDD( leftResults, rightResults, - right.output.size, sqlContext.conf.cartesianProductExecBufferInMemoryThreshold, sqlContext.conf.cartesianProductExecBufferSpillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 1dc7a3b7eecb3..cd57408e7972d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -117,10 +117,10 @@ case class ShuffledHashJoinExec( val iter = if (hashedRelation.keyIsUnique) { fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, joinRowWithStream, - joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, streamNullRow) + joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow) } else { fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, joinRowWithStream, - joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, streamNullRow) + joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow) } val resultProj = UnsafeProjection.create(output, output) @@ -146,8 +146,7 @@ case class ShuffledHashJoinExec( joinRowWithStream: InternalRow => JoinedRow, joinRowWithBuild: InternalRow => JoinedRow, streamNullJoinRowWithBuild: => InternalRow => JoinedRow, - buildNullRow: GenericInternalRow, - streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + buildNullRow: GenericInternalRow): Iterator[InternalRow] = { val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex) longMetric("buildDataSize") += matchedKeys.capacity / 8 @@ -213,8 +212,7 @@ case class ShuffledHashJoinExec( joinRowWithStream: InternalRow => JoinedRow, joinRowWithBuild: InternalRow => JoinedRow, streamNullJoinRowWithBuild: => InternalRow => JoinedRow, - buildNullRow: GenericInternalRow, - streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + buildNullRow: GenericInternalRow): Iterator[InternalRow] = { val matchedRows = new OpenHashSet[Long] TaskContext.get().addTaskCompletionListener[Unit](_ => { // At the end of the task, update the task's memory usage for this From 0d5d248bdc4cdc71627162a3d20c42ad19f24ef4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 Feb 2021 13:35:29 -0800 Subject: [PATCH 07/60] [SPARK-34508][SQL][TEST] Skip HiveExternalCatalogVersionsSuite if network is down ### What changes were proposed in this pull request? It's possible that the network is down when running Spark tests, and it's annoying to see `HiveExternalCatalogVersionsSuite` keep failing. This PR proposes to skip this test suite if we can't get the latest Spark version from the Apache website. ### Why are the changes needed? Make the Spark tests more robust. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? N/A Closes #31627 from cloud-fan/test. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 766edcae6e4a1..c8473bf28b746 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -194,7 +194,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { // scalastyle:on line.size.limit if (PROCESS_TABLES.testingVersions.isEmpty) { - fail("Fail to get the lates Spark versions to test.") + logError("Fail to get the latest Spark versions to test.") } PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => @@ -232,7 +232,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", unusedJar.toString) - runSparkSubmit(args) + if (PROCESS_TABLES.testingVersions.nonEmpty) runSparkSubmit(args) } } @@ -251,8 +251,8 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .map("""""".r.findFirstMatchIn(_).get.group(1)) .filter(_ < org.apache.spark.SPARK_VERSION) } catch { - // do not throw exception during object initialization. - case NonFatal(_) => Seq("3.0.1", "2.4.7") // A temporary fallback to use a specific version + // Do not throw exception during object initialization. + case NonFatal(_) => Nil } versions .filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) From 95e45c6257a614754e132f92b7b7239573d42b7a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 Feb 2021 13:41:24 -0800 Subject: [PATCH 08/60] [SPARK-34168][SQL][FOLLOWUP] Improve DynamicPartitionPruningSuiteBase ### What changes were proposed in this pull request? A few minor improvements for `DynamicPartitionPruningSuiteBase`. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #31625 from cloud-fan/followup. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../sql/DynamicPartitionPruningSuite.scala | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index cd7c4415d6f2b..bc9c3006cddc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expr import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, DisableAdaptiveExecution, EliminateJoinToEmptyRelation} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec} +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.functions._ @@ -44,14 +44,9 @@ abstract class DynamicPartitionPruningSuiteBase import testImplicits._ - val adaptiveExecutionOn: Boolean - override def beforeAll(): Unit = { super.beforeAll() - spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, adaptiveExecutionOn) - spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY, true) - val factData = Seq[(Int, Int, Int, Int)]( (1000, 1, 1, 10), (1010, 2, 1, 10), @@ -195,8 +190,8 @@ abstract class DynamicPartitionPruningSuiteBase subqueryBroadcast.foreach { s => s.child match { case _: ReusedExchangeExec => // reuse check ok. - case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => - case b: BroadcastExchangeExec => + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => // reuse check ok. + case b: BroadcastExchangeLike => val hasReuse = plan.find { case ReusedExchangeExec(_, e) => e eq b case _ => false @@ -337,7 +332,7 @@ abstract class DynamicPartitionPruningSuiteBase def getFactScan(plan: SparkPlan): SparkPlan = { val scanOption = - plan.find { + find(plan) { case s: FileSourceScanExec => s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined) case _ => false @@ -1261,7 +1256,7 @@ abstract class DynamicPartitionPruningSuiteBase val countSubqueryBroadcasts = collectWithSubqueries(plan)({ case _: SubqueryBroadcastExec => 1 }).sum - if (adaptiveExecutionOn) { + if (conf.getConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED)) { val countReusedSubqueryBroadcasts = collectWithSubqueries(plan)({ case ReusedSubqueryExec(_: SubqueryBroadcastExec) => 1}).sum @@ -1390,10 +1385,8 @@ abstract class DynamicPartitionPruningSuiteBase } } -class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase { - override val adaptiveExecutionOn: Boolean = false -} +class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase + with DisableAdaptiveExecutionSuite -class DynamicPartitionPruningSuiteAEOn extends DynamicPartitionPruningSuiteBase { - override val adaptiveExecutionOn: Boolean = true -} +class DynamicPartitionPruningSuiteAEOn extends DynamicPartitionPruningSuiteBase + with EnableAdaptiveExecutionSuite From 7f27d33a3c538da6754a6c011b29aa7eb0dafe2c Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Tue, 23 Feb 2021 13:45:15 -0800 Subject: [PATCH 09/60] [SPARK-31891][SQL] Support `MSCK REPAIR TABLE .. [{ADD|DROP|SYNC} PARTITIONS]` ### What changes were proposed in this pull request? In the PR, I propose to extend the `MSCK REPAIR TABLE` command, and support new options `{ADD|DROP|SYNC} PARTITIONS`. In particular: 1. Extend the logical node `RepairTable`, and add two new flags `enableAddPartitions` and `enableDropPartitions`. 2. Add similar flags to the v1 execution node `AlterTableRecoverPartitionsCommand` 3. Add new method `dropPartitions()` to `AlterTableRecoverPartitionsCommand` which drops partitions from the catalog if their locations in the file system don't exist. 4. Updated public docs about the `MSCK REPAIR TABLE` command: Screenshot 2021-02-16 at 13 46 39 Closes #31097 ### Why are the changes needed? - The changes allow to recover tables with removed partitions. The example below portraits the problem: ```sql spark-sql> create table tbl2 (col int, part int) partitioned by (part); spark-sql> insert into tbl2 partition (part=1) select 1; spark-sql> insert into tbl2 partition (part=0) select 0; spark-sql> show table extended like 'tbl2' partition (part = 0); default tbl2 false Partition Values: [part=0] Location: file:/Users/maximgekk/proj/apache-spark/spark-warehouse/tbl2/part=0 ... ``` Remove the partition (part = 0) from the filesystem: ``` $ rm -rf /Users/maximgekk/proj/apache-spark/spark-warehouse/tbl2/part=0 ``` Even after recovering, we cannot query the table: ```sql spark-sql> msck repair table tbl2; spark-sql> select * from tbl2; 21/01/08 22:49:13 ERROR SparkSQLDriver: Failed in [select * from tbl2] org.apache.hadoop.mapred.InvalidInputException: Input path does not exist: file:/Users/maximgekk/proj/apache-spark/spark-warehouse/tbl2/part=0 ``` - To have feature parity with Hive: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-RecoverPartitions(MSCKREPAIRTABLE) ### Does this PR introduce _any_ user-facing change? Yes. After the changes, we can query recovered table: ```sql spark-sql> msck repair table tbl2 sync partitions; spark-sql> select * from tbl2; 1 1 spark-sql> show partitions tbl2; part=1 ``` ### How was this patch tested? - By running the modified test suite: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *MsckRepairTableParserSuite" $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *PlanResolutionSuite" $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *AlterTableRecoverPartitionsSuite" $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *AlterTableRecoverPartitionsParallelSuite" ``` - Added unified v1 and v2 tests for `MSCK REPAIR TABLE`: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *MsckRepairTableSuite" ``` Closes #31499 from MaxGekk/repair-table-drop-partitions. Authored-by: Max Gekk Signed-off-by: Dongjoon Hyun --- docs/sql-ref-ansi-compliance.md | 1 + docs/sql-ref-syntax-ddl-repair-table.md | 9 ++- .../spark/sql/catalyst/parser/SqlBase.g4 | 6 +- .../sql/catalyst/parser/AstBuilder.scala | 17 +++- .../catalyst/plans/logical/v2Commands.scala | 5 +- .../sql/catalyst/parser/DDLParserSuite.scala | 6 -- .../analysis/ResolveSessionCatalog.scala | 10 ++- .../command/createDataSourceTables.scala | 5 +- .../spark/sql/execution/command/ddl.scala | 72 +++++++++++------ .../datasources/v2/DataSourceV2Strategy.scala | 2 +- .../command/DDLCommandTestUtils.scala | 25 ++++++ .../command/MsckRepairTableParserSuite.scala | 69 +++++++++++++++++ .../command/MsckRepairTableSuiteBase.scala | 38 +++++++++ .../v1/AlterTableAddPartitionSuite.scala | 22 ------ .../command/v1/MsckRepairTableSuite.scala | 77 +++++++++++++++++++ .../command/v2/MsckRepairTableSuite.scala | 41 ++++++++++ .../command/MsckRepairTableSuite.scala | 26 +++++++ 17 files changed, 372 insertions(+), 59 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableParserSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableSuiteBase.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/MsckRepairTableSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/MsckRepairTableSuite.scala diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index f97b166206396..48f6a6e508a8c 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -397,6 +397,7 @@ Below is a list of all the keywords in Spark SQL. |STRUCT|non-reserved|non-reserved|non-reserved| |SUBSTR|non-reserved|non-reserved|non-reserved| |SUBSTRING|non-reserved|non-reserved|non-reserved| +|SYNC|non-reserved|non-reserved|non-reserved| |TABLE|reserved|non-reserved|reserved| |TABLES|non-reserved|non-reserved|non-reserved| |TABLESAMPLE|non-reserved|non-reserved|reserved| diff --git a/docs/sql-ref-syntax-ddl-repair-table.md b/docs/sql-ref-syntax-ddl-repair-table.md index 36145126d2496..41499c3314c45 100644 --- a/docs/sql-ref-syntax-ddl-repair-table.md +++ b/docs/sql-ref-syntax-ddl-repair-table.md @@ -28,7 +28,7 @@ If the table is cached, the command clears cached data of the table and all its ### Syntax ```sql -MSCK REPAIR TABLE table_identifier +MSCK REPAIR TABLE table_identifier [{ADD|DROP|SYNC} PARTITIONS] ``` ### Parameters @@ -39,6 +39,13 @@ MSCK REPAIR TABLE table_identifier **Syntax:** `[ database_name. ] table_name` +* **`{ADD|DROP|SYNC} PARTITIONS`** + + * If specified, `MSCK REPAIR TABLE` only adds partitions to the session catalog. + * **ADD**, the command adds new partitions to the session catalog for all sub-folder in the base table folder that don't belong to any table partitions. + * **DROP**, the command drops all partitions from the session catalog that have non-existing locations in the file system. + * **SYNC** is the combination of **DROP** and **ADD**. + ### Examples ```sql diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ab4b7833503fb..50ef3764f3994 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -229,7 +229,8 @@ statement | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE multipartIdentifier partitionSpec? #loadData | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable - | MSCK REPAIR TABLE multipartIdentifier #repairTable + | MSCK REPAIR TABLE multipartIdentifier + (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable | op=(ADD | LIST) identifier (STRING | .*?) #manageResource | SET ROLE .*? #failNativeCommand | SET TIME ZONE interval #setTimeZone @@ -1173,6 +1174,7 @@ ansiNonReserved | STRUCT | SUBSTR | SUBSTRING + | SYNC | TABLES | TABLESAMPLE | TBLPROPERTIES @@ -1429,6 +1431,7 @@ nonReserved | STRUCT | SUBSTR | SUBSTRING + | SYNC | TABLE | TABLES | TABLESAMPLE @@ -1687,6 +1690,7 @@ STRATIFY: 'STRATIFY'; STRUCT: 'STRUCT'; SUBSTR: 'SUBSTR'; SUBSTRING: 'SUBSTRING'; +SYNC: 'SYNC'; TABLE: 'TABLE'; TABLES: 'TABLES'; TABLESAMPLE: 'TABLESAMPLE'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 595a3a5ba5332..23f9c8398b727 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3659,11 +3659,24 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * * For example: * {{{ - * MSCK REPAIR TABLE multi_part_name + * MSCK REPAIR TABLE multi_part_name [{ADD|DROP|SYNC} PARTITIONS] * }}} */ override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) { - RepairTable(createUnresolvedTable(ctx.multipartIdentifier, "MSCK REPAIR TABLE")) + val (enableAddPartitions, enableDropPartitions, option) = + if (ctx.SYNC() != null) { + (true, true, " ... SYNC PARTITIONS") + } else if (ctx.DROP() != null) { + (false, true, " ... DROP PARTITIONS") + } else if (ctx.ADD() != null) { + (true, false, " ... ADD PARTITIONS") + } else { + (true, false, "") + } + RepairTable( + createUnresolvedTable(ctx.multipartIdentifier, s"MSCK REPAIR TABLE$option"), + enableAddPartitions, + enableDropPartitions) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 8797b107f945a..12f13e73eadf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -802,7 +802,10 @@ case class DropView( /** * The logical plan of the MSCK REPAIR TABLE command. */ -case class RepairTable(child: LogicalPlan) extends Command { +case class RepairTable( + child: LogicalPlan, + enableAddPartitions: Boolean, + enableDropPartitions: Boolean) extends Command { override def children: Seq[LogicalPlan] = child :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index cb9dda8260a50..870ff388edc1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1913,12 +1913,6 @@ class DDLParserSuite extends AnalysisTest { "missing 'COLUMNS' at ''") } - test("MSCK REPAIR TABLE") { - comparePlans( - parsePlan("MSCK REPAIR TABLE a.b.c"), - RepairTable(UnresolvedTable(Seq("a", "b", "c"), "MSCK REPAIR TABLE", None))) - } - test("LOAD DATA INTO table") { comparePlans( parsePlan("LOAD DATA INPATH 'filepath' INTO TABLE a.b.c"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 7a8f4dd39080e..55e8c5fba0d3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -376,8 +376,12 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case AnalyzeColumn(ResolvedV1TableOrViewIdentifier(ident), columnNames, allColumns) => AnalyzeColumnCommand(ident.asTableIdentifier, columnNames, allColumns) - case RepairTable(ResolvedV1TableIdentifier(ident)) => - AlterTableRecoverPartitionsCommand(ident.asTableIdentifier, "MSCK REPAIR TABLE") + case RepairTable(ResolvedV1TableIdentifier(ident), addPartitions, dropPartitions) => + AlterTableRecoverPartitionsCommand( + ident.asTableIdentifier, + addPartitions, + dropPartitions, + "MSCK REPAIR TABLE") case LoadData(ResolvedV1TableIdentifier(ident), path, isLocal, isOverwrite, partition) => LoadDataCommand( @@ -420,6 +424,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case RecoverPartitions(ResolvedV1TableIdentifier(ident)) => AlterTableRecoverPartitionsCommand( ident.asTableIdentifier, + enableAddPartitions = true, + enableDropPartitions = false, "ALTER TABLE RECOVER PARTITIONS") case AddPartitions(ResolvedV1TableIdentifier(ident), partSpecsAndLocs, ifNotExists) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index be7fa7b1b447e..b3e48e37c66e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -189,7 +189,10 @@ case class CreateDataSourceTableAsSelectCommand( case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && sparkSession.sqlContext.conf.manageFilesourcePartitions => // Need to recover partitions into the metastore so our saved data is visible. - sessionState.executePlan(AlterTableRecoverPartitionsCommand(table.identifier)).toRdd + sessionState.executePlan(AlterTableRecoverPartitionsCommand( + table.identifier, + enableAddPartitions = true, + enableDropPartitions = false)).toRdd case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 7b4feb4af35fa..f0219efbf9a98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -597,11 +597,13 @@ case class PartitionStatistics(numFiles: Int, totalSize: Long) * The syntax of this command is: * {{{ * ALTER TABLE table RECOVER PARTITIONS; - * MSCK REPAIR TABLE table; + * MSCK REPAIR TABLE table [{ADD|DROP|SYNC} PARTITIONS]; * }}} */ case class AlterTableRecoverPartitionsCommand( tableName: TableIdentifier, + enableAddPartitions: Boolean, + enableDropPartitions: Boolean, cmd: String = "ALTER TABLE RECOVER PARTITIONS") extends RunnableCommand { // These are list of statistics that can be collected quickly without requiring a scan of the data @@ -645,34 +647,40 @@ case class AlterTableRecoverPartitionsCommand( val hadoopConf = spark.sessionState.newHadoopConf() val fs = root.getFileSystem(hadoopConf) - val threshold = spark.sparkContext.conf.get(RDD_PARALLEL_LISTING_THRESHOLD) - val pathFilter = getPathFilter(hadoopConf) + val droppedAmount = if (enableDropPartitions) { + dropPartitions(catalog, fs) + } else 0 + val addedAmount = if (enableAddPartitions) { + val threshold = spark.sparkContext.conf.get(RDD_PARALLEL_LISTING_THRESHOLD) + val pathFilter = getPathFilter(hadoopConf) + + val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) + val partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)] = + try { + scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, + spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq + } finally { + evalPool.shutdown() + } + val total = partitionSpecsAndLocs.length + logInfo(s"Found $total partitions in $root") - val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) - val partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)] = - try { - scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, - spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq - } finally { - evalPool.shutdown() + val partitionStats = if (spark.sqlContext.conf.gatherFastStats) { + gatherPartitionStats(spark, partitionSpecsAndLocs, fs, pathFilter, threshold) + } else { + GenMap.empty[String, PartitionStatistics] } - val total = partitionSpecsAndLocs.length - logInfo(s"Found $total partitions in $root") - - val partitionStats = if (spark.sqlContext.conf.gatherFastStats) { - gatherPartitionStats(spark, partitionSpecsAndLocs, fs, pathFilter, threshold) - } else { - GenMap.empty[String, PartitionStatistics] - } - logInfo(s"Finished to gather the fast stats for all $total partitions.") + logInfo(s"Finished to gather the fast stats for all $total partitions.") - addPartitions(spark, table, partitionSpecsAndLocs, partitionStats) + addPartitions(spark, table, partitionSpecsAndLocs, partitionStats) + total + } else 0 // Updates the table to indicate that its partition metadata is stored in the Hive metastore. // This is always the case for Hive format tables, but is not true for Datasource tables created // before Spark 2.1 unless they are converted via `msck repair table`. spark.sessionState.catalog.alterTable(table.copy(tracksPartitionsInCatalog = true)) spark.catalog.refreshTable(tableIdentWithDB) - logInfo(s"Recovered all partitions ($total).") + logInfo(s"Recovered all partitions: added ($addedAmount), dropped ($droppedAmount).") Seq.empty[Row] } @@ -791,8 +799,28 @@ case class AlterTableRecoverPartitionsCommand( logDebug(s"Recovered ${parts.length} partitions ($done/$total so far)") } } -} + // Drops the partitions that do not exist in the file system + private def dropPartitions(catalog: SessionCatalog, fs: FileSystem): Int = { + val dropPartSpecs = ThreadUtils.parmap( + catalog.listPartitions(tableName), + "AlterTableRecoverPartitionsCommand: non-existing partitions", + maxThreads = 8) { partition => + partition.storage.locationUri.flatMap { uri => + if (fs.exists(new Path(uri))) None else Some(partition.spec) + } + }.flatten + catalog.dropPartitions( + tableName, + dropPartSpecs, + ignoreIfNotExists = true, + purge = false, + // Since we have already checked that partition directories do not exist, we can avoid + // additional calls to the file system at the catalog side by setting this flag. + retainData = true) + dropPartSpecs.length + } +} /** * A command that sets the location of a table or a partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 3eed7160b6dfe..16a6b2ef2f2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -409,7 +409,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat table, pattern.map(_.asInstanceOf[ResolvedPartitionSpec])) :: Nil - case RepairTable(_: ResolvedTable) => + case RepairTable(_: ResolvedTable, _, _) => throw new AnalysisException("MSCK REPAIR TABLE is not supported for v2 tables.") case r: CacheTable => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala index 547fef6acac1c..f9e26f8277d8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.command +import java.io.File + +import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{FileSystem, Path} import org.scalactic.source.Position import org.scalatest.Tag @@ -144,4 +147,26 @@ trait DDLCommandTestUtils extends SQLTestUtils { val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) f(fs, root) } + + def getPartitionLocation(tableName: String, part: String): String = { + val idents = tableName.split('.') + val table = idents.last + val catalogAndNs = idents.init + val in = if (catalogAndNs.isEmpty) "" else s"IN ${catalogAndNs.mkString(".")}" + val information = sql(s"SHOW TABLE EXTENDED $in LIKE '$table' PARTITION ($part)") + .select("information") + .first().getString(0) + information + .split("\\r?\\n") + .filter(_.startsWith("Location:")) + .head + .replace("Location: file:", "") + } + + def copyPartition(tableName: String, from: String, to: String): String = { + val part0Loc = getPartitionLocation(tableName, from) + val part1Loc = part0Loc.replace(from, to) + FileUtils.copyDirectory(new File(part0Loc), new File(part1Loc)) + part1Loc + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableParserSuite.scala new file mode 100644 index 0000000000000..458b3a4fc3c8d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableParserSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTable} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.plans.logical.RepairTable + +class MsckRepairTableParserSuite extends AnalysisTest { + test("repair a table") { + comparePlans( + parsePlan("MSCK REPAIR TABLE a.b.c"), + RepairTable( + UnresolvedTable(Seq("a", "b", "c"), "MSCK REPAIR TABLE", None), + enableAddPartitions = true, + enableDropPartitions = false)) + } + + test("add partitions") { + comparePlans( + parsePlan("msck repair table ns.tbl add partitions"), + RepairTable( + UnresolvedTable( + Seq("ns", "tbl"), + "MSCK REPAIR TABLE ... ADD PARTITIONS", + None), + enableAddPartitions = true, + enableDropPartitions = false)) + } + + test("drop partitions") { + comparePlans( + parsePlan("MSCK repair table TBL Drop Partitions"), + RepairTable( + UnresolvedTable( + Seq("TBL"), + "MSCK REPAIR TABLE ... DROP PARTITIONS", + None), + enableAddPartitions = false, + enableDropPartitions = true)) + } + + test("sync partitions") { + comparePlans( + parsePlan("MSCK REPAIR TABLE spark_catalog.ns.tbl SYNC PARTITIONS"), + RepairTable( + UnresolvedTable( + Seq("spark_catalog", "ns", "tbl"), + "MSCK REPAIR TABLE ... SYNC PARTITIONS", + None), + enableAddPartitions = true, + enableDropPartitions = true)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableSuiteBase.scala new file mode 100644 index 0000000000000..b8b0d003a314c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/MsckRepairTableSuiteBase.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.QueryTest + +/** + * This base suite contains unified tests for the `MSCK REPAIR TABLE` command that + * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are + * located in more specific test suites: + * + * - V2 table catalog tests: + * `org.apache.spark.sql.execution.command.v2.MsckRepairTableSuite` + * - V1 table catalog tests: + * `org.apache.spark.sql.execution.command.v1.MsckRepairTableSuiteBase` + * - V1 In-Memory catalog: + * `org.apache.spark.sql.execution.command.v1.MsckRepairTableSuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.MsckRepairTableSuite` + */ +trait MsckRepairTableSuiteBase extends QueryTest with DDLCommandTestUtils { + override val command = "MSCK REPAIR TABLE" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala index 4013f623e074c..b2e626be1b180 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.execution.command.v1 -import java.io.File - -import org.apache.commons.io.FileUtils - import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.execution.command import org.apache.spark.sql.internal.SQLConf @@ -47,24 +43,6 @@ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuit } } - private def copyPartition(tableName: String, from: String, to: String): String = { - val idents = tableName.split('.') - val table = idents.last - val catalogAndNs = idents.init - val in = if (catalogAndNs.isEmpty) "" else s"IN ${catalogAndNs.mkString(".")}" - val information = sql(s"SHOW TABLE EXTENDED $in LIKE '$table' PARTITION ($from)") - .select("information") - .first().getString(0) - val part0Loc = information - .split("\\r?\\n") - .filter(_.startsWith("Location:")) - .head - .replace("Location: file:", "") - val part1Loc = part0Loc.replace(from, to) - FileUtils.copyDirectory(new File(part0Loc), new File(part1Loc)) - part1Loc - } - test("SPARK-34055: refresh cache in partition adding") { withTable("t") { sql(s"CREATE TABLE t (id int, part int) $defaultUsing PARTITIONED BY (part)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/MsckRepairTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/MsckRepairTableSuite.scala new file mode 100644 index 0000000000000..45dc9e0e00f63 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/MsckRepairTableSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v1 + +import java.io.File + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.command + +/** + * This base suite contains unified tests for the `MSCK REPAIR TABLE` command that + * check V1 table catalogs. The tests that cannot run for all V1 catalogs are located in more + * specific test suites: + * + * - V1 In-Memory catalog: + * `org.apache.spark.sql.execution.command.v1.MsckRepairTableSuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.MsckRepairTableSuite` + */ +trait MsckRepairTableSuiteBase extends command.MsckRepairTableSuiteBase { + def deletePartitionDir(tableName: String, part: String): Unit = { + val partLoc = getPartitionLocation(tableName, part) + FileUtils.deleteDirectory(new File(partLoc)) + } + + test("drop partitions") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (col INT, part INT) $defaultUsing PARTITIONED BY (part)") + sql(s"INSERT INTO $t PARTITION (part=0) SELECT 0") + sql(s"INSERT INTO $t PARTITION (part=1) SELECT 1") + + checkAnswer(spark.table(t), Seq(Row(0, 0), Row(1, 1))) + deletePartitionDir(t, "part=1") + sql(s"MSCK REPAIR TABLE $t DROP PARTITIONS") + checkPartitions(t, Map("part" -> "0")) + checkAnswer(spark.table(t), Seq(Row(0, 0))) + } + } + + test("sync partitions") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (col INT, part INT) $defaultUsing PARTITIONED BY (part)") + sql(s"INSERT INTO $t PARTITION (part=0) SELECT 0") + sql(s"INSERT INTO $t PARTITION (part=1) SELECT 1") + + checkAnswer(sql(s"SELECT col, part FROM $t"), Seq(Row(0, 0), Row(1, 1))) + copyPartition(t, "part=0", "part=2") + deletePartitionDir(t, "part=0") + sql(s"MSCK REPAIR TABLE $t SYNC PARTITIONS") + checkPartitions(t, Map("part" -> "1"), Map("part" -> "2")) + checkAnswer(sql(s"SELECT col, part FROM $t"), Seq(Row(1, 1), Row(0, 2))) + } + } +} + +/** + * The class contains tests for the `MSCK REPAIR TABLE` command to check + * V1 In-Memory table catalog. + */ +class MsckRepairTableSuite extends MsckRepairTableSuiteBase with CommandSuiteBase diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala new file mode 100644 index 0000000000000..d4b23e50786eb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.command + +/** + * The class contains tests for the `MSCK REPAIR TABLE` command + * to check V2 table catalogs. + */ +class MsckRepairTableSuite + extends command.MsckRepairTableSuiteBase + with CommandSuiteBase { + + // TODO(SPARK-34397): Support v2 `MSCK REPAIR TABLE` + test("repairing of v2 tables is not supported") { + withNamespaceAndTable("ns", "tbl") { t => + spark.sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + val errMsg = intercept[AnalysisException] { + sql(s"MSCK REPAIR TABLE $t") + }.getMessage + assert(errMsg.contains("MSCK REPAIR TABLE is not supported for v2 tables")) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/MsckRepairTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/MsckRepairTableSuite.scala new file mode 100644 index 0000000000000..fc40aa2b82f93 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/MsckRepairTableSuite.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution.command + +import org.apache.spark.sql.execution.command.v1 + +/** + * The class contains tests for the `MSCK REPAIR TABLE` command to check + * V1 Hive external table catalog. + */ +class MsckRepairTableSuite extends v1.MsckRepairTableSuiteBase with CommandSuiteBase From 2e31e2c5f30742c312767f26b17396c4ecfbef72 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 23 Feb 2021 16:37:29 -0800 Subject: [PATCH 10/60] [SPARK-34503][CORE] Use zstd for spark.eventLog.compression.codec by default ### What changes were proposed in this pull request? Apache Spark 3.0 introduced `spark.eventLog.compression.codec` configuration. For Apache Spark 3.2, this PR aims to set `zstd` as the default value for `spark.eventLog.compression.codec` configuration. This only affects creating a new log file. ### Why are the changes needed? The main purpose of event logs is archiving. Many logs are generated and occupy the storage, but most of them are never accessed by users. **1. Save storage resources (and money)** In general, ZSTD is much smaller than LZ4. For example, in case of TPCDS (Scale 200) log, ZSTD generates about 3 times smaller log files than LZ4. | CODEC | SIZE (bytes) | |---------|-------------| | LZ4 | 184001434| | ZSTD | 64522396| And, the plain file is 17.6 times bigger. ``` -rw-r--r-- 1 dongjoon staff 1135464691 Feb 21 22:31 spark-a1843ead29834f46b1125a03eca32679 -rw-r--r-- 1 dongjoon staff 64522396 Feb 21 22:31 spark-a1843ead29834f46b1125a03eca32679.zstd ``` **2. Better Usability** We cannot decompress Spark-generated LZ4 event log files via CLI while we can for ZSTD event log files. Spark's LZ4 event log files are inconvenient to some users who want to uncompress and access them. ``` $ lz4 -d spark-d3deba027bd34435ba849e14fc2c42ef.lz4 Decoding file spark-d3deba027bd34435ba849e14fc2c42ef Error 44 : Unrecognized header : file cannot be decoded ``` ``` $ zstd -d spark-a1843ead29834f46b1125a03eca32679.zstd spark-a1843ead29834f46b1125a03eca32679.zstd: 1135464691 bytes ``` **3. Speed** The following results are collected by running [lzbench](https://github.com/inikep/lzbench) on the above Spark event log. Note that - This is not a direct comparison of Spark compression/decompression codec. - `lzbench` is an in-memory benchmark. So, it doesn't show the benefit of the reduced network traffic due to the small size of ZSTD. Here, - To get ZSTD 1.4.8-1 result, `lzbench` `master` branch is used because Spark is using ZSTD 1.4.8. - To get LZ4 1.7.5 result, `lzbench` `v1.7` branch is used because Spark is using LZ4 1.7.1. ``` Compressor name Compress. Decompress. Compr. size Ratio Filename memcpy 7393 MB/s 7166 MB/s 1135464691 100.00 spark-a1843ead29834f46b1125a03eca32679 zstd 1.4.8 -1 1344 MB/s 3351 MB/s 56665767 4.99 spark-a1843ead29834f46b1125a03eca32679 lz4 1.7.5 1385 MB/s 4782 MB/s 127662168 11.24 spark-a1843ead29834f46b1125a03eca32679 ``` ### Does this PR introduce _any_ user-facing change? - No for the apps which doesn't use `spark.eventLog.compress` because `spark.eventLog.compress` is disabled by default. - No for the apps using `spark.eventLog.compression.codec` explicitly because this is a change of the default value. - Yes for the apps using `spark.eventLog.compress` without setting `spark.eventLog.compression.codec`. In this case, previously `spark.io.compression.codec` value was used whose default is `lz4`. So this JIRA issue, SPARK-34503, is labeled with `releasenotes`. ### How was this patch tested? Pass the updated UT. Closes #31618 from dongjoon-hyun/SPARK-34503. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/internal/config/package.scala | 5 +++-- .../deploy/history/EventLogFileWritersSuite.scala | 10 ++-------- docs/configuration.md | 5 ++--- docs/core-migration-guide.md | 2 ++ 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 3101bb663263e..4a2281a4e8785 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1726,9 +1726,10 @@ package object config { ConfigBuilder("spark.eventLog.compression.codec") .doc("The codec used to compress event log. By default, Spark provides four codecs: " + "lz4, lzf, snappy, and zstd. You can also use fully qualified class names to specify " + - "the codec. If this is not given, spark.io.compression.codec will be used.") + "the codec.") .version("3.0.0") - .fallbackConf(IO_COMPRESSION_CODEC) + .stringConf + .createWithDefault("zstd") private[spark] val BUFFER_SIZE = ConfigBuilder("spark.buffer.size") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala index e9b739ce7a4c6..e6dd9ae4224d9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala @@ -99,7 +99,7 @@ abstract class EventLogFileWritersSuite extends SparkFunSuite with LocalSparkCon } } - test("spark.eventLog.compression.codec overrides spark.io.compression.codec") { + test("Use the defalut value of spark.eventLog.compression.codec") { val conf = new SparkConf conf.set(EVENT_LOG_COMPRESS, true) val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) @@ -107,14 +107,8 @@ abstract class EventLogFileWritersSuite extends SparkFunSuite with LocalSparkCon val appId = "test" val appAttemptId = None - // The default value is `spark.io.compression.codec`. val writer = createWriter(appId, appAttemptId, testDirPath.toUri, conf, hadoopConf) - assert(writer.compressionCodecName.contains("lz4")) - - // `spark.eventLog.compression.codec` overrides `spark.io.compression.codec`. - conf.set(EVENT_LOG_COMPRESSION_CODEC, "zstd") - val writer2 = createWriter(appId, appAttemptId, testDirPath.toUri, conf, hadoopConf) - assert(writer2.compressionCodecName.contains("zstd")) + assert(writer.compressionCodecName === EVENT_LOG_COMPRESSION_CODEC.defaultValue) } protected def readLinesFromEventLogFile(log: Path, fs: FileSystem): List[String] = { diff --git a/docs/configuration.md b/docs/configuration.md index 612d62a96f305..b7b00dd42db1b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1040,10 +1040,9 @@ Apart from these, the following properties are also available, and may be useful spark.eventLog.compression.codec - + zstd - The codec to compress logged events. If this is not given, - spark.io.compression.codec will be used. + The codec to compress logged events. 3.0.0 diff --git a/docs/core-migration-guide.md b/docs/core-migration-guide.md index ec7c3ab9cb568..232b9e31adb88 100644 --- a/docs/core-migration-guide.md +++ b/docs/core-migration-guide.md @@ -24,6 +24,8 @@ license: | ## Upgrading from Core 3.1 to 3.2 +- Since Spark 3.2, `spark.eventLog.compression.codec` is set to `zstd` by default which means Spark will not fallback to use `spark.io.compression.codec` anymore. + - Since Spark 3.2, `spark.storage.replication.proactive` is enabled by default which means Spark tries to replenish in case of the loss of cached RDD block replicas due to executor failures. To restore the behavior before Spark 3.2, you can set `spark.storage.replication.proactive` to `false`. ## Upgrading from Core 3.0 to 3.1 From b5afff59fa389b30312914ff141e97d5bc511359 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 24 Feb 2021 09:50:13 +0800 Subject: [PATCH 11/60] [SPARK-26138][SQL] Pushdown limit through InnerLike when condition is empty ### What changes were proposed in this pull request? This pr pushdown limit through InnerLike when condition is empty(Origin pr: #23104). For example: ```sql CREATE TABLE t1 using parquet AS SELECT id AS a, id AS b FROM range(2); CREATE TABLE t2 using parquet AS SELECT id AS d FROM range(2); SELECT * FROM t1 CROSS JOIN t2 LIMIT 10; ``` Before this pr: ``` == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- CollectLimit 10 +- BroadcastNestedLoopJoin BuildRight, Cross :- FileScan parquet default.t1[a#5L,b#6L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/private/var/folders/tg/f5mz46090wg7swzgdc69f8q03965_0/T/warehous..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange IdentityBroadcastMode, [id=#43] +- FileScan parquet default.t2[d#7L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/private/var/folders/tg/f5mz46090wg7swzgdc69f8q03965_0/T/warehous..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` After this pr: ``` == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- CollectLimit 10 +- BroadcastNestedLoopJoin BuildRight, Cross :- LocalLimit 10 : +- FileScan parquet default.t1[a#5L,b#6L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/private/var/folders/tg/f5mz46090wg7swzgdc69f8q03965_0/T/warehous..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange IdentityBroadcastMode, [id=#51] +- LocalLimit 10 +- FileScan parquet default.t2[d#7L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/private/var/folders/tg/f5mz46090wg7swzgdc69f8q03965_0/T/warehous..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #31567 from wangyum/SPARK-26138. Authored-by: Yuming Wang Signed-off-by: Yuming Wang --- .../sql/catalyst/optimizer/Optimizer.scala | 13 ++++++++---- .../optimizer/LimitPushdownSuite.scala | 20 ++++++++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 17 ++++++++++++++-- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e2b9cc65ebded..46a90f600b2a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -539,17 +539,22 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit. case LocalLimit(exp, u: Union) => LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) - // Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to - // the left and right sides, respectively. It's not safe to push limits below FULL OUTER - // JOIN in the general case without a more invasive rewrite. + // Add extra limits below JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to + // the left and right sides, respectively. For INNER and CROSS JOIN we push limits to + // both the left and right sides if join condition is empty. It's not safe to push limits + // below FULL OUTER JOIN in the general case without a more invasive rewrite. // We also need to ensure that this limit pushdown rule will not eventually introduce limits // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) => + case LocalLimit(exp, join @ Join(left, right, joinType, conditionOpt, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) + case _: InnerLike if conditionOpt.isEmpty => + join.copy( + left = maybePushLocalLimit(exp, left), + right = maybePushLocalLimit(exp, right)) case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index bb23b63c03cea..5c760264ff219 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Add -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -194,4 +194,22 @@ class LimitPushdownSuite extends PlanTest { LocalLimit(1, y.groupBy(Symbol("b"))(count(1))))).analyze comparePlans(expected2, optimized2) } + + test("SPARK-26138: pushdown limit through InnerLike when condition is empty") { + Seq(Cross, Inner).foreach { joinType => + val originalQuery = x.join(y, joinType).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(LocalLimit(1, y), joinType)).analyze + comparePlans(optimized, correctAnswer) + } + } + + test("SPARK-26138: Should not pushdown limit through InnerLike when condition is not empty") { + Seq(Cross, Inner).foreach { joinType => + val originalQuery = x.join(y, joinType, Some("x.a".attr === "y.b".attr)).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(y, joinType, Some("x.a".attr === "y.b".attr))).analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e6a18c3894497..fe8a080ac5aeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -29,14 +29,14 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} -import org.apache.spark.sql.catalyst.plans.logical.{Project, RepartitionByExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, RepartitionByExpression} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.UnionExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.FunctionsCommand -import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException +import org.apache.spark.sql.execution.datasources.{LogicalRelation, SchemaColumnConvertNotSupportedException} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan @@ -4021,6 +4021,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-26138 Pushdown limit through InnerLike when condition is empty") { + withTable("t1", "t2") { + spark.range(5).repartition(1).write.saveAsTable("t1") + spark.range(5).repartition(1).write.saveAsTable("t2") + val df = spark.sql("SELECT * FROM t1 CROSS JOIN t2 LIMIT 3") + val pushedLocalLimits = df.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: LogicalRelation) => l + } + assert(pushedLocalLimits.length === 2) + checkAnswer(df, Row(0, 0) :: Row(0, 1) :: Row(0, 2) :: Nil) + } + } } case class Foo(bar: Option[String]) From a6dcd5544dd7c21da1b93b43e5d3d7b67d097dc7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 24 Feb 2021 11:34:29 +0900 Subject: [PATCH 12/60] [MINOR][DOCS][K8S] Use hadoop-aws 3.2.2 in K8s example ### What changes were proposed in this pull request? This PR aims to update `Hadoop` dependency in K8S doc example. ### Why are the changes needed? Apache Spark 3.2.0 is using Apache Hadoop 3.2.2 by default. ### Does this PR introduce _any_ user-facing change? No. This is a doc-only change. ### How was this patch tested? N/A Closes #31628 from dongjoon-hyun/minor-doc. Authored-by: Dongjoon Hyun Signed-off-by: HyukjinKwon --- docs/running-on-kubernetes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index cf5c01629a240..6d2e61f8fd60c 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -212,7 +212,7 @@ A typical example of this using S3 is via passing the following options: ``` ... ---packages com.amazonaws:aws-java-sdk:1.7.4,org.apache.hadoop:hadoop-aws:2.7.6 +--packages org.apache.hadoop:hadoop-aws:3.2.2 --conf spark.kubernetes.file.upload.path=s3a:///path --conf spark.hadoop.fs.s3a.access.key=... --conf spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem From 80bad086c806fd507b1fb197b171f87333f2fb08 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 24 Feb 2021 11:36:54 +0900 Subject: [PATCH 13/60] Revert "[SPARK-32703][SQL] Replace deprecated API calls from SpecificParquetRecordReaderBase" This reverts commit 27873280ffbd73be6df230b4497701794ac81d91. --- .../SpecificParquetRecordReaderBase.java | 93 +++++++++++++------ .../parquet/ParquetFileFormat.scala | 10 +- .../ParquetPartitionReaderFactory.scala | 18 ++-- 3 files changed, 85 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 8d7a294f12311..0c82c0333aba0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -31,33 +32,36 @@ import scala.Option; +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; -import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; -import org.apache.parquet.HadoopReadOptions; -import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; import org.apache.parquet.hadoop.BadConfigurationException; import org.apache.parquet.hadoop.ParquetFileReader; import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; import org.apache.parquet.hadoop.api.InitContext; import org.apache.parquet.hadoop.api.ReadSupport; import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.hadoop.util.ConfigurationUtil; -import org.apache.parquet.hadoop.util.HadoopInputFile; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Types; import org.apache.spark.TaskContext; import org.apache.spark.TaskContext$; -import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType$; import org.apache.spark.util.AccumulatorV2; @@ -88,16 +92,58 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader blocks; - ParquetReadOptions options = HadoopReadOptions - .builder(configuration) - .withRange(split.getStart(), split.getStart() + split.getLength()) - .build(); - this.reader = new ParquetFileReader(HadoopInputFile.fromPath(file, configuration), options); - this.fileSchema = reader.getFileMetaData().getSchema(); - Map fileMetadata = reader.getFileMetaData().getKeyValueMetaData(); + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // SPARK-33532: After SPARK-13883 and SPARK-13989, the parquet read process will + // no longer enter this branch because `ParquetInputSplit` only be constructed in + // `ParquetFileFormat.buildReaderWithPartitionValues` and + // `ParquetPartitionReaderFactory.buildReaderBase` method, + // and the `rowGroupOffsets` in `ParquetInputSplit` set to null explicitly. + // We didn't delete this branch because PARQUET-131 wanted to move this to the + // parquet-mr project. + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + this.fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); ReadSupport readSupport = getReadSupportInstance(getReadSupportClass(configuration)); ReadSupport.ReadContext readContext = readSupport.init(new InitContext( taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); @@ -105,6 +151,8 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont String sparkRequestedSchemaString = configuration.get(ParquetReadSupport$.MODULE$.SPARK_ROW_REQUESTED_SCHEMA()); this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); + this.reader = new ParquetFileReader( + configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); this.totalRowCount = reader.getFilteredRecordCount(); // For test purpose. @@ -117,7 +165,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { @SuppressWarnings("unchecked") AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); - intAccum.add(reader.getRowGroups().size()); + intAccum.add(blocks.size()); } } } @@ -151,21 +199,12 @@ public static List listDirectory(File path) { */ protected void initialize(String path, List columns) throws IOException { Configuration config = new Configuration(); - config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false); - config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false); + config.set("spark.sql.parquet.binaryAsString", "false"); + config.set("spark.sql.parquet.int96AsTimestamp", "false"); this.file = new Path(path); long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); - ParquetReadOptions options = HadoopReadOptions - .builder(config) - .withRange(0, length) - .build(); - - ParquetMetadata footer; - try (ParquetFileReader reader = ParquetFileReader - .open(HadoopInputFile.fromPath(file, config), options)) { - footer = reader.getFooter(); - } + ParquetMetadata footer = readFooter(config, file, range(0, length)); List blocks = footer.getBlocks(); this.fileSchema = footer.getFileMetaData().getSchema(); @@ -188,8 +227,6 @@ protected void initialize(String path, List columns) throws IOException } } this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); - // unfortunately we'd have to create the reader again since there is no column projection - // in the new API. this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); this.totalRowCount = reader.getFilteredRecordCount(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index fa6e124c4e115..64a1ac8675104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -26,7 +26,6 @@ import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat @@ -260,7 +259,14 @@ class ParquetFileFormat assert(file.partitionValues.numFields == partitionSchema.size) val filePath = new Path(new URI(file.filePath)) - val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) val sharedConf = broadcastedHadoopConf.value.value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index af0100c73234e..20d0de45ba352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -20,13 +20,12 @@ import java.net.URI import java.time.ZoneId import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS -import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat, ParquetInputSplit, ParquetRecordReader} import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast @@ -122,14 +121,21 @@ case class ParquetPartitionReaderFactory( private def buildReaderBase[T]( file: PartitionedFile, buildReaderFunc: ( - FileSplit, InternalRow, TaskAttemptContextImpl, + ParquetInputSplit, InternalRow, TaskAttemptContextImpl, Option[FilterPredicate], Option[ZoneId], LegacyBehaviorPolicy.Value, LegacyBehaviorPolicy.Value) => RecordReader[Void, T]): RecordReader[Void, T] = { val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) - val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) lazy val footerFileMetaData = ParquetFileReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData @@ -193,7 +199,7 @@ case class ParquetPartitionReaderFactory( } private def createRowBaseParquetReader( - split: FileSplit, + split: ParquetInputSplit, partitionValues: InternalRow, hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], @@ -228,7 +234,7 @@ case class ParquetPartitionReaderFactory( } private def createParquetVectorizedReader( - split: FileSplit, + split: ParquetInputSplit, partitionValues: InternalRow, hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], From f542ecdb0d968af9ef66b1ec7270767f4ec42c41 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 24 Feb 2021 05:05:04 +0000 Subject: [PATCH 14/60] [SPARK-34245][CORE] Ensure Master removes executors that failed to send finished state ### What changes were proposed in this pull request? Use `ask` instead of `send` to sync the `ExecutorStateChanged` between Worker and Master and retry(up to 5 times) on the failure until the message is successfully handled by the Master. And the Worker would exit itself if the message can not be sent after 5 times retry. ### Why are the changes needed? If the Worker fails to send ExecutorStateChanged to the Master due to some unexpected errors, e.g., temporary network error, then the Master can't remove the finished executor normally and think the executor is still alive. In the worst case, if the executor is the only executor for the application, the application can get hang. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass existing tests. Closes #31348 from Ngone51/periodically-trigger-master-schedule. Authored-by: yi.wu Signed-off-by: Wenchen Fan --- .../apache/spark/deploy/master/Master.scala | 97 ++++++++++--------- .../apache/spark/deploy/worker/Worker.scala | 55 ++++++++++- .../spark/deploy/master/MasterSuite.scala | 2 +- 3 files changed, 102 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9f1b36ad1c8c1..471a3c1b45c39 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -310,54 +310,6 @@ private[deploy] class Master( schedule() } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) - execOption match { - case Some(exec) => - val appInfo = idToApp(appId) - val oldState = exec.state - exec.state = state - - if (state == ExecutorState.RUNNING) { - assert(oldState == ExecutorState.LAUNCHING, - s"executor $execId state transfer from $oldState to RUNNING is illegal") - appInfo.resetRetryCount() - } - - exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus, None)) - - if (ExecutorState.isFinished(state)) { - // Remove this executor from the worker and app - logInfo(s"Removing executor ${exec.fullId} because it is $state") - // If an application has already finished, preserve its - // state to display its information properly on the UI - if (!appInfo.isFinished) { - appInfo.removeExecutor(exec) - } - exec.worker.removeExecutor(exec) - - val normalExit = exitStatus == Some(0) - // Only retry certain number of times so we don't go into an infinite loop. - // Important note: this code path is not exercised by tests, so be very careful when - // changing this `if` condition. - // We also don't count failures from decommissioned workers since they are "expected." - if (!normalExit - && oldState != ExecutorState.DECOMMISSIONED - && appInfo.incrementRetryCount() >= maxExecutorRetries - && maxExecutorRetries >= 0) { // < 0 disables this application-killing path - val execs = appInfo.executors.values - if (!execs.exists(_.state == ExecutorState.RUNNING)) { - logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " + - s"${appInfo.retryCount} times; removing it") - removeApplication(appInfo, ApplicationState.FAILED) - } - } - } - schedule() - case None => - logWarning(s"Got status update for unknown executor $appId/$execId") - } - case DriverStateChanged(driverId, state, exception) => state match { case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED => @@ -550,6 +502,55 @@ private[deploy] class Master( } else { context.reply(0) } + + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => + val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) + execOption match { + case Some(exec) => + val appInfo = idToApp(appId) + val oldState = exec.state + exec.state = state + + if (state == ExecutorState.RUNNING) { + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") + appInfo.resetRetryCount() + } + + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus, None)) + + if (ExecutorState.isFinished(state)) { + // Remove this executor from the worker and app + logInfo(s"Removing executor ${exec.fullId} because it is $state") + // If an application has already finished, preserve its + // state to display its information properly on the UI + if (!appInfo.isFinished) { + appInfo.removeExecutor(exec) + } + exec.worker.removeExecutor(exec) + + val normalExit = exitStatus == Some(0) + // Only retry certain number of times so we don't go into an infinite loop. + // Important note: this code path is not exercised by tests, so be very careful when + // changing this `if` condition. + // We also don't count failures from decommissioned workers since they are "expected." + if (!normalExit + && oldState != ExecutorState.DECOMMISSIONED + && appInfo.incrementRetryCount() >= maxExecutorRetries + && maxExecutorRetries >= 0) { // < 0 disables this application-killing path + val execs = appInfo.executors.values + if (!execs.exists(_.state == ExecutorState.RUNNING)) { + logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " + + s"${appInfo.retryCount} times; removing it") + removeApplication(appInfo, ApplicationState.FAILED) + } + } + } + schedule() + case None => + logWarning(s"Got status update for unknown executor $appId/$execId") + } + context.reply(true) } override def onDisconnected(address: RpcAddress): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index cb36207d2ffc4..adc953286625a 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -26,7 +26,7 @@ import java.util.function.Supplier import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext -import scala.util.Random +import scala.util.{Failure, Random, Success} import scala.util.control.NonFatal import org.apache.spark.{SecurityManager, SparkConf} @@ -159,6 +159,18 @@ private[deploy] class Worker( val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + // Record the consecutive failure attempts of executor state change syncing with Master, + // so we don't try it endless. We will exit the Worker process at the end if the failure + // attempts reach the max attempts(5). In that case, it's highly possible the Worker + // suffers a severe network issue, and the Worker would exit finally either reaches max + // re-register attempts or max state syncing attempts. + // Map from executor fullId to its consecutive failure attempts number. It's supposed + // to be very small since it's only used for the temporary network drop, which doesn't + // happen frequently and recover soon. + private val executorStateSyncFailureAttempts = new HashMap[String, Int]() + lazy private val executorStateSyncFailureHandler = ExecutionContext.fromExecutor( + ThreadUtils.newDaemonSingleThreadExecutor("executor-state-sync-failure-handler")) + val retainedExecutors = conf.get(WORKER_UI_RETAINED_EXECUTORS) val retainedDrivers = conf.get(WORKER_UI_RETAINED_DRIVERS) @@ -620,7 +632,7 @@ private[deploy] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + syncExecutorStateWithMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(e.toString), None)) } } @@ -750,6 +762,43 @@ private[deploy] class Worker( } } + /** + * Send `ExecutorStateChanged` to the current master. Unlike `sendToMaster`, we use `askSync` + * to send the message in order to ensure Master can receive the message. + */ + private def syncExecutorStateWithMaster(newState: ExecutorStateChanged): Unit = { + master match { + case Some(masterRef) => + val fullId = s"${newState.appId}/${newState.execId}" + // SPARK-34245: We used async `send` to send the state previously. In that case, the + // finished executor can be leaked if Worker fails to send `ExecutorStateChanged` + // message to Master due to some unexpected errors, e.g., temporary network error. + // In the worst case, the application can get hang if the leaked executor is the only + // or last executor for the application. Therefore, we switch to `ask` to ensure + // the state is handled by Master. + masterRef.ask[Boolean](newState).onComplete { + case Success(_) => + executorStateSyncFailureAttempts.remove(fullId) + + case Failure(t) => + val failures = executorStateSyncFailureAttempts.getOrElse(fullId, 0) + 1 + if (failures < 5) { + logError(s"Failed to send $newState to Master $masterRef, " + + s"will retry ($failures/5).", t) + executorStateSyncFailureAttempts(fullId) = failures + self.send(newState) + } else { + logError(s"Failed to send $newState to Master $masterRef for 5 times. Giving up.") + System.exit(1) + } + }(executorStateSyncFailureHandler) + + case None => + logWarning( + s"Dropping $newState because the connection to master has not yet been established") + } + } + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } @@ -825,7 +874,7 @@ private[deploy] class Worker( private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged): Unit = { - sendToMaster(executorStateChanged) + syncExecutorStateWithMaster(executorStateChanged) val state = executorStateChanged.state if (ExecutorState.isFinished(state)) { val appId = executorStateChanged.appId diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 3a4a125a9a470..562075bb63dcd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -135,7 +135,7 @@ class MockExecutorLaunchFailWorker(master: Master, conf: SparkConf = new SparkCo assert(master.idToApp.contains(appId)) appIdsToLaunchExecutor += appId failedCnt += 1 - master.self.send(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None)) + master.self.askSync(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None)) case otherMsg => super.receive(otherMsg) } From f64fc224665a3dd1c1581fc1966cf9924be156db Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 24 Feb 2021 05:21:11 +0000 Subject: [PATCH 15/60] [SPARK-34290][SQL] Support v2 `TRUNCATE TABLE` ### What changes were proposed in this pull request? Implement the v2 execution node for the `TRUNCATE TABLE` command. ### Why are the changes needed? To have feature parity with DS v1, and support truncation of v2 tables. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? By running the unified tests for v1 and v2 tables: ``` $ build/sbt -Phive -Phive-thriftserver "test:testOnly *TruncateTableSuite" ``` Closes #31605 from MaxGekk/truncate-table-v2. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 25 ++- .../analysis/ResolvePartitionSpec.scala | 8 +- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../catalyst/plans/logical/v2Commands.scala | 2 +- .../analysis/ResolveSessionCatalog.scala | 2 +- .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../datasources/v2/TruncateTableExec.scala | 52 +++++ .../command/TruncateTableParserSuite.scala | 6 +- .../command/TruncateTableSuiteBase.scala | 201 +++++++++++++++++- .../command/v1/TruncateTableSuite.scala | 190 +---------------- .../command/v2/TruncateTableSuite.scala | 25 +-- 11 files changed, 308 insertions(+), 214 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1d44e6ba298e1..59e37e8a9bfaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} -import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table} +import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table, TruncatableTable} import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -576,6 +576,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case showPartitions: ShowPartitions => checkShowPartitions(showPartitions) + case truncateTable: TruncateTable => checkTruncateTable(truncateTable) + case _ => // Falls back to the following checks } @@ -1012,11 +1014,30 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { private def checkShowPartitions(showPartitions: ShowPartitions): Unit = showPartitions match { case ShowPartitions(rt: ResolvedTable, _, _) if !rt.table.isInstanceOf[SupportsPartitionManagement] => - failAnalysis(s"SHOW PARTITIONS cannot run for a table which does not support partitioning") + failAnalysis("SHOW PARTITIONS cannot run for a table which does not support partitioning") case ShowPartitions(ResolvedTable(_, _, partTable: SupportsPartitionManagement, _), _, _) if partTable.partitionSchema().isEmpty => failAnalysis( s"SHOW PARTITIONS is not allowed on a table that is not partitioned: ${partTable.name()}") case _ => } + + private def checkTruncateTable(truncateTable: TruncateTable): Unit = truncateTable match { + case TruncateTable(rt: ResolvedTable, None) if !rt.table.isInstanceOf[TruncatableTable] => + failAnalysis(s"The table ${rt.table.name()} does not support truncation") + case TruncateTable(rt: ResolvedTable, Some(_)) + if !rt.table.isInstanceOf[SupportsPartitionManagement] => + failAnalysis("TRUNCATE TABLE cannot run for a table which does not support partitioning") + case TruncateTable( + ResolvedTable(_, _, _: SupportsPartitionManagement, _), + Some(_: UnresolvedPartitionSpec)) => + failAnalysis("Partition spec is not resolved") + case TruncateTable( + ResolvedTable(_, _, table: SupportsPartitionManagement, _), + Some(spec: ResolvedPartitionSpec)) + if spec.names.length < table.partitionSchema.length && + !table.isInstanceOf[SupportsAtomicPartitionManagement] => + failAnalysis(s"The table ${table.name()} does not support truncation of multiple partitions") + case _ => + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 72298b285f2b6..e68c9793fa6a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddPartitions, DropPartitions, LogicalPlan, RenamePartitions, ShowPartitions} +import org.apache.spark.sql.catalyst.plans.logical.{AddPartitions, DropPartitions, LogicalPlan, RenamePartitions, ShowPartitions, TruncateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement @@ -67,6 +67,12 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { table.name, partSpecs.toSeq, table.partitionSchema()).headOption) + + case r @ TruncateTable(ResolvedTable(_, _, table: SupportsPartitionManagement, _), partSpecs) => + r.copy(partitionSpec = resolvePartitionSpecs( + table.name, + partSpecs.toSeq, + table.partitionSchema()).headOption) } private def resolvePartitionSpecs( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 23f9c8398b727..25e6cbeaa524c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3759,7 +3759,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { TruncateTable( createUnresolvedTable(ctx.multipartIdentifier, "TRUNCATE TABLE"), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) + Option(ctx.partitionSpec).map { spec => + UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(spec)) + }) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 12f13e73eadf5..ea67c5571ec9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -769,7 +769,7 @@ object ShowColumns { */ case class TruncateTable( child: LogicalPlan, - partitionSpec: Option[TablePartitionSpec]) extends Command { + partitionSpec: Option[PartitionSpec]) extends Command { override def children: Seq[LogicalPlan] = child :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 55e8c5fba0d3c..7ddd2ab6d913c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -401,7 +401,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case TruncateTable(ResolvedV1TableIdentifier(ident), partitionSpec) => TruncateTableCommand( ident.asTableIdentifier, - partitionSpec) + partitionSpec.toSeq.asUnresolvedPartitionSpecs.map(_.spec).headOption) case s @ ShowPartitions( ResolvedV1TableOrViewIdentifier(ident), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 16a6b2ef2f2d2..a5b092a1aa491 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -394,8 +394,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case ShowCreateTable(_: ResolvedTable, _) => throw new AnalysisException("SHOW CREATE TABLE is not supported for v2 tables.") - case TruncateTable(_: ResolvedTable, _) => - throw new AnalysisException("TRUNCATE TABLE is not supported for v2 tables.") + case TruncateTable(r: ResolvedTable, parts) => + TruncateTableExec( + r.table, + parts.toSeq.asResolvedPartitionSpecs.headOption, + recacheTable(r)) :: Nil case ShowColumns(_: ResolvedTable, _, _) => throw new AnalysisException("SHOW COLUMNS is not supported for v2 tables.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala new file mode 100644 index 0000000000000..17f86e26074a4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ResolvedPartitionSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table, TruncatableTable} + +/** + * Physical plan node for table truncation. + */ +case class TruncateTableExec( + table: Table, + partSpecs: Option[ResolvedPartitionSpec], + refreshCache: () => Unit) extends V2CommandExec { + + override def output: Seq[Attribute] = Seq.empty + + override protected def run(): Seq[InternalRow] = { + val isTableAltered = (table, partSpecs) match { + case (truncatableTable: TruncatableTable, None) => + truncatableTable.truncateTable() + case (partTable: SupportsPartitionManagement, Some(resolvedPartSpec)) + if partTable.partitionSchema.length == resolvedPartSpec.names.length => + partTable.truncatePartition(resolvedPartSpec.ident) + case (atomicPartTable: SupportsAtomicPartitionManagement, Some(resolvedPartitionSpec)) => + val partitionIdentifiers = atomicPartTable.listPartitionIdentifiers( + resolvedPartitionSpec.names.toArray, resolvedPartitionSpec.ident) + atomicPartTable.truncatePartitions(partitionIdentifiers) + case _ => throw new IllegalArgumentException( + s"Truncation of ${table.getClass.getName} is not supported") + } + if (isTableAltered) refreshCache() + Seq.empty + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala index 8499422501155..39531c84a63d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTable} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedPartitionSpec, UnresolvedTable} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.TruncateTable @@ -35,7 +35,7 @@ class TruncateTableParserSuite extends AnalysisTest with SharedSparkSession { parsePlan("TRUNCATE TABLE a.b.c PARTITION(ds='2017-06-10')"), TruncateTable( UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE", None), - Some(Map("ds" -> "2017-06-10")))) + Some(UnresolvedPartitionSpec(Map("ds" -> "2017-06-10"), None)))) } test("truncate a multi parts partition") { @@ -43,7 +43,7 @@ class TruncateTableParserSuite extends AnalysisTest with SharedSparkSession { parsePlan("TRUNCATE TABLE ns.tbl PARTITION(a = 1, B = 'ABC')"), TruncateTable( UnresolvedTable(Seq("ns", "tbl"), "TRUNCATE TABLE", None), - Some(Map("a" -> "1", "B" -> "ABC")))) + Some(UnresolvedPartitionSpec(Map("a" -> "1", "B" -> "ABC"), None)))) } test("empty values in non-optional partition specs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala index d31ff04c07757..001ec8e250def 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +import org.apache.spark.sql.internal.SQLConf /** * This base suite contains unified tests for the `TRUNCATE TABLE` command that check V1 and V2 @@ -31,4 +33,201 @@ import org.apache.spark.sql.QueryTest */ trait TruncateTableSuiteBase extends QueryTest with DDLCommandTestUtils { override val command = "TRUNCATE TABLE" + + test("table does not exist") { + withNamespaceAndTable("ns", "does_not_exist") { t => + val errMsg = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t") + }.getMessage + assert(errMsg.contains("Table not found")) + } + } + + test("truncate non-partitioned table") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c0 INT, c1 INT) $defaultUsing") + sql(s"INSERT INTO $t SELECT 0, 1") + + sql(s"TRUNCATE TABLE $t") + QueryTest.checkAnswer(sql(s"SELECT * FROM $t"), Nil) + } + } + + protected def createPartTable(t: String): Unit = { + sql(s""" + |CREATE TABLE $t (width INT, length INT, height INT) + |$defaultUsing + |PARTITIONED BY (width, length)""".stripMargin) + sql(s"INSERT INTO $t PARTITION (width = 0, length = 0) SELECT 0") + sql(s"INSERT INTO $t PARTITION (width = 1, length = 1) SELECT 1") + sql(s"INSERT INTO $t PARTITION (width = 1, length = 2) SELECT 3") + } + + test("SPARK-34418: truncate partitioned tables") { + withNamespaceAndTable("ns", "partTable") { t => + createPartTable(t) + sql(s"TRUNCATE TABLE $t PARTITION (width = 1, length = 1)") + checkAnswer(sql(s"SELECT width, length, height FROM $t"), Seq(Row(0, 0, 0), Row(1, 2, 3))) + checkPartitions(t, + Map("width" -> "0", "length" -> "0"), + Map("width" -> "1", "length" -> "1"), + Map("width" -> "1", "length" -> "2")) + } + + withNamespaceAndTable("ns", "partTable") { t => + createPartTable(t) + // support partial partition spec + sql(s"TRUNCATE TABLE $t PARTITION (width = 1)") + QueryTest.checkAnswer(sql(s"SELECT * FROM $t"), Row(0, 0, 0) :: Nil) + checkPartitions(t, + Map("width" -> "0", "length" -> "0"), + Map("width" -> "1", "length" -> "1"), + Map("width" -> "1", "length" -> "2")) + } + + withNamespaceAndTable("ns", "partTable") { t => + createPartTable(t) + // do nothing if no partition is matched for the given partial partition spec + sql(s"TRUNCATE TABLE $t PARTITION (width = 100)") + QueryTest.checkAnswer( + sql(s"SELECT width, length, height FROM $t"), + Seq(Row(0, 0, 0), Row(1, 1, 1), Row(1, 2, 3))) + + // throw exception if no partition is matched for the given non-partial partition spec. + intercept[NoSuchPartitionException] { + sql(s"TRUNCATE TABLE $t PARTITION (width = 100, length = 100)") + } + + // throw exception if the column in partition spec is not a partition column. + val errMsg = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t PARTITION (unknown = 1)") + }.getMessage + assert(errMsg.contains("unknown is not a valid partition column")) + } + } + + protected def invalidPartColumnError: String + + test("truncate a partition of non partitioned table") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c0 INT) $defaultUsing") + sql(s"INSERT INTO $t SELECT 0") + + val errMsg = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t PARTITION (c0=1)") + }.getMessage + assert(errMsg.contains(invalidPartColumnError)) + } + } + + test("SPARK-34418: preserve partitions in truncated table") { + withNamespaceAndTable("ns", "partTable") { t => + createPartTable(t) + checkAnswer( + sql(s"SELECT width, length, height FROM $t"), + Seq(Row(0, 0, 0), Row(1, 1, 1), Row(1, 2, 3))) + sql(s"TRUNCATE TABLE $t") + checkAnswer(sql(s"SELECT width, length, height FROM $t"), Nil) + checkPartitions(t, + Map("width" -> "0", "length" -> "0"), + Map("width" -> "1", "length" -> "1"), + Map("width" -> "1", "length" -> "2")) + } + } + + test("case sensitivity in resolving partition specs") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing PARTITIONED BY (id)") + sql(s"INSERT INTO $t PARTITION (id=0) SELECT 'abc'") + sql(s"INSERT INTO $t PARTITION (id=1) SELECT 'def'") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t PARTITION (ID=1)") + }.getMessage + assert(errMsg.contains("ID is not a valid partition column")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql(s"TRUNCATE TABLE $t PARTITION (ID=1)") + QueryTest.checkAnswer(sql(s"SELECT id, data FROM $t"), Row(0, "abc") :: Nil) + } + } + } + + test("SPARK-34215: keep table cached after truncation") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c0 int) $defaultUsing") + sql(s"INSERT INTO $t SELECT 0") + sql(s"CACHE TABLE $t") + assert(spark.catalog.isCached(t)) + QueryTest.checkAnswer(sql(s"SELECT * FROM $t"), Row(0) :: Nil) + sql(s"TRUNCATE TABLE $t") + assert(spark.catalog.isCached(t)) + QueryTest.checkAnswer(sql(s"SELECT * FROM $t"), Nil) + } + } + + test("truncation of views is not allowed") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (id int, part int) $defaultUsing PARTITIONED BY (part)") + sql(s"INSERT INTO $t PARTITION (part=0) SELECT 0") + + withView("v0") { + sql(s"CREATE VIEW v0 AS SELECT * FROM $t") + val errMsg = intercept[AnalysisException] { + sql("TRUNCATE TABLE v0") + }.getMessage + assert(errMsg.contains("'TRUNCATE TABLE' expects a table")) + } + + withTempView("v1") { + sql(s"CREATE TEMP VIEW v1 AS SELECT * FROM $t") + val errMsg = intercept[AnalysisException] { + sql("TRUNCATE TABLE v1") + }.getMessage + assert(errMsg.contains("'TRUNCATE TABLE' expects a table")) + } + + val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2" + withGlobalTempView("v2") { + sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t") + val errMsg = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $v2") + }.getMessage + assert(errMsg.contains("'TRUNCATE TABLE' expects a table")) + } + } + } + + test("keep dependents as cached after table truncation") { + withNamespaceAndTable("ns", "tbl") { t => + createPartTable(t) + cacheRelation(t) + QueryTest.checkAnswer( + sql(s"SELECT width, length, height FROM $t"), + Seq(Row(0, 0, 0), Row(1, 1, 1), Row(1, 2, 3))) + + withView("v0") { + sql(s"CREATE VIEW v0 AS SELECT * FROM $t") + cacheRelation("v0") + sql(s"TRUNCATE TABLE $t PARTITION (width = 1, length = 2)") + checkCachedRelation("v0", Seq(Row(0, 0, 0), Row(1, 1, 1))) + } + + withTempView("v1") { + sql(s"CREATE TEMP VIEW v1 AS SELECT * FROM $t") + cacheRelation("v1") + sql(s"TRUNCATE TABLE $t PARTITION (width = 1, length = 1)") + checkCachedRelation("v1", Seq(Row(0, 0, 0))) + } + + val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2" + withGlobalTempView("v2") { + sql(s"INSERT INTO $t PARTITION (width = 10, length = 10) SELECT 10") + sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t") + cacheRelation(v2) + sql(s"TRUNCATE TABLE $t PARTITION (width = 10, length = 10)") + checkCachedRelation(v2, Seq(Row(0, 0, 0))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala index 5e67f2bc122be..7da03db6f7371 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala @@ -24,7 +24,6 @@ import org.apache.hadoop.fs.permission.{AclEntry, AclEntryScope, AclEntryType, F import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.execution.command import org.apache.spark.sql.execution.command.FakeLocalFsFileSystem import org.apache.spark.sql.internal.SQLConf @@ -39,99 +38,8 @@ import org.apache.spark.sql.internal.SQLConf */ trait TruncateTableSuiteBase extends command.TruncateTableSuiteBase { - test("table does not exist") { - withNamespaceAndTable("ns", "does_not_exist") { t => - val errMsg = intercept[AnalysisException] { - sql(s"TRUNCATE TABLE $t") - }.getMessage - assert(errMsg.contains("Table not found")) - } - } - - test("truncate non-partitioned table") { - withNamespaceAndTable("ns", "tbl") { t => - sql(s"CREATE TABLE $t (c0 INT, c1 INT) $defaultUsing") - sql(s"INSERT INTO $t SELECT 0, 1") - - sql(s"TRUNCATE TABLE $t") - checkAnswer(sql(s"SELECT * FROM $t"), Nil) - - // not supported since the table is not partitioned - val errMsg = intercept[AnalysisException] { - sql(s"TRUNCATE TABLE $t PARTITION (width=1)") - }.getMessage - assert(errMsg.contains( - "TRUNCATE TABLE ... PARTITION is not supported for tables that are not partitioned")) - } - } - - private def createPartTable(t: String): Unit = { - sql(s""" - |CREATE TABLE $t (width INT, length INT, height INT) - |$defaultUsing - |PARTITIONED BY (width, length)""".stripMargin) - sql(s"INSERT INTO $t PARTITION (width = 0, length = 0) SELECT 0") - sql(s"INSERT INTO $t PARTITION (width = 1, length = 1) SELECT 1") - sql(s"INSERT INTO $t PARTITION (width = 1, length = 2) SELECT 3") - } - - test("SPARK-34418: truncate partitioned tables") { - withNamespaceAndTable("ns", "partTable") { t => - createPartTable(t) - sql(s"TRUNCATE TABLE $t PARTITION (width = 1, length = 1)") - checkAnswer(sql(s"SELECT width, length, height FROM $t"), Seq(Row(0, 0, 0), Row(1, 2, 3))) - checkPartitions(t, - Map("width" -> "0", "length" -> "0"), - Map("width" -> "1", "length" -> "1"), - Map("width" -> "1", "length" -> "2")) - } - - withNamespaceAndTable("ns", "partTable") { t => - createPartTable(t) - // support partial partition spec - sql(s"TRUNCATE TABLE $t PARTITION (width = 1)") - checkAnswer(sql(s"SELECT * FROM $t"), Row(0, 0, 0)) - checkPartitions(t, - Map("width" -> "0", "length" -> "0"), - Map("width" -> "1", "length" -> "1"), - Map("width" -> "1", "length" -> "2")) - } - - withNamespaceAndTable("ns", "partTable") { t => - createPartTable(t) - // do nothing if no partition is matched for the given partial partition spec - sql(s"TRUNCATE TABLE $t PARTITION (width = 100)") - checkAnswer( - sql(s"SELECT width, length, height FROM $t"), - Seq(Row(0, 0, 0), Row(1, 1, 1), Row(1, 2, 3))) - - // throw exception if no partition is matched for the given non-partial partition spec. - intercept[NoSuchPartitionException] { - sql(s"TRUNCATE TABLE $t PARTITION (width = 100, length = 100)") - } - - // throw exception if the column in partition spec is not a partition column. - val errMsg = intercept[AnalysisException] { - sql(s"TRUNCATE TABLE $t PARTITION (unknown = 1)") - }.getMessage - assert(errMsg.contains("unknown is not a valid partition column")) - } - } - - test("SPARK-34418: preserve partitions in truncated table") { - withNamespaceAndTable("ns", "partTable") { t => - createPartTable(t) - checkAnswer( - sql(s"SELECT width, length, height FROM $t"), - Seq(Row(0, 0, 0), Row(1, 1, 1), Row(1, 2, 3))) - sql(s"TRUNCATE TABLE $t") - checkAnswer(sql(s"SELECT width, length, height FROM $t"), Nil) - checkPartitions(t, - Map("width" -> "0", "length" -> "0"), - Map("width" -> "1", "length" -> "1"), - Map("width" -> "1", "length" -> "2")) - } - } + override val invalidPartColumnError = + "TRUNCATE TABLE ... PARTITION is not supported for tables that are not partitioned" test("SPARK-30312: truncate table - keep acl/permission") { Seq(true, false).foreach { ignore => @@ -246,24 +154,6 @@ trait TruncateTableSuiteBase extends command.TruncateTableSuiteBase { } } - test("case sensitivity in resolving partition specs") { - withNamespaceAndTable("ns", "tbl") { t => - sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing PARTITIONED BY (id)") - sql(s"INSERT INTO $t PARTITION (id=0) SELECT 'abc'") - sql(s"INSERT INTO $t PARTITION (id=1) SELECT 'def'") - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - val errMsg = intercept[AnalysisException] { - sql(s"TRUNCATE TABLE $t PARTITION (ID=1)") - }.getMessage - assert(errMsg.contains("ID is not a valid partition column")) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - sql(s"TRUNCATE TABLE $t PARTITION (ID=1)") - checkAnswer(sql(s"SELECT id, data FROM $t"), Row(0, "abc")) - } - } - } - test("change stats after truncate command") { withNamespaceAndTable("ns", "tbl") { t => sql(s"CREATE TABLE $t (id INT, value INT) $defaultUsing") @@ -293,82 +183,6 @@ trait TruncateTableSuiteBase extends command.TruncateTableSuiteBase { } } } - - test("SPARK-34215: keep table cached after truncation") { - withNamespaceAndTable("ns", "tbl") { t => - sql(s"CREATE TABLE $t (c0 int) $defaultUsing") - sql(s"INSERT INTO $t SELECT 0") - sql(s"CACHE TABLE $t") - assert(spark.catalog.isCached(t)) - checkAnswer(sql(s"SELECT * FROM $t"), Row(0)) - sql(s"TRUNCATE TABLE $t") - assert(spark.catalog.isCached(t)) - checkAnswer(sql(s"SELECT * FROM $t"), Seq.empty) - } - } - - test("keep dependents as cached after table truncation") { - withNamespaceAndTable("ns", "tbl") { t => - createPartTable(t) - cacheRelation(t) - checkCachedRelation(t, Seq(Row(0, 0, 0), Row(1, 1, 1), Row(3, 1, 2))) - - withView("v0") { - sql(s"CREATE VIEW v0 AS SELECT * FROM $t") - cacheRelation("v0") - sql(s"TRUNCATE TABLE $t PARTITION (width = 1, length = 2)") - checkCachedRelation("v0", Seq(Row(0, 0, 0), Row(1, 1, 1))) - } - - withTempView("v1") { - sql(s"CREATE TEMP VIEW v1 AS SELECT * FROM $t") - cacheRelation("v1") - sql(s"TRUNCATE TABLE $t PARTITION (width = 1, length = 1)") - checkCachedRelation("v1", Seq(Row(0, 0, 0))) - } - - val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2" - withGlobalTempView("v2") { - sql(s"INSERT INTO $t PARTITION (width = 10, length = 10) SELECT 10") - sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t") - cacheRelation(v2) - sql(s"TRUNCATE TABLE $t PARTITION (width = 10, length = 10)") - checkCachedRelation(v2, Seq(Row(0, 0, 0))) - } - } - } - - test("truncation of views is not allowed") { - withNamespaceAndTable("ns", "tbl") { t => - sql(s"CREATE TABLE $t (id int, part int) $defaultUsing PARTITIONED BY (part)") - sql(s"INSERT INTO $t PARTITION (part=0) SELECT 0") - - withView("v0") { - sql(s"CREATE VIEW v0 AS SELECT * FROM $t") - val errMsg = intercept[AnalysisException] { - sql("TRUNCATE TABLE v0") - }.getMessage - assert(errMsg.contains("'TRUNCATE TABLE' expects a table")) - } - - withTempView("v1") { - sql(s"CREATE TEMP VIEW v1 AS SELECT * FROM $t") - val errMsg = intercept[AnalysisException] { - sql("TRUNCATE TABLE v1") - }.getMessage - assert(errMsg.contains("'TRUNCATE TABLE' expects a table")) - } - - val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2" - withGlobalTempView("v2") { - sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t") - val errMsg = intercept[AnalysisException] { - sql(s"TRUNCATE TABLE $v2") - }.getMessage - assert(errMsg.contains("'TRUNCATE TABLE' expects a table")) - } - } - } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala index 90e99fbac5f94..1e14a080bf042 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala @@ -25,21 +25,18 @@ import org.apache.spark.sql.execution.command */ class TruncateTableSuite extends command.TruncateTableSuiteBase with CommandSuiteBase { - // TODO(SPARK-34290): Support v2 TRUNCATE TABLE - test("truncation of v2 tables is not supported") { - withNamespaceAndTable("ns", "tbl") { t => - sql(s"CREATE TABLE $t (id int, part int) $defaultUsing PARTITIONED BY (part)") - sql(s"INSERT INTO $t PARTITION (part=0) SELECT 0") - sql(s"INSERT INTO $t PARTITION (part=1) SELECT 1") + override val invalidPartColumnError = "not a valid partition column in table" - Seq( - s"TRUNCATE TABLE $t PARTITION (part=1)", - s"TRUNCATE TABLE $t").foreach { truncateCmd => - val errMsg = intercept[AnalysisException] { - sql(truncateCmd) - }.getMessage - assert(errMsg.contains("TRUNCATE TABLE is not supported for v2 tables")) - } + test("truncate a partition of a table which does not support partitions") { + withNamespaceAndTable("ns", "tbl", s"non_part_$catalog") { t => + sql(s"CREATE TABLE $t (c0 INT) $defaultUsing") + sql(s"INSERT INTO $t SELECT 0") + + val errMsg = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t PARTITION (c0=1)") + }.getMessage + assert(errMsg.contains( + "TRUNCATE TABLE cannot run for a table which does not support partitioning")) } } } From 5d9cfd727c21f0f50f92c0236cdeb20f3b9111dc Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 24 Feb 2021 13:40:58 +0800 Subject: [PATCH 16/60] [SPARK-34246][SQL] New type coercion syntax rules in ANSI mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? In Spark ANSI mode, the type coercion rules are based on the type precedence lists of the input data types. As per the section "Type precedence list determination" of "ISO/IEC 9075-2:2011 Information technology — Database languages - SQL — Part 2: Foundation (SQL/Foundation)", the type precedence lists of primitive data types are as following: - Byte: Byte, Short, Int, Long, Decimal, Float, Double - Short: Short, Int, Long, Decimal, Float, Double - Int: Int, Long, Decimal, Float, Double - Long: Long, Decimal, Float, Double - Decimal: Any wider Numeric type - Float: Float, Double - Double: Double - String: String - Date: Date, Timestamp - Timestamp: Timestamp - Binary: Binary - Boolean: Boolean - Interval: Interval As for complex data types, Spark will determine the precedent list recursively based on their sub-types. With the definition of type precedent list, the general type coercion rules are as following: - Data type S is allowed to be implicitly cast as type T iff T is in the precedence list of S - Comparison is allowed iff the data type precedence list of both sides has at least one common element. When evaluating the comparison, Spark casts both sides as the tightest common data type of their precedent lists. - There should be at least one common data type among all the children's precedence lists for the following operators. The data type of the operator is the tightest common precedent data type. ``` In, Except(odd), Intersect, Greatest, Least, Union, If, CaseWhen, CreateArray, Array Concat,Sequence, MapConcat, CreateMap ``` - For complex types (struct, array, map), Spark recursively looks into the element type and applies the rules above. If the element nullability is converted from true to false, add runtime null check to the elements. Note: this new type coercion system will allow implicit converting String type literals as other primitive types, in case of breaking too many existing Spark SQL queries. This is a special rule and it is not from the ANSI SQL standard. ### Why are the changes needed? The current type coercion rules are complex. Also, they are very hard to describe and understand. For details please refer the attached documentation "Default Type coercion rules of Spark" [Default Type coercion rules of Spark.pdf](https://github.com/apache/spark/files/5874362/Default.Type.coercion.rules.of.Spark.pdf) This PR is to create a new and strict type coercion system under ANSI mode. The rules are simple and clean, so that users can follow them easily ### Does this PR introduce _any_ user-facing change? Yes, new implicit cast syntax rules in ANSI mode. All the details are in the first section of this description. ### How was this patch tested? Unit tests Closes #31349 from gengliangwang/ansiImplicitConversion. Authored-by: Gengliang Wang Signed-off-by: Gengliang Wang --- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 267 ++++ .../ResolveTableValuedFunctions.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 855 +++++----- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../analysis/AnsiTypeCoercionSuite.scala | 1408 +++++++++++++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 13 +- .../expressions/MathExpressionsSuite.scala | 2 +- .../sql-tests/inputs/postgreSQL/with.sql | 2 +- .../sql-tests/results/ansi/datetime.sql.out | 6 +- .../sql-tests/results/ansi/interval.sql.out | 4 +- .../results/ansi/string-functions.sql.out | 24 +- .../results/postgreSQL/float4.sql.out | 20 +- .../results/postgreSQL/strings.sql.out | 20 +- .../sql-tests/results/postgreSQL/text.sql.out | 60 +- .../results/postgreSQL/timestamp.sql.out | 9 +- .../results/postgreSQL/union.sql.out | 6 +- .../sql-tests/results/postgreSQL/with.sql.out | 2 +- 18 files changed, 2214 insertions(+), 496 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 182f456afa9e4..b351d76411ff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -219,6 +219,12 @@ class Analyzer(override val catalogManager: CatalogManager) */ val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + private def typeCoercionRules(): List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { + AnsiTypeCoercion.typeCoercionRules + } else { + TypeCoercion.typeCoercionRules + } + override def batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -278,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveRandomSeed :: ResolveBinaryArithmetic :: ResolveUnion :: - TypeCoercion.typeCoercionRules ++ + typeCoercionRules ++ extendedResolutionRules : _*), Batch("Apply Char Padding", Once, ApplyCharTypePadding), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala new file mode 100644 index 0000000000000..67eef5b857f24 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +/** + * In Spark ANSI mode, the type coercion rules are based on the type precedence lists of the input + * data types. + * As per the section "Type precedence list determination" of "ISO/IEC 9075-2:2011 + * Information technology - Database languages - SQL - Part 2: Foundation (SQL/Foundation)", + * the type precedence lists of primitive data types are as following: + * * Byte: Byte, Short, Int, Long, Decimal, Float, Double + * * Short: Short, Int, Long, Decimal, Float, Double + * * Int: Int, Long, Decimal, Float, Double + * * Long: Long, Decimal, Float, Double + * * Decimal: Float, Double, or any wider Numeric type + * * Float: Float, Double + * * Double: Double + * * String: String + * * Date: Date, Timestamp + * * Timestamp: Timestamp + * * Binary: Binary + * * Boolean: Boolean + * * Interval: Interval + * As for complex data types, Spark will determine the precedent list recursively based on their + * sub-types and nullability. + * + * With the definition of type precedent list, the general type coercion rules are as following: + * * Data type S is allowed to be implicitly cast as type T iff T is in the precedence list of S + * * Comparison is allowed iff the data type precedence list of both sides has at least one common + * element. When evaluating the comparison, Spark casts both sides as the tightest common data + * type of their precedent lists. + * * There should be at least one common data type among all the children's precedence lists for + * the following operators. The data type of the operator is the tightest common precedent + * data type. + * * In + * * Except + * * Intersect + * * Greatest + * * Least + * * Union + * * If + * * CaseWhen + * * CreateArray + * * Array Concat + * * Sequence + * * MapConcat + * * CreateMap + * * For complex types (struct, array, map), Spark recursively looks into the element type and + * applies the rules above. + * Note: this new type coercion system will allow implicit converting String type literals as other + * primitive types, in case of breaking too many existing Spark SQL queries. This is a special + * rule and it is not from the ANSI SQL standard. + */ +object AnsiTypeCoercion extends TypeCoercionBase { + override def typeCoercionRules: List[Rule[LogicalPlan]] = + InConversion :: + WidenSetOperationTypes :: + PromoteStringLiterals :: + DecimalPrecision :: + FunctionArgumentConversion :: + ConcatCoercion :: + MapZipWithCoercion :: + EltCoercion :: + CaseWhenCoercion :: + IfCoercion :: + StackCoercion :: + Division :: + IntegralDivision :: + ImplicitTypeCasts :: + DateTimeOperations :: + WindowFrameCoercion :: + StringLiteralCoercion :: + Nil + + override def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + (t1, t2) match { + case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) + case (t1, NullType) => Some(t1) + + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + case (t1: NumericType, t2: NumericType) + if !t1.isInstanceOf[DecimalType] && !t2.isInstanceOf[DecimalType] => + val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) + val widerType = numericPrecedence(index) + if (widerType == FloatType) { + // If the input type is an Integral type and a Float type, simply return Double type as + // the tightest common type to avoid potential precision loss on converting the Integral + // type as Float type. + Some(DoubleType) + } else { + Some(widerType) + } + + case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => + Some(TimestampType) + + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) + } + + } + + override def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) + .orElse(findWiderTypeForDecimal(t1, t2)) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) + } + + override def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => + r match { + case Some(d) => findWiderTypeForTwo(d, c) + case _ => None + }) + } + + override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { + implicitCast(e.dataType, expectedType, e.foldable).map { dt => + if (dt == e.dataType) e else Cast(e, dt) + } + } + + /** + * In Ansi mode, the implicit cast is only allow when `expectedType` is in the type precedent + * list of `inType`. + */ + private def implicitCast( + inType: DataType, + expectedType: AbstractDataType, + isInputFoldable: Boolean): Option[DataType] = { + (inType, expectedType) match { + // If the expected type equals the input type, no need to cast. + case _ if expectedType.acceptsType(inType) => Some(inType) + + // Cast null type (usually from null literals) into target types + case (NullType, target) => Some(target.defaultConcreteType) + + // This type coercion system will allow implicit converting String type literals as other + // primitive types, in case of breaking too many existing Spark SQL queries. + case (StringType, a: AtomicType) if isInputFoldable => + Some(a) + + // If the target type is any Numeric type, convert the String type literal as Double type. + case (StringType, NumericType) if isInputFoldable => + Some(DoubleType) + + // If the target type is any Decimal type, convert the String type literal as Double type. + case (StringType, DecimalType) if isInputFoldable => + Some(DecimalType.SYSTEM_DEFAULT) + + // If input is a numeric type but not decimal, and we expect a decimal type, + // cast the input to decimal. + case (d: NumericType, DecimalType) => Some(DecimalType.forType(d)) + + case (n1: NumericType, n2: NumericType) => + val widerType = findWiderTypeForTwo(n1, n2) + widerType match { + // if the expected type is Float type, we should still return Float type. + case Some(DoubleType) if n1 != DoubleType && n2 == FloatType => Some(FloatType) + + case Some(dt) if dt == n2 => Some(dt) + + case _ => None + } + + case (DateType, TimestampType) => Some(TimestampType) + + // When we reach here, input type is not acceptable for any types in this type collection, + // try to find the first one we can implicitly cast. + case (_, TypeCollection(types)) => + types.flatMap(implicitCast(inType, _, isInputFoldable)).headOption + + // Implicit cast between array types. + // + // Compare the nullabilities of the from type and the to type, check whether the cast of + // the nullability is resolvable by the following rules: + // 1. If the nullability of the to type is true, the cast is always allowed; + // 2. If the nullabilities of both the from type and the to type are false, the cast is + // allowed. + // 3. Otherwise, the cast is not allowed + case (ArrayType(fromType, containsNullFrom), ArrayType(toType: DataType, containsNullTo)) + if Cast.resolvableNullability(containsNullFrom, containsNullTo) => + implicitCast(fromType, toType, isInputFoldable).map(ArrayType(_, containsNullTo)) + + // Implicit cast between Map types. + // Follows the same semantics of implicit casting between two array types. + // Refer to documentation above. + case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn)) + if Cast.resolvableNullability(fn, tn) => + val newKeyType = implicitCast(fromKeyType, toKeyType, isInputFoldable) + val newValueType = implicitCast(fromValueType, toValueType, isInputFoldable) + if (newKeyType.isDefined && newValueType.isDefined) { + Some(MapType(newKeyType.get, newValueType.get, tn)) + } else { + None + } + + case _ => None + } + } + + override def canCast(from: DataType, to: DataType): Boolean = AnsiCast.canCast(from, to) + + /** + * Promotes string literals that appear in arithmetic and comparison expressions. + */ + object PromoteStringLiterals extends TypeCoercionRule { + private def castExpr(expr: Expression, targetType: DataType): Expression = { + (expr.dataType, targetType) match { + case (NullType, dt) => Literal.create(null, targetType) + case (l, dt) if (l != dt) => Cast(expr, targetType) + case _ => expr + } + } + + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case b @ BinaryOperator(left @ StringType(), right @ AtomicType()) if left.foldable => + b.makeCopy(Array(castExpr(left, right.dataType), right)) + + case b @ BinaryOperator(left @ AtomicType(), right @ StringType()) if right.foldable => + b.makeCopy(Array(left, castExpr(right, left.dataType))) + + case Abs(e @ StringType()) if e.foldable => Abs(Cast(e, DoubleType)) + case m @ UnaryMinus(e @ StringType(), _) if e.foldable => + m.withNewChildren(Seq(Cast(e, DoubleType))) + case UnaryPositive(e @ StringType()) if e.foldable => UnaryPositive(Cast(e, DoubleType)) + + // Promotes string literals in `In predicate`. + case p @ In(a, b) + if a.dataType != StringType && b.exists( e => e.foldable && e.dataType == StringType) => + val newList = b.map { + case e @ StringType() if e.foldable => Cast(e, a.dataType) + case other => other + } + p.makeCopy(Array(a, newList)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 983e4b0e901cf..75c7fad74f29f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -40,7 +40,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { if (args.length == values.length) { val casted = values.zip(args).map { case (value, (_, expectedType)) => - TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) + TypeCoercion.implicitCast(value, expectedType) } if (casted.forall(_.isDefined)) { return Some(casted.map(_.get)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1e1564829d706..8876719ec23a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -31,126 +31,52 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - -/** - * A collection of [[Rule]] that can be used to coerce differing types that participate in - * operations into compatible ones. - * - * Notes about type widening / tightest common types: Broadly, there are two cases when we need - * to widen data types (e.g. union, binary comparison). In case 1, we are looking for a common - * data type for two or more data types, and in this case no loss of precision is allowed. Examples - * include type inference in JSON (e.g. what's the column's data type if one row is an integer - * while the other row is a long?). In case 2, we are looking for a widened data type with - * some acceptable loss of precision (e.g. there is no common type for double and decimal because - * double's range is larger than decimal, and yet decimal is more precise than double, but in - * union we would cast the decimal into double). - */ -object TypeCoercion { - - def typeCoercionRules: List[Rule[LogicalPlan]] = - InConversion :: - WidenSetOperationTypes :: - PromoteStrings :: - DecimalPrecision :: - BooleanEquality :: - FunctionArgumentConversion :: - ConcatCoercion :: - MapZipWithCoercion :: - EltCoercion :: - CaseWhenCoercion :: - IfCoercion :: - StackCoercion :: - Division :: - IntegralDivision :: - ImplicitTypeCasts :: - DateTimeOperations :: - WindowFrameCoercion :: - StringLiteralCoercion :: - Nil - - // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. - // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - IndexedSeq( - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType) +abstract class TypeCoercionBase { + /** + * A collection of [[Rule]] that can be used to coerce differing types that participate in + * operations into compatible ones. + */ + def typeCoercionRules: List[Rule[LogicalPlan]] /** - * Case 1 type widening (see the classdoc comment above for TypeCoercion). - * * Find the tightest common type of two types that might be used in a binary expression. * This handles all numeric types except fixed-precision decimals interacting with each other or * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { - case (t1, t2) if t1 == t2 => Some(t1) - case (NullType, t1) => Some(t1) - case (t1, NullType) => Some(t1) - - case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => - Some(t2) - case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => - Some(t1) - - // Promote numeric types to the highest of the two - case (t1: NumericType, t2: NumericType) - if !t1.isInstanceOf[DecimalType] && !t2.isInstanceOf[DecimalType] => - val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) - Some(numericPrecedence(index)) - - case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => - Some(TimestampType) - - case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) - } - - /** Promotes all the way to StringType. */ - private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { - case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) - case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) - case _ => None - } + def findTightestCommonType(type1: DataType, type2: DataType): Option[DataType] /** - * This function determines the target type of a comparison operator when one operand - * is a String and the other is not. It also handles when one op is a Date and the - * other is a Timestamp by making the target type to be String. + * Looking for a widened data type of two given data types with some acceptable loss of precision. + * E.g. there is no common type for double and decimal because double's range + * is larger than decimal, and yet decimal is more precise than double, but in + * union we would cast the decimal into double. */ - private def findCommonTypeForBinaryComparison( - dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { - case (StringType, DateType) - => if (conf.castDatetimeToString) Some(StringType) else Some(DateType) - case (DateType, StringType) - => if (conf.castDatetimeToString) Some(StringType) else Some(DateType) - case (StringType, TimestampType) - => if (conf.castDatetimeToString) Some(StringType) else Some(TimestampType) - case (TimestampType, StringType) - => if (conf.castDatetimeToString) Some(StringType) else Some(TimestampType) - case (StringType, NullType) => Some(StringType) - case (NullType, StringType) => Some(StringType) + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] - // Cast to TimestampType when we compare DateType with TimestampType - // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true - case (TimestampType, DateType) => Some(TimestampType) - case (DateType, TimestampType) => Some(TimestampType) + /** + * Looking for a widened data type of a given sequence of data types with some acceptable loss + * of precision. + * E.g. there is no common type for double and decimal because double's range + * is larger than decimal, and yet decimal is more precise than double, but in + * union we would cast the decimal into double. + */ + def findWiderCommonType(types: Seq[DataType]): Option[DataType] - // There is no proper decimal type we can pick, - // using double type is the best we can do. - // See SPARK-22469 for details. - case (n: DecimalType, s: StringType) => Some(DoubleType) - case (s: StringType, n: DecimalType) => Some(DoubleType) + /** + * Given an expected data type, try to cast the expression and return the cast expression. + * + * If the expression already fits the input type, we simply return the expression itself. + * If the expression has an incompatible type that cannot be implicitly cast, return None. + */ + def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] - case (l: StringType, r: AtomicType) if r != StringType => Some(r) - case (l: AtomicType, r: StringType) if l != StringType => Some(l) - case (l, r) => None - } + /** + * Whether casting `from` as `to` is valid. + */ + def canCast(from: DataType, to: DataType): Boolean - private def findTypeForComplex( + protected def findTypeForComplex( t1: DataType, t2: DataType, findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { @@ -182,65 +108,24 @@ object TypeCoercion { } /** - * The method finds a common type for data types that differ only in nullable flags, including - * `nullable`, `containsNull` of [[ArrayType]] and `valueContainsNull` of [[MapType]]. - * If the input types are different besides nullable flags, None is returned. + * Finds a wider type when one or both types are decimals. If the wider decimal type exceeds + * system limitation, this rule will truncate the decimal type. If a decimal and other fractional + * types are compared, returns a double type. */ - def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { - if (t1 == t2) { - Some(t1) - } else { - findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) - } - } - - def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { - if (types.isEmpty) { - None - } else { - types.tail.foldLeft[Option[DataType]](Some(types.head)) { - case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) - case _ => None - } + protected def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = { + (dt1, dt2) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => + Some(DoubleType) + case _ => None } } - /** - * Case 2 type widening (see the classdoc comment above for TypeCoercion). - * - * i.e. the main difference with [[findTightestCommonType]] is that here we allow some - * loss of precision when widening decimal and double, and promotion to string. - */ - def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) - .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse(stringPromotion(t1, t2)) - .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) - } - - /** - * Whether the data type contains StringType. - */ - def hasStringType(dt: DataType): Boolean = dt match { - case StringType => true - case ArrayType(et, _) => hasStringType(et) - // Add StructType if we support string promotion for struct fields in the future. - case _ => false - } - - private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { - // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal - // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. - // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, - // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. - val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) - (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => - r match { - case Some(d) => findWiderTypeForTwo(d, c) - case _ => None - }) - } - /** * Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to * string. If the wider decimal type exceeds system limitation, this rule will truncate @@ -261,25 +146,6 @@ object TypeCoercion { }) } - /** - * Finds a wider type when one or both types are decimals. If the wider decimal type exceeds - * system limitation, this rule will truncate the decimal type. If a decimal and other fractional - * types are compared, returns a double type. - */ - private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = { - (dt1, dt2) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => - Some(DoubleType) - case _ => None - } - } - /** * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull. */ @@ -301,30 +167,32 @@ object TypeCoercion { } /** - * Widens numeric types and converts strings to numbers when appropriate. - * - * Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, by Tom White + * Widens the data types of the children of Union/Except/Intersect. + * 1. When ANSI mode is off: + * Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, by Tom White * - * The implicit conversion rules can be summarized as follows: - * - Any integral numeric type can be implicitly converted to a wider type. - * - All the integral numeric types, FLOAT, and (perhaps surprisingly) STRING can be implicitly - * converted to DOUBLE. - * - TINYINT, SMALLINT, and INT can all be converted to FLOAT. - * - BOOLEAN types cannot be converted to any other type. - * - Any integral numeric type can be implicitly converted to decimal type. - * - two different decimal types will be converted into a wider decimal type for both of them. - * - decimal type will be converted into double if there float or double together with it. + * The implicit conversion rules can be summarized as follows: + * - Any integral numeric type can be implicitly converted to a wider type. + * - All the integral numeric types, FLOAT, and (perhaps surprisingly) STRING can be + * implicitly converted to DOUBLE. + * - TINYINT, SMALLINT, and INT can all be converted to FLOAT. + * - BOOLEAN types cannot be converted to any other type. + * - Any integral numeric type can be implicitly converted to decimal type. + * - two different decimal types will be converted into a wider decimal type for both of them. + * - decimal type will be converted into double if there float or double together with it. * - * Additionally, all types when UNION-ed with strings will be promoted to strings. - * Other string conversions are handled by PromoteStrings. + * All types when UNION-ed with strings will be promoted to + * strings. Other string conversions are handled by PromoteStrings. * - * Widening types might result in loss of precision in the following cases: - * - IntegerType to FloatType - * - LongType to FloatType - * - LongType to DoubleType - * - DecimalType to Double + * Widening types might result in loss of precision in the following cases: + * - IntegerType to FloatType + * - LongType to FloatType + * - LongType to DoubleType + * - DecimalType to Double * - * This rule is only applied to Union/Except/Intersect + * 2. When ANSI mode is on: + * The implicit conversion is determined by the closest common data type from the precedent + * lists from left and right child. See the comments of Object `AnsiTypeCoercion` for details. */ object WidenSetOperationTypes extends TypeCoercionRule { @@ -411,62 +279,6 @@ object TypeCoercion { } } - /** - * Promotes strings that appear in arithmetic expressions. - */ - object PromoteStrings extends TypeCoercionRule { - private def castExpr(expr: Expression, targetType: DataType): Expression = { - (expr.dataType, targetType) match { - case (NullType, dt) => Literal.create(null, targetType) - case (l, dt) if (l != dt) => Cast(expr, targetType) - case _ => expr - } - } - - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - case a @ BinaryArithmetic(left @ StringType(), right) - if right.dataType != CalendarIntervalType => - a.makeCopy(Array(Cast(left, DoubleType), right)) - case a @ BinaryArithmetic(left, right @ StringType()) - if left.dataType != CalendarIntervalType => - a.makeCopy(Array(left, Cast(right, DoubleType))) - - // For equality between string and timestamp we cast the string to a timestamp - // so that things like rounding of subsecond precision does not affect the comparison. - case p @ Equality(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, TimestampType), right)) - case p @ Equality(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(left, Cast(right, TimestampType))) - - case p @ BinaryComparison(left, right) - if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => - val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get - p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) - - case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) - case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case s @ StddevPop(e @ StringType(), _) => - s.withNewChildren(Seq(Cast(e, DoubleType))) - case s @ StddevSamp(e @ StringType(), _) => - s.withNewChildren(Seq(Cast(e, DoubleType))) - case m @ UnaryMinus(e @ StringType(), _) => m.withNewChildren(Seq(Cast(e, DoubleType))) - case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) - case v @ VariancePop(e @ StringType(), _) => - v.withNewChildren(Seq(Cast(e, DoubleType))) - case v @ VarianceSamp(e @ StringType(), _) => - v.withNewChildren(Seq(Cast(e, DoubleType))) - case s @ Skewness(e @ StringType(), _) => - s.withNewChildren(Seq(Cast(e, DoubleType))) - case k @ Kurtosis(e @ StringType(), _) => - k.withNewChildren(Seq(Cast(e, DoubleType))) - } - } - /** * Handles type coercion for both IN expression with subquery and IN * expressions without subquery. @@ -526,65 +338,21 @@ object TypeCoercion { } /** - * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. + * This ensure that the types for various functions are as expected. */ - object BooleanEquality extends Rule[LogicalPlan] { - private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) - private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) + object FunctionArgumentConversion extends TypeCoercionRule { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Hive treats (true = 1) as true and (false = 0) as true, - // all other cases are considered as false. - - // We may simplify the expression if one side is literal numeric values - // TODO: Maybe these rules should go into the optimizer. - case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => bool - case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => Not(bool) - case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) - if trueValues.contains(value) => bool - case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) - if falseValues.contains(value) => Not(bool) - case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => And(IsNotNull(bool), bool) - case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) - case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) - if trueValues.contains(value) => And(IsNotNull(bool), bool) - case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) - if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) - - case EqualTo(left @ BooleanType(), right @ NumericType()) => - EqualTo(Cast(left, right.dataType), right) - case EqualTo(left @ NumericType(), right @ BooleanType()) => - EqualTo(left, Cast(right, left.dataType)) - case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => - EqualNullSafe(Cast(left, right.dataType), right) - case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => - EqualNullSafe(left, Cast(right, left.dataType)) - } - } - - /** - * This ensure that the types for various functions are as expected. - */ - object FunctionArgumentConversion extends TypeCoercionRule { - - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) => - val types = children.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType))) - case None => a - } + case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType))) + case None => a + } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(c.inputTypesForMerging) => @@ -597,7 +365,7 @@ object TypeCoercion { case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) && ArrayType.acceptsType(arr.dataType) => val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull - ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match { + implicitCast(arr, ArrayType(StringType, containsNull)) match { case Some(castedArr) => ArrayJoin(castedArr, d, nr) case None => aj } @@ -644,8 +412,10 @@ object TypeCoercion { case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => val types = es.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) - case None => c + case Some(finalDataType) => + Coalesce(es.map(castIfNotSameType(_, finalDataType))) + case None => + c } // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if @@ -785,7 +555,7 @@ object TypeCoercion { case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } c.copy(children = newChildren) } @@ -832,11 +602,11 @@ object TypeCoercion { case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c case c @ Elt(children, _) => val index = children.head - val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newIndex = implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { children.tail.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } } else { children.tail @@ -962,100 +732,6 @@ object TypeCoercion { case (_, other) => other } } - - /** - * Given an expected data type, try to cast the expression and return the cast expression. - * - * If the expression already fits the input type, we simply return the expression itself. - * If the expression has an incompatible type that cannot be implicitly cast, return None. - */ - def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { - implicitCast(e.dataType, expectedType).map { dt => - if (dt == e.dataType) e else Cast(e, dt) - } - } - - private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = { - // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. - // We wrap immediately an Option after this. - @Nullable val ret: DataType = (inType, expectedType) match { - // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.acceptsType(inType) => inType - - // Cast null type (usually from null literals) into target types - case (NullType, target) => target.defaultConcreteType - - // If the function accepts any numeric type and the input is a string, we follow the hive - // convention and cast that input into a double - case (StringType, NumericType) => NumericType.defaultConcreteType - - // Implicit cast among numeric types. When we reach here, input type is not acceptable. - - // If input is a numeric type but not decimal, and we expect a decimal type, - // cast the input to decimal. - case (d: NumericType, DecimalType) => DecimalType.forType(d) - // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) => target - - // Implicit cast between date time types - case (DateType, TimestampType) => TimestampType - case (TimestampType, DateType) => DateType - - // Implicit cast from/to string - case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT - case (StringType, target: NumericType) => target - case (StringType, DateType) => DateType - case (StringType, TimestampType) => TimestampType - case (StringType, BinaryType) => BinaryType - // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => StringType - - // When we reach here, input type is not acceptable for any types in this type collection, - // try to find the first one we can implicitly cast. - case (_, TypeCollection(types)) => - types.flatMap(implicitCast(inType, _)).headOption.orNull - - // Implicit cast between array types. - // - // Compare the nullabilities of the from type and the to type, check whether the cast of - // the nullability is resolvable by the following rules: - // 1. If the nullability of the to type is true, the cast is always allowed; - // 2. If the nullability of the to type is false, and the nullability of the from type is - // true, the cast is never allowed; - // 3. If the nullabilities of both the from type and the to type are false, the cast is - // allowed only when Cast.forceNullable(fromType, toType) is false. - case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) => - implicitCast(fromType, toType).map(ArrayType(_, true)).orNull - - case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null - - case (ArrayType(fromType, false), ArrayType(toType: DataType, false)) - if !Cast.forceNullable(fromType, toType) => - implicitCast(fromType, toType).map(ArrayType(_, false)).orNull - - // Implicit cast between Map types. - // Follows the same semantics of implicit casting between two array types. - // Refer to documentation above. Make sure that both key and values - // can not be null after the implicit cast operation by calling forceNullable - // method. - case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn)) - if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) => - if (Cast.forceNullable(fromValueType, toValueType) && !tn) { - null - } else { - val newKeyType = implicitCast(fromKeyType, toKeyType).orNull - val newValueType = implicitCast(fromValueType, toValueType).orNull - if (newKeyType != null && newValueType != null) { - MapType(newKeyType, newValueType, tn) - } else { - null - } - } - - case _ => null - } - Option(ret) - } } /** @@ -1077,7 +753,7 @@ object TypeCoercion { case (e: SpecialFrameBoundary, _) => e case (e, _: DateType) => e case (e, _: TimestampType) => e - case (e: Expression, t) if e.dataType != t && Cast.canCast(e.dataType, t) => + case (e: Expression, t) if e.dataType != t && canCast(e.dataType, t) => Cast(e, t) case _ => boundary } @@ -1113,6 +789,363 @@ object TypeCoercion { } } +/** + * A collection of [[Rule]] that can be used to coerce differing types that participate in + * operations into compatible ones. + * + * Notes about type widening / tightest common types: Broadly, there are two cases when we need + * to widen data types (e.g. union, binary comparison). In case 1, we are looking for a common + * data type for two or more data types, and in this case no loss of precision is allowed. Examples + * include type inference in JSON (e.g. what's the column's data type if one row is an integer + * while the other row is a long?). In case 2, we are looking for a widened data type with + * some acceptable loss of precision (e.g. there is no common type for double and decimal because + * double's range is larger than decimal, and yet decimal is more precise than double, but in + * union we would cast the decimal into double). + */ +object TypeCoercion extends TypeCoercionBase { + + override def typeCoercionRules: List[Rule[LogicalPlan]] = + InConversion :: + WidenSetOperationTypes :: + PromoteStrings :: + DecimalPrecision :: + BooleanEquality :: + FunctionArgumentConversion :: + ConcatCoercion :: + MapZipWithCoercion :: + EltCoercion :: + CaseWhenCoercion :: + IfCoercion :: + StackCoercion :: + Division :: + IntegralDivision :: + ImplicitTypeCasts :: + DateTimeOperations :: + WindowFrameCoercion :: + StringLiteralCoercion :: + Nil + + override def canCast(from: DataType, to: DataType): Boolean = Cast.canCast(from, to) + + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. + // The conversion for integral and floating point types have a linear widening hierarchy: + val numericPrecedence = + IndexedSeq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + override def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + (t1, t2) match { + case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) + case (t1, NullType) => Some(t1) + + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + // Promote numeric types to the highest of the two + case (t1: NumericType, t2: NumericType) + if !t1.isInstanceOf[DecimalType] && !t2.isInstanceOf[DecimalType] => + val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) + Some(numericPrecedence(index)) + + case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => + Some(TimestampType) + + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) + } + } + + /** Promotes all the way to StringType. */ + private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case _ => None + } + + /** + * This function determines the target type of a comparison operator when one operand + * is a String and the other is not. It also handles when one op is a Date and the + * other is a Timestamp by making the target type to be String. + */ + def findCommonTypeForBinaryComparison( + dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { + case (StringType, DateType) + => if (conf.castDatetimeToString) Some(StringType) else Some(DateType) + case (DateType, StringType) + => if (conf.castDatetimeToString) Some(StringType) else Some(DateType) + case (StringType, TimestampType) + => if (conf.castDatetimeToString) Some(StringType) else Some(TimestampType) + case (TimestampType, StringType) + => if (conf.castDatetimeToString) Some(StringType) else Some(TimestampType) + case (StringType, NullType) => Some(StringType) + case (NullType, StringType) => Some(StringType) + + // Cast to TimestampType when we compare DateType with TimestampType + // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true + case (TimestampType, DateType) => Some(TimestampType) + case (DateType, TimestampType) => Some(TimestampType) + + // There is no proper decimal type we can pick, + // using double type is the best we can do. + // See SPARK-22469 for details. + case (n: DecimalType, s: StringType) => Some(DoubleType) + case (s: StringType, n: DecimalType) => Some(DoubleType) + + case (l: StringType, r: AtomicType) if r != StringType => Some(r) + case (l: AtomicType, r: StringType) if l != StringType => Some(l) + case (l, r) => None + } + + override def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) + .orElse(findWiderTypeForDecimal(t1, t2)) + .orElse(stringPromotion(t1, t2)) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) + } + + override def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { + // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal + // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. + // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, + // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. + val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) + (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => + r match { + case Some(d) => findWiderTypeForTwo(d, c) + case _ => None + }) + } + + override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { + implicitCast(e.dataType, expectedType).map { dt => + if (dt == e.dataType) e else Cast(e, dt) + } + } + + private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = { + // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. + // We wrap immediately an Option after this. + @Nullable val ret: DataType = (inType, expectedType) match { + // If the expected type is already a parent of the input type, no need to cast. + case _ if expectedType.acceptsType(inType) => inType + + // Cast null type (usually from null literals) into target types + case (NullType, target) => target.defaultConcreteType + + // If the function accepts any numeric type and the input is a string, we follow the hive + // convention and cast that input into a double + case (StringType, NumericType) => NumericType.defaultConcreteType + + // Implicit cast among numeric types. When we reach here, input type is not acceptable. + + // If input is a numeric type but not decimal, and we expect a decimal type, + // cast the input to decimal. + case (d: NumericType, DecimalType) => DecimalType.forType(d) + // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long + case (_: NumericType, target: NumericType) => target + + // Implicit cast between date time types + case (DateType, TimestampType) => TimestampType + case (TimestampType, DateType) => DateType + + // Implicit cast from/to string + case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT + case (StringType, target: NumericType) => target + case (StringType, DateType) => DateType + case (StringType, TimestampType) => TimestampType + case (StringType, BinaryType) => BinaryType + // Cast any atomic type to string. + case (any: AtomicType, StringType) if any != StringType => StringType + + // When we reach here, input type is not acceptable for any types in this type collection, + // try to find the first one we can implicitly cast. + case (_, TypeCollection(types)) => + types.flatMap(implicitCast(inType, _)).headOption.orNull + + // Implicit cast between array types. + // + // Compare the nullabilities of the from type and the to type, check whether the cast of + // the nullability is resolvable by the following rules: + // 1. If the nullability of the to type is true, the cast is always allowed; + // 2. If the nullability of the to type is false, and the nullability of the from type is + // true, the cast is never allowed; + // 3. If the nullabilities of both the from type and the to type are false, the cast is + // allowed only when Cast.forceNullable(fromType, toType) is false. + case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) => + implicitCast(fromType, toType).map(ArrayType(_, true)).orNull + + case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null + + case (ArrayType(fromType, false), ArrayType(toType: DataType, false)) + if !Cast.forceNullable(fromType, toType) => + implicitCast(fromType, toType).map(ArrayType(_, false)).orNull + + // Implicit cast between Map types. + // Follows the same semantics of implicit casting between two array types. + // Refer to documentation above. Make sure that both key and values + // can not be null after the implicit cast operation by calling forceNullable + // method. + case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn)) + if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) => + if (Cast.forceNullable(fromValueType, toValueType) && !tn) { + null + } else { + val newKeyType = implicitCast(fromKeyType, toKeyType).orNull + val newValueType = implicitCast(fromValueType, toValueType).orNull + if (newKeyType != null && newValueType != null) { + MapType(newKeyType, newValueType, tn) + } else { + null + } + } + + case _ => null + } + Option(ret) + } + + /** + * The method finds a common type for data types that differ only in nullable flags, including + * `nullable`, `containsNull` of [[ArrayType]] and `valueContainsNull` of [[MapType]]. + * If the input types are different besides nullable flags, None is returned. + */ + def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { + if (t1 == t2) { + Some(t1) + } else { + findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) + } + } + + def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { + if (types.isEmpty) { + None + } else { + types.tail.foldLeft[Option[DataType]](Some(types.head)) { + case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) + case _ => None + } + } + } + + /** + * Whether the data type contains StringType. + */ + def hasStringType(dt: DataType): Boolean = dt match { + case StringType => true + case ArrayType(et, _) => hasStringType(et) + // Add StructType if we support string promotion for struct fields in the future. + case _ => false + } + + /** + * Promotes strings that appear in arithmetic expressions. + */ + object PromoteStrings extends TypeCoercionRule { + private def castExpr(expr: Expression, targetType: DataType): Expression = { + (expr.dataType, targetType) match { + case (NullType, dt) => Literal.create(null, targetType) + case (l, dt) if (l != dt) => Cast(expr, targetType) + case _ => expr + } + } + + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case a @ BinaryArithmetic(left @ StringType(), right) + if right.dataType != CalendarIntervalType => + a.makeCopy(Array(Cast(left, DoubleType), right)) + case a @ BinaryArithmetic(left, right @ StringType()) + if left.dataType != CalendarIntervalType => + a.makeCopy(Array(left, Cast(right, DoubleType))) + + // For equality between string and timestamp we cast the string to a timestamp + // so that things like rounding of subsecond precision does not affect the comparison. + case p @ Equality(left @ StringType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, TimestampType), right)) + case p @ Equality(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(left, Cast(right, TimestampType))) + + case p @ BinaryComparison(left, right) + if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get + p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) + + case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) + case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case s @ StddevPop(e @ StringType(), _) => + s.withNewChildren(Seq(Cast(e, DoubleType))) + case s @ StddevSamp(e @ StringType(), _) => + s.withNewChildren(Seq(Cast(e, DoubleType))) + case m @ UnaryMinus(e @ StringType(), _) => m.withNewChildren(Seq(Cast(e, DoubleType))) + case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) + case v @ VariancePop(e @ StringType(), _) => + v.withNewChildren(Seq(Cast(e, DoubleType))) + case v @ VarianceSamp(e @ StringType(), _) => + v.withNewChildren(Seq(Cast(e, DoubleType))) + case s @ Skewness(e @ StringType(), _) => + s.withNewChildren(Seq(Cast(e, DoubleType))) + case k @ Kurtosis(e @ StringType(), _) => + k.withNewChildren(Seq(Cast(e, DoubleType))) + } + } + + /** + * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. + */ + object BooleanEquality extends Rule[LogicalPlan] { + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + // Hive treats (true = 1) as true and (false = 0) as true, + // all other cases are considered as false. + + // We may simplify the expression if one side is literal numeric values + // TODO: Maybe these rules should go into the optimizer. + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => bool + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(bool) + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => bool + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => Not(bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) + + case EqualTo(left @ BooleanType(), right @ NumericType()) => + EqualTo(Cast(left, right.dataType), right) + case EqualTo(left @ NumericType(), right @ BooleanType()) => + EqualTo(left, Cast(right, left.dataType)) + case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => + EqualNullSafe(Cast(left, right.dataType), right) + case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => + EqualNullSafe(left, Cast(right, left.dataType)) + } + } +} + trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { /** * Applies any changes to [[AttributeReference]] data types that are made by the transform method diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9a10de72e5e47..901afd0440075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2687,7 +2687,7 @@ object SQLConf { buildConf("spark.sql.legacy.typeCoercion.datetimeToString.enabled") .internal() .doc("If it is set to true, date/timestamp will cast to string in binary comparisons " + - "with String") + s"with String when ${ANSI_ENABLED.key} is false.") .version("3.0.0") .booleanConf .createWithDefault(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala new file mode 100644 index 0000000000000..88e082f1580ef --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -0,0 +1,1408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class AnsiTypeCoercionSuite extends AnalysisTest { + import TypeCoercionSuite._ + + // scalastyle:off line.size.limit + // The following table shows all implicit data type conversions that are not visible to the user. + // +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // | Source Type\CAST TO | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType | NumericType | IntegralType | + // +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | + // | ShortType | X | ShortType | IntegerType | LongType | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | + // | IntegerType | X | X | IntegerType | LongType | DoubleType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | + // | LongType | X | X | X | LongType | DoubleType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | + // | FloatType | X | X | X | X | FloatType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | X | + // | DoubleType | X | X | X | X | X | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | X | + // | Dec(10, 2) | X | X | X | X | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | X | + // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | X | X | X | X | X | X | X | X | X | X | X | + // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | X | X | X | X | X | X | X | X | X | X | X | + // | StringType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | + // | DateType | X | X | X | X | X | X | X | X | X | X | DateType | TimestampType | X | X | X | X | X | X | X | X | + // | TimestampType | X | X | X | X | X | X | X | X | X | X | X | TimestampType | X | X | X | X | X | X | X | X | + // | ArrayType | X | X | X | X | X | X | X | X | X | X | X | X | ArrayType* | X | X | X | X | X | X | X | + // | MapType | X | X | X | X | X | X | X | X | X | X | X | X | X | MapType* | X | X | X | X | X | X | + // | StructType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | StructType* | X | X | X | X | X | + // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | + // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | + // +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // Note: StructType* is castable when all the internal child types are castable according to the table. + // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. + // scalastyle:on line.size.limit + + private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { + // Check default value + val castDefault = AnsiTypeCoercion.implicitCast(default(from), to) + assert(DataType.equalsIgnoreCompatibleNullability( + castDefault.map(_.dataType).getOrElse(null), expected), + s"Failed to cast $from to $to") + + // Check null value + val castNull = AnsiTypeCoercion.implicitCast(createNull(from), to) + assert(DataType.equalsIgnoreCaseAndNullability( + castNull.map(_.dataType).getOrElse(null), expected), + s"Failed to cast $from to $to") + } + + private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { + // Check default value + val castDefault = AnsiTypeCoercion.implicitCast(default(from), to) + assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") + + // Check null value + val castNull = AnsiTypeCoercion.implicitCast(createNull(from), to) + assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") + } + + private def shouldCastStringLiteral(to: AbstractDataType, expected: DataType): Unit = { + val input = Literal("123") + val castResult = AnsiTypeCoercion.implicitCast(input, to) + assert(DataType.equalsIgnoreCaseAndNullability( + castResult.map(_.dataType).getOrElse(null), expected), + s"Failed to cast String literal to $to") + } + + private def shouldNotCastStringLiteral(to: AbstractDataType): Unit = { + val input = Literal("123") + val castResult = AnsiTypeCoercion.implicitCast(input, to) + assert(castResult.isEmpty, s"Should not be able to cast String literal to $to") + } + + private def shouldNotCastStringInput(to: AbstractDataType): Unit = { + val input = AttributeReference("s", StringType)() + val castResult = AnsiTypeCoercion.implicitCast(input, to) + assert(castResult.isEmpty, s"Should not be able to cast non-foldable String input to $to") + } + + private def default(dataType: DataType): Expression = dataType match { + case ArrayType(internalType: DataType, _) => + CreateArray(Seq(Literal.default(internalType))) + case MapType(keyDataType: DataType, valueDataType: DataType, _) => + CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType))) + case _ => Literal.default(dataType) + } + + private def createNull(dataType: DataType): Expression = dataType match { + case ArrayType(internalType: DataType, _) => + CreateArray(Seq(Literal.create(null, internalType))) + case MapType(keyDataType: DataType, valueDataType: DataType, _) => + CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType))) + case _ => Literal.create(null, dataType) + } + + // Check whether the type `checkedType` can be cast to all the types in `castableTypes`, + // but cannot be cast to the other types in `allTypes`. + private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { + val nonCastableTypes = allTypes.filterNot(castableTypes.contains) + + castableTypes.foreach { tpe => + shouldCast(checkedType, tpe, tpe) + } + nonCastableTypes.foreach { tpe => + shouldNotCast(checkedType, tpe) + } + } + + private def checkWidenType( + widenFunc: (DataType, DataType) => Option[DataType], + t1: DataType, + t2: DataType, + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + var found = widenFunc(t1, t2) + assert(found == expected, + s"Expected $expected as wider common type for $t1 and $t2, found $found") + // Test both directions to make sure the widening is symmetric. + if (isSymmetric) { + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } + } + + test("implicit type cast - ByteType") { + val checkedType = ByteType + checkTypeCasting(checkedType, castableTypes = numericTypes) + shouldCast(checkedType, DecimalType, DecimalType.ByteDecimal) + shouldCast(checkedType, NumericType, checkedType) + shouldCast(checkedType, IntegralType, checkedType) + } + + test("implicit type cast - ShortType") { + val checkedType = ShortType + checkTypeCasting(checkedType, castableTypes = numericTypes.filterNot(_ == ByteType)) + shouldCast(checkedType, DecimalType, DecimalType.ShortDecimal) + shouldCast(checkedType, NumericType, checkedType) + shouldCast(checkedType, IntegralType, checkedType) + } + + test("implicit type cast - IntegerType") { + val checkedType = IntegerType + checkTypeCasting(checkedType, castableTypes = + Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT)) + shouldCast(IntegerType, DecimalType, DecimalType.IntDecimal) + shouldCast(checkedType, NumericType, checkedType) + shouldCast(checkedType, IntegralType, checkedType) + } + + test("implicit type cast - LongType") { + val checkedType = LongType + checkTypeCasting(checkedType, castableTypes = + Seq(LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT)) + shouldCast(checkedType, DecimalType, DecimalType.LongDecimal) + shouldCast(checkedType, NumericType, checkedType) + shouldCast(checkedType, IntegralType, checkedType) + } + + test("implicit type cast - FloatType") { + val checkedType = FloatType + checkTypeCasting(checkedType, castableTypes = Seq(FloatType, DoubleType)) + shouldCast(checkedType, DecimalType, DecimalType.FloatDecimal) + shouldCast(checkedType, NumericType, checkedType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - DoubleType") { + val checkedType = DoubleType + checkTypeCasting(checkedType, castableTypes = Seq(DoubleType)) + shouldCast(checkedType, DecimalType, DecimalType.DoubleDecimal) + shouldCast(checkedType, NumericType, checkedType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - DecimalType(10, 2)") { + val checkedType = DecimalType(10, 2) + checkTypeCasting(checkedType, castableTypes = fractionalTypes) + shouldCast(checkedType, DecimalType, checkedType) + shouldCast(checkedType, NumericType, checkedType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - BinaryType") { + val checkedType = BinaryType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - BooleanType") { + val checkedType = BooleanType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + shouldNotCast(checkedType, StringType) + } + + test("implicit type cast - unfoldable StringType") { + val nonCastableTypes = allTypes.filterNot(_ == StringType) + nonCastableTypes.foreach { dt => + shouldNotCastStringInput(dt) + } + shouldNotCastStringInput(DecimalType) + shouldNotCastStringInput(NumericType) + } + + test("implicit type cast - foldable StringType") { + atomicTypes.foreach { dt => + shouldCastStringLiteral(dt, dt) + } + allTypes.filterNot(atomicTypes.contains).foreach { dt => + shouldNotCastStringLiteral(dt) + } + shouldCastStringLiteral(DecimalType, DecimalType.defaultConcreteType) + shouldCastStringLiteral(NumericType, DoubleType) + } + + test("implicit type cast - DateType") { + val checkedType = DateType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType, TimestampType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + shouldNotCast(checkedType, StringType) + } + + test("implicit type cast - TimestampType") { + val checkedType = TimestampType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - unfoldable ArrayType(StringType)") { + val input = AttributeReference("a", ArrayType(StringType))() + val nonCastableTypes = allTypes.filterNot(_ == StringType) + nonCastableTypes.map(ArrayType(_)).foreach { dt => + assert(AnsiTypeCoercion.implicitCast(input, dt).isEmpty) + } + assert(AnsiTypeCoercion.implicitCast(input, DecimalType).isEmpty) + assert(AnsiTypeCoercion.implicitCast(input, NumericType).isEmpty) + } + + test("implicit type cast - foldable arrayType(StringType)") { + val input = Literal(Array("1")) + assert(AnsiTypeCoercion.implicitCast(input, ArrayType(StringType)) == Some(input)) + (numericTypes ++ datetimeTypes ++ Seq(BinaryType)).foreach { dt => + assert(AnsiTypeCoercion.implicitCast(input, ArrayType(dt)) == + Some(Cast(input, ArrayType(dt)))) + } + } + + test("implicit type cast between two Map types") { + val sourceType = MapType(IntegerType, IntegerType, true) + val castableTypes = + Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT) + val targetTypes = castableTypes.map { t => + MapType(t, sourceType.valueType, valueContainsNull = true) + } + val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t => + MapType(t, sourceType.valueType, valueContainsNull = true) + } + + // Tests that its possible to setup implicit casts between two map types when + // source map's key type is integer and the target map's key type are either Byte, Short, + // Long, Double, Float, Decimal(38, 18) or String. + targetTypes.foreach { targetType => + shouldCast(sourceType, targetType, targetType) + } + + // Tests that its not possible to setup implicit casts between two map types when + // source map's key type is integer and the target map's key type are either Binary, + // Boolean, Date, Timestamp, Array, Struct, CalendarIntervalType or NullType + nonCastableTargetTypes.foreach { targetType => + shouldNotCast(sourceType, targetType) + } + + // Tests that its not possible to cast from nullable map type to not nullable map type. + val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t => + MapType(t, sourceType.valueType, valueContainsNull = false) + } + val sourceMapExprWithValueNull = + CreateMap(Seq(Literal.default(sourceType.keyType), + Literal.create(null, sourceType.valueType))) + targetNotNullableTypes.foreach { targetType => + val castDefault = + AnsiTypeCoercion.implicitCast(sourceMapExprWithValueNull, targetType) + assert(castDefault.isEmpty, + s"Should not be able to cast $sourceType to $targetType, but got $castDefault") + } + } + + test("implicit type cast - StructType().add(\"a1\", StringType)") { + val checkedType = new StructType().add("a1", StringType) + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - NullType") { + val checkedType = NullType + checkTypeCasting(checkedType, castableTypes = allTypes) + shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) + shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) + shouldCast(checkedType, IntegralType, IntegralType.defaultConcreteType) + } + + test("implicit type cast - CalendarIntervalType") { + val checkedType = CalendarIntervalType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("eligible implicit type cast - TypeCollection") { + shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType) + + shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) + shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) + + shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType) + shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) + shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) + shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) + + shouldNotCast(IntegerType, TypeCollection(StringType, BinaryType)) + shouldNotCast(IntegerType, TypeCollection(BinaryType, StringType)) + + shouldCast(DecimalType.SYSTEM_DEFAULT, + TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) + shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) + shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) + shouldNotCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType)) + + shouldNotCastStringInput(TypeCollection(NumericType, BinaryType)) + shouldCastStringLiteral(TypeCollection(NumericType, BinaryType), DoubleType) + + shouldCast( + ArrayType(StringType, false), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, false)) + + shouldCast( + ArrayType(StringType, true), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, true)) + } + + test("ineligible implicit type cast - TypeCollection") { + shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) + } + + test("tightest common bound for types") { + def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = + checkWidenType(AnsiTypeCoercion.findTightestCommonType, t1, t2, expected) + + // Null + widenTest(NullType, NullType, Some(NullType)) + + // Boolean + widenTest(NullType, BooleanType, Some(BooleanType)) + widenTest(BooleanType, BooleanType, Some(BooleanType)) + widenTest(IntegerType, BooleanType, None) + widenTest(LongType, BooleanType, None) + + // Integral + widenTest(NullType, ByteType, Some(ByteType)) + widenTest(NullType, IntegerType, Some(IntegerType)) + widenTest(NullType, LongType, Some(LongType)) + widenTest(ShortType, IntegerType, Some(IntegerType)) + widenTest(ShortType, LongType, Some(LongType)) + widenTest(IntegerType, LongType, Some(LongType)) + widenTest(LongType, LongType, Some(LongType)) + + // Floating point + widenTest(NullType, FloatType, Some(FloatType)) + widenTest(NullType, DoubleType, Some(DoubleType)) + widenTest(FloatType, DoubleType, Some(DoubleType)) + widenTest(FloatType, FloatType, Some(FloatType)) + widenTest(DoubleType, DoubleType, Some(DoubleType)) + + // Integral mixed with floating point. + widenTest(IntegerType, FloatType, Some(DoubleType)) + widenTest(IntegerType, DoubleType, Some(DoubleType)) + widenTest(IntegerType, DoubleType, Some(DoubleType)) + widenTest(LongType, FloatType, Some(DoubleType)) + widenTest(LongType, DoubleType, Some(DoubleType)) + + widenTest(DecimalType(2, 1), DecimalType(3, 2), None) + widenTest(DecimalType(2, 1), DoubleType, None) + widenTest(DecimalType(2, 1), IntegerType, None) + widenTest(DoubleType, DecimalType(2, 1), None) + + // StringType + widenTest(NullType, StringType, Some(StringType)) + widenTest(StringType, StringType, Some(StringType)) + widenTest(IntegerType, StringType, None) + widenTest(LongType, StringType, None) + + // TimestampType + widenTest(NullType, TimestampType, Some(TimestampType)) + widenTest(TimestampType, TimestampType, Some(TimestampType)) + widenTest(DateType, TimestampType, Some(TimestampType)) + widenTest(IntegerType, TimestampType, None) + widenTest(StringType, TimestampType, None) + + // ComplexType + widenTest(NullType, + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, false))) + widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) + widenTest(StringType, MapType(IntegerType, StringType, true), None) + widenTest(ArrayType(IntegerType), StructType(Seq()), None) + + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("b", IntegerType))), + None) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", DoubleType, nullable = false))), + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) + + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = false))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("A", IntegerType))), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkWidenType( + AnsiTypeCoercion.findTightestCommonType, + StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType))), + StructType(Seq(StructField("A", IntegerType), StructField("b", IntegerType))), + Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), + isSymmetric = false) + } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + ArrayType(NullType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + MapType(NullType, NullType, true), + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) + + widenTest( + new StructType() + .add("null", NullType, nullable = true), + new StructType() + .add("null", IntegerType, nullable = false), + Some(new StructType() + .add("null", IntegerType, nullable = true))) + + widenTest( + ArrayType(NullType, containsNull = false), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = false))) + + widenTest(MapType(NullType, NullType, false), + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, false))) + + widenTest( + new StructType() + .add("null", NullType, nullable = false), + new StructType() + .add("null", IntegerType, nullable = false), + Some(new StructType() + .add("null", IntegerType, nullable = false))) + } + + test("wider common type for decimal and array") { + def widenTestWithoutStringPromotion( + t1: DataType, + t2: DataType, + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + AnsiTypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) + } + + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } + + // Without string promotion + widenTestWithoutStringPromotion(IntegerType, StringType, None) + widenTestWithoutStringPromotion(StringType, TimestampType, None) + widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) + widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + widenTestWithoutStringPromotion( + MapType(LongType, IntegerType), MapType(StringType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, LongType), MapType(IntegerType, StringType), None) + widenTestWithoutStringPromotion( + MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) + } + + private def ruleTest(rule: Rule[LogicalPlan], + initial: Expression, transformed: Expression): Unit = { + ruleTest(Seq(rule), initial, transformed) + } + + private def ruleTest( + rules: Seq[Rule[LogicalPlan]], + initial: Expression, + transformed: Expression): Unit = { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) + } + + comparePlans( + analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + + test("cast NullType for expressions that implement ExpectsInputTypes") { + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + AnyTypeUnaryExpression(Literal.create(null, NullType)), + AnyTypeUnaryExpression(Literal.create(null, NullType))) + + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + NumericTypeUnaryExpression(Literal.create(null, NullType)), + NumericTypeUnaryExpression(Literal.create(null, DoubleType))) + } + + test("cast NullType for binary operators") { + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) + + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) + } + + test("coalesce casts") { + val rule = AnsiTypeCoercion.FunctionArgumentConversion + + val intLit = Literal(1) + val longLit = Literal.create(1L) + val doubleLit = Literal(1.0) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.0f, FloatType) + val doubleNullLit = Cast(floatNullLit, DoubleType) + val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis()))) + val strArrayLit = Literal(Array("c")) + val intArrayLit = Literal(Array(1)) + + ruleTest(rule, + Coalesce(Seq(doubleLit, intLit, floatLit)), + Coalesce(Seq(doubleLit, Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(longLit, intLit, decimalLit)), + Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), + Cast(intLit, DecimalType(22, 0)), decimalLit))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit)), + Coalesce(Seq(Cast(nullLit, IntegerType), intLit))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, stringLit)), + Coalesce(Seq(timestampLit, stringLit))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, intLit)), + Coalesce(Seq(Cast(nullLit, DoubleType), doubleNullLit, Cast(intLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), + Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), + Cast(decimalLit, DoubleType), doubleLit))) + + // There is no a common type among Float/Double/String + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), + Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit))) + + // There is no a common type among Timestamp/Int/String + ruleTest(rule, + Coalesce(Seq(timestampLit, intLit, stringLit)), + Coalesce(Seq(timestampLit, intLit, stringLit))) + + ruleTest(rule, + Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), + Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit))) + } + + test("CreateArray casts") { + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0f, FloatType) + :: Nil), + CreateArray(Literal(1.0) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0f, FloatType), DoubleType) + :: Nil)) + + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal(1.0) + :: Literal(1) + :: Literal("a") + :: Nil), + CreateArray(Literal(1.0) + :: Literal(1) + :: Literal("a") + :: Nil)) + + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal.create(null, DecimalType(5, 3)) + :: Literal(1) + :: Nil), + CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3)) + :: Literal(1).cast(DecimalType(13, 3)) + :: Nil)) + + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal.create(null, DecimalType(5, 3)) + :: Literal.create(null, DecimalType(22, 10)) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil), + CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil)) + } + + test("CreateMap casts") { + // type coercion for map keys + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal.create(2.0f, FloatType) + :: Literal("b") + :: Nil), + CreateMap(Cast(Literal(1), DoubleType) + :: Literal("a") + :: Cast(Literal.create(2.0f, FloatType), DoubleType) + :: Literal("b") + :: Nil)) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal.create(null, DecimalType(5, 3)) + :: Literal("a") + :: Literal.create(2.0f, FloatType) + :: Literal("b") + :: Nil), + CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType) + :: Literal("a") + :: Literal.create(2.0f, FloatType).cast(DoubleType) + :: Literal("b") + :: Nil)) + // type coercion for map values + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal(2) + :: Literal(3.0) + :: Nil), + CreateMap(Literal(1) + :: Literal("a") + :: Literal(2) + :: Literal(3.0) + :: Nil)) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal.create(null, DecimalType(38, 0)) + :: Literal(2) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil), + CreateMap(Literal(1) + :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) + :: Literal(2) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil)) + // type coercion for both map keys and values + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal(2.0) + :: Literal(3.0) + :: Nil), + CreateMap(Cast(Literal(1), DoubleType) + :: Literal("a") + :: Literal(2.0) + :: Literal(3.0) + :: Nil)) + } + + test("greatest/least cast") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0f, FloatType) + :: Nil), + operator(Literal(1.0) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0f, FloatType), DoubleType) + :: Nil)) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil)) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal.create(null, DecimalType(10, 5)) + :: Literal(1) + :: Nil), + operator(Literal(1.0) + :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) + :: Literal(1).cast(DoubleType) + :: Nil)) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + operator(Literal.create(null, DecimalType(15, 0)) + :: Literal.create(null, DecimalType(10, 5)) + :: Literal(1) + :: Nil), + operator(Literal.create(null, DecimalType(15, 0)).cast(DecimalType(20, 5)) + :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) + :: Literal(1).cast(DecimalType(20, 5)) + :: Nil)) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + operator(Literal.create(2L, LongType) + :: Literal(1) + :: Literal.create(null, DecimalType(10, 5)) + :: Nil), + operator(Literal.create(2L, LongType).cast(DecimalType(25, 5)) + :: Literal(1).cast(DecimalType(25, 5)) + :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(25, 5)) + :: Nil)) + } + } + + test("nanvl casts") { + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) + ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) + } + + test("type coercion for If") { + val rule = AnsiTypeCoercion.IfCoercion + val intLit = Literal(1) + val doubleLit = Literal(1.0) + val trueLit = Literal.create(true, BooleanType) + val falseLit = Literal.create(false, BooleanType) + val stringLit = Literal.create("c", StringType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + + ruleTest(rule, + If(Literal(true), Literal(1), Literal(1L)), + If(Literal(true), Cast(Literal(1), LongType), Literal(1L))) + + ruleTest(rule, + If(Literal.create(null, NullType), Literal(1), Literal(1)), + If(Literal.create(null, BooleanType), Literal(1), Literal(1))) + + ruleTest(rule, + If(AssertTrue(trueLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(AssertTrue(falseLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(trueLit, intLit, doubleLit), + If(trueLit, Cast(intLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, doubleLit), + If(trueLit, Cast(floatLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, decimalLit), + If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType))) + + ruleTest(rule, + If(falseLit, stringLit, doubleLit), + If(falseLit, stringLit, doubleLit)) + + ruleTest(rule, + If(trueLit, timestampLit, stringLit), + If(trueLit, timestampLit, stringLit)) + } + + test("type coercion for CaseKeyWhen") { + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) + ) + ruleTest(AnsiTypeCoercion.CaseWhenCoercion, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) + ruleTest(AnsiTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq((Literal(true), Literal(1.2))), + Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Literal(1.2))), + Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DoubleType)) + ) + ruleTest(AnsiTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq((Literal(true), Literal(100L))), + Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), + Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DecimalType(22, 2))) + ) + } + + test("type coercion for Stack") { + val rule = AnsiTypeCoercion.StackCoercion + + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal(null))), + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal.create(null, IntegerType)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1.0), Literal(null), Literal(3.0))), + Stack(Seq(Literal(3), Literal(1.0), Literal.create(null, DoubleType), Literal(3.0)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal("2"), Literal("3"))), + Stack(Seq(Literal(3), Literal.create(null, StringType), Literal("2"), Literal("3")))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))), + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal(null), + Literal(null), Literal("2"))), + Stack(Seq(Literal(2), + Literal(1), Literal.create(null, StringType), + Literal.create(null, IntegerType), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(1), + Literal("2"), Literal(null))), + Stack(Seq(Literal(2), + Literal.create(null, StringType), Literal(1), + Literal("2"), Literal.create(null, IntegerType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(null), + Literal(1), Literal("2"))), + Stack(Seq(Literal(2), + Literal.create(null, IntegerType), Literal.create(null, StringType), + Literal(1), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + } + + test("type coercion for Concat") { + val rule = AnsiTypeCoercion.ConcatCoercion + + ruleTest(rule, + Concat(Seq(Literal("ab"), Literal("cde"))), + Concat(Seq(Literal("ab"), Literal("cde")))) + ruleTest(rule, + Concat(Seq(Literal(null), Literal("abc"))), + Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Concat(Seq(Literal(1), Literal("234"))), + Concat(Seq(Literal(1), Literal("234")))) + ruleTest(rule, + Concat(Seq(Literal("1"), Literal("234".getBytes()))), + Concat(Seq(Literal("1"), Literal("234".getBytes())))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), + Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1)))) + ruleTest(rule, + Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), + Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(0.1))), + Concat(Seq(Literal(1L), Literal(0.1)))) + ruleTest(rule, + Concat(Seq(Literal(Decimal(10)))), + Concat(Seq(Literal(Decimal(10))))) + ruleTest(rule, + Concat(Seq(Literal(BigDecimal.valueOf(10)))), + Concat(Seq(Literal(BigDecimal.valueOf(10))))) + ruleTest(rule, + Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), + Concat(Seq(Literal(java.math.BigDecimal.valueOf(10))))) + ruleTest(rule, + Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0))))) + + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) + } + + test("type coercion for Elt") { + val rule = AnsiTypeCoercion.EltCoercion + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Literal(1), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Literal(1L), Literal(0.1)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Literal(Decimal(10))))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10))))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10))))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0))))) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + + private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) + } + } + + private val timeZoneResolver = ResolveTimeZone + + private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { + timeZoneResolver(AnsiTypeCoercion.WidenSetOperationTypes(plan)) + } + + test("WidenSetOperationTypes for except and intersect") { + val firstTable = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val secondTable = LocalRelation( + AttributeReference("s", LongType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + + val expectedTypes = Seq(LongType, DecimalType.SYSTEM_DEFAULT, DoubleType, DoubleType) + + val r1 = widenSetOperationTypes( + Except(firstTable, secondTable, isAll = false)).asInstanceOf[Except] + val r2 = widenSetOperationTypes( + Intersect(firstTable, secondTable, isAll = false)).asInstanceOf[Intersect] + checkOutput(r1.left, expectedTypes) + checkOutput(r1.right, expectedTypes) + checkOutput(r2.left, expectedTypes) + checkOutput(r2.right, expectedTypes) + + // Check if a Project is added + assert(r1.left.isInstanceOf[Project]) + assert(r1.right.isInstanceOf[Project]) + assert(r2.left.isInstanceOf[Project]) + assert(r2.right.isInstanceOf[Project]) + } + + test("WidenSetOperationTypes for union") { + val firstTable = LocalRelation( + AttributeReference("i", DateType)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val secondTable = LocalRelation( + AttributeReference("s", DateType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + val thirdTable = LocalRelation( + AttributeReference("m", TimestampType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", FloatType)(), + AttributeReference("q", DoubleType)()) + val forthTable = LocalRelation( + AttributeReference("m", DateType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", ByteType)(), + AttributeReference("q", DoubleType)()) + + val expectedTypes = Seq(TimestampType, DecimalType.SYSTEM_DEFAULT, DoubleType, DoubleType) + + val unionRelation = widenSetOperationTypes( + Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] + assert(unionRelation.children.length == 4) + checkOutput(unionRelation.children.head, expectedTypes) + checkOutput(unionRelation.children(1), expectedTypes) + checkOutput(unionRelation.children(2), expectedTypes) + checkOutput(unionRelation.children(3), expectedTypes) + + assert(unionRelation.children.head.isInstanceOf[Project]) + assert(unionRelation.children(1).isInstanceOf[Project]) + assert(unionRelation.children(2).isInstanceOf[Project]) + assert(unionRelation.children(3).isInstanceOf[Project]) + } + + test("Transform Decimal precision/scale for union except and intersect") { + def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) + } + } + + val left1 = LocalRelation( + AttributeReference("l", DecimalType(10, 8))()) + val right1 = LocalRelation( + AttributeReference("r", DecimalType(5, 5))()) + val expectedType1 = Seq(DecimalType(10, 8)) + + val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] + val r2 = widenSetOperationTypes( + Except(left1, right1, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(left1, right1, isAll = false)).asInstanceOf[Intersect] + + checkOutput(r1.children.head, expectedType1) + checkOutput(r1.children.last, expectedType1) + checkOutput(r2.left, expectedType1) + checkOutput(r2.right, expectedType1) + checkOutput(r3.left, expectedType1) + checkOutput(r3.right, expectedType1) + + val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))()) + + val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) + val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), + DecimalType(25, 5), DoubleType, DoubleType) + + rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) => + val plan2 = LocalRelation( + AttributeReference("r", rType)()) + + val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = widenSetOperationTypes( + Except(plan1, plan2, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(plan1, plan2, isAll = false)).asInstanceOf[Intersect] + + checkOutput(r1.children.last, Seq(expectedType)) + checkOutput(r2.right, Seq(expectedType)) + checkOutput(r3.right, Seq(expectedType)) + + val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = widenSetOperationTypes( + Except(plan2, plan1, isAll = false)).asInstanceOf[Except] + val r6 = widenSetOperationTypes( + Intersect(plan2, plan1, isAll = false)).asInstanceOf[Intersect] + + checkOutput(r4.children.last, Seq(expectedType)) + checkOutput(r5.left, Seq(expectedType)) + checkOutput(r6.left, Seq(expectedType)) + } + } + + test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") { + val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))()) + val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))()) + val p1 = t1.select(t1.output.head).as("p1") + val p2 = t2.select(t2.output.head).as("p2") + val union = p1.union(p2) + val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v")) + assert(wp1.isInstanceOf[Project]) + // The attribute `p1.output.head` should be replaced in the root `Project`. + assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty)) + val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union)) + assert(wp2.isInstanceOf[Aggregate]) + assert(wp2.missingInput.isEmpty) + } + + /** + * There are rules that need to not fire before child expressions get resolved. + * We use this test to make sure those rules do not fire early. + */ + test("make sure rules do not fire early") { + // InConversion + val inConversion = AnsiTypeCoercion.InConversion + ruleTest(inConversion, + In(UnresolvedAttribute("a"), Seq(Literal(1))), + In(UnresolvedAttribute("a"), Seq(Literal(1))) + ) + ruleTest(inConversion, + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))) + ) + ruleTest(inConversion, + In(Literal("a"), Seq(Literal(1), Literal("b"))), + In(Literal("a"), Seq(Literal(1), Literal("b"))) + ) + } + + test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + + "in aggregation function like sum") { + val rules = Seq(FunctionArgumentConversion, Division) + // Casts Integer to Double + ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) + // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will + // cast the right expression to Double. + ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) + // Left expression is Int, right expression is Double + ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType)))) + // Casts Float to Double + ruleTest( + rules, + sum(Divide(4.0f, 3)), + sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType)))) + // Left expression is Decimal, right expression is Int. Another rule DecimalPrecision will cast + // the right expression to Decimal. + ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) + } + + test("SPARK-17117 null type coercion in divide") { + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val nullLit = Literal.create(null, NullType) + ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) + ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) + } + + test("cast WindowFrame boundaries to the type they operate upon") { + // Can cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(3), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Cast(3, LongType), Literal(2147483648L))) + ) + // Cannot cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))) + ) + // Should not cast SpecialFrameBoundary. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)) + ) + } + + test("SPARK-29000: skip to handle decimals in ImplicitTypeCasts") { + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + Multiply(CaseWhen(Seq((EqualTo(1, 2), Cast(1, DecimalType(34, 24)))), + Cast(100, DecimalType(34, 24))), Literal(1)), + Multiply(CaseWhen(Seq((EqualTo(1, 2), Cast(1, DecimalType(34, 24)))), + Cast(100, DecimalType(34, 24))), Literal(1))) + + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + Multiply(CaseWhen(Seq((EqualTo(1, 2), Cast(1, DecimalType(34, 24)))), + Cast(100, DecimalType(34, 24))), Cast(1, IntegerType)), + Multiply(CaseWhen(Seq((EqualTo(1, 2), Cast(1, DecimalType(34, 24)))), + Cast(100, DecimalType(34, 24))), Cast(1, IntegerType))) + } + + test("SPARK-31468: null types should be casted to decimal types in ImplicitTypeCasts") { + Seq(AnyTypeBinaryOperator(_, _), NumericTypeBinaryOperator(_, _)).foreach { binaryOp => + // binaryOp(decimal, null) case + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + binaryOp(Literal.create(null, DecimalType.SYSTEM_DEFAULT), + Literal.create(null, NullType)), + binaryOp(Literal.create(null, DecimalType.SYSTEM_DEFAULT), + Cast(Literal.create(null, NullType), DecimalType.SYSTEM_DEFAULT))) + + // binaryOp(null, decimal) case + ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, + binaryOp(Literal.create(null, NullType), + Literal.create(null, DecimalType.SYSTEM_DEFAULT)), + binaryOp(Cast(Literal.create(null, NullType), DecimalType.SYSTEM_DEFAULT), + Literal.create(null, DecimalType.SYSTEM_DEFAULT))) + } + } + + test("SPARK-31761: byte, short and int should be cast to long for IntegralDivide's datatype") { + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + // Casts Byte to Long + ruleTest(AnsiTypeCoercion.IntegralDivision, IntegralDivide(2.toByte, 1.toByte), + IntegralDivide(Cast(2.toByte, LongType), Cast(1.toByte, LongType))) + // Casts Short to Long + ruleTest(AnsiTypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1.toShort), + IntegralDivide(Cast(2.toShort, LongType), Cast(1.toShort, LongType))) + // Casts Integer to Long + ruleTest(AnsiTypeCoercion.IntegralDivision, IntegralDivide(2, 1), + IntegralDivide(Cast(2, LongType), Cast(1, LongType))) + // should not be any change for Long data types + ruleTest(AnsiTypeCoercion.IntegralDivision, IntegralDivide(2L, 1L), IntegralDivide(2L, 1L)) + // one of the operand is byte + ruleTest(AnsiTypeCoercion.IntegralDivision, IntegralDivide(2L, 1.toByte), + IntegralDivide(2L, Cast(1.toByte, LongType))) + // one of the operand is short + ruleTest(AnsiTypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1L), + IntegralDivide(Cast(2.toShort, LongType), 1L)) + // one of the operand is int + ruleTest(AnsiTypeCoercion.IntegralDivision, IntegralDivide(2, 1L), + IntegralDivide(Cast(2, LongType), 1L)) + } + + test("Promote string literals") { + val rule = AnsiTypeCoercion.PromoteStringLiterals + val stringLiteral = Literal("123") + val castStringLiteralAsInt = Cast(stringLiteral, IntegerType) + val castStringLiteralAsDouble = Cast(stringLiteral, DoubleType) + val castStringLiteralAsDate = Cast(stringLiteral, DateType) + val castStringLiteralAsTimestamp = Cast(stringLiteral, TimestampType) + ruleTest(rule, + GreaterThan(stringLiteral, Literal(1)), + GreaterThan(castStringLiteralAsInt, Literal(1))) + ruleTest(rule, + LessThan(Literal(true), stringLiteral), + LessThan(Literal(true), Cast(stringLiteral, BooleanType))) + ruleTest(rule, + EqualTo(Literal(Array(1, 2)), stringLiteral), + EqualTo(Literal(Array(1, 2)), stringLiteral)) + ruleTest(rule, + GreaterThan(stringLiteral, Literal(0.5)), + GreaterThan(castStringLiteralAsDouble, Literal(0.5))) + + val dateLiteral = Literal(java.sql.Date.valueOf("2021-01-01")) + ruleTest(rule, + EqualTo(stringLiteral, dateLiteral), + EqualTo(castStringLiteralAsDate, dateLiteral)) + + val timestampLiteral = Literal(Timestamp.valueOf("2021-01-01 00:00:00")) + ruleTest(rule, + EqualTo(stringLiteral, timestampLiteral), + EqualTo(castStringLiteralAsTimestamp, timestampLiteral)) + + ruleTest(rule, Add(stringLiteral, Literal(1)), + Add(castStringLiteralAsInt, Literal(1))) + ruleTest(rule, Divide(stringLiteral, Literal(1)), + Divide(castStringLiteralAsInt, Literal(1))) + + ruleTest(rule, + In(Literal(1), Seq(stringLiteral, Literal(2))), + In(Literal(1), Seq(castStringLiteralAsInt, Literal(2)))) + ruleTest(rule, + In(Literal(1.0), Seq(stringLiteral, Literal(2.2))), + In(Literal(1.0), Seq(castStringLiteralAsDouble, Literal(2.2)))) + ruleTest(rule, + In(dateLiteral, Seq(stringLiteral)), + In(dateLiteral, Seq(castStringLiteralAsDate))) + ruleTest(rule, + In(timestampLiteral, Seq(stringLiteral)), + In(timestampLiteral, Seq(castStringLiteralAsTimestamp))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 5c4d45b5394f7..a6145c5421d48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -61,13 +61,13 @@ class TypeCoercionSuite extends AnalysisTest { private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { // Check default value - val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to) + val castDefault = TypeCoercion.implicitCast(default(from), to) assert(DataType.equalsIgnoreCompatibleNullability( castDefault.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") // Check null value - val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to) + val castNull = TypeCoercion.implicitCast(createNull(from), to) assert(DataType.equalsIgnoreCaseAndNullability( castNull.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") @@ -75,11 +75,11 @@ class TypeCoercionSuite extends AnalysisTest { private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { // Check default value - val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to) + val castDefault = TypeCoercion.implicitCast(default(from), to) assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") // Check null value - val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to) + val castNull = TypeCoercion.implicitCast(createNull(from), to) assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") } @@ -274,7 +274,7 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, sourceType.valueType))) targetNotNullableTypes.foreach { targetType => val castDefault = - TypeCoercion.ImplicitTypeCasts.implicitCast(sourceMapExprWithValueNull, targetType) + TypeCoercion.implicitCast(sourceMapExprWithValueNull, targetType) assert(castDefault.isEmpty, s"Should not be able to cast $sourceType to $targetType, but got $castDefault") } @@ -1607,8 +1607,9 @@ object TypeCoercionSuite { val fractionalTypes: Seq[DataType] = Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)) val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes + val datetimeTypes: Seq[DataType] = Seq(DateType, TimestampType) val atomicTypes: Seq[DataType] = - numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType) + numericTypes ++ datetimeTypes ++ Seq(BinaryType, BooleanType, StringType) val complexTypes: Seq[DataType] = Seq(ArrayType(IntegerType), ArrayType(StringType), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 4c4df9ef83de9..72d15e8abef6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -23,7 +23,7 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts.implicitCast +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.implicitCast import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/with.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/with.sql index 83c6fd8cbac91..a3e0b15b582f5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/with.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/with.sql @@ -931,7 +931,7 @@ SELECT * FROM outermost ORDER BY 1; -- data-modifying WITH containing INSERT...ON CONFLICT DO UPDATE -- [ORIGINAL SQL] --CREATE TABLE withz AS SELECT i AS k, (i || ' v')::text v FROM generate_series(1, 16, 3) i; -CREATE TABLE withz USING parquet AS SELECT i AS k, CAST(i || ' v' AS string) v FROM (SELECT EXPLODE(SEQUENCE(1, 16, 3)) i); +CREATE TABLE withz USING parquet AS SELECT i AS k, CAST(i AS string) || ' v' AS v FROM (SELECT EXPLODE(SEQUENCE(1, 16, 3)) i); -- [NOTE] Spark SQL doesn't support UNIQUE constraints --ALTER TABLE withz ADD UNIQUE (k); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out index 9c0a740ec96df..b954bf71b9434 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out @@ -536,7 +536,7 @@ select date'2011-11-11' + '1' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(DATE '2011-11-11', CAST('1' AS DOUBLE))' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'CAST('1' AS DOUBLE)' is of double type.; line 1 pos 7 +cannot resolve 'date_add(DATE '2011-11-11', CAST('1' AS DATE))' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'CAST('1' AS DATE)' is of date type.; line 1 pos 7 -- !query @@ -576,8 +576,8 @@ select date '2001-10-01' - '7' -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(DATE '2001-10-01', CAST('7' AS DOUBLE))' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'CAST('7' AS DOUBLE)' is of double type.; line 1 pos 7 +java.time.DateTimeException +Cannot cast 7 to DateType. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 2e2c621c14dc1..217eb2b7235a2 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -196,8 +196,8 @@ select make_interval(0, 0, 0, 0, 0, 0, 1234567890123456789) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -Decimal(expanded,1234567890123456789,20,0}) cannot be represented as Decimal(18, 6). +org.apache.spark.sql.AnalysisException +cannot resolve 'make_interval(0, 0, 0, 0, 0, 0, 1234567890123456789L)' due to data type mismatch: argument 7 requires decimal(18,6) type, however, '1234567890123456789L' is of bigint type.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 7769cd7069048..a1f1d87f5a594 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -80,9 +80,10 @@ invalid input syntax for type numeric: a -- !query select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) -- !query schema -struct +struct<> -- !query output -cd abcd cd NULL +org.apache.spark.sql.AnalysisException +cannot resolve 'substring('abcd', (- CAST('2' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('2' AS DOUBLE))' is of double type.; line 1 pos 43 -- !query @@ -90,8 +91,8 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') -- !query schema struct<> -- !query output -java.lang.NumberFormatException -invalid input syntax for type numeric: a +org.apache.spark.sql.AnalysisException +cannot resolve 'substring('abcd', (- CAST('a' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('a' AS DOUBLE))' is of double type.; line 1 pos 61 -- !query @@ -289,25 +290,28 @@ trim -- !query SELECT btrim(encode(" xyz ", 'utf-8')) -- !query schema -struct +struct<> -- !query output -xyz +org.apache.spark.sql.AnalysisException +cannot resolve 'trim(encode(' xyz ', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode(' xyz ', 'utf-8')' is of binary type.; line 1 pos 7 -- !query SELECT btrim(encode('yxTomxx', 'utf-8'), encode('xyz', 'utf-8')) -- !query schema -struct +struct<> -- !query output -Tom +org.apache.spark.sql.AnalysisException +cannot resolve 'TRIM(BOTH encode('xyz', 'utf-8') FROM encode('yxTomxx', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode('yxTomxx', 'utf-8')' is of binary type. argument 2 requires string type, however, 'encode('xyz', 'utf-8')' is of binary type.; line 1 pos 7 -- !query SELECT btrim(encode('xxxbarxxx', 'utf-8'), encode('x', 'utf-8')) -- !query schema -struct +struct<> -- !query output -bar +org.apache.spark.sql.AnalysisException +cannot resolve 'TRIM(BOTH encode('x', 'utf-8') FROM encode('xxxbarxxx', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode('xxxbarxxx', 'utf-8')' is of binary type. argument 2 requires string type, however, 'encode('x', 'utf-8')' is of binary type.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out index 4e5725746d44a..c9ff3c8debc21 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out @@ -227,22 +227,22 @@ struct SELECT '' AS three, f.f1, f.f1 * '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' -- !query schema -struct +struct -- !query output - 1.2345679E-20 -1.2345678720289608E-19 - 1.2345679E20 -1.2345678955701443E21 - 1004.3 -10042.999877929688 + 1.2345679E-20 -1.2345678E-19 + 1.2345679E20 -1.2345678E21 + 1004.3 -10043.0 -- !query SELECT '' AS three, f.f1, f.f1 + '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' -- !query schema -struct +struct -- !query output 1.2345679E-20 -10.0 - 1.2345679E20 1.2345678955701443E20 - 1004.3 994.2999877929688 + 1.2345679E20 1.2345679E20 + 1004.3 994.3 -- !query @@ -260,11 +260,11 @@ struct SELECT '' AS three, f.f1, f.f1 - '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' -- !query schema -struct +struct -- !query output 1.2345679E-20 10.0 - 1.2345679E20 1.2345678955701443E20 - 1004.3 1014.2999877929688 + 1.2345679E20 1.2345679E20 + 1004.3 1014.3 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out index 28904629df373..253a5e49b81fa 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out @@ -977,33 +977,37 @@ struct -- !query SELECT trim(binary('\\000') from binary('\\000Tom\\000')) -- !query schema -struct +struct<> -- !query output -Tom +org.apache.spark.sql.AnalysisException +cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('\\000Tom\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000Tom\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 -- !query SELECT btrim(binary('\\000trim\\000'), binary('\\000')) -- !query schema -struct +struct<> -- !query output -trim +org.apache.spark.sql.AnalysisException +cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('\\000trim\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000trim\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 -- !query SELECT btrim(binary(''), binary('\\000')) -- !query schema -struct +struct<> -- !query output - +org.apache.spark.sql.AnalysisException +cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 -- !query SELECT btrim(binary('\\000trim\\000'), binary('')) -- !query schema -struct +struct<> -- !query output -\000trim\000 +org.apache.spark.sql.AnalysisException +cannot resolve 'TRIM(BOTH CAST('' AS BINARY) FROM CAST('\\000trim\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000trim\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('' AS BINARY)' is of binary type.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index f16d7e29bdf4c..b8a7571a33731 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -54,9 +54,10 @@ struct -- !query select length(42) -- !query schema -struct +struct<> -- !query output -2 +org.apache.spark.sql.AnalysisException +cannot resolve 'length(42)' due to data type mismatch: argument 1 requires (string or binary) type, however, '42' is of int type.; line 1 pos 7 -- !query @@ -64,8 +65,8 @@ select string('four: ') || 2+2 -- !query schema struct<> -- !query output -java.lang.NumberFormatException -invalid input syntax for type numeric: four: 2 +org.apache.spark.sql.AnalysisException +cannot resolve 'concat(CAST('four: ' AS STRING), 2)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [string, int]; line 1 pos 7 -- !query @@ -73,16 +74,17 @@ select 'four: ' || 2+2 -- !query schema struct<> -- !query output -java.lang.NumberFormatException -invalid input syntax for type numeric: four: 2 +org.apache.spark.sql.AnalysisException +cannot resolve 'concat('four: ', 2)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [string, int]; line 1 pos 7 -- !query select 3 || 4.0 -- !query schema -struct +struct<> -- !query output -34.0 +org.apache.spark.sql.AnalysisException +cannot resolve 'concat(3, 4.0BD)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [int, decimal(2,1)]; line 1 pos 7 -- !query @@ -99,9 +101,10 @@ one -- !query select concat(1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct +struct<> -- !query output -123hellotruefalse2010-03-09 +org.apache.spark.sql.AnalysisException +cannot resolve 'concat(1, 2, 3, 'hello', true, false, to_date('20100309', 'yyyyMMdd'))' due to data type mismatch: input to function concat should have been string, binary or array, but it's [int, int, int, string, boolean, boolean, date]; line 1 pos 7 -- !query @@ -115,33 +118,37 @@ one -- !query select concat_ws('#',1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct +struct<> -- !query output -1#x#x#hello#true#false#x-03-09 +org.apache.spark.sql.AnalysisException +cannot resolve 'concat_ws('#', 1, 2, 3, 'hello', true, false, to_date('20100309', 'yyyyMMdd'))' due to data type mismatch: argument 2 requires (array or string) type, however, '1' is of int type. argument 3 requires (array or string) type, however, '2' is of int type. argument 4 requires (array or string) type, however, '3' is of int type. argument 6 requires (array or string) type, however, 'true' is of boolean type. argument 7 requires (array or string) type, however, 'false' is of boolean type. argument 8 requires (array or string) type, however, 'to_date('20100309', 'yyyyMMdd')' is of date type.; line 1 pos 7 -- !query select concat_ws(',',10,20,null,30) -- !query schema -struct +struct<> -- !query output -10,20,30 +org.apache.spark.sql.AnalysisException +cannot resolve 'concat_ws(',', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 -- !query select concat_ws('',10,20,null,30) -- !query schema -struct +struct<> -- !query output -102030 +org.apache.spark.sql.AnalysisException +cannot resolve 'concat_ws('', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 -- !query select concat_ws(NULL,10,20,null,30) is null -- !query schema -struct<(concat_ws(NULL, 10, 20, NULL, 30) IS NULL):boolean> +struct<> -- !query output -true +org.apache.spark.sql.AnalysisException +cannot resolve 'concat_ws(CAST(NULL AS STRING), 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 -- !query @@ -155,19 +162,10 @@ edcba -- !query select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i -- !query schema -struct --- !query output --5 --4 --3 --2 --1 -0 -1 a j -2 ah oj -3 aho hoj -4 ahoj ahoj -5 ahoj ahoj +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'substring('ahoj', 1, t.`i`)' due to data type mismatch: argument 3 requires int type, however, 't.`i`' is of bigint type.; line 1 pos 10 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/timestamp.sql.out index 4fa1759c78192..9c507a7cec2bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/timestamp.sql.out @@ -256,13 +256,10 @@ SELECT '' AS `54`, d1 as `timestamp`, date_part( 'minute', d1) AS `minute`, date_part( 'second', d1) AS `second` FROM TIMESTAMP_TBL WHERE d1 BETWEEN '1902-01-01' AND '2038-01-01' -- !query schema -struct<54:string,timestamp:timestamp,year:int,month:int,day:int,hour:int,minute:int,second:decimal(8,6)> +struct<> -- !query output - 1969-12-31 16:00:00 1969 12 31 16 0 0.000000 - 1997-01-02 00:00:00 1997 1 2 0 0 0.000000 - 1997-01-02 03:04:05 1997 1 2 3 4 5.000000 - 1997-02-10 17:32:01 1997 2 10 17 32 1.000000 - 2001-09-22 18:19:20 2001 9 22 18 19 20.000000 +org.apache.spark.sql.AnalysisException +cannot resolve 'year(spark_catalog.default.timestamp_tbl.`d1`)' due to data type mismatch: argument 1 requires date type, however, 'spark_catalog.default.timestamp_tbl.`d1`' is of timestamp type.; line 2 pos 4 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out index 2fe53055cf656..4d7bec74f4b0a 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out @@ -682,10 +682,10 @@ struct -- !query SELECT cast('3.4' as decimal(38, 18)) UNION SELECT 'foo' -- !query schema -struct +struct<> -- !query output -3.400000000000000000 -foo +org.apache.spark.SparkException +Failed to merge incompatible data types decimal(38,18) and string -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/with.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/with.sql.out index 167c5fde08882..21bad134706bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/with.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/with.sql.out @@ -223,7 +223,7 @@ Table or view not found: outermost; line 4 pos 23 -- !query -CREATE TABLE withz USING parquet AS SELECT i AS k, CAST(i || ' v' AS string) v FROM (SELECT EXPLODE(SEQUENCE(1, 16, 3)) i) +CREATE TABLE withz USING parquet AS SELECT i AS k, CAST(i AS string) || ' v' AS v FROM (SELECT EXPLODE(SEQUENCE(1, 16, 3)) i) -- !query schema struct<> -- !query output From 714ff73d4aec317fddf32720d5a7a1c283921983 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 24 Feb 2021 06:50:11 +0000 Subject: [PATCH 17/60] [SPARK-34152][SQL] Make CreateViewStatement.child to be LogicalPlan's children so that it's resolved in analyze phase ### What changes were proposed in this pull request? This PR proposes to make `CreateViewStatement.child` to be `LogicalPlan`'s `children` so that it's resolved in the analyze phase. ### Why are the changes needed? Currently, the `CreateViewStatement.child` is resolved when the create view command runs, which is inconsistent with other plan resolutions. For example, you may see the following in the physical plan: ``` == Physical Plan == Execute CreateViewCommand (1) +- CreateViewCommand (2) +- Project (4) +- UnresolvedRelation (3) ``` ### Does this PR introduce _any_ user-facing change? Yes. For the example, you will now see the resolved plan: ``` == Physical Plan == Execute CreateViewCommand (1) +- CreateViewCommand (2) +- Project (5) +- SubqueryAlias (4) +- LogicalRelation (3) ``` ### How was this patch tested? Updated existing tests. Closes #31273 from imback82/spark-34152. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 26 ++++-- .../UnsupportedOperationChecker.scala | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 18 ++-- .../sql/catalyst/catalog/interface.scala | 13 ++- .../plans/logical/basicLogicalOperators.scala | 20 +++-- .../catalyst/plans/logical/statements.scala | 5 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 12 +-- .../sql/catalyst/analysis/AnalysisTest.scala | 13 +++ .../catalyst/analysis/ResolveHintsSuite.scala | 56 ++++++------ .../catalog/SessionCatalogSuite.scala | 35 +++++--- .../analysis/ResolveSessionCatalog.scala | 2 +- .../spark/sql/execution/command/views.scala | 90 +++++++++++-------- .../sql-tests/results/explain-aqe.sql.out | 30 +++++-- .../sql-tests/results/explain.sql.out | 30 +++++-- .../sql/connector/DataSourceV2SQLSuite.scala | 2 +- 15 files changed, 230 insertions(+), 123 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b351d76411ff2..3952cc063b73c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -867,7 +867,20 @@ class Analyzer(override val catalogManager: CatalogManager) } private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty - private def referredTempViewNames: Seq[Seq[String]] = AnalysisContext.get.referredTempViewNames + private def isReferredTempViewName(nameParts: Seq[String]): Boolean = { + AnalysisContext.get.referredTempViewNames.exists { n => + (n.length == nameParts.length) && n.zip(nameParts).forall { + case (a, b) => resolver(a, b) + } + } + } + + private def unwrapRelationPlan(plan: LogicalPlan): LogicalPlan = { + EliminateSubqueryAliases(plan) match { + case v: View if v.isDataFrameTempView => v.child + case other => other + } + } /** * Resolve relations to temp views. This is not an actual rule, and is called by @@ -893,7 +906,7 @@ class Analyzer(override val catalogManager: CatalogManager) case write: V2WriteCommand => write.table match { case UnresolvedRelation(ident, _, false) => - lookupTempView(ident, performCheck = true).map(EliminateSubqueryAliases(_)).map { + lookupTempView(ident, performCheck = true).map(unwrapRelationPlan).map { case r: DataSourceV2Relation => write.withNewTable(r) case _ => throw QueryCompilationErrors.writeIntoTempViewNotAllowedError(ident.quoted) }.getOrElse(write) @@ -930,7 +943,7 @@ class Analyzer(override val catalogManager: CatalogManager) isStreaming: Boolean = false, performCheck: Boolean = false): Option[LogicalPlan] = { // Permanent View can't refer to temp views, no need to lookup at all. - if (isResolvingView && !referredTempViewNames.contains(identifier)) return None + if (isResolvingView && !isReferredTempViewName(identifier)) return None val tmpView = identifier match { case Seq(part1) => v1SessionCatalog.lookupTempView(part1) @@ -948,7 +961,7 @@ class Analyzer(override val catalogManager: CatalogManager) // If we are resolving relations insides views, we need to expand single-part relation names with // the current catalog and namespace of when the view was created. private def expandRelationName(nameParts: Seq[String]): Seq[String] = { - if (!isResolvingView || referredTempViewNames.contains(nameParts)) return nameParts + if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts if (nameParts.length == 1) { AnalysisContext.get.catalogAndNamespace :+ nameParts.head @@ -1145,7 +1158,10 @@ class Analyzer(override val catalogManager: CatalogManager) case other => other } - EliminateSubqueryAliases(relation) match { + // Inserting into a file-based temporary view is allowed. + // (e.g., spark.read.parquet("path").createOrReplaceTempView("t"). + // Thus, we need to look at the raw plan if `relation` is a temporary view. + unwrapRelationPlan(relation) match { case v: View => throw QueryCompilationErrors.insertIntoViewNotAllowedError(v.desc.identifier, table) case other => i.copy(table = other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ab7d90098bfd3..42ccc45cec62f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -393,6 +393,7 @@ object UnsupportedOperationChecker extends Logging { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | _: TypedFilter) => + case v: View if v.isDataFrameTempView => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 689e7c3733180..74a80f566f94a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -622,8 +622,7 @@ class SessionCatalog( } /** - * Generate a [[View]] operator from the view description if the view stores sql text, - * otherwise, it is same to `getRawTempView` + * Generate a [[View]] operator from the temporary view stored. */ def getTempView(name: String): Option[LogicalPlan] = synchronized { getRawTempView(name).map(getTempViewPlan) @@ -641,8 +640,7 @@ class SessionCatalog( } /** - * Generate a [[View]] operator from the view description if the view stores sql text, - * otherwise, it is same to `getRawGlobalTempView` + * Generate a [[View]] operator from the global temporary view stored. */ def getGlobalTempView(name: String): Option[LogicalPlan] = { getRawGlobalTempView(name).map(getTempViewPlan) @@ -683,7 +681,7 @@ class SessionCatalog( val table = formatTableName(name.table) if (name.database.isEmpty) { tempViews.get(table).map { - case TemporaryViewRelation(metadata) => metadata + case TemporaryViewRelation(metadata, _) => metadata case plan => CatalogTable( identifier = TableIdentifier(table), @@ -693,7 +691,7 @@ class SessionCatalog( }.getOrElse(getTableMetadata(name)) } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { globalTempViewManager.get(table).map { - case TemporaryViewRelation(metadata) => metadata + case TemporaryViewRelation(metadata, _) => metadata case plan => CatalogTable( identifier = TableIdentifier(table, Some(globalTempViewManager.database)), @@ -838,9 +836,11 @@ class SessionCatalog( private def getTempViewPlan(plan: LogicalPlan): LogicalPlan = { plan match { - case viewInfo: TemporaryViewRelation => - fromCatalogTable(viewInfo.tableMeta, isTempView = true) - case v => v + case TemporaryViewRelation(tableMeta, None) => + fromCatalogTable(tableMeta, isTempView = true) + case TemporaryViewRelation(tableMeta, Some(plan)) => + View(desc = tableMeta, isTempView = true, child = plan) + case other => other } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 89cb103a7bf73..b6a23214c9084 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_CREATED_FROM_DATAFRAME import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils @@ -467,6 +468,8 @@ object CatalogTable { val VIEW_REFERRED_TEMP_VIEW_NAMES = VIEW_PREFIX + "referredTempViewNames" val VIEW_REFERRED_TEMP_FUNCTION_NAMES = VIEW_PREFIX + "referredTempFunctionsNames" + val VIEW_CREATED_FROM_DATAFRAME = VIEW_PREFIX + "createdFromDataFrame" + def splitLargeTableProp( key: String, value: String, @@ -779,9 +782,15 @@ case class UnresolvedCatalogRelation( /** * A wrapper to store the temporary view info, will be kept in `SessionCatalog` - * and will be transformed to `View` during analysis + * and will be transformed to `View` during analysis. If the temporary view was + * created from a dataframe, `plan` is set to the analyzed plan for the view. */ -case class TemporaryViewRelation(tableMeta: CatalogTable) extends LeafNode { +case class TemporaryViewRelation( + tableMeta: CatalogTable, + plan: Option[LogicalPlan] = None) extends LeafNode { + require(plan.isEmpty || + (plan.get.resolved && tableMeta.properties.contains(VIEW_CREATED_FROM_DATAFRAME))) + override lazy val resolved: Boolean = false override def output: Seq[Attribute] = Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index fad1457ac1403..3f20d8f67b44d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_CREATED_FROM_DATAFRAME import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -443,21 +444,25 @@ case class InsertIntoDir( } /** - * A container for holding the view description(CatalogTable), and the output of the view. The - * child should be a logical plan parsed from the `CatalogTable.viewText`, should throw an error - * if the `viewText` is not defined. + * A container for holding the view description(CatalogTable) and info whether the view is temporary + * or not. If it's a SQL (temp) view, the child should be a logical plan parsed from the + * `CatalogTable.viewText`. Otherwise, the view is a temporary one created from a dataframe and the + * view description should contain a `VIEW_CREATED_FROM_DATAFRAME` property; in this case, the child + * must be already resolved. + * * This operator will be removed at the end of analysis stage. * * @param desc A view description(CatalogTable) that provides necessary information to resolve the * view. - * we are able to decouple the output from the underlying structure. - * @param child The logical plan of a view operator, it should be a logical plan parsed from the - * `CatalogTable.viewText`, should throw an error if the `viewText` is not defined. + * @param isTempView A flag to indicate whether the view is temporary or not. + * @param child The logical plan of a view operator. If the view description is available, it should + * be a logical plan parsed from the `CatalogTable.viewText`. */ case class View( desc: CatalogTable, isTempView: Boolean, child: LogicalPlan) extends UnaryNode { + require(!isDataFrameTempView || child.resolved) override def output: Seq[Attribute] = child.output @@ -470,6 +475,9 @@ case class View( case _ => child.canonicalized } + def isDataFrameTempView: Boolean = + isTempView && desc.properties.contains(VIEW_CREATED_FROM_DATAFRAME) + // When resolving a SQL view, we use an extra Project to add cast and alias to make sure the view // output schema doesn't change even if the table referenced by the view is changed after view // creation. We should remove this extra Project during canonicalize if it does nothing. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 26f36bfe9b970..cc6e387d0f600 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -179,7 +179,10 @@ case class CreateViewStatement( child: LogicalPlan, allowExisting: Boolean, replace: Boolean, - viewType: ViewType) extends ParsedStatement + viewType: ViewType) extends ParsedStatement { + + override def children: Seq[LogicalPlan] = Seq(child) +} /** * A REPLACE TABLE command, as parsed from SQL. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0e0142eb76894..a3c26ecdaba2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -95,7 +95,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Project(Seq(UnresolvedAttribute("a")), testRelation), Project(testRelation.output, testRelation)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( Project(Seq(UnresolvedAttribute("TbL.a")), SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation)) @@ -105,13 +105,13 @@ class AnalysisSuite extends AnalysisTest with Matchers { SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Seq("cannot resolve")) - checkAnalysis( + checkAnalysisWithoutViewWrapper( Project(Seq(UnresolvedAttribute("TbL.a")), SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), caseSensitive = false) - checkAnalysis( + checkAnalysisWithoutViewWrapper( Project(Seq(UnresolvedAttribute("tBl.a")), SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), @@ -203,10 +203,10 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("resolve relations") { assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe")), Seq()) - checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE")), testRelation) - checkAnalysis( + checkAnalysisWithoutViewWrapper(UnresolvedRelation(TableIdentifier("TaBlE")), testRelation) + checkAnalysisWithoutViewWrapper( UnresolvedRelation(TableIdentifier("tAbLe")), testRelation, caseSensitive = false) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedRelation(TableIdentifier("TaBlE")), testRelation, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 37db4be502a83..7248424a68ad9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -59,6 +59,19 @@ trait AnalysisTest extends PlanTest { } } + protected def checkAnalysisWithoutViewWrapper( + inputPlan: LogicalPlan, + expectedPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val actualPlan = getAnalyzer.executeAndCheck(inputPlan, new QueryPlanningTracker) + val transformed = actualPlan transformUp { + case v: View if v.isDataFrameTempView => v.child + } + comparePlans(transformed, expectedPlan) + } + } + protected override def comparePlans( plan1: LogicalPlan, plan2: LogicalPlan, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 513f1d001f757..9db64c684c40f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -31,36 +31,36 @@ class ResolveHintsSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("invalid hints should be ignored") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), testRelation, caseSensitive = false) } test("case-sensitive or insensitive parameters") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = true) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), testRelation, caseSensitive = true) } test("multiple broadcast hint aliases") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), Join(ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))), @@ -69,7 +69,7 @@ class ResolveHintsSuite extends AnalysisTest { } test("do not traverse past existing broadcast hints") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("table"), ResolvedHint(table("table").where('a > 1), HintInfo(strategy = Some(BROADCAST)))), ResolvedHint(testRelation.where('a > 1), HintInfo(strategy = Some(BROADCAST))).analyze, @@ -77,32 +77,32 @@ class ResolveHintsSuite extends AnalysisTest { } test("should work for subqueries") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("table"), table("table").as("tableAlias")), testRelation, caseSensitive = false) } test("do not traverse past subquery alias") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), testRelation.where('a > 1).analyze, caseSensitive = false) } test("should work for CTE") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( CatalystSqlParser.parsePlan( """ |WITH ctetable AS (SELECT * FROM table WHERE a > 1) @@ -115,7 +115,7 @@ class ResolveHintsSuite extends AnalysisTest { } test("should not traverse down CTE") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( CatalystSqlParser.parsePlan( """ |WITH ctetable AS (SELECT * FROM table WHERE a > 1) @@ -127,16 +127,16 @@ class ResolveHintsSuite extends AnalysisTest { } test("coalesce and repartition hint") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("COALESCE", Seq(Literal(10)), table("TaBlE")), Repartition(numPartitions = 10, shuffle = false, child = testRelation)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("coalesce", Seq(Literal(20)), table("TaBlE")), Repartition(numPartitions = 20, shuffle = false, child = testRelation)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("REPARTITION", Seq(Literal(100)), table("TaBlE")), Repartition(numPartitions = 100, shuffle = true, child = testRelation)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("RePARTITion", Seq(Literal(200)), table("TaBlE")), Repartition(numPartitions = 200, shuffle = true, child = testRelation)) @@ -152,15 +152,15 @@ class ResolveHintsSuite extends AnalysisTest { UnresolvedHint("COALESCE", Seq(Literal(1.0)), table("TaBlE")), Seq(errMsg)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("RePartition", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("REPARTITION", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression( Seq(AttributeReference("a", IntegerType)()), testRelation, None)) @@ -176,13 +176,13 @@ class ResolveHintsSuite extends AnalysisTest { } e.getMessage.contains("For range partitioning use REPARTITION_BY_RANGE instead") - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint( "REPARTITION_BY_RANGE", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression( Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), testRelation, 10)) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint( "REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression( @@ -225,7 +225,7 @@ class ResolveHintsSuite extends AnalysisTest { test("log warnings for invalid hints") { val logAppender = new LogAppender("invalid hints") withLogAppender(logAppender) { - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint("unknown_hint", Seq("TaBlE"), table("TaBlE")), testRelation, caseSensitive = false) @@ -236,7 +236,7 @@ class ResolveHintsSuite extends AnalysisTest { } test("SPARK-30003: Do not throw stack overflow exception in non-root unknown hint resolution") { - checkAnalysis( + checkAnalysisWithoutViewWrapper( Project(testRelation.output, UnresolvedHint("unknown_hint", Seq("TaBlE"), table("TaBlE"))), Project(testRelation.output, testRelation), caseSensitive = false) @@ -248,7 +248,7 @@ class ResolveHintsSuite extends AnalysisTest { ("SHUFFLE_HASH", SHUFFLE_HASH), ("SHUFFLE_REPLICATE_NL", SHUFFLE_REPLICATE_NL)).foreach { case (hintName, st) => // local temp table (single-part identifier case) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint(hintName, Seq("table", "table2"), table("TaBlE").join(table("TaBlE2"))), Join( @@ -259,7 +259,7 @@ class ResolveHintsSuite extends AnalysisTest { JoinHint.NONE), caseSensitive = false) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint(hintName, Seq("TaBlE", "table2"), table("TaBlE").join(table("TaBlE2"))), Join( @@ -271,7 +271,7 @@ class ResolveHintsSuite extends AnalysisTest { caseSensitive = true) // global temp table (multi-part identifier case) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint(hintName, Seq("GlOBal_TeMP.table4", "table5"), table("global_temp", "table4").join(table("global_temp", "table5"))), Join( @@ -282,7 +282,7 @@ class ResolveHintsSuite extends AnalysisTest { JoinHint.NONE), caseSensitive = false) - checkAnalysis( + checkAnalysisWithoutViewWrapper( UnresolvedHint(hintName, Seq("global_temp.TaBlE4", "table5"), table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))), Join( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 57b728aa5eb95..635fea9114434 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import org.scalatest.concurrent.Eventually import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -84,6 +84,12 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.reset() } } + + private def getTempViewRawPlan(plan: Option[LogicalPlan]): Option[LogicalPlan] = plan match { + case Some(v: View) if v.isDataFrameTempView => Some(v.child) + case other => other + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- @@ -301,16 +307,16 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { val tempTable2 = Range(1, 20, 2, 10) catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) - assert(catalog.getTempView("tbl1") == Option(tempTable1)) - assert(catalog.getTempView("tbl2") == Option(tempTable2)) - assert(catalog.getTempView("tbl3").isEmpty) + assert(getTempViewRawPlan(catalog.getTempView("tbl1")) == Option(tempTable1)) + assert(getTempViewRawPlan(catalog.getTempView("tbl2")) == Option(tempTable2)) + assert(getTempViewRawPlan(catalog.getTempView("tbl3")).isEmpty) // Temporary view already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } // Temporary view already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) - assert(catalog.getTempView("tbl1") == Option(tempTable2)) + assert(getTempViewRawPlan(catalog.getTempView("tbl1")) == Option(tempTable2)) } } @@ -352,7 +358,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { val tempTable = Range(1, 10, 2, 10) catalog.createTempView("tbl1", tempTable, overrideIfExists = false) catalog.setCurrentDatabase("db2") - assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(getTempViewRawPlan(catalog.getTempView("tbl1")) == Some(tempTable)) assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be dropped first catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) @@ -366,7 +372,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) catalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, purge = false) - assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(getTempViewRawPlan(catalog.getTempView("tbl1")) == Some(tempTable)) assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) } } @@ -419,16 +425,16 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { val tempTable = Range(1, 10, 2, 10) catalog.createTempView("tbl1", tempTable, overrideIfExists = false) catalog.setCurrentDatabase("db2") - assert(catalog.getTempView("tbl1") == Option(tempTable)) + assert(getTempViewRawPlan(catalog.getTempView("tbl1")) == Option(tempTable)) assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be renamed first catalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) assert(catalog.getTempView("tbl1").isEmpty) - assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(getTempViewRawPlan(catalog.getTempView("tbl3")) == Option(tempTable)) assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is specified, temp tables are never renamed catalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) - assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(getTempViewRawPlan(catalog.getTempView("tbl3")) == Option(tempTable)) assert(catalog.getTempView("tbl4").isEmpty) assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) } @@ -625,8 +631,9 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { assert(catalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head .asInstanceOf[UnresolvedCatalogRelation].tableMeta == metastoreTable1) // Otherwise, we'll first look up a temporary table with the same name - assert(catalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", tempTable1)) + val tbl1 = catalog.lookupRelation(TableIdentifier("tbl1")).asInstanceOf[SubqueryAlias] + assert(tbl1.identifier == AliasIdentifier("tbl1")) + assert(getTempViewRawPlan(Some(tbl1.child)).get == tempTable1) // Then, if that does not exist, look up the relation in the current database catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) assert(catalog.lookupRelation(TableIdentifier("tbl1")).children.head @@ -1581,11 +1588,11 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { original.copyStateTo(clone) assert(original ne clone) - assert(clone.getTempView("copytest1") == Some(tempTable1)) + assert(getTempViewRawPlan(clone.getTempView("copytest1")) == Some(tempTable1)) // check if clone and original independent clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false) - assert(original.getTempView("copytest1") == Some(tempTable1)) + assert(getTempViewRawPlan(original.getTempView("copytest1")) == Some(tempTable1)) val tempTable2 = Range(1, 20, 2, 10) original.createTempView("copytest2", tempTable2, overrideIfExists = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 7ddd2ab6d913c..232e8a16cdd76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -471,7 +471,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case CreateViewStatement( tbl, userSpecifiedColumns, comment, properties, - originalText, child, allowExisting, replace, viewType) => + originalText, child, allowExisting, replace, viewType) if child.resolved => val v1TableName = if (viewType != PersistedView) { // temp view doesn't belong to any catalog and we shouldn't resolve catalog in the name. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index a14f247515773..bae7c54f0b99d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -23,12 +23,12 @@ import org.json4s.JsonAST.{JArray, JString} import org.json4s.jackson.JsonMethods._ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, UnresolvedFunction, UnresolvedRelation, ViewType} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, ViewType} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, SessionCatalog, TemporaryViewRelation} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression, UserDefinedExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View, With} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -111,12 +111,11 @@ case class CreateViewCommand( // When creating a permanent view, not allowed to reference temporary objects. // This should be called after `qe.assertAnalyzed()` (i.e., `child` can be resolved) - verifyTemporaryObjectsNotExists(catalog, isTemporary, name, child) + verifyTemporaryObjectsNotExists(catalog, isTemporary, name, analyzedPlan) if (viewType == LocalTempView) { val aliasedPlan = aliasPlan(sparkSession, analyzedPlan) - if (replace && catalog.getRawTempView(name.table).isDefined && - !catalog.getRawTempView(name.table).get.sameResult(aliasedPlan)) { + if (replace && !isSamePlan(catalog.getRawTempView(name.table), aliasedPlan)) { logInfo(s"Try to uncache ${name.quotedString} before replacing.") checkCyclicViewReference(analyzedPlan, Seq(name), name) CommandUtils.uncacheTableOrView(sparkSession, name.quotedString) @@ -129,18 +128,18 @@ case class CreateViewCommand( sparkSession, analyzedPlan, aliasedPlan.schema, - originalText, - child)) + originalText)) } else { - aliasedPlan + TemporaryViewRelation( + prepareTemporaryViewFromDataFrame(name, aliasedPlan), + Some(aliasedPlan)) } catalog.createTempView(name.table, tableDefinition, overrideIfExists = replace) } else if (viewType == GlobalTempView) { val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) val viewIdent = TableIdentifier(name.table, Option(db)) val aliasedPlan = aliasPlan(sparkSession, analyzedPlan) - if (replace && catalog.getRawGlobalTempView(name.table).isDefined && - !catalog.getRawGlobalTempView(name.table).get.sameResult(aliasedPlan)) { + if (replace && !isSamePlan(catalog.getRawGlobalTempView(name.table), aliasedPlan)) { logInfo(s"Try to uncache ${viewIdent.quotedString} before replacing.") checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) CommandUtils.uncacheTableOrView(sparkSession, viewIdent.quotedString) @@ -152,10 +151,11 @@ case class CreateViewCommand( sparkSession, analyzedPlan, aliasedPlan.schema, - originalText, - child)) + originalText)) } else { - aliasedPlan + TemporaryViewRelation( + prepareTemporaryViewFromDataFrame(name, aliasedPlan), + Some(aliasedPlan)) } catalog.createGlobalTempView(name.table, tableDefinition, overrideIfExists = replace) } else if (catalog.tableExists(name)) { @@ -192,6 +192,18 @@ case class CreateViewCommand( Seq.empty[Row] } + /** + * Checks if the temp view (the result of getTempViewRawPlan or getRawGlobalTempView) is storing + * the same plan as the given aliased plan. + */ + private def isSamePlan( + rawTempView: Option[LogicalPlan], + aliasedPlan: LogicalPlan): Boolean = rawTempView match { + case Some(TemporaryViewRelation(_, Some(p))) => p.sameResult(aliasedPlan) + case Some(p) => p.sameResult(aliasedPlan) + case _ => false + } + /** * If `userSpecifiedColumns` is defined, alias the analyzed plan to the user specified columns, * else return the analyzed plan directly. @@ -280,7 +292,7 @@ case class AlterViewAsCommand( checkCyclicViewReference(analyzedPlan, Seq(name), name) TemporaryViewRelation( prepareTemporaryView( - name, session, analyzedPlan, analyzedPlan.schema, Some(originalText), query)) + name, session, analyzedPlan, analyzedPlan.schema, Some(originalText))) } session.sessionState.catalog.alterTempViewDefinition(name, tableDefinition) } @@ -541,40 +553,32 @@ object ViewHelper { } /** - * Collect all temporary views and functions and return the identifiers separately - * This func traverses the unresolved plan `child`. Below are the reasons: - * 1) Analyzer replaces unresolved temporary views by a SubqueryAlias with the corresponding - * logical plan. After replacement, it is impossible to detect whether the SubqueryAlias is - * added/generated from a temporary view. - * 2) The temp functions are represented by multiple classes. Most are inaccessible from this - * package (e.g., HiveGenericUDF). + * Collect all temporary views and functions and return the identifiers separately. */ private def collectTemporaryObjects( catalog: SessionCatalog, child: LogicalPlan): (Seq[Seq[String]], Seq[String]) = { def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = { child.flatMap { - case UnresolvedRelation(nameParts, _, _) if catalog.isTempView(nameParts) => - Seq(nameParts) - case w: With if !w.resolved => w.innerChildren.flatMap(collectTempViews) - case plan if !plan.resolved => plan.expressions.flatMap(_.flatMap { + case view: View if view.isTempView => + val ident = view.desc.identifier + Seq(ident.database.toSeq :+ ident.table) + case plan => plan.expressions.flatMap(_.flatMap { case e: SubqueryExpression => collectTempViews(e.plan) case _ => Seq.empty }) - case _ => Seq.empty }.distinct } def collectTempFunctions(child: LogicalPlan): Seq[String] = { child.flatMap { - case w: With if !w.resolved => w.innerChildren.flatMap(collectTempFunctions) - case plan if !plan.resolved => + case plan => plan.expressions.flatMap(_.flatMap { case e: SubqueryExpression => collectTempFunctions(e.plan) - case e: UnresolvedFunction if catalog.isTemporaryFunction(e.name) => - Seq(e.name.funcName) + case e: UserDefinedExpression + if catalog.isTemporaryFunction(FunctionIdentifier(e.name)) => + Seq(e.name) case _ => Seq.empty }) - case _ => Seq.empty }.distinct } (collectTempViews(child), collectTempFunctions(child)) @@ -592,11 +596,10 @@ object ViewHelper { session: SparkSession, analyzedPlan: LogicalPlan, viewSchema: StructType, - originalText: Option[String], - child: LogicalPlan): CatalogTable = { + originalText: Option[String]): CatalogTable = { val catalog = session.sessionState.catalog - val (tempViews, tempFunctions) = collectTemporaryObjects(catalog, child) + val (tempViews, tempFunctions) = collectTemporaryObjects(catalog, analyzedPlan) // TBLPROPERTIES is not allowed for temporary view, so we don't use it for // generating temporary view properties val newProperties = generateViewProperties( @@ -610,4 +613,19 @@ object ViewHelper { viewText = originalText, properties = newProperties) } + + /** + * Returns a [[CatalogTable]] that contains information for the temporary view created + * from a dataframe. + */ + def prepareTemporaryViewFromDataFrame( + viewName: TableIdentifier, + analyzedPlan: LogicalPlan): CatalogTable = { + CatalogTable( + identifier = viewName, + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = analyzedPlan.schema, + properties = Map((VIEW_CREATED_FROM_DATAFRAME, "true"))) + } } diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 26c85a0241e57..bcb98396b3028 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -904,8 +904,9 @@ struct == Physical Plan == Execute CreateViewCommand (1) +- CreateViewCommand (2) - +- Project (4) - +- UnresolvedRelation (3) + +- Project (5) + +- SubqueryAlias (4) + +- LogicalRelation (3) (1) Execute CreateViewCommand @@ -914,11 +915,26 @@ Output: [] (2) CreateViewCommand Arguments: `default`.`explain_view`, SELECT key, val FROM explain_temp1, false, false, PersistedView -(3) UnresolvedRelation -Arguments: [explain_temp1], [], false - -(4) Project -Arguments: ['key, 'val] +(3) LogicalRelation +Arguments: parquet, [key#x, val#x], CatalogTable( +Database: default +Table: explain_temp1 +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type: MANAGED +Provider: PARQUET +Location [not included in comparison]/{warehouse_dir}/explain_temp1 +Schema: root +-- key: integer (nullable = true) +-- val: integer (nullable = true) +), false + +(4) SubqueryAlias +Arguments: spark_catalog.default.explain_temp1 + +(5) Project +Arguments: [key#x, val#x] -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index fb43da8b5b604..a72a5f0a2aa86 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -849,8 +849,9 @@ struct == Physical Plan == Execute CreateViewCommand (1) +- CreateViewCommand (2) - +- Project (4) - +- UnresolvedRelation (3) + +- Project (5) + +- SubqueryAlias (4) + +- LogicalRelation (3) (1) Execute CreateViewCommand @@ -859,11 +860,26 @@ Output: [] (2) CreateViewCommand Arguments: `default`.`explain_view`, SELECT key, val FROM explain_temp1, false, false, PersistedView -(3) UnresolvedRelation -Arguments: [explain_temp1], [], false - -(4) Project -Arguments: ['key, 'val] +(3) LogicalRelation +Arguments: parquet, [key#x, val#x], CatalogTable( +Database: default +Table: explain_temp1 +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type: MANAGED +Provider: PARQUET +Location [not included in comparison]/{warehouse_dir}/explain_temp1 +Schema: root +-- key: integer (nullable = true) +-- val: integer (nullable = true) +), false + +(4) SubqueryAlias +Arguments: spark_catalog.default.explain_temp1 + +(5) Project +Arguments: [key#x, val#x] -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 533428f9504b1..2f57298856fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1992,7 +1992,7 @@ class DataSourceV2SQLSuite test("CREATE VIEW") { val v = "testcat.ns1.ns2.v" val e = intercept[AnalysisException] { - sql(s"CREATE VIEW $v AS SELECT * FROM tab1") + sql(s"CREATE VIEW $v AS SELECT 1") } assert(e.message.contains("CREATE VIEW is only supported with v1 tables")) } From 14934f42d066ca147df508b0a97bfa03223046ba Mon Sep 17 00:00:00 2001 From: beliefer Date: Wed, 24 Feb 2021 07:28:44 +0000 Subject: [PATCH 18/60] [SPARK-33599][SQL][FOLLOWUP] Group exception messages in catalyst/analysis ### What changes were proposed in this pull request? This PR follows up https://github.com/apache/spark/pull/30717 Maybe some contributors don't know the job and added some exception by the old way. ### Why are the changes needed? It will largely help with standardization of error messages and its maintenance. ### Does this PR introduce _any_ user-facing change? No. Error messages remain unchanged. ### How was this patch tested? No new tests - pass all original tests to make sure it doesn't break any existing behavior. Closes #31316 from beliefer/SPARK-33599-followup. Lead-authored-by: beliefer Co-authored-by: gengjiaan Co-authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/analysis/ResolveHints.scala | 7 +++---- .../spark/sql/errors/QueryCompilationErrors.scala | 13 +++++++++++++ .../sql/catalyst/analysis/ResolveHintsSuite.scala | 3 ++- .../catalyst/analysis/ResolveSessionCatalog.scala | 12 ++++++------ 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index ab7a59d4588ea..4544ee72733e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -187,10 +187,9 @@ object ResolveHints { def createRepartitionByExpression( numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = { val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder]) - if (sortOrders.nonEmpty) throw new IllegalArgumentException( - s"""Invalid partitionExprs specified: $sortOrders - |For range partitioning use REPARTITION_BY_RANGE instead. - """.stripMargin) + if (sortOrders.nonEmpty) { + throw QueryCompilationErrors.invalidRepartitionExpressionsError(sortOrders) + } val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute]) if (invalidParams.nonEmpty) { throw QueryCompilationErrors.invalidHintParameterError(hintName, invalidParams) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 87542c64effdc..e599d1223b52c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -741,4 +741,17 @@ private[spark] object QueryCompilationErrors { new InvalidUDFClassException(s"No handler for UDAF '$name'. " + "Use sparkSession.udf.register(...) instead.") } + + def databaseFromV1SessionCatalogNotSpecifiedError(): Throwable = { + new AnalysisException("Database from v1 session catalog is not specified") + } + + def nestedDatabaseUnsupportedByV1SessionCatalogError(catalog: String): Throwable = { + new AnalysisException(s"Nested databases are not supported by v1 session catalog: $catalog") + } + + def invalidRepartitionExpressionsError(sortOrders: Seq[Any]): Throwable = { + new AnalysisException(s"Invalid partitionExprs specified: $sortOrders For range " + + "partitioning use REPARTITION_BY_RANGE instead.") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 9db64c684c40f..4b9a2ca94ea50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.log4j.Level +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Literal, SortOrder} @@ -165,7 +166,7 @@ class ResolveHintsSuite extends AnalysisTest { RepartitionByExpression( Seq(AttributeReference("a", IntegerType)()), testRelation, None)) - val e = intercept[IllegalArgumentException] { + val e = intercept[AnalysisException] { checkAnalysis( UnresolvedHint("REPARTITION", Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 232e8a16cdd76..68c608310e214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} @@ -217,8 +217,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case Alias(child, _) => throw QueryCompilationErrors.commandNotSupportNestedColumnError( "DESC TABLE COLUMN", toPrettySQL(child)) - case other => - throw new AnalysisException(s"[BUG] unexpected column expression: $other") + case _ => + throw new IllegalStateException(s"[BUG] unexpected column expression: $column") } // For CREATE TABLE [AS SELECT], we should use the v1 command if the catalog is resolved to the @@ -708,12 +708,12 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) def unapply(resolved: ResolvedNamespace): Option[String] = resolved match { case ResolvedNamespace(catalog, _) if !isSessionCatalog(catalog) => None case ResolvedNamespace(_, Seq()) => - throw new AnalysisException("Database from v1 session catalog is not specified") + throw QueryCompilationErrors.databaseFromV1SessionCatalogNotSpecifiedError() case ResolvedNamespace(_, Seq(dbName)) => Some(dbName) case _ => assert(resolved.namespace.length > 1) - throw new AnalysisException("Nested databases are not supported by " + - s"v1 session catalog: ${resolved.namespace.map(quoteIfNeeded).mkString(".")}") + throw QueryCompilationErrors.nestedDatabaseUnsupportedByV1SessionCatalogError( + resolved.namespace.map(quoteIfNeeded).mkString(".")) } } } From 6ef57d31cde110d9740ba6fb646818388feb8054 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 24 Feb 2021 10:23:01 +0000 Subject: [PATCH 19/60] [SPARK-34514][SQL] Push down limit for LEFT SEMI and LEFT ANTI join ### What changes were proposed in this pull request? I found out during code review of https://github.com/apache/spark/pull/31567#discussion_r577379572, where we can push down limit to the left side of LEFT SEMI and LEFT ANTI join, if the join condition is empty. Why it's safe to push down limit: The semantics of LEFT SEMI join without condition: (1). if right side is non-empty, output all rows from left side. (2). if right side is empty, output nothing. The semantics of LEFT ANTI join without condition: (1). if right side is non-empty, output nothing. (2). if right side is empty, output all rows from left side. With the semantics of output all rows from left side or nothing (all or nothing), it's safe to push down limit to left side. NOTE: LEFT SEMI / LEFT ANTI join with non-empty condition is not safe for limit push down, because output can be a portion of left side rows. Reference: physical operator implementation for LEFT SEMI / LEFT ANTI join without condition - https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala#L200-L204 . ### Why are the changes needed? Better performance. Save CPU and IO for these joins, as limit being pushed down before join. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test in `LimitPushdownSuite.scala` and `SQLQuerySuite.scala`. Closes #31630 from c21/limit-pushdown. Authored-by: Cheng Su Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 20 ++++++++----- .../optimizer/LimitPushdownSuite.scala | 20 ++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 30 +++++++++++++++++++ 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 46a90f600b2a3..b08187d0bc3be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -502,7 +502,7 @@ object RemoveNoopOperators extends Rule[LogicalPlan] { } /** - * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. + * Pushes down [[LocalLimit]] beneath UNION ALL and joins. */ object LimitPushDown extends Rule[LogicalPlan] { @@ -539,12 +539,16 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit. case LocalLimit(exp, u: Union) => LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) - // Add extra limits below JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to - // the left and right sides, respectively. For INNER and CROSS JOIN we push limits to - // both the left and right sides if join condition is empty. It's not safe to push limits - // below FULL OUTER JOIN in the general case without a more invasive rewrite. - // We also need to ensure that this limit pushdown rule will not eventually introduce limits - // on both sides if it is applied multiple times. Therefore: + // Add extra limits below JOIN: + // 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides, + // respectively. + // 2. For INNER and CROSS JOIN, we push limits to both the left and right sides if join + // condition is empty. + // 3. For LEFT SEMI and LEFT ANTI JOIN, we push limits to the left side if join condition + // is empty. + // It's not safe to push limits below FULL OUTER JOIN in the general case without a more + // invasive rewrite. We also need to ensure that this limit pushdown rule will not eventually + // introduce limits on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. case LocalLimit(exp, join @ Join(left, right, joinType, conditionOpt, _)) => @@ -555,6 +559,8 @@ object LimitPushDown extends Rule[LogicalPlan] { join.copy( left = maybePushLocalLimit(exp, left), right = maybePushLocalLimit(exp, right)) + case LeftSemi | LeftAnti if conditionOpt.isEmpty => + join.copy(left = maybePushLocalLimit(exp, left)) case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 5c760264ff219..7a33b5b4b53df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Add -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -212,4 +212,22 @@ class LimitPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-34514: Push down limit through LEFT SEMI and LEFT ANTI join") { + // Push down when condition is empty + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.join(y, joinType).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, joinType)).analyze + comparePlans(optimized, correctAnswer) + } + + // No push down when condition is not empty + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.join(y, joinType, Some("x.a".attr === "y.b".attr)).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(y, joinType, Some("x.a".attr === "y.b".attr))).analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index fe8a080ac5aeb..82c49f9cbf29a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4034,6 +4034,36 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df, Row(0, 0) :: Row(0, 1) :: Row(0, 2) :: Nil) } } + + test("SPARK-34514: Push down limit through LEFT SEMI and LEFT ANTI join") { + withTable("left_table", "nonempty_right_table", "empty_right_table") { + spark.range(5).toDF().repartition(1).write.saveAsTable("left_table") + spark.range(3).write.saveAsTable("nonempty_right_table") + spark.range(0).write.saveAsTable("empty_right_table") + Seq("LEFT SEMI", "LEFT ANTI").foreach { joinType => + val joinWithNonEmptyRightDf = spark.sql( + s"SELECT * FROM left_table $joinType JOIN nonempty_right_table LIMIT 3") + val joinWithEmptyRightDf = spark.sql( + s"SELECT * FROM left_table $joinType JOIN empty_right_table LIMIT 3") + + Seq(joinWithNonEmptyRightDf, joinWithEmptyRightDf).foreach { df => + val pushedLocalLimits = df.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: LogicalRelation) => l + } + assert(pushedLocalLimits.length === 1) + } + + val expectedAnswer = Seq(Row(0), Row(1), Row(2)) + if (joinType == "LEFT SEMI") { + checkAnswer(joinWithNonEmptyRightDf, expectedAnswer) + checkAnswer(joinWithEmptyRightDf, Seq.empty) + } else { + checkAnswer(joinWithNonEmptyRightDf, Seq.empty) + checkAnswer(joinWithEmptyRightDf, expectedAnswer) + } + } + } + } } case class Foo(bar: Option[String]) From 87409c42bcca5d9b73b6a472017c5dd65da0718d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Feb 2021 21:13:58 +0900 Subject: [PATCH 20/60] [SPARK-31891][SQL][DOCS][FOLLOWUP] Fix typo in the description of `MSCK REPAIR TABLE` ### What changes were proposed in this pull request? Fix typo and highlight that `ADD PARTITIONS` is the default. ### Why are the changes needed? Fix a typo which can mislead users. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? n/a Closes #31633 from MaxGekk/repair-table-drop-partitions-followup. Lead-authored-by: Wenchen Fan Co-authored-by: Max Gekk Signed-off-by: HyukjinKwon --- docs/sql-ref-syntax-ddl-repair-table.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sql-ref-syntax-ddl-repair-table.md b/docs/sql-ref-syntax-ddl-repair-table.md index 41499c3314c45..2e3711c260282 100644 --- a/docs/sql-ref-syntax-ddl-repair-table.md +++ b/docs/sql-ref-syntax-ddl-repair-table.md @@ -41,7 +41,8 @@ MSCK REPAIR TABLE table_identifier [{ADD|DROP|SYNC} PARTITIONS] * **`{ADD|DROP|SYNC} PARTITIONS`** - * If specified, `MSCK REPAIR TABLE` only adds partitions to the session catalog. + Specifies how to recover partitions. If not specified, **ADD** is the default. + * **ADD**, the command adds new partitions to the session catalog for all sub-folder in the base table folder that don't belong to any table partitions. * **DROP**, the command drops all partitions from the session catalog that have non-existing locations in the file system. * **SYNC** is the combination of **DROP** and **ADD**. From 999d3b89b6df14a5ccb94ffc2ffadb82964e9f7d Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Wed, 24 Feb 2021 21:32:19 +0800 Subject: [PATCH 21/60] [SPARK-34515][SQL] Fix NPE if InSet contains null value during getPartitionsByFilter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Skip null value during rewrite `InSet` to `>= and <=` at getPartitionsByFilter. ### Why are the changes needed? Spark will convert `InSet` to `>= and <=` if it's values size over `spark.sql.hive.metastorePartitionPruningInSetThreshold` during pruning partition . At this case, if values contain a null, we will get such exception    ``` java.lang.NullPointerException at org.apache.spark.unsafe.types.UTF8String.compareTo(UTF8String.java:1389) at org.apache.spark.unsafe.types.UTF8String.compareTo(UTF8String.java:50) at scala.math.LowPriorityOrderingImplicits$$anon$3.compare(Ordering.scala:153) at java.util.TimSort.countRunAndMakeAscending(TimSort.java:355) at java.util.TimSort.sort(TimSort.java:220) at java.util.Arrays.sort(Arrays.java:1438) at scala.collection.SeqLike.sorted(SeqLike.scala:659) at scala.collection.SeqLike.sorted$(SeqLike.scala:647) at scala.collection.AbstractSeq.sorted(Seq.scala:45) at org.apache.spark.sql.hive.client.Shim_v0_13.convert$1(HiveShim.scala:772) at org.apache.spark.sql.hive.client.Shim_v0_13.$anonfun$convertFilters$4(HiveShim.scala:826) at scala.collection.immutable.Stream.flatMap(Stream.scala:489) at org.apache.spark.sql.hive.client.Shim_v0_13.convertFilters(HiveShim.scala:826) at org.apache.spark.sql.hive.client.Shim_v0_13.getPartitionsByFilter(HiveShim.scala:848) at org.apache.spark.sql.hive.client.HiveClientImpl.$anonfun$getPartitionsByFilter$1(HiveClientImpl.scala:750) ``` ### Does this PR introduce _any_ user-facing change? Yes, bug fix. ### How was this patch tested? Add test. Closes #31632 from ulysses-you/SPARK-34515. Authored-by: ulysses-you Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/hive/client/HiveShim.scala | 4 +++- .../org/apache/spark/sql/hive/client/FiltersSuite.scala | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index ed088648bc20a..8ccb17ce35925 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -769,7 +769,9 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case InSet(child, values) if useAdvanced && values.size > inSetThreshold => val dataType = child.dataType - val sortedValues = values.toSeq.sorted(TypeUtils.getInterpretedOrdering(dataType)) + // Skip null here is safe, more details could see at ExtractableLiterals. + val sortedValues = values.filter(_ != null).toSeq + .sorted(TypeUtils.getInterpretedOrdering(dataType)) convert(And(GreaterThanOrEqual(child, Literal(sortedValues.head, dataType)), LessThanOrEqual(child, Literal(sortedValues.last, dataType)))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 12ed0e5305299..6962f9dd6b186 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -179,5 +179,13 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { } } + test("SPARK-34515: Fix NPE if InSet contains null value during getPartitionsByFilter") { + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD.key -> "2") { + val filter = InSet(a("p", IntegerType), Set(null, 1, 2)) + val converted = shim.convertFilters(testTable, Seq(filter), conf.sessionLocalTimeZone) + assert(converted == "(p >= 1 and p <= 2)") + } + } + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() } From b17754a8cbd2593eb2b1952e95a7eeb0f8e09cdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 24 Feb 2021 11:46:27 -0800 Subject: [PATCH 22/60] [SPARK-32617][K8S][TESTS] Configure kubernetes client based on kubeconfig settings in kubernetes integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? From [minikube version v1.1.0](https://github.com/kubernetes/minikube/blob/v1.1.0/CHANGELOG.md) kubectl is available as a command. So the kubeconfig settings can be accessed like: ``` $ minikube kubectl config view apiVersion: v1 clusters: - cluster: certificate-authority: /Users/attilazsoltpiros/.minikube/ca.crt server: https://127.0.0.1:32788 name: minikube contexts: - context: cluster: minikube namespace: default user: minikube name: minikube current-context: minikube kind: Config preferences: {} users: - name: minikube user: client-certificate: /Users/attilazsoltpiros/.minikube/profiles/minikube/client.crt client-key: /Users/attilazsoltpiros/.minikube/profiles/minikube/client.key ``` Here the vm-driver was docker and the server port (https://127.0.0.1:32788) is different from the hardcoded 8443. So the main part of this PR is introducing kubernetes client configuration based on the kubeconfig (output of `minikube kubectl config view`) in case of minikube versions after v1.1.0 and the old legacy way of configuration is also kept as minikube version should be supported back to v0.34.1 . Moreover as the old style of config parsing pattern wasn't sufficient in my case as when the `minikube kubectl config view` is called kubectl downloading message might be included before the first key I changed it even for the existent keys to be a consistent pattern in this file. The old parsing in an example: ``` private val HOST_PREFIX = "host:" val hostString = statusString.find(_.contains(s"$HOST_PREFIX ")) val status1 = hostString.get.split(HOST_PREFIX)(1) ``` The new parsing: ``` private val HOST_PREFIX = "host: " val hostString = statusString.find(_.contains(HOST_PREFIX)) hostString.get.split(HOST_PREFIX)(1) ``` So the PREFIX is extended with the extra space at the declaration (this way the two separate string operation are more safe and consistent with each other) and the replace is changed to split and getting the 2nd string from the result (which is guaranteed to contain only the text after the PREFIX when the PREFIX is a contained substring). Finally there is tiny change in `dev-run-integration-tests.sh` to introduce `--skip-building-dependencies` which switchs off building of maven dependencies of `kubernetes-integration-tests` from the Spark project. This could be used when only the `kubernetes-integration-tests` should be rebuilded as only the tests are modified. ### Why are the changes needed? Kubernetes client configuration based on kubeconfig settings is more reliable and provides a solution which is minikube version independent. ### Does this PR introduce _any_ user-facing change? No. This is only test code. ### How was this patch tested? tested manually on two minikube versions. Minikube v0.34.1: ``` $ minikube version minikube version: v0.34.1 $ grep "version\|building" resource-managers/kubernetes/integration-tests/target/integration-tests.log 20/12/12 12:52:25.135 ScalaTest-main-running-DiscoverySuite INFO Minikube: minikube version: v0.34.1 20/12/12 12:52:25.761 ScalaTest-main-running-DiscoverySuite INFO Minikube: building kubernetes config with apiVersion: v1, masterUrl: https://192.168.99.103:8443, caCertFile: /Users/attilazsoltpiros/.minikube/ca.crt, clientCertFile: /Users/attilazsoltpiros/.minikube/apiserver.crt, clientKeyFile: /Users/attilazsoltpiros/.minikube/apiserver.key ``` Minikube v1.15.1 ``` $ minikube version minikube version: v1.15.1 commit: 23f40a012abb52eff365ff99a709501a61ac5876 $ grep "version\|building" resource-managers/kubernetes/integration-tests/target/integration-tests.log 20/12/13 06:25:55.086 ScalaTest-main-running-DiscoverySuite INFO Minikube: minikube version: v1.15.1 20/12/13 06:25:55.597 ScalaTest-main-running-DiscoverySuite INFO Minikube: building kubernetes config with apiVersion: v1, masterUrl: https://192.168.64.4:8443, caCertFile: /Users/attilazsoltpiros/.minikube/ca.crt, clientCertFile: /Users/attilazsoltpiros/.minikube/profiles/minikube/client.crt, clientKeyFile: /Users/attilazsoltpiros/.minikube/profiles/minikube/client.key $ minikube kubectl config view apiVersion: v1 clusters: - cluster: certificate-authority: /Users/attilazsoltpiros/.minikube/ca.crt server: https://192.168.64.4:8443 name: minikube contexts: - context: cluster: minikube namespace: default user: minikube name: minikube current-context: minikube kind: Config preferences: {} users: - name: minikube user: client-certificate: /Users/attilazsoltpiros/.minikube/profiles/minikube/client.crt client-key: /Users/attilazsoltpiros/.minikube/profiles/minikube/client.key ``` Closes #30751 from attilapiros/SPARK-32617. Authored-by: “attilapiros” Signed-off-by: Holden Karau --- .../dev/dev-run-integration-tests.sh | 6 +- .../backend/minikube/Minikube.scala | 138 ++++++++++++------ 2 files changed, 101 insertions(+), 43 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index b72a4f74918ba..c87437e48f589 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -35,6 +35,7 @@ CONTEXT= INCLUDE_TAGS="k8s" EXCLUDE_TAGS= JAVA_VERSION="8" +BUILD_DEPENDENCIES_MVN_FLAG="-am" HADOOP_PROFILE="hadoop-3.2" MVN="$TEST_ROOT_DIR/build/mvn" @@ -117,6 +118,9 @@ while (( "$#" )); do HADOOP_PROFILE="$2" shift ;; + --skip-building-dependencies) + BUILD_DEPENDENCIES_MVN_FLAG="" + ;; *) echo "Unexpected command line flag $2 $1." exit 1 @@ -176,4 +180,4 @@ properties+=( -Dlog4j.logger.org.apache.spark=DEBUG ) -$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-$SCALA_VERSION -P$HADOOP_PROFILE -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests $BUILD_DEPENDENCIES_MVN_FLAG -Pscala-$SCALA_VERSION -P$HADOOP_PROFILE -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala index c33875243c598..5cb068545ef37 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.deploy.k8s.integrationtest.backend.minikube -import java.nio.file.{Files, Paths} +import java.nio.file.Paths -import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient} +import io.fabric8.kubernetes.client.{Config, ConfigBuilder, DefaultKubernetesClient} import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils import org.apache.spark.internal.Logging @@ -26,18 +26,26 @@ import org.apache.spark.internal.Logging // TODO support windows private[spark] object Minikube extends Logging { private val MINIKUBE_STARTUP_TIMEOUT_SECONDS = 60 - private val HOST_PREFIX = "host:" - private val KUBELET_PREFIX = "kubelet:" - private val APISERVER_PREFIX = "apiserver:" - private val KUBECTL_PREFIX = "kubectl:" - private val KUBECONFIG_PREFIX = "kubeconfig:" + private val VERSION_PREFIX = "minikube version: " + private val HOST_PREFIX = "host: " + private val KUBELET_PREFIX = "kubelet: " + private val APISERVER_PREFIX = "apiserver: " + private val KUBECTL_PREFIX = "kubectl: " + private val KUBECONFIG_PREFIX = "kubeconfig: " private val MINIKUBE_VM_PREFIX = "minikubeVM: " private val MINIKUBE_PREFIX = "minikube: " private val MINIKUBE_PATH = ".minikube" + private val APIVERSION_PREFIX = "apiVersion: " + private val SERVER_PREFIX = "server: " + private val CA_PREFIX = "certificate-authority: " + private val CLIENTCERT_PREFIX = "client-certificate: " + private val CLIENTKEY_PREFIX = "client-key: " - def logVersion(): Unit = { - logInfo(executeMinikube("version").mkString("\n")) - } + lazy val minikubeVersionString = + executeMinikube("version").find(_.contains(VERSION_PREFIX)).get + + def logVersion(): Unit = + logInfo(minikubeVersionString) def getMinikubeIp: String = { val outputs = executeMinikube("ip") @@ -56,60 +64,106 @@ private[spark] object Minikube extends Logging { if (oldMinikube.isEmpty) { getIfNewMinikubeStatus(statusString) } else { - val finalStatusString = oldMinikube - .head - .replaceFirst(MINIKUBE_VM_PREFIX, "") - .replaceFirst(MINIKUBE_PREFIX, "") + val statusLine = oldMinikube.head + val finalStatusString = if (statusLine.contains(MINIKUBE_VM_PREFIX)) { + statusLine.split(MINIKUBE_VM_PREFIX)(1) + } else { + statusLine.split(MINIKUBE_PREFIX)(1) + } MinikubeStatus.unapply(finalStatusString) .getOrElse(throw new IllegalStateException(s"Unknown status $statusString")) } } def getKubernetesClient: DefaultKubernetesClient = { + // only the three-part version number is matched (the optional suffix like "-beta.0" is dropped) + val versionArrayOpt = "\\d+\\.\\d+\\.\\d+".r + .findFirstIn(minikubeVersionString.split(VERSION_PREFIX)(1)) + .map(_.split('.').map(_.toInt)) + + assert(versionArrayOpt.isDefined && versionArrayOpt.get.size == 3, + s"Unexpected version format detected in `$minikubeVersionString`." + + "For minikube version a three-part version number is expected (the optional non-numeric " + + "suffix is intentionally dropped)") + + val kubernetesConf = versionArrayOpt.get match { + case Array(x, y, z) => + // comparing the versions as the kubectl command is only introduced in version v1.1.0: + // https://github.com/kubernetes/minikube/blob/v1.1.0/CHANGELOG.md + if (Ordering.Tuple3[Int, Int, Int].gteq((x, y, z), (1, 1, 0))) { + kubectlBasedKubernetesClientConf + } else { + legacyKubernetesClientConf + } + } + new DefaultKubernetesClient(kubernetesConf) + } + + private def legacyKubernetesClientConf: Config = { val kubernetesMaster = s"https://${getMinikubeIp}:8443" val userHome = System.getProperty("user.home") - val minikubeBasePath = Paths.get(userHome, MINIKUBE_PATH).toString - val profileDir = if (Files.exists(Paths.get(minikubeBasePath, "apiserver.crt"))) { - // For Minikube <1.9 - "" - } else { - // For Minikube >=1.9 - Paths.get("profiles", executeMinikube("profile")(0)).toString - } - val apiServerCertPath = Paths.get(minikubeBasePath, profileDir, "apiserver.crt") - val apiServerKeyPath = Paths.get(minikubeBasePath, profileDir, "apiserver.key") - val kubernetesConf = new ConfigBuilder() - .withApiVersion("v1") - .withMasterUrl(kubernetesMaster) - .withCaCertFile( - Paths.get(userHome, MINIKUBE_PATH, "ca.crt").toFile.getAbsolutePath) - .withClientCertFile(apiServerCertPath.toFile.getAbsolutePath) - .withClientKeyFile(apiServerKeyPath.toFile.getAbsolutePath) + buildKubernetesClientConf( + "v1", + kubernetesMaster, + Paths.get(userHome, MINIKUBE_PATH, "ca.crt").toFile.getAbsolutePath, + Paths.get(userHome, MINIKUBE_PATH, "apiserver.crt").toFile.getAbsolutePath, + Paths.get(userHome, MINIKUBE_PATH, "apiserver.key").toFile.getAbsolutePath) + } + + private def kubectlBasedKubernetesClientConf: Config = { + val outputs = executeMinikube("kubectl config view") + val apiVersionString = outputs.find(_.contains(APIVERSION_PREFIX)) + val serverString = outputs.find(_.contains(SERVER_PREFIX)) + val caString = outputs.find(_.contains(CA_PREFIX)) + val clientCertString = outputs.find(_.contains(CLIENTCERT_PREFIX)) + val clientKeyString = outputs.find(_.contains(CLIENTKEY_PREFIX)) + + assert(!apiVersionString.isEmpty && !serverString.isEmpty && !caString.isEmpty && + !clientKeyString.isEmpty && !clientKeyString.isEmpty, + "The output of 'minikube kubectl config view' does not contain all the neccesary attributes") + + buildKubernetesClientConf( + apiVersionString.get.split(APIVERSION_PREFIX)(1), + serverString.get.split(SERVER_PREFIX)(1), + caString.get.split(CA_PREFIX)(1), + clientCertString.get.split(CLIENTCERT_PREFIX)(1), + clientKeyString.get.split(CLIENTKEY_PREFIX)(1)) + } + + private def buildKubernetesClientConf(apiVersion: String, masterUrl: String, caCertFile: String, + clientCertFile: String, clientKeyFile: String): Config = { + logInfo(s"building kubernetes config with apiVersion: $apiVersion, masterUrl: $masterUrl, " + + s"caCertFile: $caCertFile, clientCertFile: $clientCertFile, clientKeyFile: $clientKeyFile") + new ConfigBuilder() + .withApiVersion(apiVersion) + .withMasterUrl(masterUrl) + .withCaCertFile(caCertFile) + .withClientCertFile(clientCertFile) + .withClientKeyFile(clientKeyFile) .build() - new DefaultKubernetesClient(kubernetesConf) } // Covers minikube status output after Minikube V0.30. private def getIfNewMinikubeStatus(statusString: Seq[String]): MinikubeStatus.Value = { - val hostString = statusString.find(_.contains(s"$HOST_PREFIX ")) - val kubeletString = statusString.find(_.contains(s"$KUBELET_PREFIX ")) - val apiserverString = statusString.find(_.contains(s"$APISERVER_PREFIX ")) - val kubectlString = statusString.find(_.contains(s"$KUBECTL_PREFIX ")) - val kubeconfigString = statusString.find(_.contains(s"$KUBECONFIG_PREFIX ")) + val hostString = statusString.find(_.contains(HOST_PREFIX)) + val kubeletString = statusString.find(_.contains(KUBELET_PREFIX)) + val apiserverString = statusString.find(_.contains(APISERVER_PREFIX)) + val kubectlString = statusString.find(_.contains(KUBECTL_PREFIX)) + val kubeconfigString = statusString.find(_.contains(KUBECONFIG_PREFIX)) val hasConfigStatus = kubectlString.isDefined || kubeconfigString.isDefined if (hostString.isEmpty || kubeletString.isEmpty || apiserverString.isEmpty || !hasConfigStatus) { MinikubeStatus.NONE } else { - val status1 = hostString.get.replaceFirst(s"$HOST_PREFIX ", "") - val status2 = kubeletString.get.replaceFirst(s"$KUBELET_PREFIX ", "") - val status3 = apiserverString.get.replaceFirst(s"$APISERVER_PREFIX ", "") + val status1 = hostString.get.split(HOST_PREFIX)(1) + val status2 = kubeletString.get.split(KUBELET_PREFIX)(1) + val status3 = apiserverString.get.split(APISERVER_PREFIX)(1) val isConfigured = if (kubectlString.isDefined) { - val cfgStatus = kubectlString.get.replaceFirst(s"$KUBECTL_PREFIX ", "") + val cfgStatus = kubectlString.get.split(KUBECTL_PREFIX)(1) cfgStatus.contains("Correctly Configured:") } else { - kubeconfigString.get.replaceFirst(s"$KUBECONFIG_PREFIX ", "") == "Configured" + kubeconfigString.get.split(KUBECONFIG_PREFIX)(1) == "Configured" } if (isConfigured) { val stats = List(status1, status2, status3) From 44eadb943bbcec48e90398731f57a32a967d81bb Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 25 Feb 2021 09:25:17 +0900 Subject: [PATCH 23/60] [SPARK-34497][SQL] Fix built-in JDBC connection providers to restore JVM security context changes ### What changes were proposed in this pull request? Some of the built-in JDBC connection providers are changing the JVM security context to do the authentication which is fine. The problematic part is that executors can be reused by another query. The following situation leads to incorrect behaviour: * Query1 opens JDBC connection and changes JVM security context in Executor1 * Query2 tries to open JDBC connection but it realizes there is already an entry for that DB type in Executor1 * Query2 is not changing JVM security context and uses Query1 keytab and principal * Query2 fails with authentication error In this PR I've changed to code such a way that JVM security context is changed all the time but only temporarily until the connection built-up and then rolled back. Since `getConnection` is synchronised with `SecurityConfigurationLock` it ends-up in correct behaviour without any race. ### Why are the changes needed? Incorrect JVM security context handling. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit + integration tests. Closes #31622 from gaborgsomogyi/SPARK-34497. Authored-by: Gabor Somogyi Signed-off-by: HyukjinKwon --- .../jdbc/connection/ConnectionProvider.scala | 11 ++++++++- .../connection/DB2ConnectionProvider.scala | 9 +------ .../connection/MSSQLConnectionProvider.scala | 16 +------------ .../MariaDBConnectionProvider.scala | 17 +------------ .../connection/OracleConnectionProvider.scala | 9 +------ .../PostgresConnectionProvider.scala | 7 ------ .../connection/SecureConnectionProvider.scala | 20 ++-------------- .../connection/ConnectionProviderSuite.scala | 24 ++++++++++++------- .../ConnectionProviderSuiteBase.scala | 8 ++----- .../DB2ConnectionProviderSuite.scala | 2 +- .../MSSQLConnectionProviderSuite.scala | 4 ++-- .../MariaDBConnectionProviderSuite.scala | 2 +- .../OracleConnectionProviderSuite.scala | 2 +- .../PostgresConnectionProviderSuite.scala | 2 +- 14 files changed, 39 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index e81add4df960a..fbc69704f1479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection import java.sql.{Connection, Driver} import java.util.ServiceLoader +import javax.security.auth.login.Configuration import scala.collection.mutable @@ -60,7 +61,15 @@ private[jdbc] object ConnectionProvider extends Logging { "JDBC connection initiated but not exactly one connection provider found which can handle " + s"it. Found active providers: ${filteredProviders.mkString(", ")}") SecurityConfigurationLock.synchronized { - filteredProviders.head.getConnection(driver, options) + // Inside getConnection it's safe to get parent again because SecurityConfigurationLock + // makes sure it's untouched + val parent = Configuration.getConfiguration + try { + filteredProviders.head.getConnection(driver, options) + } finally { + logDebug("Restoring original security configuration") + Configuration.setConfiguration(parent) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala index 775c3ae4a533a..060653c5a8b79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala @@ -34,7 +34,7 @@ private[sql] class DB2ConnectionProvider extends SecureConnectionProvider { override def getConnection(driver: Driver, options: Map[String, String]): Connection = { val jdbcOptions = new JDBCOptions(options) - setAuthenticationConfigIfNeeded(driver, jdbcOptions) + setAuthenticationConfig(driver, jdbcOptions) UserGroupInformation.loginUserFromKeytabAndReturnUGI(jdbcOptions.principal, jdbcOptions.keytab) .doAs( new PrivilegedExceptionAction[Connection]() { @@ -52,11 +52,4 @@ private[sql] class DB2ConnectionProvider extends SecureConnectionProvider { result.put("KerberosServerPrincipal", options.principal) result } - - override def setAuthenticationConfigIfNeeded(driver: Driver, options: JDBCOptions): Unit = { - val (parent, configEntry) = getConfigWithAppEntry(driver, options) - if (configEntry == null || configEntry.isEmpty) { - setAuthenticationConfig(parent, driver, options) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProvider.scala index e3d3e1a43d510..aa8c9227377c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProvider.scala @@ -61,7 +61,7 @@ private[sql] class MSSQLConnectionProvider extends SecureConnectionProvider { override def getConnection(driver: Driver, options: Map[String, String]): Connection = { val jdbcOptions = new JDBCOptions(options) - setAuthenticationConfigIfNeeded(driver, jdbcOptions) + setAuthenticationConfig(driver, jdbcOptions) UserGroupInformation.loginUserFromKeytabAndReturnUGI(jdbcOptions.principal, jdbcOptions.keytab) .doAs( new PrivilegedExceptionAction[Connection]() { @@ -79,18 +79,4 @@ private[sql] class MSSQLConnectionProvider extends SecureConnectionProvider { result.put("authenticationScheme", "JavaKerberos") result } - - override def setAuthenticationConfigIfNeeded(driver: Driver, options: JDBCOptions): Unit = { - val (parent, configEntry) = getConfigWithAppEntry(driver, options) - /** - * Couple of things to mention here (v8.2.2 client): - * 1. MS SQL supports JAAS application name configuration - * 2. MS SQL sets a default JAAS config if "java.security.auth.login.config" is not set - */ - val entryUsesKeytab = configEntry != null && - configEntry.exists(_.getOptions().get("useKeyTab") == "true") - if (configEntry == null || configEntry.isEmpty || !entryUsesKeytab) { - setAuthenticationConfig(parent, driver, options) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala index 29a08d0b5f269..6a53c663a2773 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala @@ -26,20 +26,5 @@ private[jdbc] class MariaDBConnectionProvider extends SecureConnectionProvider { override val name: String = "mariadb" - override def appEntry(driver: Driver, options: JDBCOptions): String = - "Krb5ConnectorContext" - - override def setAuthenticationConfigIfNeeded(driver: Driver, options: JDBCOptions): Unit = { - val (parent, configEntry) = getConfigWithAppEntry(driver, options) - /** - * Couple of things to mention here (v2.5.4 client): - * 1. MariaDB doesn't support JAAS application name configuration - * 2. MariaDB sets a default JAAS config if "java.security.auth.login.config" is not set - */ - val entryUsesKeytab = configEntry != null && - configEntry.exists(_.getOptions().get("useKeyTab") == "true") - if (configEntry == null || configEntry.isEmpty || !entryUsesKeytab) { - setAuthenticationConfig(parent, driver, options) - } - } + override def appEntry(driver: Driver, options: JDBCOptions): String = "Krb5ConnectorContext" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProvider.scala index 0d43851bb255e..ef8d91b5aa8f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProvider.scala @@ -34,7 +34,7 @@ private[sql] class OracleConnectionProvider extends SecureConnectionProvider { override def getConnection(driver: Driver, options: Map[String, String]): Connection = { val jdbcOptions = new JDBCOptions(options) - setAuthenticationConfigIfNeeded(driver, jdbcOptions) + setAuthenticationConfig(driver, jdbcOptions) UserGroupInformation.loginUserFromKeytabAndReturnUGI(jdbcOptions.principal, jdbcOptions.keytab) .doAs( new PrivilegedExceptionAction[Connection]() { @@ -53,11 +53,4 @@ private[sql] class OracleConnectionProvider extends SecureConnectionProvider { result.put("oracle.net.authentication_services", "(KERBEROS5)"); result } - - override def setAuthenticationConfigIfNeeded(driver: Driver, options: JDBCOptions): Unit = { - val (parent, configEntry) = getConfigWithAppEntry(driver, options) - if (configEntry == null || configEntry.isEmpty) { - setAuthenticationConfig(parent, driver, options) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala index f26a11e34dc38..ec9ac66147e99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala @@ -32,11 +32,4 @@ private[jdbc] class PostgresConnectionProvider extends SecureConnectionProvider val properties = parseURL.invoke(driver, options.url, null).asInstanceOf[Properties] properties.getProperty("jaasApplicationName", "pgjdbc") } - - override def setAuthenticationConfigIfNeeded(driver: Driver, options: JDBCOptions): Unit = { - val (parent, configEntry) = getConfigWithAppEntry(driver, options) - if (configEntry == null || configEntry.isEmpty) { - setAuthenticationConfig(parent, driver, options) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala index 80c795957dac8..4138c7216970f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala @@ -40,7 +40,7 @@ private[jdbc] abstract class SecureConnectionProvider extends BasicConnectionPro override def getConnection(driver: Driver, options: Map[String, String]): Connection = { val jdbcOptions = new JDBCOptions(options) - setAuthenticationConfigIfNeeded(driver, jdbcOptions) + setAuthenticationConfig(driver, jdbcOptions) super.getConnection(driver: Driver, options: Map[String, String]) } @@ -49,24 +49,8 @@ private[jdbc] abstract class SecureConnectionProvider extends BasicConnectionPro */ def appEntry(driver: Driver, options: JDBCOptions): String - /** - * Sets database specific authentication configuration when needed. If configuration already set - * then later calls must be no op. When the global JVM security configuration changed then the - * related code parts must be synchronized properly. - */ - def setAuthenticationConfigIfNeeded(driver: Driver, options: JDBCOptions): Unit - - protected def getConfigWithAppEntry( - driver: Driver, - options: JDBCOptions): (Configuration, Array[AppConfigurationEntry]) = { + private[connection] def setAuthenticationConfig(driver: Driver, options: JDBCOptions) = { val parent = Configuration.getConfiguration - (parent, parent.getAppConfigurationEntry(appEntry(driver, options))) - } - - protected def setAuthenticationConfig( - parent: Configuration, - driver: Driver, - options: JDBCOptions) = { val config = new SecureConnectionProvider.JDBCConfiguration( parent, appEntry(driver, options), options.keytab, options.principal) logDebug("Adding database specific security configuration") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala index 0e9498b2681e2..71b0325f93732 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala @@ -57,17 +57,23 @@ class ConnectionProviderSuite extends ConnectionProviderSuiteBase with SharedSpa val db2AppEntry = db2Provider.appEntry(db2Driver, db2Options) // Make sure no authentication for the databases are set - val oldConfig = Configuration.getConfiguration - assert(oldConfig.getAppConfigurationEntry(postgresAppEntry) == null) - assert(oldConfig.getAppConfigurationEntry(db2AppEntry) == null) + val rootConfig = Configuration.getConfiguration + assert(rootConfig.getAppConfigurationEntry(postgresAppEntry) == null) + assert(rootConfig.getAppConfigurationEntry(db2AppEntry) == null) - postgresProvider.setAuthenticationConfigIfNeeded(postgresDriver, postgresOptions) - db2Provider.setAuthenticationConfigIfNeeded(db2Driver, db2Options) + postgresProvider.setAuthenticationConfig(postgresDriver, postgresOptions) + val postgresConfig = Configuration.getConfiguration + + db2Provider.setAuthenticationConfig(db2Driver, db2Options) + val db2Config = Configuration.getConfiguration // Make sure authentication for the databases are set - val newConfig = Configuration.getConfiguration - assert(oldConfig != newConfig) - assert(newConfig.getAppConfigurationEntry(postgresAppEntry) != null) - assert(newConfig.getAppConfigurationEntry(db2AppEntry) != null) + assert(rootConfig != postgresConfig) + assert(rootConfig != db2Config) + // The topmost config in the chain is linked with all the subsequent entries + assert(db2Config.getAppConfigurationEntry(postgresAppEntry) != null) + assert(db2Config.getAppConfigurationEntry(db2AppEntry) != null) + + Configuration.setConfiguration(null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala index a299841b3c149..f42b17abf31bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala @@ -59,16 +59,12 @@ abstract class ConnectionProviderSuiteBase extends SparkFunSuite with BeforeAndA // Make sure no authentication for the database is set assert(Configuration.getConfiguration.getAppConfigurationEntry(providerAppEntry) == null) - // Make sure the first call sets authentication properly + // Make sure setAuthenticationConfig call sets authentication properly val savedConfig = Configuration.getConfiguration - provider.setAuthenticationConfigIfNeeded(driver, options) + provider.setAuthenticationConfig(driver, options) val config = Configuration.getConfiguration assert(savedConfig != config) val appEntry = config.getAppConfigurationEntry(providerAppEntry) assert(appEntry != null) - - // Make sure a second call is not modifying the existing authentication - provider.setAuthenticationConfigIfNeeded(driver, options) - assert(config.getAppConfigurationEntry(providerAppEntry) === appEntry) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProviderSuite.scala index 5885af82532d4..895b3d85d960b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProviderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection class DB2ConnectionProviderSuite extends ConnectionProviderSuiteBase { - test("setAuthenticationConfigIfNeeded must set authentication if not set") { + test("setAuthenticationConfig must set authentication all the time") { val provider = new DB2ConnectionProvider() val driver = registerDriver(provider.driverClass) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProviderSuite.scala index a5704e842e018..a0b9af2d82e13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MSSQLConnectionProviderSuite.scala @@ -22,7 +22,7 @@ import java.sql.Driver import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions class MSSQLConnectionProviderSuite extends ConnectionProviderSuiteBase { - test("setAuthenticationConfigIfNeeded default parser must set authentication if not set") { + test("setAuthenticationConfig default parser must set authentication all the time") { val provider = new MSSQLConnectionProvider() val driver = registerDriver(provider.driverClass) @@ -30,7 +30,7 @@ class MSSQLConnectionProviderSuite extends ConnectionProviderSuiteBase { options("jdbc:sqlserver://localhost/mssql;jaasConfigurationName=custommssql")) } - test("setAuthenticationConfigIfNeeded custom parser must set authentication if not set") { + test("setAuthenticationConfig custom parser must set authentication all the time") { val provider = new MSSQLConnectionProvider() { override val parserMethod: String = "IntentionallyNotExistingMethod" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala index f450662fcbe74..d8bdf26b35c7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection class MariaDBConnectionProviderSuite extends ConnectionProviderSuiteBase { - test("setAuthenticationConfigIfNeeded must set authentication if not set") { + test("setAuthenticationConfig must set authentication all the time") { val provider = new MariaDBConnectionProvider() val driver = registerDriver(provider.driverClass) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProviderSuite.scala index 40e7f1191dccc..4aaaf8168eb53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProviderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection class OracleConnectionProviderSuite extends ConnectionProviderSuiteBase { - test("setAuthenticationConfigIfNeeded must set authentication if not set") { + test("setAuthenticationConfig must set authentication all the time") { val provider = new OracleConnectionProvider() val driver = registerDriver(provider.driverClass) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala index ee43a7d9708c5..5006bf4091380 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection class PostgresConnectionProviderSuite extends ConnectionProviderSuiteBase { - test("setAuthenticationConfigIfNeeded must set authentication if not set") { + test("setAuthenticationConfig must set authentication all the time") { val provider = new PostgresConnectionProvider() val defaultOptions = options("jdbc:postgresql://localhost/postgres") val customOptions = From 22383e312d00ec8888cdc2d12750b7d1e7e21d99 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 24 Feb 2021 18:11:25 -0800 Subject: [PATCH 24/60] [SPARK-34531][CORE] Remove Experimental API tag in PrometheusServlet ### What changes were proposed in this pull request? The endpoints of Prometheus metrics are properly marked and documented as an experimental (SPARK-31674). The class `PrometheusServlet` itself is not the part of an API so this PR proposes to remove it. ### Why are the changes needed? To avoid marking a non-API as an API. ### Does this PR introduce _any_ user-facing change? No, the class is already `private[spark]`. ### How was this patch tested? Existing tests should cover. Closes #31640 from HyukjinKwon/SPARK-34531. Lead-authored-by: HyukjinKwon Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/metrics/sink/PrometheusServlet.scala | 3 --- docs/monitoring.md | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala index 0f8fbd3ba2e9e..7cc2665ee7eee 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala @@ -24,18 +24,15 @@ import com.codahale.metrics.MetricRegistry import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.annotation.Experimental import org.apache.spark.ui.JettyUtils._ /** - * :: Experimental :: * This exposes the metrics of the given registry with Prometheus format. * * The output is consistent with /metrics/json result in terms of item ordering * and with the previous result of Spark JMX Sink + Prometheus JMX Converter combination * in terms of key string format. */ -@Experimental private[spark] class PrometheusServlet( val property: Properties, val registry: MetricRegistry, diff --git a/docs/monitoring.md b/docs/monitoring.md index 5b3278bca031d..930f91f9a5c2f 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -755,7 +755,7 @@ A list of the available metrics, with a short description: Executor-level metrics are sent from each executor to the driver as part of the Heartbeat to describe the performance metrics of Executor itself like JVM heap memory, GC information. Executor metric values and their measured memory peak values per executor are exposed via the REST API in JSON format and in Prometheus format. The JSON end point is exposed at: `/applications/[app-id]/executors`, and the Prometheus endpoint at: `/metrics/executors/prometheus`. -The Prometheus endpoint is experimental and conditional to a configuration parameter: `spark.ui.prometheus.enabled=true` (the default is `false`). +The Prometheus endpoint is conditional to a configuration parameter: `spark.ui.prometheus.enabled=true` (the default is `false`). In addition, aggregated per-stage peak values of the executor memory metrics are written to the event log if `spark.eventLog.logStageExecutorMetrics` is true. Executor memory metrics are also exposed via the Spark metrics system based on the [Dropwizard metrics library](http://metrics.dropwizard.io/4.1.1). From 8a1e172b513ba58763336de83f94e00ceaa69255 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 24 Feb 2021 20:38:03 -0800 Subject: [PATCH 25/60] [SPARK-34520][CORE] Remove unused SecurityManager references ### What changes were proposed in this pull request? This is kind of a followup of https://github.com/apache/spark/pull/24033 and https://github.com/apache/spark/pull/30945. Many of references in `SecurityManager` were introduced from SPARK-1189, and related usages were removed later from https://github.com/apache/spark/pull/24033 and https://github.com/apache/spark/pull/30945. This PR proposes to remove them out. ### Why are the changes needed? For better readability of codes. ### Does this PR introduce _any_ user-facing change? No, dev-only. ### How was this patch tested? Manually complied. GitHub Actions and Jenkins build should test it out as well. Closes #31636 from HyukjinKwon/SPARK-34520. Authored-by: HyukjinKwon Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/SparkEnv.scala | 7 +++-- .../spark/broadcast/BroadcastFactory.scala | 3 +-- .../spark/broadcast/BroadcastManager.scala | 9 +++---- .../broadcast/TorrentBroadcastFactory.scala | 5 ++-- .../spark/deploy/ExternalShuffleService.scala | 3 +-- .../apache/spark/deploy/master/Master.scala | 4 +-- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/metrics/MetricsSystem.scala | 27 +++++++------------ .../spark/metrics/sink/ConsoleSink.scala | 5 ++-- .../apache/spark/metrics/sink/CsvSink.scala | 5 ++-- .../spark/metrics/sink/GraphiteSink.scala | 5 ++-- .../apache/spark/metrics/sink/JmxSink.scala | 5 ++-- .../spark/metrics/sink/MetricsServlet.scala | 7 ++--- .../metrics/sink/PrometheusServlet.scala | 7 ++--- .../apache/spark/metrics/sink/Slf4jSink.scala | 6 +---- .../spark/metrics/sink/StatsdSink.scala | 6 +---- .../apache/spark/MapOutputTrackerSuite.scala | 3 +-- .../spark/metrics/MetricsSystemSuite.scala | 26 +++++++++--------- .../metrics/sink/GraphiteSinkSuite.scala | 8 +++--- .../metrics/sink/PrometheusServletSuite.scala | 2 +- .../spark/metrics/sink/StatsdSinkSuite.scala | 5 ++-- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 2 +- .../cluster/mesos/MesosClusterScheduler.scala | 5 ++-- .../spark/deploy/yarn/ApplicationMaster.scala | 3 +-- .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- 27 files changed, 64 insertions(+), 102 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 9fc60ac3990fc..ed8dc43b16c96 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -314,7 +314,7 @@ object SparkEnv extends Logging { } } - val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) + val broadcastManager = new BroadcastManager(isDriver, conf) val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf, broadcastManager, isLocal) @@ -397,14 +397,13 @@ object SparkEnv extends Logging { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. // Then we can start the metrics system. - MetricsSystem.createMetricsSystem(MetricsSystemInstances.DRIVER, conf, securityManager) + MetricsSystem.createMetricsSystem(MetricsSystemInstances.DRIVER, conf) } else { // We need to set the executor ID before the MetricsSystem is created because sources and // sinks specified in the metrics configuration file will want to incorporate this executor's // ID into the metrics they report. conf.set(EXECUTOR_ID, executorId) - val ms = MetricsSystem.createMetricsSystem(MetricsSystemInstances.EXECUTOR, conf, - securityManager) + val ms = MetricsSystem.createMetricsSystem(MetricsSystemInstances.EXECUTOR, conf) ms.start(conf.get(METRICS_STATIC_SOURCES_ENABLED)) ms } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index ece4ae6ab0310..9891582501b8b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -19,7 +19,6 @@ package org.apache.spark.broadcast import scala.reflect.ClassTag -import org.apache.spark.SecurityManager import org.apache.spark.SparkConf /** @@ -29,7 +28,7 @@ import org.apache.spark.SparkConf */ private[spark] trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + def initialize(isDriver: Boolean, conf: SparkConf): Unit /** * Creates a new broadcast variable. diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index c93cadf1ab3e8..989a1941d1791 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -24,15 +24,12 @@ import scala.reflect.ClassTag import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.internal.Logging private[spark] class BroadcastManager( - val isDriver: Boolean, - conf: SparkConf, - securityManager: SecurityManager) - extends Logging { + val isDriver: Boolean, conf: SparkConf) extends Logging { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -44,7 +41,7 @@ private[spark] class BroadcastManager( synchronized { if (!initialized) { broadcastFactory = new TorrentBroadcastFactory - broadcastFactory.initialize(isDriver, conf, securityManager) + broadcastFactory.initialize(isDriver, conf) initialized = true } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 65fb5186afae1..6846e1967c4d6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -19,7 +19,7 @@ package org.apache.spark.broadcast import scala.reflect.ClassTag -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SparkConf /** * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like @@ -28,8 +28,7 @@ import org.apache.spark.{SecurityManager, SparkConf} */ private[spark] class TorrentBroadcastFactory extends BroadcastFactory { - override def initialize(isDriver: Boolean, conf: SparkConf, - securityMgr: SecurityManager): Unit = { } + override def initialize(isDriver: Boolean, conf: SparkConf): Unit = { } override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { new TorrentBroadcast[T](value_, id) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index ebfff89308886..eff1e15659fc4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -44,8 +44,7 @@ private[deploy] class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging { protected val masterMetricsSystem = - MetricsSystem.createMetricsSystem(MetricsSystemInstances.SHUFFLE_SERVICE, - sparkConf, securityManager) + MetricsSystem.createMetricsSystem(MetricsSystemInstances.SHUFFLE_SERVICE, sparkConf) private val enabled = sparkConf.get(config.SHUFFLE_SERVICE_ENABLED) private val port = sparkConf.get(config.SHUFFLE_SERVICE_PORT) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 471a3c1b45c39..c964e343ca6c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -87,9 +87,9 @@ private[deploy] class Master( Utils.checkHost(address.host) private val masterMetricsSystem = - MetricsSystem.createMetricsSystem(MetricsSystemInstances.MASTER, conf, securityMgr) + MetricsSystem.createMetricsSystem(MetricsSystemInstances.MASTER, conf) private val applicationMetricsSystem = - MetricsSystem.createMetricsSystem(MetricsSystemInstances.APPLICATIONS, conf, securityMgr) + MetricsSystem.createMetricsSystem(MetricsSystemInstances.APPLICATIONS, conf) private val masterSource = new MasterSource(this) // After onStart, webUi will be set diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index adc953286625a..05e8e5a6b6766 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -190,7 +190,7 @@ private[deploy] class Worker( private var connectionAttemptCount = 0 private val metricsSystem = - MetricsSystem.createMetricsSystem(MetricsSystemInstances.WORKER, conf, securityMgr) + MetricsSystem.createMetricsSystem(MetricsSystemInstances.WORKER, conf) private val workerSource = new WorkerSource(this) val reverseProxy = conf.get(UI_REVERSE_PROXY) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 48f816f649d36..b0c424bdc3f99 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import com.codahale.metrics.{Metric, MetricRegistry} import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.metrics.sink.{MetricsServlet, PrometheusServlet, Sink} @@ -68,10 +68,7 @@ import org.apache.spark.util.Utils * [options] represent the specific property of this source or sink. */ private[spark] class MetricsSystem private ( - val instance: String, - conf: SparkConf, - securityMgr: SecurityManager) - extends Logging { + val instance: String, conf: SparkConf) extends Logging { private[this] val metricsConfig = new MetricsConfig(conf) @@ -200,21 +197,18 @@ private[spark] class MetricsSystem private ( try { if (kv._1 == "servlet") { val servlet = Utils.classForName[MetricsServlet](classPath) - .getConstructor( - classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) - .newInstance(kv._2, registry, securityMgr) + .getConstructor(classOf[Properties], classOf[MetricRegistry]) + .newInstance(kv._2, registry) metricsServlet = Some(servlet) } else if (kv._1 == "prometheusServlet") { val servlet = Utils.classForName[PrometheusServlet](classPath) - .getConstructor( - classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) - .newInstance(kv._2, registry, securityMgr) + .getConstructor(classOf[Properties], classOf[MetricRegistry]) + .newInstance(kv._2, registry) prometheusServlet = Some(servlet) } else { val sink = Utils.classForName[Sink](classPath) - .getConstructor( - classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) - .newInstance(kv._2, registry, securityMgr) + .getConstructor(classOf[Properties], classOf[MetricRegistry]) + .newInstance(kv._2, registry) sinks += sink } } catch { @@ -242,9 +236,8 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem( - instance: String, conf: SparkConf, securityMgr: SecurityManager): MetricsSystem = { - new MetricsSystem(instance, conf, securityMgr) + def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem = { + new MetricsSystem(instance, conf) } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index bfd23168e4003..c8a3e4488a019 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -22,11 +22,10 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} -import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -private[spark] class ConsoleSink(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class ConsoleSink( + val property: Properties, val registry: MetricRegistry) extends Sink { val CONSOLE_DEFAULT_PERIOD = 10 val CONSOLE_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 579b8e0c0e984..101691f640029 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -23,11 +23,10 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{CsvReporter, MetricRegistry} -import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -private[spark] class CsvSink(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class CsvSink( + val property: Properties, val registry: MetricRegistry) extends Sink { val CSV_KEY_PERIOD = "period" val CSV_KEY_UNIT = "unit" val CSV_KEY_DIR = "directory" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 6ce64cd3543fe..1c59e191db531 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -23,11 +23,10 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} import com.codahale.metrics.graphite.{Graphite, GraphiteReporter, GraphiteUDP} -import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -private[spark] class GraphiteSink(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class GraphiteSink( + val property: Properties, val registry: MetricRegistry) extends Sink { val GRAPHITE_DEFAULT_PERIOD = 10 val GRAPHITE_DEFAULT_UNIT = "SECONDS" val GRAPHITE_DEFAULT_PREFIX = "" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index a7b7b5573cfe8..7ca581aee6ba6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -22,10 +22,9 @@ import java.util.Properties import com.codahale.metrics.MetricRegistry import com.codahale.metrics.jmx.JmxReporter -import org.apache.spark.SecurityManager -private[spark] class JmxSink(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class JmxSink( + val property: Properties, val registry: MetricRegistry) extends Sink { val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 7dd27d4fb9bf3..46d2c6821fea1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -26,14 +26,11 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( - val property: Properties, - val registry: MetricRegistry, - securityMgr: SecurityManager) - extends Sink { + val property: Properties, val registry: MetricRegistry) extends Sink { val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala index 7cc2665ee7eee..c087ee7c000c3 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala @@ -23,7 +23,7 @@ import javax.servlet.http.HttpServletRequest import com.codahale.metrics.MetricRegistry import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.ui.JettyUtils._ /** @@ -34,10 +34,7 @@ import org.apache.spark.ui.JettyUtils._ * in terms of key string format. */ private[spark] class PrometheusServlet( - val property: Properties, - val registry: MetricRegistry, - securityMgr: SecurityManager) - extends Sink { + val property: Properties, val registry: MetricRegistry) extends Sink { val SERVLET_KEY_PATH = "path" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 968d5ca809e72..728687f8f78ba 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -22,14 +22,10 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{MetricRegistry, Slf4jReporter} -import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem private[spark] class Slf4jSink( - val property: Properties, - val registry: MetricRegistry, - securityMgr: SecurityManager) - extends Sink { + val property: Properties, val registry: MetricRegistry) extends Sink { val SLF4J_DEFAULT_PERIOD = 10 val SLF4J_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala index 61e74e05169cc..c6e7bcccd4ce9 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala @@ -22,7 +22,6 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry -import org.apache.spark.SecurityManager import org.apache.spark.internal.Logging import org.apache.spark.metrics.MetricsSystem @@ -41,10 +40,7 @@ private[spark] object StatsdSink { } private[spark] class StatsdSink( - val property: Properties, - val registry: MetricRegistry, - securityMgr: SecurityManager) - extends Sink with Logging { + val property: Properties, val registry: MetricRegistry) extends Sink with Logging { import StatsdSink._ val host = property.getProperty(STATSD_KEY_HOST, STATSD_DEFAULT_HOST) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index b5b68f639ffc9..20b040f7c810d 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -35,8 +35,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf private def newTrackerMaster(sparkConf: SparkConf = conf) = { - val broadcastManager = new BroadcastManager(true, sparkConf, - new SecurityManager(sparkConf)) + val broadcastManager = new BroadcastManager(true, sparkConf) new MapOutputTrackerMaster(sparkConf, broadcastManager, true) } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 70b6c9a112142..31d8492510f06 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -40,7 +40,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM } test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) + val metricsSystem = MetricsSystem.createMetricsSystem("default", conf) metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]](Symbol("sources")) val sinks = PrivateMethod[ArrayBuffer[Sink]](Symbol("sinks")) @@ -51,7 +51,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM } test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) + val metricsSystem = MetricsSystem.createMetricsSystem("test", conf) metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]](Symbol("sources")) val sinks = PrivateMethod[ArrayBuffer[Sink]](Symbol("sinks")) @@ -77,7 +77,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = MetricsSystemInstances.DRIVER - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = driverMetricsSystem.buildRegistryName(source) assert(metricName === s"$appId.$executorId.${source.sourceName}") @@ -93,7 +93,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = MetricsSystemInstances.DRIVER - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = driverMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) @@ -109,7 +109,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.app.id", appId) val instanceName = MetricsSystemInstances.DRIVER - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = driverMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) @@ -127,7 +127,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = MetricsSystemInstances.EXECUTOR - val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === s"$appId.$executorId.${source.sourceName}") @@ -143,7 +143,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = MetricsSystemInstances.EXECUTOR - val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) @@ -159,7 +159,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.app.id", appId) val instanceName = MetricsSystemInstances.EXECUTOR - val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) @@ -177,7 +177,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "testInstance" - val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = testMetricsSystem.buildRegistryName(source) @@ -201,7 +201,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, "${spark.app.name}") val instanceName = MetricsSystemInstances.EXECUTOR - val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === s"$appName.$executorId.${source.sourceName}") @@ -219,7 +219,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, namespaceToResolve) val instanceName = MetricsSystemInstances.EXECUTOR - val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = executorMetricsSystem.buildRegistryName(source) // If the user set the spark.metrics.namespace property to an expansion of another property @@ -239,7 +239,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, "${spark.app.name}") val instanceName = MetricsSystemInstances.EXECUTOR - val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) @@ -260,7 +260,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "testInstance" - val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf) val metricName = testMetricsSystem.buildRegistryName(source) diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala index 2369218830215..cf34121fe73dc 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import com.codahale.metrics._ -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite class GraphiteSinkSuite extends SparkFunSuite { @@ -32,9 +32,8 @@ class GraphiteSinkSuite extends SparkFunSuite { props.put("host", "127.0.0.1") props.put("port", "54321") val registry = new MetricRegistry - val securityMgr = new SecurityManager(new SparkConf(false)) - val sink = new GraphiteSink(props, registry, securityMgr) + val sink = new GraphiteSink(props, registry) val gauge = new Gauge[Double] { override def getValue: Double = 1.23 @@ -55,9 +54,8 @@ class GraphiteSinkSuite extends SparkFunSuite { props.put("port", "54321") props.put("regex", "local-[0-9]+.driver.(CodeGenerator|BlockManager)") val registry = new MetricRegistry - val securityMgr = new SecurityManager(new SparkConf(false)) - val sink = new GraphiteSink(props, registry, securityMgr) + val sink = new GraphiteSink(props, registry) val gauge = new Gauge[Double] { override def getValue: Double = 1.23 diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala index 080ca0e41f793..4b5b41c14a21e 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala @@ -69,5 +69,5 @@ class PrometheusServletSuite extends SparkFunSuite with PrivateMethodTester { } private def createPrometheusServlet(): PrometheusServlet = - new PrometheusServlet(new Properties, new MetricRegistry, securityMgr = null) + new PrometheusServlet(new Properties, new MetricRegistry) } diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala index 3d4b8c868d6fc..ff883633d5e7a 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala @@ -24,11 +24,10 @@ import java.util.concurrent.TimeUnit._ import com.codahale.metrics._ -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.metrics.sink.StatsdSink._ class StatsdSinkSuite extends SparkFunSuite { - private val securityMgr = new SecurityManager(new SparkConf(false)) private val defaultProps = Map( STATSD_KEY_PREFIX -> "spark", STATSD_KEY_PERIOD -> "1", @@ -61,7 +60,7 @@ class StatsdSinkSuite extends SparkFunSuite { defaultProps.foreach(e => props.put(e._1, e._2)) props.put(STATSD_KEY_PORT, socket.getLocalPort.toString) val registry = new MetricRegistry - val sink = new StatsdSink(props, registry, securityMgr) + val sink = new StatsdSink(props, registry) try { testCode(socket, sink) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6b332ec1298f5..4c74e4fbb3728 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -319,7 +319,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti cacheLocations.clear() results.clear() securityMgr = new SecurityManager(sc.getConf) - broadcastManager = new BroadcastManager(true, sc.getConf, securityMgr) + broadcastManager = new BroadcastManager(true, sc.getConf) mapOutputTracker = spy(new MyMapOutputTrackerMaster(sc.getConf, broadcastManager)) blockManagerMaster = spy(new MyBlockManagerMaster(sc.getConf)) scheduler = new DAGScheduler( diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 1e9b48102616f..495747b2c7c11 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -55,7 +55,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null protected lazy val securityMgr = new SecurityManager(conf) - protected lazy val bcastManager = new BroadcastManager(true, conf, securityMgr) + protected lazy val bcastManager = new BroadcastManager(true, conf) protected lazy val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) protected lazy val shuffleManager = new SortShuffleManager(conf) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 82d7abfddd82b..055ee0debeb12 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -79,7 +79,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var master: BlockManagerMaster = null var liveListenerBus: LiveListenerBus = null val securityMgr = new SecurityManager(new SparkConf(false)) - val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) + val bcastManager = new BroadcastManager(true, new SparkConf(false)) val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) val shuffleManager = new SortShuffleManager(new SparkConf(false)) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index c7e0869e4bd5c..16cffd03135df 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,7 +30,7 @@ import org.apache.mesos.Protos.{SlaveID => AgentID, TaskState => MesosTaskState, import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason -import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} +import org.apache.spark.{SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.{config, MesosDriverDescription} import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.internal.config._ @@ -125,8 +125,7 @@ private[spark] class MesosClusterScheduler( extends Scheduler with MesosSchedulerUtils with MesosScheduler { var frameworkUrl: String = _ private val metricsSystem = - MetricsSystem.createMetricsSystem(MetricsSystemInstances.MESOS_CLUSTER, conf, - new SecurityManager(conf)) + MetricsSystem.createMetricsSystem(MetricsSystemInstances.MESOS_CLUSTER, conf) private val master = conf.get("spark.master") private val appName = conf.get("spark.app.name") private val queuedCapacity = conf.get(config.MAX_DRIVERS) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index eb927a3c296c0..e7377e05479c5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -481,8 +481,7 @@ private[spark] class ApplicationMaster( rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverRef)) allocator.allocateResources() - val ms = MetricsSystem.createMetricsSystem(MetricsSystemInstances.APPLICATION_MASTER, - sparkConf, securityMgr) + val ms = MetricsSystem.createMetricsSystem(MetricsSystemInstances.APPLICATION_MASTER, sparkConf) val prefix = _sparkConf.get(YARN_METRICS_NAMESPACE).getOrElse(appId) ms.registerSource(new ApplicationMasterSource(prefix, allocator)) // do not register static sources in this case as per SPARK-25277 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 913ab1f46d59e..425e39c5980a1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -71,7 +71,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf, encryptionKey) - val broadcastManager = new BroadcastManager(true, conf, securityMgr) + val broadcastManager = new BroadcastManager(true, conf) val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) val shuffleManager = new SortShuffleManager(conf) val serializer = new KryoSerializer(conf) From c56af69cdf3cc68821e69fa4ef0213b5cc281ab0 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 25 Feb 2021 09:32:41 +0000 Subject: [PATCH 26/60] [SPARK-34518][SQL] Rename `AlterTableRecoverPartitionsCommand` to `RepairTableCommand` ### What changes were proposed in this pull request? Rename the execution node `AlterTableRecoverPartitionsCommand` for the commands: - `MSCK REPAIR TABLE table [{ADD|DROP|SYNC} PARTITIONS]` - `ALTER TABLE table RECOVER PARTITIONS` to `RepairTableCommand`. ### Why are the changes needed? 1. After the PR https://github.com/apache/spark/pull/31499, `ALTER TABLE table RECOVER PARTITIONS` is equal to `MSCK REPAIR TABLE table ADD PARTITIONS`. And mapping of the generic command `MSCK REPAIR TABLE` to the more specific execution node `AlterTableRecoverPartitionsCommand` can confuse devs in the future. 2. `ALTER TABLE table RECOVER PARTITIONS` does not support any options/extensions. So, additional parameters `enableAddPartitions` and `enableDropPartitions` in `AlterTableRecoverPartitionsCommand` confuse as well. ### Does this PR introduce _any_ user-facing change? No because this is internal API. ### How was this patch tested? By running the existing test suites: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *AlterTableRecoverPartitionsSuite" $ build/sbt "test:testOnly *AlterTableRecoverPartitionsParserSuite" $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *MsckRepairTableSuite" $ build/sbt "test:testOnly *MsckRepairTableParserSuite" ``` Closes #31635 from MaxGekk/rename-recover-partitions. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/ResolveSessionCatalog.scala | 8 ++------ .../sql/execution/command/createDataSourceTables.scala | 2 +- .../org/apache/spark/sql/execution/command/ddl.scala | 10 +++++----- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 68c608310e214..dde31f62e06b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -377,11 +377,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) AnalyzeColumnCommand(ident.asTableIdentifier, columnNames, allColumns) case RepairTable(ResolvedV1TableIdentifier(ident), addPartitions, dropPartitions) => - AlterTableRecoverPartitionsCommand( - ident.asTableIdentifier, - addPartitions, - dropPartitions, - "MSCK REPAIR TABLE") + RepairTableCommand(ident.asTableIdentifier, addPartitions, dropPartitions) case LoadData(ResolvedV1TableIdentifier(ident), path, isLocal, isOverwrite, partition) => LoadDataCommand( @@ -422,7 +418,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) ShowColumnsCommand(db, v1TableName, output) case RecoverPartitions(ResolvedV1TableIdentifier(ident)) => - AlterTableRecoverPartitionsCommand( + RepairTableCommand( ident.asTableIdentifier, enableAddPartitions = true, enableDropPartitions = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index b3e48e37c66e2..995d6273ea588 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -189,7 +189,7 @@ case class CreateDataSourceTableAsSelectCommand( case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && sparkSession.sqlContext.conf.manageFilesourcePartitions => // Need to recover partitions into the metastore so our saved data is visible. - sessionState.executePlan(AlterTableRecoverPartitionsCommand( + sessionState.executePlan(RepairTableCommand( table.identifier, enableAddPartitions = true, enableDropPartitions = false)).toRdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index f0219efbf9a98..2fc6d6fd85322 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -591,7 +591,7 @@ case class AlterTableDropPartitionCommand( case class PartitionStatistics(numFiles: Int, totalSize: Long) /** - * Recover Partitions in ALTER TABLE: recover all the partition in the directory of a table and + * Repair a table by recovering all the partition in the directory of the table and * update the catalog. * * The syntax of this command is: @@ -600,11 +600,11 @@ case class PartitionStatistics(numFiles: Int, totalSize: Long) * MSCK REPAIR TABLE table [{ADD|DROP|SYNC} PARTITIONS]; * }}} */ -case class AlterTableRecoverPartitionsCommand( +case class RepairTableCommand( tableName: TableIdentifier, enableAddPartitions: Boolean, enableDropPartitions: Boolean, - cmd: String = "ALTER TABLE RECOVER PARTITIONS") extends RunnableCommand { + cmd: String = "MSCK REPAIR TABLE") extends RunnableCommand { // These are list of statistics that can be collected quickly without requiring a scan of the data // see https://github.com/apache/hive/blob/master/ @@ -654,7 +654,7 @@ case class AlterTableRecoverPartitionsCommand( val threshold = spark.sparkContext.conf.get(RDD_PARALLEL_LISTING_THRESHOLD) val pathFilter = getPathFilter(hadoopConf) - val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) + val evalPool = ThreadUtils.newForkJoinPool("RepairTableCommand", 8) val partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)] = try { scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, @@ -804,7 +804,7 @@ case class AlterTableRecoverPartitionsCommand( private def dropPartitions(catalog: SessionCatalog, fs: FileSystem): Int = { val dropPartSpecs = ThreadUtils.parmap( catalog.listPartitions(tableName), - "AlterTableRecoverPartitionsCommand: non-existing partitions", + "RepairTableCommand: non-existing partitions", maxThreads = 8) { partition => partition.storage.locationUri.flatMap { uri => if (fs.exists(new Path(uri))) None else Some(partition.spec) From 4a3200b08ac3e7733b5a3dc7271d35e6872c5967 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 25 Feb 2021 18:07:39 +0800 Subject: [PATCH 27/60] [SPARK-34436][SQL] DPP support LIKE ANY/ALL expression ### What changes were proposed in this pull request? This pr make DPP support LIKE ANY/ALL expression: ```sql SELECT date_id, product_id FROM fact_sk f JOIN dim_store s ON f.store_id = s.store_id WHERE s.country LIKE ANY ('%D%E%', '%A%B%') ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #31563 from wangyum/SPARK-34436. Lead-authored-by: Yuming Wang Co-authored-by: Yuming Wang Signed-off-by: Wenchen Fan --- .../dynamicpruning/PartitionPruning.scala | 1 + .../sql/DynamicPartitionPruningSuite.scala | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 182c2aaad581c..4b341ec3c762f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -164,6 +164,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { case _: BinaryComparison => true case _: In | _: InSet => true case _: StringPredicate => true + case _: MultiLikeBase => true case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index bc9c3006cddc9..aceb97207878a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -1383,6 +1383,26 @@ abstract class DynamicPartitionPruningSuiteBase checkAnswer(df, Nil) } } + + test("SPARK-34436: DPP support LIKE ANY/ALL expression") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + val df = sql( + """ + |SELECT date_id, product_id FROM fact_sk f + |JOIN dim_store s + |ON f.store_id = s.store_id WHERE s.country LIKE ANY ('%D%E%', '%A%B%') + """.stripMargin) + + checkPartitionPruningPredicate(df, false, true) + + checkAnswer(df, + Row(1030, 2) :: + Row(1040, 2) :: + Row(1050, 2) :: + Row(1060, 2) :: Nil + ) + } + } } class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase From f7ac2d655c756100c33e652402cefc507d2493b7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Feb 2021 12:41:07 -0800 Subject: [PATCH 28/60] [SPARK-34474][SQL] Remove unnecessary Union under Distinct/Deduplicate ### What changes were proposed in this pull request? This patch proposes to let optimizer to remove unnecessary `Union` under `Distinct`/`Deduplicate`. ### Why are the changes needed? For an `Union` under `Distinct`/`Deduplicate`, if its children are all the same, we can just keep one among them and remove the `Union`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Closes #31595 from viirya/remove-union. Authored-by: Liang-Chi Hsieh Signed-off-by: Liang-Chi Hsieh --- .../sql/catalyst/optimizer/Optimizer.scala | 43 ++++ .../optimizer/RemoveNoopUnionSuite.scala | 70 ++++++ .../resources/sql-tests/inputs/explain.sql | 2 +- .../sql-tests/results/explain-aqe.sql.out | 6 +- .../sql-tests/results/explain.sql.out | 6 +- .../sql/DataFrameSetOperationsSuite.scala | 222 +++++++++++------- .../org/apache/spark/sql/SubquerySuite.scala | 4 +- 7 files changed, 259 insertions(+), 94 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopUnionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b08187d0bc3be..717770f9fa1be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -157,6 +157,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, RemoveNoopOperators, + RemoveNoopUnion, CombineUnions) :: Batch("OptimizeLimitZero", Once, OptimizeLimitZero) :: @@ -501,6 +502,48 @@ object RemoveNoopOperators extends Rule[LogicalPlan] { } } +/** + * Remove no-op `Union` from the query plan that do not make any modifications. + */ +object RemoveNoopUnion extends Rule[LogicalPlan] { + /** + * This only removes the `Project` that has only attributes or aliased attributes + * from its child. + */ + private def removeAliasOnlyProject(plan: LogicalPlan): LogicalPlan = plan match { + case p @ Project(projectList, child) => + val aliasOnly = projectList.length == child.output.length && + projectList.zip(child.output).forall { + case (Alias(left: Attribute, _), right) => left.semanticEquals(right) + case (left: Attribute, right) => left.semanticEquals(right) + case _ => false + } + if (aliasOnly) { + child + } else { + p + } + case _ => plan + } + + private def removeUnion(u: Union): Option[LogicalPlan] = { + val unionChildren = u.children.map(removeAliasOnlyProject) + if (unionChildren.tail.forall(unionChildren.head.sameResult(_))) { + Some(u.children.head) + } else { + None + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case d @ Distinct(u: Union) => + removeUnion(u).map(c => d.withNewChildren(Seq(c))).getOrElse(d) + + case d @ Deduplicate(_, u: Union) => + removeUnion(u).map(c => d.withNewChildren(Seq(c))).getOrElse(d) + } +} + /** * Pushes down [[LocalLimit]] beneath UNION ALL and joins. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopUnionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopUnionSuite.scala new file mode 100644 index 0000000000000..807a3eb87298e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopUnionSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class RemoveNoopUnionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("CollapseProject", Once, + CollapseProject) :: + Batch("RemoveNoopUnion", Once, + RemoveNoopUnion) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("SPARK-34474: Remove redundant Union under Distinct") { + val union = Union(testRelation :: testRelation :: Nil) + val distinct = Distinct(union) + val optimized = Optimize.execute(distinct) + comparePlans(optimized, Distinct(testRelation)) + } + + test("SPARK-34474: Remove redundant Union under Deduplicate") { + val union = Union(testRelation :: testRelation :: Nil) + val deduplicate = Deduplicate(testRelation.output, union) + val optimized = Optimize.execute(deduplicate) + comparePlans(optimized, Deduplicate(testRelation.output, testRelation)) + } + + test("SPARK-34474: Do not remove necessary Project 1") { + val child1 = Project(Seq(testRelation.output(0), testRelation.output(1), + (testRelation.output(0) + 1).as("expr")), testRelation) + val child2 = Project(Seq(testRelation.output(0), testRelation.output(1), + (testRelation.output(0) + 2).as("expr")), testRelation) + val union = Union(child1 :: child2 :: Nil) + val distinct = Distinct(union) + val optimized = Optimize.execute(distinct) + comparePlans(optimized, distinct) + } + + test("SPARK-34474: Do not remove necessary Project 2") { + val child1 = Project(Seq(testRelation.output(0), testRelation.output(1)), testRelation) + val child2 = Project(Seq(testRelation.output(1), testRelation.output(0)), testRelation) + val union = Union(child1 :: child2 :: Nil) + val distinct = Distinct(union) + val optimized = Optimize.execute(distinct) + comparePlans(optimized, distinct) + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index fdff1b4eef941..736084597eb79 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -34,7 +34,7 @@ EXPLAIN FORMATTED EXPLAIN FORMATTED SELECT key, val FROM explain_temp1 WHERE key > 0 UNION - SELECT key, val FROM explain_temp1 WHERE key > 0; + SELECT key, val FROM explain_temp1 WHERE key > 1; -- Join EXPLAIN FORMATTED diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index bcb98396b3028..ddfab99b481c2 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -205,7 +205,7 @@ Arguments: isFinalPlan=false EXPLAIN FORMATTED SELECT key, val FROM explain_temp1 WHERE key > 0 UNION - SELECT key, val FROM explain_temp1 WHERE key > 0 + SELECT key, val FROM explain_temp1 WHERE key > 1 -- !query schema struct -- !query output @@ -236,12 +236,12 @@ Condition : (isnotnull(key#x) AND (key#x > 0)) Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp1] -PushedFilters: [IsNotNull(key), GreaterThan(key,0)] +PushedFilters: [IsNotNull(key), GreaterThan(key,1)] ReadSchema: struct (4) Filter Input [2]: [key#x, val#x] -Condition : (isnotnull(key#x) AND (key#x > 0)) +Condition : (isnotnull(key#x) AND (key#x > 1)) (5) Union diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index a72a5f0a2aa86..1f7f8f6615727 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -203,7 +203,7 @@ Input [3]: [key#x, max(val)#x, max(val#x)#x] EXPLAIN FORMATTED SELECT key, val FROM explain_temp1 WHERE key > 0 UNION - SELECT key, val FROM explain_temp1 WHERE key > 0 + SELECT key, val FROM explain_temp1 WHERE key > 1 -- !query schema struct -- !query output @@ -238,7 +238,7 @@ Condition : (isnotnull(key#x) AND (key#x > 0)) Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp1] -PushedFilters: [IsNotNull(key), GreaterThan(key,0)] +PushedFilters: [IsNotNull(key), GreaterThan(key,1)] ReadSchema: struct (5) ColumnarToRow [codegen id : 2] @@ -246,7 +246,7 @@ Input [2]: [key#x, val#x] (6) Filter [codegen id : 2] Input [2]: [key#x, val#x] -Condition : (isnotnull(key#x) AND (key#x > 0)) +Condition : (isnotnull(key#x) AND (key#x > 1)) (7) Union diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 816c7c4aab775..baf8d99a5092e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.optimizer.RemoveNoopUnion import org.apache.spark.sql.catalyst.plans.logical.Union import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -387,100 +388,104 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { test("SPARK-34283: SQL-style union using Dataset, " + "remove unnecessary deduplicate in multiple unions") { - val unionDF = testData.union(testData).distinct().union(testData).distinct() - .union(testData).distinct().union(testData).distinct() - - // Before optimizer, there are three 'union.deduplicate' operations should be combined. - assert(unionDF.queryExecution.analyzed.collect { - case u: Union if u.children.size == 4 => u - }.size === 1) - - // After optimizer, four 'union.deduplicate' operations should be combined. - assert(unionDF.queryExecution.optimizedPlan.collect { - case u: Union if u.children.size == 5 => u - }.size === 1) - - checkAnswer( - unionDF.agg(avg("key"), max("key"), min("key"), - sum("key")), Row(50.5, 100, 1, 5050) :: Nil - ) + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> RemoveNoopUnion.ruleName) { + val unionDF = testData.union(testData).distinct().union(testData).distinct() + .union(testData).distinct().union(testData).distinct() + + // Before optimizer, there are three 'union.deduplicate' operations should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case u: Union if u.children.size == 4 => u + }.size === 1) + + // After optimizer, four 'union.deduplicate' operations should be combined. + assert(unionDF.queryExecution.optimizedPlan.collect { + case u: Union if u.children.size == 5 => u + }.size === 1) + + checkAnswer( + unionDF.agg(avg("key"), max("key"), min("key"), + sum("key")), Row(50.5, 100, 1, 5050) :: Nil + ) - // The result of SQL-style union - val unionSQLResult = sql( - """ - | select key, value from testData - | union - | select key, value from testData - | union - | select key, value from testData - | union - | select key, value from testData - | union - | select key, value from testData - |""".stripMargin) - checkAnswer(unionDF, unionSQLResult) + // The result of SQL-style union + val unionSQLResult = sql( + """ + | select key, value from testData + | union + | select key, value from testData + | union + | select key, value from testData + | union + | select key, value from testData + | union + | select key, value from testData + |""".stripMargin) + checkAnswer(unionDF, unionSQLResult) + } } test("SPARK-34283: SQL-style union using Dataset, " + "keep necessary deduplicate in multiple unions") { - val df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - var df2 = Seq((6, 2, 5)).toDF("a", "b", "c") - var df3 = Seq((2, 4, 3)).toDF("c", "a", "b") - var df4 = Seq((1, 4, 5)).toDF("b", "a", "c") - - val unionDF = df1.unionByName(df2).dropDuplicates(Seq("a")) - .unionByName(df3).dropDuplicates("c").unionByName(df4) - .dropDuplicates("b") - - // In this case, there is no 'union.deduplicate' operation will be combined. - assert(unionDF.queryExecution.analyzed.collect { - case u: Union if u.children.size == 2 => u - }.size === 3) - - assert(unionDF.queryExecution.optimizedPlan.collect { - case u: Union if u.children.size == 2 => u - }.size === 3) - - checkAnswer( - unionDF, - Row(4, 3, 2) :: Row(4, 1, 5) :: Row(1, 2, 3) :: Nil - ) + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> RemoveNoopUnion.ruleName) { + val df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((6, 2, 5)).toDF("a", "b", "c") + var df3 = Seq((2, 4, 3)).toDF("c", "a", "b") + var df4 = Seq((1, 4, 5)).toDF("b", "a", "c") + + val unionDF = df1.unionByName(df2).dropDuplicates(Seq("a")) + .unionByName(df3).dropDuplicates("c").unionByName(df4) + .dropDuplicates("b") + + // In this case, there is no 'union.deduplicate' operation will be combined. + assert(unionDF.queryExecution.analyzed.collect { + case u: Union if u.children.size == 2 => u + }.size === 3) + + assert(unionDF.queryExecution.optimizedPlan.collect { + case u: Union if u.children.size == 2 => u + }.size === 3) + + checkAnswer( + unionDF, + Row(4, 3, 2) :: Row(4, 1, 5) :: Row(1, 2, 3) :: Nil + ) - val unionDF1 = df1.unionByName(df2).dropDuplicates(Seq("B", "A", "c")) - .unionByName(df3).dropDuplicates().unionByName(df4) - .dropDuplicates("A") - - // In this case, there are two 'union.deduplicate' operations will be combined. - assert(unionDF1.queryExecution.analyzed.collect { - case u: Union if u.children.size == 2 => u - }.size === 1) - assert(unionDF1.queryExecution.analyzed.collect { - case u: Union if u.children.size == 3 => u - }.size === 1) - - assert(unionDF1.queryExecution.optimizedPlan.collect { - case u: Union if u.children.size == 2 => u - }.size === 1) - assert(unionDF1.queryExecution.optimizedPlan.collect { - case u: Union if u.children.size == 3 => u - }.size === 1) + val unionDF1 = df1.unionByName(df2).dropDuplicates(Seq("B", "A", "c")) + .unionByName(df3).dropDuplicates().unionByName(df4) + .dropDuplicates("A") + + // In this case, there are two 'union.deduplicate' operations will be combined. + assert(unionDF1.queryExecution.analyzed.collect { + case u: Union if u.children.size == 2 => u + }.size === 1) + assert(unionDF1.queryExecution.analyzed.collect { + case u: Union if u.children.size == 3 => u + }.size === 1) + + assert(unionDF1.queryExecution.optimizedPlan.collect { + case u: Union if u.children.size == 2 => u + }.size === 1) + assert(unionDF1.queryExecution.optimizedPlan.collect { + case u: Union if u.children.size == 3 => u + }.size === 1) + + checkAnswer( + unionDF1, + Row(4, 3, 2) :: Row(6, 2, 5) :: Row(1, 2, 3) :: Nil + ) - checkAnswer( - unionDF1, - Row(4, 3, 2) :: Row(6, 2, 5) :: Row(1, 2, 3) :: Nil - ) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + df2 = Seq((6, 2, 5)).toDF("a", "B", "C") + df3 = Seq((2, 1, 3)).toDF("b", "a", "c") + df4 = Seq((1, 4, 5)).toDF("b", "a", "c") - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - df2 = Seq((6, 2, 5)).toDF("a", "B", "C") - df3 = Seq((2, 1, 3)).toDF("b", "a", "c") - df4 = Seq((1, 4, 5)).toDF("b", "a", "c") + val unionDF2 = df1.unionByName(df2, true).distinct() + .unionByName(df3, true).dropDuplicates(Seq("a")).unionByName(df4, true).distinct() - val unionDF2 = df1.unionByName(df2, true).distinct() - .unionByName(df3, true).dropDuplicates(Seq("a")).unionByName(df4, true).distinct() - - checkAnswer(unionDF2, - Row(4, 1, 5, null, null) :: Row(1, 2, 3, null, null) :: Row(6, null, null, 2, 5) :: Nil) - assert(unionDF2.schema.fieldNames === Array("a", "b", "c", "B", "C")) + checkAnswer(unionDF2, + Row(4, 1, 5, null, null) :: Row(1, 2, 3, null, null) :: Row(6, null, null, 2, 5) :: Nil) + assert(unionDF2.schema.fieldNames === Array("a", "b", "c", "B", "C")) + } } } @@ -808,6 +813,53 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { // scalastyle:on checkAnswer(union, row1 :: row2 :: Nil) } + + test("SPARK-34474: Remove unnecessary Union under Distinct") { + Seq(RemoveNoopUnion.ruleName, "").map { ruleName => + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ruleName) { + val distinctUnionDF1 = testData.union(testData).distinct() + checkAnswer(distinctUnionDF1, testData.distinct()) + + + val distinctUnionDF2 = testData.union(testData).dropDuplicates(Seq("key")) + checkAnswer(distinctUnionDF2, testData.dropDuplicates(Seq("key"))) + + val distinctUnionDF3 = sql( + """ + |select key, value from testData + |union + |select key, value from testData + |""".stripMargin) + checkAnswer(distinctUnionDF3, testData.distinct()) + + val distinctUnionDF4 = sql( + """ + |select distinct key, expr + |from + |( + | select key, key + 1 as expr + | from testData + | union all + | select key, key + 2 as expr + | from testData + |) + |""".stripMargin) + val expected = sql( + """ + |select key, expr + |from + |( + | select key, key + 1 as expr + | from testData + | union all + | select key, key + 2 as expr + | from testData + |) group by key, expr + |""".stripMargin) + checkAnswer(distinctUnionDF4, expected) + } + } + } } case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 629a06b9df6dc..aa3bf85044488 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1115,12 +1115,12 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |SELECT c1 FROM t1 |WHERE |c1 IN (( - | SELECT c1 FROM t2 + | SELECT c1 + 1 AS c1 FROM t2 | ORDER BY c1 | ) | UNION | ( - | SELECT c1 FROM t2 + | SELECT c1 + 2 AS c1 FROM t2 | ORDER BY c1 | )) """.stripMargin From dffb01f28a1f5fb6445e59e1c8eefdd683fe4a29 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 25 Feb 2021 15:25:41 -0800 Subject: [PATCH 29/60] [SPARK-34152][SQL][FOLLOWUP] Do not uncache the temp view if it doesn't exist ### What changes were proposed in this pull request? This PR fixes a mistake in https://github.com/apache/spark/pull/31273. When CREATE OR REPLACE a temp view, we need to uncache the to-be-replaced existing temp view. However, we shouldn't uncache if there is no existing temp view. This doesn't cause real issues because the uncache action is failure-safe. But it produces a lot of warning messages. ### Why are the changes needed? Avoid unnecessary warning logs. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually run tests and check the warning messages. Closes #31650 from cloud-fan/warnning. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/command/views.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index bae7c54f0b99d..5f5f0099e2bfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -115,7 +115,7 @@ case class CreateViewCommand( if (viewType == LocalTempView) { val aliasedPlan = aliasPlan(sparkSession, analyzedPlan) - if (replace && !isSamePlan(catalog.getRawTempView(name.table), aliasedPlan)) { + if (replace && needsToUncache(catalog.getRawTempView(name.table), aliasedPlan)) { logInfo(s"Try to uncache ${name.quotedString} before replacing.") checkCyclicViewReference(analyzedPlan, Seq(name), name) CommandUtils.uncacheTableOrView(sparkSession, name.quotedString) @@ -139,7 +139,7 @@ case class CreateViewCommand( val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) val viewIdent = TableIdentifier(name.table, Option(db)) val aliasedPlan = aliasPlan(sparkSession, analyzedPlan) - if (replace && !isSamePlan(catalog.getRawGlobalTempView(name.table), aliasedPlan)) { + if (replace && needsToUncache(catalog.getRawGlobalTempView(name.table), aliasedPlan)) { logInfo(s"Try to uncache ${viewIdent.quotedString} before replacing.") checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) CommandUtils.uncacheTableOrView(sparkSession, viewIdent.quotedString) @@ -193,15 +193,17 @@ case class CreateViewCommand( } /** - * Checks if the temp view (the result of getTempViewRawPlan or getRawGlobalTempView) is storing - * the same plan as the given aliased plan. + * Checks if need to uncache the temp view being replaced. */ - private def isSamePlan( + private def needsToUncache( rawTempView: Option[LogicalPlan], aliasedPlan: LogicalPlan): Boolean = rawTempView match { - case Some(TemporaryViewRelation(_, Some(p))) => p.sameResult(aliasedPlan) - case Some(p) => p.sameResult(aliasedPlan) - case _ => false + // The temp view doesn't exist, no need to uncache. + case None => false + // Do not need to uncache if the to-be-replaced temp view plan and the new plan are the + // same-result plans. + case Some(TemporaryViewRelation(_, Some(p))) => !p.sameResult(aliasedPlan) + case Some(p) => !p.sameResult(aliasedPlan) } /** From 1967760277595af5c842402c6f2d1f28dfb18728 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 25 Feb 2021 15:27:46 -0800 Subject: [PATCH 30/60] [SPARK-34505][BUILD] Upgrade Scala to 2.13.5 ### What changes were proposed in this pull request? This PR aims to update from Scala 2.13.4 to Scala 2.13.5 for Apache Spark 3.2. ### Why are the changes needed? Scala 2.13.5 is a maintenance release for 2.13 line and improves Java 13, 14, 15, 16, and 17 support. - https://github.com/scala/scala/releases/tag/v2.13.5 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass the GitHub Action `Scala 2.13` job and manual test. I verified the following locally and all passed. ``` $ dev/change-scala-version.sh 2.13 $ build/sbt test -Pscala-2.13 ``` Closes #31620 from dongjoon-hyun/SPARK-34505. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index d497ed176bd35..b7b9f8d826dbb 100644 --- a/pom.xml +++ b/pom.xml @@ -3278,7 +3278,7 @@ scala-2.13 - 2.13.4 + 2.13.5 2.13 From 0d3a9cd3c9d25fdd35ddf04d0fe2ed1f8fead2a5 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 26 Feb 2021 09:20:40 +0900 Subject: [PATCH 31/60] [SPARK-34535][SQL] Cleanup unused symbol in Orc related code ### What changes were proposed in this pull request? Cleanup unused symbol in Orc related code as follows: - `OrcDeserializer` : parameter `dataSchema` in constructor - `OrcFilters` : parameter `schema ` in method `convertibleFilters`. - `OrcPartitionReaderFactory`: ignore return value of `OrcUtils.orcResultSchemaString` in method `buildReader(file: PartitionedFile)` ### Why are the changes needed? Cleanup code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass the Jenkins or GitHub Action Closes #31644 from LuciferYang/cleanup-orc-unused-symbol. Authored-by: yangjie01 Signed-off-by: HyukjinKwon --- .../sql/execution/datasources/orc/OrcDeserializer.scala | 1 - .../spark/sql/execution/datasources/orc/OrcFileFormat.scala | 2 +- .../spark/sql/execution/datasources/orc/OrcFilters.scala | 3 +-- .../datasources/v2/orc/OrcPartitionReaderFactory.scala | 5 ++--- .../sql/execution/datasources/v2/orc/OrcScanBuilder.scala | 2 +- 5 files changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 32ce7185f7381..22374c59e5059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -30,7 +30,6 @@ import org.apache.spark.unsafe.types.UTF8String * A deserializer to deserialize ORC structs to Spark rows. */ class OrcDeserializer( - dataSchema: StructType, requiredSchema: StructType, requestedColIds: Array[Int]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 83504d8c4458a..8f4d1e5098029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -230,7 +230,7 @@ class OrcFileFormat val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + val deserializer = new OrcDeserializer(requiredSchema, requestedColIds) if (partitionSchema.length == 0) { iter.map(value => unsafeProjection(deserializer.deserialize(value))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 9511fc31f4ac5..61090966f4778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -70,7 +70,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + val conjunctionOptional = buildTree(convertibleFilters(dataTypeMap, filters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the @@ -80,7 +80,6 @@ private[sql] object OrcFilters extends OrcFiltersBase { } def convertibleFilters( - schema: StructType, dataTypeMap: Map[String, OrcPrimitiveField], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 6f9a3ae4c67fe..414252cc12481 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -98,8 +98,7 @@ case class OrcPartitionReaderFactory( new EmptyPartitionReader[InternalRow] } else { val (requestedColIds, canPruneCols) = resultedColPruneInfo.get - val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols, - dataSchema, resultSchema, partitionSchema, conf) + OrcUtils.orcResultSchemaString(canPruneCols, dataSchema, resultSchema, partitionSchema, conf) assert(requestedColIds.length == readDataSchema.length, "[BUG] requested column IDs do not match required schema") @@ -111,7 +110,7 @@ case class OrcPartitionReaderFactory( val orcRecordReader = new OrcInputFormat[OrcStruct] .createRecordReader(fileSplit, taskAttemptContext) - val deserializer = new OrcDeserializer(dataSchema, readDataSchema, requestedColIds) + val deserializer = new OrcDeserializer(readDataSchema, requestedColIds) val fileReader = new PartitionReader[InternalRow] { override def next(): Boolean = orcRecordReader.nextKeyValue() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 0dbc74395afb1..a8c813a03e0ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -54,7 +54,7 @@ case class OrcScanBuilder( override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) - _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray + _pushedFilters = OrcFilters.convertibleFilters(dataTypeMap, filters).toArray } filters } From 4d428a821b2117789d0a2c61c7229d00af1704eb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 25 Feb 2021 17:10:58 -0800 Subject: [PATCH 32/60] Revert "[SPARK-32617][K8S][TESTS] Configure kubernetes client based on kubeconfig settings in kubernetes integration tests" This reverts commit b17754a8cbd2593eb2b1952e95a7eeb0f8e09cdb. --- .../dev/dev-run-integration-tests.sh | 6 +- .../backend/minikube/Minikube.scala | 138 ++++++------------ 2 files changed, 43 insertions(+), 101 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index c87437e48f589..b72a4f74918ba 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -35,7 +35,6 @@ CONTEXT= INCLUDE_TAGS="k8s" EXCLUDE_TAGS= JAVA_VERSION="8" -BUILD_DEPENDENCIES_MVN_FLAG="-am" HADOOP_PROFILE="hadoop-3.2" MVN="$TEST_ROOT_DIR/build/mvn" @@ -118,9 +117,6 @@ while (( "$#" )); do HADOOP_PROFILE="$2" shift ;; - --skip-building-dependencies) - BUILD_DEPENDENCIES_MVN_FLAG="" - ;; *) echo "Unexpected command line flag $2 $1." exit 1 @@ -180,4 +176,4 @@ properties+=( -Dlog4j.logger.org.apache.spark=DEBUG ) -$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests $BUILD_DEPENDENCIES_MVN_FLAG -Pscala-$SCALA_VERSION -P$HADOOP_PROFILE -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-$SCALA_VERSION -P$HADOOP_PROFILE -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala index 5cb068545ef37..c33875243c598 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.deploy.k8s.integrationtest.backend.minikube -import java.nio.file.Paths +import java.nio.file.{Files, Paths} -import io.fabric8.kubernetes.client.{Config, ConfigBuilder, DefaultKubernetesClient} +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient} import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils import org.apache.spark.internal.Logging @@ -26,26 +26,18 @@ import org.apache.spark.internal.Logging // TODO support windows private[spark] object Minikube extends Logging { private val MINIKUBE_STARTUP_TIMEOUT_SECONDS = 60 - private val VERSION_PREFIX = "minikube version: " - private val HOST_PREFIX = "host: " - private val KUBELET_PREFIX = "kubelet: " - private val APISERVER_PREFIX = "apiserver: " - private val KUBECTL_PREFIX = "kubectl: " - private val KUBECONFIG_PREFIX = "kubeconfig: " + private val HOST_PREFIX = "host:" + private val KUBELET_PREFIX = "kubelet:" + private val APISERVER_PREFIX = "apiserver:" + private val KUBECTL_PREFIX = "kubectl:" + private val KUBECONFIG_PREFIX = "kubeconfig:" private val MINIKUBE_VM_PREFIX = "minikubeVM: " private val MINIKUBE_PREFIX = "minikube: " private val MINIKUBE_PATH = ".minikube" - private val APIVERSION_PREFIX = "apiVersion: " - private val SERVER_PREFIX = "server: " - private val CA_PREFIX = "certificate-authority: " - private val CLIENTCERT_PREFIX = "client-certificate: " - private val CLIENTKEY_PREFIX = "client-key: " - lazy val minikubeVersionString = - executeMinikube("version").find(_.contains(VERSION_PREFIX)).get - - def logVersion(): Unit = - logInfo(minikubeVersionString) + def logVersion(): Unit = { + logInfo(executeMinikube("version").mkString("\n")) + } def getMinikubeIp: String = { val outputs = executeMinikube("ip") @@ -64,106 +56,60 @@ private[spark] object Minikube extends Logging { if (oldMinikube.isEmpty) { getIfNewMinikubeStatus(statusString) } else { - val statusLine = oldMinikube.head - val finalStatusString = if (statusLine.contains(MINIKUBE_VM_PREFIX)) { - statusLine.split(MINIKUBE_VM_PREFIX)(1) - } else { - statusLine.split(MINIKUBE_PREFIX)(1) - } + val finalStatusString = oldMinikube + .head + .replaceFirst(MINIKUBE_VM_PREFIX, "") + .replaceFirst(MINIKUBE_PREFIX, "") MinikubeStatus.unapply(finalStatusString) .getOrElse(throw new IllegalStateException(s"Unknown status $statusString")) } } def getKubernetesClient: DefaultKubernetesClient = { - // only the three-part version number is matched (the optional suffix like "-beta.0" is dropped) - val versionArrayOpt = "\\d+\\.\\d+\\.\\d+".r - .findFirstIn(minikubeVersionString.split(VERSION_PREFIX)(1)) - .map(_.split('.').map(_.toInt)) - - assert(versionArrayOpt.isDefined && versionArrayOpt.get.size == 3, - s"Unexpected version format detected in `$minikubeVersionString`." + - "For minikube version a three-part version number is expected (the optional non-numeric " + - "suffix is intentionally dropped)") - - val kubernetesConf = versionArrayOpt.get match { - case Array(x, y, z) => - // comparing the versions as the kubectl command is only introduced in version v1.1.0: - // https://github.com/kubernetes/minikube/blob/v1.1.0/CHANGELOG.md - if (Ordering.Tuple3[Int, Int, Int].gteq((x, y, z), (1, 1, 0))) { - kubectlBasedKubernetesClientConf - } else { - legacyKubernetesClientConf - } - } - new DefaultKubernetesClient(kubernetesConf) - } - - private def legacyKubernetesClientConf: Config = { val kubernetesMaster = s"https://${getMinikubeIp}:8443" val userHome = System.getProperty("user.home") - buildKubernetesClientConf( - "v1", - kubernetesMaster, - Paths.get(userHome, MINIKUBE_PATH, "ca.crt").toFile.getAbsolutePath, - Paths.get(userHome, MINIKUBE_PATH, "apiserver.crt").toFile.getAbsolutePath, - Paths.get(userHome, MINIKUBE_PATH, "apiserver.key").toFile.getAbsolutePath) - } - - private def kubectlBasedKubernetesClientConf: Config = { - val outputs = executeMinikube("kubectl config view") - val apiVersionString = outputs.find(_.contains(APIVERSION_PREFIX)) - val serverString = outputs.find(_.contains(SERVER_PREFIX)) - val caString = outputs.find(_.contains(CA_PREFIX)) - val clientCertString = outputs.find(_.contains(CLIENTCERT_PREFIX)) - val clientKeyString = outputs.find(_.contains(CLIENTKEY_PREFIX)) - - assert(!apiVersionString.isEmpty && !serverString.isEmpty && !caString.isEmpty && - !clientKeyString.isEmpty && !clientKeyString.isEmpty, - "The output of 'minikube kubectl config view' does not contain all the neccesary attributes") - - buildKubernetesClientConf( - apiVersionString.get.split(APIVERSION_PREFIX)(1), - serverString.get.split(SERVER_PREFIX)(1), - caString.get.split(CA_PREFIX)(1), - clientCertString.get.split(CLIENTCERT_PREFIX)(1), - clientKeyString.get.split(CLIENTKEY_PREFIX)(1)) - } - - private def buildKubernetesClientConf(apiVersion: String, masterUrl: String, caCertFile: String, - clientCertFile: String, clientKeyFile: String): Config = { - logInfo(s"building kubernetes config with apiVersion: $apiVersion, masterUrl: $masterUrl, " + - s"caCertFile: $caCertFile, clientCertFile: $clientCertFile, clientKeyFile: $clientKeyFile") - new ConfigBuilder() - .withApiVersion(apiVersion) - .withMasterUrl(masterUrl) - .withCaCertFile(caCertFile) - .withClientCertFile(clientCertFile) - .withClientKeyFile(clientKeyFile) + val minikubeBasePath = Paths.get(userHome, MINIKUBE_PATH).toString + val profileDir = if (Files.exists(Paths.get(minikubeBasePath, "apiserver.crt"))) { + // For Minikube <1.9 + "" + } else { + // For Minikube >=1.9 + Paths.get("profiles", executeMinikube("profile")(0)).toString + } + val apiServerCertPath = Paths.get(minikubeBasePath, profileDir, "apiserver.crt") + val apiServerKeyPath = Paths.get(minikubeBasePath, profileDir, "apiserver.key") + val kubernetesConf = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(kubernetesMaster) + .withCaCertFile( + Paths.get(userHome, MINIKUBE_PATH, "ca.crt").toFile.getAbsolutePath) + .withClientCertFile(apiServerCertPath.toFile.getAbsolutePath) + .withClientKeyFile(apiServerKeyPath.toFile.getAbsolutePath) .build() + new DefaultKubernetesClient(kubernetesConf) } // Covers minikube status output after Minikube V0.30. private def getIfNewMinikubeStatus(statusString: Seq[String]): MinikubeStatus.Value = { - val hostString = statusString.find(_.contains(HOST_PREFIX)) - val kubeletString = statusString.find(_.contains(KUBELET_PREFIX)) - val apiserverString = statusString.find(_.contains(APISERVER_PREFIX)) - val kubectlString = statusString.find(_.contains(KUBECTL_PREFIX)) - val kubeconfigString = statusString.find(_.contains(KUBECONFIG_PREFIX)) + val hostString = statusString.find(_.contains(s"$HOST_PREFIX ")) + val kubeletString = statusString.find(_.contains(s"$KUBELET_PREFIX ")) + val apiserverString = statusString.find(_.contains(s"$APISERVER_PREFIX ")) + val kubectlString = statusString.find(_.contains(s"$KUBECTL_PREFIX ")) + val kubeconfigString = statusString.find(_.contains(s"$KUBECONFIG_PREFIX ")) val hasConfigStatus = kubectlString.isDefined || kubeconfigString.isDefined if (hostString.isEmpty || kubeletString.isEmpty || apiserverString.isEmpty || !hasConfigStatus) { MinikubeStatus.NONE } else { - val status1 = hostString.get.split(HOST_PREFIX)(1) - val status2 = kubeletString.get.split(KUBELET_PREFIX)(1) - val status3 = apiserverString.get.split(APISERVER_PREFIX)(1) + val status1 = hostString.get.replaceFirst(s"$HOST_PREFIX ", "") + val status2 = kubeletString.get.replaceFirst(s"$KUBELET_PREFIX ", "") + val status3 = apiserverString.get.replaceFirst(s"$APISERVER_PREFIX ", "") val isConfigured = if (kubectlString.isDefined) { - val cfgStatus = kubectlString.get.split(KUBECTL_PREFIX)(1) + val cfgStatus = kubectlString.get.replaceFirst(s"$KUBECTL_PREFIX ", "") cfgStatus.contains("Correctly Configured:") } else { - kubeconfigString.get.split(KUBECONFIG_PREFIX)(1) == "Configured" + kubeconfigString.get.replaceFirst(s"$KUBECONFIG_PREFIX ", "") == "Configured" } if (isConfigured) { val stats = List(status1, status2, status3) From 5c7d019b609c87a9427fa9309f3aa03d02f61878 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 26 Feb 2021 15:28:57 +0800 Subject: [PATCH 33/60] [SPARK-34543][SQL] Respect the `spark.sql.caseSensitive` config while resolving partition spec in v1 `SET LOCATION` ### What changes were proposed in this pull request? Preprocess the partition spec passed to the V1 `ALTER TABLE .. SET LOCATION` implementation `AlterTableSetLocationCommand`, and normalize the passed spec according to the partition columns w.r.t the case sensitivity flag **spark.sql.caseSensitive**. ### Why are the changes needed? V1 `ALTER TABLE .. SET LOCATION` is case sensitive in fact, and doesn't respect the SQL config **spark.sql.caseSensitive** which is false by default, for instance: ```sql spark-sql> CREATE TABLE tbl (id INT, part INT) PARTITIONED BY (part); spark-sql> INSERT INTO tbl PARTITION (part=0) SELECT 0; spark-sql> SHOW TABLE EXTENDED LIKE 'tbl' PARTITION (part=0); Location: file:/Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/part=0 spark-sql> ALTER TABLE tbl ADD PARTITION (part=1); spark-sql> SELECT * FROM tbl; 0 0 ``` Create new partition folder in the file system: ``` $ cp -r /Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/part=0 /Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/aaa ``` Set new location for the partition part=1: ```sql spark-sql> ALTER TABLE tbl PARTITION (part=1) SET LOCATION '/Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/aaa'; spark-sql> SELECT * FROM tbl; 0 0 0 1 spark-sql> ALTER TABLE tbl ADD PARTITION (PART=2); spark-sql> SELECT * FROM tbl; 0 0 0 1 ``` Set location for a partition in the upper case: ``` $ cp -r /Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/part=0 /Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/bbb ``` ```sql spark-sql> ALTER TABLE tbl PARTITION (PART=2) SET LOCATION '/Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/bbb'; Error in query: Partition spec is invalid. The spec (PART) must match the partition spec (part) defined in table '`default`.`tbl`' ``` ### Does this PR introduce _any_ user-facing change? Yes. After the changes, the command above works as expected: ```sql spark-sql> ALTER TABLE tbl PARTITION (PART=2) SET LOCATION '/Users/maximgekk/proj/set-location-case-sense/spark-warehouse/tbl/bbb'; spark-sql> SELECT * FROM tbl; 0 0 0 1 0 2 ``` ### How was this patch tested? By running the modified test suite: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *CatalogedDDLSuite" ``` Closes #31651 from MaxGekk/set-location-case-sense. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/execution/command/ddl.scala | 7 ++++++- .../apache/spark/sql/execution/command/DDLSuite.scala | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 2fc6d6fd85322..3e91dab6a5bea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -848,7 +848,12 @@ case class AlterTableSetLocationCommand( DDLUtils.verifyPartitionProviderIsHive( sparkSession, table, "ALTER TABLE ... SET LOCATION") // Partition spec is specified, so we set the location only for this partition - val part = catalog.getPartition(table.identifier, spec) + val normalizedSpec = PartitioningUtils.normalizePartitionSpec( + spec, + table.partitionSchema, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) + val part = catalog.getPartition(table.identifier, normalizedSpec) val newPart = part.copy(storage = part.storage.copy(locationUri = Some(locUri))) catalog.alterPartitions(table.identifier, Seq(newPart)) case None => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 7086c31082df7..19b5118a232ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1261,6 +1261,17 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // set table partition location sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") verifyLocation(new URI("/path/to/part/ways"), Some(partSpec)) + // set location for partition spec in the upper case + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("ALTER TABLE dbx.tab1 PARTITION (A='1', B='2') SET LOCATION '/path/to/part/ways2'") + verifyLocation(new URI("/path/to/part/ways2"), Some(partSpec)) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 PARTITION (A='1', B='2') SET LOCATION '/path/to/part/ways3'") + }.getMessage + assert(errMsg.contains("not a valid partition column")) + } // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") From 5b925319374b11fa30f6d00b9c2e92fbee3aa343 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 26 Feb 2021 20:19:33 +0900 Subject: [PATCH 34/60] [SPARK-34551][INFRA] Fix credit related scripts to recover, drop Python 2 and work with Python 3 ### What changes were proposed in this pull request? This PR proposes to make the scripts working by: - Recovering credit related scripts that were broken from https://github.com/apache/spark/pull/29563 `raw_input` does not exist in `releaseutils` but only in Python 2 - Dropping Python 2 in these scripts because we dropped Python 2 in https://github.com/apache/spark/pull/28957 - Making these scripts workin with Python 3 ### Why are the changes needed? To unblock the release. ### Does this PR introduce _any_ user-facing change? No, it's dev-only change. ### How was this patch tested? I manually tested against Spark 3.1.1 RC3. Closes #31660 from HyukjinKwon/SPARK-34551. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- dev/create-release/generate-contributors.py | 8 +++---- dev/create-release/releaseutils.py | 15 +++--------- dev/create-release/translate-contributors.py | 24 ++++---------------- dev/requirements.txt | 1 - 4 files changed, 12 insertions(+), 36 deletions(-) diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py index 4e07bd79f8ac3..75965a6a26201 100755 --- a/dev/create-release/generate-contributors.py +++ b/dev/create-release/generate-contributors.py @@ -22,7 +22,7 @@ import re import sys -from releaseutils import tag_exists, raw_input, get_commits, yesOrNoPrompt, get_date, \ +from releaseutils import tag_exists, get_commits, yesOrNoPrompt, get_date, \ is_valid_author, capitalize_author, JIRA, find_components, translate_issue_type, \ translate_component, CORE_COMPONENT, contributors_file_name, nice_join @@ -33,10 +33,10 @@ # If the release tags are not provided, prompt the user to provide them while not tag_exists(RELEASE_TAG): - RELEASE_TAG = raw_input("Please provide a valid release tag: ") + RELEASE_TAG = input("Please provide a valid release tag: ") while not tag_exists(PREVIOUS_RELEASE_TAG): print("Please specify the previous release tag.") - PREVIOUS_RELEASE_TAG = raw_input( + PREVIOUS_RELEASE_TAG = input( "For instance, if you are releasing v1.2.0, you should specify v1.1.0: ") # Gather commits found in the new tag but not in the old tag. @@ -236,7 +236,7 @@ def populate(issue_type, components): # e.g. * Andrew Or -- Bug fixes in Windows, Core, and Web UI; improvements in Core # e.g. * Tathagata Das -- Bug fixes and new features in Streaming contributors_file = open(contributors_file_name, "w") -authors = author_info.keys() +authors = list(author_info.keys()) authors.sort() for author in authors: contribution = "" diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index a0e9695d58361..94e255bd440b8 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -42,13 +42,6 @@ print("Install using 'sudo pip install PyGithub'") sys.exit(-1) -try: - import unidecode -except ImportError: - print("This tool requires the unidecode library to decode obscure github usernames") - print("Install using 'sudo pip install unidecode'") - sys.exit(-1) - # Contributors list file name contributors_file_name = "contributors.txt" @@ -64,11 +57,11 @@ def yesOrNoPrompt(msg): # Utility functions run git commands (written with Git 1.8.5) def run_cmd(cmd): - return Popen(cmd, stdout=PIPE).communicate()[0] + return Popen(cmd, stdout=PIPE).communicate()[0].decode("utf8") def run_cmd_error(cmd): - return Popen(cmd, stdout=PIPE, stderr=PIPE).communicate()[1] + return Popen(cmd, stdout=PIPE, stderr=PIPE).communicate()[1].decode("utf8") def get_date(commit_hash): @@ -149,9 +142,7 @@ def get_commits(tag): # username so we can translate it properly later if not is_valid_author(author): author = github_username - # Guard against special characters - author = str(author) - author = unidecode.unidecode(author).strip() + author = author.strip() commit = Commit(_hash, author, title, pr_number) commits.append(commit) return commits diff --git a/dev/create-release/translate-contributors.py b/dev/create-release/translate-contributors.py index be5611ce65a7d..0736917de394d 100755 --- a/dev/create-release/translate-contributors.py +++ b/dev/create-release/translate-contributors.py @@ -32,14 +32,7 @@ import sys from releaseutils import JIRA, JIRAError, get_jira_name, Github, get_github_name, \ - contributors_file_name, is_valid_author, raw_input, capitalize_author, yesOrNoPrompt - -try: - import unidecode -except ImportError: - print("This tool requires the unidecode library to decode obscure github usernames") - print("Install using 'sudo pip install unidecode'") - sys.exit(-1) + contributors_file_name, is_valid_author, capitalize_author, yesOrNoPrompt # You must set the following before use! JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira") @@ -139,15 +132,8 @@ def generate_candidates(author, issues): (NOT_FOUND, "No full name found for %s assignee %s" % (issue, user_name))) else: candidates.append((NOT_FOUND, "No assignee found for %s" % issue)) - # Guard against special characters in candidate names - # Note that the candidate name may already be in unicode (JIRA returns this) for i, (candidate, source) in enumerate(candidates): - try: - candidate = unicode(candidate, "UTF-8") # noqa: F821 - except TypeError: - # already in unicode - pass - candidate = unidecode.unidecode(candidate).strip() + candidate = candidate.strip() candidates[i] = (candidate, source) return candidates @@ -209,13 +195,13 @@ def generate_candidates(author, issues): if INTERACTIVE_MODE: print(" [%d] %s - Raw GitHub username" % (raw_index, author)) print(" [%d] Custom" % custom_index) - response = raw_input(" Your choice: ") + response = input(" Your choice: ") last_index = custom_index while not response.isdigit() or int(response) > last_index: - response = raw_input(" Please enter an integer between 0 and %d: " % last_index) + response = input(" Please enter an integer between 0 and %d: " % last_index) response = int(response) if response == custom_index: - new_author = raw_input(" Please type a custom name for this author: ") + new_author = input(" Please type a custom name for this author: ") elif response != raw_index: new_author = candidate_names[response] # In non-interactive mode, just pick the first candidate diff --git a/dev/requirements.txt b/dev/requirements.txt index c1546c8b8d4d3..ddb3e1729f03f 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -1,7 +1,6 @@ flake8==3.5.0 jira==1.0.3 PyGithub==1.26.0 -Unidecode==0.04.19 sphinx pydata_sphinx_theme ipython From ac774ec0c2cdc9a5d2e20e5f751ef2e753df352f Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 26 Feb 2021 20:20:31 +0900 Subject: [PATCH 35/60] [SPARK-34553][INFRA] Rename GITHUB_API_TOKEN to GITHUB_OAUTH_KEY in translate-contributors.py ### What changes were proposed in this pull request? This PR proposes to add an alias environment variable `GITHUB_OAUTH_KEY` for `GITHUB_API_TOKEN` in `translate-contributors.py` script. ### Why are the changes needed? ``` dev/github_jira_sync.py:GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY") dev/github_jira_sync.py: request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY) dev/github_jira_sync.py: request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY) dev/merge_spark_pr.py:GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY") dev/merge_spark_pr.py: if GITHUB_OAUTH_KEY: dev/merge_spark_pr.py: request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY) dev/run-tests-jenkins.py: github_oauth_key = os.environ["GITHUB_OAUTH_KEY"] ``` Spark uses `GITHUB_OAUTH_KEY` for GitHub token, but `translate-contributors.py` script alone uses `GITHUB_API_TOKEN`. We should better match to make it easier to run the script ### Does this PR introduce _any_ user-facing change? No, it's dev-only. ### How was this patch tested? I manually tested by running this script. Closes #31662 from HyukjinKwon/minor-gh-token-name. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- dev/create-release/translate-contributors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/create-release/translate-contributors.py b/dev/create-release/translate-contributors.py index 0736917de394d..6af975916ec49 100755 --- a/dev/create-release/translate-contributors.py +++ b/dev/create-release/translate-contributors.py @@ -38,11 +38,11 @@ JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira") JIRA_USERNAME = os.environ.get("JIRA_USERNAME", None) JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", None) -GITHUB_API_TOKEN = os.environ.get("GITHUB_API_TOKEN", None) +GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY", os.environ.get("GITHUB_API_TOKEN", None)) if not JIRA_USERNAME or not JIRA_PASSWORD: sys.exit("Both JIRA_USERNAME and JIRA_PASSWORD must be set") -if not GITHUB_API_TOKEN: - sys.exit("GITHUB_API_TOKEN must be set") +if not GITHUB_OAUTH_KEY: + sys.exit("GITHUB_OAUTH_KEY must be set") # Write new contributors list to .final if not os.path.isfile(contributors_file_name): @@ -64,7 +64,7 @@ # Setup GitHub and JIRA clients jira_options = {"server": JIRA_API_BASE} jira_client = JIRA(options=jira_options, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) -github_client = Github(GITHUB_API_TOKEN) +github_client = Github(GITHUB_OAUTH_KEY) # Load known author translations that are cached locally known_translations = {} From 73857cdd87757d2888bd92f6b7c2fad709701484 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Feb 2021 11:44:42 +0000 Subject: [PATCH 36/60] [SPARK-34524][SQL] Simplify v2 partition commands resolution ### What changes were proposed in this pull request? This PR simplifies the resolution of v2 partition commands: 1. Add a common trait for v2 partition commands, so that we don't need to match them one by one in the rules. 2. Make partition spec an expression, so that it's easier to resolve them via tree node transformation. 3. Add `TruncatePartition` so that `TruncateTable` doesn't need to be a v2 partition command. 4. Simplify `CheckAnalysis` to only check if the table is partitioned. For partitioned tables, partition spec is always resolved, so we don't need to check it. The `SupportsAtomicPartitionManagement` check is also done in the runtime. Since Spark eagerly executes commands, exception in runtime will also be thrown at analysis time. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #31637 from cloud-fan/simplify. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 63 ++++--------- .../analysis/ResolvePartitionSpec.scala | 92 +++++++------------ .../catalyst/analysis/v2ResolutionPlans.scala | 16 +++- .../sql/catalyst/parser/AstBuilder.scala | 9 +- .../catalyst/plans/logical/v2Commands.scala | 62 ++++++------- .../v2/DataSourceV2Implicits.scala | 10 +- .../analysis/ResolveSessionCatalog.scala | 7 +- .../datasources/v2/DataSourceV2Strategy.scala | 14 +-- .../v2/TruncatePartitionExec.scala | 52 +++++++++++ .../datasources/v2/TruncateTableExec.scala | 21 +---- .../command/ShowPartitionsSuiteBase.scala | 4 +- .../command/TruncateTableParserSuite.scala | 12 +-- .../v2/AlterTableAddPartitionSuite.scala | 2 +- .../v2/AlterTableDropPartitionSuite.scala | 2 +- .../command/v2/ShowPartitionsSuite.scala | 3 +- .../command/v2/TruncateTableSuite.scala | 2 +- 16 files changed, 188 insertions(+), 183 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 59e37e8a9bfaf..389bbb828da6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} -import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table, TruncatableTable} +import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table} import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -150,6 +150,23 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case AlterTable(_, _, u: UnresolvedV2Relation, _) => failAnalysis(s"Table not found: ${u.originalNameParts.quoted}") + case command: V2PartitionCommand => + command.table match { + case r @ ResolvedTable(_, _, table, _) => table match { + case t: SupportsPartitionManagement => + if (t.partitionSchema.isEmpty) { + failAnalysis(s"Table ${r.name} is not partitioned.") + } + case _ => + failAnalysis(s"Table ${r.name} does not support partition management.") + } + case _ => + } + + // `ShowTableExtended` should have been converted to the v1 command if the table is v1. + case _: ShowTableExtended => + throw new AnalysisException("SHOW TABLE EXTENDED is not supported for v2 tables.") + case operator: LogicalPlan => // Check argument data types of higher-order functions downwards first. // If the arguments of the higher-order functions are resolved but the type check fails, @@ -565,19 +582,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // no validation needed for set and remove property } - case AddPartitions(r: ResolvedTable, parts, _) => - checkAlterTablePartition(r.table, parts) - - case DropPartitions(r: ResolvedTable, parts, _, _) => - checkAlterTablePartition(r.table, parts) - - case RenamePartitions(r: ResolvedTable, from, _) => - checkAlterTablePartition(r.table, Seq(from)) - - case showPartitions: ShowPartitions => checkShowPartitions(showPartitions) - - case truncateTable: TruncateTable => checkTruncateTable(truncateTable) - case _ => // Falls back to the following checks } @@ -1009,35 +1013,4 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case _ => } } - - // Make sure that the `SHOW PARTITIONS` command is allowed for the table - private def checkShowPartitions(showPartitions: ShowPartitions): Unit = showPartitions match { - case ShowPartitions(rt: ResolvedTable, _, _) - if !rt.table.isInstanceOf[SupportsPartitionManagement] => - failAnalysis("SHOW PARTITIONS cannot run for a table which does not support partitioning") - case ShowPartitions(ResolvedTable(_, _, partTable: SupportsPartitionManagement, _), _, _) - if partTable.partitionSchema().isEmpty => - failAnalysis( - s"SHOW PARTITIONS is not allowed on a table that is not partitioned: ${partTable.name()}") - case _ => - } - - private def checkTruncateTable(truncateTable: TruncateTable): Unit = truncateTable match { - case TruncateTable(rt: ResolvedTable, None) if !rt.table.isInstanceOf[TruncatableTable] => - failAnalysis(s"The table ${rt.table.name()} does not support truncation") - case TruncateTable(rt: ResolvedTable, Some(_)) - if !rt.table.isInstanceOf[SupportsPartitionManagement] => - failAnalysis("TRUNCATE TABLE cannot run for a table which does not support partitioning") - case TruncateTable( - ResolvedTable(_, _, _: SupportsPartitionManagement, _), - Some(_: UnresolvedPartitionSpec)) => - failAnalysis("Partition spec is not resolved") - case TruncateTable( - ResolvedTable(_, _, table: SupportsPartitionManagement, _), - Some(spec: ResolvedPartitionSpec)) - if spec.names.length < table.partitionSchema.length && - !table.isInstanceOf[SupportsAtomicPartitionManagement] => - failAnalysis(s"The table ${table.name()} does not support truncation of multiple partitions") - case _ => - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index e68c9793fa6a0..79b7b0c5ba35e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AddPartitions, DropPartitions, LogicalPlan, RenamePartitions, ShowPartitions, TruncateTable} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2PartitionCommand} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement @@ -33,70 +33,42 @@ import org.apache.spark.sql.util.PartitioningUtils.{normalizePartitionSpec, requ object ResolvePartitionSpec extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case r @ AddPartitions( - ResolvedTable(_, _, table: SupportsPartitionManagement, _), partSpecs, _) => - val partitionSchema = table.partitionSchema() - r.copy(parts = resolvePartitionSpecs( - table.name, - partSpecs, - partitionSchema, - requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames))) - - case r @ DropPartitions( - ResolvedTable(_, _, table: SupportsPartitionManagement, _), partSpecs, _, _) => - val partitionSchema = table.partitionSchema() - r.copy(parts = resolvePartitionSpecs( - table.name, - partSpecs, - partitionSchema, - requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames))) - - case r @ RenamePartitions( - ResolvedTable(_, _, table: SupportsPartitionManagement, _), from, to) => - val partitionSchema = table.partitionSchema() - val Seq(resolvedFrom, resolvedTo) = resolvePartitionSpecs( - table.name, - Seq(from, to), - partitionSchema, - requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames)) - r.copy(from = resolvedFrom, to = resolvedTo) - - case r @ ShowPartitions( - ResolvedTable(_, _, table: SupportsPartitionManagement, _), partSpecs, _) => - r.copy(pattern = resolvePartitionSpecs( - table.name, - partSpecs.toSeq, - table.partitionSchema()).headOption) - - case r @ TruncateTable(ResolvedTable(_, _, table: SupportsPartitionManagement, _), partSpecs) => - r.copy(partitionSpec = resolvePartitionSpecs( - table.name, - partSpecs.toSeq, - table.partitionSchema()).headOption) + case command: V2PartitionCommand if command.childrenResolved && !command.resolved => + command.table match { + case r @ ResolvedTable(_, _, table: SupportsPartitionManagement, _) => + command.transformExpressions { + case partSpecs: UnresolvedPartitionSpec => + val partitionSchema = table.partitionSchema() + resolvePartitionSpec( + r.name, + partSpecs, + partitionSchema, + command.allowPartialPartitionSpec) + } + case _ => command + } } - private def resolvePartitionSpecs( + private def resolvePartitionSpec( tableName: String, - partSpecs: Seq[PartitionSpec], + partSpec: UnresolvedPartitionSpec, partSchema: StructType, - checkSpec: TablePartitionSpec => Unit = _ => ()): Seq[ResolvedPartitionSpec] = - partSpecs.map { - case unresolvedPartSpec: UnresolvedPartitionSpec => - val normalizedSpec = normalizePartitionSpec( - unresolvedPartSpec.spec, - partSchema, - tableName, - conf.resolver) - checkSpec(normalizedSpec) - val partitionNames = normalizedSpec.keySet - val requestedFields = partSchema.filter(field => partitionNames.contains(field.name)) - ResolvedPartitionSpec( - requestedFields.map(_.name), - convertToPartIdent(normalizedSpec, requestedFields), - unresolvedPartSpec.location) - case resolvedPartitionSpec: ResolvedPartitionSpec => - resolvedPartitionSpec + allowPartitionSpec: Boolean): ResolvedPartitionSpec = { + val normalizedSpec = normalizePartitionSpec( + partSpec.spec, + partSchema, + tableName, + conf.resolver) + if (!allowPartitionSpec) { + requireExactMatchedPartitionSpec(tableName, normalizedSpec, partSchema.fieldNames) } + val partitionNames = normalizedSpec.keySet + val requestedFields = partSchema.filter(field => partitionNames.contains(field.name)) + ResolvedPartitionSpec( + requestedFields.map(_.name), + convertToPartIdent(normalizedSpec, requestedFields), + partSpec.location) + } private[sql] def convertToPartIdent( partitionSpec: TablePartitionSpec, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index f7e08bdb73ec0..b50c306805436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable} import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.types.DataType /** * Holds the name of a namespace that has yet to be looked up in a catalog. It will be resolved to @@ -73,11 +75,18 @@ case class UnresolvedTableOrView( override def output: Seq[Attribute] = Nil } -sealed trait PartitionSpec +sealed trait PartitionSpec extends LeafExpression with Unevaluable { + override def dataType: DataType = throw new IllegalStateException( + "PartitionSpec.dataType should not be called.") + override def nullable: Boolean = throw new IllegalStateException( + "PartitionSpec.nullable should not be called.") +} case class UnresolvedPartitionSpec( spec: TablePartitionSpec, - location: Option[String] = None) extends PartitionSpec + location: Option[String] = None) extends PartitionSpec { + override lazy val resolved = false +} /** * Holds the name of a function that has yet to be looked up in a catalog. It will be resolved to @@ -109,6 +118,7 @@ case class ResolvedTable( val qualifier = catalog.name +: identifier.namespace :+ identifier.name outputAttributes.map(_.withQualifier(qualifier)) } + def name: String = (catalog.name +: identifier.namespace() :+ identifier.name()).quoted } object ResolvedTable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 25e6cbeaa524c..a43d28b045d09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3757,11 +3757,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { - TruncateTable( - createUnresolvedTable(ctx.multipartIdentifier, "TRUNCATE TABLE"), - Option(ctx.partitionSpec).map { spec => - UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(spec)) - }) + val table = createUnresolvedTable(ctx.multipartIdentifier, "TRUNCATE TABLE") + Option(ctx.partitionSpec).map { spec => + TruncatePartition(table, UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(spec))) + }.getOrElse(TruncateTable(table)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index ea67c5571ec9b..847d7ae0117e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PartitionSpec, ResolvedPartitionSpec, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PartitionSpec, UnresolvedException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema @@ -59,6 +59,11 @@ trait V2WriteCommand extends Command { def withNewTable(newTable: NamedRelation): V2WriteCommand } +trait V2PartitionCommand extends Command { + def table: LogicalPlan + def allowPartialPartitionSpec: Boolean = false +} + /** * Append data to an existing table. */ @@ -677,13 +682,10 @@ case class AnalyzeColumn( * }}} */ case class AddPartitions( - child: LogicalPlan, + table: LogicalPlan, parts: Seq[PartitionSpec], - ifNotExists: Boolean) extends Command { - override lazy val resolved: Boolean = - childrenResolved && parts.forall(_.isInstanceOf[ResolvedPartitionSpec]) - - override def children: Seq[LogicalPlan] = child :: Nil + ifNotExists: Boolean) extends V2PartitionCommand { + override def children: Seq[LogicalPlan] = table :: Nil } /** @@ -699,29 +701,21 @@ case class AddPartitions( * }}} */ case class DropPartitions( - child: LogicalPlan, + table: LogicalPlan, parts: Seq[PartitionSpec], ifExists: Boolean, - purge: Boolean) extends Command { - override lazy val resolved: Boolean = - childrenResolved && parts.forall(_.isInstanceOf[ResolvedPartitionSpec]) - - override def children: Seq[LogicalPlan] = child :: Nil + purge: Boolean) extends V2PartitionCommand { + override def children: Seq[LogicalPlan] = table :: Nil } /** * The logical plan of the ALTER TABLE ... RENAME TO PARTITION command. */ case class RenamePartitions( - child: LogicalPlan, + table: LogicalPlan, from: PartitionSpec, - to: PartitionSpec) extends Command { - override lazy val resolved: Boolean = - childrenResolved && - from.isInstanceOf[ResolvedPartitionSpec] && - to.isInstanceOf[ResolvedPartitionSpec] - - override def children: Seq[LogicalPlan] = child :: Nil + to: PartitionSpec) extends V2PartitionCommand { + override def children: Seq[LogicalPlan] = table :: Nil } /** @@ -767,23 +761,29 @@ object ShowColumns { /** * The logical plan of the TRUNCATE TABLE command. */ -case class TruncateTable( - child: LogicalPlan, - partitionSpec: Option[PartitionSpec]) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil +case class TruncateTable(table: LogicalPlan) extends Command { + override def children: Seq[LogicalPlan] = table :: Nil +} + +/** + * The logical plan of the TRUNCATE TABLE ... PARTITION command. + */ +case class TruncatePartition( + table: LogicalPlan, + partitionSpec: PartitionSpec) extends V2PartitionCommand { + override def children: Seq[LogicalPlan] = table :: Nil + override def allowPartialPartitionSpec: Boolean = true } /** * The logical plan of the SHOW PARTITIONS command. */ case class ShowPartitions( - child: LogicalPlan, + table: LogicalPlan, pattern: Option[PartitionSpec], - override val output: Seq[Attribute] = ShowPartitions.OUTPUT) extends Command { - override def children: Seq[LogicalPlan] = child :: Nil - - override lazy val resolved: Boolean = - childrenResolved && pattern.forall(_.isInstanceOf[ResolvedPartitionSpec]) + override val output: Seq[Attribute] = ShowPartitions.OUTPUT) extends V2PartitionCommand { + override def children: Seq[LogicalPlan] = table :: Nil + override def allowPartialPartitionSpec: Boolean = true } object ShowPartitions { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index 4326c730f88dd..daa2c0468370e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{PartitionSpec, ResolvedPartitionSpec, UnresolvedPartitionSpec} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsAtomicPartitionManagement, SupportsDelete, SupportsPartitionManagement, SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsAtomicPartitionManagement, SupportsDelete, SupportsPartitionManagement, SupportsRead, SupportsWrite, Table, TableCapability, TruncatableTable} import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -57,6 +57,14 @@ object DataSourceV2Implicits { } } + def asTruncatable: TruncatableTable = { + table match { + case t: TruncatableTable => t + case _ => + throw new AnalysisException(s"Table does not support truncates: ${table.name}") + } + } + def asPartitionable: SupportsPartitionManagement = { table match { case support: SupportsPartitionManagement => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index dde31f62e06b9..290833d6a41ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -394,10 +394,13 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) ShowCreateTableCommand(ident.asTableIdentifier) } - case TruncateTable(ResolvedV1TableIdentifier(ident), partitionSpec) => + case TruncateTable(ResolvedV1TableIdentifier(ident)) => + TruncateTableCommand(ident.asTableIdentifier, None) + + case TruncatePartition(ResolvedV1TableIdentifier(ident), partitionSpec) => TruncateTableCommand( ident.asTableIdentifier, - partitionSpec.toSeq.asUnresolvedPartitionSpecs.map(_.spec).headOption) + Seq(partitionSpec).asUnresolvedPartitionSpecs.map(_.spec).headOption) case s @ ShowPartitions( ResolvedV1TableOrViewIdentifier(ident), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index a5b092a1aa491..135de2ad4c5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -337,9 +337,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case ShowTables(ResolvedNamespace(catalog, ns), pattern, output) => ShowTablesExec(output, catalog.asTableCatalog, ns, pattern) :: Nil - case _: ShowTableExtended => - throw new AnalysisException("SHOW TABLE EXTENDED is not supported for v2 tables.") - case SetCatalogAndNamespace(catalogManager, catalogName, ns) => SetCatalogAndNamespaceExec(catalogManager, catalogName, ns) :: Nil @@ -394,10 +391,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case ShowCreateTable(_: ResolvedTable, _) => throw new AnalysisException("SHOW CREATE TABLE is not supported for v2 tables.") - case TruncateTable(r: ResolvedTable, parts) => + case TruncateTable(r: ResolvedTable) => TruncateTableExec( - r.table, - parts.toSeq.asResolvedPartitionSpecs.headOption, + r.table.asTruncatable, + recacheTable(r)) :: Nil + + case TruncatePartition(r: ResolvedTable, part) => + TruncatePartitionExec( + r.table.asPartitionable, + Seq(part).asResolvedPartitionSpecs.head, recacheTable(r)) :: Nil case ShowColumns(_: ResolvedTable, _, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala new file mode 100644 index 0000000000000..135005b64973d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ResolvedPartitionSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement, SupportsPartitionManagement} + +/** + * Physical plan node for table partition truncation. + */ +case class TruncatePartitionExec( + table: SupportsPartitionManagement, + partSpec: ResolvedPartitionSpec, + refreshCache: () => Unit) extends V2CommandExec { + + override def output: Seq[Attribute] = Seq.empty + + override protected def run(): Seq[InternalRow] = { + val isTableAltered = if (table.partitionSchema.length != partSpec.names.length) { + table match { + case atomicPartTable: SupportsAtomicPartitionManagement => + val partitionIdentifiers = atomicPartTable.listPartitionIdentifiers( + partSpec.names.toArray, partSpec.ident) + atomicPartTable.truncatePartitions(partitionIdentifiers) + case _ => + throw new UnsupportedOperationException( + s"The table ${table.name()} does not support truncation of multiple partition.") + } + } else { + table.truncatePartition(partSpec.ident) + } + if (isTableAltered) refreshCache() + Seq.empty + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala index 17f86e26074a4..69261b3084776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala @@ -18,35 +18,20 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.ResolvedPartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table, TruncatableTable} +import org.apache.spark.sql.connector.catalog.TruncatableTable /** * Physical plan node for table truncation. */ case class TruncateTableExec( - table: Table, - partSpecs: Option[ResolvedPartitionSpec], + table: TruncatableTable, refreshCache: () => Unit) extends V2CommandExec { override def output: Seq[Attribute] = Seq.empty override protected def run(): Seq[InternalRow] = { - val isTableAltered = (table, partSpecs) match { - case (truncatableTable: TruncatableTable, None) => - truncatableTable.truncateTable() - case (partTable: SupportsPartitionManagement, Some(resolvedPartSpec)) - if partTable.partitionSchema.length == resolvedPartSpec.names.length => - partTable.truncatePartition(resolvedPartSpec.ident) - case (atomicPartTable: SupportsAtomicPartitionManagement, Some(resolvedPartitionSpec)) => - val partitionIdentifiers = atomicPartTable.listPartitionIdentifiers( - resolvedPartitionSpec.names.toArray, resolvedPartitionSpec.ident) - atomicPartTable.truncatePartitions(partitionIdentifiers) - case _ => throw new IllegalArgumentException( - s"Truncation of ${table.getClass.getName} is not supported") - } - if (isTableAltered) refreshCache() + if (table.truncateTable()) refreshCache() Seq.empty } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsSuiteBase.scala index 29edb8fb51cf8..27d2eb9854302 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsSuiteBase.scala @@ -69,7 +69,9 @@ trait ShowPartitionsSuiteBase extends QueryTest with DDLCommandTestUtils { val errMsg = intercept[AnalysisException] { sql(s"SHOW PARTITIONS $t") }.getMessage - assert(errMsg.contains("not allowed on a table that is not partitioned")) + assert(errMsg.contains("not allowed on a table that is not partitioned") || + // V2 error message. + errMsg.contains(s"Table $t is not partitioned")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala index 39531c84a63d0..7f4a48023c16e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala @@ -20,30 +20,30 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedPartitionSpec, UnresolvedTable} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.TruncateTable +import org.apache.spark.sql.catalyst.plans.logical.{TruncatePartition, TruncateTable} import org.apache.spark.sql.test.SharedSparkSession class TruncateTableParserSuite extends AnalysisTest with SharedSparkSession { test("truncate table") { comparePlans( parsePlan("TRUNCATE TABLE a.b.c"), - TruncateTable(UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE", None), None)) + TruncateTable(UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE", None))) } test("truncate a single part partition") { comparePlans( parsePlan("TRUNCATE TABLE a.b.c PARTITION(ds='2017-06-10')"), - TruncateTable( + TruncatePartition( UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE", None), - Some(UnresolvedPartitionSpec(Map("ds" -> "2017-06-10"), None)))) + UnresolvedPartitionSpec(Map("ds" -> "2017-06-10"), None))) } test("truncate a multi parts partition") { comparePlans( parsePlan("TRUNCATE TABLE ns.tbl PARTITION(a = 1, B = 'ABC')"), - TruncateTable( + TruncatePartition( UnresolvedTable(Seq("ns", "tbl"), "TRUNCATE TABLE", None), - Some(UnresolvedPartitionSpec(Map("a" -> "1", "B" -> "ABC"), None)))) + UnresolvedPartitionSpec(Map("a" -> "1", "B" -> "ABC"), None))) } test("empty values in non-optional partition specs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala index a33eb0e4628bc..fabe399c340ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala @@ -33,7 +33,7 @@ class AlterTableAddPartitionSuite val errMsg = intercept[AnalysisException] { sql(s"ALTER TABLE $t ADD PARTITION (id=1)") }.getMessage - assert(errMsg.contains(s"Table $t can not alter partitions")) + assert(errMsg.contains(s"Table $t does not support partition management")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala index 3515fa3390206..b03c8fb17f542 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala @@ -36,7 +36,7 @@ class AlterTableDropPartitionSuite val errMsg = intercept[AnalysisException] { sql(s"ALTER TABLE $t DROP PARTITION (id=1)") }.getMessage - assert(errMsg.contains("can not alter partitions")) + assert(errMsg.contains(s"Table $t does not support partition management")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala index 42f05ee55504a..8ae8171924c69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala @@ -33,8 +33,7 @@ class ShowPartitionsSuite extends command.ShowPartitionsSuiteBase with CommandSu val errMsg = intercept[AnalysisException] { sql(s"SHOW PARTITIONS $table") }.getMessage - assert(errMsg.contains( - "SHOW PARTITIONS cannot run for a table which does not support partitioning")) + assert(errMsg.contains(s"Table $table does not support partition management")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala index 1e14a080bf042..f125a72bd32a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala @@ -36,7 +36,7 @@ class TruncateTableSuite extends command.TruncateTableSuiteBase with CommandSuit sql(s"TRUNCATE TABLE $t PARTITION (c0=1)") }.getMessage assert(errMsg.contains( - "TRUNCATE TABLE cannot run for a table which does not support partitioning")) + "Table non_part_test_catalog.ns.tbl does not support partition management")) } } } From 7d5021f5eed2b9c48bd02b92cce1535edc46d0e4 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 26 Feb 2021 11:46:27 +0000 Subject: [PATCH 37/60] [SPARK-34533][SQL] Eliminate LEFT ANTI join to empty relation in AQE ### What changes were proposed in this pull request? I discovered from review discussion - https://github.com/apache/spark/pull/31630#discussion_r581774000 , that we can eliminate LEFT ANTI join (with no join condition) to empty relation, if the right side is known to be non-empty. So with AQE, this is doable similar to https://github.com/apache/spark/pull/29484 . ### Why are the changes needed? This can help eliminate the join operator during logical plan optimization. Before this PR, [left side physical plan `execute()` will be called](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala#L192), so if left side is complicated (e.g. contain broadcast exchange operator), then some computation would happen. However after this PR, the join operator will be removed during logical plan, and nothing will be computed from left side. Potentially it can save resource for these kinds of query. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests for positive and negative queries in `AdaptiveQueryExecSuite.scala`. Closes #31641 from c21/left-anti-aqe. Authored-by: Cheng Su Signed-off-by: Wenchen Fan --- .../EliminateJoinToEmptyRelation.scala | 16 +++++++++++++- .../adaptive/AdaptiveQueryExecSuite.scala | 22 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala index cfdd20ec7565d..d6df52278e079 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin -import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} +import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation, HashedRelationWithAllNullKeys} @@ -33,6 +33,8 @@ import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation * This applies to all Joins (sort merge join, shuffled hash join, and broadcast hash join), * because sort merge join and shuffled hash join will be changed to broadcast hash join with AQE * at the first place. + * + * 3. Join is left anti join without condition, and join right side is non-empty. */ object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] { @@ -53,5 +55,17 @@ object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] { case j @ Join(_, _, LeftSemi, _, _) if canEliminate(j.right, EmptyHashedRelation) => LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + + case j @ Join(_, _, LeftAnti, None, _) => + val isNonEmptyRightSide = j.right match { + case LogicalQueryStage(_, stage: QueryStageExec) if stage.resultOption.get().isDefined => + stage.getRuntimeStatistics.rowCount.exists(_ > 0) + case _ => false + } + if (isNonEmptyRightSide) { + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + } else { + j + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 122bc2d1e59a6..d7a1d5d26eae1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1230,6 +1230,28 @@ class AdaptiveQueryExecSuite } } + test("SPARK-34533: Eliminate left anti join to empty relation") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withTable("emptyTestData") { + spark.range(0).write.saveAsTable("emptyTestData") + Seq( + // broadcast non-empty right side + ("SELECT /*+ broadcast(testData3) */ * FROM testData LEFT ANTI JOIN testData3", true), + // broadcast empty right side + ("SELECT /*+ broadcast(emptyTestData) */ * FROM testData LEFT ANTI JOIN emptyTestData", + false), + // broadcast left side + ("SELECT /*+ broadcast(testData) */ * FROM testData LEFT ANTI JOIN testData3", false) + ).foreach { case (query, isEliminated) => + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(findTopLevelBaseJoin(plan).size == 1) + assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated) + } + } + } + } + test("SPARK-32753: Only copy tags to node with no tags") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { withTempView("v1") { From 82267acfe8c78a70d56a6ae6ab9a1135c0dc0836 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 26 Feb 2021 21:29:14 +0900 Subject: [PATCH 38/60] [SPARK-34550][SQL] Skip InSet null value during push filter to Hive metastore ### What changes were proposed in this pull request? Skip `InSet` null value during push filter to Hive metastore. ### Why are the changes needed? If `InSet` contains a null value, we should skip it and push other values to metastore. To keep same behavior with `In`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add test. Closes #31659 from ulysses-you/SPARK-34550. Authored-by: ulysses-you Signed-off-by: HyukjinKwon --- .../apache/spark/sql/hive/client/HiveShim.scala | 4 ++-- .../spark/sql/hive/client/FiltersSuite.scala | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8ccb17ce35925..db67480ceb77a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -700,7 +700,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def unapply(values: Set[Any]): Option[Seq[String]] = { - val extractables = values.toSeq.map(valueToLiteralString.lift) + val extractables = values.filter(_ != null).toSeq.map(valueToLiteralString.lift) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) } else { @@ -715,7 +715,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def unapply(values: Set[Any]): Option[Seq[String]] = { - val extractables = values.toSeq.map(valueToLiteralString.lift) + val extractables = values.filter(_ != null).toSeq.map(valueToLiteralString.lift) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 6962f9dd6b186..79b34bd141de3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -187,5 +187,20 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { } } + test("SPARK-34538: Skip InSet null value during push filter to Hive metastore") { + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD.key -> "3") { + val intFilter = InSet(a("p", IntegerType), Set(null, 1, 2)) + val intConverted = shim.convertFilters(testTable, Seq(intFilter), conf.sessionLocalTimeZone) + assert(intConverted == "(p = 1 or p = 2)") + } + + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD.key -> "3") { + val dateFilter = InSet(a("p", DateType), Set(null, + Literal(Date.valueOf("2020-01-01")).eval(), Literal(Date.valueOf("2021-01-01")).eval())) + val dateConverted = shim.convertFilters(testTable, Seq(dateFilter), conf.sessionLocalTimeZone) + assert(dateConverted == "(p = 2020-01-01 or p = 2021-01-01)") + } + } + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() } From a9e8e0528a52d19103463bae0a9420127a99bf59 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Feb 2021 21:30:24 +0900 Subject: [PATCH 39/60] [SPARK-34549][BUILD] Upgrade aws kinesis to 1.14.0 and java sdk 1.11.844 ### What changes were proposed in this pull request? This patch tries to upgrade aws kinesis and java sdk version. ### Why are the changes needed? Upgrade aws kinesis and java sdk to catch up minimum requirement for new feature like IAM role for service accounts: https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts-minimum-sdk.html ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. Closes #31658 from viirya/upgrade-aws-sdk. Authored-by: Liang-Chi Hsieh Signed-off-by: HyukjinKwon --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index b7b9f8d826dbb..39f2792d29d87 100644 --- a/pom.xml +++ b/pom.xml @@ -149,9 +149,9 @@ --> 4.1.1 1.10.1 - 1.12.0 + 1.14.0 - 1.11.655 + 1.11.844 0.12.8 From c1beb16cc8db9f61f1b86b5bfa4cd4d603c9b990 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 26 Feb 2021 21:33:14 +0900 Subject: [PATCH 40/60] [SPARK-34554][SQL] Implement the copy() method in ColumnarMap ### What changes were proposed in this pull request? Implement `ColumnarMap.copy()` by using the `copy()` method of `ColumnarArray`. ### Why are the changes needed? To eliminate `java.lang.UnsupportedOperationException` while using `ColumnarMap`. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? By running new tests in `ColumnarBatchSuite`. Closes #31663 from MaxGekk/columnar-map-copy. Authored-by: Max Gekk Signed-off-by: HyukjinKwon --- .../org/apache/spark/sql/vectorized/ColumnarMap.java | 5 +++-- .../execution/vectorized/ColumnarBatchSuite.scala | 12 +++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java index 35648e386c4f1..6b3d518746dc3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.vectorized; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; import org.apache.spark.sql.catalyst.util.MapData; /** @@ -47,7 +48,7 @@ public ColumnarArray valueArray() { } @Override - public ColumnarMap copy() { - throw new UnsupportedOperationException(); + public MapData copy() { + return new ArrayBasedMapData(keys.copy(), values.copy()); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index a369b2d6900f9..bd69bab6f5da2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -896,6 +896,16 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(a4.asScala == Map()) assert(a5.asScala == Map(3 -> 6, 4 -> 8, 5 -> 10)) + def toScalaMap(mapData: MapData): Map[Int, Int] = { + val keys = mapData.keyArray().toSeq[Int](IntegerType) + val values = mapData.valueArray().toSeq[Int](IntegerType) + (keys zip values).toMap + } + assert(toScalaMap(column.getMap(0).copy()) === Map(0 -> 0)) + assert(toScalaMap(column.getMap(1).copy()) === Map(1 -> 2, 2 -> 4)) + assert(toScalaMap(column.getMap(3).copy()) === Map()) + assert(toScalaMap(column.getMap(4).copy()) === Map(3 -> 6, 4 -> 8, 5 -> 10)) + column.close() } } From 67ec4f7f67dc494c2619b7faf1b1145f2200b65c Mon Sep 17 00:00:00 2001 From: "tanel.kiis@gmail.com" Date: Fri, 26 Feb 2021 21:59:02 +0900 Subject: [PATCH 41/60] [SPARK-33971][SQL] Eliminate distinct from more aggregates ### What changes were proposed in this pull request? Add more aggregate expressions to `EliminateDistinct` rule. ### Why are the changes needed? Distinct aggregation can add a significant overhead. It's better to remove distinct whenever possible. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #30999 from tanelk/SPARK-33971_eliminate_distinct. Authored-by: tanel.kiis@gmail.com Signed-off-by: Takeshi Yamamuro --- .../sql/catalyst/optimizer/Optimizer.scala | 16 +++++--- .../optimizer/EliminateDistinctSuite.scala | 41 ++++++++++--------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 717770f9fa1be..cb24180c57842 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -352,11 +352,17 @@ abstract class Optimizer(catalogManager: CatalogManager) */ object EliminateDistinct extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { - case ae: AggregateExpression if ae.isDistinct => - ae.aggregateFunction match { - case _: Max | _: Min => ae.copy(isDistinct = false) - case _ => ae - } + case ae: AggregateExpression if ae.isDistinct && isDuplicateAgnostic(ae.aggregateFunction) => + ae.copy(isDistinct = false) + } + + private def isDuplicateAgnostic(af: AggregateFunction): Boolean = af match { + case _: Max => true + case _: Min => true + case _: BitAndAgg => true + case _: BitOrAgg => true + case _: CollectSet => true + case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala index 51c751923e414..0848d5609ff02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -32,25 +34,24 @@ class EliminateDistinctSuite extends PlanTest { val testRelation = LocalRelation('a.int) - test("Eliminate Distinct in Max") { - val query = testRelation - .select(maxDistinct('a).as('result)) - .analyze - val answer = testRelation - .select(max('a).as('result)) - .analyze - assert(query != answer) - comparePlans(Optimize.execute(query), answer) - } - - test("Eliminate Distinct in Min") { - val query = testRelation - .select(minDistinct('a).as('result)) - .analyze - val answer = testRelation - .select(min('a).as('result)) - .analyze - assert(query != answer) - comparePlans(Optimize.execute(query), answer) + Seq( + Max(_), + Min(_), + BitAndAgg(_), + BitOrAgg(_), + CollectSet(_: Expression) + ).foreach { + aggBuilder => + val agg = aggBuilder('a) + test(s"Eliminate Distinct in ${agg.prettyName}") { + val query = testRelation + .select(agg.toAggregateExpression(isDistinct = true).as('result)) + .analyze + val answer = testRelation + .select(agg.toAggregateExpression(isDistinct = false).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } } } From 8d68f3f74658c7e0c12ee1d6f09a1aae14d9e04f Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Sat, 27 Feb 2021 01:02:11 +0900 Subject: [PATCH 42/60] [MINOR] Add more known translations of contributors ### What changes were proposed in this pull request? This PR adds some more known translations of contributors who contributed multiple times in Spark 3.1.1. ### Why are the changes needed? To make release process easier. ### Does this PR introduce _any_ user-facing change? No, dev-only. ### How was this patch tested? N/A (auto-generated) Closes #31665 from HyukjinKwon/minor-add-known-translations. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- dev/create-release/known_translations | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 64bd9ada1bf61..3b599b98c199c 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -25,6 +25,7 @@ GenTang - Gen TANG GraceH - Jie Huang Gschiavon - German Schiavon Matteo GuoPhilipse - Philipse Guo +HeartSaVioR - Jungtaek Lim Hellsen83 - Erik Christiansen HyukjinKwon - Hyukjin Kwon Icysandwich - Icysandwich @@ -66,6 +67,7 @@ SaintBacchus - Huang Zhaowei Sephiroth-Lin - Sephiroth Lin Shiti - Shiti Saxena SongYadong - Yadong Song +TJX2014 - Jinxin Tang TigerYang414 - David Yang TomokoKomiyama - Tomoko Komiyama TopGunViper - TopGunViper @@ -89,6 +91,7 @@ ajithme - Ajith S akonopko - Alexander Konopko alexdebrie - Alex DeBrie alicegugu - Gu Huiqin Alice +allisonwang-db - Allison Wang alokito - Alok Saldanha alyaxey - Alex Slusarenko amanomer - Aman Omer @@ -96,6 +99,7 @@ ameyc - Amey Chaugule anabranch - Bill Chambers anantasty - Anant Asthana ancasarb - Anca Sarb +anchovYu - Xinyi Yu andrewor14 - Andrew Or aniketbhatnagar - Aniket Bhatnagar animeshbaranawal - Animesh Baranawal @@ -204,6 +208,7 @@ igorcalabria - Igor Calabria imback82 - Terry Kim industrial-sloth - Jascha Swisher invkrh - Hao Ren +itholic - Haejoon Lee ivoson - Tengfei Huang jackylk - Jacky Li jagadeesanas2 - Jagadeesan A S @@ -234,6 +239,7 @@ laskfla - Keith Sun lazyman500 - Dong Xu lcqzte10192193 - Chaoqun Li leahmcguire - Leah McGuire +leanken - Leanken Lin lee19 - Lee leoluan2009 - Xuedong Luan liangxs - Xuesen Liang @@ -357,6 +363,7 @@ surq - Surong Quan suxingfate - Xinglong Wang suyanNone - Su Yan szheng79 - Shuai Zheng +tanelk - Tanel Kiis tedyu - Ted Yu teeyog - Yong Tian texasmichelle - Michelle Casbon @@ -383,6 +390,7 @@ watermen - Yadong Qi weixiuli - XiuLi Wei wenfang6 - wenfang6 wenxuanguan - wenxuanguan +williamhyun - William Hyun windpiger - Song Jun witgo - Guoqiang Li woudygao - Woudy Gao From 56e664c7179eadeb5134b4418f3aaa6a9d742ef6 Mon Sep 17 00:00:00 2001 From: ShiKai Wang Date: Fri, 26 Feb 2021 11:03:20 -0600 Subject: [PATCH 43/60] [SPARK-34392][SQL] Support ZoneOffset +h:mm in DateTimeUtils. getZoneId ### What changes were proposed in this pull request? To support +8:00 in Spark3 when execute sql `select to_utc_timestamp("2020-02-07 16:00:00", "GMT+8:00")` ### Why are the changes needed? +8:00 this format is supported in PostgreSQL,hive, presto, but not supported in Spark3 https://issues.apache.org/jira/browse/SPARK-34392 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? unit test Closes #31624 from Karl-WangSK/zone. Lead-authored-by: ShiKai Wang Co-authored-by: Karl-WangSK Signed-off-by: Sean Owen --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 5 ++++- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 13 +++++++++++++ .../apache/spark/sql/internal/SQLConfSuite.scala | 5 ++--- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index b4f12db439f7f..2ffccdd06e504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -50,7 +50,10 @@ object DateTimeUtils { val TIMEZONE_OPTION = "timeZone" - def getZoneId(timeZoneId: String): ZoneId = ZoneId.of(timeZoneId, ZoneId.SHORT_IDS) + def getZoneId(timeZoneId: String): ZoneId = { + // To support the (+|-)h:mm format because it was supported before Spark 3.0. + ZoneId.of(timeZoneId.replaceFirst("(\\+|\\-)(\\d):", "$10$2:"), ZoneId.SHORT_IDS) + } def getTimeZone(timeZoneId: String): TimeZone = TimeZone.getTimeZone(getZoneId(timeZoneId)) /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index b9b55da5a2080..46e333a660600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -471,6 +471,13 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { test("2011-12-25 09:00:00.123456", JST.getId, "2011-12-25 18:00:00.123456") test("2011-12-25 09:00:00.123456", LA.getId, "2011-12-25 01:00:00.123456") test("2011-12-25 09:00:00.123456", "Asia/Shanghai", "2011-12-25 17:00:00.123456") + test("2011-12-25 09:00:00.123456", "-7", "2011-12-25 02:00:00.123456") + test("2011-12-25 09:00:00.123456", "+8:00", "2011-12-25 17:00:00.123456") + test("2011-12-25 09:00:00.123456", "+8:00:00", "2011-12-25 17:00:00.123456") + test("2011-12-25 09:00:00.123456", "+0800", "2011-12-25 17:00:00.123456") + test("2011-12-25 09:00:00.123456", "-071020", "2011-12-25 01:49:40.123456") + test("2011-12-25 09:00:00.123456", "-07:10:20", "2011-12-25 01:49:40.123456") + } } @@ -496,6 +503,12 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { test("2011-12-25 18:00:00.123456", JST.getId, "2011-12-25 09:00:00.123456") test("2011-12-25 01:00:00.123456", LA.getId, "2011-12-25 09:00:00.123456") test("2011-12-25 17:00:00.123456", "Asia/Shanghai", "2011-12-25 09:00:00.123456") + test("2011-12-25 02:00:00.123456", "-7", "2011-12-25 09:00:00.123456") + test("2011-12-25 17:00:00.123456", "+8:00", "2011-12-25 09:00:00.123456") + test("2011-12-25 17:00:00.123456", "+8:00:00", "2011-12-25 09:00:00.123456") + test("2011-12-25 17:00:00.123456", "+0800", "2011-12-25 09:00:00.123456") + test("2011-12-25 01:49:40.123456", "-071020", "2011-12-25 09:00:00.123456") + test("2011-12-25 01:49:40.123456", "-07:10:20", "2011-12-25 09:00:00.123456") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 93b785952768d..b85a668e5b8ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -414,13 +414,12 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { spark.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Chicago") assert(sql(s"set ${SQLConf.SESSION_LOCAL_TIMEZONE.key}").head().getString(1) === "America/Chicago") + spark.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, "GMT+8:00") + assert(sql(s"set ${SQLConf.SESSION_LOCAL_TIMEZONE.key}").head().getString(1) === "GMT+8:00") intercept[IllegalArgumentException] { spark.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, "pst") } - intercept[IllegalArgumentException] { - spark.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, "GMT+8:00") - } val e = intercept[IllegalArgumentException] { spark.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, "Asia/shanghai") } From 05069ff4ce1bbacd88b0b8497c97d8a8ca23d5a7 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 27 Feb 2021 16:48:20 +0900 Subject: [PATCH 44/60] [SPARK-34353][SQL] CollectLimitExec avoid shuffle if input rdd has 0/1 partition ### What changes were proposed in this pull request? if child rdd has only one partition or zero partition, skip the shuffle ### Why are the changes needed? skip shuffle if possible ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #31468 from zhengruifeng/collect_limit_single_partition. Authored-by: Ruifeng Zheng Signed-off-by: Takeshi Yamamuro --- .../apache/spark/sql/execution/limit.scala | 76 +++++++++++-------- .../TakeOrderedAndProjectSuite.scala | 56 ++++++++------ 2 files changed, 78 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 671de65ff1089..0b74a2667a273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.util.collection.Utils /** * The operator takes limited number of elements from its child operator. @@ -52,16 +53,25 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec { SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { - val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) - val shuffled = new ShuffledRowRDD( - ShuffleExchangeExec.prepareShuffleDependency( - locallyLimited, - child.output, - SinglePartition, - serializer, - writeMetrics), - readMetrics) - shuffled.mapPartitionsInternal(_.take(limit)) + val childRDD = child.execute() + if (childRDD.getNumPartitions == 0) { + new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, Map.empty) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val locallyLimited = childRDD.mapPartitionsInternal(_.take(limit)) + new ShuffledRowRDD( + ShuffleExchangeExec.prepareShuffleDependency( + locallyLimited, + child.output, + SinglePartition, + serializer, + writeMetrics), + readMetrics) + } + singlePartitionRDD.mapPartitionsInternal(_.take(limit)) + } } } @@ -200,28 +210,32 @@ case class TakeOrderedAndProjectExec( protected override def doExecute(): RDD[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val childRDD = child.execute() - val singlePartitionRDD = if (childRDD.getNumPartitions > 1) { - val localTopK = childRDD.mapPartitions { iter => - org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - } - new ShuffledRowRDD( - ShuffleExchangeExec.prepareShuffleDependency( - localTopK, - child.output, - SinglePartition, - serializer, - writeMetrics), - readMetrics) + if (childRDD.getNumPartitions == 0) { + new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, Map.empty) } else { - childRDD - } - singlePartitionRDD.mapPartitions { iter => - val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - if (projectList != child.output) { - val proj = UnsafeProjection.create(projectList, child.output) - topK.map(r => proj(r)) + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD } else { - topK + val localTopK = childRDD.mapPartitions { iter => + Utils.takeOrdered(iter.map(_.copy()), limit)(ord) + } + new ShuffledRowRDD( + ShuffleExchangeExec.prepareShuffleDependency( + localTopK, + child.output, + SinglePartition, + serializer, + writeMetrics), + readMetrics) + } + singlePartitionRDD.mapPartitions { iter => + val topK = Utils.takeOrdered(iter.map(_.copy()), limit)(ord) + if (projectList != child.output) { + val proj = UnsafeProjection.create(projectList, child.output) + topK.map(r => proj(r)) + } else { + topK + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 376d330ebeb70..6ec5c6287eed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -37,12 +37,18 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession { rand = new Random(seed) } - private def generateRandomInputData(): DataFrame = { + private def generateRandomInputData(numRows: Int, numParts: Int): DataFrame = { val schema = new StructType() .add("a", IntegerType, nullable = false) .add("b", IntegerType, nullable = false) - val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) - spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + val rdd = if (numParts > 0) { + val inputData = Seq.fill(numRows)(Row(rand.nextInt(), rand.nextInt())) + sparkContext.parallelize(Random.shuffle(inputData), numParts) + } else { + sparkContext.emptyRDD[Row] + } + assert(rdd.getNumPartitions == numParts) + spark.createDataFrame(rdd, schema) } /** @@ -56,31 +62,35 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession { test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) + Seq((0, 0), (10000, 1), (10000, 10)).foreach { case (n, m) => + checkThatPlansAgree( + generateRandomInputData(n, m), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) + } } } test("TakeOrderedAndProject.doExecute with project") { withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) + Seq((0, 0), (10000, 1), (10000, 10)).foreach { case (n, m) => + checkThatPlansAgree( + generateRandomInputData(n, m), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) + } } } } From d75821038f88144918b0814830ba4cb03f739433 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 26 Feb 2021 23:58:45 -0800 Subject: [PATCH 45/60] [SPARK-34557][BUILD] Exclude Avro's transitive zstd-jni dependency ### What changes were proposed in this pull request? This PR aims to exclude `Apache Avro`'s transitive zstd-jni dependency. ### Why are the changes needed? While SPARK-27733 upgrades Apache Avro from 1.8 to 1.10, `zstd-jni` transitive dependency is created. This PR explicitly prevents dependency conflicts. **BEFORE** ``` $ build/sbt "core/evicted" | grep zstd [info] * com.github.luben:zstd-jni:1.4.8-5 is selected over 1.4.5-12 ``` **AFTER** ``` $ build/sbt "core/evicted" | grep zstd ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. Closes #31670 from dongjoon-hyun/SPARK-34557. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pom.xml b/pom.xml index 39f2792d29d87..f9dc229181dda 100644 --- a/pom.xml +++ b/pom.xml @@ -1235,6 +1235,10 @@ javax.annotation javax.annotation-api + + com.github.luben + zstd-jni + From 1aeafb485298f87c64c5c09ec3a70aad4171209f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 27 Feb 2021 03:04:12 -0800 Subject: [PATCH 46/60] [SPARK-34559][BUILD] Upgrade to ZSTD JNI 1.4.8-6 ### What changes were proposed in this pull request? This PR aims to upgrade ZSTD JNI to 1.4.8-6. ### Why are the changes needed? This fixes the following issue and will unblock SPARK-34479 (Support ZSTD at Avro data source). - https://github.com/luben/zstd-jni/issues/161 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. Closes #31674 from dongjoon-hyun/SPARK-34559. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-2.7-hive-2.3 | 2 +- dev/deps/spark-deps-hadoop-3.2-hive-2.3 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index 228bb94d22920..51a0d8b0f09f5 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -243,4 +243,4 @@ xz/1.8//xz-1.8.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar zookeeper/3.6.2//zookeeper-3.6.2.jar -zstd-jni/1.4.8-5//zstd-jni-1.4.8-5.jar +zstd-jni/1.4.8-6//zstd-jni-1.4.8-6.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index 71ef4c1ef998e..977fc4b1210f1 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -210,4 +210,4 @@ xz/1.8//xz-1.8.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar zookeeper/3.6.2//zookeeper-3.6.2.jar -zstd-jni/1.4.8-5//zstd-jni-1.4.8-5.jar +zstd-jni/1.4.8-6//zstd-jni-1.4.8-6.jar diff --git a/pom.xml b/pom.xml index f9dc229181dda..3bd5ef74a9336 100644 --- a/pom.xml +++ b/pom.xml @@ -700,7 +700,7 @@ com.github.luben zstd-jni - 1.4.8-5 + 1.4.8-6 com.clearspring.analytics From 397b843890db974a0534394b1907d33d62c2b888 Mon Sep 17 00:00:00 2001 From: Phillip Henry Date: Sat, 27 Feb 2021 08:34:39 -0600 Subject: [PATCH 47/60] [SPARK-34415][ML] Randomization in hyperparameter optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Code in the PR generates random parameters for hyperparameter tuning. A discussion with Sean Owen can be found on the dev mailing list here: http://apache-spark-developers-list.1001551.n3.nabble.com/Hyperparameter-Optimization-via-Randomization-td30629.html All code is entirely my own work and I license the work to the project under the project’s open source license. ### Why are the changes needed? Randomization can be a more effective techinique than a grid search since min/max points can fall between the grid and never be found. Randomisation is not so restricted although the probability of finding minima/maxima is dependent on the number of attempts. Alice Zheng has an accessible description on how this technique works at https://www.oreilly.com/library/view/evaluating-machine-learning/9781492048756/ch04.html Although there are Python libraries with more sophisticated techniques, not every Spark developer is using Python. ### Does this PR introduce _any_ user-facing change? A new class (`ParamRandomBuilder.scala`) and its tests have been created but there is no change to existing code. This class offers an alternative to `ParamGridBuilder` and can be dropped into the code wherever `ParamGridBuilder` appears. Indeed, it extends `ParamGridBuilder` and is completely compatible with its interface. It merely adds one method that provides a range over which a hyperparameter will be randomly defined. ### How was this patch tested? Tests `ParamRandomBuilderSuite.scala` and `RandomRangesSuite.scala` were added. `ParamRandomBuilderSuite` is the analogue of the already existing `ParamGridBuilderSuite` which tests the user-facing interface. `RandomRangesSuite` uses ScalaCheck to test the random ranges over which hyperparameters are distributed. Closes #31535 from PhillHenry/ParamRandomBuilder. Authored-by: Phillip Henry Signed-off-by: Sean Owen --- docs/ml-tuning.md | 36 +++- ...ectionViaRandomHyperparametersExample.java | 83 +++++++++ ...ctionViaRandomHyperparametersExample.scala | 79 ++++++++ .../spark/ml/tuning/ParamRandomBuilder.scala | 160 +++++++++++++++++ .../ml/tuning/ParamRandomBuilderSuite.scala | 123 +++++++++++++ .../spark/ml/tuning/RandomRangesSuite.scala | 168 ++++++++++++++++++ python/docs/source/reference/pyspark.ml.rst | 1 + python/pyspark/ml/tests/test_tuning.py | 106 ++++++++++- python/pyspark/ml/tuning.py | 48 ++++- python/pyspark/ml/tuning.pyi | 5 + 10 files changed, 806 insertions(+), 3 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaRandomHyperparametersExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaRandomHyperparametersExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tuning/ParamRandomBuilder.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tuning/ParamRandomBuilderSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tuning/RandomRangesSuite.scala diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 3ddd185d19ff4..e7940a3493685 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -71,10 +71,44 @@ for multiclass problems, a [`MultilabelClassificationEvaluator`](api/scala/org/a [`RankingEvaluator`](api/scala/org/apache/spark/ml/evaluation/RankingEvaluator.html) for ranking problems. The default metric used to choose the best `ParamMap` can be overridden by the `setMetricName` method in each of these evaluators. -To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/org/apache/spark/ml/tuning/ParamGridBuilder.html) utility. +To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/org/apache/spark/ml/tuning/ParamGridBuilder.html) utility (see the *Cross-Validation* section below for an example). By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit`. The value of `parallelism` should be chosen carefully to maximize parallelism without exceeding cluster resources, and larger values may not always lead to improved performance. Generally speaking, a value up to 10 should be sufficient for most clusters. +Alternatively, users can use the [`ParamRandomBuilder`](api/scala/org/apache/spark/ml/tuning/ParamRandomBuilder.html) utility. +This has the same properties of `ParamGridBuilder` mentioned above, but hyperparameters are chosen at random within a user-defined range. +The mathematical principle behind this is that given enough samples, the probability of at least one sample *not* being near the optimum within a range tends to zero. +Irrespective of machine learning model, the expected number of samples needed to have at least one within 5% of the optimum is about 60. +If this 5% volume lies between the parameters defined in a grid search, it will *never* be found by `ParamGridBuilder`. + +
+ +
+ +Refer to the [`ParamRandomBuilder` Scala docs](api/scala/org/apache/spark/ml/tuning/ParamRandomBuilder.html) for details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaRandomHyperparametersExample.scala %} +
+ +
+ +Refer to the [`ParamRandomBuilder` Java docs](api/java/org/apache/spark/ml/tuning/ParamRandomBuilder.html) for details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaRandomHyperparametersExample.java %} +
+ +
+ +Python users are recommended to look at Python libraries that are specifically for hyperparameter tuning such as Hyperopt. + +Refer to the [`ParamRandomBuilder` Java docs](api/python/reference/api/pyspark.ml.tuning.ParamRandomBuilder.html) for details on the API. + +{% include_example python/ml/model_selection_random_hyperparameters_example.py %} + +
+ +
+ # Cross-Validation `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets. E.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular `ParamMap`, `CrossValidator` computes the average evaluation metric for the 3 `Model`s produced by fitting the `Estimator` on the 3 different (training, test) dataset pairs. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaRandomHyperparametersExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaRandomHyperparametersExample.java new file mode 100644 index 0000000000000..086920f775362 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaRandomHyperparametersExample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +// $example off$ + +/** + * A simple example demonstrating model selection using ParamRandomBuilder. + * + * Run with + * {{{ + * bin/run-example ml.JavaModelSelectionViaRandomHyperparametersExample + * }}} + */ +public class JavaModelSelectionViaRandomHyperparametersExample { + + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaModelSelectionViaTrainValidationSplitExample") + .getOrCreate(); + + // $example on$ + Dataset data = spark.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + LinearRegression lr = new LinearRegression(); + + // We sample the regularization parameter logarithmically over the range [0.01, 1.0]. + // This means that values around 0.01, 0.1 and 1.0 are roughly equally likely. + // Note that both parameters must be greater than zero as otherwise we'll get an infinity. + // We sample the the ElasticNet mixing parameter uniformly over the range [0, 1] + // Note that in real life, you'd choose more than the 5 samples we see below. + ParamMap[] hyperparameters = new ParamRandomBuilder() + .addLog10Random(lr.regParam(), 0.01, 1.0, 5) + .addRandom(lr.elasticNetParam(), 0.0, 1.0, 5) + .addGrid(lr.fitIntercept()) + .build(); + + System.out.println("hyperparameters:"); + for (ParamMap param : hyperparameters) { + System.out.println(param); + } + + CrossValidator cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(hyperparameters) + .setEvaluator(new RegressionEvaluator()) + .setNumFolds(3); + CrossValidatorModel cvModel = cv.fit(data); + LinearRegression parent = (LinearRegression)cvModel.bestModel().parent(); + + System.out.println("Optimal model has\n" + lr.regParam() + " = " + parent.getRegParam() + + "\n" + lr.elasticNetParam() + " = "+ parent.getElasticNetParam() + + "\n" + lr.fitIntercept() + " = " + parent.getFitIntercept()); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaRandomHyperparametersExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaRandomHyperparametersExample.scala new file mode 100644 index 0000000000000..9d2c58bbf9c7f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaRandomHyperparametersExample.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, Limits, ParamRandomBuilder} +import org.apache.spark.ml.tuning.RandomRanges._ +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * A simple example demonstrating model selection using ParamRandomBuilder. + * + * Run with + * {{{ + * bin/run-example ml.ModelSelectionViaRandomHyperparametersExample + * }}} + */ +object ModelSelectionViaRandomHyperparametersExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("ModelSelectionViaTrainValidationSplitExample") + .getOrCreate() + // scalastyle:off println + // $example on$ + // Prepare training and test data. + val data = spark.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") + + val lr = new LinearRegression().setMaxIter(10) + + // We sample the regularization parameter logarithmically over the range [0.01, 1.0]. + // This means that values around 0.01, 0.1 and 1.0 are roughly equally likely. + // Note that both parameters must be greater than zero as otherwise we'll get an infinity. + // We sample the the ElasticNet mixing parameter uniformly over the range [0, 1] + // Note that in real life, you'd choose more than the 5 samples we see below. + val hyperparameters = new ParamRandomBuilder() + .addLog10Random(lr.regParam, Limits(0.01, 1.0), 5) + .addGrid(lr.fitIntercept) + .addRandom(lr.elasticNetParam, Limits(0.0, 1.0), 5) + .build() + + println(s"hyperparameters:\n${hyperparameters.mkString("\n")}") + + val cv: CrossValidator = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(hyperparameters) + .setEvaluator(new RegressionEvaluator) + .setNumFolds(3) + val cvModel: CrossValidatorModel = cv.fit(data) + val parent: LinearRegression = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + + println(s"""Optimal model has: + |${lr.regParam} = ${parent.getRegParam} + |${lr.elasticNetParam} = ${parent.getElasticNetParam} + |${lr.fitIntercept} = ${parent.getFitIntercept}""".stripMargin) + // $example off$ + + spark.stop() + } + // scalastyle:on println +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamRandomBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamRandomBuilder.scala new file mode 100644 index 0000000000000..9c296bbc95224 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamRandomBuilder.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.param._ +import org.apache.spark.ml.tuning.RandomRanges._ + +case class Limits[T: Numeric](x: T, y: T) + +private[ml] abstract class RandomT[T: Numeric] { + def randomT(): T + def randomTLog(n: Int): T +} + +abstract class Generator[T: Numeric] { + def apply(lim: Limits[T]): RandomT[T] +} + +object RandomRanges { + + private val rnd = new scala.util.Random + + private[tuning] def randomBigInt0To(x: BigInt): BigInt = { + var randVal = BigInt(x.bitLength, rnd) + while (randVal > x) { + randVal = BigInt(x.bitLength, rnd) + } + randVal + } + + private[ml] def bigIntBetween(lower: BigInt, upper: BigInt): BigInt = { + val diff: BigInt = upper - lower + randomBigInt0To(diff) + lower + } + + private def randomBigDecimalBetween(lower: BigDecimal, upper: BigDecimal): BigDecimal = { + val zeroCenteredRnd: BigDecimal = BigDecimal(rnd.nextDouble() - 0.5) + val range: BigDecimal = upper - lower + val halfWay: BigDecimal = lower + range / 2 + (zeroCenteredRnd * range) + halfWay + } + + implicit object DoubleGenerator extends Generator[Double] { + def apply(limits: Limits[Double]): RandomT[Double] = new RandomT[Double] { + import limits._ + val lower: Double = math.min(x, y) + val upper: Double = math.max(x, y) + + override def randomTLog(n: Int): Double = + RandomRanges.randomLog(lower, upper, n) + + override def randomT(): Double = + randomBigDecimalBetween(BigDecimal(lower), BigDecimal(upper)).doubleValue + } + } + + implicit object FloatGenerator extends Generator[Float] { + def apply(limits: Limits[Float]): RandomT[Float] = new RandomT[Float] { + import limits._ + val lower: Float = math.min(x, y) + val upper: Float = math.max(x, y) + + override def randomTLog(n: Int): Float = + RandomRanges.randomLog(lower, upper, n).toFloat + + override def randomT(): Float = + randomBigDecimalBetween(BigDecimal(lower), BigDecimal(upper)).floatValue + } + } + + implicit object IntGenerator extends Generator[Int] { + def apply(limits: Limits[Int]): RandomT[Int] = new RandomT[Int] { + import limits._ + val lower: Int = math.min(x, y) + val upper: Int = math.max(x, y) + + override def randomTLog(n: Int): Int = + RandomRanges.randomLog(lower, upper, n).toInt + + override def randomT(): Int = + bigIntBetween(BigInt(lower), BigInt(upper)).intValue + } + } + + private[ml] def logN(x: Double, base: Int): Double = math.log(x) / math.log(base) + + private[ml] def randomLog(lower: Double, upper: Double, n: Int): Double = { + val logLower: Double = logN(lower, n) + val logUpper: Double = logN(upper, n) + val logLimits: Limits[Double] = Limits(logLower, logUpper) + val rndLogged: RandomT[Double] = RandomRanges(logLimits) + math.pow(n, rndLogged.randomT()) + } + + private[ml] def apply[T: Generator](lim: Limits[T])(implicit t: Generator[T]): RandomT[T] = t(lim) + +} + +/** + * "For any distribution over a sample space with a finite maximum, the maximum of 60 random + * observations lies within the top 5% of the true maximum, with 95% probability" + * - Evaluating Machine Learning Models by Alice Zheng + * https://www.oreilly.com/library/view/evaluating-machine-learning/9781492048756/ch04.html + * + * Note: if you want more sophisticated hyperparameter tuning, consider Python libraries + * such as Hyperopt. + */ +@Since("3.2.0") +class ParamRandomBuilder extends ParamGridBuilder { + def addRandom[T: Generator](param: Param[T], lim: Limits[T], n: Int): this.type = { + val gen: RandomT[T] = RandomRanges(lim) + addGrid(param, (1 to n).map { _: Int => gen.randomT() }) + } + + def addLog10Random[T: Generator](param: Param[T], lim: Limits[T], n: Int): this.type = + addLogRandom(param, lim, n, 10) + + private def addLogRandom[T: Generator](param: Param[T], lim: Limits[T], + n: Int, base: Int): this.type = { + val gen: RandomT[T] = RandomRanges(lim) + addGrid(param, (1 to n).map { _: Int => gen.randomTLog(base) }) + } + + // specialized versions for Java. + + def addRandom(param: DoubleParam, x: Double, y: Double, n: Int): this.type = + addRandom(param, Limits(x, y), n)(DoubleGenerator) + + def addLog10Random(param: DoubleParam, x: Double, y: Double, n: Int): this.type = + addLogRandom(param, Limits(x, y), n, 10)(DoubleGenerator) + + def addRandom(param: FloatParam, x: Float, y: Float, n: Int): this.type = + addRandom(param, Limits(x, y), n)(FloatGenerator) + + def addLog10Random(param: FloatParam, x: Float, y: Float, n: Int): this.type = + addLogRandom(param, Limits(x, y), n, 10)(FloatGenerator) + + def addRandom(param: IntParam, x: Int, y: Int, n: Int): this.type = + addRandom(param, Limits(x, y), n)(IntGenerator) + + def addLog10Random(param: IntParam, x: Int, y: Int, n: Int): this.type = + addLogRandom(param, Limits(x, y), n, 10)(IntGenerator) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamRandomBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamRandomBuilderSuite.scala new file mode 100644 index 0000000000000..e17c48e4d991d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamRandomBuilderSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.scalatest.matchers.must.Matchers +import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param._ + +class ParamRandomBuilderSuite extends SparkFunSuite with ScalaCheckDrivenPropertyChecks + with Matchers { + + val solver = new TestParams() { + private val randomColName = "randomVal" + val DummyDoubleParam = new DoubleParam(this, randomColName, "doc") + val DummyFloatParam = new FloatParam(this, randomColName, "doc") + val DummyIntParam = new IntParam(this, randomColName, "doc") + } + import solver._ + + val DoubleLimits: Limits[Double] = Limits(1d, 100d) + val FloatLimits: Limits[Float] = Limits(1f, 100f) + val IntLimits: Limits[Int] = Limits(1, 100) + val nRandoms: Int = 5 + + // Java API + + test("Java API random Double linear params mixed with fixed values") { + checkRangeAndCardinality( + _.addRandom(DummyDoubleParam, DoubleLimits.x, DoubleLimits.y, nRandoms), + DoubleLimits, + DummyDoubleParam) + } + + test("Java API random Double log10 params mixed with fixed values") { + checkRangeAndCardinality( + _.addLog10Random(DummyDoubleParam, DoubleLimits.x, DoubleLimits.y, nRandoms), + DoubleLimits, + DummyDoubleParam) + } + + test("Java API random Float linear params mixed with fixed values") { + checkRangeAndCardinality( + _.addRandom(DummyFloatParam, FloatLimits.x, FloatLimits.y, nRandoms), + FloatLimits, + DummyFloatParam) + } + + test("Java API random Float log10 params mixed with fixed values") { + checkRangeAndCardinality( + _.addLog10Random(DummyFloatParam, FloatLimits.x, FloatLimits.y, nRandoms), + FloatLimits, + DummyFloatParam) + } + + test("Java API random Int linear params mixed with fixed values") { + checkRangeAndCardinality( + _.addRandom(DummyIntParam, IntLimits.x, IntLimits.y, nRandoms), + IntLimits, + DummyIntParam) + } + + test("Java API random Int log10 params mixed with fixed values") { + checkRangeAndCardinality( + _.addLog10Random(DummyIntParam, IntLimits.x, IntLimits.y, nRandoms), + IntLimits, + DummyIntParam) + } + + // Scala API + + test("random linear params mixed with fixed values") { + import RandomRanges._ + checkRangeAndCardinality(_.addRandom(DummyDoubleParam, DoubleLimits, nRandoms), + DoubleLimits, + DummyDoubleParam) + } + + test("random log10 params mixed with fixed values") { + import RandomRanges._ + checkRangeAndCardinality(_.addLog10Random(DummyDoubleParam, DoubleLimits, nRandoms), + DoubleLimits, + DummyDoubleParam) + } + + def checkRangeAndCardinality[T: Numeric](addFn: ParamRandomBuilder => ParamRandomBuilder, + lim: Limits[T], + randomCol: Param[T]): Unit = { + val maxIterations: Int = 10 + val basedOn: Array[ParamPair[_]] = Array(maxIter -> maxIterations) + val inputCols: Array[String] = Array("input0", "input1") + val ops: Numeric[T] = implicitly[Numeric[T]] + + val builder: ParamRandomBuilder = new ParamRandomBuilder() + .baseOn(basedOn: _*) + .addGrid(inputCol, inputCols) + val paramMap: Array[ParamMap] = addFn(builder).build() + assert(paramMap.length == inputCols.length * nRandoms * basedOn.length) + paramMap.foreach { m: ParamMap => + assert(m(maxIter) == maxIterations) + assert(inputCols contains m(inputCol)) + assert(ops.gteq(m(randomCol), lim.x)) + assert(ops.lteq(m(randomCol), lim.y)) + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/RandomRangesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/RandomRangesSuite.scala new file mode 100644 index 0000000000000..afcbc033956b5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/RandomRangesSuite.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.reflect.runtime.universe.TypeTag + +import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen.Choose +import org.scalatest.{Assertion, Succeeded} +import org.scalatest.matchers.must.Matchers +import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks + +import org.apache.spark.SparkFunSuite + +class RandomRangesSuite extends SparkFunSuite with ScalaCheckDrivenPropertyChecks with Matchers { + + import RandomRanges._ + + test("log of any base") { + assert(logN(16, 4) == 2d) + assert(logN(1000, 10) === (3d +- 0.000001)) + assert(logN(256, 2) == 8d) + } + + test("random doubles in log space") { + val gen: Gen[(Double, Double, Int)] = for { + x <- Gen.choose(0d, Double.MaxValue) + y <- Gen.choose(0d, Double.MaxValue) + n <- Gen.choose(0, Int.MaxValue) + } yield (x, y, n) + forAll(gen) { case (x, y, n) => + val lower = math.min(x, y) + val upper = math.max(x, y) + val result = randomLog(x, y, n) + assert(result >= lower && result <= upper) + } + } + + test("random BigInt generation does not go into infinite loop") { + assert(randomBigInt0To(0) == BigInt(0)) + } + + test("random ints") { + checkRange(Linear[Int]) + } + + test("random log ints") { + checkRange(Log10[Int]) + } + + test("random int distribution") { + checkDistributionOf(1000) + } + + test("random doubles") { + checkRange(Linear[Double]) + } + + test("random log doubles") { + checkRange(Log10[Double]) + } + + test("random double distribution") { + checkDistributionOf(1000d) + } + + test("random floats") { + checkRange(Linear[Float]) + } + + test("random log floats") { + checkRange(Log10[Float]) + } + + test("random float distribution") { + checkDistributionOf(1000f) + } + + private abstract class RandomFn[T: Numeric: Generator] { + def apply(genRandom: RandomT[T]): T = genRandom.randomT() + def appropriate(x: T, y: T): Boolean + } + + private def Linear[T: Numeric: Generator]: RandomFn[T] = new RandomFn { + override def apply(genRandom: RandomT[T]): T = genRandom.randomT() + override def appropriate(x: T, y: T): Boolean = true + } + + private def Log10[T: Numeric: Generator]: RandomFn[T] = new RandomFn { + override def apply(genRandom: RandomT[T]): T = genRandom.randomTLog(10) + val ops: Numeric[T] = implicitly[Numeric[T]] + override def appropriate(x: T, y: T): Boolean = { + ops.gt(x, ops.zero) && ops.gt(y, ops.zero) && x != y + } + } + + private def checkRange[T: Numeric: Generator: Choose: TypeTag: Arbitrary] + (rand: RandomFn[T]): Assertion = + forAll { (x: T, y: T) => + if (rand.appropriate(x, y)) { + val ops: Numeric[T] = implicitly[Numeric[T]] + val limit: Limits[T] = Limits(x, y) + val gen: RandomT[T] = RandomRanges(limit) + val result: T = rand(gen) + val ordered: (T, T) = lowerUpper(x, y) + assert(ops.gteq(result, ordered._1) && ops.lteq(result, ordered._2)) + } else Succeeded + } + + private def checkDistributionOf[T: Numeric: Generator: Choose](range: T): Unit = { + val ops: Numeric[T] = implicitly[Numeric[T]] + import ops._ + val gen: Gen[(T, T)] = for { + x <- Gen.choose(negate(range), range) + y <- Gen.choose(range, times(range, plus(one, one))) + } yield (x, y) + forAll(gen) { case (x, y) => + assertEvenDistribution(10000, Limits(x, y)) + } + } + + private def meanAndStandardDeviation[T: Numeric](xs: Seq[T]): (Double, Double) = { + val ops: Numeric[T] = implicitly[Numeric[T]] + val n: Int = xs.length + val mean: Double = ops.toDouble(xs.sum) / n + val squaredDiff: Seq[Double] = xs.map { x: T => math.pow(ops.toDouble(x) - mean, 2) } + val stdDev: Double = math.pow(squaredDiff.sum / n - 1, 0.5) + (mean, stdDev) + } + + private def lowerUpper[T: Numeric](x: T, y: T): (T, T) = { + val ops: Numeric[T] = implicitly[Numeric[T]] + (ops.min(x, y), ops.max(x, y)) + } + + private def midPointOf[T: Numeric : Generator](lim: Limits[T]): Double = { + val ordered: (T, T) = lowerUpper(lim.x, lim.y) + val ops: Numeric[T] = implicitly[Numeric[T]] + val range: T = ops.minus(ordered._2, ordered._1) + (ops.toDouble(range) / 2) + ops.toDouble(ordered._1) + } + + private def assertEvenDistribution[T: Numeric: Generator](n: Int, lim: Limits[T]): Assertion = { + val gen: RandomT[T] = RandomRanges(lim) + val xs: Seq[T] = (0 to n).map { _: Int => gen.randomT() } + val (mean, stdDev) = meanAndStandardDeviation(xs) + val tolerance: Double = 4 * stdDev + val halfWay: Double = midPointOf(lim) + assert(mean > halfWay - tolerance && mean < halfWay + tolerance) + } + +} diff --git a/python/docs/source/reference/pyspark.ml.rst b/python/docs/source/reference/pyspark.ml.rst index 7837d609ecb96..fc6060c979d1e 100644 --- a/python/docs/source/reference/pyspark.ml.rst +++ b/python/docs/source/reference/pyspark.ml.rst @@ -288,6 +288,7 @@ Tuning :toctree: api/ ParamGridBuilder + ParamRandomBuilder CrossValidator CrossValidatorModel TrainValidationSplit diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index 3cde34facbf9a..9f6c8192e0ff3 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -16,8 +16,10 @@ # import tempfile +import math import unittest +import numpy as np from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml import Estimator, Pipeline, Model from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel, OneVsRest @@ -26,7 +28,7 @@ from pyspark.ml.linalg import Vectors from pyspark.ml.param import Param, Params from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \ - TrainValidationSplit, TrainValidationSplitModel + TrainValidationSplit, TrainValidationSplitModel, ParamRandomBuilder from pyspark.sql.functions import rand from pyspark.testing.mlutils import DummyEvaluator, DummyLogisticRegression, \ DummyLogisticRegressionModel, SparkSessionTestCase @@ -65,6 +67,108 @@ def _fit(self, dataset): return model +class DummyParams(Params): + + def __init__(self): + super(DummyParams, self).__init__() + self.test_param = Param(self, "test_param", "dummy parameter for testing") + self.another_test_param = Param(self, "another_test_param", "second parameter for testing") + + +class ParamRandomBuilderTests(unittest.TestCase): + + def __init__(self, methodName): + super(ParamRandomBuilderTests, self).__init__(methodName=methodName) + self.dummy_params = DummyParams() + self.to_test = ParamRandomBuilder() + self.n = 100 + + def check_ranges(self, params, lowest, highest, expected_type): + self.assertEqual(self.n, len(params)) + for param in params: + for v in param.values(): + self.assertGreaterEqual(v, lowest) + self.assertLessEqual(v, highest) + self.assertEqual(type(v), expected_type) + + def check_addRandom_ranges(self, x, y, expected_type): + params = self.to_test.addRandom(self.dummy_params.test_param, x, y, self.n).build() + self.check_ranges(params, x, y, expected_type) + + def check_addLog10Random_ranges(self, x, y, expected_type): + params = self.to_test.addLog10Random(self.dummy_params.test_param, x, y, self.n).build() + self.check_ranges(params, x, y, expected_type) + + @staticmethod + def counts(xs): + key_to_count = {} + for v in xs: + k = int(v) + if key_to_count.get(k) is None: + key_to_count[k] = 1 + else: + key_to_count[k] = key_to_count[k] + 1 + return key_to_count + + @staticmethod + def raw_values_of(params): + values = [] + for param in params: + for v in param.values(): + values.append(v) + return values + + def check_even_distribution(self, vs, bin_function): + binned = map(lambda x: bin_function(x), vs) + histogram = self.counts(binned) + values = list(histogram.values()) + sd = np.std(values) + mu = np.mean(values) + for k, v in histogram.items(): + self.assertLess(abs(v - mu), 5 * sd, "{} values for bucket {} is unlikely " + "when the mean is {} and standard deviation {}" + .format(v, k, mu, sd)) + + def test_distribution(self): + params = self.to_test.addRandom(self.dummy_params.test_param, 0, 20000, 10000).build() + values = self.raw_values_of(params) + self.check_even_distribution(values, lambda x: x // 1000) + + def test_logarithmic_distribution(self): + params = self.to_test.addLog10Random(self.dummy_params.test_param, 1, 1e10, 10000).build() + values = self.raw_values_of(params) + self.check_even_distribution(values, lambda x: math.log10(x)) + + def test_param_cardinality(self): + num_random_params = 7 + values = [1, 2, 3] + self.to_test.addRandom(self.dummy_params.test_param, 1, 10, num_random_params) + self.to_test.addGrid(self.dummy_params.another_test_param, values) + self.assertEqual(len(self.to_test.build()), num_random_params * len(values)) + + def test_add_random_integer_logarithmic_range(self): + self.check_addLog10Random_ranges(100, 200, int) + + def test_add_logarithmic_random_float_and_integer_yields_floats(self): + self.check_addLog10Random_ranges(100, 200., float) + + def test_add_random_float_logarithmic_range(self): + self.check_addLog10Random_ranges(100., 200., float) + + def test_add_random_integer_range(self): + self.check_addRandom_ranges(100, 200, int) + + def test_add_random_float_and_integer_yields_floats(self): + self.check_addRandom_ranges(100, 200., float) + + def test_add_random_float_range(self): + self.check_addRandom_ranges(100., 200., float) + + def test_unexpected_type(self): + with self.assertRaises(TypeError): + self.to_test.addRandom(self.dummy_params.test_param, 1, "wrong type", 1).build() + + class ParamGridBuilderTests(SparkSessionTestCase): def test_addGrid(self): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 2bddfe822f29e..85174c8cd02f2 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,6 +18,8 @@ import os import sys import itertools +import random +import math from multiprocessing.pool import ThreadPool import numpy as np @@ -35,7 +37,7 @@ from pyspark.sql.types import BooleanType __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', - 'TrainValidationSplitModel'] + 'TrainValidationSplitModel', 'ParamRandomBuilder'] def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): @@ -152,6 +154,50 @@ def to_key_value_pairs(keys, values): return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] +class ParamRandomBuilder(ParamGridBuilder): + r""" + Builder for random value parameters used in search-based model selection. + + + .. versionadded:: 3.2.0 + """ + + @since("3.2.0") + def addRandom(self, param, x, y, n): + """ + Adds n random values between x and y. + The arguments x and y can be integers, floats or a combination of the two. If either + x or y is a float, the domain of the random value will be float. + """ + if type(x) == int and type(y) == int: + values = map(lambda _: random.randrange(x, y), range(n)) + elif type(x) == float or type(y) == float: + values = map(lambda _: random.uniform(x, y), range(n)) + else: + raise TypeError("unable to make range for types %s and %s" % type(x) % type(y)) + self.addGrid(param, values) + return self + + @since("3.2.0") + def addLog10Random(self, param, x, y, n): + """ + Adds n random values scaled logarithmically (base 10) between x and y. + For instance, a distribution for x=1.0, y=10000.0 and n=5 might reasonably look like + [1.6, 65.3, 221.9, 1024.3, 8997.5] + """ + def logarithmic_random(): + rand = random.uniform(math.log10(x), math.log10(y)) + value = 10 ** rand + if type(x) == int and type(y) == int: + value = int(value) + return value + + values = map(lambda _: logarithmic_random(), range(n)) + self.addGrid(param, values) + + return self + + class _ValidatorParams(HasSeed): """ Common params for TrainValidationSplit and CrossValidator. diff --git a/python/pyspark/ml/tuning.pyi b/python/pyspark/ml/tuning.pyi index 912abd4d7124a..028cebdccac92 100644 --- a/python/pyspark/ml/tuning.pyi +++ b/python/pyspark/ml/tuning.pyi @@ -35,6 +35,11 @@ class ParamGridBuilder: def baseOn(self, *args: Tuple[Param, Any]) -> ParamGridBuilder: ... def build(self) -> List[ParamMap]: ... +class ParamRandomBuilder(ParamGridBuilder): + def __init__(self) -> None: ... + def addRandom(self, param: Param, x: Any, y: Any, n: int) -> ParamRandomBuilder: ... + def addLog10Random(self, param: Param, x: Any, y: Any, n: int) -> ParamRandomBuilder: ... + class _ValidatorParams(HasSeed): estimator: Param[Estimator] estimatorParamMaps: Param[List[ParamMap]] From 54c053afb0c9d3fcc7ac311100c8db9deeb163c0 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 27 Feb 2021 10:31:42 -0800 Subject: [PATCH 48/60] [SPARK-34479][SQL] Add zstandard codec to Avro compression codec list ### What changes were proposed in this pull request? Avro add zstandard codec since AVRO-2195. This pr add zstandard codec to Avro compression codec list. ### Why are the changes needed? To make Avro support zstandard codec. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #31673 from wangyum/SPARK-34479. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun --- .../src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala | 5 +++++ .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 4d2cb8c19fff1..74f4d0e649587 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.avro.Schema import org.apache.avro.file.{DataFileReader, FileReader} -import org.apache.avro.file.DataFileConstants.{BZIP2_CODEC, DEFLATE_CODEC, SNAPPY_CODEC, XZ_CODEC} +import org.apache.avro.file.DataFileConstants.{BZIP2_CODEC, DEFLATE_CODEC, SNAPPY_CODEC, XZ_CODEC, ZSTANDARD_CODEC} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} import org.apache.avro.mapreduce.AvroJob @@ -109,7 +109,7 @@ private[sql] object AvroUtils extends Logging { logInfo(s"Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec.") job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) DEFLATE_CODEC - case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC) => codec + case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC | ZSTANDARD_CODEC) => codec case unknown => throw new IllegalArgumentException(s"Invalid compression codec: $unknown") } job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index d356f1f2c199d..b31f1f9274a52 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -479,6 +479,7 @@ abstract class AvroSuite val xzDir = s"$dir/xz" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" + val zstandardDir = s"$dir/zstandard" val df = spark.read.format("avro").load(testAvro) spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "uncompressed") @@ -492,17 +493,21 @@ abstract class AvroSuite df.write.format("avro").save(deflateDir) spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "snappy") df.write.format("avro").save(snappyDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "zstandard") + df.write.format("avro").save(zstandardDir) val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) val bzip2Size = FileUtils.sizeOfDirectory(new File(bzip2Dir)) val xzSize = FileUtils.sizeOfDirectory(new File(xzDir)) val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) + val zstandardSize = FileUtils.sizeOfDirectory(new File(zstandardDir)) assert(uncompressSize > deflateSize) assert(snappySize > deflateSize) assert(snappySize > bzip2Size) assert(bzip2Size > xzSize) + assert(uncompressSize > zstandardSize) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 901afd0440075..6cb5fbfbe5e68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2470,10 +2470,10 @@ object SQLConf { val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + - "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") + "uncompressed, deflate, snappy, bzip2, xz and zstandard. Default codec is snappy.") .version("2.4.0") .stringConf - .checkValues(Set("uncompressed", "deflate", "snappy", "bzip2", "xz")) + .checkValues(Set("uncompressed", "deflate", "snappy", "bzip2", "xz", "zstandard")) .createWithDefault("snappy") val AVRO_DEFLATE_LEVEL = buildConf("spark.sql.avro.deflate.level") From 5a48eb8d00faee3a7c8f023c0699296e22edb893 Mon Sep 17 00:00:00 2001 From: Phillip Henry Date: Sun, 28 Feb 2021 17:01:13 -0600 Subject: [PATCH 49/60] [SPARK-34415][ML] Python example Missing Python example file for [SPARK-34415][ML] Randomization in hyperparameter optimization (https://github.com/apache/spark/pull/31535) ### What changes were proposed in this pull request? For some reason (probably me being silly) a examples/src/main/python/ml/model_selection_random_hyperparameters_example.py was not pushed in a previous PR. This PR restores that file. ### Why are the changes needed? A single file (examples/src/main/python/ml/model_selection_random_hyperparameters_example.py) that should have been pushed as part of SPARK-34415 but was not. This was causing Lint errors as highlighted by dongjoon-hyun. Consequently, srowen asked for a new PR. ### Does this PR introduce _any_ user-facing change? No, it merely restores a file that was overlook in SPARK-34415. ### How was this patch tested? By running: `bin/spark-submit examples/src/main/python/ml/model_selection_random_hyperparameters_example.py` Closes #31687 from PhillHenry/SPARK-34415_model_selection_random_hyperparameters_example. Authored-by: Phillip Henry Signed-off-by: Sean Owen --- ...election_random_hyperparameters_example.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 examples/src/main/python/ml/model_selection_random_hyperparameters_example.py diff --git a/examples/src/main/python/ml/model_selection_random_hyperparameters_example.py b/examples/src/main/python/ml/model_selection_random_hyperparameters_example.py new file mode 100644 index 0000000000000..b436341b19665 --- /dev/null +++ b/examples/src/main/python/ml/model_selection_random_hyperparameters_example.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This example uses random hyperparameters to perform model selection. +Run with: + + bin/spark-submit examples/src/main/python/ml/model_selection_random_hyperparameters_example.py +""" +# $example on$ +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.regression import LinearRegression +from pyspark.ml.tuning import ParamRandomBuilder, CrossValidator +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("TrainValidationSplit") \ + .getOrCreate() + + # $example on$ + data = spark.read.format("libsvm") \ + .load("data/mllib/sample_linear_regression_data.txt") + + lr = LinearRegression(maxIter=10) + + # We sample the regularization parameter logarithmically over the range [0.01, 1.0]. + # This means that values around 0.01, 0.1 and 1.0 are roughly equally likely. + # Note that both parameters must be greater than zero as otherwise we'll get an infinity. + # We sample the the ElasticNet mixing parameter uniformly over the range [0, 1] + # Note that in real life, you'd choose more than the 5 samples we see below. + hyperparameters = ParamRandomBuilder() \ + .addLog10Random(lr.regParam, 0.01, 1.0, 5) \ + .addRandom(lr.elasticNetParam, 0.0, 1.0, 5) \ + .addGrid(lr.fitIntercept, [False, True]) \ + .build() + + cv = CrossValidator(estimator=lr, + estimatorParamMaps=hyperparameters, + evaluator=RegressionEvaluator(), + numFolds=2) + + model = cv.fit(data) + bestModel = model.bestModel + print("Optimal model has regParam = {}, elasticNetParam = {}, fitIntercept = {}" + .format(bestModel.getRegParam(), bestModel.getElasticNetParam(), + bestModel.getFitIntercept())) + + # $example off$ + spark.stop() From d07fc3076b296e642ce321f2b2435f3059eeed4c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 1 Mar 2021 09:06:47 +0900 Subject: [PATCH 50/60] [SPARK-33687][SQL] Support analyze all tables in a specific database MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This pr add support analyze all tables in a specific database: ```g4 ANALYZE TABLES ((FROM | IN) multipartIdentifier)? COMPUTE STATISTICS (identifier)? ``` ### Why are the changes needed? 1. Make it easy to analyze all tables in a specific database. 2. PostgreSQL has a similar implementation: https://www.postgresql.org/docs/12/sql-analyze.html. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The feature tested by unit test. The documentation tested by regenerating the documentation: menu-sql.yaml |  sql-ref-syntax-aux-analyze-tables.md -- | -- ![image](https://user-images.githubusercontent.com/5399861/109098769-dc33a200-775c-11eb-86b1-55531e5425e0.png) | ![image](https://user-images.githubusercontent.com/5399861/109098841-02594200-775d-11eb-8588-de8da97ec94a.png) Closes #30648 from wangyum/SPARK-33687. Authored-by: Yuming Wang Signed-off-by: Takeshi Yamamuro --- docs/_data/menu-sql.yaml | 2 + docs/sql-ref-syntax-aux-analyze-table.md | 6 +- docs/sql-ref-syntax-aux-analyze-tables.md | 110 ++++++++++++++++++ docs/sql-ref-syntax-aux-analyze.md | 1 + docs/sql-ref-syntax.md | 1 + .../spark/sql/catalyst/parser/SqlBase.g4 | 2 + .../sql/catalyst/analysis/Analyzer.scala | 2 + .../sql/catalyst/parser/AstBuilder.scala | 19 +++ .../catalyst/plans/logical/v2Commands.scala | 9 ++ .../sql/catalyst/parser/DDLParserSuite.scala | 9 ++ .../analysis/ResolveSessionCatalog.scala | 3 + .../command/AnalyzeTableCommand.scala | 36 +----- .../command/AnalyzeTablesCommand.scala | 46 ++++++++ .../sql/execution/command/CommandUtils.scala | 37 +++++- .../spark/sql/StatisticsCollectionSuite.scala | 37 ++++++ 15 files changed, 285 insertions(+), 35 deletions(-) create mode 100644 docs/sql-ref-syntax-aux-analyze-tables.md create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index cda2a1a5139a1..a9ea6fed92d4d 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -198,6 +198,8 @@ subitems: - text: ANALYZE TABLE url: sql-ref-syntax-aux-analyze-table.html + - text: ANALYZE TABLES + url: sql-ref-syntax-aux-analyze-tables.html - text: CACHE url: sql-ref-syntax-aux-cache.html subitems: diff --git a/docs/sql-ref-syntax-aux-analyze-table.md b/docs/sql-ref-syntax-aux-analyze-table.md index 8f43d7388d7db..da5338564cda9 100644 --- a/docs/sql-ref-syntax-aux-analyze-table.md +++ b/docs/sql-ref-syntax-aux-analyze-table.md @@ -50,7 +50,7 @@ ANALYZE TABLE table_identifier [ partition_spec ] * If no analyze option is specified, `ANALYZE TABLE` collects the table's number of rows and size in bytes. * **NOSCAN** - Collects only the table's size in bytes ( which does not require scanning the entire table ). + Collects only the table's size in bytes (which does not require scanning the entire table). * **FOR COLUMNS col [ , ... ] `|` FOR ALL COLUMNS** Collects column statistics for each column specified, or alternatively for every column, as well as table statistics. @@ -122,3 +122,7 @@ DESC EXTENDED students name; | histogram| NULL| +--------------+----------+ ``` + +### Related Statements + +* [ANALYZE TABLES](sql-ref-syntax-aux-analyze-tables.html) diff --git a/docs/sql-ref-syntax-aux-analyze-tables.md b/docs/sql-ref-syntax-aux-analyze-tables.md new file mode 100644 index 0000000000000..f70cfa4d7de2e --- /dev/null +++ b/docs/sql-ref-syntax-aux-analyze-tables.md @@ -0,0 +1,110 @@ +--- +layout: global +title: ANALYZE TABLES +displayTitle: ANALYZE TABLES +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +### Description + +The `ANALYZE TABLES` statement collects statistics about all the tables in a specified database to be used by the query optimizer to find a better query execution plan. + +### Syntax + +```sql +ANALYZE TABLES [ { FROM | IN } database_name ] COMPUTE STATISTICS [ NOSCAN ] +``` + +### Parameters + +* **{ FROM `|` IN } database_name** + + Specifies the name of the database to be analyzed. Without a database name, `ANALYZE` collects all tables in the current database that the current user has permission to analyze. + +* **[ NOSCAN ]** + + Collects only the table's size in bytes (which does not require scanning the entire table). + +### Examples + +```sql +CREATE DATABASE school_db; +USE school_db; + +CREATE TABLE teachers (name STRING, teacher_id INT); +INSERT INTO teachers VALUES ('Tom', 1), ('Jerry', 2); + +CREATE TABLE students (name STRING, student_id INT, age SHORT); +INSERT INTO students VALUES ('Mark', 111111, 10), ('John', 222222, 11); + +ANALYZE TABLES IN school_db COMPUTE STATISTICS NOSCAN; + +DESC EXTENDED teachers; ++--------------------+--------------------+-------+ +| col_name| data_type|comment| ++--------------------+--------------------+-------+ +| name| string| null| +| teacher_id| int| null| +| ...| ...| ...| +| Provider| parquet| | +| Statistics| 1382 bytes| | +| ...| ...| ...| ++--------------------+--------------------+-------+ + +DESC EXTENDED students; ++--------------------+--------------------+-------+ +| col_name| data_type|comment| ++--------------------+--------------------+-------+ +| name| string| null| +| student_id| int| null| +| age| smallint| null| +| ...| ...| ...| +| Statistics| 1828 bytes| | +| ...| ...| ...| ++--------------------+--------------------+-------+ + +ANALYZE TABLES COMPUTE STATISTICS; + +DESC EXTENDED teachers; ++--------------------+--------------------+-------+ +| col_name| data_type|comment| ++--------------------+--------------------+-------+ +| name| string| null| +| teacher_id| int| null| +| ...| ...| ...| +| Provider| parquet| | +| Statistics| 1382 bytes, 2 rows| | +| ...| ...| ...| ++--------------------+--------------------+-------+ + +DESC EXTENDED students; ++--------------------+--------------------+-------+ +| col_name| data_type|comment| ++--------------------+--------------------+-------+ +| name| string| null| +| student_id| int| null| +| age| smallint| null| +| ...| ...| ...| +| Provider| parquet| | +| Statistics| 1828 bytes, 2 rows| | +| ...| ...| ...| ++--------------------+--------------------+-------+ +``` + +### Related Statements + +* [ANALYZE TABLE](sql-ref-syntax-aux-analyze-table.html) diff --git a/docs/sql-ref-syntax-aux-analyze.md b/docs/sql-ref-syntax-aux-analyze.md index 4c68e6b9ec974..7808966ffe145 100644 --- a/docs/sql-ref-syntax-aux-analyze.md +++ b/docs/sql-ref-syntax-aux-analyze.md @@ -20,3 +20,4 @@ license: | --- * [ANALYZE TABLE statement](sql-ref-syntax-aux-analyze-table.html) + * [ANALYZE TABLES statement](sql-ref-syntax-aux-analyze-tables.html) diff --git a/docs/sql-ref-syntax.md b/docs/sql-ref-syntax.md index f3d35b57d90cd..4cff2123e20b0 100644 --- a/docs/sql-ref-syntax.md +++ b/docs/sql-ref-syntax.md @@ -77,6 +77,7 @@ Spark SQL is Apache Spark's module for working with structured data. The SQL Syn * [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html) * [ADD JAR](sql-ref-syntax-aux-resource-mgmt-add-jar.html) * [ANALYZE TABLE](sql-ref-syntax-aux-analyze-table.html) + * [ANALYZE TABLES](sql-ref-syntax-aux-analyze-tables.html) * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) * [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 50ef3764f3994..aa6adb3b9b481 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -134,6 +134,8 @@ statement (AS? query)? #replaceTable | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze + | ANALYZE TABLES ((FROM | IN) multipartIdentifier)? COMPUTE STATISTICS + (identifier)? #analyzeTables | ALTER TABLE multipartIdentifier ADD (COLUMN | COLUMNS) columns=qualifiedColTypeWithPositionList #addTableColumns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3952cc063b73c..282cb37d514f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -859,6 +859,8 @@ class Analyzer(override val catalogManager: CatalogManager) s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) case s @ ShowViews(UnresolvedNamespace(Seq()), _, _) => s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) + case a @ AnalyzeTables(UnresolvedNamespace(Seq()), _) => + a.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) case UnresolvedNamespace(Seq()) => ResolvedNamespace(currentCatalog, Seq.empty[String]) case UnresolvedNamespace(CatalogAndNamespace(catalog, ns)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a43d28b045d09..8ea0f2a750365 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3654,6 +3654,25 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } } + /** + * Create an [[AnalyzeTables]]. + * Example SQL for analyzing all tables in default database: + * {{{ + * ANALYZE TABLES IN default COMPUTE STATISTICS; + * }}} + */ + override def visitAnalyzeTables(ctx: AnalyzeTablesContext): LogicalPlan = withOrigin(ctx) { + if (ctx.identifier != null && + ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { + throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", + ctx.identifier()) + } + val multiPart = Option(ctx.multipartIdentifier).map(visitMultipartIdentifier) + AnalyzeTables( + UnresolvedNamespace(multiPart.getOrElse(Seq.empty[String])), + noScan = ctx.identifier != null) + } + /** * Create a [[RepairTable]]. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 847d7ae0117e5..7316e615ff8d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -660,6 +660,15 @@ case class AnalyzeTable( override def children: Seq[LogicalPlan] = child :: Nil } +/** + * The logical plan of the ANALYZE TABLES command. + */ +case class AnalyzeTables( + namespace: LogicalPlan, + noScan: Boolean) extends Command { + override def children: Seq[LogicalPlan] = Seq(namespace) +} + /** * The logical plan of the ANALYZE TABLE FOR COLUMNS command. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 870ff388edc1f..f1557daedc9f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1873,6 +1873,15 @@ class DDLParserSuite extends AnalysisTest { "Expected `NOSCAN` instead of `xxxx`") } + test("SPARK-33687: analyze tables statistics") { + comparePlans(parsePlan("ANALYZE TABLES IN a.b.c COMPUTE STATISTICS"), + AnalyzeTables(UnresolvedNamespace(Seq("a", "b", "c")), noScan = false)) + comparePlans(parsePlan("ANALYZE TABLES FROM a COMPUTE STATISTICS NOSCAN"), + AnalyzeTables(UnresolvedNamespace(Seq("a")), noScan = true)) + intercept("ANALYZE TABLES IN a.b.c COMPUTE STATISTICS xxxx", + "Expected `NOSCAN` instead of `xxxx`") + } + test("analyze table column statistics") { intercept("ANALYZE TABLE a.b.c COMPUTE STATISTICS FOR COLUMNS", "") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 290833d6a41ee..5748bc8f1f430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -373,6 +373,9 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) AnalyzePartitionCommand(ident.asTableIdentifier, partitionSpec, noScan) } + case AnalyzeTables(DatabaseInSessionCatalog(db), noScan) => + AnalyzeTablesCommand(Some(db), noScan) + case AnalyzeColumn(ResolvedV1TableOrViewIdentifier(ident), columnNames, allColumns) => AnalyzeColumnCommand(ident.asTableIdentifier, columnNames, allColumns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 67cfcebec187c..d114ca015d7ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogTableType /** @@ -27,39 +26,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType */ case class AnalyzeTableCommand( tableIdent: TableIdentifier, - noscan: Boolean = true) extends RunnableCommand { + noScan: Boolean = true) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - val sessionState = sparkSession.sessionState - val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) - if (tableMeta.tableType == CatalogTableType.VIEW) { - // Analyzes a catalog view if the view is cached - val table = sparkSession.table(tableIdent.quotedString) - val cacheManager = sparkSession.sharedState.cacheManager - if (cacheManager.lookupCachedData(table.logicalPlan).isDefined) { - if (!noscan) { - // To collect table stats, materializes an underlying columnar RDD - table.count() - } - } else { - throw new AnalysisException("ANALYZE TABLE is not supported on views.") - } - } else { - // Compute stats for the whole table - val newTotalSize = CommandUtils.calculateTotalSize(sparkSession, tableMeta) - val newRowCount = - if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) - - // Update the metastore if the above statistics of the table are different from those - // recorded in the metastore. - val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount) - if (newStats.isDefined) { - sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) - } - } - + CommandUtils.analyzeTable(sparkSession, tableIdent, noScan) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala new file mode 100644 index 0000000000000..ef0701909de2e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import scala.util.control.NonFatal + +import org.apache.spark.sql.{Row, SparkSession} + + +/** + * Analyzes all tables in the given database to generate statistics. + */ +case class AnalyzeTablesCommand( + databaseName: Option[String], + noScan: Boolean) extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val db = databaseName.getOrElse(catalog.getCurrentDatabase) + catalog.listTables(db).foreach { tbl => + try { + CommandUtils.analyzeTable(sparkSession, tbl, noScan) + } catch { + case NonFatal(e) => + logWarning(s"Failed to analyze table ${tbl.table} in the " + + s"database $db because of ${e.toString}", e) + } + } + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 98912966db474..da5d00c595cb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -199,6 +199,41 @@ object CommandUtils extends Logging { newStats } + def analyzeTable( + sparkSession: SparkSession, + tableIdent: TableIdentifier, + noScan: Boolean): Unit = { + val sessionState = sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) + if (tableMeta.tableType == CatalogTableType.VIEW) { + // Analyzes a catalog view if the view is cached + val table = sparkSession.table(tableIdent.quotedString) + val cacheManager = sparkSession.sharedState.cacheManager + if (cacheManager.lookupCachedData(table.logicalPlan).isDefined) { + if (!noScan) { + // To collect table stats, materializes an underlying columnar RDD + table.count() + } + } else { + throw new AnalysisException("ANALYZE TABLE is not supported on views.") + } + } else { + // Compute stats for the whole table + val newTotalSize = CommandUtils.calculateTotalSize(sparkSession, tableMeta) + val newRowCount = + if (noScan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) + + // Update the metastore if the above statistics of the table are different from those + // recorded in the metastore. + val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount) + if (newStats.isDefined) { + sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) + } + } + } + /** * Compute stats for the given columns. * @return (row count, map from column name to CatalogColumnStats) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 87e7641c87f6a..481bc66f902a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -678,4 +678,41 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("SPARK-33687: analyze all tables in a specific database") { + withTempDatabase { database => + spark.catalog.setCurrentDatabase(database) + withTempDir { dir => + withTable("t1", "t2") { + spark.range(10).write.saveAsTable("t1") + sql(s"CREATE EXTERNAL TABLE t2 USING parquet LOCATION '${dir.toURI}' " + + "AS SELECT * FROM range(20)") + withView("v1", "v2") { + sql("CREATE VIEW v1 AS SELECT 1 c1") + sql("CREATE VIEW v2 AS SELECT 2 c2") + sql("CACHE TABLE v1") + sql("CACHE LAZY TABLE v2") + + sql(s"ANALYZE TABLES IN $database COMPUTE STATISTICS NOSCAN") + checkTableStats("t1", hasSizeInBytes = true, expectedRowCounts = None) + checkTableStats("t2", hasSizeInBytes = true, expectedRowCounts = None) + assert(getCatalogTable("v1").stats.isEmpty) + checkOptimizedPlanStats(spark.table("v1"), 4, Some(1), Seq.empty) + checkOptimizedPlanStats(spark.table("v2"), 1, None, Seq.empty) + + sql("ANALYZE TABLES COMPUTE STATISTICS") + checkTableStats("t1", hasSizeInBytes = true, expectedRowCounts = Some(10)) + checkTableStats("t2", hasSizeInBytes = true, expectedRowCounts = Some(20)) + checkOptimizedPlanStats(spark.table("v1"), 4, Some(1), Seq.empty) + checkOptimizedPlanStats(spark.table("v2"), 4, Some(1), Seq.empty) + } + } + } + } + + val errMsg = intercept[AnalysisException] { + sql(s"ANALYZE TABLES IN db_not_exists COMPUTE STATISTICS") + }.getMessage + assert(errMsg.contains("Database 'db_not_exists' not found")) + } } From 0216051acadedcc7e9bcd840aa78776159b200d1 Mon Sep 17 00:00:00 2001 From: Shardul Mahadik Date: Mon, 1 Mar 2021 09:10:20 +0900 Subject: [PATCH 51/60] [SPARK-34506][CORE] ADD JAR with ivy coordinates should be compatible with Hive transitive behavior ### What changes were proposed in this pull request? SPARK-33084 added the ability to use ivy coordinates with `SparkContext.addJar`. PR #29966 claims to mimic Hive behavior although I found a few cases where it doesn't 1) The default value of the transitive parameter is false, both in case of parameter not being specified in coordinate or parameter value being invalid. The Hive behavior is that transitive is [true if not specified](https://github.com/apache/hive/blob/cb2ac3dcc6af276c6f64ee00f034f082fe75222b/ql/src/java/org/apache/hadoop/hive/ql/util/DependencyResolver.java#L169) in the coordinate and [false for invalid values](https://github.com/apache/hive/blob/cb2ac3dcc6af276c6f64ee00f034f082fe75222b/ql/src/java/org/apache/hadoop/hive/ql/util/DependencyResolver.java#L124). Also, regardless of Hive, I think a default of true for the transitive parameter also matches [ivy's own defaults](https://ant.apache.org/ivy/history/2.5.0/ivyfile/dependency.html#_attributes). 2) The parameter value for transitive parameter is regarded as case-sensitive [based on the understanding](https://github.com/apache/spark/pull/29966#discussion_r547752259) that Hive behavior is case-sensitive. However, this is not correct, Hive [treats the parameter value case-insensitively](https://github.com/apache/hive/blob/cb2ac3dcc6af276c6f64ee00f034f082fe75222b/ql/src/java/org/apache/hadoop/hive/ql/util/DependencyResolver.java#L122). I propose that we be compatible with Hive for these behaviors ### Why are the changes needed? To make `ADD JAR` with ivy coordinates compatible with Hive's transitive behavior ### Does this PR introduce _any_ user-facing change? The user-facing changes here are within master as the feature introduced in SPARK-33084 has not been released yet 1. Previously an ivy coordinate without `transitive` parameter specified did not resolve transitive dependency, now it does. 2. Previously an `transitive` parameter value was treated case-sensitively. e.g. `transitive=TRUE` would be treated as false as it did not match exactly `true`. Now it will be treated case-insensitively. ### How was this patch tested? Modified existing unit tests to test new behavior Add new unit test to cover usage of `exclude` with unspecified `transitive` Closes #31623 from shardulm94/spark-34506. Authored-by: Shardul Mahadik Signed-off-by: Takeshi Yamamuro --- .../apache/spark/util/DependencyUtils.scala | 14 ++++---- .../org/apache/spark/SparkContextSuite.scala | 33 +++++++++++++------ ...ql-ref-syntax-aux-resource-mgmt-add-jar.md | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++--- .../sql/hive/execution/HiveQuerySuite.scala | 7 +++- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala index 60e866a556796..f7135edd2129d 100644 --- a/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala @@ -59,8 +59,9 @@ private[spark] object DependencyUtils extends Logging { * @param uri Ivy URI need to be downloaded. * @return Tuple value of parameter `transitive` and `exclude` value. * - * 1. transitive: whether to download dependency jar of Ivy URI, default value is false - * and this parameter value is case-sensitive. Invalid value will be treat as false. + * 1. transitive: whether to download dependency jar of Ivy URI, default value is true + * and this parameter value is case-insensitive. This mimics Hive's behaviour for + * parsing the transitive parameter. Invalid value will be treat as false. * Example: Input: exclude=org.mortbay.jetty:jetty&transitive=true * Output: true * @@ -72,7 +73,7 @@ private[spark] object DependencyUtils extends Logging { private def parseQueryParams(uri: URI): (Boolean, String) = { val uriQuery = uri.getQuery if (uriQuery == null) { - (false, "") + (true, "") } else { val mapTokens = uriQuery.split("&").map(_.split("=")) if (mapTokens.exists(isInvalidQueryString)) { @@ -81,14 +82,15 @@ private[spark] object DependencyUtils extends Logging { } val groupedParams = mapTokens.map(kv => (kv(0), kv(1))).groupBy(_._1) - // Parse transitive parameters (e.g., transitive=true) in an Ivy URI, default value is false + // Parse transitive parameters (e.g., transitive=true) in an Ivy URI, default value is true val transitiveParams = groupedParams.get("transitive") if (transitiveParams.map(_.size).getOrElse(0) > 1) { logWarning("It's best to specify `transitive` parameter in ivy URI query only once." + " If there are multiple `transitive` parameter, we will select the last one") } val transitive = - transitiveParams.flatMap(_.takeRight(1).map(_._2 == "true").headOption).getOrElse(false) + transitiveParams.flatMap(_.takeRight(1).map(_._2.equalsIgnoreCase("true")).headOption) + .getOrElse(true) // Parse an excluded list (e.g., exclude=org.mortbay.jetty:jetty,org.eclipse.jetty:jetty-http) // in an Ivy URI. When download Ivy URI jar, Spark won't download transitive jar @@ -125,7 +127,7 @@ private[spark] object DependencyUtils extends Logging { * `parameter=value¶meter=value...` * Note that currently Ivy URI query part support two parameters: * 1. transitive: whether to download dependent jars related to your Ivy URI. - * transitive=false or `transitive=true`, if not set, the default value is false. + * transitive=false or `transitive=true`, if not set, the default value is true. * 2. exclude: exclusion list when download Ivy URI jar and dependency jars. * The `exclude` parameter content is a ',' separated `group:module` pair string : * `exclude=group:module,group:module...` diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 7a4970e60e932..0ba2a030dac12 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -1035,13 +1035,10 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } - test("SPARK-33084: Add jar support Ivy URI -- default transitive = false") { + test("SPARK-33084: Add jar support Ivy URI -- default transitive = true") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0") assert(sc.listJars().exists(_.contains("org.apache.hive_hive-storage-api-2.7.0.jar"))) - assert(!sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) - - sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=true") assert(sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) } @@ -1083,6 +1080,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + test("SPARK-34506: Add jar support Ivy URI -- transitive=false will not download " + + "dependency jars") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) + sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=false") + assert(sc.listJars().exists(_.contains("org.apache.hive_hive-storage-api-2.7.0.jar"))) + assert(!sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) + } + + test("SPARK-34506: Add jar support Ivy URI -- test exclude param when transitive unspecified") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) + sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?exclude=commons-lang:commons-lang") + assert(sc.listJars().exists(_.contains("org.apache.hive_hive-storage-api-2.7.0.jar"))) + assert(sc.listJars().exists(_.contains("org.slf4j_slf4j-api-1.7.10.jar"))) + assert(!sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) + } + test("SPARK-33084: Add jar support Ivy URI -- test exclude param when transitive=true") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0" + @@ -1131,24 +1144,24 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("SPARK-33084: Add jar support Ivy URI -- test param key case sensitive") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) - sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?TRANSITIVE=true") + sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=false") assert(sc.listJars().exists(_.contains("org.apache.hive_hive-storage-api-2.7.0.jar"))) assert(!sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) - sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=true") + sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?TRANSITIVE=false") assert(sc.listJars().exists(_.contains("org.apache.hive_hive-storage-api-2.7.0.jar"))) assert(sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) } - test("SPARK-33084: Add jar support Ivy URI -- test transitive value case sensitive") { + test("SPARK-33084: Add jar support Ivy URI -- test transitive value case insensitive") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) - sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=TRUE") + sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=FALSE") assert(sc.listJars().exists(_.contains("org.apache.hive_hive-storage-api-2.7.0.jar"))) assert(!sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) - sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=true") + sc.addJar("ivy://org.apache.hive:hive-storage-api:2.7.0?transitive=false") assert(sc.listJars().exists(_.contains("org.apache.hive_hive-storage-api-2.7.0.jar"))) - assert(sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) + assert(!sc.listJars().exists(_.contains("commons-lang_commons-lang-2.6.jar"))) } test("SPARK-34346: hadoop configuration priority for spark/hive/hadoop configs") { diff --git a/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md b/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md index 6d31125fd612d..e5ac58ba8195f 100644 --- a/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md +++ b/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md @@ -36,7 +36,7 @@ ADD JAR file_name The name of the JAR file to be added. It could be either on a local file system or a distributed file system or an Ivy URI. Apache Ivy is a popular dependency manager focusing on flexibility and simplicity. Now we support two parameter in URI query string: - * transitive: whether to download dependent jars related to your ivy URL. It is case-sensitive and only take last one if multiple transitive parameters are specified. + * transitive: whether to download dependent jars related to your ivy URL. The parameter name is case-sensitive, and the parameter value is case-insensitive. If multiple transitive parameters are specified, the last one wins. * exclude: exclusion list during downloading Ivy URI jar and dependent jars. User can write Ivy URI such as: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 82c49f9cbf29a..98af68b3f4cb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3726,13 +3726,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark test("SPARK-33084: Add jar support Ivy URI in SQL") { val sc = spark.sparkContext val hiveVersion = "2.3.8" - // default transitive=false, only download specified jar - sql(s"ADD JAR ivy://org.apache.hive.hcatalog:hive-hcatalog-core:$hiveVersion") + // transitive=false, only download specified jar + sql(s"ADD JAR ivy://org.apache.hive.hcatalog:hive-hcatalog-core:$hiveVersion?transitive=false") assert(sc.listJars() .exists(_.contains(s"org.apache.hive.hcatalog_hive-hcatalog-core-$hiveVersion.jar"))) - // test download ivy URL jar return multiple jars - sql("ADD JAR ivy://org.scala-js:scalajs-test-interface_2.12:1.2.0?transitive=true") + // default transitive=true, test download ivy URL jar return multiple jars + sql("ADD JAR ivy://org.scala-js:scalajs-test-interface_2.12:1.2.0") assert(sc.listJars().exists(_.contains("scalajs-library_2.12"))) assert(sc.listJars().exists(_.contains("scalajs-test-interface_2.12"))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index d4bcba4128db9..87c2541dc7555 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1224,7 +1224,12 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd test("SPARK-33084: Add jar support Ivy URI in SQL") { val testData = TestHive.getHiveFile("data/files/sample.json").toURI withTable("t") { - sql(s"ADD JAR ivy://org.apache.hive.hcatalog:hive-hcatalog-core:$hiveVersion") + // hive-catalog-core has some transitive dependencies which dont exist on maven central + // and hence cannot be found in the test environment or are non-jar (.pom) which cause + // failures in tests. Use transitive=false as it should be good enough to test the Ivy + // support in Hive ADD JAR + sql(s"ADD JAR ivy://org.apache.hive.hcatalog:hive-hcatalog-core:$hiveVersion" + + "?transitive=false") sql( """CREATE TABLE t(a string, b string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) From d574308864816b74372346b1f0b497f2e71c2000 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Sun, 28 Feb 2021 16:21:42 -0800 Subject: [PATCH 52/60] [SPARK-34579][SQL][TEST] Fix wrong UT in SQLQuerySuite ### What changes were proposed in this pull request? Some UT in SQLQuerySuite is not incorrect, it have wrong table name in `withTable`, this pr to make it correct. ### Why are the changes needed? Fix UT ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existed UT Closes #31681 from AngersZhuuuu/SPARK-34569. Authored-by: Angerszhuuuu Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 96c5bf7e27279..f3aad782cebc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -775,7 +775,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi } test("test CTAS") { - withTable("test_ctas_1234") { + withTable("test_ctas_123") { sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src") checkAnswer( sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), @@ -1969,7 +1969,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi for (i <- 1 to 3) { Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) } - withTable("load_t_folder_wildcard") { + withTable("load_t") { sql("CREATE TABLE load_t (a STRING) USING hive") sql(s"LOAD DATA LOCAL INPATH '${ path.substring(0, path.length - 1) From 1afe284ed899792a5230b0635ae11ff56ebd8f1b Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 1 Mar 2021 09:30:18 +0900 Subject: [PATCH 53/60] [SPARK-34570][SQL] Remove dead code from constructors of [Hive]SessionStateBuilder ### What changes were proposed in this pull request? the parameter - `options` is never used. The changes here was part of https://github.com/apache/spark/pull/30642, It got reverted for easier backporting #30642 as a hotfix by https://github.com/apache/spark/pull/30642/commits/dad24543aa7bb7cc81d2a8522112eb797b015633, this PR brings it back to master. ### Why are the changes needed? remove unless dead code ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Passing CI is enough. Closes #31683 from yaooqinn/SPARK-34570. Authored-by: Kent Yao Signed-off-by: Takeshi Yamamuro --- .../scala/org/apache/spark/sql/SparkSession.scala | 11 ++++------- .../spark/sql/internal/BaseSessionStateBuilder.scala | 3 +-- .../org/apache/spark/sql/internal/SessionState.scala | 7 +++---- .../org/apache/spark/sql/test/TestSQLContext.scala | 9 ++++----- .../spark/sql/hive/HiveSessionStateBuilder.scala | 7 +++---- .../org/apache/spark/sql/hive/test/TestHive.scala | 9 ++++----- 6 files changed, 19 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 0fada5500edde..678233dfea16b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -152,8 +152,7 @@ class SparkSession private( .getOrElse { val state = SparkSession.instantiateSessionState( SparkSession.sessionStateClassName(sparkContext.conf), - self, - initialSessionOptions) + self) state } } @@ -1134,16 +1133,14 @@ object SparkSession extends Logging { */ private def instantiateSessionState( className: String, - sparkSession: SparkSession, - options: Map[String, String]): SessionState = { + sparkSession: SparkSession): SessionState = { try { // invoke new [Hive]SessionStateBuilder( // SparkSession, - // Option[SessionState], - // Map[String, String]) + // Option[SessionState]) val clazz = Utils.classForName(className) val ctor = clazz.getConstructors.head - ctor.newInstance(sparkSession, None, options).asInstanceOf[BaseSessionStateBuilder].build() + ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build() } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Error while instantiating '$className':", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index da1782cff5412..eb769340d7c5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -56,8 +56,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager @Unstable abstract class BaseSessionStateBuilder( val session: SparkSession, - val parentState: Option[SessionState], - val options: Map[String, String]) { + val parentState: Option[SessionState]) { type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 258c9bbac7b80..12ec732beead3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -135,10 +135,9 @@ private[sql] object SessionState { @Unstable class SessionStateBuilder( session: SparkSession, - parentState: Option[SessionState], - options: Map[String, String]) - extends BaseSessionStateBuilder(session, parentState, options) { - override protected def newBuilder: NewBuilder = new SessionStateBuilder(_, _, Map.empty) + parentState: Option[SessionState]) + extends BaseSessionStateBuilder(session, parentState) { + override protected def newBuilder: NewBuilder = new SessionStateBuilder(_, _) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 380723029b8a8..47a6f3617da63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -39,7 +39,7 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) @transient override lazy val sessionState: SessionState = { - new TestSQLSessionStateBuilder(this, None, Map.empty).build() + new TestSQLSessionStateBuilder(this, None).build() } // Needed for Java tests @@ -66,9 +66,8 @@ private[sql] object TestSQLContext { private[sql] class TestSQLSessionStateBuilder( session: SparkSession, - state: Option[SessionState], - options: Map[String, String]) - extends SessionStateBuilder(session, state, options) with WithTestConf { + state: Option[SessionState]) + extends SessionStateBuilder(session, state) with WithTestConf { override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs - override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _, Map.empty) + override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 95dec02ef172c..b98a956dfcc50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -39,9 +39,8 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo */ class HiveSessionStateBuilder( session: SparkSession, - parentState: Option[SessionState], - options: Map[String, String]) - extends BaseSessionStateBuilder(session, parentState, options) { + parentState: Option[SessionState]) + extends BaseSessionStateBuilder(session, parentState) { private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog @@ -116,7 +115,7 @@ class HiveSessionStateBuilder( } } - override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _, Map.empty) + override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _) } class HiveSessionResourceLoader( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index cbba9be32b77c..061ab9b6f38cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -223,7 +223,7 @@ private[hive] class TestHiveSparkSession( @transient override lazy val sessionState: SessionState = { - new TestHiveSessionStateBuilder(this, parentSessionState, Map.empty).build() + new TestHiveSessionStateBuilder(this, parentSessionState).build() } lazy val metadataHive: HiveClient = { @@ -651,9 +651,8 @@ private[hive] object TestHiveContext { private[sql] class TestHiveSessionStateBuilder( session: SparkSession, - state: Option[SessionState], - options: Map[String, String]) - extends HiveSessionStateBuilder(session, state, options) + state: Option[SessionState]) + extends HiveSessionStateBuilder(session, state) with WithTestConf { override def overrideConfs: Map[String, String] = TestHiveContext.overrideConfs @@ -662,7 +661,7 @@ private[sql] class TestHiveSessionStateBuilder( new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan) } - override protected def newBuilder: NewBuilder = new TestHiveSessionStateBuilder(_, _, Map.empty) + override protected def newBuilder: NewBuilder = new TestHiveSessionStateBuilder(_, _) } private[hive] object HiveTestJars { From f494c5cff9d56744f8e7a2b646be6d01de8a09f4 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Sun, 28 Feb 2021 16:37:49 -0800 Subject: [PATCH 54/60] [SPARK-33212][FOLLOWUP] Add hadoop-yarn-server-web-proxy for Hadoop 3.x profile ### What changes were proposed in this pull request? This adds `hadoop-yarn-server-web-proxy` as dependency for Yarn and Hadoop 3.x profile (it is already a dependency for 2.x). Also excludes some dependencies from the module which are already covered by other Hadoop jars used by Spark. ### Why are the changes needed? The class `org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter` is used by `ApplicationMaster`: ```scala private def addAmIpFilter(driver: Option[RpcEndpointRef], proxyBase: String) = { val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" val params = client.getAmIpFilterParams(yarnConf, proxyBase) driver match { case Some(d) => d.send(AddWebUIFilter(amFilter, params, proxyBase)) ... ``` and will be loaded at runtime. Therefore, without the above jar Spark Yarn app will fail with `ClassNotFoundError`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit tests. Also tested manually and it worked with the fix, while was failing previously. Closes #31642 from sunchao/SPARK-33212-followup-2. Authored-by: Chao Sun Signed-off-by: Dongjoon Hyun --- assembly/pom.xml | 4 ++++ dev/deps/spark-deps-hadoop-3.2-hive-2.3 | 1 + pom.xml | 24 ++++++++++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/assembly/pom.xml b/assembly/pom.xml index 6aa97710f7307..d662aae96c4af 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -136,6 +136,10 @@ spark-yarn_${scala.binary.version} ${project.version}
+ + org.apache.hadoop + hadoop-yarn-server-web-proxy +
diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index 977fc4b1210f1..39951c2aec3e8 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -58,6 +58,7 @@ gson/2.2.4//gson-2.2.4.jar guava/14.0.1//guava-14.0.1.jar hadoop-client-api/3.2.2//hadoop-client-api-3.2.2.jar hadoop-client-runtime/3.2.2//hadoop-client-runtime-3.2.2.jar +hadoop-yarn-server-web-proxy/3.2.2//hadoop-yarn-server-web-proxy-3.2.2.jar hive-beeline/2.3.8//hive-beeline-2.3.8.jar hive-cli/2.3.8//hive-cli-2.3.8.jar hive-common/2.3.8//hive-common-2.3.8.jar diff --git a/pom.xml b/pom.xml index 3bd5ef74a9336..4c300e475700d 100644 --- a/pom.xml +++ b/pom.xml @@ -1407,6 +1407,26 @@ ${yarn.version} ${hadoop.deps.scope} + + org.apache.hadoop + hadoop-yarn-server-common + + + org.apache.hadoop + hadoop-yarn-common + + + org.apache.hadoop + hadoop-yarn-api + + + org.bouncycastle + bcprov-jdk15on + + + org.bouncycastle + bcpkix-jdk15on + org.fusesource.leveldbjni leveldbjni-all @@ -1427,6 +1447,10 @@ javax.servlet servlet-api + + javax.servlet + javax.servlet-api + commons-logging commons-logging From 3d0ee9604eab3c01af469049d80b053bf2aaa636 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Mon, 1 Mar 2021 11:18:57 +0900 Subject: [PATCH 55/60] [SPARK-34520][CORE][FOLLOW-UP] Remove SecurityManager in GangliaSink ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/31636. There was one place missed in `GangliaSink`, and we should also remove `SecurityManager`. ### Why are the changes needed? To make `GangliaSink` work. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? It was found in the internal it tests in the company I work for. Closes #31688 from HyukjinKwon/SPARK-34520-followup. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- .../scala/org/apache/spark/metrics/sink/GangliaSink.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index 7266187597589..2b48a34abb8fe 100644 --- a/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -25,11 +25,10 @@ import com.codahale.metrics.ganglia.GangliaReporter import info.ganglia.gmetric4j.gmetric.GMetric import info.ganglia.gmetric4j.gmetric.GMetric.UDPAddressingMode -import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GangliaSink(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +class GangliaSink( + val property: Properties, val registry: MetricRegistry) extends Sink { val GANGLIA_KEY_PERIOD = "period" val GANGLIA_DEFAULT_PERIOD = 10 From 62737e140c7b04805726a33c392c297335db7b45 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 1 Mar 2021 13:55:35 +0900 Subject: [PATCH 56/60] [SPARK-34556][SQL] Checking duplicate static partition columns should respect case sensitive conf ### What changes were proposed in this pull request? This PR makes partition spec parsing respect case sensitive conf. ### Why are the changes needed? When parsing the partition spec, Spark will call `org.apache.spark.sql.catalyst.parser.ParserUtils.checkDuplicateKeys` to check if there are duplicate partition column names in the list. But this method is always case sensitive and doesn't detect duplicate partition column names when using different cases. ### Does this PR introduce _any_ user-facing change? Yep. This prevents users from writing incorrect queries such as `INSERT OVERWRITE t PARTITION (c='2', C='3') VALUES (1)` when they don't enable case sensitive conf. ### How was this patch tested? The new added test will fail without this change. Closes #31669 from zsxwing/SPARK-34556. Authored-by: Shixiong Zhu Signed-off-by: HyukjinKwon --- .../sql/catalyst/parser/AstBuilder.scala | 6 ++++- .../apache/spark/sql/SQLInsertTestSuite.scala | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8ea0f2a750365..c56426b2d42cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -482,7 +482,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for // partition columns will be done in analyzer. - checkDuplicateKeys(parts.toSeq, ctx) + if (conf.caseSensitiveAnalysis) { + checkDuplicateKeys(parts.toSeq, ctx) + } else { + checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx) + } parts.toMap } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index c7446c7a9f443..67c5f12dc71dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -208,6 +208,28 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("t"), Row("1", null)) } } + + test("SPARK-34556: " + + "checking duplicate static partition columns should respect case sensitive conf") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c)") + val e = intercept[AnalysisException] { + sql("INSERT OVERWRITE t PARTITION (c='2', C='3') VALUES (1)") + } + assert(e.getMessage.contains("Found duplicate keys 'c'")) + } + // The following code is skipped for Hive because columns stored in Hive Metastore is always + // case insensitive and we cannot create such table in Hive Metastore. + if (!format.startsWith("hive")) { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTable("t") { + sql(s"CREATE TABLE t(i int, c string, C string) USING PARQUET PARTITIONED BY (c, C)") + sql("INSERT OVERWRITE t PARTITION (c='2', C='3') VALUES (1)") + checkAnswer(spark.table("t"), Row(1, "2", "3")) + } + } + } + } } class FileSourceSQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession { From a6cc5e625fcba2ef889f759207a01075cce3b38b Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 1 Mar 2021 15:36:33 +0900 Subject: [PATCH 57/60] [SPARK-34574][DOCS] Jekyll fails to generate Scala API docs for Scala 2.13 ### What changes were proposed in this pull request? This PR fixes an issue that `bundler exec jekyll` build fails to generate Scala API docs even though after `dev/change-scala-version.sh 2.13` run. ### Why are the changes needed? The reason of this issue is that `build/sbt` in `copy_api_dirs.rb` runs without `-Pscala-2.13`. So, it's a bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I tested the following patterns manually. * `dev/change-scala-version 2.13` and then `bundler exec jekyll build` * `dev/change-scala-version 2.12` to change back to Scala 2.12 and then `bundler exec jekyll build` * `dev/change-scala-version 2.13` two times to confirm the idempotency and then `bundler exec jekyll build` * `dev/change-scala-version 2.12` two times to confirm the idempotency and then `bundler exec jekyll build` Closes #31690 from sarutak/jekyll-scala-2.13. Authored-by: Kousuke Saruta Signed-off-by: HyukjinKwon --- dev/change-scala-version.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh index 06411b9b12a0d..a69405e0cca89 100755 --- a/dev/change-scala-version.sh +++ b/dev/change-scala-version.sh @@ -67,4 +67,9 @@ sed_i '1,/[0-9]*\.[0-9]*[0-9 # Update source of scaladocs echo "$BASEDIR/docs/_plugins/copy_api_dirs.rb" +if [ $TO_VERSION = "2.13" ]; then + sed_i '/\-Pscala-'$TO_VERSION'/!s:build/sbt:build/sbt \-Pscala\-'$TO_VERSION':' "$BASEDIR/docs/_plugins/copy_api_dirs.rb" +else + sed_i 's:build/sbt \-Pscala\-'$FROM_VERSION':build/sbt:' "$BASEDIR/docs/_plugins/copy_api_dirs.rb" +fi sed_i 's/scala\-'$FROM_VERSION'/scala\-'$TO_VERSION'/' "$BASEDIR/docs/_plugins/copy_api_dirs.rb" From 984ff396a2eeea98169575228dc00513cdca85ea Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Mon, 1 Mar 2021 22:20:28 +0800 Subject: [PATCH 58/60] [SPARK-34561][SQL] Fix drop/add columns from/to a dataset of v2 `DESCRIBE TABLE` ### What changes were proposed in this pull request? In the PR, I propose to generate "stable" output attributes per the logical node of the `DESCRIBE TABLE` command. ### Why are the changes needed? This fixes the issue demonstrated by the example: ```scala val tbl = "testcat.ns1.ns2.tbl" sql(s"CREATE TABLE $tbl (c0 INT) USING _") val description = sql(s"DESCRIBE TABLE $tbl") description.drop("comment") ``` The `drop()` method fails with the error: ``` org.apache.spark.sql.AnalysisException: Resolved attribute(s) col_name#102,data_type#103 missing from col_name#29,data_type#30,comment#31 in operator !Project [col_name#102, data_type#103]. Attribute(s) with the same name appear in the operation: col_name,data_type. Please check if the right attribute(s) are used.; !Project [col_name#102, data_type#103] +- LocalRelation [col_name#29, data_type#30, comment#31] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis(CheckAnalysis.scala:51) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis$(CheckAnalysis.scala:50) ``` ### Does this PR introduce _any_ user-facing change? Yes. After the changes, `drop()`/`add()` works as expected: ```scala description.drop("comment").show() +---------------+---------+ | col_name|data_type| +---------------+---------+ | c0| int| | | | | # Partitioning| | |Not partitioned| | +---------------+---------+ ``` ### How was this patch tested? 1. Run new test: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *DataSourceV2SQLSuite" ``` 2. Run existing test suite: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *CatalogedDDLSuite" ``` Closes #31676 from MaxGekk/describe-table-drop-column. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../catalyst/plans/logical/v2Commands.scala | 8 ++++-- .../analysis/ResolveSessionCatalog.scala | 2 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +-- .../sql-tests/results/describe.sql.out | 2 +- .../sql/connector/DataSourceV2SQLSuite.scala | 26 ++++++++++++++++++- .../command/PlanResolutionSuite.scala | 8 +++--- 6 files changed, 39 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 7316e615ff8d6..e5c4370e1648e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -345,9 +345,13 @@ object ShowNamespaces { case class DescribeRelation( relation: LogicalPlan, partitionSpec: TablePartitionSpec, - isExtended: Boolean) extends Command { + isExtended: Boolean, + override val output: Seq[Attribute] = DescribeRelation.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = Seq(relation) - override def output: Seq[Attribute] = DescribeCommandSchema.describeTableAttributes() +} + +object DescribeRelation { + def getOutputAttrs: Seq[Attribute] = DescribeCommandSchema.describeTableAttributes() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 5748bc8f1f430..4d17d329f0f24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -200,7 +200,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) AlterTableRenameCommand(oldName.asTableIdentifier, newName.asTableIdentifier, isView) // Use v1 command to describe (temp) view, as v2 catalog doesn't support view yet. - case DescribeRelation(ResolvedV1TableOrViewIdentifier(ident), partitionSpec, isExtended) => + case DescribeRelation(ResolvedV1TableOrViewIdentifier(ident), partitionSpec, isExtended, _) => DescribeTableCommand(ident.asTableIdentifier, partitionSpec, isExtended) case DescribeColumn(ResolvedViewIdentifier(ident), column: UnresolvedAttribute, isExtended) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 135de2ad4c5c8..639ade4ec44da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -270,11 +270,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case desc @ DescribeNamespace(ResolvedNamespace(catalog, ns), extended) => DescribeNamespaceExec(desc.output, catalog.asNamespaceCatalog, ns, extended) :: Nil - case desc @ DescribeRelation(r: ResolvedTable, partitionSpec, isExtended) => + case DescribeRelation(r: ResolvedTable, partitionSpec, isExtended, output) => if (partitionSpec.nonEmpty) { throw new AnalysisException("DESCRIBE does not support partition for v2 tables.") } - DescribeTableExec(desc.output, r.table, isExtended) :: Nil + DescribeTableExec(output, r.table, isExtended) :: Nil case desc @ DescribeColumn(_: ResolvedTable, column, isExtended) => column match { diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 3b5d8a1396283..5f88478a239fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -539,7 +539,7 @@ EXPLAIN EXTENDED DESC t struct -- !query output == Parsed Logical Plan == -'DescribeRelation false +'DescribeRelation false, [col_name#x, data_type#x, comment#x] +- 'UnresolvedTableOrView [t], DESCRIBE TABLE, true == Analyzed Logical Plan == diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 2f57298856fb5..ca4dff871ed50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION} import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource -import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, LongType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -2576,6 +2576,30 @@ class DataSourceV2SQLSuite } } + test("SPARK-34561: drop/add columns to a dataset of `DESCRIBE TABLE`") { + val tbl = s"${catalogAndNamespace}tbl" + withTable(tbl) { + sql(s"CREATE TABLE $tbl (c0 INT) USING $v2Format") + val description = sql(s"DESCRIBE TABLE $tbl") + val noCommentDataset = description.drop("comment") + val expectedSchema = new StructType() + .add( + name = "col_name", + dataType = StringType, + nullable = false, + metadata = new MetadataBuilder().putString("comment", "name of the column").build()) + .add( + name = "data_type", + dataType = StringType, + nullable = false, + metadata = new MetadataBuilder().putString("comment", "data type of the column").build()) + assert(noCommentDataset.schema === expectedSchema) + val isNullDataset = noCommentDataset + .withColumn("is_null", noCommentDataset("col_name").isNull) + assert(isNullDataset.schema === expectedSchema.add("is_null", BooleanType, false)) + } + } + private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 1b090369f2a23..dcc91f38e604f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -875,13 +875,13 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed2, expected2) } else { parsed1 match { - case DescribeRelation(_: ResolvedTable, _, isExtended) => + case DescribeRelation(_: ResolvedTable, _, isExtended, _) => assert(!isExtended) case _ => fail("Expect DescribeTable, but got:\n" + parsed1.treeString) } parsed2 match { - case DescribeRelation(_: ResolvedTable, _, isExtended) => + case DescribeRelation(_: ResolvedTable, _, isExtended, _) => assert(isExtended) case _ => fail("Expect DescribeTable, but got:\n" + parsed2.treeString) } @@ -895,7 +895,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed3, expected3) } else { parsed3 match { - case DescribeRelation(_: ResolvedTable, partitionSpec, isExtended) => + case DescribeRelation(_: ResolvedTable, partitionSpec, isExtended, _) => assert(!isExtended) assert(partitionSpec == Map("a" -> "1")) case _ => fail("Expect DescribeTable, but got:\n" + parsed2.treeString) @@ -1198,7 +1198,7 @@ class PlanResolutionSuite extends AnalysisTest { case AppendData(r: DataSourceV2Relation, _, _, _, _) => assert(r.catalog.exists(_ == catalogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) - case DescribeRelation(r: ResolvedTable, _, _) => + case DescribeRelation(r: ResolvedTable, _, _, _) => assert(r.catalog == catalogIdent) assert(r.identifier.name() == tableIdent) case ShowTableProperties(r: ResolvedTable, _, _) => From 85b50d42586be2f3f19c7d94a8aa297215ebfbc2 Mon Sep 17 00:00:00 2001 From: Yikun Jiang Date: Mon, 1 Mar 2021 08:39:38 -0600 Subject: [PATCH 59/60] [SPARK-34539][BUILD][INFRA] Remove stand-alone version Zinc server ### What changes were proposed in this pull request? Cleanup all Zinc standalone server code, and realated coniguration. ### Why are the changes needed? ![image](https://user-images.githubusercontent.com/1736354/109154790-c1d3e580-77a9-11eb-8cde-835deed6e10e.png) - Zinc is the incremental compiler to speed up builds of compilation. - The scala-maven-plugin is the mave plugin, which is used by Spark, one of the function is to integrate the Zinc to enable the incremental compiler. - Since Spark v3.0.0 ([SPARK-28759](https://issues.apache.org/jira/browse/SPARK-28759)), the scala-maven-plugin is upgraded to v4.X, that means Zinc v0.3.13 standalone server is useless anymore. However, we still download, install, start the standalone Zinc server. we should remove all zinc standalone server code, and all related configuration. See more in [SPARK-34539](https://issues.apache.org/jira/projects/SPARK/issues/SPARK-34539) or the doc [Zinc standalone server is useless after scala-maven-plugin 4.x](https://docs.google.com/document/d/1u4kCHDx7KjVlHGerfmbcKSB0cZo6AD4cBdHSse-SBsM). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Run any mvn build: ./build/mvn -DskipTests clean package -pl core You could see the increamental compilation is still working, the stage of "scala-maven-plugin:4.3.0:compile (scala-compile-first)" with incremental compilation info, like: ``` [INFO] --- scala-maven-plugin:4.3.0:testCompile (scala-test-compile-first) spark-core_2.12 --- [INFO] Using incremental compilation using Mixed compile order [INFO] Compiler bridge file: /root/.sbt/1.0/zinc/org.scala-sbt/org.scala-sbt-compiler-bridge_2.12-1.3.1-bin_2.12.10__52.0-1.3.1_20191012T045515.jar [INFO] compiler plugin: BasicArtifact(com.github.ghik,silencer-plugin_2.12.10,1.6.0,null) [INFO] Compiling 303 Scala sources and 27 Java sources to /root/spark/core/target/scala-2.12/test-classes ... ``` Closes #31647 from Yikun/cleanup-zinc. Authored-by: Yikun Jiang Signed-off-by: Sean Owen --- .github/workflows/build_and_test.yml | 12 +++---- .gitignore | 1 - build/mvn | 47 +------------------------ dev/create-release/do-release-docker.sh | 1 - dev/create-release/release-build.sh | 24 +++---------- dev/run-tests.py | 13 +------ docs/building-spark.md | 7 ++-- pom.xml | 1 - 8 files changed, 14 insertions(+), 92 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 6c61281740748..8be24f18e41a9 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -99,12 +99,11 @@ jobs: if: ${{ github.event.inputs.target != '' }} run: git merge --progress --ff-only origin/${{ github.event.inputs.target }} # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - - name: Cache Scala, SBT, Maven and Zinc + - name: Cache Scala, SBT and Maven uses: actions/cache@v2 with: path: | build/apache-maven-* - build/zinc-* build/scala-* build/*.jar ~/.sbt @@ -186,12 +185,11 @@ jobs: if: ${{ github.event.inputs.target != '' }} run: git merge --progress --ff-only origin/${{ github.event.inputs.target }} # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - - name: Cache Scala, SBT, Maven and Zinc + - name: Cache Scala, SBT and Maven uses: actions/cache@v2 with: path: | build/apache-maven-* - build/zinc-* build/scala-* build/*.jar ~/.sbt @@ -254,12 +252,11 @@ jobs: if: ${{ github.event.inputs.target != '' }} run: git merge --progress --ff-only origin/${{ github.event.inputs.target }} # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - - name: Cache Scala, SBT, Maven and Zinc + - name: Cache Scala, SBT and Maven uses: actions/cache@v2 with: path: | build/apache-maven-* - build/zinc-* build/scala-* build/*.jar ~/.sbt @@ -297,12 +294,11 @@ jobs: - name: Checkout Spark repository uses: actions/checkout@v2 # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - - name: Cache Scala, SBT, Maven and Zinc + - name: Cache Scala, SBT and Maven uses: actions/cache@v2 with: path: | build/apache-maven-* - build/zinc-* build/scala-* build/*.jar ~/.sbt diff --git a/.gitignore b/.gitignore index 917eac1e6c882..021af9ba4bba7 100644 --- a/.gitignore +++ b/.gitignore @@ -30,7 +30,6 @@ R/pkg/tests/fulltests/Rplots.pdf build/*.jar build/apache-maven* build/scala* -build/zinc* cache checkpoint conf/*.cmd diff --git a/build/mvn b/build/mvn index 672599a280310..719d7573f4c05 100755 --- a/build/mvn +++ b/build/mvn @@ -91,27 +91,6 @@ install_mvn() { fi } -# Install zinc under the build/ folder -install_zinc() { - local ZINC_VERSION=0.3.15 - ZINC_BIN="$(command -v zinc)" - if [ "$ZINC_BIN" ]; then - local ZINC_DETECTED_VERSION="$(zinc -version | head -n1 | awk '{print $5}')" - fi - - if [ $(version $ZINC_DETECTED_VERSION) -lt $(version $ZINC_VERSION) ]; then - local zinc_path="zinc-${ZINC_VERSION}/bin/zinc" - [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} - - install_app \ - "${TYPESAFE_MIRROR}/zinc/${ZINC_VERSION}" \ - "zinc-${ZINC_VERSION}.tgz" \ - "${zinc_path}" - ZINC_BIN="${_DIR}/${zinc_path}" - fi -} - # Determine the Scala version from the root pom.xml file, set the Scala URL, # and, with that, download the specific version of Scala necessary under # the build/ folder @@ -131,31 +110,12 @@ install_scala() { SCALA_LIBRARY="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-library.jar" } -# Setup healthy defaults for the Zinc port if none were provided from -# the environment -ZINC_PORT=${ZINC_PORT:-"3030"} - -# Install the proper version of Scala, Zinc and Maven for the build -if [ "$(uname -m)" != 'aarch64' ]; then - install_zinc -fi install_scala install_mvn # Reset the current working directory cd "${_CALLING_DIR}" -# Now that zinc is ensured to be installed, check its status and, if its -# not running or just installed, start it -if [ "$(uname -m)" != 'aarch64' ] && [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then - export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} - "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} - "${ZINC_BIN}" -start -port ${ZINC_PORT} \ - -server 127.0.0.1 -idle-timeout 3h \ - -scala-compiler "${SCALA_COMPILER}" \ - -scala-library "${SCALA_LIBRARY}" &>/dev/null -fi - # Set any `mvn` options if not already present export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} @@ -163,12 +123,7 @@ echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # call the `mvn` command as usual # SPARK-25854 -"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" "$@" MVN_RETCODE=$? -# Try to shut down zinc explicitly if the server is still running. -if [ "$(uname -m)" != 'aarch64' ]; then - "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} -fi - exit $MVN_RETCODE diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh index 19a5345e6569c..f1632f01686c7 100755 --- a/dev/create-release/do-release-docker.sh +++ b/dev/create-release/do-release-docker.sh @@ -133,7 +133,6 @@ ASF_PASSWORD=$ASF_PASSWORD GPG_PASSPHRASE=$GPG_PASSPHRASE RELEASE_STEP=$RELEASE_STEP USER=$USER -ZINC_OPTS=${RELEASE_ZINC_OPTS:-"-Xmx4g -XX:ReservedCodeCacheSize=2g"} EOF JAVA_VOL= diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index a39ea6e82bdfe..52665f7ceb50f 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -179,8 +179,6 @@ if [[ "$1" == "package" ]]; then shasum -a 512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION - ZINC_PORT=3035 - # Updated for each binary build make_binary_release() { NAME=$1 @@ -198,17 +196,12 @@ if [[ "$1" == "package" ]]; then R_FLAG="--r" fi - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - ZINC_PORT=$((ZINC_PORT + 1)) - echo "Building binary dist $NAME" cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME ./dev/change-scala-version.sh $SCALA_VERSION - export ZINC_PORT=$ZINC_PORT echo "Creating distribution: $NAME ($FLAGS)" # Write out the VERSION to PySpark version info we rewrite the - into a . and SNAPSHOT @@ -221,8 +214,7 @@ if [[ "$1" == "package" ]]; then echo "Creating distribution" ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz \ - $PIP_FLAG $R_FLAG $FLAGS \ - -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log + $PIP_FLAG $R_FLAG $FLAGS 2>&1 > ../binary-release-$NAME.log cd .. if [[ -n $R_FLAG ]]; then @@ -380,14 +372,11 @@ if [[ "$1" == "publish-snapshot" ]]; then echo "$ASF_PASSWORD" >> $tmp_settings echo "" >> $tmp_settings - # Generate random port for Zinc - export ZINC_PORT=$(python -S -c "import random; print(random.randrange(3030,4030))") - - $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $SCALA_2_12_PROFILES $PUBLISH_PROFILES clean deploy + $MVN --settings $tmp_settings -DskipTests $SCALA_2_12_PROFILES $PUBLISH_PROFILES clean deploy if [[ $PUBLISH_SCALA_2_13 = 1 ]]; then ./dev/change-scala-version.sh 2.13 - $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $SCALA_2_13_PROFILES $PUBLISH_PROFILES clean deploy + $MVN --settings $tmp_settings -DskipTests $SCALA_2_13_PROFILES $PUBLISH_PROFILES clean deploy fi rm $tmp_settings @@ -417,18 +406,15 @@ if [[ "$1" == "publish-release" ]]; then tmp_repo=$(mktemp -d spark-repo-XXXXX) - # Generate random port for Zinc - export ZINC_PORT=$(python -S -c "import random; print(random.randrange(3030,4030))") - if [[ $PUBLISH_SCALA_2_13 = 1 ]]; then ./dev/change-scala-version.sh 2.13 - $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests \ + $MVN -Dmaven.repo.local=$tmp_repo -DskipTests \ $SCALA_2_13_PROFILES $PUBLISH_PROFILES clean install fi if [[ $PUBLISH_SCALA_2_12 = 1 ]]; then ./dev/change-scala-version.sh 2.12 - $MVN -DzincPort=$((ZINC_PORT + 2)) -Dmaven.repo.local=$tmp_repo -DskipTests \ + $MVN -Dmaven.repo.local=$tmp_repo -DskipTests \ $SCALA_2_12_PROFILES $PUBLISH_PROFILES clean install fi diff --git a/dev/run-tests.py b/dev/run-tests.py index e54e098551514..83f9f02dafb2d 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -20,7 +20,6 @@ import itertools from argparse import ArgumentParser import os -import random import re import sys import subprocess @@ -257,21 +256,11 @@ def build_spark_documentation(): os.chdir(SPARK_HOME) -def get_zinc_port(): - """ - Get a randomized port on which to start Zinc - """ - return random.randrange(3030, 4030) - - def exec_maven(mvn_args=()): """Will call Maven in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" - zinc_port = get_zinc_port() - os.environ["ZINC_PORT"] = "%s" % zinc_port - zinc_flag = "-DzincPort=%s" % zinc_port - flags = [os.path.join(SPARK_HOME, "build", "mvn"), zinc_flag] + flags = [os.path.join(SPARK_HOME, "build", "mvn")] run_cmd(flags + mvn_args) diff --git a/docs/building-spark.md b/docs/building-spark.md index f9599b642d309..8e1c84a37b436 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -51,7 +51,7 @@ You can fix these problems by setting the `MAVEN_OPTS` variable as discussed bef ### build/mvn -Spark now comes packaged with a self-contained Maven installation to ease building and deployment of Spark from source located under the `build/` directory. This script will automatically download and setup all necessary build requirements ([Maven](https://maven.apache.org/), [Scala](https://www.scala-lang.org/), and [Zinc](https://github.com/typesafehub/zinc)) locally within the `build/` directory itself. It honors any `mvn` binary if present already, however, will pull down its own copy of Scala and Zinc regardless to ensure proper version requirements are met. `build/mvn` execution acts as a pass through to the `mvn` call allowing easy transition from previous build methods. As an example, one can build a version of Spark as follows: +Spark now comes packaged with a self-contained Maven installation to ease building and deployment of Spark from source located under the `build/` directory. This script will automatically download and setup all necessary build requirements ([Maven](https://maven.apache.org/), [Scala](https://www.scala-lang.org/)) locally within the `build/` directory itself. It honors any `mvn` binary if present already, however, will pull down its own copy of Scala regardless to ensure proper version requirements are met. `build/mvn` execution acts as a pass through to the `mvn` call allowing easy transition from previous build methods. As an example, one can build a version of Spark as follows: ./build/mvn -DskipTests clean package @@ -163,9 +163,8 @@ For the meanings of these two options, please carefully read the [Setting up Mav ## Speeding up Compilation -Developers who compile Spark frequently may want to speed up compilation; e.g., by using Zinc -(for developers who build with Maven) or by avoiding re-compilation of the assembly JAR (for -developers who build with SBT). For more information about how to do this, refer to the +Developers who compile Spark frequently may want to speed up compilation; e.g., by avoiding re-compilation of the +assembly JAR (for developers who build with SBT). For more information about how to do this, refer to the [Useful Developer Tools page](https://spark.apache.org/developer-tools.html#reducing-build-times). ## Encrypted Filesystems diff --git a/pom.xml b/pom.xml index 4c300e475700d..e543cf220286a 100644 --- a/pom.xml +++ b/pom.xml @@ -2562,7 +2562,6 @@ true true incremental - true -unchecked -deprecation From 70f6267de6258459537b748842c059f09f1f2aff Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Mon, 1 Mar 2021 18:32:32 +0000 Subject: [PATCH 60/60] [SPARK-34560][SQL] Generate unique output attributes in the `SHOW TABLES` logical node ### What changes were proposed in this pull request? In the PR, I propose to generate unique attributes in the logical nodes of the `SHOW TABLES` command. Also, this PR fixes similar issues in other logical nodes: - ShowTableExtended - ShowViews - ShowTableProperties - ShowFunctions - ShowColumns - ShowPartitions - ShowNamespaces ### Why are the changes needed? This fixes the issue which is demonstrated by the example below: ```scala scala> val show1 = sql("SHOW TABLES IN ns1") show1: org.apache.spark.sql.DataFrame = [namespace: string, tableName: string ... 1 more field] scala> val show2 = sql("SHOW TABLES IN ns2") show2: org.apache.spark.sql.DataFrame = [namespace: string, tableName: string ... 1 more field] scala> show1.show +---------+---------+-----------+ |namespace|tableName|isTemporary| +---------+---------+-----------+ | ns1| tbl1| false| +---------+---------+-----------+ scala> show2.show +---------+---------+-----------+ |namespace|tableName|isTemporary| +---------+---------+-----------+ | ns2| tbl2| false| +---------+---------+-----------+ scala> show1.join(show2).where(show1("tableName") =!= show2("tableName")).show org.apache.spark.sql.AnalysisException: Column tableName#17 are ambiguous. It's probably because you joined several Datasets together, and some of these Datasets are the same. This column points to one of the Datasets but Spark is unable to figure out which one. Please alias the Datasets with different names via `Dataset.as` before joining them, and specify the column using qualified name, e.g. `df.as("a").join(df.as("b"), $"a.id" > $"b.id")`. You can also set spark.sql.analyzer.failAmbiguousSelfJoin to false to disable this check. at org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin$.apply(DetectAmbiguousSelfJoin.scala:157) ``` ### Does this PR introduce _any_ user-facing change? Yes. After the changes, the example above works as expected: ```scala scala> show1.join(show2).where(show1("tableName") =!= show2("tableName")).show +---------+---------+-----------+---------+---------+-----------+ |namespace|tableName|isTemporary|namespace|tableName|isTemporary| +---------+---------+-----------+---------+---------+-----------+ | ns1| tbl1| false| ns2| tbl2| false| +---------+---------+-----------+---------+---------+-----------+ ``` ### How was this patch tested? By running the new test: ``` $ build/sbt -Phive-2.3 -Phive-thriftserver "test:testOnly *ShowTablesSuite" ``` Closes #31675 from MaxGekk/fix-output-attrs. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../catalyst/plans/logical/v2Commands.scala | 41 +++++++++++-------- .../command/ShowTablesSuiteBase.scala | 12 ++++++ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index e5c4370e1648e..cae221ba643a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -331,12 +331,14 @@ case class SetNamespaceLocation( case class ShowNamespaces( namespace: LogicalPlan, pattern: Option[String], - override val output: Seq[Attribute] = ShowNamespaces.OUTPUT) extends Command { + override val output: Seq[Attribute] = ShowNamespaces.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = Seq(namespace) } object ShowNamespaces { - val OUTPUT = Seq(AttributeReference("namespace", StringType, nullable = false)()) + def getOutputAttrs: Seq[Attribute] = { + Seq(AttributeReference("namespace", StringType, nullable = false)()) + } } /** @@ -496,12 +498,12 @@ case class RenameTable( case class ShowTables( namespace: LogicalPlan, pattern: Option[String], - override val output: Seq[Attribute] = ShowTables.OUTPUT) extends Command { + override val output: Seq[Attribute] = ShowTables.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = Seq(namespace) } object ShowTables { - val OUTPUT = Seq( + def getOutputAttrs: Seq[Attribute] = Seq( AttributeReference("namespace", StringType, nullable = false)(), AttributeReference("tableName", StringType, nullable = false)(), AttributeReference("isTemporary", BooleanType, nullable = false)()) @@ -514,12 +516,12 @@ case class ShowTableExtended( namespace: LogicalPlan, pattern: String, partitionSpec: Option[PartitionSpec], - override val output: Seq[Attribute] = ShowTableExtended.OUTPUT) extends Command { + override val output: Seq[Attribute] = ShowTableExtended.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = namespace :: Nil } object ShowTableExtended { - val OUTPUT = Seq( + def getOutputAttrs: Seq[Attribute] = Seq( AttributeReference("namespace", StringType, nullable = false)(), AttributeReference("tableName", StringType, nullable = false)(), AttributeReference("isTemporary", BooleanType, nullable = false)(), @@ -535,12 +537,12 @@ object ShowTableExtended { case class ShowViews( namespace: LogicalPlan, pattern: Option[String], - override val output: Seq[Attribute] = ShowViews.OUTPUT) extends Command { + override val output: Seq[Attribute] = ShowViews.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = Seq(namespace) } object ShowViews { - val OUTPUT = Seq( + def getOutputAttrs: Seq[Attribute] = Seq( AttributeReference("namespace", StringType, nullable = false)(), AttributeReference("viewName", StringType, nullable = false)(), AttributeReference("isTemporary", BooleanType, nullable = false)()) @@ -576,12 +578,12 @@ case class ShowCurrentNamespace(catalogManager: CatalogManager) extends Command case class ShowTableProperties( table: LogicalPlan, propertyKey: Option[String], - override val output: Seq[Attribute] = ShowTableProperties.OUTPUT) extends Command { + override val output: Seq[Attribute] = ShowTableProperties.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = table :: Nil } object ShowTableProperties { - val OUTPUT: Seq[Attribute] = Seq( + def getOutputAttrs: Seq[Attribute] = Seq( AttributeReference("key", StringType, nullable = false)(), AttributeReference("value", StringType, nullable = false)()) } @@ -646,12 +648,14 @@ case class ShowFunctions( userScope: Boolean, systemScope: Boolean, pattern: Option[String], - override val output: Seq[Attribute] = ShowFunctions.OUTPUT) extends Command { + override val output: Seq[Attribute] = ShowFunctions.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = child.toSeq } object ShowFunctions { - val OUTPUT = Seq(AttributeReference("function", StringType, nullable = false)()) + def getOutputAttrs: Seq[Attribute] = { + Seq(AttributeReference("function", StringType, nullable = false)()) + } } /** @@ -763,12 +767,14 @@ case class ShowCreateTable(child: LogicalPlan, asSerde: Boolean = false) extends case class ShowColumns( child: LogicalPlan, namespace: Option[Seq[String]], - override val output: Seq[Attribute] = ShowColumns.OUTPUT) extends Command { + override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = child :: Nil } object ShowColumns { - val OUTPUT: Seq[Attribute] = Seq(AttributeReference("col_name", StringType, nullable = false)()) + def getOutputAttrs: Seq[Attribute] = { + Seq(AttributeReference("col_name", StringType, nullable = false)()) + } } /** @@ -794,13 +800,16 @@ case class TruncatePartition( case class ShowPartitions( table: LogicalPlan, pattern: Option[PartitionSpec], - override val output: Seq[Attribute] = ShowPartitions.OUTPUT) extends V2PartitionCommand { + override val output: Seq[Attribute] = ShowPartitions.getOutputAttrs) + extends V2PartitionCommand { override def children: Seq[LogicalPlan] = table :: Nil override def allowPartialPartitionSpec: Boolean = true } object ShowPartitions { - val OUTPUT = Seq(AttributeReference("partition", StringType, nullable = false)()) + def getOutputAttrs: Seq[Attribute] = { + Seq(AttributeReference("partition", StringType, nullable = false)()) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala index 7af6940dc94fc..06385017bbd64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala @@ -106,6 +106,18 @@ trait ShowTablesSuiteBase extends QueryTest with DDLCommandTestUtils { } } + test("SPARK-34560: unique attribute references") { + withNamespaceAndTable("ns1", "tbl1") { t1 => + sql(s"CREATE TABLE $t1 (col INT) $defaultUsing") + val show1 = sql(s"SHOW TABLES IN $catalog.ns1") + withNamespaceAndTable("ns2", "tbl2") { t2 => + sql(s"CREATE TABLE $t2 (col INT) $defaultUsing") + val show2 = sql(s"SHOW TABLES IN $catalog.ns2") + assert(!show1.join(show2).where(show1("tableName") =!= show2("tableName")).isEmpty) + } + } + } + test("change current catalog and namespace with USE statements") { withNamespaceAndTable("ns", "table") { t => sql(s"CREATE TABLE $t (name STRING, id INT) $defaultUsing")