From 3fbe20129daf1ebff6e95e1643c326172198242c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 1 Feb 2020 20:50:47 -0800 Subject: [PATCH 0001/1280] [SPARK-27686][DOC][SQL] Update migration guide for make Hive 2.3 dependency by default ### What changes were proposed in this pull request? We have upgraded the built-in Hive from 1.2 to 2.3. This may need to set `spark.sql.hive.metastore.version` and `spark.sql.hive.metastore.jars` according to the version of your Hive metastore. Example: ``` --conf spark.sql.hive.metastore.version=1.2.1 --conf spark.sql.hive.metastore.jars=/root/hive-1.2.1-lib/* ``` Otherwise: ``` org.apache.spark.sql.AnalysisException: org.apache.hadoop.hive.ql.metadata.HiveException: Unable to fetch table spark_27686. Invalid method name: 'get_table_req'; at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:110) at org.apache.spark.sql.hive.HiveExternalCatalog.tableExists(HiveExternalCatalog.scala:841) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener.tableExists(ExternalCatalogWithListener.scala:146) at org.apache.spark.sql.catalyst.catalog.SessionCatalog.tableExists(SessionCatalog.scala:431) at org.apache.spark.sql.execution.command.CreateDataSourceTableCommand.run(createDataSourceTables.scala:52) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:70) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:68) at org.apache.spark.sql.execution.command.ExecutedCommandExec.executeCollect(commands.scala:79) at org.apache.spark.sql.Dataset.$anonfun$logicalPlan$1(Dataset.scala:226) at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3487) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$4(SQLExecution.scala:100) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:160) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:87) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3485) at org.apache.spark.sql.Dataset.(Dataset.scala:226) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:96) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:607) ... 47 elided Caused by: org.apache.hadoop.hive.ql.metadata.HiveException: Unable to fetch table spark_27686. Invalid method name: 'get_table_req' at org.apache.hadoop.hive.ql.metadata.Hive.getTable(Hive.java:1282) at org.apache.spark.sql.hive.client.HiveClientImpl.getRawTableOption(HiveClientImpl.scala:422) at org.apache.spark.sql.hive.client.HiveClientImpl.$anonfun$tableExists$1(HiveClientImpl.scala:436) at scala.runtime.java8.JFunction0$mcZ$sp.apply(JFunction0$mcZ$sp.java:23) at org.apache.spark.sql.hive.client.HiveClientImpl.$anonfun$withHiveState$1(HiveClientImpl.scala:322) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:256) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:305) at org.apache.spark.sql.hive.client.HiveClientImpl.tableExists(HiveClientImpl.scala:436) at org.apache.spark.sql.hive.HiveExternalCatalog.$anonfun$tableExists$1(HiveExternalCatalog.scala:841) at scala.runtime.java8.JFunction0$mcZ$sp.apply(JFunction0$mcZ$sp.java:23) at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:100) ... 63 more Caused by: org.apache.thrift.TApplicationException: Invalid method name: 'get_table_req' at org.apache.thrift.TServiceClient.receiveBase(TServiceClient.java:79) at org.apache.hadoop.hive.metastore.api.ThriftHiveMetastore$Client.recv_get_table_req(ThriftHiveMetastore.java:1567) at org.apache.hadoop.hive.metastore.api.ThriftHiveMetastore$Client.get_table_req(ThriftHiveMetastore.java:1554) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient.getTable(HiveMetaStoreClient.java:1350) at org.apache.hadoop.hive.ql.metadata.SessionHiveMetaStoreClient.getTable(SessionHiveMetaStoreClient.java:127) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.hadoop.hive.metastore.RetryingMetaStoreClient.invoke(RetryingMetaStoreClient.java:173) at com.sun.proxy.$Proxy38.getTable(Unknown Source) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient$SynchronizedHandler.invoke(HiveMetaStoreClient.java:2336) at com.sun.proxy.$Proxy38.getTable(Unknown Source) at org.apache.hadoop.hive.ql.metadata.Hive.getTable(Hive.java:1274) ... 74 more ``` ### Why are the changes needed? Improve documentation. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? ```SKIP_API=1 jekyll build```: ![image](https://user-images.githubusercontent.com/5399861/73531432-67a50b80-4455-11ea-9401-5cad12fd3d14.png) Closes #27161 from wangyum/SPARK-27686. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun (cherry picked from commit cd5f03a3ba18ae455f93abc5e5d04f098fa8f046) Signed-off-by: Dongjoon Hyun --- docs/sql-migration-guide.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 533c96a0832de..e4d2358b5de09 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -330,6 +330,9 @@ license: | - Since Spark 3.0, `SHOW CREATE TABLE` will always return Spark DDL, even when the given table is a Hive serde table. For Hive DDL, please use `SHOW CREATE TABLE AS SERDE` command instead. + - Since Spark 3.0, we upgraded the built-in Hive from 1.2 to 2.3. This may need to set `spark.sql.hive.metastore.version` and `spark.sql.hive.metastore.jars` according to the version of the Hive metastore. + For example: set `spark.sql.hive.metastore.version` to `1.2.1` and `spark.sql.hive.metastore.jars` to `maven` if your Hive metastore version is 1.2.1. + ## Upgrading from Spark SQL 2.4.4 to 2.4.5 - Since Spark 2.4.5, `TRUNCATE TABLE` command tries to set back original permission and ACLs during re-creating the table/partition paths. To restore the behaviour of earlier versions, set `spark.sql.truncateTable.ignorePermissionAcl.enabled` to `true`. From 2f1fb4c01d0d4bfda17b3262e6f586f4f1a25bac Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 2 Feb 2020 00:44:25 -0800 Subject: [PATCH 0002/1280] [SPARK-30704][INFRA] Use jekyll-redirect-from 0.15.0 instead of the latest ### What changes were proposed in this pull request? This PR aims to pin the version of `jekyll-redirect-from` to 0.15.0. This is a release blocker for both Apache Spark 3.0.0 and 2.4.5. ### Why are the changes needed? `jekyll-redirect-from` released 0.16.0 a few days ago and that requires Ruby 2.4.0. - https://github.com/jekyll/jekyll-redirect-from/releases/tag/v0.16.0 ``` $ cd dev/create-release/spark-rm/ $ docker build -t spark:test . ... ERROR: Error installing jekyll-redirect-from: jekyll-redirect-from requires Ruby version >= 2.4.0. ... ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manually do the above command to build `spark-rm` Docker image. ``` ... Successfully installed jekyll-redirect-from-0.15.0 Parsing documentation for jekyll-redirect-from-0.15.0 Installing ri documentation for jekyll-redirect-from-0.15.0 Done installing documentation for jekyll-redirect-from after 0 seconds 1 gem installed Successfully installed rouge-3.15.0 Parsing documentation for rouge-3.15.0 Installing ri documentation for rouge-3.15.0 Done installing documentation for rouge after 4 seconds 1 gem installed Removing intermediate container e0ec7c77b69f ---> 32dec37291c6 ``` Closes #27434 from dongjoon-hyun/SPARK-30704. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit 1adf3520e3c753e6df8dccb752e8239de682a09a) Signed-off-by: Dongjoon Hyun --- dev/create-release/spark-rm/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index 3ba8e97929613..63451687ee8c2 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -78,7 +78,7 @@ RUN apt-get clean && apt-get update && $APT_INSTALL gnupg ca-certificates && \ # Install tools needed to build the documentation. $APT_INSTALL ruby2.3 ruby2.3-dev mkdocs && \ gem install jekyll --no-rdoc --no-ri -v 3.8.6 && \ - gem install jekyll-redirect-from && \ + gem install jekyll-redirect-from -v 0.15.0 && \ gem install rouge WORKDIR /opt/spark-rm/output From 91f78aee718888fad5677445ba21024263d1037a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 3 Feb 2020 14:08:59 +0800 Subject: [PATCH 0003/1280] [SPARK-30697][SQL] Handle database and namespace exceptions in catalog.isView ### What changes were proposed in this pull request? Adds NoSuchDatabaseException and NoSuchNamespaceException to the `isView` method for SessionCatalog. ### Why are the changes needed? This method prevents specialized resolutions from kicking in within Analysis when using V2 Catalogs if the identifier is a specialized identifier. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added test to DataSourceV2SessionCatalogSuite Closes #27423 from brkyvz/isViewF. Authored-by: Burak Yavuz Signed-off-by: Wenchen Fan (cherry picked from commit 2eccfd8a73c4afa30a6aa97c2afd38661f29e24b) Signed-off-by: Wenchen Fan --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 ++ ...SourceV2DataFrameSessionCatalogSuite.scala | 22 +++++++++++++++++++ .../DataSourceV2SQLSessionCatalogSuite.scala | 14 ++++++++++++ 3 files changed, 38 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 45f0ef6c97e70..12f9a61fc2b65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -826,6 +826,8 @@ class SessionCatalog( getTempViewOrPermanentTableMetadata(ident).tableType == CatalogTableType.VIEW } catch { case _: NoSuchTableException => false + case _: NoSuchDatabaseException => false + case _: NoSuchNamespaceException => false } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 4c67888cbdc48..01caf8e2eb115 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -101,6 +101,13 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable new InMemoryTable(name, schema, partitions, properties) } + override def loadTable(ident: Identifier): Table = { + val identToUse = Option(InMemoryTableSessionCatalog.customIdentifierResolution) + .map(_(ident)) + .getOrElse(ident) + super.loadTable(identToUse) + } + override def alterTable(ident: Identifier, changes: TableChange*): Table = { val fullIdent = fullIdentifier(ident) Option(tables.get(fullIdent)) match { @@ -125,6 +132,21 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable } } +object InMemoryTableSessionCatalog { + private var customIdentifierResolution: Identifier => Identifier = _ + + def withCustomIdentifierResolver( + resolver: Identifier => Identifier)( + f: => Unit): Unit = { + try { + customIdentifierResolution = resolver + f + } finally { + customIdentifierResolution = null + } + } +} + private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]] extends QueryTest with SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala index 27725bcadbcd5..b6997445013e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala @@ -49,4 +49,18 @@ class DataSourceV2SQLSessionCatalogSuite v2Catalog.asInstanceOf[TableCatalog] .loadTable(Identifier.of(Array.empty, nameParts.last)) } + + test("SPARK-30697: catalog.isView doesn't throw an error for specialized identifiers") { + val t1 = "tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + + def idResolver(id: Identifier): Identifier = Identifier.of(Array.empty, id.name()) + + InMemoryTableSessionCatalog.withCustomIdentifierResolver(idResolver) { + // The following should not throw AnalysisException. + sql(s"DESCRIBE TABLE ignored.$t1") + } + } + } } From f9b86370cb04b72a4f00cbd4d60873960aa2792c Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 2 Feb 2020 23:37:13 -0800 Subject: [PATCH 0004/1280] [SPARK-29543][SS][FOLLOWUP] Move `spark.sql.streaming.ui.*` configs to StaticSQLConf ### What changes were proposed in this pull request? Put the configs below needed by Structured Streaming UI into StaticSQLConf: - spark.sql.streaming.ui.enabled - spark.sql.streaming.ui.retainedProgressUpdates - spark.sql.streaming.ui.retainedQueries ### Why are the changes needed? Make all SS UI configs consistent with other similar configs in usage and naming. ### Does this PR introduce any user-facing change? Yes, add new static config `spark.sql.streaming.ui.retainedProgressUpdates`. ### How was this patch tested? Existing UT. Closes #27425 from xuanyuanking/SPARK-29543-follow. Authored-by: Yuanjian Li Signed-off-by: Shixiong Zhu (cherry picked from commit a4912cee615314e9578e6ab4eae25f147feacbd5) Signed-off-by: Shixiong Zhu --- .../apache/spark/sql/internal/SQLConf.scala | 16 --------------- .../spark/sql/internal/StaticSQLConf.scala | 20 +++++++++++++++++++ .../spark/sql/internal/SharedState.scala | 15 +++++++------- .../ui/StreamingQueryStatusListener.scala | 10 ++++++---- .../sql/streaming/ui/StreamingQueryTab.scala | 2 +- .../StreamingQueryStatusListenerSuite.scala | 4 ++-- 6 files changed, 37 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04572c38be8dd..3ad3416256c7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1150,18 +1150,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val STREAMING_UI_ENABLED = - buildConf("spark.sql.streaming.ui.enabled") - .doc("Whether to run the structured streaming UI for the Spark application.") - .booleanConf - .createWithDefault(true) - - val STREAMING_UI_INACTIVE_QUERY_RETENTION = - buildConf("spark.sql.streaming.ui.numInactiveQueries") - .doc("The number of inactive queries to retain for structured streaming ui.") - .intConf - .createWithDefault(100) - val VARIABLE_SUBSTITUTE_ENABLED = buildConf("spark.sql.variable.substitute") .doc("This enables substitution using syntax like ${var} ${system:var} and ${env:var}.") @@ -2284,10 +2272,6 @@ class SQLConf extends Serializable with Logging { def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) - def isStreamingUIEnabled: Boolean = getConf(STREAMING_UI_ENABLED) - - def streamingUIInactiveQueryRetention: Int = getConf(STREAMING_UI_INACTIVE_QUERY_RETENTION) - def streamingFileCommitProtocolClass: String = getConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS) def fileSinkLogDeletion: Boolean = getConf(FILE_SINK_LOG_DELETION) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 66ac9ddb21aaa..6bc752260a893 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -176,4 +176,24 @@ object StaticSQLConf { .internal() .booleanConf .createWithDefault(true) + + val STREAMING_UI_ENABLED = + buildStaticConf("spark.sql.streaming.ui.enabled") + .doc("Whether to run the Structured Streaming Web UI for the Spark application when the " + + "Spark Web UI is enabled.") + .booleanConf + .createWithDefault(true) + + val STREAMING_UI_RETAINED_PROGRESS_UPDATES = + buildStaticConf("spark.sql.streaming.ui.retainedProgressUpdates") + .doc("The number of progress updates to retain for a streaming query for Structured " + + "Streaming UI.") + .intConf + .createWithDefault(100) + + val STREAMING_UI_RETAINED_QUERIES = + buildStaticConf("spark.sql.streaming.ui.retainedQueries") + .doc("The number of inactive queries to retain for Structured Streaming UI.") + .intConf + .createWithDefault(100) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index fefd72dcf1752..5347264d7c50a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -145,13 +145,14 @@ private[sql] class SharedState( * data to show. */ lazy val streamingQueryStatusListener: Option[StreamingQueryStatusListener] = { - val sqlConf = SQLConf.get - if (sqlConf.isStreamingUIEnabled) { - val statusListener = new StreamingQueryStatusListener(sqlConf) - sparkContext.ui.foreach(new StreamingQueryTab(statusListener, _)) - Some(statusListener) - } else { - None + sparkContext.ui.flatMap { ui => + if (conf.get(STREAMING_UI_ENABLED)) { + val statusListener = new StreamingQueryStatusListener(conf) + new StreamingQueryTab(statusListener, ui) + Some(statusListener) + } else { + None + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala index db085dbe87ec4..91815110e0d39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala @@ -24,8 +24,9 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.streaming.{StreamingQueryListener, StreamingQueryProgress} /** @@ -33,7 +34,7 @@ import org.apache.spark.sql.streaming.{StreamingQueryListener, StreamingQueryPro * UI data for both active and inactive query. * TODO: Add support for history server. */ -private[sql] class StreamingQueryStatusListener(sqlConf: SQLConf) extends StreamingQueryListener { +private[sql] class StreamingQueryStatusListener(conf: SparkConf) extends StreamingQueryListener { private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 timestampFormat.setTimeZone(DateTimeUtils.getTimeZone("UTC")) @@ -45,8 +46,9 @@ private[sql] class StreamingQueryStatusListener(sqlConf: SQLConf) extends Stream private[ui] val activeQueryStatus = new ConcurrentHashMap[UUID, StreamingQueryUIData]() private[ui] val inactiveQueryStatus = new mutable.Queue[StreamingQueryUIData]() - private val streamingProgressRetention = sqlConf.streamingProgressRetention - private val inactiveQueryStatusRetention = sqlConf.streamingUIInactiveQueryRetention + private val streamingProgressRetention = + conf.get(StaticSQLConf.STREAMING_UI_RETAINED_PROGRESS_UPDATES) + private val inactiveQueryStatusRetention = conf.get(StaticSQLConf.STREAMING_UI_RETAINED_QUERIES) override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { activeQueryStatus.putIfAbsent(event.runId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryTab.scala index f909cfd97514e..bb097ffc06912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryTab.scala @@ -34,6 +34,6 @@ private[sql] class StreamingQueryTab( parent.addStaticHandler(StreamingQueryTab.STATIC_RESOURCE_DIR, "/static/sql") } -object StreamingQueryTab { +private[sql] object StreamingQueryTab { private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala index bd74ed340b408..adbb501f9842e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.streaming class StreamingQueryStatusListenerSuite extends StreamTest { test("onQueryStarted, onQueryProgress, onQueryTerminated") { - val listener = new StreamingQueryStatusListener(spark.sqlContext.conf) + val listener = new StreamingQueryStatusListener(spark.sparkContext.conf) // hanlde query started event val id = UUID.randomUUID() @@ -73,7 +73,7 @@ class StreamingQueryStatusListenerSuite extends StreamTest { } test("same query start multiple times") { - val listener = new StreamingQueryStatusListener(spark.sqlContext.conf) + val listener = new StreamingQueryStatusListener(spark.sparkContext.conf) // handle first time start val id = UUID.randomUUID() From a496750cf3d6da210ebbdbaf8e66910798242f2a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 3 Feb 2020 19:57:16 -0800 Subject: [PATCH 0005/1280] [SPARK-30718][BUILD] Exclude jdk.tools dependency from hadoop-yarn-api ### What changes were proposed in this pull request? This PR removes the `jdk.tools:jdk.tools` transitive dependency from `hadoop-yarn-api`. - This is only used in `hadoop-annotation` project in some `*Doclet.java`. ### Why are the changes needed? Although this is not used in Apache Spark, this can cause a resolve failure in JDK11 environment. jdk tools ### Does this PR introduce any user-facing change? No. This is a dev-only change. From developers, this will remove the `Cannot resolve` error in IDE environment. ### How was this patch tested? - Pass the Jenkins in JDK8 - Manually, import the project with JDK11. Closes #27445 from dongjoon-hyun/SPARK-30718. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit 41bdb7ad3949d05542abe5ab2b440a51c3a18bce) Signed-off-by: Dongjoon Hyun --- pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pom.xml b/pom.xml index 7c23444054efc..a8d6ac932bac2 100644 --- a/pom.xml +++ b/pom.xml @@ -1200,6 +1200,10 @@ com.sun.jersey.contribs * + + jdk.tools + jdk.tools + From b4e6a7c74737e96895d723e6d0cf499f7012c9ea Mon Sep 17 00:00:00 2001 From: maryannxue Date: Tue, 4 Feb 2020 12:31:44 +0800 Subject: [PATCH 0006/1280] [SPARK-30717][SQL] AQE subquery map should cache `SubqueryExec` instead of `ExecSubqueryExpression` ### What changes were proposed in this pull request? This PR is to fix a potential bug in AQE where an `ExecSubqueryExpression` could be mistakenly replaced with another `ExecSubqueryExpression` with the same `ListQuery` but a different `child` expression. This is because a ListQuery's id can only identify the ListQuery itself, not the parent expression `InSubquery`, but right now the `subqueryMap` in `InsertAdaptiveSparkPlan` uses the `ListQuery`'s id as key and the corresponding `InSubqueryExec` for the `ListQuery`'s parent expression as value. So the fix uses the corresponding `SubqueryExec` for the `ListQuery` itself as the map's value. ### Why are the changes needed? This logical bug could potentially cause a wrong query plan, which could throw an exception related to unresolved columns. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Passed existing UTs. Closes #27446 from maryannxue/spark-30717. Authored-by: maryannxue Signed-off-by: Wenchen Fan (cherry picked from commit 6097b343baa8e4a8bc7159dc3d394f13b3c9959b) Signed-off-by: Wenchen Fan --- .../adaptive/InsertAdaptiveSparkPlan.scala | 30 ++++++------------- .../adaptive/PlanAdaptiveSubqueries.scala | 23 +++++++++----- 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 04696209ce10e..9252827856af4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution.adaptive import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningSubquery, ListQuery, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{DynamicPruningSubquery, ListQuery, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.ExecutedCommandExec import org.apache.spark.sql.execution.exchange.Exchange @@ -102,36 +101,25 @@ case class InsertAdaptiveSparkPlan( * For each sub-query, generate the adaptive execution plan for each sub-query by applying this * rule, or reuse the execution plan from another sub-query of the same semantics if possible. */ - private def buildSubqueryMap(plan: SparkPlan): mutable.HashMap[Long, ExecSubqueryExpression] = { - val subqueryMap = mutable.HashMap.empty[Long, ExecSubqueryExpression] + private def buildSubqueryMap(plan: SparkPlan): Map[Long, SubqueryExec] = { + val subqueryMap = mutable.HashMap.empty[Long, SubqueryExec] plan.foreach(_.expressions.foreach(_.foreach { case expressions.ScalarSubquery(p, _, exprId) if !subqueryMap.contains(exprId.id) => val executedPlan = compileSubquery(p) verifyAdaptivePlan(executedPlan, p) - val scalarSubquery = execution.ScalarSubquery( - SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) - subqueryMap.put(exprId.id, scalarSubquery) - case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) + val subquery = SubqueryExec(s"subquery${exprId.id}", executedPlan) + subqueryMap.put(exprId.id, subquery) + case expressions.InSubquery(_, ListQuery(query, _, exprId, _)) if !subqueryMap.contains(exprId.id) => val executedPlan = compileSubquery(query) verifyAdaptivePlan(executedPlan, query) - val expr = if (values.length == 1) { - values.head - } else { - CreateNamedStruct( - values.zipWithIndex.flatMap { case (v, index) => - Seq(Literal(s"col_$index"), v) - } - ) - } - val inSubquery = InSubqueryExec(expr, - SubqueryExec(s"subquery#${exprId.id}", executedPlan), exprId) - subqueryMap.put(exprId.id, inSubquery) + val subquery = SubqueryExec(s"subquery#${exprId.id}", executedPlan) + subqueryMap.put(exprId.id, subquery) case _ => })) - subqueryMap + subqueryMap.toMap } def compileSubquery(plan: LogicalPlan): SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 91d4359224a6a..f845b6b16ee3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -18,19 +18,28 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.ListQuery +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, ListQuery, Literal} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ExecSubqueryExpression, SparkPlan} +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan, SubqueryExec} -case class PlanAdaptiveSubqueries( - subqueryMap: scala.collection.Map[Long, ExecSubqueryExpression]) extends Rule[SparkPlan] { +case class PlanAdaptiveSubqueries(subqueryMap: Map[Long, SubqueryExec]) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case expressions.ScalarSubquery(_, _, exprId) => - subqueryMap(exprId.id) - case expressions.InSubquery(_, ListQuery(_, _, exprId, _)) => - subqueryMap(exprId.id) + execution.ScalarSubquery(subqueryMap(exprId.id), exprId) + case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) => + val expr = if (values.length == 1) { + values.head + } else { + CreateNamedStruct( + values.zipWithIndex.flatMap { case (v, index) => + Seq(Literal(s"col_$index"), v) + } + ) + } + InSubqueryExec(expr, subqueryMap(exprId.id), exprId) } } } From c8ffaa9c2d86b1b8e807ad6f0570efcd0d7cd3d9 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 3 Feb 2020 11:09:25 +0900 Subject: [PATCH 0007/1280] [MINOR][SPARKR][DOCS] Remove duplicate @name tags from read.df and read.stream ### What changes were proposed in this pull request? Remove duplicate `name` tags from `read.df` and `read.stream`. ### Why are the changes needed? These tags are already present in https://github.com/apache/spark/blob/1adf3520e3c753e6df8dccb752e8239de682a09a/R/pkg/R/SQLContext.R#L546 and https://github.com/apache/spark/blob/1adf3520e3c753e6df8dccb752e8239de682a09a/R/pkg/R/SQLContext.R#L678 for `read.df` and `read.stream` respectively. As only one `name` tag per block is allowed, this causes build warnings with recent `roxygen2` versions: ``` Warning: [/path/to/spark/R/pkg/R/SQLContext.R:559] name May only use one name per block Warning: [/path/to/spark/R/pkg/R/SQLContext.R:690] name May only use one name per block ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests. Closes #27437 from zero323/roxygen-warnings-names. Authored-by: zero323 Signed-off-by: HyukjinKwon --- R/pkg/R/SQLContext.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f48a334ed6766..c6842912706af 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -556,7 +556,6 @@ tableToDF <- function(tableName) { #' stringSchema <- "name STRING, info MAP" #' df4 <- read.df(mapTypeJsonPath, "json", stringSchema, multiLine = TRUE) #' } -#' @name read.df #' @note read.df since 1.4.0 read.df <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { if (!is.null(path) && !is.character(path)) { @@ -687,7 +686,6 @@ read.jdbc <- function(url, tableName, #' stringSchema <- "name STRING, info MAP" #' df1 <- read.stream("json", path = jsonDir, schema = stringSchema, maxFilesPerTrigger = 1) #' } -#' @name read.stream #' @note read.stream since 2.2.0 #' @note experimental read.stream <- function(source = NULL, schema = NULL, ...) { From 5fdea014873535457d6865594eb7ea2cc38fdc25 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 4 Feb 2020 16:33:34 +0900 Subject: [PATCH 0008/1280] [SPARK-26618][SQL][FOLLOWUP] Describe the behavior change of typed `TIMESTAMP`/`DATE` literals ### What changes were proposed in this pull request? In the PR, I propose to update the SQL migration guide, and clarify behavior change of typed `TIMESTAMP` and `DATE` literals for input strings without time zone information - local timestamp and date strings. ### Why are the changes needed? To inform users that the typed literals may change their behavior in Spark 3.0 because of different sources of the default time zone - JVM system time zone in Spark 2.4 and earlier, and `spark.sql.session.timeZone` in Spark 3.0. ### Does this PR introduce any user-facing change? No ### How was this patch tested? N/A Closes #27435 from MaxGekk/timestamp-lit-migration-guide. Authored-by: Maxim Gekk Signed-off-by: HyukjinKwon (cherry picked from commit 0202b675afca65c6615a06805a4d4d12f3f97bdb) Signed-off-by: HyukjinKwon --- docs/sql-migration-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index e4d2358b5de09..a5ef1c2b1d045 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -77,7 +77,7 @@ license: | - Formatting of `TIMESTAMP` and `DATE` literals. - - Creating of typed `TIMESTAMP` and `DATE` literals from strings. Since Spark 3.0, string conversion to typed `TIMESTAMP`/`DATE` literals is performed via casting to `TIMESTAMP`/`DATE` values. For example, `TIMESTAMP '2019-12-23 12:59:30'` is semantically equal to `CAST('2019-12-23 12:59:30' AS TIMESTAMP)`. In Spark version 2.4 and earlier, the `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()` functions are used for the same purpose. + - Creating of typed `TIMESTAMP` and `DATE` literals from strings. Since Spark 3.0, string conversion to typed `TIMESTAMP`/`DATE` literals is performed via casting to `TIMESTAMP`/`DATE` values. For example, `TIMESTAMP '2019-12-23 12:59:30'` is semantically equal to `CAST('2019-12-23 12:59:30' AS TIMESTAMP)`. When the input string does not contain information about time zone, the time zone from the SQL config `spark.sql.session.timeZone` is used in that case. In Spark version 2.4 and earlier, the conversion is based on JVM system time zone. The different sources of the default time zone may change the behavior of typed `TIMESTAMP` and `DATE` literals. - In Spark version 2.4 and earlier, invalid time zone ids are silently ignored and replaced by GMT time zone, for example, in the from_utc_timestamp function. Since Spark 3.0, such time zone ids are rejected, and Spark throws `java.time.DateTimeException`. From 0d1984269fc8f81de384529acca2b3f8584d2f1a Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 4 Feb 2020 21:17:05 +0800 Subject: [PATCH 0009/1280] [SPARK-30725][SQL] Make legacy SQL configs as internal configs ### What changes were proposed in this pull request? All legacy SQL configs are marked as internal configs. In particular, the following configs are updated as internals: - spark.sql.legacy.sizeOfNull - spark.sql.legacy.replaceDatabricksSparkAvro.enabled - spark.sql.legacy.typeCoercion.datetimeToString.enabled - spark.sql.legacy.looseUpcast - spark.sql.legacy.arrayExistsFollowsThreeValuedLogic ### Why are the changes needed? In general case, users shouldn't change legacy configs, so, they can be marked as internals. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Should be tested by jenkins build and run tests. Closes #27448 from MaxGekk/legacy-internal-sql-conf. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan (cherry picked from commit f2dd082544aeba5978d0c140d0194eedb969d132) Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3ad3416256c7d..b94ddbdc0fc9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1916,6 +1916,7 @@ object SQLConf { .createWithDefault(Deflater.DEFAULT_COMPRESSION) val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .internal() .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + "The size function returns null for null input if the flag is disabled.") .booleanConf @@ -1923,6 +1924,7 @@ object SQLConf { val LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED = buildConf("spark.sql.legacy.replaceDatabricksSparkAvro.enabled") + .internal() .doc("If it is set to true, the data source provider com.databricks.spark.avro is mapped " + "to the built-in but external Avro data source module for backward compatibility.") .booleanConf @@ -2048,10 +2050,11 @@ object SQLConf { val LEGACY_CAST_DATETIME_TO_STRING = buildConf("spark.sql.legacy.typeCoercion.datetimeToString.enabled") + .internal() .doc("If it is set to true, date/timestamp will cast to string in binary comparisons " + "with String") - .booleanConf - .createWithDefault(false) + .booleanConf + .createWithDefault(false) val DEFAULT_CATALOG = buildConf("spark.sql.defaultCatalog") .doc("Name of the default catalog. This will be the current catalog if users have not " + @@ -2071,6 +2074,7 @@ object SQLConf { .createOptional val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.looseUpcast") + .internal() .doc("When true, the upcast will be loose and allows string to atomic types.") .booleanConf .createWithDefault(false) @@ -2083,6 +2087,7 @@ object SQLConf { val LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC = buildConf("spark.sql.legacy.arrayExistsFollowsThreeValuedLogic") + .internal() .doc("When true, the ArrayExists will follow the three-valued boolean logic.") .booleanConf .createWithDefault(true) From d3de7568f32e298442f07b0a28b2c906de72c797 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 4 Feb 2020 17:22:23 -0800 Subject: [PATCH 0010/1280] [SPARK-25040][SQL][FOLLOWUP] Add legacy config for allowing empty strings for certain types in json parser ### What changes were proposed in this pull request? This is a follow-up for #22787. In #22787 we disallowed empty strings for json parser except for string and binary types. This follow-up adds a legacy config for restoring previous behavior of allowing empty string. ### Why are the changes needed? Adding a legacy config to make migration easy for Spark users. ### Does this PR introduce any user-facing change? Yes. If set this legacy config to true, the users can restore previous behavior prior to Spark 3.0.0. ### How was this patch tested? Unit test. Closes #27456 from viirya/SPARK-25040-followup. Lead-authored-by: Liang-Chi Hsieh Co-authored-by: Liang-Chi Hsieh Signed-off-by: Dongjoon Hyun (cherry picked from commit 7631275f974d2ecf68cd8394ed683e30be320e56) Signed-off-by: Dongjoon Hyun --- docs/sql-migration-guide.md | 2 +- .../sql/catalyst/json/JacksonParser.scala | 14 ++++- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../datasources/json/JsonSuite.scala | 61 ++++++++++++++----- 4 files changed, 68 insertions(+), 17 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index a5ef1c2b1d045..0c47370283736 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -37,7 +37,7 @@ license: | - Since Spark 3.0, the Dataset and DataFrame API `unionAll` is not deprecated any more. It is an alias for `union`. - - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`. + - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType`, `DoubleType`, `DateType` and `TimestampType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`. The previous behaviour of allowing empty string can be restored by setting `spark.sql.legacy.json.allowEmptyString.enabled` to `true`. - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index c44025ca8bcfd..76efa574a99ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils @@ -307,6 +308,8 @@ class JacksonParser( } } + private val allowEmptyString = SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_EMPTY_STRING_IN_JSON) + /** * This function throws an exception for failed conversion. For empty string on data types * except for string and binary types, this also throws an exception. @@ -315,7 +318,16 @@ class JacksonParser( parser: JsonParser, dataType: DataType): PartialFunction[JsonToken, R] = { - // SPARK-25040: Disallow empty strings for data types except for string and binary types. + // SPARK-25040: Disallows empty strings for data types except for string and binary types. + // But treats empty strings as null for certain types if the legacy config is enabled. + case VALUE_STRING if parser.getTextLength < 1 && allowEmptyString => + dataType match { + case FloatType | DoubleType | TimestampType | DateType => + throw new RuntimeException( + s"Failed to parse an empty string for data type ${dataType.catalogString}") + case _ => null + } + case VALUE_STRING if parser.getTextLength < 1 => throw new RuntimeException( s"Failed to parse an empty string for data type ${dataType.catalogString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b94ddbdc0fc9a..5ce5692123805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1989,6 +1989,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_ALLOW_EMPTY_STRING_IN_JSON = + buildConf("spark.sql.legacy.json.allowEmptyString.enabled") + .internal() + .doc("When set to true, the parser of JSON data source treats empty strings as null for " + + "some data types such as `IntegerType`.") + .booleanConf + .createWithDefault(false) + val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL = buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index d0e2e001034fb..b20da2266b0f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2436,23 +2436,24 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson } } - test("SPARK-25040: empty strings should be disallowed") { - def failedOnEmptyString(dataType: DataType): Unit = { - val df = spark.read.schema(s"a ${dataType.catalogString}") - .option("mode", "FAILFAST").json(Seq("""{"a":""}""").toDS) - val errMessage = intercept[SparkException] { - df.collect() - }.getMessage - assert(errMessage.contains( - s"Failed to parse an empty string for data type ${dataType.catalogString}")) - } - def emptyString(dataType: DataType, expected: Any): Unit = { - val df = spark.read.schema(s"a ${dataType.catalogString}") - .option("mode", "FAILFAST").json(Seq("""{"a":""}""").toDS) - checkAnswer(df, Row(expected) :: Nil) - } + private def failedOnEmptyString(dataType: DataType): Unit = { + val df = spark.read.schema(s"a ${dataType.catalogString}") + .option("mode", "FAILFAST").json(Seq("""{"a":""}""").toDS) + val errMessage = intercept[SparkException] { + df.collect() + }.getMessage + assert(errMessage.contains( + s"Failed to parse an empty string for data type ${dataType.catalogString}")) + } + private def emptyString(dataType: DataType, expected: Any): Unit = { + val df = spark.read.schema(s"a ${dataType.catalogString}") + .option("mode", "FAILFAST").json(Seq("""{"a":""}""").toDS) + checkAnswer(df, Row(expected) :: Nil) + } + + test("SPARK-25040: empty strings should be disallowed") { failedOnEmptyString(BooleanType) failedOnEmptyString(ByteType) failedOnEmptyString(ShortType) @@ -2471,6 +2472,36 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson emptyString(BinaryType, "".getBytes(StandardCharsets.UTF_8)) } + test("SPARK-25040: allowing empty strings when legacy config is enabled") { + def emptyStringAsNull(dataType: DataType): Unit = { + val df = spark.read.schema(s"a ${dataType.catalogString}") + .option("mode", "FAILFAST").json(Seq("""{"a":""}""").toDS) + checkAnswer(df, Row(null) :: Nil) + } + + // Legacy mode prior to Spark 3.0.0 + withSQLConf(SQLConf.LEGACY_ALLOW_EMPTY_STRING_IN_JSON.key -> "true") { + emptyStringAsNull(BooleanType) + emptyStringAsNull(ByteType) + emptyStringAsNull(ShortType) + emptyStringAsNull(IntegerType) + emptyStringAsNull(LongType) + + failedOnEmptyString(FloatType) + failedOnEmptyString(DoubleType) + failedOnEmptyString(TimestampType) + failedOnEmptyString(DateType) + + emptyStringAsNull(DecimalType.SYSTEM_DEFAULT) + emptyStringAsNull(ArrayType(IntegerType)) + emptyStringAsNull(MapType(StringType, IntegerType, true)) + emptyStringAsNull(StructType(StructField("f1", IntegerType, true) :: Nil)) + + emptyString(StringType, "") + emptyString(BinaryType, "".getBytes(StandardCharsets.UTF_8)) + } + } + test("return partial result for bad records") { val schema = "a double, b array, c string, _corrupt_record string" val badRecords = Seq( From d1991c8c1ee7948ab672d702aa89f4ce99d2f485 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 4 Feb 2020 17:26:46 -0800 Subject: [PATCH 0011/1280] Revert "[SPARK-28310][SQL] Support (FIRST_VALUE|LAST_VALUE)(expr[ (IGNORE|RESPECT) NULLS]?) syntax" ### What changes were proposed in this pull request? This reverts commit b89c3de1a439ed7302dd8f44c49b89bb7da2eebe. ### Why are the changes needed? `FIRST_VALUE` is used only for window expression. Please see the discussion on https://github.com/apache/spark/pull/25082 . ### Does this PR introduce any user-facing change? Yes. ### How was this patch tested? Pass the Jenkins. Closes #27458 from dongjoon-hyun/SPARK-28310. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit 898716980dce44a4cc09411e72d64c848698cad5) Signed-off-by: Dongjoon Hyun --- docs/sql-keywords.md | 3 --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 11 ++--------- .../sql/catalyst/parser/ExpressionParserSuite.scala | 9 --------- .../catalyst/parser/TableIdentifierParserSuite.scala | 5 ----- 4 files changed, 2 insertions(+), 26 deletions(-) diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md index b18855366bb2b..9e4a3c54100c6 100644 --- a/docs/sql-keywords.md +++ b/docs/sql-keywords.md @@ -119,7 +119,6 @@ Below is a list of all the keywords in Spark SQL. FILTERreservednon-reservedreserved FILEFORMATnon-reservednon-reservednon-reserved FIRSTnon-reservednon-reservednon-reserved - FIRST_VALUEreservednon-reservedreserved FOLLOWINGnon-reservednon-reservednon-reserved FORreservednon-reservedreserved FOREIGNreservednon-reservedreserved @@ -153,7 +152,6 @@ Below is a list of all the keywords in Spark SQL. JOINreservedstrict-non-reservedreserved KEYSnon-reservednon-reservednon-reserved LASTnon-reservednon-reservednon-reserved - LAST_VALUEreservednon-reservedreserved LATERALnon-reservednon-reservedreserved LAZYnon-reservednon-reservednon-reserved LEADINGreservednon-reservedreserved @@ -221,7 +219,6 @@ Below is a list of all the keywords in Spark SQL. REPAIRnon-reservednon-reservednon-reserved REPLACEnon-reservednon-reservednon-reserved RESETnon-reservednon-reservednon-reserved - RESPECTnon-reservednon-reservednon-reserved RESTRICTnon-reservednon-reservednon-reserved REVOKEnon-reservednon-reservedreserved RIGHTreservedstrict-non-reservedreserved diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6f2bb7a9a7536..08d5ff53bf2e2 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -771,8 +771,8 @@ primaryExpression | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct - | (FIRST | FIRST_VALUE) '(' expression ((IGNORE | RESPECT) NULLS)? ')' #first - | (LAST | LAST_VALUE) '(' expression ((IGNORE | RESPECT) NULLS)? ')' #last + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position | constant #constantDefault | ASTERISK #star @@ -1120,7 +1120,6 @@ ansiNonReserved | REPAIR | REPLACE | RESET - | RESPECT | RESTRICT | REVOKE | RLIKE @@ -1280,7 +1279,6 @@ nonReserved | FIELDS | FILEFORMAT | FIRST - | FIRST_VALUE | FOLLOWING | FOR | FOREIGN @@ -1310,7 +1308,6 @@ nonReserved | ITEMS | KEYS | LAST - | LAST_VALUE | LATERAL | LAZY | LEADING @@ -1374,7 +1371,6 @@ nonReserved | REPAIR | REPLACE | RESET - | RESPECT | RESTRICT | REVOKE | RLIKE @@ -1531,7 +1527,6 @@ FIELDS: 'FIELDS'; FILTER: 'FILTER'; FILEFORMAT: 'FILEFORMAT'; FIRST: 'FIRST'; -FIRST_VALUE: 'FIRST_VALUE'; FOLLOWING: 'FOLLOWING'; FOR: 'FOR'; FOREIGN: 'FOREIGN'; @@ -1565,7 +1560,6 @@ ITEMS: 'ITEMS'; JOIN: 'JOIN'; KEYS: 'KEYS'; LAST: 'LAST'; -LAST_VALUE: 'LAST_VALUE'; LATERAL: 'LATERAL'; LAZY: 'LAZY'; LEADING: 'LEADING'; @@ -1632,7 +1626,6 @@ RENAME: 'RENAME'; REPAIR: 'REPAIR'; REPLACE: 'REPLACE'; RESET: 'RESET'; -RESPECT: 'RESPECT'; RESTRICT: 'RESTRICT'; REVOKE: 'REVOKE'; RIGHT: 'RIGHT'; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 85efc6accf01f..df012ccf09620 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -771,15 +771,6 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) } - test("Support respect nulls keywords for first_value and last_value") { - assertEqual("first_value(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) - assertEqual("first_value(a respect nulls)", First('a, Literal(false)).toAggregateExpression()) - assertEqual("first_value(a)", First('a, Literal(false)).toAggregateExpression()) - assertEqual("last_value(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) - assertEqual("last_value(a respect nulls)", Last('a, Literal(false)).toAggregateExpression()) - assertEqual("last_value(a)", Last('a, Literal(false)).toAggregateExpression()) - } - test("timestamp literals") { DateTimeTestUtils.outstandingTimezones.foreach { timeZone => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone.getID) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 43244b3c0a57d..053d57846ce8d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -369,7 +369,6 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "fields", "fileformat", "first", - "first_value", "following", "for", "foreign", @@ -403,7 +402,6 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "join", "keys", "last", - "last_value", "lateral", "lazy", "leading", @@ -467,7 +465,6 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "repair", "replace", "reset", - "respect", "restrict", "revoke", "right", @@ -562,7 +559,6 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "except", "false", "fetch", - "first_value", "for", "foreign", "from", @@ -577,7 +573,6 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "into", "join", "is", - "last_value", "leading", "left", "minute", From d8bc72e89e8221cc7ec17470ae991878d9daecd7 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 5 Feb 2020 16:45:54 +0900 Subject: [PATCH 0012/1280] [SPARK-30733][R][HOTFIX] Fix SparkR tests per testthat and R version upgrade, and disable CRAN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? There are currently the R test failures after upgrading `testthat` to 2.0.0, and R version 3.5.2 as of SPARK-23435. This PR targets to fix the tests and make the tests pass. See the explanations and causes below: ``` test_context.R:49: failure: Check masked functions length(maskedCompletely) not equal to length(namesOfMaskedCompletely). 1/1 mismatches [1] 6 - 4 == 2 test_context.R:53: failure: Check masked functions sort(maskedCompletely, na.last = TRUE) not equal to sort(namesOfMaskedCompletely, na.last = TRUE). 5/6 mismatches x[2]: "endsWith" y[2]: "filter" x[3]: "filter" y[3]: "not" x[4]: "not" y[4]: "sample" x[5]: "sample" y[5]: NA x[6]: "startsWith" y[6]: NA ``` From my cursory look, R base and R's version are mismatched. I fixed accordingly and Jenkins will test it out. ``` test_includePackage.R:31: error: include inside function package or namespace load failed for ���plyr���: package ���plyr��� was installed by an R version with different internals; it needs to be reinstalled for use with this R version Seems it's a package installation issue. Looks like plyr has to be re-installed. ``` From my cursory look, previously installed `plyr` remains and it's not compatible with the new R version. I fixed accordingly and Jenkins will test it out. ``` test_sparkSQL.R:499: warning: SPARK-17811: can create DataFrame containing NA as date and time Your system is mis-configured: ���/etc/localtime��� is not a symlink ``` Seems a env problem. I suppressed the warnings for now. ``` test_sparkSQL.R:499: warning: SPARK-17811: can create DataFrame containing NA as date and time It is strongly recommended to set envionment variable TZ to ���America/Los_Angeles��� (or equivalent) ``` Seems a env problem. I suppressed the warnings for now. ``` test_sparkSQL.R:1814: error: string operators unable to find an inherited method for function ���startsWith��� for signature ���"character"��� 1: expect_true(startsWith("Hello World", "Hello")) at /home/jenkins/workspace/SparkPullRequestBuilder2/R/pkg/tests/fulltests/test_sparkSQL.R:1814 2: quasi_label(enquo(object), label) 3: eval_bare(get_expr(quo), get_env(quo)) 4: startsWith("Hello World", "Hello") 5: (function (classes, fdef, mtable) { methods <- .findInheritedMethods(classes, fdef, mtable) if (length(methods) == 1L) return(methods[[1L]]) else if (length(methods) == 0L) { cnames <- paste0("\"", vapply(classes, as.character, ""), "\"", collapse = ", ") stop(gettextf("unable to find an inherited method for function %s for signature %s", sQuote(fdefgeneric), sQuote(cnames)), domain = NA) } else stop("Internal error in finding inherited methods; didn't return a unique method", domain = NA) })(list("character"), new("nonstandardGenericFunction", .Data = function (x, prefix) { standardGeneric("startsWith") }, generic = structure("startsWith", package = "SparkR"), package = "SparkR", group = list(), valueClass = character(0), signature = c("x", "prefix"), default = NULL, skeleton = (function (x, prefix) stop("invalid call in method dispatch to 'startsWith' (no default method)", domain = NA))(x, prefix)), ) 6: stop(gettextf("unable to find an inherited method for function %s for signature %s", sQuote(fdefgeneric), sQuote(cnames)), domain = NA) ``` From my cursory look, R base and R's version are mismatched. I fixed accordingly and Jenkins will test it out. Also, this PR causes a CRAN check failure as below: ``` * creating vignettes ... ERROR Error: processing vignette 'sparkr-vignettes.Rmd' failed with diagnostics: package ���htmltools��� was installed by an R version with different internals; it needs to be reinstalled for use with this R version ``` This PR disables it for now. ### Why are the changes needed? To unblock other PRs. ### Does this PR introduce any user-facing change? No. Test only and dev only. ### How was this patch tested? No. I am going to use Jenkins to test. Closes #27460 from HyukjinKwon/r-test-failure. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon (cherry picked from commit e2d984aa1c79eb389cc8d333f656196b17af1c32) Signed-off-by: HyukjinKwon --- R/pkg/tests/fulltests/test_context.R | 3 ++- R/pkg/tests/fulltests/test_includePackage.R | 8 ++++---- R/pkg/tests/fulltests/test_sparkSQL.R | 10 ++++++++-- R/run-tests.sh | 7 ++++--- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index b9139154bc165..6be04b321e985 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -25,7 +25,8 @@ test_that("Check masked functions", { namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", "summary", "transform", "drop", "window", "as.data.frame", "union", "not") - if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { + version <- packageVersion("base") + if (as.numeric(version$major) >= 3 && as.numeric(version$minor) >= 3) { namesOfMasked <- c("endsWith", "startsWith", namesOfMasked) } masked <- conflicts(detail = TRUE)$`package:SparkR` diff --git a/R/pkg/tests/fulltests/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R index 916361ff4c797..1d16b260c4c52 100644 --- a/R/pkg/tests/fulltests/test_includePackage.R +++ b/R/pkg/tests/fulltests/test_includePackage.R @@ -27,8 +27,8 @@ rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { # Only run the test if plyr is installed. - if ("plyr" %in% rownames(installed.packages())) { - suppressPackageStartupMessages(library(plyr)) + if ("plyr" %in% rownames(installed.packages()) && + suppressPackageStartupMessages(suppressWarnings(library(plyr, logical.return = TRUE)))) { generateData <- function(x) { suppressPackageStartupMessages(library(plyr)) attach(airquality) @@ -44,8 +44,8 @@ test_that("include inside function", { test_that("use include package", { # Only run the test if plyr is installed. - if ("plyr" %in% rownames(installed.packages())) { - suppressPackageStartupMessages(library(plyr)) + if ("plyr" %in% rownames(installed.packages()) && + suppressPackageStartupMessages(suppressWarnings(library(plyr, logical.return = TRUE)))) { generateData <- function(x) { attach(airquality) result <- transform(Ozone, logOzone = log(Ozone)) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 3b3768f7e2715..23fadc4373c3f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -496,7 +496,12 @@ test_that("SPARK-17811: can create DataFrame containing NA as date and time", { expect_true(is.na(DF$date[2])) expect_equal(DF$date[1], as.Date("2016-10-01")) expect_true(is.na(DF$time[2])) - expect_equal(DF$time[1], as.POSIXlt("2016-01-10")) + # Warnings were suppressed due to Jenkins environment issues. + # It shows both warnings as below in Jenkins: + # - Your system is mis-configured: /etc/localtime is not a symlink + # - It is strongly recommended to set environment variable TZ to + # America/Los_Angeles (or equivalent) + suppressWarnings(expect_equal(DF$time[1], as.POSIXlt("2016-01-10"))) }) test_that("create DataFrame with complex types", { @@ -1810,7 +1815,8 @@ test_that("string operators", { expect_true(first(select(df, endsWith(df$name, "el")))[[1]]) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") expect_equal(first(select(df, substr(df$name, 4, 6)))[[1]], "hae") - if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { + version <- packageVersion("base") + if (as.numeric(version$major) >= 3 && as.numeric(version$minor) >= 3) { expect_true(startsWith("Hello World", "Hello")) expect_false(endsWith("Hello World", "a")) } diff --git a/R/run-tests.sh b/R/run-tests.sh index 51ca7d600caf0..782b5f5baebaf 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -31,9 +31,10 @@ NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" # Also run the documentation tests for CRAN CRAN_CHECK_LOG_FILE=$FWDIR/cran-check.out rm -f $CRAN_CHECK_LOG_FILE - -NO_TESTS=1 NO_MANUAL=1 $FWDIR/check-cran.sh 2>&1 | tee -a $CRAN_CHECK_LOG_FILE -FAILED=$((PIPESTATUS[0]||$FAILED)) +# TODO(SPARK-30737) reenable this once packages are correctly installed in Jenkins +# NO_TESTS=1 NO_MANUAL=1 $FWDIR/check-cran.sh 2>&1 | tee -a $CRAN_CHECK_LOG_FILE +# FAILED=$((PIPESTATUS[0]||$FAILED)) +touch $CRAN_CHECK_LOG_FILE NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)" From 3854ad87c78f2a331f9c9c1a34f9ec281900f8fe Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 5 Feb 2020 16:15:44 +0800 Subject: [PATCH 0013/1280] [SPARK-30594][CORE] Do not post SparkListenerBlockUpdated when updateBlockInfo returns false ### What changes were proposed in this pull request? If `updateBlockInfo` returns false, which means the `BlockManager` will re-register and report all blocks later. So, we may report two times for the same block, which causes `AppStatusListener` to count used memory for two times, too. As a result, the used memory can exceed the total memory. So, this PR changes it to not post `SparkListenerBlockUpdated` when `updateBlockInfo` returns false. And, always clean up used memory whenever `AppStatusListener` receives `SparkListenerBlockManagerAdded`. ### Why are the changes needed? This PR tries to fix negative memory usage in UI (https://user-images.githubusercontent.com/3488126/72131225-95e37e00-33b6-11ea-8708-6e5ed328d1ca.png, see #27144 ). Though, I'm not very sure this is the root cause for #27144 since known information is limited here. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added new tests by xuanyuanking Closes #27306 from Ngone51/fix-possible-negative-memory. Lead-authored-by: yi.wu Co-authored-by: Yuanjian Li Co-authored-by: wuyi Signed-off-by: Wenchen Fan (cherry picked from commit 30e418a6fe971b4a84c37ca0ae20f1a664b117d3) Signed-off-by: Wenchen Fan --- .../spark/status/AppStatusListener.scala | 9 +++++-- .../org/apache/spark/status/LiveEntity.scala | 2 +- .../storage/BlockManagerMasterEndpoint.scala | 9 +++++-- .../spark/status/AppStatusListenerSuite.scala | 24 +++++++++++++++++++ .../spark/storage/BlockManagerSuite.scala | 17 +++++++++++-- 5 files changed, 54 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index a5850fc2ac4b9..c3f22f32993a8 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -71,7 +71,7 @@ private[spark] class AppStatusListener( // causing too many writes to the underlying store, and other expensive operations). private val liveStages = new ConcurrentHashMap[(Int, Int), LiveStage]() private val liveJobs = new HashMap[Int, LiveJob]() - private val liveExecutors = new HashMap[String, LiveExecutor]() + private[spark] val liveExecutors = new HashMap[String, LiveExecutor]() private val deadExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() @@ -772,6 +772,11 @@ private[spark] class AppStatusListener( event.maxOnHeapMem.foreach { _ => exec.totalOnHeap = event.maxOnHeapMem.get exec.totalOffHeap = event.maxOffHeapMem.get + // SPARK-30594: whenever(first time or re-register) a BlockManager added, all blocks + // from this BlockManager will be reported to driver later. So, we should clean up + // used memory to avoid overlapped count. + exec.usedOnHeap = 0 + exec.usedOffHeap = 0 } exec.isActive = true exec.maxMemory = event.maxMem @@ -1042,7 +1047,7 @@ private[spark] class AppStatusListener( } } - private def updateExecutorMemoryDiskInfo( + private[spark] def updateExecutorMemoryDiskInfo( exec: LiveExecutor, storageLevel: StorageLevel, memoryDelta: Long, diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index e3046dce34e67..2714f30de14f0 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -245,7 +245,7 @@ private class LiveTask( } -private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveEntity { +private[spark] class LiveExecutor(val executorId: String, _addTime: Long) extends LiveEntity { var hostPort: String = null var host: String = null diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 41ef1909cd4c2..d7f7eedc7f33b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -98,8 +98,13 @@ class BlockManagerMasterEndpoint( case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => - context.reply(updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)) - listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) + val isSuccess = updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) + context.reply(isSuccess) + // SPARK-30594: we should not post `SparkListenerBlockUpdated` when updateBlockInfo + // returns false since the block info would be updated again later. + if (isSuccess) { + listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) + } case GetLocations(blockId) => context.reply(getLocations(blockId)) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index e7eed7bf4c879..255f91866ef58 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1657,6 +1657,30 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("clean up used memory when BlockManager added") { + val listener = new AppStatusListener(store, conf, true) + // Add block manager at the first time + val driver = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "localhost", 42) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded( + time, driver, 42L, Some(43L), Some(44L))) + // Update the memory metrics + listener.updateExecutorMemoryDiskInfo( + listener.liveExecutors(SparkContext.DRIVER_IDENTIFIER), + StorageLevel.MEMORY_AND_DISK, + 10L, + 10L + ) + // Re-add the same block manager again + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded( + time, driver, 42L, Some(43L), Some(44L))) + + check[ExecutorSummaryWrapper](SparkContext.DRIVER_IDENTIFIER) { d => + val memoryMetrics = d.info.memoryMetrics.get + assert(memoryMetrics.usedOffHeapStorageMemory == 0) + assert(memoryMetrics.usedOnHeapStorageMemory == 0) + } + } + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 89f00b5a9d902..8d06768a2b284 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -50,7 +50,7 @@ import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, Transpo import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerBlockUpdated} import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -71,6 +71,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val allStores = ArrayBuffer[BlockManager]() var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null + var liveListenerBus: LiveListenerBus = null val securityMgr = new SecurityManager(new SparkConf(false)) val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) @@ -145,9 +146,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(sc.conf).thenReturn(conf) val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() + liveListenerBus = spy(new LiveListenerBus(conf)) master = spy(new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf), None, blockManagerInfo)), + liveListenerBus, None, blockManagerInfo)), rpcEnv.setupEndpoint("blockmanagerHeartbeat", new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)) @@ -164,6 +166,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE rpcEnv.awaitTermination() rpcEnv = null master = null + liveListenerBus = null } finally { super.afterEach() } @@ -1693,6 +1696,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(locs(blockIds(0)) == expectedLocs) } + test("SPARK-30594: Do not post SparkListenerBlockUpdated when updateBlockInfo returns false") { + // update block info for non-existent block manager + val updateInfo = UpdateBlockInfo(BlockManagerId("1", "host1", 100), + BlockId("test_1"), StorageLevel.MEMORY_ONLY, 1, 1) + val result = master.driverEndpoint.askSync[Boolean](updateInfo) + + assert(!result) + verify(liveListenerBus, never()).post(SparkListenerBlockUpdated(BlockUpdatedInfo(updateInfo))) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: DownloadFileManager = null From 3ccebcdd45a0def4853c1175c18d8538602d571d Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 5 Feb 2020 17:16:38 +0800 Subject: [PATCH 0014/1280] [SPARK-30506][SQL][DOC] Document for generic file source options/configs ### What changes were proposed in this pull request? Add a new document page named *Generic File Source Options* for *Data Sources* menu and added following sub items: * spark.sql.files.ignoreCorruptFiles * spark.sql.files.ignoreMissingFiles * pathGlobFilter * recursiveFileLookup And here're snapshots of the generated document: doc-1 doc-2 doc-3 doc-4 ### Why are the changes needed? Better guidance for end-user. ### Does this PR introduce any user-facing change? No, added in Spark 3.0. ### How was this patch tested? Pass Jenkins. Closes #27302 from Ngone51/doc-generic-file-source-option. Lead-authored-by: yi.wu Co-authored-by: Yuanjian Li Signed-off-by: Wenchen Fan (cherry picked from commit 5983ad9cc4481e224a7e094de830ef2e816c1fe6) Signed-off-by: Wenchen Fan --- docs/_data/menu-sql.yaml | 2 + docs/sql-data-sources-avro.md | 2 +- docs/sql-data-sources-generic-options.md | 121 ++++++++++++++++++ docs/sql-data-sources-load-save-functions.md | 21 --- docs/sql-data-sources.md | 5 + .../sql/JavaSQLDataSourceExample.java | 48 ++++++- examples/src/main/python/sql/datasource.py | 48 ++++++- examples/src/main/r/RSparkSQLExample.R | 24 +++- .../main/resources/dir1/dir2/file2.parquet | Bin 0 -> 520 bytes .../src/main/resources/dir1/file1.parquet | Bin 0 -> 520 bytes examples/src/main/resources/dir1/file3.json | 1 + .../do_not_read_this.txt | 1 - .../users.orc | Bin 448 -> 0 bytes .../favorite_color=red/users.orc | Bin 402 -> 0 bytes .../examples/sql/SQLDataSourceExample.scala | 48 ++++++- 15 files changed, 282 insertions(+), 39 deletions(-) create mode 100644 docs/sql-data-sources-generic-options.md create mode 100644 examples/src/main/resources/dir1/dir2/file2.parquet create mode 100644 examples/src/main/resources/dir1/file1.parquet create mode 100644 examples/src/main/resources/dir1/file3.json delete mode 100644 examples/src/main/resources/partitioned_users.orc/do_not_read_this.txt delete mode 100644 examples/src/main/resources/partitioned_users.orc/favorite_color=__HIVE_DEFAULT_PARTITION__/users.orc delete mode 100644 examples/src/main/resources/partitioned_users.orc/favorite_color=red/users.orc diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 3e4db7107ec34..241ec399d7bd5 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -24,6 +24,8 @@ subitems: - text: "Generic Load/Save Functions" url: sql-data-sources-load-save-functions.html + - text: "Generic File Source Options" + url: sql-data-sources-generic-options.html - text: Parquet Files url: sql-data-sources-parquet.html - text: ORC Files diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index b0076878e02da..8e6a4079cd5de 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -230,7 +230,7 @@ Data source options of Avro can be set via: ignoreExtension true - The option controls ignoring of files without .avro extensions in read.
If the option is enabled, all files (with and without .avro extension) are loaded.
The option has been deprecated, and it will be removed in the future releases. Please use the general data source option pathGlobFilter for filtering file names. + The option controls ignoring of files without .avro extensions in read.
If the option is enabled, all files (with and without .avro extension) are loaded.
The option has been deprecated, and it will be removed in the future releases. Please use the general data source option pathGlobFilter for filtering file names. read diff --git a/docs/sql-data-sources-generic-options.md b/docs/sql-data-sources-generic-options.md new file mode 100644 index 0000000000000..0cfe2ed1aa891 --- /dev/null +++ b/docs/sql-data-sources-generic-options.md @@ -0,0 +1,121 @@ +--- +layout: global +title: Generic File Source Options +displayTitle: Generic File Source Options +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +* Table of contents +{:toc} + +These generic options/configurations are effective only when using file-based sources: parquet, orc, avro, json, csv, text. + +Please note that the hierarchy of directories used in examples below are: + +{% highlight text %} + +dir1/ + ├── dir2/ + │ └── file2.parquet (schema: , content: "file2.parquet") + └── file1.parquet (schema: , content: "file1.parquet") + └── file3.json (schema: , content: "{'file':'corrupt.json'}") + +{% endhighlight %} + +### Ignore Corrupt Files + +Spark allows you to use `spark.sql.files.ignoreCorruptFiles` to ignore corrupt files while reading data +from files. When set to true, the Spark jobs will continue to run when encountering corrupted files and +the contents that have been read will still be returned. + +To ignore corrupt files while reading data files, you can use: + +
+
+{% include_example ignore_corrupt_files scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example ignore_corrupt_files java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example ignore_corrupt_files python/sql/datasource.py %} +
+ +
+{% include_example ignore_corrupt_files r/RSparkSQLExample.R %} +
+
+ +### Ignore Missing Files + +Spark allows you to use `spark.sql.files.ignoreMissingFiles` to ignore missing files while reading data +from files. Here, missing file really means the deleted file under directory after you construct the +`DataFrame`. When set to true, the Spark jobs will continue to run when encountering missing files and +the contents that have been read will still be returned. + +### Path Global Filter + +`pathGlobFilter` is used to only include files with file names matching the pattern. +The syntax follows org.apache.hadoop.fs.GlobFilter. +It does not change the behavior of partition discovery. + +To load files with paths matching a given glob pattern while keeping the behavior of partition discovery, +you can use: + +
+
+{% include_example load_with_path_glob_filter scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example load_with_path_glob_filter java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example load_with_path_glob_filter python/sql/datasource.py %} +
+ +
+{% include_example load_with_path_glob_filter r/RSparkSQLExample.R %} +
+
+ +### Recursive File Lookup +`recursiveFileLookup` is used to recursively load files and it disables partition inferring. Its default value is `false`. +If data source explicitly specifies the `partitionSpec` when `recursiveFileLookup` is true, exception will be thrown. + +To load all files recursively, you can use: + +
+
+{% include_example recursive_file_lookup scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example recursive_file_lookup java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example recursive_file_lookup python/sql/datasource.py %} +
+ +
+{% include_example recursive_file_lookup r/RSparkSQLExample.R %} +
+
\ No newline at end of file diff --git a/docs/sql-data-sources-load-save-functions.md b/docs/sql-data-sources-load-save-functions.md index 07482137a28a3..a7efb9347ac64 100644 --- a/docs/sql-data-sources-load-save-functions.md +++ b/docs/sql-data-sources-load-save-functions.md @@ -102,27 +102,6 @@ To load a CSV file you can use: -To load files with paths matching a given glob pattern while keeping the behavior of partition discovery, -you can use: - -
-
-{% include_example load_with_path_glob_filter scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} -
- -
-{% include_example load_with_path_glob_filter java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} -
- -
-{% include_example load_with_path_glob_filter python/sql/datasource.py %} -
- -
-{% include_example load_with_path_glob_filter r/RSparkSQLExample.R %} -
-
- The extra options are also used during write operation. For example, you can control bloom filters and dictionary encodings for ORC data sources. The following ORC example will create bloom filter and use dictionary encoding only for `favorite_color`. diff --git a/docs/sql-data-sources.md b/docs/sql-data-sources.md index 079c54060d15d..9396846041709 100644 --- a/docs/sql-data-sources.md +++ b/docs/sql-data-sources.md @@ -33,6 +33,11 @@ goes into specific options that are available for the built-in data sources. * [Save Modes](sql-data-sources-load-save-functions.html#save-modes) * [Saving to Persistent Tables](sql-data-sources-load-save-functions.html#saving-to-persistent-tables) * [Bucketing, Sorting and Partitioning](sql-data-sources-load-save-functions.html#bucketing-sorting-and-partitioning) +* [Generic File Source Options](sql-data-sources-generic-options.html) + * [Ignore Corrupt Files](sql-data-sources-generic-options.html#ignore-corrupt-iles) + * [Ignore Missing Files](sql-data-sources-generic-options.html#ignore-missing-iles) + * [Path Global Filter](sql-data-sources-generic-options.html#path-global-filter) + * [Recursive File Lookup](sql-data-sources-generic-options.html#recursive-file-lookup) * [Parquet Files](sql-data-sources-parquet.html) * [Loading Data Programmatically](sql-data-sources-parquet.html#loading-data-programmatically) * [Partition Discovery](sql-data-sources-parquet.html#partition-discovery) diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index b2ce0bc08642a..2295225387a33 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -98,6 +98,7 @@ public static void main(String[] args) { .getOrCreate(); runBasicDataSourceExample(spark); + runGenericFileSourceOptionsExample(spark); runBasicParquetExample(spark); runParquetSchemaMergingExample(spark); runJsonDatasetExample(spark); @@ -106,6 +107,48 @@ public static void main(String[] args) { spark.stop(); } + private static void runGenericFileSourceOptionsExample(SparkSession spark) { + // $example on:ignore_corrupt_files$ + // enable ignore corrupt files + spark.sql("set spark.sql.files.ignoreCorruptFiles=true"); + // dir1/file3.json is corrupt from parquet's view + Dataset testCorruptDF = spark.read().parquet( + "examples/src/main/resources/dir1/", + "examples/src/main/resources/dir1/dir2/"); + testCorruptDF.show(); + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // |file2.parquet| + // +-------------+ + // $example off:ignore_corrupt_files$ + // $example on:recursive_file_lookup$ + Dataset recursiveLoadedDF = spark.read().format("parquet") + .option("recursiveFileLookup", "true") + .load("examples/src/main/resources/dir1"); + recursiveLoadedDF.show(); + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // |file2.parquet| + // +-------------+ + // $example off:recursive_file_lookup$ + spark.sql("set spark.sql.files.ignoreCorruptFiles=false"); + // $example on:load_with_path_glob_filter$ + Dataset testGlobFilterDF = spark.read().format("parquet") + .option("pathGlobFilter", "*.parquet") // json file should be filtered out + .load("examples/src/main/resources/dir1"); + testGlobFilterDF.show(); + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // +-------------+ + // $example off:load_with_path_glob_filter$ + } + private static void runBasicDataSourceExample(SparkSession spark) { // $example on:generic_load_save_functions$ Dataset usersDF = spark.read().load("examples/src/main/resources/users.parquet"); @@ -123,11 +166,6 @@ private static void runBasicDataSourceExample(SparkSession spark) { .option("header", "true") .load("examples/src/main/resources/people.csv"); // $example off:manual_load_options_csv$ - // $example on:load_with_path_glob_filter$ - Dataset partitionedUsersDF = spark.read().format("orc") - .option("pathGlobFilter", "*.orc") - .load("examples/src/main/resources/partitioned_users.orc"); - // $example off:load_with_path_glob_filter$ // $example on:manual_save_options_orc$ usersDF.write().format("orc") .option("orc.bloom.filter.columns", "favorite_color") diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index 0d78097ea975e..265f135e1e5f2 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -28,6 +28,48 @@ # $example off:schema_merging$ +def generic_file_source_options_example(spark): + # $example on:ignore_corrupt_files$ + # enable ignore corrupt files + spark.sql("set spark.sql.files.ignoreCorruptFiles=true") + # dir1/file3.json is corrupt from parquet's view + test_corrupt_df = spark.read.parquet("examples/src/main/resources/dir1/", + "examples/src/main/resources/dir1/dir2/") + test_corrupt_df.show() + # +-------------+ + # | file| + # +-------------+ + # |file1.parquet| + # |file2.parquet| + # +-------------+ + # $example off:ignore_corrupt_files$ + + # $example on:recursive_file_lookup$ + recursive_loaded_df = spark.read.format("parquet")\ + .option("recursiveFileLookup", "true")\ + .load("examples/src/main/resources/dir1") + recursive_loaded_df.show() + # +-------------+ + # | file| + # +-------------+ + # |file1.parquet| + # |file2.parquet| + # +-------------+ + # $example off:recursive_file_lookup$ + spark.sql("set spark.sql.files.ignoreCorruptFiles=false") + + # $example on:load_with_path_glob_filter$ + df = spark.read.load("examples/src/main/resources/dir1", + format="parquet", pathGlobFilter="*.parquet") + df.show() + # +-------------+ + # | file| + # +-------------+ + # |file1.parquet| + # +-------------+ + # $example off:load_with_path_glob_filter$ + + def basic_datasource_example(spark): # $example on:generic_load_save_functions$ df = spark.read.load("examples/src/main/resources/users.parquet") @@ -57,11 +99,6 @@ def basic_datasource_example(spark): format="csv", sep=":", inferSchema="true", header="true") # $example off:manual_load_options_csv$ - # $example on:load_with_path_glob_filter$ - df = spark.read.load("examples/src/main/resources/partitioned_users.orc", - format="orc", pathGlobFilter="*.orc") - # $example off:load_with_path_glob_filter$ - # $example on:manual_save_options_orc$ df = spark.read.orc("examples/src/main/resources/users.orc") (df.write.format("orc") @@ -233,6 +270,7 @@ def jdbc_dataset_example(spark): .getOrCreate() basic_datasource_example(spark) + generic_file_source_options_example(spark) parquet_example(spark) parquet_schema_merging_example(spark) json_dataset_example(spark) diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index fa083d5542fae..8685cfb5c05f2 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -99,6 +99,26 @@ createOrReplaceTempView(df, "table") df <- sql("SELECT * FROM table") # $example off:run_sql$ +# Ignore corrupt files +# $example on:ignore_corrupt_files$ +# enable ignore corrupt files +sql("set spark.sql.files.ignoreCorruptFiles=true") +# dir1/file3.json is corrupt from parquet's view +testCorruptDF <- read.parquet(c("examples/src/main/resources/dir1/", "examples/src/main/resources/dir1/dir2/")) +head(testCorruptDF) +# file +# 1 file1.parquet +# 2 file2.parquet +# $example off:ignore_corrupt_files$ + +# $example on:recursive_file_lookup$ +recursiveLoadedDF <- read.df("examples/src/main/resources/dir1", "parquet", recursiveFileLookup = "true") +head(recursiveLoadedDF) +# file +# 1 file1.parquet +# 2 file2.parquet +# $example off:recursive_file_lookup$ +sql("set spark.sql.files.ignoreCorruptFiles=false") # $example on:generic_load_save_functions$ df <- read.df("examples/src/main/resources/users.parquet") @@ -119,7 +139,9 @@ namesAndAges <- select(df, "name", "age") # $example off:manual_load_options_csv$ # $example on:load_with_path_glob_filter$ -df <- read.df("examples/src/main/resources/partitioned_users.orc", "orc", pathGlobFilter = "*.orc") +df <- read.df("examples/src/main/resources/dir1", "parquet", pathGlobFilter = "*.parquet") +# file +# 1 file1.parquet # $example off:load_with_path_glob_filter$ # $example on:manual_save_options_orc$ diff --git a/examples/src/main/resources/dir1/dir2/file2.parquet b/examples/src/main/resources/dir1/dir2/file2.parquet new file mode 100644 index 0000000000000000000000000000000000000000..d1895bf29b75ce0d6c70009562975645b7699fb9 GIT binary patch literal 520 zcmah{O-sW-5S>OXIi$xeyO2XzXlWrXNkh{nc=0Blibq9c(`;%mN!xrVQu;6aVg4ki z+9Fm^*j@JVX5PG+-Pz@hOMnO>Y@*?%O>~oXk~C8zv6AJwQS}k*!r)IH05seutqz_) zgowuME2Bc$r-y3(sB%d(AVyE4r@OcwKv!cXGyA$p3^s0q&b}CeMEAXgtFK=i**Sv$ zx??7G30N4bp`-@PrgT{@gj`AVBtGqlXH{0|vY=<4aD)SN_$#7XXNLiaa`_^1Rm)h` zlHP83{kl>-=0P+GCQhWH R*c~#{#2LJ`0A~0HJ^)`pl6wFE literal 0 HcmV?d00001 diff --git a/examples/src/main/resources/dir1/file1.parquet b/examples/src/main/resources/dir1/file1.parquet new file mode 100644 index 0000000000000000000000000000000000000000..ad360b16fd898ea301cbbc00de07a2fd40bbbb44 GIT binary patch literal 520 zcmah{O-sW-5S>OXIkd+uyRe6_&|)DjiJ@r|ym%8&#iJs!X*M;Oq-{PFDg77zFn^L$ zZ4oOd>@NFwGjHC^?)37;B|s#RHZkzhCMKudB+axdm#SPfx;}tCIQ^3c00Z|?tBdCk zA!0E5s#Brt(?d0T)VU%CkWSyA$J@9*K+j;D@qx)nhmG1`vM+`p(R**<8kiRcb`Idb z?ij}|0SjX_lr(_ZlrGA-R7-`PYgRtsAssm|g`gm-IEsC-Lp`$ghAFfHbYizdrdgj9&B+7^3J5w16z z`wh_R-J6`{rM2^Dxzcr{Yn?_-n#Goy1aUG+g1|eEe4Y5d=f#fi_<<$^{ID@wqz!d+$2T@;=asU7T literal 0 HcmV?d00001 diff --git a/examples/src/main/resources/dir1/file3.json b/examples/src/main/resources/dir1/file3.json new file mode 100644 index 0000000000000..0490f92d7f317 --- /dev/null +++ b/examples/src/main/resources/dir1/file3.json @@ -0,0 +1 @@ +{"file":"corrupt.json"} diff --git a/examples/src/main/resources/partitioned_users.orc/do_not_read_this.txt b/examples/src/main/resources/partitioned_users.orc/do_not_read_this.txt deleted file mode 100644 index 9c19f2a0449eb..0000000000000 --- a/examples/src/main/resources/partitioned_users.orc/do_not_read_this.txt +++ /dev/null @@ -1 +0,0 @@ -do not read this diff --git a/examples/src/main/resources/partitioned_users.orc/favorite_color=__HIVE_DEFAULT_PARTITION__/users.orc b/examples/src/main/resources/partitioned_users.orc/favorite_color=__HIVE_DEFAULT_PARTITION__/users.orc deleted file mode 100644 index 890395a9281abb71a8444a3c9041155fae6c0f9f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 448 zcmZ9I&q~8U5XNVBo9(zw?Ydr;SjFZbMIi*dg&ryutq3;s6fbI&s73pSL<`=EPtdpU zMZ_oa0on0VPFPM^G~@eRf!WgU9H=ldim;DS}SrJF{zz-(+ldGL?0J-_Gzeh^9ZY$ja_ PcC+R4_ix5}{f_k!}tAl=uBlmXZBh|u*|g)<(VO4sHM8gz=Q*dq|0Nzq8h zaaAOV-=FcZP}g^L`!CzXuYyvTD7|GlK_F?~v7@;p&h{?)7#xsq1&U6i$Yy!R+-2)| zIep5ni`DEVFPDL!X5f*wD7lr#huY1{`!HK%w-0%^Tx8{AxMtsUURZAsMqw0TTvNsW iQpG=1@s?`%dUXl(hDNx}#smWV{vcA%pHe2{k^2I!Wi|=` diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index d4c05e5ad9944..2c7abfcd335d1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -32,6 +32,7 @@ object SQLDataSourceExample { .getOrCreate() runBasicDataSourceExample(spark) + runGenericFileSourceOptionsExample(spark) runBasicParquetExample(spark) runParquetSchemaMergingExample(spark) runJsonDatasetExample(spark) @@ -40,6 +41,48 @@ object SQLDataSourceExample { spark.stop() } + private def runGenericFileSourceOptionsExample(spark: SparkSession): Unit = { + // $example on:ignore_corrupt_files$ + // enable ignore corrupt files + spark.sql("set spark.sql.files.ignoreCorruptFiles=true") + // dir1/file3.json is corrupt from parquet's view + val testCorruptDF = spark.read.parquet( + "examples/src/main/resources/dir1/", + "examples/src/main/resources/dir1/dir2/") + testCorruptDF.show() + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // |file2.parquet| + // +-------------+ + // $example off:ignore_corrupt_files$ + // $example on:recursive_file_lookup$ + val recursiveLoadedDF = spark.read.format("parquet") + .option("recursiveFileLookup", "true") + .load("examples/src/main/resources/dir1") + recursiveLoadedDF.show() + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // |file2.parquet| + // +-------------+ + // $example off:recursive_file_lookup$ + spark.sql("set spark.sql.files.ignoreCorruptFiles=false") + // $example on:load_with_path_glob_filter$ + val testGlobFilterDF = spark.read.format("parquet") + .option("pathGlobFilter", "*.parquet") // json file should be filtered out + .load("examples/src/main/resources/dir1") + testGlobFilterDF.show() + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // +-------------+ + // $example off:load_with_path_glob_filter$ + } + private def runBasicDataSourceExample(spark: SparkSession): Unit = { // $example on:generic_load_save_functions$ val usersDF = spark.read.load("examples/src/main/resources/users.parquet") @@ -56,11 +99,6 @@ object SQLDataSourceExample { .option("header", "true") .load("examples/src/main/resources/people.csv") // $example off:manual_load_options_csv$ - // $example on:load_with_path_glob_filter$ - val partitionedUsersDF = spark.read.format("orc") - .option("pathGlobFilter", "*.orc") - .load("examples/src/main/resources/partitioned_users.orc") - // $example off:load_with_path_glob_filter$ // $example on:manual_save_options_orc$ usersDF.write.format("orc") .option("orc.bloom.filter.columns", "favorite_color") From 92f57237871400ab9d499e1174af22a867c01988 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 5 Feb 2020 18:48:45 +0800 Subject: [PATCH 0015/1280] [SPARK-30668][SQL] Support `SimpleDateFormat` patterns in parsing timestamps/dates strings ### What changes were proposed in this pull request? In the PR, I propose to partially revert the commit https://github.com/apache/spark/commit/51a6ba0181a013f2b62b47184785a8b6f6a78f12, and provide a legacy parser based on `FastDateFormat` which is compatible to `SimpleDateFormat`. To enable the legacy parser, set `spark.sql.legacy.timeParser.enabled` to `true`. ### Why are the changes needed? To allow users to restore old behavior in parsing timestamps/dates using `SimpleDateFormat` patterns. The main reason for restoring is `DateTimeFormatter`'s patterns are not fully compatible to `SimpleDateFormat` patterns, see https://issues.apache.org/jira/browse/SPARK-30668 ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? - Added new test to `DateFunctionsSuite` - Restored additional test cases in `JsonInferSchemaSuite`. Closes #27441 from MaxGekk/support-simpledateformat. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan (cherry picked from commit 459e757ed40fd1cdd37911d3f57b48b54ca2fff7) Signed-off-by: Wenchen Fan --- docs/sql-migration-guide.md | 4 +- .../sql/catalyst/util/DateFormatter.scala | 35 +++++++- .../catalyst/util/TimestampFormatter.scala | 38 +++++++-- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../catalyst/json/JsonInferSchemaSuite.scala | 79 +++++++++++-------- .../apache/spark/sql/DateFunctionsSuite.scala | 14 ++++ 6 files changed, 136 insertions(+), 44 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 0c47370283736..5a5e802f6a900 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -67,9 +67,7 @@ license: | - Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html). The changes impact on the results for dates before October 15, 1582 (Gregorian) and affect on the following Spark 3.0 API: - - CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpose with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`. - - - The `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`. + - Parsing/formatting of timestamp/date strings. This effects on CSV/JSON datasources and on the `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions when patterns specified by users is used for parsing and formatting. Since Spark 3.0, the conversions are based on `java.time.format.DateTimeFormatter`, see https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html. New implementation performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`. In Spark version 2.4 and earlier, `java.text.SimpleDateFormat` is used for timestamp/date string conversions, and the supported patterns are described in https://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html. The old behavior can be restored by setting `spark.sql.legacy.timeParser.enabled` to `true`. - The `weekofyear`, `weekday`, `dayofweek`, `date_trunc`, `from_utc_timestamp`, `to_utc_timestamp`, and `unix_timestamp` functions use java.time API for calculation week number of year, day number of week as well for conversion from/to TimestampType values in UTC time zone. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 7f982b019c8d1..28189b65dee9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -20,7 +20,10 @@ package org.apache.spark.sql.catalyst.util import java.time.{LocalDate, ZoneId} import java.util.Locale -import DateTimeUtils.{convertSpecialDate, localDateToDays} +import org.apache.commons.lang3.time.FastDateFormat + +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, localDateToDays} +import org.apache.spark.sql.internal.SQLConf sealed trait DateFormatter extends Serializable { def parse(s: String): Int // returns days since epoch @@ -48,17 +51,41 @@ class Iso8601DateFormatter( } } +class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { + @transient + private lazy val format = FastDateFormat.getInstance(pattern, locale) + + override def parse(s: String): Int = { + val milliseconds = format.parse(s).getTime + DateTimeUtils.millisToDays(milliseconds) + } + + override def format(days: Int): String = { + val date = DateTimeUtils.toJavaDate(days) + format.format(date) + } +} + object DateFormatter { - val defaultPattern: String = "uuuu-MM-dd" val defaultLocale: Locale = Locale.US def apply(format: String, zoneId: ZoneId, locale: Locale): DateFormatter = { - new Iso8601DateFormatter(format, zoneId, locale) + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyDateFormatter(format, locale) + } else { + new Iso8601DateFormatter(format, zoneId, locale) + } } def apply(format: String, zoneId: ZoneId): DateFormatter = { apply(format, zoneId, defaultLocale) } - def apply(zoneId: ZoneId): DateFormatter = apply(defaultPattern, zoneId) + def apply(zoneId: ZoneId): DateFormatter = { + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyDateFormatter("yyyy-MM-dd", defaultLocale) + } else { + new Iso8601DateFormatter("uuuu-MM-dd", zoneId, defaultLocale) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 5be4807083fa3..fe1a4fe710c20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -22,10 +22,14 @@ import java.time._ import java.time.format.DateTimeParseException import java.time.temporal.ChronoField.MICRO_OF_SECOND import java.time.temporal.TemporalQueries -import java.util.Locale +import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit.SECONDS -import DateTimeUtils.convertSpecialTimestamp +import org.apache.commons.lang3.time.FastDateFormat + +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS +import org.apache.spark.sql.catalyst.util.DateTimeUtils.convertSpecialTimestamp +import org.apache.spark.sql.internal.SQLConf sealed trait TimestampFormatter extends Serializable { /** @@ -86,12 +90,32 @@ class FractionTimestampFormatter(zoneId: ZoneId) override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter } +class LegacyTimestampFormatter( + pattern: String, + zoneId: ZoneId, + locale: Locale) extends TimestampFormatter { + + @transient private lazy val format = + FastDateFormat.getInstance(pattern, TimeZone.getTimeZone(zoneId), locale) + + protected def toMillis(s: String): Long = format.parse(s).getTime + + override def parse(s: String): Long = toMillis(s) * MICROS_PER_MILLIS + + override def format(us: Long): String = { + format.format(DateTimeUtils.toJavaTimestamp(us)) + } +} + object TimestampFormatter { - val defaultPattern: String = "uuuu-MM-dd HH:mm:ss" val defaultLocale: Locale = Locale.US def apply(format: String, zoneId: ZoneId, locale: Locale): TimestampFormatter = { - new Iso8601TimestampFormatter(format, zoneId, locale) + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyTimestampFormatter(format, zoneId, locale) + } else { + new Iso8601TimestampFormatter(format, zoneId, locale) + } } def apply(format: String, zoneId: ZoneId): TimestampFormatter = { @@ -99,7 +123,11 @@ object TimestampFormatter { } def apply(zoneId: ZoneId): TimestampFormatter = { - apply(defaultPattern, zoneId, defaultLocale) + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyTimestampFormatter("yyyy-MM-dd HH:mm:ss", zoneId, defaultLocale) + } else { + new Iso8601TimestampFormatter("uuuu-MM-dd HH:mm:ss", zoneId, defaultLocale) + } } def getFractionFormatter(zoneId: ZoneId): TimestampFormatter = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5ce5692123805..acc0922e2cee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2159,6 +2159,14 @@ object SQLConf { .checkValue(_ > 0, "The value of spark.sql.addPartitionInBatch.size must be positive") .createWithDefault(100) + val LEGACY_TIME_PARSER_ENABLED = buildConf("spark.sql.legacy.timeParser.enabled") + .internal() + .doc("When set to true, java.text.SimpleDateFormat is used for formatting and parsing " + + "dates/timestamps in a locale-sensitive manner. When set to false, classes from " + + "java.time.* packages are used for the same purpose.") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * @@ -2447,6 +2455,8 @@ class SQLConf extends Serializable with Logging { def legacyMsSqlServerNumericMappingEnabled: Boolean = getConf(LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED) + def legacyTimeParserEnabled: Boolean = getConf(SQLConf.LEGACY_TIME_PARSER_ENABLED) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala index a48e61861c158..c2e03bd2c3609 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.json -import com.fasterxml.jackson.core.JsonFactory - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper { @@ -41,45 +40,61 @@ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper { } test("inferring timestamp type") { - checkTimestampType("yyyy", """{"a": "2018"}""") - checkTimestampType("yyyy=MM", """{"a": "2018=12"}""") - checkTimestampType("yyyy MM dd", """{"a": "2018 12 02"}""") - checkTimestampType( - "yyyy-MM-dd'T'HH:mm:ss.SSS", - """{"a": "2018-12-02T21:04:00.123"}""") - checkTimestampType( - "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXX", - """{"a": "2018-12-02T21:04:00.123567+01:00"}""") + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkTimestampType("yyyy", """{"a": "2018"}""") + checkTimestampType("yyyy=MM", """{"a": "2018=12"}""") + checkTimestampType("yyyy MM dd", """{"a": "2018 12 02"}""") + checkTimestampType( + "yyyy-MM-dd'T'HH:mm:ss.SSS", + """{"a": "2018-12-02T21:04:00.123"}""") + checkTimestampType( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXX", + """{"a": "2018-12-02T21:04:00.123567+01:00"}""") + } + } } test("prefer decimals over timestamps") { - checkType( - options = Map( - "prefersDecimal" -> "true", - "timestampFormat" -> "yyyyMMdd.HHmmssSSS" - ), - json = """{"a": "20181202.210400123"}""", - dt = DecimalType(17, 9) - ) + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map( + "prefersDecimal" -> "true", + "timestampFormat" -> "yyyyMMdd.HHmmssSSS" + ), + json = """{"a": "20181202.210400123"}""", + dt = DecimalType(17, 9) + ) + } + } } test("skip decimal type inferring") { - checkType( - options = Map( - "prefersDecimal" -> "false", - "timestampFormat" -> "yyyyMMdd.HHmmssSSS" - ), - json = """{"a": "20181202.210400123"}""", - dt = TimestampType - ) + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map( + "prefersDecimal" -> "false", + "timestampFormat" -> "yyyyMMdd.HHmmssSSS" + ), + json = """{"a": "20181202.210400123"}""", + dt = TimestampType + ) + } + } } test("fallback to string type") { - checkType( - options = Map("timestampFormat" -> "yyyy,MM,dd.HHmmssSSS"), - json = """{"a": "20181202.210400123"}""", - dt = StringType - ) + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map("timestampFormat" -> "yyyy,MM,dd.HHmmssSSS"), + json = """{"a": "20181202.210400123"}""", + dt = StringType + ) + } + } } test("disable timestamp inferring") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index d7d8c2c52d12b..3b3d3cc3d7a17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -789,4 +789,18 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { Row(Timestamp.valueOf("2015-07-24 07:00:00")), Row(Timestamp.valueOf("2015-07-24 22:00:00")))) } + + test("SPARK-30668: use legacy timestamp parser in to_timestamp") { + def checkTimeZoneParsing(expected: Any): Unit = { + val df = Seq("2020-01-27T20:06:11.847-0800").toDF("ts") + checkAnswer(df.select(to_timestamp(col("ts"), "yyyy-MM-dd'T'HH:mm:ss.SSSz")), + Row(expected)) + } + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") { + checkTimeZoneParsing(Timestamp.valueOf("2020-01-27 20:06:11.847")) + } + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "false") { + checkTimeZoneParsing(null) + } + } } From fb99ebc47d12e82a2814b6b4b3b97148fbaa8627 Mon Sep 17 00:00:00 2001 From: turbofei Date: Wed, 5 Feb 2020 21:24:02 +0800 Subject: [PATCH 0016/1280] [SPARK-26218][SQL][FOLLOW UP] Fix the corner case when casting float to Integer ### What changes were proposed in this pull request? When spark.sql.ansi.enabled is true, for the statement: ``` select cast(cast(2147483648 as Float) as Integer) //result is 2147483647 ``` Its result is 2147483647 and does not throw `ArithmeticException`. The root cause is that, the below code does not work for some corner cases. https://github.com/apache/spark/blob/94fc0e3235162afc6038019eed6ec546e3d1983e/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala#L129-L141 For example: ![image](https://user-images.githubusercontent.com/6757692/72074911-badfde80-332d-11ea-963e-2db0e43c33e8.png) In this PR, I fix it by comparing Math.floor(x) with Int.MaxValue directly. ### Why are the changes needed? Result corrupt. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added Unit test. Closes #27151 from turboFei/SPARK-26218-follow-up-int-overflow. Authored-by: turbofei Signed-off-by: Wenchen Fan (cherry picked from commit 6d507b4a31feb965bf31d104f1a6a2c359b166dc) Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/types/numerics.scala | 16 ++++++++-------- .../sql-tests/results/postgreSQL/float4.sql.out | 5 +++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +++++++++++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 1ac85360f944f..b5226213effc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -121,10 +121,10 @@ object FloatExactNumeric extends FloatIsFractional { private def overflowException(x: Float, dataType: String) = throw new ArithmeticException(s"Casting $x to $dataType causes overflow") - private val intUpperBound = Int.MaxValue.toFloat - private val intLowerBound = Int.MinValue.toFloat - private val longUpperBound = Long.MaxValue.toFloat - private val longLowerBound = Long.MinValue.toFloat + private val intUpperBound = Int.MaxValue + private val intLowerBound = Int.MinValue + private val longUpperBound = Long.MaxValue + private val longLowerBound = Long.MinValue override def toInt(x: Float): Int = { // When casting floating values to integral types, Spark uses the method `Numeric.toInt` @@ -155,10 +155,10 @@ object DoubleExactNumeric extends DoubleIsFractional { private def overflowException(x: Double, dataType: String) = throw new ArithmeticException(s"Casting $x to $dataType causes overflow") - private val intUpperBound = Int.MaxValue.toDouble - private val intLowerBound = Int.MinValue.toDouble - private val longUpperBound = Long.MaxValue.toDouble - private val longLowerBound = Long.MinValue.toDouble + private val intUpperBound = Int.MaxValue + private val intLowerBound = Int.MinValue + private val longUpperBound = Long.MaxValue + private val longLowerBound = Long.MinValue override def toInt(x: Double): Int = { if (Math.floor(x) <= intUpperBound && Math.ceil(x) >= intLowerBound) { diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out index ba913789d5623..fe8375c5eab8f 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out @@ -322,9 +322,10 @@ struct -- !query SELECT int(float('2147483647')) -- !query schema -struct +struct<> -- !query output -2147483647 +java.lang.ArithmeticException +Casting 2.14748365E9 to int causes overflow -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a6dae9a269740..11f9724e587f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3383,6 +3383,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df, Row(1)) } } + + test("SPARK-26218: Fix the corner case when casting float to Integer") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + intercept[ArithmeticException]( + sql("SELECT CAST(CAST(2147483648 as FLOAT) as Integer)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST(2147483648 as DOUBLE) as Integer)").collect() + ) + } + } } case class Foo(bar: Option[String]) From abc93b05272b2853862689b553eec1320d3f5f7f Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 5 Feb 2020 07:54:16 -0800 Subject: [PATCH 0017/1280] [MINOR][DOC] Add migration note for removing `org.apache.spark.ml.image.ImageSchema.readImages` ### What changes were proposed in this pull request? Add migration note for removing `org.apache.spark.ml.image.ImageSchema.readImages` ### Why are the changes needed? ### Does this PR introduce any user-facing change? ### How was this patch tested? Closes #27467 from WeichenXu123/SC-26286. Authored-by: WeichenXu Signed-off-by: Dongjoon Hyun (cherry picked from commit ec70e0708f953f3b22ec17d931ff388d007ac1f6) Signed-off-by: Dongjoon Hyun --- docs/ml-migration-guide.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/ml-migration-guide.md b/docs/ml-migration-guide.md index 49f701b2156b3..860c941e6b44b 100644 --- a/docs/ml-migration-guide.md +++ b/docs/ml-migration-guide.md @@ -32,6 +32,7 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide. {:.no_toc} * `OneHotEncoder` which is deprecated in 2.3, is removed in 3.0 and `OneHotEncoderEstimator` is now renamed to `OneHotEncoder`. +* `org.apache.spark.ml.image.ImageSchema.readImages` which is deprecated in 2.3, is removed in 3.0, use `spark.read.format('image')` instead. ### Changes of behavior {:.no_toc} From 5561cead40041d2798fab3876ef75bcfbfdb854a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Feb 2020 11:01:53 -0800 Subject: [PATCH 0018/1280] [SPARK-30738][K8S] Use specific image version in "Launcher client dependencies" test ### What changes were proposed in this pull request? This PR use a specific version of docker image instead of `latest`. As of today, when I run K8s integration test locally, this test case fails always. Also, in this PR, I shows two consecutive failures with a dummy change. - https://github.com/apache/spark/pull/27465#issuecomment-582326614 - https://github.com/apache/spark/pull/27465#issuecomment-582329114 ``` - Launcher client dependencies *** FAILED *** ``` After that, I added the patch and K8s Integration test passed. - https://github.com/apache/spark/pull/27465#issuecomment-582361696 ### Why are the changes needed? [SPARK-28465](https://github.com/apache/spark/pull/25222) switched from `v4.0.0-stable-4.0-master-centos-7-x86_64` to `latest` to catch up the API change. However, the API change seems to occur again. We had better use a specific version to prevent accidental failures. ```scala - .withImage("ceph/daemon:v4.0.0-stable-4.0-master-centos-7-x86_64") + .withImage("ceph/daemon:latest") ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Pass `Launcher client dependencies` test in Jenkins K8s Integration Suite. Or, run K8s Integration test locally. Closes #27465 from dongjoon-hyun/SPARK-K8S-IT. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit 9d90c8b898d0f043afbcebd901ec866c2883c6ca) Signed-off-by: Dongjoon Hyun --- .../spark/deploy/k8s/integrationtest/DepsTestsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala index 7181774b9f17e..414126862ed69 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala @@ -63,7 +63,7 @@ private[spark] trait DepsTestsSuite { k8sSuite: KubernetesSuite => ).asJava new ContainerBuilder() - .withImage("ceph/daemon:latest") + .withImage("ceph/daemon:v4.0.3-stable-4.0-nautilus-centos-7-x86_64") .withImagePullPolicy("Always") .withName(cName) .withPorts(new ContainerPortBuilder() From 3013d28d0a17e6fcea6e42e7caee23a43dd7e196 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 5 Feb 2020 11:19:42 -0800 Subject: [PATCH 0019/1280] [SPARK-29864][SQL][FOLLOWUP] Reference the config for the old behavior in error message ### What changes were proposed in this pull request? Follow up work for SPARK-29864, reference the config `spark.sql.legacy.fromDayTimeString.enabled` in error message. ### Why are the changes needed? For better usability. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests. Closes #27464 from xuanyuanking/SPARK-29864-follow. Authored-by: Yuanjian Li Signed-off-by: Dongjoon Hyun (cherry picked from commit 4938905a1c047e367c066e39dce8232bfcff14f1) Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/util/IntervalUtils.scala | 9 +++++++-- .../sql-tests/results/ansi/interval.sql.out | 12 +++++------ .../sql-tests/results/interval.sql.out | 12 +++++------ .../results/postgreSQL/interval.sql.out | 20 +++++++++---------- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 7692299a46ef5..2d98384363323 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -176,6 +176,9 @@ object IntervalUtils { private val dayTimePatternLegacy = "^([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?$".r + private val fallbackNotice = s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + + "to restore the behavior before Spark 3.0." + /** * Legacy method of parsing a string in a day-time format. It ignores the `from` bound, * and takes into account only the `to` bound by truncating the result. For example, @@ -195,7 +198,8 @@ object IntervalUtils { require(input != null, "Interval day-time string must be not null") assert(input.length == input.trim.length) val m = dayTimePatternLegacy.pattern.matcher(input) - require(m.matches, s"Interval string must match day-time format of 'd h:m:s.n': $input") + require(m.matches, s"Interval string must match day-time format of 'd h:m:s.n': $input, " + + s"$fallbackNotice") try { val sign = if (m.group(1) != null && m.group(1) == "-") -1 else 1 @@ -296,7 +300,8 @@ object IntervalUtils { require(regexp.isDefined, s"Cannot support (interval '$input' $from to $to) expression") val pattern = regexp.get.pattern val m = pattern.matcher(input) - require(m.matches, s"Interval string must match day-time format of '$pattern': $input") + require(m.matches, s"Interval string must match day-time format of '$pattern': $input, " + + s"$fallbackNotice") var micros: Long = 0L var days: Int = 0 unitsRange(to, from).foreach { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index ab6130da869c4..f37049064d989 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -320,7 +320,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 20 15:40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '20 15:40:32.99899999' day to hour @@ -334,7 +334,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2})$': 20 15:40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2})$': 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '20 15:40:32.99899999' day to minute @@ -348,7 +348,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 15:40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '15:40:32.99899999' hour to minute @@ -362,7 +362,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '15:40.99899999' hour to second @@ -376,7 +376,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '15:40' hour to second @@ -390,7 +390,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 20 40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '20 40:32.99899999' minute to second diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 8f523a35f3c19..94b4f15815ca5 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -314,7 +314,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 20 15:40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '20 15:40:32.99899999' day to hour @@ -328,7 +328,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2})$': 20 15:40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2})$': 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '20 15:40:32.99899999' day to minute @@ -342,7 +342,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 15:40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '15:40:32.99899999' hour to minute @@ -356,7 +356,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '15:40.99899999' hour to second @@ -370,7 +370,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '15:40' hour to second @@ -384,7 +384,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 20 40:32.99899999(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == select interval '20 40:32.99899999' minute to second diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out index 4bd846d3ff923..62d47410aab65 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out @@ -105,7 +105,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 1 2:03(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03' day to hour @@ -119,7 +119,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 1 2:03:04(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2})$': 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03:04' day to hour @@ -141,7 +141,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2})$': 1 2:03:04(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2})$': 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03:04' day to minute @@ -155,7 +155,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d+) (?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03' day to second @@ -177,7 +177,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 1 2:03(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03' hour to minute @@ -191,7 +191,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 1 2:03:04(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2})$': 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03:04' hour to minute @@ -205,7 +205,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03' hour to second @@ -219,7 +219,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03:04(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03:04' hour to second @@ -233,7 +233,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03' minute to second @@ -247,7 +247,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03:04(line 1, pos 16) +requirement failed: Interval string must match day-time format of '^(?[+|-])?(?\d{1,2}):(?(\d{1,2})(\.(\d{1,9}))?)$': 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0.(line 1, pos 16) == SQL == SELECT interval '1 2:03:04' minute to second From d06a9dfd9098b704c75047161590bb4a32b25286 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Feb 2020 12:36:51 -0800 Subject: [PATCH 0020/1280] [SPARK-30721][SQL][TESTS] Fix DataFrameAggregateSuite when enabling AQE ### What changes were proposed in this pull request? update `DataFrameAggregateSuite` to make it pass with AQE ### Why are the changes needed? We don't need to turn off AQE in `DataFrameAggregateSuite` ### Does this PR introduce any user-facing change? no ### How was this patch tested? run `DataFrameAggregateSuite` locally with AQE on. Closes #27451 from cloud-fan/aqe-test. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun (cherry picked from commit 3b26f807a0eb0e59c5123c3f1e2262b712800c0f) Signed-off-by: Dongjoon Hyun --- .../spark/sql/DataFrameAggregateSuite.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index dc1767a6852f6..d7df75fd0e2c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -615,34 +615,33 @@ class DataFrameAggregateSuite extends QueryTest Seq((true, true), (true, false), (false, true), (false, false))) { withSQLConf( (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), - (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString), - (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false")) { - // When enable AQE, the WholeStageCodegenExec is added during QueryStageExec. + (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") // test case for HashAggregate val hashAggDF = df.groupBy("x").agg(c, sum("y")) + hashAggDF.collect() val hashAggPlan = hashAggDF.queryExecution.executedPlan if (wholeStage) { - assert(hashAggPlan.find { + assert(find(hashAggPlan) { case WholeStageCodegenExec(_: HashAggregateExec) => true case _ => false }.isDefined) } else { - assert(hashAggPlan.isInstanceOf[HashAggregateExec]) + assert(stripAQEPlan(hashAggPlan).isInstanceOf[HashAggregateExec]) } - hashAggDF.collect() // test case for ObjectHashAggregate and SortAggregate val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) - val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan + objHashAggOrSortAggDF.collect() + val objHashAggOrSortAggPlan = + stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan) if (useObjectHashAgg) { assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) } else { assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) } - objHashAggOrSortAggDF.collect() } } } From da1caaede6aee381a74f9a19b7d850c459ba215d Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Thu, 6 Feb 2020 13:01:08 +0900 Subject: [PATCH 0021/1280] [SPARK-30737][SPARK-27262][R][BUILD] Reenable CRAN check with UTF-8 encoding to DESCRIPTION ### What changes were proposed in this pull request? This PR proposes to reenable CRAN check disabled at https://github.com/apache/spark/pull/27460. Given the tests https://github.com/apache/spark/pull/27468, seems we should also port https://github.com/apache/spark/pull/23823 together. ### Why are the changes needed? To check CRAN back. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? It was tested at https://github.com/apache/spark/pull/27468 and Jenkins should test it out. Closes #27472 from HyukjinKwon/SPARK-30737. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon (cherry picked from commit b95ccb1d8b726b11435789cdb5882df6643430ed) Signed-off-by: HyukjinKwon --- R/pkg/DESCRIPTION | 1 + R/pkg/tests/fulltests/test_sparkSQL.R | 7 +------ R/run-tests.sh | 7 +++---- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 95d3e52bef3a9..c8cb1c3a992ad 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -62,3 +62,4 @@ Collate: RoxygenNote: 5.0.1 VignetteBuilder: knitr NeedsCompilation: no +Encoding: UTF-8 diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 23fadc4373c3f..c1d277ac84be1 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -496,12 +496,7 @@ test_that("SPARK-17811: can create DataFrame containing NA as date and time", { expect_true(is.na(DF$date[2])) expect_equal(DF$date[1], as.Date("2016-10-01")) expect_true(is.na(DF$time[2])) - # Warnings were suppressed due to Jenkins environment issues. - # It shows both warnings as below in Jenkins: - # - Your system is mis-configured: /etc/localtime is not a symlink - # - It is strongly recommended to set environment variable TZ to - # America/Los_Angeles (or equivalent) - suppressWarnings(expect_equal(DF$time[1], as.POSIXlt("2016-01-10"))) + expect_equal(DF$time[1], as.POSIXlt("2016-01-10")) }) test_that("create DataFrame with complex types", { diff --git a/R/run-tests.sh b/R/run-tests.sh index 782b5f5baebaf..51ca7d600caf0 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -31,10 +31,9 @@ NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" # Also run the documentation tests for CRAN CRAN_CHECK_LOG_FILE=$FWDIR/cran-check.out rm -f $CRAN_CHECK_LOG_FILE -# TODO(SPARK-30737) reenable this once packages are correctly installed in Jenkins -# NO_TESTS=1 NO_MANUAL=1 $FWDIR/check-cran.sh 2>&1 | tee -a $CRAN_CHECK_LOG_FILE -# FAILED=$((PIPESTATUS[0]||$FAILED)) -touch $CRAN_CHECK_LOG_FILE + +NO_TESTS=1 NO_MANUAL=1 $FWDIR/check-cran.sh 2>&1 | tee -a $CRAN_CHECK_LOG_FILE +FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)" From baf1a07704a6404313ebf558652e26acf137c5b4 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 6 Feb 2020 12:48:27 +0800 Subject: [PATCH 0022/1280] [SPARK-30729][CORE] Eagerly filter out zombie TaskSetManager before offering resources ### What changes were proposed in this pull request? Eagerly filter out zombie `TaskSetManager` before offering resources to reduce any overhead as possible. And this PR also avoid doing `recomputeLocality` and `addPendingTask` when `TaskSetManager` is zombie. ### Why are the changes needed? Zombie `TaskSetManager` could still exist in Pool's `schedulableQueue` when it has running tasks. Offering resources on a zombie `TaskSetManager` could bring unnecessary overhead and is meaningless. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Pass Jenkins. Closes #27455 from Ngone51/exclude-zombie-tsm. Authored-by: yi.wu Signed-off-by: Wenchen Fan (cherry picked from commit aebabf0bed712511eaa8844cab3a0c562219b2d0) Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f25a36c7af22a..6a1d460e6a9d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -430,7 +430,7 @@ private[spark] class TaskSchedulerImpl( val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableResources = shuffledOffers.map(_.resources).toArray val availableCpus = shuffledOffers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue + val sortedTaskSets = rootPool.getSortedTaskSetQueue.filterNot(_.isZombie) for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( taskSet.parent.name, taskSet.name, taskSet.runningTasks)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3b620ec69a9ab..2ce11347ade39 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -229,6 +229,8 @@ private[spark] class TaskSetManager( index: Int, resolveRacks: Boolean = true, speculatable: Boolean = false): Unit = { + // A zombie TaskSetManager may reach here while handling failed task. + if (isZombie) return val pendingTaskSetToAddTo = if (speculatable) pendingSpeculatableTasks else pendingTasks for (loc <- tasks(index).preferredLocations) { loc match { @@ -1082,6 +1084,8 @@ private[spark] class TaskSetManager( } def recomputeLocality(): Unit = { + // A zombie TaskSetManager may reach here while executorLost happens + if (isZombie) return val previousLocalityLevel = myLocalityLevels(currentLocalityIndex) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) From fbb49664532135e86a7fb83d261ce1f500c041b9 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 6 Feb 2020 13:54:17 +0800 Subject: [PATCH 0023/1280] [SPARK-30612][SQL] Resolve qualified column name with v2 tables ### What changes were proposed in this pull request? This PR fixes the issue where queries with qualified columns like `SELECT t.a FROM t` would fail to resolve for v2 tables. This PR would allow qualified column names in query as following: ```SQL SELECT testcat.ns1.ns2.tbl.foo FROM testcat.ns1.ns2.tbl SELECT ns1.ns2.tbl.foo FROM testcat.ns1.ns2.tbl SELECT ns2.tbl.foo FROM testcat.ns1.ns2.tbl SELECT tbl.foo FROM testcat.ns1.ns2.tbl ``` ### Why are the changes needed? This is a bug because you cannot qualify column names in queries. ### Does this PR introduce any user-facing change? Yes, now users can qualify column names for v2 tables. ### How was this patch tested? Added new tests. Closes #27391 from imback82/qualified_col. Authored-by: Terry Kim Signed-off-by: Wenchen Fan (cherry picked from commit c27a616450959a5e984d10bf93b12ac0ced6c94d) Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 5 +- .../expressions/namedExpressions.scala | 2 - .../sql/catalyst/expressions/package.scala | 83 ++++++++++- .../spark/sql/catalyst/identifiers.scala | 16 +- .../plans/logical/basicLogicalOperators.scala | 17 ++- .../spark/sql/catalyst/trees/TreeNode.scala | 3 +- .../AttributeResolutionSuite.scala | 137 ++++++++++++++++++ .../sql/catalyst/trees/TreeNodeSuite.scala | 10 +- .../sql-tests/results/group-by-filter.sql.out | 32 ++-- .../invalid-correlation.sql.out | 4 +- .../sql/connector/DataSourceV2SQLSuite.scala | 38 +++++ .../benchmark/TPCDSQueryBenchmark.scala | 2 +- .../command/PlanResolutionSuite.scala | 65 +++++---- .../sql/hive/HiveMetastoreCatalogSuite.scala | 2 +- 14 files changed, 342 insertions(+), 74 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3fd5039a4f116..56cc2a274bb7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -799,6 +799,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp { case u: UnresolvedRelation => lookupV2Relation(u.multipartIdentifier) + .map(SubqueryAlias(u.multipartIdentifier, _)) .getOrElse(u) case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) => @@ -923,7 +924,9 @@ class Analyzer( case v1Table: V1Table => v1SessionCatalog.getRelation(v1Table.v1Table) case table => - DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + SubqueryAlias( + identifier, + DataSourceV2Relation.create(table, Some(catalog), Some(ident))) } val key = catalog.name +: ident.namespace :+ ident.name Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3362353e2662a..02e90f8458c3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -236,8 +236,6 @@ case class AttributeReference( val qualifier: Seq[String] = Seq.empty[String]) extends Attribute with Unevaluable { - // currently can only handle qualifier of length 2 - require(qualifier.length <= 2) /** * Returns true iff the expression id is the same for both attributes. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 7164b6b82adbc..9f42e643e4cb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -23,7 +23,6 @@ import com.google.common.collect.Maps import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -153,13 +152,19 @@ package object expressions { unique(grouped) } - /** Perform attribute resolution given a name and a resolver. */ - def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + /** Returns true if all qualifiers in `attrs` have 2 or less parts. */ + @transient private val hasTwoOrLessQualifierParts: Boolean = + attrs.forall(_.qualifier.length <= 2) + + /** Match attributes for the case where all qualifiers in `attrs` have 2 or less parts. */ + private def matchWithTwoOrLessQualifierParts( + nameParts: Seq[String], + resolver: Resolver): (Seq[Attribute], Seq[String]) = { // Collect matching attributes given a name and a lookup. def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { - candidates.toSeq.flatMap(_.collect { + candidates.getOrElse(Nil).collect { case a if resolver(a.name, name) => a.withName(name) - }) + } } // Find matches for the given name assuming that the 1st two parts are qualifier @@ -204,13 +209,79 @@ package object expressions { // If none of attributes match database.table.column pattern or // `table.column` pattern, we try to resolve it as a column. - val (candidates, nestedFields) = matches match { + matches match { case (Seq(), _) => val name = nameParts.head val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) (attributes, nameParts.tail) case _ => matches } + } + + /** + * Match attributes for the case where at least one qualifier in `attrs` has more than 2 parts. + */ + private def matchWithThreeOrMoreQualifierParts( + nameParts: Seq[String], + resolver: Resolver): (Seq[Attribute], Seq[String]) = { + // Returns true if the `short` qualifier is a subset of the last elements of + // `long` qualifier. For example, Seq("a", "b") is a subset of Seq("a", "a", "b"), + // but not a subset of Seq("a", "b", "b"). + def matchQualifier(short: Seq[String], long: Seq[String]): Boolean = { + (long.length >= short.length) && + long.takeRight(short.length) + .zip(short) + .forall(x => resolver(x._1, x._2)) + } + + // Collect attributes that match the given name and qualifier. + // A match occurs if + // 1) the given name matches the attribute's name according to the resolver. + // 2) the given qualifier is a subset of the attribute's qualifier. + def collectMatches( + name: String, + qualifier: Seq[String], + candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.getOrElse(Nil).collect { + case a if resolver(name, a.name) && matchQualifier(qualifier, a.qualifier) => + a.withName(name) + } + } + + // Iterate each string in `nameParts` in a reverse order and try to match the attributes + // considering the current string as the attribute name. For example, if `nameParts` is + // Seq("a", "b", "c"), the match will be performed in the following order: + // 1) name = "c", qualifier = Seq("a", "b") + // 2) name = "b", qualifier = Seq("a") + // 3) name = "a", qualifier = Seq() + // Note that the match is performed in the reverse order in order to match the longest + // qualifier as possible. If a match is found, the remaining portion of `nameParts` + // is also returned as nested fields. + var candidates: Seq[Attribute] = Nil + var nestedFields: Seq[String] = Nil + var i = nameParts.length - 1 + while (i >= 0 && candidates.isEmpty) { + val name = nameParts(i) + candidates = collectMatches( + name, + nameParts.take(i), + direct.get(name.toLowerCase(Locale.ROOT))) + if (candidates.nonEmpty) { + nestedFields = nameParts.takeRight(nameParts.length - i - 1) + } + i -= 1 + } + + (candidates, nestedFields) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + val (candidates, nestedFields) = if (hasTwoOrLessQualifierParts) { + matchWithTwoOrLessQualifierParts(nameParts, resolver) + } else { + matchWithThreeOrMoreQualifierParts(nameParts, resolver) + } def name = UnresolvedAttribute(nameParts).name candidates match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index deceec73dda30..c574a20da0b5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -49,19 +49,21 @@ sealed trait IdentifierWithDatabase { /** * Encapsulates an identifier that is either a alias name or an identifier that has table - * name and optionally a database name. + * name and a qualifier. * The SubqueryAlias node keeps track of the qualifier using the information in this structure - * @param identifier - Is an alias name or a table name - * @param database - Is a database name and is optional + * @param name - Is an alias name or a table name + * @param qualifier - Is a qualifier */ -case class AliasIdentifier(identifier: String, database: Option[String]) - extends IdentifierWithDatabase { +case class AliasIdentifier(name: String, qualifier: Seq[String]) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def this(identifier: String) = this(identifier, Seq()) - def this(identifier: String) = this(identifier, None) + override def toString: String = (qualifier :+ name).quoted } object AliasIdentifier { - def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier) + def apply(name: String): AliasIdentifier = new AliasIdentifier(name) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 40db8b6f49dc4..54e5ff7aeb754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.types._ import org.apache.spark.util.random.RandomSampler @@ -849,18 +850,18 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi /** * Aliased subquery. * - * @param name the alias identifier for this subquery. + * @param identifier the alias identifier for this subquery. * @param child the logical plan of this subquery. */ case class SubqueryAlias( - name: AliasIdentifier, + identifier: AliasIdentifier, child: LogicalPlan) extends OrderPreservingUnaryNode { - def alias: String = name.identifier + def alias: String = identifier.name override def output: Seq[Attribute] = { - val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias)) + val qualifierList = identifier.qualifier :+ alias child.output.map(_.withQualifier(qualifierList)) } override def doCanonicalize(): LogicalPlan = child.canonicalized @@ -877,7 +878,13 @@ object SubqueryAlias { identifier: String, database: String, child: LogicalPlan): SubqueryAlias = { - SubqueryAlias(AliasIdentifier(identifier, Some(database)), child) + SubqueryAlias(AliasIdentifier(identifier, Seq(database)), child) + } + + def apply( + multipartIdentifier: Seq[String], + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(multipartIdentifier.last, multipartIdentifier.init), child) } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ba1eeb38e247e..56a198763b4e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -27,7 +27,7 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.sql.catalyst.IdentifierWithDatabase +import org.apache.spark.sql.catalyst.{AliasIdentifier, IdentifierWithDatabase} import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} import org.apache.spark.sql.catalyst.errors._ @@ -780,6 +780,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case exprId: ExprId => true case field: StructField => true case id: IdentifierWithDatabase => true + case alias: AliasIdentifier => true case join: JoinType => true case spec: BucketSpec => true case catalog: CatalogTable => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala new file mode 100644 index 0000000000000..813a68f68451c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class AttributeResolutionSuite extends SparkFunSuite { + val resolver = caseInsensitiveResolution + + test("basic attribute resolution with namespaces") { + val attrs = Seq( + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t1")), + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "ns3", "t2"))) + + // Try to match attribute reference with name "a" with qualifier "ns1.ns2.t1". + Seq(Seq("t1", "a"), Seq("ns2", "t1", "a"), Seq("ns1", "ns2", "t1", "a")).foreach { nameParts => + attrs.resolve(nameParts, resolver) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0))) + case _ => fail() + } + } + + // Non-matching cases + Seq(Seq("ns1", "ns2", "t1"), Seq("ns2", "a")).foreach { nameParts => + assert(attrs.resolve(nameParts, resolver).isEmpty) + } + } + + test("attribute resolution where table and attribute names are the same") { + val attrs = Seq(AttributeReference("t", IntegerType)(qualifier = Seq("ns1", "ns2", "t"))) + // Matching cases + Seq( + Seq("t"), Seq("t", "t"), Seq("ns2", "t", "t"), Seq("ns1", "ns2", "t", "t") + ).foreach { nameParts => + attrs.resolve(nameParts, resolver) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0))) + case _ => fail() + } + } + + // Non-matching case + assert(attrs.resolve(Seq("ns1", "ns2", "t"), resolver).isEmpty) + } + + test("attribute resolution ambiguity at the attribute name level") { + val attrs = Seq( + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t1")), + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t2"))) + + val ex = intercept[AnalysisException] { + attrs.resolve(Seq("a"), resolver) + } + assert(ex.getMessage.contains( + "Reference 'a' is ambiguous, could be: ns1.t1.a, ns1.ns2.t2.a.")) + } + + test("attribute resolution ambiguity at the qualifier level") { + val attrs = Seq( + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")), + AttributeReference("a", IntegerType)(qualifier = Seq("ns2", "ns1", "t"))) + + val ex = intercept[AnalysisException] { + attrs.resolve(Seq("ns1", "t", "a"), resolver) + } + assert(ex.getMessage.contains( + "Reference 'ns1.t.a' is ambiguous, could be: ns1.t.a, ns2.ns1.t.a.")) + } + + test("attribute resolution with nested fields") { + val attrType = StructType(Seq(StructField("aa", IntegerType), StructField("bb", IntegerType))) + val attrs = Seq(AttributeReference("a", attrType)(qualifier = Seq("ns1", "t"))) + + val resolved = attrs.resolve(Seq("ns1", "t", "a", "aa"), resolver) + resolved match { + case Some(Alias(_, name)) => assert(name == "aa") + case _ => fail() + } + + val ex = intercept[AnalysisException] { + attrs.resolve(Seq("ns1", "t", "a", "cc"), resolver) + } + assert(ex.getMessage.contains("No such struct field cc in aa, bb")) + } + + test("attribute resolution with case insensitive resolver") { + val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t"))) + attrs.resolve(Seq("Ns1", "T", "A"), caseInsensitiveResolution) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0)) && attr.name == "A") + case _ => fail() + } + } + + test("attribute resolution with case sensitive resolver") { + val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t"))) + assert(attrs.resolve(Seq("Ns1", "T", "A"), caseSensitiveResolution).isEmpty) + assert(attrs.resolve(Seq("ns1", "t", "A"), caseSensitiveResolution).isEmpty) + attrs.resolve(Seq("ns1", "t", "a"), caseSensitiveResolution) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0))) + case _ => fail() + } + } + + test("attribute resolution should try to match the longest qualifier") { + // We have two attributes: + // 1) "a.b" where "a" is the name and "b" is the nested field. + // 2) "a.b.a" where "b" is the name, left-side "a" is the qualifier and the right-side "a" + // is the nested field. + // When "a.b" is resolved, "b" is tried first as the name, so it is resolved to #2 attribute. + val a1Type = StructType(Seq(StructField("b", IntegerType))) + val a2Type = StructType(Seq(StructField("a", IntegerType))) + val attrs = Seq( + AttributeReference("a", a1Type)(), + AttributeReference("b", a2Type)(qualifier = Seq("a"))) + attrs.resolve(Seq("a", "b"), resolver) match { + case Some(attr) => assert(attr.semanticEquals(attrs(1))) + case _ => fail() + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 0e094bc06b05f..e72b2e9b1b214 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -433,10 +433,11 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { // Converts AliasIdentifier to JSON assertJSON( - AliasIdentifier("alias"), + AliasIdentifier("alias", Seq("ns1", "ns2")), JObject( "product-class" -> JString(classOf[AliasIdentifier].getName), - "identifier" -> "alias")) + "name" -> "alias", + "qualifier" -> "[ns1, ns2]")) // Converts SubqueryAlias to JSON assertJSON( @@ -445,8 +446,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { JObject( "class" -> classOf[SubqueryAlias].getName, "num-children" -> 1, - "name" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName), - "identifier" -> "t1"), + "identifier" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName), + "name" -> "t1", + "qualifier" -> JArray(Nil)), "child" -> 0), JObject( "class" -> classOf[JsonTestTreeNode].getName, diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index a032678e90fe8..a4c7c2cf90cd7 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -369,13 +369,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; @@ -395,13 +395,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; @@ -420,13 +420,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x] : +- Distinct : +- Project [dept_id#x] -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; @@ -445,13 +445,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x] : +- Distinct : +- Project [dept_id#x] -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 1599634ff9efb..ec7ecf28754ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -110,8 +110,8 @@ struct<> org.apache.spark.sql.AnalysisException Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: Aggregate [min(outer(t2a#x)) AS min(outer())#x] -+- SubqueryAlias `t3` ++- SubqueryAlias t3 +- Project [t3a#x, t3b#x, t3c#x] - +- SubqueryAlias `t3` + +- SubqueryAlias t3 +- LocalRelation [t3a#x, t3b#x, t3c#x] ; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 2c8349a0e6a75..eabcb81c50646 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -679,6 +679,44 @@ class DataSourceV2SQLSuite } } + test("qualified column names for v2 tables") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, point struct) USING foo") + sql(s"INSERT INTO $t VALUES (1, (10, 20))") + + checkAnswer( + sql(s"SELECT testcat.ns1.ns2.tbl.id, testcat.ns1.ns2.tbl.point.x FROM $t"), + Row(1, 10)) + checkAnswer(sql(s"SELECT ns1.ns2.tbl.id, ns1.ns2.tbl.point.x FROM $t"), Row(1, 10)) + checkAnswer(sql(s"SELECT ns2.tbl.id, ns2.tbl.point.x FROM $t"), Row(1, 10)) + checkAnswer(sql(s"SELECT tbl.id, tbl.point.x FROM $t"), Row(1, 10)) + + val ex = intercept[AnalysisException] { + sql(s"SELECT ns1.ns2.ns3.tbl.id from $t") + } + assert(ex.getMessage.contains("cannot resolve '`ns1.ns2.ns3.tbl.id`")) + } + } + + test("qualified column names for v1 tables") { + // unset this config to use the default v2 session catalog. + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + + withTable("t") { + sql("CREATE TABLE t USING json AS SELECT 1 AS i") + checkAnswer(sql("select default.t.i from spark_catalog.t"), Row(1)) + checkAnswer(sql("select t.i from spark_catalog.default.t"), Row(1)) + checkAnswer(sql("select default.t.i from spark_catalog.default.t"), Row(1)) + + // catalog name cannot be used for v1 tables. + val ex = intercept[AnalysisException] { + sql(s"select spark_catalog.default.t.i from spark_catalog.default.t") + } + assert(ex.getMessage.contains("cannot resolve '`spark_catalog.default.t.i`")) + } + } + test("InsertInto: append - across catalog") { val t1 = "testcat.ns1.ns2.tbl" val t2 = "testcat2.db.tbl" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index c93d27f02c686..ad3d79760adf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -81,7 +81,7 @@ object TPCDSQueryBenchmark extends SqlBasedBenchmark { val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.analyzed.foreach { case SubqueryAlias(alias, _: LogicalRelation) => - queryRelations.add(alias.identifier) + queryRelations.add(alias.name) case LogicalRelation(_, _, Some(catalogTable), _) => queryRelations.add(catalogTable.identifier.table) case HiveTableRelation(tableMeta, _, _, _, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index c0c3cd70fcc9e..88f30353cce94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -885,33 +885,34 @@ class PlanResolutionSuite extends AnalysisTest { val parsed4 = parseAndResolve(sql4) parsed1 match { - case DeleteFromTable(_: DataSourceV2Relation, None) => - case _ => fail("Expect DeleteFromTable, bug got:\n" + parsed1.treeString) + case DeleteFromTable(AsDataSourceV2Relation(_), None) => + case _ => fail("Expect DeleteFromTable, but got:\n" + parsed1.treeString) } parsed2 match { case DeleteFromTable( - _: DataSourceV2Relation, + AsDataSourceV2Relation(_), Some(EqualTo(name: UnresolvedAttribute, StringLiteral("Robert")))) => assert(name.name == "name") - case _ => fail("Expect DeleteFromTable, bug got:\n" + parsed2.treeString) + case _ => fail("Expect DeleteFromTable, but got:\n" + parsed2.treeString) } parsed3 match { case DeleteFromTable( - SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Some(EqualTo(name: UnresolvedAttribute, StringLiteral("Robert")))) => assert(name.name == "t.name") - case _ => fail("Expect DeleteFromTable, bug got:\n" + parsed3.treeString) + case _ => fail("Expect DeleteFromTable, but got:\n" + parsed3.treeString) } parsed4 match { - case DeleteFromTable(SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + case DeleteFromTable( + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Some(InSubquery(values, query))) => assert(values.size == 1 && values.head.isInstanceOf[UnresolvedAttribute]) assert(values.head.asInstanceOf[UnresolvedAttribute].name == "t.name") query match { - case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", None), + case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") @@ -944,7 +945,7 @@ class PlanResolutionSuite extends AnalysisTest { parsed1 match { case UpdateTable( - _: DataSourceV2Relation, + AsDataSourceV2Relation(_), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), None) => @@ -956,7 +957,7 @@ class PlanResolutionSuite extends AnalysisTest { parsed2 match { case UpdateTable( - SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), None) => @@ -968,7 +969,7 @@ class PlanResolutionSuite extends AnalysisTest { parsed3 match { case UpdateTable( - SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), Some(EqualTo(p: UnresolvedAttribute, IntegerLiteral(1)))) => @@ -980,14 +981,14 @@ class PlanResolutionSuite extends AnalysisTest { } parsed4 match { - case UpdateTable(SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + case UpdateTable(SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Seq(Assignment(key: UnresolvedAttribute, IntegerLiteral(32))), Some(InSubquery(values, query))) => assert(key.name == "t.age") assert(values.size == 1 && values.head.isInstanceOf[UnresolvedAttribute]) assert(values.head.asInstanceOf[UnresolvedAttribute].name == "t.name") query match { - case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", None), + case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") @@ -1129,7 +1130,7 @@ class PlanResolutionSuite extends AnalysisTest { case AlterTable(_, _, r: DataSourceV2Relation, _) => assert(r.catalog.exists(_ == catlogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) - case Project(_, r: DataSourceV2Relation) => + case Project(_, AsDataSourceV2Relation(r)) => assert(r.catalog.exists(_ == catlogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) case InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) => @@ -1206,8 +1207,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql1) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), @@ -1232,8 +1233,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql2) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, @@ -1258,8 +1259,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql3) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, Seq(DeleteAction(None), UpdateAction(None, updateAssigns)), Seq(InsertAction(None, insertAssigns))) => @@ -1282,8 +1283,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql4) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: Project), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), source: Project), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), @@ -1311,8 +1312,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql5) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: Project), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), source: Project), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), @@ -1346,8 +1347,8 @@ class PlanResolutionSuite extends AnalysisTest { parseAndResolve(sql1) match { case MergeIntoTable( - target: DataSourceV2Relation, - source: DataSourceV2Relation, + AsDataSourceV2Relation(target), + AsDataSourceV2Relation(source), _, Seq(DeleteAction(None), UpdateAction(None, updateAssigns)), Seq(InsertAction( @@ -1453,8 +1454,8 @@ class PlanResolutionSuite extends AnalysisTest { parseAndResolve(sql) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), _: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), _: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(_)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(_)), EqualTo(l: UnresolvedAttribute, r: UnresolvedAttribute), Seq( DeleteAction(Some(EqualTo(dl: UnresolvedAttribute, StringLiteral("delete")))), @@ -1481,3 +1482,11 @@ class PlanResolutionSuite extends AnalysisTest { } // TODO: add tests for more commands. } + +object AsDataSourceV2Relation { + def unapply(plan: LogicalPlan): Option[DataSourceV2Relation] = plan match { + case SubqueryAlias(_, r: DataSourceV2Relation) => Some(r) + case _ => None + } +} + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 20bafd832d0da..b8ef44b096eed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -62,7 +62,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias(AliasIdentifier("vw1", Some("default")), _) => x + case x @ SubqueryAlias(AliasIdentifier("vw1", Seq("default")), _) => x } assert(aliases.size == 1) } From 706c21763e94957efe4a309abd6e13601b8dcaf3 Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 6 Feb 2020 15:24:26 +0900 Subject: [PATCH 0024/1280] [SPARK-29108][SQL][TESTS][FOLLOWUP] Comment out no use test case and add 'insert into' statement of window.sql (Part 2) ### What changes were proposed in this pull request? When I running the `window_part2.sql` tests find it lack insert sql. Therefore, the output is empty. I checked the postgresql and reference https://github.com/postgres/postgres/blob/master/src/test/regress/sql/window.sql Although `window_part1.sql` and `window_part3.sql` exists the insert sql, I think should also add it into `window_part2.sql`. Because only one case reference the table `empsalary` and it throws `AnalysisException`. ``` -- !query select last(salary) over(order by salary range between 1000 preceding and 1000 following), lag(salary) over(order by salary range between 1000 preceding and 1000 following), salary from empsalary -- !query schema struct<> -- !query output org.apache.spark.sql.AnalysisException Window Frame specifiedwindowframe(RangeFrame, -1000, 1000) must match the required frame specifiedwindowframe(RowFrame, -1, -1); ``` So we should do four work: 1. comment out the only one case and create a new ticket. 2. Add `INSERT INTO empsalary`. Note: window_part4.sql not use the table `empsalary`. ### Why are the changes needed? Supplementary test data. ### Does this PR introduce any user-facing change? No ### How was this patch tested? New test case Closes #27439 from beliefer/add-insert-to-window. Authored-by: beliefer Signed-off-by: HyukjinKwon --- .../inputs/postgreSQL/window_part2.sql | 19 ++++++++++-- .../results/postgreSQL/window_part2.sql.out | 29 ++++++++++++------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql index 395149e48d5c8..ba1acc9f56b4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql @@ -15,6 +15,18 @@ CREATE TABLE empsalary ( enroll_date date ) USING parquet; +INSERT INTO empsalary VALUES + ('develop', 10, 5200, date '2007-08-01'), + ('sales', 1, 5000, date '2006-10-01'), + ('personnel', 5, 3500, date '2007-12-10'), + ('sales', 4, 4800, date '2007-08-08'), + ('personnel', 2, 3900, date '2006-12-23'), + ('develop', 7, 4200, date '2008-01-01'), + ('develop', 9, 4500, date '2008-01-01'), + ('sales', 3, 4800, date '2007-08-01'), + ('develop', 8, 6000, date '2006-10-01'), + ('develop', 11, 5200, date '2007-08-15'); + -- [SPARK-28429] SQL Datetime util function being casted to double instead of timestamp -- CREATE TEMP VIEW v_window AS -- SELECT i, min(i) over (order by i range between '1 day' preceding and '10 days' following) as min_i @@ -99,9 +111,10 @@ FROM tenk1 WHERE unique1 < 10; -- nth_value(salary, 1) over(order by salary range between 1000 preceding and 1000 following), -- salary from empsalary; -select last(salary) over(order by salary range between 1000 preceding and 1000 following), -lag(salary) over(order by salary range between 1000 preceding and 1000 following), -salary from empsalary; +-- [SPARK-30734] AnalysisException that window RangeFrame not match RowFrame +-- select last(salary) over(order by salary range between 1000 preceding and 1000 following), +-- lag(salary) over(order by salary range between 1000 preceding and 1000 following), +-- salary from empsalary; -- [SPARK-27951] ANSI SQL: NTH_VALUE function -- select first_value(salary) over(order by salary range between 1000 following and 3000 following diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out index 0015740a0638e..f41659a196ae1 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out @@ -15,6 +15,24 @@ struct<> +-- !query +INSERT INTO empsalary VALUES + ('develop', 10, 5200, date '2007-08-01'), + ('sales', 1, 5000, date '2006-10-01'), + ('personnel', 5, 3500, date '2007-12-10'), + ('sales', 4, 4800, date '2007-08-08'), + ('personnel', 2, 3900, date '2006-12-23'), + ('develop', 7, 4200, date '2008-01-01'), + ('develop', 9, 4500, date '2008-01-01'), + ('sales', 3, 4800, date '2007-08-01'), + ('develop', 8, 6000, date '2006-10-01'), + ('develop', 11, 5200, date '2007-08-15') +-- !query schema +struct<> +-- !query output + + + -- !query SELECT sum(unique1) over (order by four range between 2 preceding and 1 preceding), unique1, four @@ -72,17 +90,6 @@ struct --- !query output -org.apache.spark.sql.AnalysisException -Window Frame specifiedwindowframe(RangeFrame, -1000, 1000) must match the required frame specifiedwindowframe(RowFrame, -1, -1); - - -- !query select ss.id, ss.y, first(ss.y) over w, From 659f8c8ef549fd040596978478898266d24e88ff Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 6 Feb 2020 20:34:29 +0800 Subject: [PATCH 0025/1280] [SPARK-27297][DOC][FOLLOW-UP] Improve documentation for various Scala functions ### What changes were proposed in this pull request? Add examples and parameter description for these Scala functions: * transform * exists * forall * aggregate * zip_with * transform_keys * transform_values * map_filter * map_zip_with ### Why are the changes needed? Better documentation for UX. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Pass Jenkins. Closes #27449 from Ngone51/doc-funcs. Authored-by: yi.wu Signed-off-by: Wenchen Fan (cherry picked from commit 368ee62a5dce83682ccaec92feeea8428af5a8cf) Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/functions.scala | 93 +++++++++++++++++-- 1 file changed, 83 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index da26c5a2f4625..d125581857e0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3410,6 +3410,12 @@ object functions { /** * Returns an array of elements after applying a transformation to each element * in the input array. + * {{{ + * df.select(transform(col("i"), x => x + 1)) + * }}} + * + * @param column the input array column + * @param f col => transformed_col, the lambda function to transform the input column * * @group collection_funcs * @since 3.0.0 @@ -3421,6 +3427,13 @@ object functions { /** * Returns an array of elements after applying a transformation to each element * in the input array. + * {{{ + * df.select(transform(col("i"), (x, i) => x + i)) + * }}} + * + * @param column the input array column + * @param f (col, index) => transformed_col, the lambda function to filter the input column + * given the index. Indices start at 0. * * @group collection_funcs * @since 3.0.0 @@ -3431,6 +3444,12 @@ object functions { /** * Returns whether a predicate holds for one or more elements in the array. + * {{{ + * df.select(exists(col("i"), _ % 2 === 0)) + * }}} + * + * @param column the input array column + * @param f col => predicate, the Boolean predicate to check the input column * * @group collection_funcs * @since 3.0.0 @@ -3441,6 +3460,12 @@ object functions { /** * Returns whether a predicate holds for every element in the array. + * {{{ + * df.select(forall(col("i"), x => x % 2 === 0)) + * }}} + * + * @param column the input array column + * @param f col => predicate, the Boolean predicate to check the input column * * @group collection_funcs * @since 3.0.0 @@ -3453,11 +3478,10 @@ object functions { * Returns an array of elements for which a predicate holds in a given array. * {{{ * df.select(filter(col("s"), x => x % 2 === 0)) - * df.selectExpr("filter(col, x -> x % 2 == 0)") * }}} * - * @param column: the input array column - * @param f: col => predicate, the Boolean predicate to filter the input column + * @param column the input array column + * @param f col => predicate, the Boolean predicate to filter the input column * * @group collection_funcs * @since 3.0.0 @@ -3470,11 +3494,10 @@ object functions { * Returns an array of elements for which a predicate holds in a given array. * {{{ * df.select(filter(col("s"), (x, i) => i % 2 === 0)) - * df.selectExpr("filter(col, (x, i) -> i % 2 == 0)") * }}} * - * @param column: the input array column - * @param f: (col, index) => predicate, the Boolean predicate to filter the input column + * @param column the input array column + * @param f (col, index) => predicate, the Boolean predicate to filter the input column * given the index. Indices start at 0. * * @group collection_funcs @@ -3488,18 +3511,28 @@ object functions { * Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. The final state is converted into the final result * by applying a finish function. + * {{{ + * df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)) + * }}} + * + * @param expr the input array column + * @param initialValue the initial value + * @param merge (combined_value, input_value) => combined_value, the merge function to merge + * an input value to the combined_value + * @param finish combined_value => final_value, the lambda function to convert the combined value + * of all inputs to final result * * @group collection_funcs * @since 3.0.0 */ def aggregate( expr: Column, - zero: Column, + initialValue: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = withExpr { ArrayAggregate( expr.expr, - zero.expr, + initialValue.expr, createLambda(merge), createLambda(finish) ) @@ -3508,17 +3541,31 @@ object functions { /** * Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. + * {{{ + * df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)) + * }}} * + * @param expr the input array column + * @param initialValue the initial value + * @param merge (combined_value, input_value) => combined_value, the merge function to merge + * an input value to the combined_value * @group collection_funcs * @since 3.0.0 */ - def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = - aggregate(expr, zero, merge, c => c) + def aggregate(expr: Column, initialValue: Column, merge: (Column, Column) => Column): Column = + aggregate(expr, initialValue, merge, c => c) /** * Merge two given arrays, element-wise, into a single array using a function. * If one array is shorter, nulls are appended at the end to match the length of the longer * array, before applying the function. + * {{{ + * df.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y)) + * }}} + * + * @param left the left input array column + * @param right the right input array column + * @param f (lCol, rCol) => col, the lambda function to merge two input columns into one column * * @group collection_funcs * @since 3.0.0 @@ -3530,6 +3577,12 @@ object functions { /** * Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new keys for the pairs. + * {{{ + * df.select(transform_keys(col("i"), (k, v) => k + v)) + * }}} + * + * @param expr the input map column + * @param f (key, value) => new_key, the lambda function to transform the key of input map column * * @group collection_funcs * @since 3.0.0 @@ -3541,6 +3594,13 @@ object functions { /** * Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new values for the pairs. + * {{{ + * df.select(transform_values(col("i"), (k, v) => k + v)) + * }}} + * + * @param expr the input map column + * @param f (key, value) => new_value, the lambda function to transform the value of input map + * column * * @group collection_funcs * @since 3.0.0 @@ -3551,6 +3611,12 @@ object functions { /** * Returns a map whose key-value pairs satisfy a predicate. + * {{{ + * df.select(map_filter(col("m"), (k, v) => k * 10 === v)) + * }}} + * + * @param expr the input map column + * @param f (key, value) => predicate, the Boolean predicate to filter the input map column * * @group collection_funcs * @since 3.0.0 @@ -3561,6 +3627,13 @@ object functions { /** * Merge two given maps, key-wise into a single map using a function. + * {{{ + * df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)) + * }}} + * + * @param left the left input map column + * @param right the right input map column + * @param f (key, value1, value2) => new_value, the lambda function to merge the map values * * @group collection_funcs * @since 3.0.0 From 4546f128c17b73f8eaaef7524148d588c304a9d4 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 6 Feb 2020 20:53:44 +0800 Subject: [PATCH 0026/1280] [SPARK-26700][CORE][FOLLOWUP] Add config `spark.network.maxRemoteBlockSizeFetchToMem` ### What changes were proposed in this pull request? Add new config `spark.network.maxRemoteBlockSizeFetchToMem` fallback to the old config `spark.maxRemoteBlockSizeFetchToMem`. ### Why are the changes needed? For naming consistency. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests. Closes #27463 from xuanyuanking/SPARK-26700-follow. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan (cherry picked from commit d8613571bc1847775dd5c1945757279234cb388c) Signed-off-by: Wenchen Fan --- core/src/main/scala/org/apache/spark/SparkConf.scala | 3 ++- .../main/scala/org/apache/spark/internal/config/package.scala | 2 +- docs/configuration.md | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 0e0291d2407d1..40915e3904f7e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -684,7 +684,8 @@ private[spark] object SparkConf extends Logging { "spark.yarn.jars" -> Seq( AlternateConfig("spark.yarn.jar", "2.0")), MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key -> Seq( - AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")), + AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3"), + AlternateConfig("spark.maxRemoteBlockSizeFetchToMem", "3.0")), LISTENER_BUS_EVENT_QUEUE_CAPACITY.key -> Seq( AlternateConfig("spark.scheduler.listenerbus.eventqueue.size", "2.3")), DRIVER_MEMORY_OVERHEAD.key -> Seq( diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index f91f31be2f1ad..02acb6b530737 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -895,7 +895,7 @@ package object config { .createWithDefault(Int.MaxValue) private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = - ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") + ConfigBuilder("spark.network.maxRemoteBlockSizeFetchToMem") .doc("Remote block will be fetched to disk when size of the block is above this threshold " + "in bytes. This is to avoid a giant request takes too much memory. Note this " + "configuration will affect both shuffle fetch and block manager remote block fetch. " + diff --git a/docs/configuration.md b/docs/configuration.md index 2febfe9744d5c..5bd3f3e80cf71 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1810,7 +1810,7 @@ Apart from these, the following properties are also available, and may be useful - spark.maxRemoteBlockSizeFetchToMem + spark.network.maxRemoteBlockSizeFetchToMem 200m Remote block will be fetched to disk when size of the block is above this threshold From b29cb1a82b1a1facf1dd040025db93d998dad4cd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 6 Feb 2020 09:16:14 -0800 Subject: [PATCH 0027/1280] [SPARK-30719][SQL] do not log warning if AQE is intentionally skipped and add a config to force apply ### What changes were proposed in this pull request? Update `InsertAdaptiveSparkPlan` to not log warning if AQE is skipped intentionally. This PR also add a config to not skip AQE. ### Why are the changes needed? It's not a warning at all if we intentionally skip AQE. ### Does this PR introduce any user-facing change? no ### How was this patch tested? run `AdaptiveQueryExecSuite` locally and verify that there is no warning logs. Closes #27452 from cloud-fan/aqe. Authored-by: Wenchen Fan Signed-off-by: Xiao Li (cherry picked from commit 8ce58627ebe4f0372fba9a30d8cd4213611acd9b) Signed-off-by: Xiao Li --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../adaptive/InsertAdaptiveSparkPlan.scala | 83 +++++++++++-------- .../adaptive/AdaptiveQueryExecSuite.scala | 9 ++ 3 files changed, 65 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index acc0922e2cee7..bed8410acaed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -358,6 +358,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_EXECUTION_FORCE_APPLY = buildConf("spark.sql.adaptive.forceApply") + .internal() + .doc("Adaptive query execution is skipped when the query does not have exchanges or " + + "sub-queries. By setting this config to true (together with " + + s"'${ADAPTIVE_EXECUTION_ENABLED.key}' enabled), Spark will force apply adaptive query " + + "execution for all supported queries.") + .booleanConf + .createWithDefault(false) + val REDUCE_POST_SHUFFLE_PARTITIONS_ENABLED = buildConf("spark.sql.adaptive.shuffle.reducePostShufflePartitions.enabled") .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is enabled, this enables reducing " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 9252827856af4..621c063e5a7d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -40,49 +40,60 @@ case class InsertAdaptiveSparkPlan( private val conf = adaptiveExecutionContext.session.sessionState.conf - def containShuffle(plan: SparkPlan): Boolean = { - plan.find { - case _: Exchange => true - case s: SparkPlan => !s.requiredChildDistribution.forall(_ == UnspecifiedDistribution) - }.isDefined - } - - def containSubQuery(plan: SparkPlan): Boolean = { - plan.find(_.expressions.exists(_.find { - case _: SubqueryExpression => true - case _ => false - }.isDefined)).isDefined - } - override def apply(plan: SparkPlan): SparkPlan = applyInternal(plan, false) private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match { + case _ if !conf.adaptiveExecutionEnabled => plan case _: ExecutedCommandExec => plan - case _ if conf.adaptiveExecutionEnabled && supportAdaptive(plan) - && (isSubquery || containShuffle(plan) || containSubQuery(plan)) => - try { - // Plan sub-queries recursively and pass in the shared stage cache for exchange reuse. Fall - // back to non-adaptive mode if adaptive execution is supported in any of the sub-queries. - val subqueryMap = buildSubqueryMap(plan) - val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap) - val preprocessingRules = Seq( - planSubqueriesRule) - // Run pre-processing rules. - val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preprocessingRules) - logDebug(s"Adaptive execution enabled for plan: $plan") - AdaptiveSparkPlanExec(newPlan, adaptiveExecutionContext, preprocessingRules, isSubquery) - } catch { - case SubqueryAdaptiveNotSupportedException(subquery) => - logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is enabled " + - s"but is not supported for sub-query: $subquery.") - plan - } - case _ => - if (conf.adaptiveExecutionEnabled) { + case _ if shouldApplyAQE(plan, isSubquery) => + if (supportAdaptive(plan)) { + try { + // Plan sub-queries recursively and pass in the shared stage cache for exchange reuse. + // Fall back to non-AQE mode if AQE is not supported in any of the sub-queries. + val subqueryMap = buildSubqueryMap(plan) + val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap) + val preprocessingRules = Seq( + planSubqueriesRule) + // Run pre-processing rules. + val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preprocessingRules) + logDebug(s"Adaptive execution enabled for plan: $plan") + AdaptiveSparkPlanExec(newPlan, adaptiveExecutionContext, preprocessingRules, isSubquery) + } catch { + case SubqueryAdaptiveNotSupportedException(subquery) => + logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is enabled " + + s"but is not supported for sub-query: $subquery.") + plan + } + } else { logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is enabled " + s"but is not supported for query: $plan.") + plan } - plan + + case _ => plan + } + + // AQE is only useful when the query has exchanges or sub-queries. This method returns true if + // one of the following conditions is satisfied: + // - The config ADAPTIVE_EXECUTION_FORCE_APPLY is true. + // - The input query is from a sub-query. When this happens, it means we've already decided to + // apply AQE for the main query and we must continue to do it. + // - The query contains exchanges. + // - The query may need to add exchanges. It's an overkill to run `EnsureRequirements` here, so + // we just check `SparkPlan.requiredChildDistribution` and see if it's possible that the + // the query needs to add exchanges later. + // - The query contains sub-query. + private def shouldApplyAQE(plan: SparkPlan, isSubquery: Boolean): Boolean = { + conf.getConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) || isSubquery || { + plan.find { + case _: Exchange => true + case p if !p.requiredChildDistribution.forall(_ == UnspecifiedDistribution) => true + case p => p.expressions.exists(_.find { + case _: SubqueryExpression => true + case _ => false + }.isDefined) + }.isDefined + } } private def supportAdaptive(plan: SparkPlan): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 78a1183664749..96e977221e512 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -780,4 +780,13 @@ class AdaptiveQueryExecSuite ) } } + + test("force apply AQE") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + val plan = sql("SELECT * FROM testData").queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } } From 1baee64750b4098ec37be6408906c70674579eb8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 6 Feb 2020 13:33:39 -0800 Subject: [PATCH 0028/1280] [SPARK-27986][SQL][FOLLOWUP] window aggregate function with filter predicate is not supported ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/26656. We don't support window aggregate function with filter predicate yet and we should fail explicitly. Observable metrics has the same issue. This PR fixes it as well. ### Why are the changes needed? If we simply ignore filter predicate when we don't support it, the result is wrong. ### Does this PR introduce any user-facing change? yea, fix the query result. ### How was this patch tested? new tests Closes #27476 from cloud-fan/filter. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun (cherry picked from commit 5a4c70b4e2367441ce4260f02d39d3345078f411) Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/analysis/Analyzer.scala | 4 ++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 3 +++ .../analysis/AnalysisErrorSuite.scala | 20 +++++++++++++++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 10 +++++++++- .../resources/sql-tests/inputs/window.sql | 5 +++++ .../sql-tests/results/window.sql.out | 13 +++++++++++- 6 files changed, 51 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 56cc2a274bb7a..75f1aa7185ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2428,6 +2428,10 @@ class Analyzer( } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + case WindowExpression(ae: AggregateExpression, _) if ae.filter.isDefined => + failAnalysis( + "window aggregate function with filter predicate is not supported yet.") + // Extract Windowed AggregateExpression case we @ WindowExpression( ae @ AggregateExpression(function, _, _, _, _), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4ec737fd9b70d..e769e038c960f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -308,6 +308,9 @@ trait CheckAnalysis extends PredicateHelper { case a: AggregateExpression if a.isDistinct => e.failAnalysis( "distinct aggregates are not allowed in observed metrics, but found: " + s.sql) + case a: AggregateExpression if a.filter.isDefined => + e.failAnalysis("aggregates with filter predicate are not allowed in " + + "observed metrics, but found: " + s.sql) case _: Attribute if !seenAggregate => e.failAnalysis (s"attribute ${s.sql} can only be used as an argument to an " + "aggregate function.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7023dbe2a3672..5cc0453135c07 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -164,6 +164,22 @@ class AnalysisErrorSuite extends AnalysisTest { UnspecifiedFrame)).as("window")), "Distinct window functions are not supported" :: Nil) + errorTest( + "window aggregate function with filter predicate", + testRelation2.select( + WindowExpression( + AggregateExpression( + Count(UnresolvedAttribute("b")), + Complete, + isDistinct = false, + filter = Some(UnresolvedAttribute("b") > 1)), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as("window")), + "window aggregate function with filter predicate is not supported" :: Nil + ) + errorTest( "distinct function", CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), @@ -191,12 +207,12 @@ class AnalysisErrorSuite extends AnalysisTest { "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) errorTest( - "DISTINCT and FILTER cannot be used in aggregate functions at the same time", + "DISTINCT aggregate function with filter predicate", CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) errorTest( - "FILTER expression is non-deterministic, it cannot be used in aggregate functions", + "non-deterministic filter predicate in aggregate functions", CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), "FILTER expression is non-deterministic, it cannot be used in aggregate functions" :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5405009c9e208..c747d394b1bc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ @@ -736,5 +736,13 @@ class AnalysisSuite extends AnalysisTest with Matchers { b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil, CollectMetrics("evt1", count :: Nil, tblB)) assertAnalysisError(query, "Multiple definitions of observed metrics" :: "evt1" :: Nil) + + // Aggregate with filter predicate - fail + val sumWithFilter = sum.transform { + case a: AggregateExpression => a.copy(filter = Some(true)) + }.asInstanceOf[NamedExpression] + assertAnalysisError( + CollectMetrics("evt1", sumWithFilter :: Nil, testRelation), + "aggregates with filter predicate are not allowed" :: Nil) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index e25a252418301..3d05dfda6c3fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -120,3 +120,8 @@ SELECT cate, sum(val) OVER (w) FROM testData WHERE val is not null WINDOW w AS (PARTITION BY cate ORDER BY val); + +-- with filter predicate +SELECT val, cate, +count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate) +FROM testData ORDER BY cate, val; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index f795374735f59..625088f90ced9 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -380,3 +380,14 @@ a 4 b 1 b 3 b 6 + + +-- !query +SELECT val, cate, +count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate) +FROM testData ORDER BY cate, val +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +window aggregate function with filter predicate is not supported yet.; From 6130deb911e6dc65530c5aa92cd27cb0afad92b2 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 6 Feb 2020 13:42:09 -0800 Subject: [PATCH 0029/1280] [MINOR][INFRA][3.0] Enable GitHub Action in branch-3.0 ### What changes were proposed in this pull request? This aims to enable `GitHub Action` in `branch-3.0`. ### Why are the changes needed? Currently, it's not enabled. commitlog This will protect `branch-3.0` by monitoring every commits and PR against `branch-3.0`. ### Does this PR introduce any user-facing change? No. This is a dev-only infra. ### How was this patch tested? See the GitHub Action triggering in this PR. Closes #27480 from dongjoon-hyun/GHA-branch-3.0. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .github/workflows/master.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index d53119ad75599..0f80c88eba2f2 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -3,10 +3,10 @@ name: master on: push: branches: - - master + - branch-3.0 pull_request: branches: - - master + - branch-3.0 jobs: build: From 475a8f2e2639fde28787d58b50c1d38f9a9c05ed Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 6 Feb 2020 14:58:53 -0800 Subject: [PATCH 0030/1280] [MINOR][DOC] Fix document UI left menu broken ### What changes were proposed in this pull request? Fix the left menu broken introduced in #25459. ### Why are the changes needed? The `left-menu-wrapper` CSS reused for both ml-guide and sql-programming-guide, the before changes will break the UI. Before: ![image](https://user-images.githubusercontent.com/4833765/73952563-1061d800-493a-11ea-8a75-d802a1534a44.png) ![image](https://user-images.githubusercontent.com/4833765/73952584-18217c80-493a-11ea-85a3-ce5f9875545f.png) ![image](https://user-images.githubusercontent.com/4833765/73952605-21124e00-493a-11ea-8d79-24f4dfec73d9.png) After: ![image](https://user-images.githubusercontent.com/4833765/73952630-2a031f80-493a-11ea-80ff-4630801cfaf4.png) ![image](https://user-images.githubusercontent.com/4833765/73952652-30919700-493a-11ea-9db1-8bb4a3f913b4.png) ![image](https://user-images.githubusercontent.com/4833765/73952671-35eee180-493a-11ea-801b-d50c4397adf2.png) ### Does this PR introduce any user-facing change? Document UI change only. ### How was this patch tested? Local test, screenshot attached below. Closes #27479 from xuanyuanking/doc-ui. Authored-by: Yuanjian Li Signed-off-by: Dongjoon Hyun (cherry picked from commit 4804445327f06ae3a26365d8f110f06ea07eb637) Signed-off-by: Dongjoon Hyun --- docs/css/main.css | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/css/main.css b/docs/css/main.css index e24dff8531f24..dc05d287996be 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -211,8 +211,6 @@ a.anchorjs-link:hover { text-decoration: none; } float: left; position: fixed; overflow-y: scroll; - top: 0; - bottom: 0; } .left-menu { From 683a07d5bad9519df88d1528e1afc143e367a688 Mon Sep 17 00:00:00 2001 From: sharif ahmad Date: Fri, 7 Feb 2020 18:42:16 +0900 Subject: [PATCH 0031/1280] [MINOR][DOCS] Fix typos at python/pyspark/sql/types.py ### What changes were proposed in this pull request? This PR fixes some typos in `python/pyspark/sql/types.py` file. ### Why are the changes needed? To deliver correct wording in documentation and codes. ### Does this PR introduce any user-facing change? Yes, it fixes some typos in user-facing API documentation. ### How was this patch tested? Locally tested the linter. Closes #27475 from sharifahmad2061/master. Lead-authored-by: sharif ahmad Co-authored-by: Sharif ahmad Signed-off-by: HyukjinKwon (cherry picked from commit dd2f4431f56e02cd06848b02b93b4cf34c97a5d5) Signed-off-by: HyukjinKwon --- python/pyspark/sql/types.py | 40 ++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8afff77b723a8..a5302e7bfd5ab 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -76,7 +76,7 @@ def json(self): def needConversion(self): """ - Does this type need to conversion between Python object and internal SQL object. + Does this type needs conversion between Python object and internal SQL object. This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. """ @@ -210,17 +210,17 @@ class DecimalType(FractionalType): The precision can be up to 38, the scale must be less or equal to precision. - When create a DecimalType, the default precision and scale is (10, 0). When infer + When creating a DecimalType, the default precision and scale is (10, 0). When inferring schema from decimal.Decimal objects, it will be DecimalType(38, 18). - :param precision: the maximum total number of digits (default: 10) + :param precision: the maximum (i.e. total) number of digits (default: 10) :param scale: the number of digits on right side of dot. (default: 0) """ def __init__(self, precision=10, scale=0): self.precision = precision self.scale = scale - self.hasPrecisionInfo = True # this is public API + self.hasPrecisionInfo = True # this is a public API def simpleString(self): return "decimal(%d,%d)" % (self.precision, self.scale) @@ -457,8 +457,8 @@ class StructType(DataType): This is the data type representing a :class:`Row`. - Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. - A contained :class:`StructField` can be accessed by name or position. + Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s. + A contained :class:`StructField` can be accessed by its name or position. >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] @@ -492,8 +492,8 @@ def __init__(self, fields=None): def add(self, field, data_type=None, nullable=True, metadata=None): """ - Construct a StructType by adding new elements to it to define the schema. The method accepts - either: + Construct a StructType by adding new elements to it, to define the schema. + The method accepts either: a) A single parameter which is a StructField object. b) Between 2 and 4 parameters as (name, data_type, nullable (optional), @@ -676,7 +676,7 @@ def needConversion(self): @classmethod def _cachedSqlType(cls): """ - Cache the sqlType() into class, because it's heavy used in `toInternal`. + Cache the sqlType() into class, because it's heavily used in `toInternal`. """ if not hasattr(cls, "_cached_sql_type"): cls._cached_sql_type = cls.sqlType() @@ -693,7 +693,7 @@ def fromInternal(self, obj): def serialize(self, obj): """ - Converts the a user-type object into a SQL datum. + Converts a user-type object into a SQL datum. """ raise NotImplementedError("UDT must implement toInternal().") @@ -760,7 +760,7 @@ def __eq__(self, other): def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals - to :class:`DataType.simpleString`, except that top level struct type can omit + :class:`DataType.simpleString`, except that the top level struct type can omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted @@ -921,7 +921,7 @@ def _parse_datatype_json_value(json_value): # We should be careful here. The size of these types in python depends on C # implementation. We need to make sure that this conversion does not lose any # precision. Also, JVM only support signed types, when converting unsigned types, -# keep in mind that it required 1 more bit when stored as singed types. +# keep in mind that it require 1 more bit when stored as signed types. # # Reference for C integer size, see: # ISO/IEC 9899:201x specification, chapter 5.2.4.2.1 Sizes of integer types . @@ -959,7 +959,7 @@ def _int_size_to_type(size): if size <= 64: return LongType -# The list of all supported array typecodes is stored here +# The list of all supported array typecodes, is stored here _array_type_mappings = { # Warning: Actual properties for float and double in C is not specified in C. # On almost every system supported by both python and JVM, they are IEEE 754 @@ -995,9 +995,9 @@ def _int_size_to_type(size): _array_type_mappings['c'] = StringType # SPARK-21465: -# In python2, array of 'L' happened to be mistakenly partially supported. To +# In python2, array of 'L' happened to be mistakenly, just partially supported. To # avoid breaking user's code, we should keep this partial support. Below is a -# dirty hacking to keep this partial support and make the unit test passes +# dirty hacking to keep this partial support and pass the unit test. import platform if sys.version_info[0] < 3 and platform.python_implementation() != 'PyPy': if 'L' not in _array_type_mappings.keys(): @@ -1071,7 +1071,7 @@ def _infer_schema(row, names=None): def _has_nulltype(dt): - """ Return whether there is NullType in `dt` or not """ + """ Return whether there is a NullType in `dt` or not """ if isinstance(dt, StructType): return any(_has_nulltype(f.dataType) for f in dt.fields) elif isinstance(dt, ArrayType): @@ -1211,7 +1211,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): This verifier also checks the value of obj against datatype and raises a ValueError if it's not within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is - not checked, so it will become infinity when cast to Java float if it overflows. + not checked, so it will become infinity when cast to Java float, if it overflows. >>> _make_type_verifier(StructType([]))(None) >>> _make_type_verifier(StringType())("") @@ -1433,7 +1433,7 @@ class Row(tuple): ``key in row`` will search through row keys. Row can be used to create a row object by using named arguments. - It is not allowed to omit a named argument to represent the value is + It is not allowed to omit a named argument to represent that the value is None or missing. This should be explicitly set to None in this case. NOTE: As of Spark 3.0.0, Rows created from named arguments no longer have @@ -1524,9 +1524,9 @@ def __new__(cls, *args, **kwargs): def asDict(self, recursive=False): """ - Return as an dict + Return as a dict - :param recursive: turns the nested Row as dict (default: False). + :param recursive: turns the nested Rows to dict (default: False). >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} True From a9d9d834c600a4b9888f8781cec04fcbef29972d Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 8 Feb 2020 02:32:07 +0800 Subject: [PATCH 0032/1280] [SPARK-30752][SQL] Fix `to_utc_timestamp` on daylight saving day ### What changes were proposed in this pull request? - Rewrite the `convertTz` method of `DateTimeUtils` using Java 8 time API - Change types of `convertTz` parameters from `TimeZone` to `ZoneId`. This allows to avoid unnecessary conversions `TimeZone` -> `ZoneId` and performance regressions as a consequence. ### Why are the changes needed? - Fixes incorrect behavior of `to_utc_timestamp` on daylight saving day. For example: ```scala scala> df.select(to_utc_timestamp(lit("2019-11-03T12:00:00"), "Asia/Hong_Kong").as("local UTC")).show +-------------------+ | local UTC| +-------------------+ |2019-11-03 03:00:00| +-------------------+ ``` but the result must be 2019-11-03 04:00:00: Screen Shot 2020-02-06 at 20 09 36 - Simplifies the code, and make it more maintainable - Switches `convertTz` on Proleptic Gregorian calendar used by Java 8 time classes by default. That makes the function consistent to other date-time functions. ### Does this PR introduce any user-facing change? Yes, after the changes `to_utc_timestamp` returns the correct result `2019-11-03 04:00:00`. ### How was this patch tested? - By existing test suite `DateTimeUtilsSuite`, `DateFunctionsSuite` and `DateExpressionsSuite`. - Added `convert time zones on a daylight saving day` to DateFunctionsSuite Closes #27474 from MaxGekk/port-convertTz-on-Java8-api. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan (cherry picked from commit a3e77773cfa03a18d31370acd9a10562ff5312bb) Signed-off-by: Wenchen Fan --- .../expressions/datetimeExpressions.scala | 16 ++++------ .../sql/catalyst/util/DateTimeUtils.scala | 28 ++++-------------- .../parquet/VectorizedColumnReader.java | 9 +++--- .../VectorizedParquetRecordReader.java | 6 ++-- .../parquet/ParquetFileFormat.scala | 2 +- .../parquet/ParquetReadSupport.scala | 5 ++-- .../parquet/ParquetRecordMaterializer.scala | 4 +-- .../parquet/ParquetRowConverter.scala | 9 +++--- .../ParquetPartitionReaderFactory.scala | 10 +++---- .../apache/spark/sql/DateFunctionsSuite.scala | 29 +++++++++++++++++-- .../ParquetInteroperabilitySuite.scala | 5 ++-- 11 files changed, 64 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 28f1d34267224..aa2bd5a1273e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1176,14 +1176,12 @@ case class FromUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzClass = classOf[TimeZone].getName + val tzClass = classOf[ZoneId].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$escapedTz");""") - val utcTerm = "tzUTC" - ctx.addImmutableStateIfNotExists(tzClass, utcTerm, - v => s"""$v = $dtu.getTimeZone("UTC");""") + v => s"""$v = $dtu.getZoneId("$escapedTz");""") + val utcTerm = "java.time.ZoneOffset.UTC" val eval = left.genCode(ctx) ev.copy(code = code""" |${eval.code} @@ -1382,14 +1380,12 @@ case class ToUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzClass = classOf[TimeZone].getName + val tzClass = classOf[ZoneId].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$escapedTz");""") - val utcTerm = "tzUTC" - ctx.addImmutableStateIfNotExists(tzClass, utcTerm, - v => s"""$v = $dtu.getTimeZone("UTC");""") + v => s"""$v = $dtu.getZoneId("$escapedTz");""") + val utcTerm = "java.time.ZoneOffset.UTC" val eval = left.genCode(ctx) ev.copy(code = code""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6800abb2ae109..8eb560944d4cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -801,27 +801,9 @@ object DateTimeUtils { * mapping, the conversion here may return wrong result, we should make the timestamp * timezone-aware. */ - def convertTz(ts: SQLTimestamp, fromZone: TimeZone, toZone: TimeZone): SQLTimestamp = { - // We always use local timezone to parse or format a timestamp - val localZone = defaultTimeZone() - val utcTs = if (fromZone.getID == localZone.getID) { - ts - } else { - // get the human time using local time zone, that actually is in fromZone. - val localZoneOffsetMs = localZone.getOffset(MICROSECONDS.toMillis(ts)) - val localTsUs = ts + MILLISECONDS.toMicros(localZoneOffsetMs) // in fromZone - val offsetFromLocalMs = getOffsetFromLocalMillis(MICROSECONDS.toMillis(localTsUs), fromZone) - localTsUs - MILLISECONDS.toMicros(offsetFromLocalMs) - } - if (toZone.getID == localZone.getID) { - utcTs - } else { - val toZoneOffsetMs = toZone.getOffset(MICROSECONDS.toMillis(utcTs)) - val localTsUs = utcTs + MILLISECONDS.toMicros(toZoneOffsetMs) // in toZone - // treat it as local timezone, convert to UTC (we could get the expected human time back) - val offsetFromLocalMs = getOffsetFromLocalMillis(MICROSECONDS.toMillis(localTsUs), localZone) - localTsUs - MILLISECONDS.toMicros(offsetFromLocalMs) - } + def convertTz(ts: SQLTimestamp, fromZone: ZoneId, toZone: ZoneId): SQLTimestamp = { + val rebasedDateTime = microsToInstant(ts).atZone(toZone).toLocalDateTime.atZone(fromZone) + instantToMicros(rebasedDateTime.toInstant) } /** @@ -829,7 +811,7 @@ object DateTimeUtils { * representation in their timezone. */ def fromUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { - convertTz(time, TimeZoneGMT, getTimeZone(timeZone)) + convertTz(time, ZoneOffset.UTC, getZoneId(timeZone)) } /** @@ -837,7 +819,7 @@ object DateTimeUtils { * string representation in their timezone. */ def toUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { - convertTz(time, getTimeZone(timeZone), TimeZoneGMT) + convertTz(time, getZoneId(timeZone), ZoneOffset.UTC) } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index ba26b57567e64..329465544979d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.time.ZoneId; +import java.time.ZoneOffset; import java.util.Arrays; -import java.util.TimeZone; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesInput; @@ -98,14 +99,14 @@ public class VectorizedColumnReader { private final ColumnDescriptor descriptor; private final OriginalType originalType; // The timezone conversion to apply to int96 timestamps. Null if no conversion. - private final TimeZone convertTz; - private static final TimeZone UTC = DateTimeUtils.TimeZoneUTC(); + private final ZoneId convertTz; + private static final ZoneId UTC = ZoneOffset.UTC; public VectorizedColumnReader( ColumnDescriptor descriptor, OriginalType originalType, PageReader pageReader, - TimeZone convertTz) throws IOException { + ZoneId convertTz) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; this.convertTz = convertTz; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index f02861355c404..7306709a79c34 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.time.ZoneId; import java.util.Arrays; import java.util.List; -import java.util.TimeZone; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; @@ -86,7 +86,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to * workaround incompatibilities between different engines when writing timestamp values. */ - private TimeZone convertTz = null; + private ZoneId convertTz = null; /** * columnBatch object that is used for batch decoding. This is created on first use and triggers @@ -116,7 +116,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private final MemoryMode MEMORY_MODE; - public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap, int capacity) { + public VectorizedParquetRecordReader(ZoneId convertTz, boolean useOffHeap, int capacity) { this.convertTz = convertTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; this.capacity = capacity; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f52aaf0140e1d..29dbd8dfbca8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -295,7 +295,7 @@ class ParquetFileFormat val convertTz = if (timestampConversion && !isCreatedByParquetMr) { - Some(DateTimeUtils.getTimeZone(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 69c8bad5f1c83..c05ecf16311ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.{Locale, Map => JMap, TimeZone} +import java.time.ZoneId +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ @@ -49,7 +50,7 @@ import org.apache.spark.sql.types._ * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]] * to [[prepareForRead()]], but use a private `var` for simplicity. */ -class ParquetReadSupport(val convertTz: Option[TimeZone], +class ParquetReadSupport(val convertTz: Option[ZoneId], enableVectorizedReader: Boolean) extends ReadSupport[InternalRow] with Logging { private var catalystRequestedSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index 3098a332d3027..5622169df1281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.TimeZone +import java.time.ZoneId import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType @@ -36,7 +36,7 @@ private[parquet] class ParquetRecordMaterializer( parquetSchema: MessageType, catalystSchema: StructType, schemaConverter: ParquetToSparkSchemaConverter, - convertTz: Option[TimeZone]) + convertTz: Option[ZoneId]) extends RecordMaterializer[InternalRow] { private val rootConverter = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 98ac2ecd2955c..850adae8a6b95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder -import java.util.TimeZone +import java.time.{ZoneId, ZoneOffset} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -125,7 +125,7 @@ private[parquet] class ParquetRowConverter( schemaConverter: ParquetToSparkSchemaConverter, parquetType: GroupType, catalystType: StructType, - convertTz: Option[TimeZone], + convertTz: Option[ZoneId], updater: ParentContainerUpdater) extends ParquetGroupConverter(updater) with Logging { @@ -154,8 +154,6 @@ private[parquet] class ParquetRowConverter( |${catalystType.prettyJson} """.stripMargin) - private[this] val UTC = DateTimeUtils.TimeZoneUTC - /** * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates * converted filed values to the `ordinal`-th cell in `currentRow`. @@ -292,7 +290,8 @@ private[parquet] class ParquetRowConverter( val timeOfDayNanos = buf.getLong val julianDay = buf.getInt val rawTime = DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) - val adjTime = convertTz.map(DateTimeUtils.convertTz(rawTime, _, UTC)).getOrElse(rawTime) + val adjTime = convertTz.map(DateTimeUtils.convertTz(rawTime, _, ZoneOffset.UTC)) + .getOrElse(rawTime) updater.setLong(adjTime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index b2fc724057eba..047bc74a8d81e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import java.net.URI -import java.util.TimeZone +import java.time.ZoneId import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -117,7 +117,7 @@ case class ParquetPartitionReaderFactory( file: PartitionedFile, buildReaderFunc: ( ParquetInputSplit, InternalRow, TaskAttemptContextImpl, Option[FilterPredicate], - Option[TimeZone]) => RecordReader[Void, T]): RecordReader[Void, T] = { + Option[ZoneId]) => RecordReader[Void, T]): RecordReader[Void, T] = { val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) @@ -156,7 +156,7 @@ case class ParquetPartitionReaderFactory( val convertTz = if (timestampConversion && !isCreatedByParquetMr) { - Some(DateTimeUtils.getTimeZone(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) } else { None } @@ -184,7 +184,7 @@ case class ParquetPartitionReaderFactory( partitionValues: InternalRow, hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], - convertTz: Option[TimeZone]): RecordReader[Void, InternalRow] = { + convertTz: Option[ZoneId]): RecordReader[Void, InternalRow] = { logDebug(s"Falling back to parquet-mr") val taskContext = Option(TaskContext.get()) // ParquetRecordReader returns InternalRow @@ -213,7 +213,7 @@ case class ParquetPartitionReaderFactory( partitionValues: InternalRow, hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], - convertTz: Option[TimeZone]): VectorizedParquetRecordReader = { + convertTz: Option[ZoneId]): VectorizedParquetRecordReader = { val taskContext = Option(TaskContext.get()) val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 3b3d3cc3d7a17..bb8cdf3cb6de1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.time.Instant -import java.util.Locale +import java.time.{Instant, LocalDateTime} +import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} @@ -803,4 +803,29 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { checkTimeZoneParsing(null) } } + + test("SPARK-30752: convert time zones on a daylight saving day") { + val systemTz = "PST" + val sessionTz = "UTC" + val fromTz = "Asia/Hong_Kong" + val fromTs = "2019-11-03T12:00:00" // daylight saving date in PST + val utsTs = "2019-11-03T04:00:00" + val defaultTz = TimeZone.getDefault + try { + TimeZone.setDefault(DateTimeUtils.getTimeZone(systemTz)) + withSQLConf( + SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> sessionTz) { + val expected = LocalDateTime.parse(utsTs) + .atZone(DateTimeUtils.getZoneId(sessionTz)) + .toInstant + val df = Seq(fromTs).toDF("localTs") + checkAnswer( + df.select(to_utc_timestamp(col("localTs"), fromTz)), + Row(expected)) + } + } finally { + TimeZone.setDefault(defaultTz) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 1ded34f24e436..649a46f190580 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import java.time.ZoneOffset import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} @@ -145,8 +146,8 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS impalaFileData.map { ts => DateTimeUtils.toJavaTimestamp(DateTimeUtils.convertTz( DateTimeUtils.fromJavaTimestamp(ts), - DateTimeUtils.TimeZoneUTC, - DateTimeUtils.getTimeZone(conf.sessionLocalTimeZone))) + ZoneOffset.UTC, + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))) } } val fullExpectations = (ts ++ impalaExpectations).map(_.toString).sorted.toArray From 73eba319f2ec548fd655a780ee3d240b24b6276d Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sat, 8 Feb 2020 14:10:28 -0800 Subject: [PATCH 0033/1280] [SPARK-28228][SQL] Change the default behavior for name conflict in nested WITH clause ### What changes were proposed in this pull request? This is a follow-up for #25029, in this PR we throw an AnalysisException when name conflict is detected in nested WITH clause. In this way, the config `spark.sql.legacy.ctePrecedence.enabled` should be set explicitly for the expected behavior. ### Why are the changes needed? The original change might risky to end-users, it changes behavior silently. ### Does this PR introduce any user-facing change? Yes, change the config `spark.sql.legacy.ctePrecedence.enabled` as optional. ### How was this patch tested? New UT. Closes #27454 from xuanyuanking/SPARK-28228-follow. Authored-by: Yuanjian Li Signed-off-by: Dongjoon Hyun (cherry picked from commit 3db3e39f1122350f55f305bee049363621c5894d) Signed-off-by: Dongjoon Hyun --- docs/sql-migration-guide.md | 2 +- .../catalyst/analysis/CTESubstitution.scala | 49 ++- .../apache/spark/sql/internal/SQLConf.scala | 6 +- .../sql-tests/inputs/cte-nonlegacy.sql | 2 + .../sql-tests/results/cte-nonlegacy.sql.out | 343 ++++++++++++++++++ .../resources/sql-tests/results/cte.sql.out | 30 +- 6 files changed, 415 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/cte-nonlegacy.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 5a5e802f6a900..be0fe32ded99b 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -101,7 +101,7 @@ license: | - Since Spark 3.0, if files or subdirectories disappear during recursive directory listing (i.e. they appear in an intermediate listing but then cannot be read or listed during later phases of the recursive directory listing, due to either concurrent file deletions or object store consistency issues) then the listing will fail with an exception unless `spark.sql.files.ignoreMissingFiles` is `true` (default `false`). In previous versions, these missing files or subdirectories would be ignored. Note that this change of behavior only applies during initial table file listing (or during `REFRESH TABLE`), not during query execution: the net change is that `spark.sql.files.ignoreMissingFiles` is now obeyed during table file listing / query planning, not only at query execution time. - - Since Spark 3.0, substitution order of nested WITH clauses is changed and an inner CTE definition takes precedence over an outer. In version 2.4 and earlier, `WITH t AS (SELECT 1), t2 AS (WITH t AS (SELECT 2) SELECT * FROM t) SELECT * FROM t2` returns `1` while in version 3.0 it returns `2`. The previous behaviour can be restored by setting `spark.sql.legacy.ctePrecedence.enabled` to `true`. + - Since Spark 3.0, Spark throws an AnalysisException if name conflict is detected in the nested WITH clause by default. It forces the users to choose the specific substitution order they wanted, which is controlled by `spark.sql.legacy.ctePrecedence.enabled`. If set to false (which is recommended), inner CTE definitions take precedence over outer definitions. For example, set the config to `false`, `WITH t AS (SELECT 1), t2 AS (WITH t AS (SELECT 2) SELECT * FROM t) SELECT * FROM t2` returns `2`, while setting it to `true`, the result is `1` which is the behavior in version 2.4 and earlier. - Since Spark 3.0, the `add_months` function does not adjust the resulting date to a last day of month if the original date is a last day of months. For example, `select add_months(DATE'2019-02-28', 1)` results `2019-03-28`. In Spark version 2.4 and earlier, the resulting date is adjusted when the original date is a last day of months. For example, adding a month to `2019-02-28` results in `2019-03-31`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 60e6bf8db06d7..d2be15d87d023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, With} import org.apache.spark.sql.catalyst.rules.Rule @@ -28,10 +29,54 @@ import org.apache.spark.sql.internal.SQLConf.LEGACY_CTE_PRECEDENCE_ENABLED */ object CTESubstitution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - if (SQLConf.get.getConf(LEGACY_CTE_PRECEDENCE_ENABLED)) { + val isLegacy = SQLConf.get.getConf(LEGACY_CTE_PRECEDENCE_ENABLED) + if (isLegacy.isEmpty) { + assertNoNameConflictsInCTE(plan, inTraverse = false) + traverseAndSubstituteCTE(plan, inTraverse = false) + } else if (isLegacy.get) { legacyTraverseAndSubstituteCTE(plan) } else { - traverseAndSubstituteCTE(plan, false) + traverseAndSubstituteCTE(plan, inTraverse = false) + } + } + + /** + * Check the plan to be traversed has naming conflicts in nested CTE or not, traverse through + * child, innerChildren and subquery for the current plan. + */ + private def assertNoNameConflictsInCTE( + plan: LogicalPlan, + inTraverse: Boolean, + cteNames: Set[String] = Set.empty): Unit = { + plan.foreach { + case w @ With(child, relations) => + val newNames = relations.map { + case (cteName, _) => + if (cteNames.contains(cteName)) { + throw new AnalysisException(s"Name $cteName is ambiguous in nested CTE. " + + s"Please set ${LEGACY_CTE_PRECEDENCE_ENABLED.key} to false so that name defined " + + "in inner CTE takes precedence. See more details in SPARK-28228.") + } else { + cteName + } + }.toSet + child.transformExpressions { + case e: SubqueryExpression => + assertNoNameConflictsInCTE(e.plan, inTraverse = true, cteNames ++ newNames) + e + } + w.innerChildren.foreach { p => + assertNoNameConflictsInCTE(p, inTraverse = true, cteNames ++ newNames) + } + + case other if inTraverse => + other.transformExpressions { + case e: SubqueryExpression => + assertNoNameConflictsInCTE(e.plan, inTraverse = true, cteNames) + e + } + + case _ => } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bed8410acaed7..a72bd53188f29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2098,9 +2098,11 @@ object SQLConf { val LEGACY_CTE_PRECEDENCE_ENABLED = buildConf("spark.sql.legacy.ctePrecedence.enabled") .internal() - .doc("When true, outer CTE definitions takes precedence over inner definitions.") + .doc("When true, outer CTE definitions takes precedence over inner definitions. If set to " + + "false, inner CTE definitions take precedence. The default value is empty, " + + "AnalysisException is thrown while name conflict is detected in nested CTE.") .booleanConf - .createWithDefault(false) + .createOptional val LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC = buildConf("spark.sql.legacy.arrayExistsFollowsThreeValuedLogic") diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-nonlegacy.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-nonlegacy.sql new file mode 100644 index 0000000000000..b711bf338ab08 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-nonlegacy.sql @@ -0,0 +1,2 @@ +--SET spark.sql.legacy.ctePrecedence.enabled = false +--IMPORT cte.sql diff --git a/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out new file mode 100644 index 0000000000000..2d87781193c25 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out @@ -0,0 +1,343 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 27 + + +-- !query +create temporary view t as select * from values 0, 1, 2 as t(id) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view t2 as select * from values 0, 1 as t(id) +-- !query schema +struct<> +-- !query output + + + +-- !query +WITH s AS (SELECT 1 FROM s) SELECT * FROM s +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: s; line 1 pos 25 + + +-- !query +WITH r AS (SELECT (SELECT * FROM r)) +SELECT * FROM r +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: r; line 1 pos 33 + + +-- !query +WITH t AS (SELECT 1 FROM t) SELECT * FROM t +-- !query schema +struct<1:int> +-- !query output +1 +1 +1 + + +-- !query +WITH s1 AS (SELECT 1 FROM s2), s2 AS (SELECT 1 FROM s1) SELECT * FROM s1, s2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: s2; line 1 pos 26 + + +-- !query +WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2 +-- !query schema +struct +-- !query output +0 2 +0 2 +1 2 +1 2 + + +-- !query +WITH CTE1 AS ( + SELECT b.id AS id + FROM T2 a + CROSS JOIN (SELECT id AS id FROM T2) b +) +SELECT t1.id AS c1, + t2.id AS c2 +FROM CTE1 t1 + CROSS JOIN CTE1 t2 +-- !query schema +struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 1 +0 1 +0 1 +0 1 +1 0 +1 0 +1 0 +1 0 +1 1 +1 1 +1 1 +1 1 + + +-- !query +WITH t(x) AS (SELECT 1) +SELECT * FROM t WHERE x = 1 +-- !query schema +struct +-- !query output +1 + + +-- !query +WITH t(x, y) AS (SELECT 1, 2) +SELECT * FROM t WHERE x = 1 AND y = 2 +-- !query schema +struct +-- !query output +1 2 + + +-- !query +WITH t(x, x) AS (SELECT 1, 2) +SELECT * FROM t +-- !query schema +struct +-- !query output +1 2 + + +-- !query +WITH t() AS (SELECT 1) +SELECT * FROM t +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'WITH t()'(line 1, pos 7) + +== SQL == +WITH t() AS (SELECT 1) +-------^^^ +SELECT * FROM t + + +-- !query +WITH + t(x) AS (SELECT 1), + t(x) AS (SELECT 2) +SELECT * FROM t +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +CTE definition can't have duplicate names: 't'.(line 1, pos 0) + +== SQL == +WITH +^^^ + t(x) AS (SELECT 1), + t(x) AS (SELECT 2) +SELECT * FROM t + + +-- !query +WITH t as ( + WITH t2 AS (SELECT 1) + SELECT * FROM t2 +) +SELECT * FROM t +-- !query schema +struct<1:int> +-- !query output +1 + + +-- !query +SELECT max(c) FROM ( + WITH t(c) AS (SELECT 1) + SELECT * FROM t +) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT ( + WITH t AS (SELECT 1) + SELECT * FROM t +) +-- !query schema +struct +-- !query output +1 + + +-- !query +WITH + t AS (SELECT 1), + t2 AS ( + WITH t AS (SELECT 2) + SELECT * FROM t + ) +SELECT * FROM t2 +-- !query schema +struct<2:int> +-- !query output +2 + + +-- !query +WITH + t(c) AS (SELECT 1), + t2 AS ( + SELECT ( + SELECT max(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t + ) + ) + ) +SELECT * FROM t2 +-- !query schema +struct +-- !query output +2 + + +-- !query +WITH + t AS (SELECT 1), + t2 AS ( + WITH t AS (SELECT 2), + t2 AS ( + WITH t AS (SELECT 3) + SELECT * FROM t + ) + SELECT * FROM t2 + ) +SELECT * FROM t2 +-- !query schema +struct<3:int> +-- !query output +3 + + +-- !query +WITH t(c) AS (SELECT 1) +SELECT max(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t +) +-- !query schema +struct +-- !query output +2 + + +-- !query +WITH t(c) AS (SELECT 1) +SELECT sum(c) FROM ( + SELECT max(c) AS c FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t + ) +) +-- !query schema +struct +-- !query output +2 + + +-- !query +WITH t(c) AS (SELECT 1) +SELECT sum(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT max(c) AS c FROM ( + WITH t(c) AS (SELECT 3) + SELECT * FROM t + ) +) +-- !query schema +struct +-- !query output +3 + + +-- !query +WITH t AS (SELECT 1) +SELECT ( + WITH t AS (SELECT 2) + SELECT * FROM t +) +-- !query schema +struct +-- !query output +2 + + +-- !query +WITH t AS (SELECT 1) +SELECT ( + SELECT ( + WITH t AS (SELECT 2) + SELECT * FROM t + ) +) +-- !query schema +struct +-- !query output +2 + + +-- !query +WITH t AS (SELECT 1) +SELECT ( + WITH t AS (SELECT 2) + SELECT ( + WITH t AS (SELECT 3) + SELECT * FROM t + ) +) +-- !query schema +struct +-- !query output +3 + + +-- !query +DROP VIEW IF EXISTS t +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t2 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/cte.sql.out b/sql/core/src/test/resources/sql-tests/results/cte.sql.out index 2d87781193c25..1d50aa8f57505 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte.sql.out @@ -204,9 +204,10 @@ WITH ) SELECT * FROM t2 -- !query schema -struct<2:int> +struct<> -- !query output -2 +org.apache.spark.sql.AnalysisException +Name t is ambiguous in nested CTE. Please set spark.sql.legacy.ctePrecedence.enabled to false so that name defined in inner CTE takes precedence. See more details in SPARK-28228.; -- !query @@ -222,9 +223,10 @@ WITH ) SELECT * FROM t2 -- !query schema -struct +struct<> -- !query output -2 +org.apache.spark.sql.AnalysisException +Name t is ambiguous in nested CTE. Please set spark.sql.legacy.ctePrecedence.enabled to false so that name defined in inner CTE takes precedence. See more details in SPARK-28228.; -- !query @@ -240,9 +242,10 @@ WITH ) SELECT * FROM t2 -- !query schema -struct<3:int> +struct<> -- !query output -3 +org.apache.spark.sql.AnalysisException +Name t is ambiguous in nested CTE. Please set spark.sql.legacy.ctePrecedence.enabled to false so that name defined in inner CTE takes precedence. See more details in SPARK-28228.; -- !query @@ -293,9 +296,10 @@ SELECT ( SELECT * FROM t ) -- !query schema -struct +struct<> -- !query output -2 +org.apache.spark.sql.AnalysisException +Name t is ambiguous in nested CTE. Please set spark.sql.legacy.ctePrecedence.enabled to false so that name defined in inner CTE takes precedence. See more details in SPARK-28228.; -- !query @@ -307,9 +311,10 @@ SELECT ( ) ) -- !query schema -struct +struct<> -- !query output -2 +org.apache.spark.sql.AnalysisException +Name t is ambiguous in nested CTE. Please set spark.sql.legacy.ctePrecedence.enabled to false so that name defined in inner CTE takes precedence. See more details in SPARK-28228.; -- !query @@ -322,9 +327,10 @@ SELECT ( ) ) -- !query schema -struct +struct<> -- !query output -3 +org.apache.spark.sql.AnalysisException +Name t is ambiguous in nested CTE. Please set spark.sql.legacy.ctePrecedence.enabled to false so that name defined in inner CTE takes precedence. See more details in SPARK-28228.; -- !query From 287d93fa1c5232fcbb5c4fe0ddbb4aeca39cd6b9 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sat, 8 Feb 2020 14:28:15 -0800 Subject: [PATCH 0034/1280] [SPARK-29587][DOC][FOLLOWUP] Add `SQL` tab in the `Data Types` page ### What changes were proposed in this pull request? Add the new tab `SQL` in the `Data Types` page. ### Why are the changes needed? New type added in SPARK-29587. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Locally test by Jekyll. ![image](https://user-images.githubusercontent.com/4833765/73908593-2e511d80-48e5-11ea-85a7-6ee451e6b727.png) Closes #27447 from xuanyuanking/SPARK-29587-follow. Authored-by: Yuanjian Li Signed-off-by: Dongjoon Hyun (cherry picked from commit e1cd4d9dc25ac3abe33c07686fc2a7d1f2b5c122) Signed-off-by: Dongjoon Hyun --- docs/sql-ref-datatypes.md | 75 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index 0add62b10ed6b..9700608fe8a34 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -631,4 +631,79 @@ from pyspark.sql.types import * + +
+ +The following table shows the type names as well as aliases used in Spark SQL parser for each data type. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeSQL name
BooleanType BOOLEAN
ByteType BYTE, TINYINT
ShortType SHORT, SMALLINT
IntegerType INT, INTEGER
LongType LONG, BIGINT
FloatType FLOAT, REAL
DoubleType DOUBLE
DateType DATE
TimestampType TIMESTAMP
StringType STRING
BinaryType BINARY
DecimalType DECIMAL, DEC, NUMERIC
CalendarIntervalType INTERVAL
ArrayType ARRAY<element_type>
StructType STRUCT<field1_name: field1_type, field2_name: field2_type, ...>
MapType MAP<key_type, value_type>
+
From d5865493ae71e6369e9f3350dd7e694afcf57298 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sun, 9 Feb 2020 19:20:47 +0900 Subject: [PATCH 0035/1280] [SPARK-30510][SQL][DOCS] Publicly document Spark SQL configuration options ### What changes were proposed in this pull request? This PR adds a doc builder for Spark SQL's configuration options. Here's what the new Spark SQL config docs look like ([configuration.html.zip](https://github.com/apache/spark/files/4172109/configuration.html.zip)): ![Screen Shot 2020-02-07 at 12 13 23 PM](https://user-images.githubusercontent.com/1039369/74050007-425b5480-49a3-11ea-818c-42700c54d1fb.png) Compare this to the [current docs](http://spark.apache.org/docs/3.0.0-preview2/configuration.html#spark-sql): ![Screen Shot 2020-02-04 at 4 55 10 PM](https://user-images.githubusercontent.com/1039369/73790828-24a5a980-476f-11ea-998c-12cd613883e8.png) ### Why are the changes needed? There is no visibility into the various Spark SQL configs on [the config docs page](http://spark.apache.org/docs/3.0.0-preview2/configuration.html#spark-sql). ### Does this PR introduce any user-facing change? No, apart from new documentation. ### How was this patch tested? I tested this manually by building the docs and reviewing them in my browser. Closes #27459 from nchammas/SPARK-30510-spark-sql-options. Authored-by: Nicholas Chammas Signed-off-by: HyukjinKwon (cherry picked from commit 339c0f9a623521acd4d66292b3fe3e6c4ec3b108) Signed-off-by: HyukjinKwon --- docs/.gitignore | 1 + docs/configuration.md | 46 ++----- sql/README.md | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 35 +++--- .../spark/sql/api/python/PythonSQLUtils.scala | 7 ++ sql/create-docs.sh | 14 ++- ...en-sql-markdown.py => gen-sql-api-docs.py} | 8 +- sql/gen-sql-config-docs.py | 117 ++++++++++++++++++ 8 files changed, 163 insertions(+), 67 deletions(-) create mode 100644 docs/.gitignore rename sql/{gen-sql-markdown.py => gen-sql-api-docs.py} (96%) create mode 100644 sql/gen-sql-config-docs.py diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000000000..2260493b46ab3 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +sql-configs.html diff --git a/docs/configuration.md b/docs/configuration.md index 5bd3f3e80cf71..1343755f9d87f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2399,47 +2399,15 @@ the driver or executor, or, in the absence of that value, the number of cores av Please refer to the [Security](security.html) page for available options on how to secure different Spark subsystems. -### Spark SQL - -Running the SET -v command will show the entire list of the SQL configuration. - -
-
-{% highlight scala %} -// spark is an existing SparkSession -spark.sql("SET -v").show(numRows = 200, truncate = false) -{% endhighlight %} - -
- -
- -{% highlight java %} -// spark is an existing SparkSession -spark.sql("SET -v").show(200, false); -{% endhighlight %} -
- -
- -{% highlight python %} -# spark is an existing SparkSession -spark.sql("SET -v").show(n=200, truncate=False) -{% endhighlight %} - -
- -
- -{% highlight r %} -sparkR.session() -properties <- sql("SET -v") -showDF(properties, numRows = 200, truncate = FALSE) -{% endhighlight %} +{% for static_file in site.static_files %} + {% if static_file.name == 'sql-configs.html' %} +### Spark SQL -
-
+ {% include_relative sql-configs.html %} + {% break %} + {% endif %} +{% endfor %} ### Spark Streaming diff --git a/sql/README.md b/sql/README.md index 67e3225e2c275..ae5ebd1d75370 100644 --- a/sql/README.md +++ b/sql/README.md @@ -9,4 +9,4 @@ Spark SQL is broken up into four subprojects: - Hive Support (sql/hive) - Includes extensions that allow users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. -Running `./sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`. +Running `./sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`, and SQL configuration documentation that gets included as part of `configuration.md` in the main `docs` directory. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a72bd53188f29..e38fe7606c4ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -324,11 +324,11 @@ object SQLConf { .doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + "Note that currently statistics are only supported for Hive Metastore tables where the " + - "command ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been " + + "command `ANALYZE TABLE COMPUTE STATISTICS noscan` has been " + "run, and file-based data source tables where the statistics are computed directly on " + "the files of data.") .bytesConf(ByteUnit.BYTE) - .createWithDefault(10L * 1024 * 1024) + .createWithDefaultString("10MB") val LIMIT_SCALE_UP_FACTOR = buildConf("spark.sql.limit.scaleUpFactor") .internal() @@ -402,7 +402,7 @@ object SQLConf { s"an effect when '${ADAPTIVE_EXECUTION_ENABLED.key}' and " + s"'${REDUCE_POST_SHUFFLE_PARTITIONS_ENABLED.key}' is enabled.") .bytesConf(ByteUnit.BYTE) - .createWithDefault(64 * 1024 * 1024) + .createWithDefaultString("64MB") val SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS = buildConf("spark.sql.adaptive.shuffle.maxNumPostShufflePartitions") @@ -436,7 +436,7 @@ object SQLConf { .doc("Configures the minimum size in bytes for a partition that is considered as a skewed " + "partition in adaptive skewed join.") .bytesConf(ByteUnit.BYTE) - .createWithDefault(64 * 1024 * 1024) + .createWithDefaultString("64MB") val ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR = buildConf("spark.sql.adaptive.optimizeSkewedJoin.skewedPartitionFactor") @@ -770,7 +770,7 @@ object SQLConf { val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) - .createWithDefault(5 * 60) + .createWithDefaultString(s"${5 * 60}") // This is only used for the thriftserver val THRIFTSERVER_POOL = buildConf("spark.sql.thriftserver.scheduler.pool") @@ -830,7 +830,7 @@ object SQLConf { .createWithDefault(true) val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") - .doc("The maximum number of buckets allowed. Defaults to 100000") + .doc("The maximum number of buckets allowed.") .intConf .checkValue(_ > 0, "the value of spark.sql.sources.bucketing.maxBuckets must be greater than 0") .createWithDefault(100000) @@ -1022,7 +1022,7 @@ object SQLConf { "This configuration is effective only when using file-based sources such as Parquet, JSON " + "and ORC.") .bytesConf(ByteUnit.BYTE) - .createWithDefault(128 * 1024 * 1024) // parquet.block.size + .createWithDefaultString("128MB") // parquet.block.size val FILES_OPEN_COST_IN_BYTES = buildConf("spark.sql.files.openCostInBytes") .internal() @@ -1161,7 +1161,8 @@ object SQLConf { val VARIABLE_SUBSTITUTE_ENABLED = buildConf("spark.sql.variable.substitute") - .doc("This enables substitution using syntax like ${var} ${system:var} and ${env:var}.") + .doc("This enables substitution using syntax like `${var}`, `${system:var}`, " + + "and `${env:var}`.") .booleanConf .createWithDefault(true) @@ -1171,7 +1172,7 @@ object SQLConf { .doc("Enable two-level aggregate hash map. When enabled, records will first be " + "inserted/looked-up at a 1st-level, small, fast map, and then fallback to a " + "2nd-level, larger, slower map when 1st level is full or keys cannot be found. " + - "When disabled, records go directly to the 2nd level. Defaults to true.") + "When disabled, records go directly to the 2nd level.") .booleanConf .createWithDefault(true) @@ -1325,10 +1326,10 @@ object SQLConf { val STREAMING_STOP_TIMEOUT = buildConf("spark.sql.streaming.stopTimeout") - .doc("How long to wait for the streaming execution thread to stop when calling the " + - "streaming query's stop() method in milliseconds. 0 or negative values wait indefinitely.") + .doc("How long to wait in milliseconds for the streaming execution thread to stop when " + + "calling the streaming query's stop() method. 0 or negative values wait indefinitely.") .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(0L) + .createWithDefaultString("0") val STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL = buildConf("spark.sql.streaming.noDataProgressEventInterval") @@ -1611,10 +1612,10 @@ object SQLConf { val PANDAS_UDF_BUFFER_SIZE = buildConf("spark.sql.pandas.udf.buffer.size") .doc( - s"Same as ${BUFFER_SIZE} but only applies to Pandas UDF executions. If it is not set, " + - s"the fallback is ${BUFFER_SIZE}. Note that Pandas execution requires more than 4 bytes. " + - "Lowering this value could make small Pandas UDF batch iterated and pipelined; however, " + - "it might degrade performance. See SPARK-27870.") + s"Same as `${BUFFER_SIZE.key}` but only applies to Pandas UDF executions. If it is not " + + s"set, the fallback is `${BUFFER_SIZE.key}`. Note that Pandas execution requires more " + + "than 4 bytes. Lowering this value could make small Pandas UDF batch iterated and " + + "pipelined; however, it might degrade performance. See SPARK-27870.") .fallbackConf(BUFFER_SIZE) val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = @@ -2039,7 +2040,7 @@ object SQLConf { .checkValue(i => i >= 0 && i <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, "Invalid " + "value for 'spark.sql.maxPlanStringLength'. Length must be a valid string length " + "(nonnegative and shorter than the maximum size).") - .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + .createWithDefaultString(s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}") val SET_COMMAND_REJECTS_SPARK_CORE_CONFS = buildConf("spark.sql.legacy.setCommandRejectsSparkCoreConfs") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index b232aa18c816e..bf3055d5e3e09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType private[sql] object PythonSQLUtils { @@ -39,6 +40,12 @@ private[sql] object PythonSQLUtils { FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray } + def listSQLConfigs(): Array[(String, String, String)] = { + val conf = new SQLConf() + // Py4J doesn't seem to translate Seq well, so we convert to an Array. + conf.getAllDefinedConfs.toArray + } + /** * Python callable function to read a file in Arrow stream format and create a [[RDD]] * using each serialized ArrowRecordBatch as a partition. diff --git a/sql/create-docs.sh b/sql/create-docs.sh index 4353708d22f7b..44aa877332fd5 100755 --- a/sql/create-docs.sh +++ b/sql/create-docs.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# Script to create SQL API docs. This requires `mkdocs` and to build +# Script to create SQL API and config docs. This requires `mkdocs` and to build # Spark first. After running this script the html docs can be found in # $SPARK_HOME/sql/site @@ -39,14 +39,16 @@ fi pushd "$FWDIR" > /dev/null -# Now create the markdown file rm -fr docs mkdir docs -echo "Generating markdown files for SQL documentation." -"$SPARK_HOME/bin/spark-submit" gen-sql-markdown.py -# Now create the HTML files -echo "Generating HTML files for SQL documentation." +echo "Generating SQL API Markdown files." +"$SPARK_HOME/bin/spark-submit" gen-sql-api-docs.py + +echo "Generating SQL configuration table HTML file." +"$SPARK_HOME/bin/spark-submit" gen-sql-config-docs.py + +echo "Generating HTML files for SQL API documentation." mkdocs build --clean rm -fr docs diff --git a/sql/gen-sql-markdown.py b/sql/gen-sql-api-docs.py similarity index 96% rename from sql/gen-sql-markdown.py rename to sql/gen-sql-api-docs.py index e0529f8310613..4feee7ad52570 100644 --- a/sql/gen-sql-markdown.py +++ b/sql/gen-sql-api-docs.py @@ -15,10 +15,11 @@ # limitations under the License. # -import sys import os from collections import namedtuple +from pyspark.java_gateway import launch_gateway + ExpressionInfo = namedtuple( "ExpressionInfo", "className name usage arguments examples note since deprecated") @@ -219,8 +220,7 @@ def generate_sql_markdown(jvm, path): if __name__ == "__main__": - from pyspark.java_gateway import launch_gateway - jvm = launch_gateway().jvm - markdown_file_path = "%s/docs/index.md" % os.path.dirname(sys.argv[0]) + spark_root_dir = os.path.dirname(os.path.dirname(__file__)) + markdown_file_path = os.path.join(spark_root_dir, "sql/docs/index.md") generate_sql_markdown(jvm, markdown_file_path) diff --git a/sql/gen-sql-config-docs.py b/sql/gen-sql-config-docs.py new file mode 100644 index 0000000000000..04f5a850c9980 --- /dev/null +++ b/sql/gen-sql-config-docs.py @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +from collections import namedtuple +from textwrap import dedent + +# To avoid adding a new direct dependency, we import markdown from within mkdocs. +from mkdocs.structure.pages import markdown +from pyspark.java_gateway import launch_gateway + +SQLConfEntry = namedtuple( + "SQLConfEntry", ["name", "default", "description"]) + + +def get_public_sql_configs(jvm): + sql_configs = [ + SQLConfEntry( + name=_sql_config._1(), + default=_sql_config._2(), + description=_sql_config._3(), + ) + for _sql_config in jvm.org.apache.spark.sql.api.python.PythonSQLUtils.listSQLConfigs() + ] + return sql_configs + + +def generate_sql_configs_table(sql_configs, path): + """ + Generates an HTML table at `path` that lists all public SQL + configuration options. + + The table will look something like this: + + ```html + + + + + + + + + + ... + +
Property NameDefaultMeaning
spark.sql.adaptive.enabledfalse

When true, enable adaptive query execution.

+ ``` + """ + value_reference_pattern = re.compile(r"^$") + + with open(path, 'w') as f: + f.write(dedent( + """ + + + """ + )) + for config in sorted(sql_configs, key=lambda x: x.name): + if config.default == "": + default = "(none)" + elif config.default.startswith(" + + + + + """ + .format( + name=config.name, + default=default, + description=markdown.markdown(config.description), + ) + )) + f.write("
Property NameDefaultMeaning
{name}{default}{description}
\n") + + +if __name__ == "__main__": + jvm = launch_gateway().jvm + sql_configs = get_public_sql_configs(jvm) + + spark_root_dir = os.path.dirname(os.path.dirname(__file__)) + sql_configs_table_path = os.path.join(spark_root_dir, "docs/sql-configs.html") + + generate_sql_configs_table(sql_configs, path=sql_configs_table_path) From 00c761d15e2628c8c5e1a9dad5ae122f292eae28 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 9 Feb 2020 14:18:51 -0800 Subject: [PATCH 0036/1280] [SPARK-30684 ][WEBUI][FollowUp] A new approach for SPARK-30684 ### What changes were proposed in this pull request? Simplify the changes for adding metrics description for WholeStageCodegen in https://github.com/apache/spark/pull/27405 ### Why are the changes needed? In https://github.com/apache/spark/pull/27405, the UI changes can be made without using the function `adjustPositionOfOperationName` to adjust the position of operation name and mark as an operation-name class. I suggest we make simpler changes so that it would be easier for future development. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Manual test with the queries provided in https://github.com/apache/spark/pull/27405 ``` sc.parallelize(1 to 10).toDF.sort("value").filter("value > 1").selectExpr("value * 2").show sc.parallelize(1 to 10).toDF.sort("value").filter("value > 1").selectExpr("value * 2").write.format("json").mode("overwrite").save("/tmp/test_output") sc.parallelize(1 to 10).toDF.write.format("json").mode("append").save("/tmp/test_output") ``` ![image](https://user-images.githubusercontent.com/1097932/74073629-e3f09f00-49bf-11ea-90dc-1edb5ca29e5e.png) Closes #27490 from gengliangwang/wholeCodegenUI. Authored-by: Gengliang Wang Signed-off-by: Gengliang Wang (cherry picked from commit b877aac14657832d1b896ea57e06b0d0fd15ee01) Signed-off-by: Gengliang Wang --- .../sql/execution/ui/static/spark-sql-viz.css | 8 ++--- .../sql/execution/ui/static/spark-sql-viz.js | 31 +------------------ .../sql/execution/ui/SparkPlanGraph.scala | 18 +++++------ 3 files changed, 11 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index eff0142dc523f..20188387c9ba4 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -18,6 +18,7 @@ #plan-viz-graph .label { font-weight: normal; text-shadow: none; + color: black; } #plan-viz-graph svg g.cluster rect { @@ -32,13 +33,8 @@ stroke-width: 1px; } -/* This declaration is needed to define the width of rectangles */ -#plan-viz-graph svg text :first-child { - font-weight: bold; -} - /* Highlight the SparkPlan node name */ -#plan-viz-graph svg text .operator-name { +#plan-viz-graph svg text :first-child { font-weight: bold; } diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js index e6ce641a841b3..c8349149439c8 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js @@ -34,7 +34,6 @@ function renderPlanViz() { preprocessGraphLayout(g); var renderer = new dagreD3.render(); renderer(graph, g); - adjustPositionOfOperationName(); // Round corners on rectangles svg @@ -82,7 +81,7 @@ function setupTooltipForSparkPlanNode(nodeId) { * and sizes of graph elements, e.g. padding, font style, shape. */ function preprocessGraphLayout(g) { - g.graph().ranksep = "90"; + g.graph().ranksep = "70"; var nodes = g.nodes(); for (var i = 0; i < nodes.length; i++) { var node = g.node(nodes[i]); @@ -129,34 +128,6 @@ function resizeSvg(svg) { .attr("height", height); } - -/* Helper function to adjust the position of operation name and mark as a operation-name class. */ -function adjustPositionOfOperationName() { - $("#plan-viz-graph svg text") - .each(function() { - var tspans = $(this).find("tspan"); - - if (tspans[0].textContent.trim() !== "") { - var isOperationNameOnly = - $(tspans).filter(function(i, n) { - return i !== 0 && n.textContent.trim() !== ""; - }).length === 0; - - if (isOperationNameOnly) { - // If the only text in a node is operation name, - // vertically centering the position of operation name. - var operationName = tspans[0].textContent; - var half = parseInt(tspans.length / 2); - tspans[0].textContent = tspans[half].textContent; - tspans[half].textContent = operationName; - $(tspans[half]).addClass("operator-name"); - } else { - tspans.first().addClass("operator-name"); - } - } - }); -} - /* Helper function to convert attributes to numeric values. */ function toFloat(f) { if (f) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1e767c3c043c3..d31d77840b802 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -169,20 +169,18 @@ private[ui] class SparkPlanGraphNode( metric.name + ": " + value } - // If there are metrics, display each entry in a separate line. - // Note: whitespace between two "\n"s is to create an empty line between the name of - // SparkPlan and metrics. If removing it, it won't display the empty line in UI. - builder ++= "\n \n" - if (values.nonEmpty) { + // If there are metrics, display each entry in a separate line. + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + builder ++= "\n \n" builder ++= values.mkString("\n") + s""" $id [label="${StringEscapeUtils.escapeJava(builder.toString())}"];""" } else { - // A certain level of height is needed for a rect as a node in a sub-graph - // to avoid layout collapse for sub-graphs. - builder ++= " " + // SPARK-30684: when there is no metrics, add empty lines to increase the height of the node, + // so that there won't be gaps between an edge and a small node. + s""" $id [labelType="html" label="
$name

"];""" } - - s""" $id [label="${StringEscapeUtils.escapeJava(builder.toString())}"];""" } } From 0c5403721cc085c6515a585c39b195f75ef6ac7d Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Sat, 8 Feb 2020 02:47:44 +0800 Subject: [PATCH 0037/1280] [SPARK-30614][SQL] The native ALTER COLUMN syntax should change one property at a time ### What changes were proposed in this pull request? The current ALTER COLUMN syntax allows to change multiple properties at a time: ``` ALTER TABLE table=multipartIdentifier (ALTER | CHANGE) COLUMN? column=multipartIdentifier (TYPE dataType)? (COMMENT comment=STRING)? colPosition? ``` The SQL standard (section 11.12) only allows changing one property at a time. This is also true on other recent SQL systems like [snowflake](https://docs.snowflake.net/manuals/sql-reference/sql/alter-table-column.html) and [redshift](https://docs.aws.amazon.com/redshift/latest/dg/r_ALTER_TABLE.html). (credit to cloud-fan) This PR proposes to change ALTER COLUMN to follow SQL standard, thus allows altering only one column property at a time. Note that ALTER COLUMN syntax being changed here is newly added in Spark 3.0, so it doesn't affect Spark 2.4 behavior. ### Why are the changes needed? To follow SQL standard (and other recent SQL systems) behavior. ### Does this PR introduce any user-facing change? Yes, now the user can update the column properties only one at a time. For example, ``` ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint COMMENT 'new comment' ``` should be broken into ``` ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint ALTER TABLE table1 ALTER COLUMN a.b.c COMMENT 'new comment' ``` ### How was this patch tested? Updated existing tests. Closes #27444 from imback82/alter_column_one_at_a_time. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 14 ++-- .../sql/catalyst/parser/AstBuilder.scala | 78 ++++++++++--------- .../sql/catalyst/parser/DDLParserSuite.scala | 23 +++--- .../analysis/ResolveSessionCatalog.scala | 24 +++--- .../sql-tests/inputs/change-column.sql | 21 +++-- .../sql-tests/results/change-column.sql.out | 46 ++++++++--- .../spark/sql/connector/AlterTableTests.scala | 13 ---- .../sql/execution/command/DDLSuite.scala | 5 +- .../command/PlanResolutionSuite.scala | 47 ++++++++--- 9 files changed, 165 insertions(+), 106 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 08d5ff53bf2e2..563ef69b3b8ae 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -158,12 +158,9 @@ statement SET TBLPROPERTIES tablePropertyList #setTableProperties | ALTER (TABLE | VIEW) multipartIdentifier UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties - | ALTER TABLE table=multipartIdentifier + |ALTER TABLE table=multipartIdentifier (ALTER | CHANGE) COLUMN? column=multipartIdentifier - (TYPE dataType)? commentSpec? colPosition? #alterTableColumn - | ALTER TABLE table=multipartIdentifier - ALTER COLUMN? column=multipartIdentifier - setOrDrop=(SET | DROP) NOT NULL #alterColumnNullability + alterColumnAction? #alterTableAlterColumn | ALTER TABLE table=multipartIdentifier partitionSpec? CHANGE COLUMN? colName=multipartIdentifier colType colPosition? #hiveChangeColumn @@ -983,6 +980,13 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; +alterColumnAction + : TYPE dataType + | commentSpec + | colPosition + | setOrDrop=(SET | DROP) NOT NULL + ; + // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. // - Reserved keywords: // Keywords that are reserved and can't be used as identifiers for table, view, column, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e9ad84472904d..6fc65e14868e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2940,55 +2940,61 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Parse a [[AlterTableAlterColumnStatement]] command. + * Parse a [[AlterTableAlterColumnStatement]] command to alter a column's property. * * For example: * {{{ * ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint - * ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint COMMENT 'new comment' + * ALTER TABLE table1 ALTER COLUMN a.b.c SET NOT NULL + * ALTER TABLE table1 ALTER COLUMN a.b.c DROP NOT NULL * ALTER TABLE table1 ALTER COLUMN a.b.c COMMENT 'new comment' + * ALTER TABLE table1 ALTER COLUMN a.b.c FIRST + * ALTER TABLE table1 ALTER COLUMN a.b.c AFTER x * }}} */ - override def visitAlterTableColumn( - ctx: AlterTableColumnContext): LogicalPlan = withOrigin(ctx) { - val verb = if (ctx.CHANGE != null) "CHANGE" else "ALTER" - if (ctx.dataType == null && ctx.commentSpec() == null && ctx.colPosition == null) { + override def visitAlterTableAlterColumn( + ctx: AlterTableAlterColumnContext): LogicalPlan = withOrigin(ctx) { + val action = ctx.alterColumnAction + if (action == null) { + val verb = if (ctx.CHANGE != null) "CHANGE" else "ALTER" operationNotAllowed( - s"ALTER TABLE table $verb COLUMN requires a TYPE or a COMMENT or a FIRST/AFTER", ctx) + s"ALTER TABLE table $verb COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER", + ctx) + } + + val dataType = if (action.dataType != null) { + Some(typedVisit[DataType](action.dataType)) + } else { + None + } + val nullable = if (action.setOrDrop != null) { + action.setOrDrop.getType match { + case SqlBaseParser.SET => Some(false) + case SqlBaseParser.DROP => Some(true) + } + } else { + None + } + val comment = if (action.commentSpec != null) { + Some(visitCommentSpec(action.commentSpec())) + } else { + None + } + val position = if (action.colPosition != null) { + Some(typedVisit[ColumnPosition](action.colPosition)) + } else { + None } + assert(Seq(dataType, nullable, comment, position).count(_.nonEmpty) == 1) + AlterTableAlterColumnStatement( visitMultipartIdentifier(ctx.table), typedVisit[Seq[String]](ctx.column), - dataType = Option(ctx.dataType).map(typedVisit[DataType]), - nullable = None, - comment = Option(ctx.commentSpec()).map(visitCommentSpec), - position = Option(ctx.colPosition).map(typedVisit[ColumnPosition])) - } - - /** - * Parse a [[AlterTableAlterColumnStatement]] command to change column nullability. - * - * For example: - * {{{ - * ALTER TABLE table1 ALTER COLUMN a.b.c SET NOT NULL - * ALTER TABLE table1 ALTER COLUMN a.b.c DROP NOT NULL - * }}} - */ - override def visitAlterColumnNullability(ctx: AlterColumnNullabilityContext): LogicalPlan = { - withOrigin(ctx) { - val nullable = ctx.setOrDrop.getType match { - case SqlBaseParser.SET => false - case SqlBaseParser.DROP => true - } - AlterTableAlterColumnStatement( - visitMultipartIdentifier(ctx.table), - typedVisit[Seq[String]](ctx.column), - dataType = None, - nullable = Some(nullable), - comment = None, - position = None) - } + dataType = dataType, + nullable = nullable, + comment = comment, + position = position) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 56d52571d1cc3..bc7b51f25b20d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -646,17 +646,18 @@ class DDLParserSuite extends AnalysisTest { Some(first()))) } - test("alter table: update column type, comment and position") { - comparePlans( - parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c " + - "TYPE bigint COMMENT 'new comment' AFTER d"), - AlterTableAlterColumnStatement( - Seq("table_name"), - Seq("a", "b", "c"), - Some(LongType), - None, - Some("new comment"), - Some(after("d")))) + test("alter table: mutiple property changes are not allowed") { + intercept[ParseException] { + parsePlan("ALTER TABLE table_name ALTER COLUMN a.b.c " + + "TYPE bigint COMMENT 'new comment'")} + + intercept[ParseException] { + parsePlan("ALTER TABLE table_name ALTER COLUMN a.b.c " + + "TYPE bigint COMMENT AFTER d")} + + intercept[ParseException] { + parsePlan("ALTER TABLE table_name ALTER COLUMN a.b.c " + + "TYPE bigint COMMENT 'new comment' AFTER d")} } test("alter table: SET/DROP NOT NULL") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 486e7f1f84b46..77d549c28aae5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -77,10 +77,6 @@ class ResolveSessionCatalog( throw new AnalysisException( "ALTER COLUMN with qualified column is only supported with v2 tables.") } - if (a.dataType.isEmpty) { - throw new AnalysisException( - "ALTER COLUMN with v1 tables must specify new data type.") - } if (a.nullable.isDefined) { throw new AnalysisException( "ALTER COLUMN with v1 tables cannot specify NOT NULL.") @@ -92,17 +88,27 @@ class ResolveSessionCatalog( val builder = new MetadataBuilder // Add comment to metadata a.comment.map(c => builder.putString("comment", c)) + val colName = a.column(0) + val dataType = a.dataType.getOrElse { + v1Table.schema.findNestedField(Seq(colName), resolver = conf.resolver) + .map(_._2.dataType) + .getOrElse { + throw new AnalysisException( + s"ALTER COLUMN cannot find column ${quote(colName)} in v1 table. " + + s"Available: ${v1Table.schema.fieldNames.mkString(", ")}") + } + } // Add Hive type string to metadata. - val cleanedDataType = HiveStringType.replaceCharType(a.dataType.get) - if (a.dataType.get != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, a.dataType.get.catalogString) + val cleanedDataType = HiveStringType.replaceCharType(dataType) + if (dataType != cleanedDataType) { + builder.putString(HIVE_TYPE_STRING, dataType.catalogString) } val newColumn = StructField( - a.column(0), + colName, cleanedDataType, nullable = true, builder.build()) - AlterTableChangeColumnCommand(tbl.asTableIdentifier, a.column(0), newColumn) + AlterTableChangeColumnCommand(tbl.asTableIdentifier, colName, newColumn) }.getOrElse { val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index dd2fc660b53e3..2b57891cfcbc5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -15,29 +15,34 @@ ALTER TABLE test_change CHANGE a TYPE STRING; DESC test_change; -- Change column position (not supported yet) -ALTER TABLE test_change CHANGE a TYPE INT AFTER b; -ALTER TABLE test_change CHANGE b TYPE STRING FIRST; +ALTER TABLE test_change CHANGE a AFTER b; +ALTER TABLE test_change CHANGE b FIRST; DESC test_change; -- Change column comment -ALTER TABLE test_change CHANGE a TYPE INT COMMENT 'this is column a'; -ALTER TABLE test_change CHANGE b TYPE STRING COMMENT '#*02?`'; -ALTER TABLE test_change CHANGE c TYPE INT COMMENT ''; +ALTER TABLE test_change CHANGE a COMMENT 'this is column a'; +ALTER TABLE test_change CHANGE b COMMENT '#*02?`'; +ALTER TABLE test_change CHANGE c COMMENT ''; DESC test_change; -- Don't change anything. -ALTER TABLE test_change CHANGE a TYPE INT COMMENT 'this is column a'; +ALTER TABLE test_change CHANGE a TYPE INT; +ALTER TABLE test_change CHANGE a COMMENT 'this is column a'; DESC test_change; -- Change a invalid column ALTER TABLE test_change CHANGE invalid_col TYPE INT; DESC test_change; +-- Check case insensitivity. +ALTER TABLE test_change CHANGE A COMMENT 'case insensitivity'; +DESC test_change; + -- Change column can't apply to a temporary/global_temporary view CREATE TEMPORARY VIEW temp_view(a, b) AS SELECT 1, "one"; -ALTER TABLE temp_view CHANGE a TYPE INT COMMENT 'this is column a'; +ALTER TABLE temp_view CHANGE a TYPE INT; CREATE GLOBAL TEMPORARY VIEW global_temp_view(a, b) AS SELECT 1, "one"; -ALTER TABLE global_temp.global_temp_view CHANGE a TYPE INT COMMENT 'this is column a'; +ALTER TABLE global_temp.global_temp_view CHANGE a TYPE INT; -- DROP TEST TABLE DROP TABLE test_change; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 5bb00e028c4b7..b1a32ad1f63e9 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 28 -- !query @@ -27,7 +27,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -Operation not allowed: ALTER TABLE table CHANGE COLUMN requires a TYPE or a COMMENT or a FIRST/AFTER(line 1, pos 0) +Operation not allowed: ALTER TABLE table CHANGE COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER(line 1, pos 0) == SQL == ALTER TABLE test_change CHANGE a @@ -83,7 +83,7 @@ c int -- !query -ALTER TABLE test_change CHANGE a TYPE INT AFTER b +ALTER TABLE test_change CHANGE a AFTER b -- !query schema struct<> -- !query output @@ -92,7 +92,7 @@ ALTER COLUMN ... FIRST | ALTER is only supported with v2 tables.; -- !query -ALTER TABLE test_change CHANGE b TYPE STRING FIRST +ALTER TABLE test_change CHANGE b FIRST -- !query schema struct<> -- !query output @@ -111,7 +111,7 @@ c int -- !query -ALTER TABLE test_change CHANGE a TYPE INT COMMENT 'this is column a' +ALTER TABLE test_change CHANGE a COMMENT 'this is column a' -- !query schema struct<> -- !query output @@ -119,7 +119,7 @@ struct<> -- !query -ALTER TABLE test_change CHANGE b TYPE STRING COMMENT '#*02?`' +ALTER TABLE test_change CHANGE b COMMENT '#*02?`' -- !query schema struct<> -- !query output @@ -127,7 +127,7 @@ struct<> -- !query -ALTER TABLE test_change CHANGE c TYPE INT COMMENT '' +ALTER TABLE test_change CHANGE c COMMENT '' -- !query schema struct<> -- !query output @@ -145,7 +145,15 @@ c int -- !query -ALTER TABLE test_change CHANGE a TYPE INT COMMENT 'this is column a' +ALTER TABLE test_change CHANGE a TYPE INT +-- !query schema +struct<> +-- !query output + + + +-- !query +ALTER TABLE test_change CHANGE a COMMENT 'this is column a' -- !query schema struct<> -- !query output @@ -181,6 +189,24 @@ b string #*02?` c int +-- !query +ALTER TABLE test_change CHANGE A COMMENT 'case insensitivity' +-- !query schema +struct<> +-- !query output + + + +-- !query +DESC test_change +-- !query schema +struct +-- !query output +a int case insensitivity +b string #*02?` +c int + + -- !query CREATE TEMPORARY VIEW temp_view(a, b) AS SELECT 1, "one" -- !query schema @@ -190,7 +216,7 @@ struct<> -- !query -ALTER TABLE temp_view CHANGE a TYPE INT COMMENT 'this is column a' +ALTER TABLE temp_view CHANGE a TYPE INT -- !query schema struct<> -- !query output @@ -207,7 +233,7 @@ struct<> -- !query -ALTER TABLE global_temp.global_temp_view CHANGE a TYPE INT COMMENT 'this is column a' +ALTER TABLE global_temp.global_temp_view CHANGE a TYPE INT -- !query schema struct<> -- !query output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 3cdac59c20fc9..420cb01d766a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -651,19 +651,6 @@ trait AlterTableTests extends SharedSparkSession { } } - test("AlterTable: update column type and comment") { - val t = s"${catalogAndNamespace}table_name" - withTable(t) { - sql(s"CREATE TABLE $t (id int) USING $v2Format") - sql(s"ALTER TABLE $t ALTER COLUMN id TYPE bigint COMMENT 'doc'") - - val table = getTableMetadata(t) - - assert(table.name === fullTableName(t)) - assert(table.schema === StructType(Seq(StructField("id", LongType).withComment("doc")))) - } - } - test("AlterTable: update nested column comment") { val t = s"${catalogAndNamespace}table_name" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 913cd80a24c6e..31e00781ae6b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -188,7 +188,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { withTable("t") { sql("CREATE TABLE t(i INT) USING parquet") val e = intercept[AnalysisException] { - sql("ALTER TABLE t ALTER COLUMN i TYPE INT FIRST") + sql("ALTER TABLE t ALTER COLUMN i FIRST") } assert(e.message.contains("ALTER COLUMN ... FIRST | ALTER is only supported with v2 tables")) } @@ -1786,7 +1786,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { column.map(_.metadata).getOrElse(Metadata.empty) } // Ensure that change column will preserve other metadata fields. - sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 TYPE INT COMMENT 'this is col1'") + sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 TYPE INT") + sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 COMMENT 'this is col1'") assert(getMetadata("col1").getString("key") == "value") assert(getMetadata("col1").getString("comment") == "this is col1") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 88f30353cce94..d439e5b1cd651 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1013,27 +1013,29 @@ class PlanResolutionSuite extends AnalysisTest { Seq("v1Table" -> true, "v2Table" -> false, "testcat.tab" -> false).foreach { case (tblName, useV1Command) => val sql1 = s"ALTER TABLE $tblName ALTER COLUMN i TYPE bigint" - val sql2 = s"ALTER TABLE $tblName ALTER COLUMN i TYPE bigint COMMENT 'new comment'" - val sql3 = s"ALTER TABLE $tblName ALTER COLUMN i COMMENT 'new comment'" + val sql2 = s"ALTER TABLE $tblName ALTER COLUMN i COMMENT 'new comment'" val parsed1 = parseAndResolve(sql1) val parsed2 = parseAndResolve(sql2) val tableIdent = TableIdentifier(tblName, None) if (useV1Command) { + val oldColumn = StructField("i", IntegerType) val newColumn = StructField("i", LongType) val expected1 = AlterTableChangeColumnCommand( tableIdent, "i", newColumn) val expected2 = AlterTableChangeColumnCommand( - tableIdent, "i", newColumn.withComment("new comment")) + tableIdent, "i", oldColumn.withComment("new comment")) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) + val sql3 = s"ALTER TABLE $tblName ALTER COLUMN j COMMENT 'new comment'" val e1 = intercept[AnalysisException] { parseAndResolve(sql3) } - assert(e1.getMessage.contains("ALTER COLUMN with v1 tables must specify new data type")) + assert(e1.getMessage.contains( + "ALTER COLUMN cannot find column j in v1 table. Available: i, s")) val sql4 = s"ALTER TABLE $tblName ALTER COLUMN a.b.c TYPE bigint" val e2 = intercept[AnalysisException] { @@ -1051,8 +1053,6 @@ class PlanResolutionSuite extends AnalysisTest { val parsed5 = parseAndResolve(sql5) comparePlans(parsed5, expected5) } else { - val parsed3 = parseAndResolve(sql3) - parsed1 match { case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( @@ -1063,18 +1063,41 @@ class PlanResolutionSuite extends AnalysisTest { parsed2 match { case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( - TableChange.updateColumnType(Array("i"), LongType), TableChange.updateColumnComment(Array("i"), "new comment"))) case _ => fail("expect AlterTable") } + } + } + } - parsed3 match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => - assert(changes == Seq( - TableChange.updateColumnComment(Array("i"), "new comment"))) - case _ => fail("expect AlterTable") + test("alter table: alter column action is not specified") { + val e = intercept[AnalysisException] { + parseAndResolve("ALTER TABLE v1Table ALTER COLUMN i") + } + assert(e.getMessage.contains( + "ALTER TABLE table ALTER COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER")) + } + + test("alter table: alter column case sensitivity for v1 table") { + val tblName = "v1Table" + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val sql = s"ALTER TABLE $tblName ALTER COLUMN I COMMENT 'new comment'" + if (caseSensitive) { + val e = intercept[AnalysisException] { + parseAndResolve(sql) } + assert(e.getMessage.contains( + "ALTER COLUMN cannot find column I in v1 table. Available: i, s")) + } else { + val actual = parseAndResolve(sql) + val expected = AlterTableChangeColumnCommand( + TableIdentifier(tblName, None), + "I", + StructField("I", IntegerType).withComment("new comment")) + comparePlans(actual, expected) } + } } } From b012ff72b64f08e3fcb9e4fbcf04b874711cf5b6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 10 Feb 2020 16:23:44 +0900 Subject: [PATCH 0038/1280] [SPARK-30592][SQL][FOLLOWUP] Add some round-trip test cases ### What changes were proposed in this pull request? Add round-trip tests for CSV and JSON functions as https://github.com/apache/spark/pull/27317#discussion_r376745135 asked. ### Why are the changes needed? improve test coverage ### Does this PR introduce any user-facing change? no ### How was this patch tested? add uts Closes #27510 from yaooqinn/SPARK-30592-F. Authored-by: Kent Yao Signed-off-by: HyukjinKwon (cherry picked from commit 58b9ca1e6f7768b23e752dabc30468c06d0e1c57) Signed-off-by: HyukjinKwon --- .../resources/sql-tests/inputs/interval.sql | 14 +++++-- .../sql-tests/results/ansi/interval.sql.out | 38 ++++++++----------- .../sql-tests/results/interval.sql.out | 38 ++++++++----------- 3 files changed, 40 insertions(+), 50 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index fb6c485f619ae..a4e621e9639d4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -222,7 +222,13 @@ select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); -- interval support for csv and json functions -SELECT from_csv('1, 1 day', 'a INT, b interval'); -SELECT to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)); -SELECT from_json('{"a":"1 days"}', 'a interval'); -SELECT to_json(map('a', interval 25 month 100 day 130 minute)); +SELECT + from_csv('1, 1 day', 'a INT, b interval'), + to_csv(from_csv('1, 1 day', 'a INT, b interval')), + to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), + from_csv(to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), 'a interval, b interval'); +SELECT + from_json('{"a":"1 days"}', 'a interval'), + to_json(from_json('{"a":"1 days"}', 'a interval')), + to_json(map('a', interval 25 month 100 day 130 minute)), + from_json(to_json(map('a', interval 25 month 100 day 130 minute)), 'a interval'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index f37049064d989..7fdb4c53d1dcb 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 101 +-- Number of queries: 99 -- !query @@ -988,32 +988,24 @@ integer overflow -- !query -SELECT from_csv('1, 1 day', 'a INT, b interval') --- !query schema -struct> --- !query output -{"a":1,"b":1 days} - - --- !query -SELECT to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)) --- !query schema -struct --- !query output -2 years 8 months,1 hours 10 minutes - - --- !query -SELECT from_json('{"a":"1 days"}', 'a interval') +SELECT + from_csv('1, 1 day', 'a INT, b interval'), + to_csv(from_csv('1, 1 day', 'a INT, b interval')), + to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), + from_csv(to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), 'a interval, b interval') -- !query schema -struct> +struct,to_csv(from_csv(1, 1 day)):string,to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes')):string,from_csv(to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes'))):struct> -- !query output -{"a":1 days} +{"a":1,"b":1 days} 1,1 days 2 years 8 months,1 hours 10 minutes {"a":2 years 8 months,"b":1 hours 10 minutes} -- !query -SELECT to_json(map('a', interval 25 month 100 day 130 minute)) +SELECT + from_json('{"a":"1 days"}', 'a interval'), + to_json(from_json('{"a":"1 days"}', 'a interval')), + to_json(map('a', interval 25 month 100 day 130 minute)), + from_json(to_json(map('a', interval 25 month 100 day 130 minute)), 'a interval') -- !query schema -struct +struct,to_json(from_json({"a":"1 days"})):string,to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes')):string,from_json(to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes'))):struct> -- !query output -{"a":"2 years 1 months 100 days 2 hours 10 minutes"} +{"a":1 days} {"a":"1 days"} {"a":"2 years 1 months 100 days 2 hours 10 minutes"} {"a":2 years 1 months 100 days 2 hours 10 minutes} diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 94b4f15815ca5..3c4b4301d0025 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 101 +-- Number of queries: 99 -- !query @@ -969,32 +969,24 @@ integer overflow -- !query -SELECT from_csv('1, 1 day', 'a INT, b interval') --- !query schema -struct> --- !query output -{"a":1,"b":1 days} - - --- !query -SELECT to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)) --- !query schema -struct --- !query output -2 years 8 months,1 hours 10 minutes - - --- !query -SELECT from_json('{"a":"1 days"}', 'a interval') +SELECT + from_csv('1, 1 day', 'a INT, b interval'), + to_csv(from_csv('1, 1 day', 'a INT, b interval')), + to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), + from_csv(to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), 'a interval, b interval') -- !query schema -struct> +struct,to_csv(from_csv(1, 1 day)):string,to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes')):string,from_csv(to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes'))):struct> -- !query output -{"a":1 days} +{"a":1,"b":1 days} 1,1 days 2 years 8 months,1 hours 10 minutes {"a":2 years 8 months,"b":1 hours 10 minutes} -- !query -SELECT to_json(map('a', interval 25 month 100 day 130 minute)) +SELECT + from_json('{"a":"1 days"}', 'a interval'), + to_json(from_json('{"a":"1 days"}', 'a interval')), + to_json(map('a', interval 25 month 100 day 130 minute)), + from_json(to_json(map('a', interval 25 month 100 day 130 minute)), 'a interval') -- !query schema -struct +struct,to_json(from_json({"a":"1 days"})):string,to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes')):string,from_json(to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes'))):struct> -- !query output -{"a":"2 years 1 months 100 days 2 hours 10 minutes"} +{"a":1 days} {"a":"1 days"} {"a":"2 years 1 months 100 days 2 hours 10 minutes"} {"a":2 years 1 months 100 days 2 hours 10 minutes} From dbf17f194d625c6bb8d8fcba5f4913e84000ab6f Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 10 Feb 2020 19:04:49 +0800 Subject: [PATCH 0039/1280] [SPARK-30757][SQL][DOC] Update the doc on TableCatalog.alterTable's behavior ### What changes were proposed in this pull request? This PR updates the documentation on `TableCatalog.alterTable`s behavior on the order by which the requested changes are applied. It now explicitly mentions that the changes are applied in the order given. ### Why are the changes needed? The current documentation on `TableCatalog.alterTable` doesn't mention which order the requested changes are applied. It will be useful to explicitly document this behavior so that the user can expect the behavior. For example, `REPLACE COLUMNS` needs to delete columns before adding new columns, and if the order is guaranteed by `alterTable`, it's much easier to work with the catalog API. ### Does this PR introduce any user-facing change? Yes, document change. ### How was this patch tested? Not added (doc changes). Closes #27496 from imback82/catalog_table_alter_table. Authored-by: Terry Kim Signed-off-by: Wenchen Fan (cherry picked from commit 70e545a94d47afb2848c24e81c908d28d41016da) Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/connector/catalog/TableCatalog.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index a69b23bf84d0c..2f102348ec517 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -134,6 +134,8 @@ Table createTable( * Implementations may reject the requested changes. If any change is rejected, none of the * changes should be applied to the table. *

+ * The requested changes must be applied in the order given. + *

* If the catalog supports views and contains a view for the identifier and not a table, this * must throw {@link NoSuchTableException}. * From 7c10a6664e06fc00f9be9704f473135b2cf3e48b Mon Sep 17 00:00:00 2001 From: jiake Date: Mon, 10 Feb 2020 21:48:00 +0800 Subject: [PATCH 0040/1280] [SPARK-30719][SQL] Add unit test to verify the log warning print when intentionally skip AQE ### What changes were proposed in this pull request? This is a follow up in [#27452](https://github.com/apache/spark/pull/27452). Add a unit test to verify whether the log warning is print when intentionally skip AQE. ### Why are the changes needed? Add unit test ### Does this PR introduce any user-facing change? No ### How was this patch tested? adding unit test Closes #27515 from JkSelf/aqeLoggingWarningTest. Authored-by: jiake Signed-off-by: Wenchen Fan (cherry picked from commit 5a240603fd920e3cb5d9ef49c31d46df8a630d8c) Signed-off-by: Wenchen Fan --- .../adaptive/AdaptiveQueryExecSuite.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 96e977221e512..a2071903bea7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -789,4 +789,19 @@ class AdaptiveQueryExecSuite assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) } } + + test("SPARK-30719: do not log warning if intentionally skip AQE") { + val testAppender = new LogAppender("aqe logging warning test when skip") + withLogAppender(testAppender) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val plan = sql("SELECT * FROM testData").queryExecution.executedPlan + assert(!plan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } + assert(!testAppender.loggingEvents + .exists(msg => msg.getRenderedMessage.contains( + s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + + s" enabled but is not supported for"))) + } } From fd6d1b400630d7fee6d031e6de1fccfb4993778b Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Mon, 10 Feb 2020 23:41:39 +0800 Subject: [PATCH 0041/1280] [SPARK-30326][SQL] Raise exception if analyzer exceed max iterations ### What changes were proposed in this pull request? Enhance RuleExecutor strategy to take different actions when exceeding max iterations. And raise exception if analyzer exceed max iterations. ### Why are the changes needed? Currently, both analyzer and optimizer just log warning message if rule execution exceed max iterations. They should have different behavior. Analyzer should raise exception to indicates the plan is not fixed after max iterations, while optimizer just log warning to keep the current plan. This is more feasible after SPARK-30138 was introduced. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Add test in AnalysisSuite Closes #26977 from Eric5553/EnhanceMaxIterations. Authored-by: Eric Wu <492960551@qq.com> Signed-off-by: Wenchen Fan (cherry picked from commit b2011a295bd78b3693a516e049e90250366b8f52) Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++++++- .../sql/catalyst/optimizer/Optimizer.scala | 5 +++- .../sql/catalyst/rules/RuleExecutor.scala | 27 ++++++++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 25 ++++++++++++++++- 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 75f1aa7185ef3..ce82b3b567b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -176,7 +176,15 @@ class Analyzer( def resolver: Resolver = conf.resolver - protected val fixedPoint = FixedPoint(maxIterations) + /** + * If the plan cannot be resolved within maxIterations, analyzer will throw exception to inform + * user to increase the value of SQLConf.ANALYZER_MAX_ITERATIONS. + */ + protected val fixedPoint = + FixedPoint( + maxIterations, + errorOnExceed = true, + maxIterationsSetting = SQLConf.ANALYZER_MAX_ITERATIONS.key) /** * Override to provide additional rules for the "Resolution" batch. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0fdf6b022d885..c90117b4fbbbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -53,7 +53,10 @@ abstract class Optimizer(catalogManager: CatalogManager) "PartitionPruning", "Extract Python UDFs") - protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) + protected def fixedPoint = + FixedPoint( + SQLConf.get.optimizerMaxIterations, + maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key) /** * Defines the default rule batches in the Optimizer. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 287ae0e8e9f67..da5242bee28e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -45,7 +45,17 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * An execution strategy for rules that indicates the maximum number of executions. If the * execution reaches fix point (i.e. converge) before maxIterations, it will stop. */ - abstract class Strategy { def maxIterations: Int } + abstract class Strategy { + + /** The maximum number of executions. */ + def maxIterations: Int + + /** Whether to throw exception when exceeding the maximum number. */ + def errorOnExceed: Boolean = false + + /** The key of SQLConf setting to tune maxIterations */ + def maxIterationsSetting: String = null + } /** A strategy that is run once and idempotent. */ case object Once extends Strategy { val maxIterations = 1 } @@ -54,7 +64,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * A strategy that runs until fix point or maxIterations times, whichever comes first. * Especially, a FixedPoint(1) batch is supposed to run only once. */ - case class FixedPoint(maxIterations: Int) extends Strategy + case class FixedPoint( + override val maxIterations: Int, + override val errorOnExceed: Boolean = false, + override val maxIterationsSetting: String = null) extends Strategy /** A batch of rules. */ protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) @@ -155,8 +168,14 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" - if (Utils.isTesting) { + val endingMsg = if (batch.strategy.maxIterationsSetting == null) { + "." + } else { + s", please set '${batch.strategy.maxIterationsSetting}' to a larger value." + } + val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" + + s"$endingMsg" + if (Utils.isTesting || batch.strategy.errorOnExceed) { throw new TreeNodeException(curPlan, message, null) } else { logWarning(message) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c747d394b1bc2..d38513319388b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -25,9 +25,10 @@ import org.scalatest.Matchers import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan @@ -745,4 +746,26 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics("evt1", sumWithFilter :: Nil, testRelation), "aggregates with filter predicate are not allowed" :: Nil) } + + test("Analysis exceed max iterations") { + // RuleExecutor only throw exception or log warning when the rule is supposed to run + // more than once. + val maxIterations = 2 + val conf = new SQLConf().copy(SQLConf.ANALYZER_MAX_ITERATIONS -> maxIterations) + val testAnalyzer = new Analyzer( + new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf), conf) + + val plan = testRelation2.select( + $"a" / Literal(2) as "div1", + $"a" / $"b" as "div2", + $"a" / $"c" as "div3", + $"a" / $"d" as "div4", + $"e" / $"e" as "div5") + + val message = intercept[TreeNodeException[LogicalPlan]] { + testAnalyzer.execute(plan) + }.getMessage + assert(message.startsWith(s"Max iterations ($maxIterations) reached for batch Resolution, " + + s"please set '${SQLConf.ANALYZER_MAX_ITERATIONS.key}' to a larger value.")) + } } From ff395a39a5b10a7e71ef61813084bd3cf120280c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 9 Feb 2020 19:45:16 -0800 Subject: [PATCH 0042/1280] Revert "[SPARK-29721][SQL] Prune unnecessary nested fields from Generate without Project This reverts commit a0e63b61e7c5d55ae2a9213b95ab1e87ac7c203c. ### What changes were proposed in this pull request? This reverts the patch at #26978 based on gatorsmile's suggestion. ### Why are the changes needed? Original patch #26978 has not considered a corner case. We may need to put more time on ensuring we can cover all cases. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Unit test. Closes #27504 from viirya/revert-SPARK-29721. Authored-by: Liang-Chi Hsieh Signed-off-by: Xiao Li --- .../optimizer/NestedColumnAliasing.scala | 47 ------------------- .../sql/catalyst/optimizer/Optimizer.scala | 43 ++++++++++------- .../datasources/SchemaPruningSuite.scala | 32 ------------- 3 files changed, 25 insertions(+), 97 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index ea85014a37bd8..43a6006f9b5c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -155,53 +155,6 @@ object NestedColumnAliasing { case MapType(keyType, valueType, _) => totalFieldNum(keyType) + totalFieldNum(valueType) case _ => 1 // UDT and others } -} - -/** - * This prunes unnessary nested columns from `Generate` and optional `Project` on top - * of it. - */ -object GeneratorNestedColumnAliasing { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { - // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we - // need to prune nested columns through Project and under Generate. The difference is - // when `nestedSchemaPruningEnabled` is on, nested columns will be pruned further at - // file format readers if it is supported. - case Project(projectList, g: Generate) if (SQLConf.get.nestedPruningOnExpressions || - SQLConf.get.nestedSchemaPruningEnabled) && canPruneGenerator(g.generator) => - // On top on `Generate`, a `Project` that might have nested column accessors. - // We try to get alias maps for both project list and generator's children expressions. - NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map { - case (nestedFieldToAlias, attrToAliases) => - val newChild = pruneGenerate(g, nestedFieldToAlias, attrToAliases) - Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) - } - - case g: Generate if SQLConf.get.nestedSchemaPruningEnabled && - canPruneGenerator(g.generator) => - NestedColumnAliasing.getAliasSubMap(g.generator.children).map { - case (nestedFieldToAlias, attrToAliases) => - pruneGenerate(g, nestedFieldToAlias, attrToAliases) - } - - case _ => - None - } - - private def pruneGenerate( - g: Generate, - nestedFieldToAlias: Map[ExtractValue, Alias], - attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { - val newGenerator = g.generator.transform { - case f: ExtractValue if nestedFieldToAlias.contains(f) => - nestedFieldToAlias(f).toAttribute - }.asInstanceOf[Generator] - - // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - val newGenerate = g.copy(generator = newGenerator) - - NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) - } /** * This is a while-list for pruning nested fields at `Generator`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c90117b4fbbbc..08acac18f48bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -601,24 +601,31 @@ object ColumnPruning extends Rule[LogicalPlan] { s.copy(child = prunedChild(child, s.references)) // prune unrequired references - case p @ Project(_, g: Generate) => - val currP = if (p.references != g.outputSet) { - val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references - val newChild = prunedChild(g.child, requiredAttrs) - val unrequired = g.generator.references -- p.references - val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) - .map(_._2) - p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) - } else { - p - } - // If we can prune nested column on Project + Generate, do it now. - // Otherwise by transforming down to Generate, it could be pruned individually, - // and causes nested column on top Project unable to resolve. - GeneratorNestedColumnAliasing.unapply(currP).getOrElse(currP) - - // prune unrequired nested fields from `Generate`. - case GeneratorNestedColumnAliasing(p) => p + case p @ Project(_, g: Generate) if p.references != g.outputSet => + val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references + val newChild = prunedChild(g.child, requiredAttrs) + val unrequired = g.generator.references -- p.references + val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) + .map(_._2) + p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) + + // prune unrequired nested fields + case p @ Project(projectList, g: Generate) if SQLConf.get.nestedPruningOnExpressions && + NestedColumnAliasing.canPruneGenerator(g.generator) => + NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map { + case (nestedFieldToAlias, attrToAliases) => + val newGenerator = g.generator.transform { + case f: ExtractValue if nestedFieldToAlias.contains(f) => + nestedFieldToAlias(f).toAttribute + }.asInstanceOf[Generator] + + // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. + val newGenerate = g.copy(generator = newGenerator) + + val newChild = NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) + + Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) + }.getOrElse(p) // Eliminate unneeded attributes from right side of a Left Existence Join. case j @ Join(_, right, LeftExistence(_), _, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 5977e867f788a..a3d4905e82cee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -301,38 +301,6 @@ abstract class SchemaPruningSuite checkAnswer(query, Row("Y.", 1) :: Row("X.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) } - testSchemaPruning("select explode of nested field of array of struct") { - // Config combinations - val configs = Seq((true, true), (true, false), (false, true), (false, false)) - - configs.foreach { case (nestedPruning, nestedPruningOnExpr) => - withSQLConf( - SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> nestedPruning.toString, - SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> nestedPruningOnExpr.toString) { - val query1 = spark.table("contacts") - .select(explode(col("friends.first"))) - if (nestedPruning) { - // If `NESTED_SCHEMA_PRUNING_ENABLED` is enabled, - // even disabling `NESTED_PRUNING_ON_EXPRESSIONS`, - // nested schema is still pruned at scan node. - checkScan(query1, "struct>>") - } else { - checkScan(query1, "struct>>") - } - checkAnswer(query1, Row("Susan") :: Nil) - - val query2 = spark.table("contacts") - .select(explode(col("friends.first")), col("friends.middle")) - if (nestedPruning) { - checkScan(query2, "struct>>") - } else { - checkScan(query2, "struct>>") - } - checkAnswer(query2, Row("Susan", Array("Z.")) :: Nil) - } - } - } - protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { From d5e4f2e262e78e4a22e6b881bb721c7a8c0a5823 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 10 Feb 2020 10:45:00 -0800 Subject: [PATCH 0043/1280] [SPARK-27946][SQL][FOLLOW-UP] Change doc and error message for SHOW CREATE TABLE ### What changes were proposed in this pull request? This is a follow-up for #24938 to tweak error message and migration doc. ### Why are the changes needed? Making user know workaround if SHOW CREATE TABLE doesn't work for some Hive tables. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing unit tests. Closes #27505 from viirya/SPARK-27946-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Liang-Chi Hsieh (cherry picked from commit acfdb46a60fc06dac0af55951492d74b7073f546) Signed-off-by: Liang-Chi Hsieh --- docs/sql-migration-guide.md | 2 +- .../org/apache/spark/sql/execution/command/tables.scala | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index be0fe32ded99b..26eb5838892b4 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -326,7 +326,7 @@ license: | - Since Spark 3.0, `SHOW TBLPROPERTIES` will cause `AnalysisException` if the table does not exist. In Spark version 2.4 and earlier, this scenario caused `NoSuchTableException`. Also, `SHOW TBLPROPERTIES` on a temporary view will cause `AnalysisException`. In Spark version 2.4 and earlier, it returned an empty result. - - Since Spark 3.0, `SHOW CREATE TABLE` will always return Spark DDL, even when the given table is a Hive serde table. For Hive DDL, please use `SHOW CREATE TABLE AS SERDE` command instead. + - Since Spark 3.0, `SHOW CREATE TABLE` will always return Spark DDL, even when the given table is a Hive serde table. For generating Hive DDL, please use `SHOW CREATE TABLE AS SERDE` command instead. - Since Spark 3.0, we upgraded the built-in Hive from 1.2 to 2.3. This may need to set `spark.sql.hive.metastore.version` and `spark.sql.hive.metastore.jars` according to the version of the Hive metastore. For example: set `spark.sql.hive.metastore.version` to `1.2.1` and `spark.sql.hive.metastore.jars` to `maven` if your Hive metastore version is 1.2.1. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 468ca505cce1f..90dbdf5515d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -1076,7 +1076,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) "Failed to execute SHOW CREATE TABLE against table " + s"${tableMetadata.identifier}, which is created by Hive and uses the " + "following unsupported feature(s)\n" + - tableMetadata.unsupportedFeatures.map(" - " + _).mkString("\n") + tableMetadata.unsupportedFeatures.map(" - " + _).mkString("\n") + ". " + + s"Please use `SHOW CREATE TABLE ${tableMetadata.identifier} AS SERDE` " + + "to show Hive DDL instead." ) } @@ -1086,7 +1088,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) if ("true".equalsIgnoreCase(tableMetadata.properties.getOrElse("transactional", "false"))) { throw new AnalysisException( - "SHOW CREATE TABLE doesn't support transactional Hive table") + "SHOW CREATE TABLE doesn't support transactional Hive table. " + + s"Please use `SHOW CREATE TABLE ${tableMetadata.identifier} AS SERDE` " + + "to show Hive DDL instead.") } convertTableMetadata(tableMetadata) From 3038a81ecdb526b01e80eeb34a7eacc6ac48d360 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 10 Feb 2020 22:16:25 +0100 Subject: [PATCH 0044/1280] [SPARK-30556][SQL][FOLLOWUP] Reset the status changed in SQLExecution.withThreadLocalCaptured ### What changes were proposed in this pull request? Follow up for #27267, reset the status changed in SQLExecution.withThreadLocalCaptured. ### Why are the changes needed? For code safety. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing UT. Closes #27516 from xuanyuanking/SPARK-30556-follow. Authored-by: Yuanjian Li Signed-off-by: herman (cherry picked from commit a6b91d2bf727e175d0e175295001db85647539b1) Signed-off-by: herman --- .../apache/spark/sql/execution/SQLExecution.scala | 12 +++++++++++- .../sql/internal/ExecutorSideSQLConfSuite.scala | 10 ++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 995d94ef5eac7..9f177819f6ea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -177,9 +177,19 @@ object SQLExecution { val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) Future { + val originalSession = SparkSession.getActiveSession + val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) sc.setLocalProperties(localProps) - body + val res = body + // reset active session and local props. + sc.setLocalProperties(originalLocalProps) + if (originalSession.nonEmpty) { + SparkSession.setActiveSession(originalSession.get) + } else { + SparkSession.clearActiveSession() + } + res }(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 0cc658c499615..46d0c64592a00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.UUID + import org.scalatest.Assertions._ import org.apache.spark.{SparkException, SparkFunSuite, TaskContext} @@ -144,16 +146,16 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } // set local configuration and assert - val confValue1 = "e" + val confValue1 = UUID.randomUUID().toString() createDataframe(confKey, confValue1).createOrReplaceTempView("m") spark.sparkContext.setLocalProperty(confKey, confValue1) - assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect.size == 1) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect().length == 1) // change the conf value and assert again - val confValue2 = "f" + val confValue2 = UUID.randomUUID().toString() createDataframe(confKey, confValue2).createOrReplaceTempView("n") spark.sparkContext.setLocalProperty(confKey, confValue2) - assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().size == 1) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().length == 1) } } } From 45d834cb8cc2c30f902d0dec1cdf561b993521d0 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Feb 2020 14:26:14 -0800 Subject: [PATCH 0045/1280] [SPARK-30779][SS] Fix some API issues found when reviewing Structured Streaming API docs ### What changes were proposed in this pull request? - Fix the scope of `Logging.initializeForcefully` so that it doesn't appear in subclasses' public methods. Right now, `sc.initializeForcefully(false, false)` is allowed to called. - Don't show classes under `org.apache.spark.internal` package in API docs. - Add missing `since` annotation. - Fix the scope of `ArrowUtils` to remove it from the API docs. ### Why are the changes needed? Avoid leaking APIs unintentionally in Spark 3.0.0. ### Does this PR introduce any user-facing change? No. All these changes are to avoid leaking APIs unintentionally in Spark 3.0.0. ### How was this patch tested? Manually generated the API docs and verified the above issues have been fixed. Closes #27528 from zsxwing/audit-ss-apis. Authored-by: Shixiong Zhu Signed-off-by: Xiao Li --- core/src/main/scala/org/apache/spark/internal/Logging.scala | 2 +- project/SparkBuild.scala | 1 + .../sql/connector/read/streaming/ContinuousPartitionReader.java | 2 ++ .../read/streaming/ContinuousPartitionReaderFactory.java | 2 ++ .../spark/sql/connector/read/streaming/ContinuousStream.java | 2 ++ .../spark/sql/connector/read/streaming/MicroBatchStream.java | 2 ++ .../org/apache/spark/sql/connector/read/streaming/Offset.java | 2 ++ .../spark/sql/connector/read/streaming/PartitionOffset.java | 2 ++ .../apache/spark/sql/connector/read/streaming/ReadLimit.java | 1 + .../spark/sql/connector/read/streaming/SparkDataStream.java | 2 ++ .../connector/write/streaming/StreamingDataWriterFactory.java | 2 ++ .../spark/sql/connector/write/streaming/StreamingWrite.java | 2 ++ .../src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala | 2 +- 13 files changed, 22 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index 2e4846bec2db4..0c1d9635b6535 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -117,7 +117,7 @@ trait Logging { } // For testing - def initializeForcefully(isInterpreter: Boolean, silent: Boolean): Unit = { + private[spark] def initializeForcefully(isInterpreter: Boolean, silent: Boolean): Unit = { initializeLogging(isInterpreter, silent) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 707c31d2248eb..9d0af3aa8c1b6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -819,6 +819,7 @@ object Unidoc { .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/internal"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(f => diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java index 8bd5273bb7d8e..c2ad9ec244a0d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java @@ -22,6 +22,8 @@ /** * A variation on {@link PartitionReader} for use with continuous streaming processing. + * + * @since 3.0.0 */ @Evolving public interface ContinuousPartitionReader extends PartitionReader { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java index 962864da4aad8..385c6f655440f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java @@ -27,6 +27,8 @@ /** * A variation on {@link PartitionReaderFactory} that returns {@link ContinuousPartitionReader} * instead of {@link PartitionReader}. It's used for continuous streaming processing. + * + * @since 3.0.0 */ @Evolving public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java index ee01a2553ae7a..a84578fe461a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java @@ -23,6 +23,8 @@ /** * A {@link SparkDataStream} for streaming queries with continuous mode. + * + * @since 3.0.0 */ @Evolving public interface ContinuousStream extends SparkDataStream { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java index ceab0f75734d3..40ecbf0578ee5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java @@ -25,6 +25,8 @@ /** * A {@link SparkDataStream} for streaming queries with micro-batch mode. + * + * @since 3.0.0 */ @Evolving public interface MicroBatchStream extends SparkDataStream { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java index 400de2a659746..efb8ebb684f06 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java @@ -25,6 +25,8 @@ * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. + * + * @since 3.0.0 */ @Evolving public abstract class Offset { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java index 35ad3bbde5cbf..faee230467bea 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java @@ -26,6 +26,8 @@ * provide a method to merge these into a global Offset. * * These offsets must be serializable. + * + * @since 3.0.0 */ @Evolving public interface PartitionOffset extends Serializable { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java index 121ed1ad116f9..36f6e05e365d9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java @@ -27,6 +27,7 @@ * @see SupportsAdmissionControl#latestOffset(Offset, ReadLimit) * @see ReadAllAvailable * @see ReadMaxRows + * @since 3.0.0 */ @Evolving public interface ReadLimit { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java index 1ba0c25ef4466..95703e255ea4e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java @@ -25,6 +25,8 @@ * * Data sources should implement concrete data stream interfaces: * {@link MicroBatchStream} and {@link ContinuousStream}. + * + * @since 3.0.0 */ @Evolving public interface SparkDataStream { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java index 9946867e8ea65..0923d07e7e5a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java @@ -33,6 +33,8 @@ * Note that, the writer factory will be serialized and sent to executors, then the data writer * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. + * + * @since 3.0.0 */ @Evolving public interface StreamingDataWriterFactory extends Serializable { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java index 4f930e1c158e5..e3dec3b2ff55e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java @@ -40,6 +40,8 @@ * do it manually in their Spark applications if they want to retry. * * Please refer to the documentation of commit/abort methods for detailed specifications. + * + * @since 3.0.0 */ @Evolving public interface StreamingWrite { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 2da0d1a51cb29..003ce850c926e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -27,7 +27,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -object ArrowUtils { +private[sql] object ArrowUtils { val rootAllocator = new RootAllocator(Long.MaxValue) From b2b7cca6dec575b578f093bc7caa80f1b9d7b170 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 11 Feb 2020 10:03:01 +0900 Subject: [PATCH 0046/1280] [SPARK-30777][PYTHON][TESTS] Fix test failures for Pandas >= 1.0.0 ### What changes were proposed in this pull request? Fix PySpark test failures for using Pandas >= 1.0.0. ### Why are the changes needed? Pandas 1.0.0 has recently been released and has API changes that result in PySpark test failures, this PR fixes the broken tests. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Manually tested with Pandas 1.0.1 and PyArrow 0.16.0 Closes #27529 from BryanCutler/pandas-fix-tests-1.0-SPARK-30777. Authored-by: Bryan Cutler Signed-off-by: HyukjinKwon (cherry picked from commit 07a9885f2792be1353f4a923d649e90bc431cb38) Signed-off-by: HyukjinKwon --- python/pyspark/sql/tests/test_arrow.py | 4 ++-- python/pyspark/sql/tests/test_pandas_grouped_map.py | 6 +++--- python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 98f44dfd29da5..004c79f290213 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -297,9 +297,9 @@ def test_createDataFrame_does_not_modify_input(self): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) + pdf.iloc[0, 7] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted - pdf.ix[1, '2_int_t'] = None + pdf.iloc[1, 1] = None pdf_copy = pdf.copy(deep=True) self.spark.createDataFrame(pdf, schema=self.schema) self.assertTrue(pdf.equals(pdf_copy)) diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 51dd07fd7d70c..ff53a0c6f2cf2 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -390,11 +390,11 @@ def rename_pdf(pdf, names): # Function returns a pdf with required column names, but order could be arbitrary using dict def change_col_order(pdf): # Constructing a DataFrame from a dict should result in the same order, - # but use from_items to ensure the pdf column order is different than schema - return pd.DataFrame.from_items([ + # but use OrderedDict to ensure the pdf column order is different than schema + return pd.DataFrame.from_dict(OrderedDict([ ('id', pdf.id), ('u', pdf.v * 2), - ('v', pdf.v)]) + ('v', pdf.v)])) ordered_udf = pandas_udf( change_col_order, diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 974ad560daebf..21679785a769e 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -357,7 +357,7 @@ def test_complex_expressions(self): plus_one(sum_udf(col('v1'))), sum_udf(plus_one(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) expected1 = (df.withColumn('v1', df.v + 1) .withColumn('v2', df.v + 2) @@ -368,7 +368,7 @@ def test_complex_expressions(self): plus_one(sum(col('v1'))), sum(plus_one(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) # Test complex expressions with sql expression, scala pandas UDF and # group aggregate pandas UDF @@ -381,7 +381,7 @@ def test_complex_expressions(self): plus_two(sum_udf(col('v1'))), sum_udf(plus_two(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) expected2 = (df.withColumn('v1', df.v + 1) .withColumn('v2', df.v + 2) @@ -392,7 +392,7 @@ def test_complex_expressions(self): plus_two(sum(col('v1'))), sum(plus_two(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) # Test sequential groupby aggregate result3 = (df.groupby('id') From 8efe367a4ee862b8a85aee8881b0335b34cbba70 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 11 Feb 2020 15:50:03 +0900 Subject: [PATCH 0047/1280] [SPARK-30756][SQL] Fix `ThriftServerWithSparkContextSuite` on spark-branch-3.0-test-sbt-hadoop-2.7-hive-2.3 ### What changes were proposed in this pull request? This PR tries #26710 (comment) way to fix the test. ### Why are the changes needed? To make the tests pass. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Jenkins will test first, and then `on spark-branch-3.0-test-sbt-hadoop-2.7-hive-2.3` will test it out. Closes #27513 from HyukjinKwon/test-SPARK-30756. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9d0af3aa8c1b6..1c5c36ea8eae2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -478,7 +478,8 @@ object SparkParallelTestGrouping { "org.apache.spark.sql.hive.thriftserver.ThriftServerQueryTestSuite", "org.apache.spark.sql.hive.thriftserver.SparkSQLEnvSuite", "org.apache.spark.sql.hive.thriftserver.ui.ThriftServerPageSuite", - "org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2ListenerSuite" + "org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2ListenerSuite", + "org.apache.spark.sql.hive.thriftserver.ThriftServerWithSparkContextSuite" ) private val DEFAULT_TEST_GROUP = "default_test_group" From 1e5766cbdd69080a7fc3881636406945fbd85752 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 11 Feb 2020 17:22:08 +0900 Subject: [PATCH 0048/1280] [SPARK-29462][SQL] The data type of "array()" should be array ### What changes were proposed in this pull request? This brings https://github.com/apache/spark/pull/26324 back. It was reverted basically because, firstly Hive compatibility, and the lack of investigations in other DBMSes and ANSI. - In case of PostgreSQL seems coercing NULL literal to TEXT type. - Presto seems coercing `array() + array(1)` -> array of int. - Hive seems `array() + array(1)` -> array of strings Given that, the design choices have been differently made for some reasons. If we pick one of both, seems coercing to array of int makes much more sense. Another investigation was made offline internally. Seems ANSI SQL 2011, section 6.5 "" states: > If ES is specified, then let ET be the element type determined by the context in which ES appears. The declared type DT of ES is Case: > > a) If ES simply contains ARRAY, then ET ARRAY[0]. > > b) If ES simply contains MULTISET, then ET MULTISET. > > ES is effectively replaced by CAST ( ES AS DT ) From reading other related context, doing it to `NullType`. Given the investigation made, choosing to `null` seems correct, and we have a reference Presto now. Therefore, this PR proposes to bring it back. ### Why are the changes needed? When empty array is created, it should be declared as array. ### Does this PR introduce any user-facing change? Yes, `array()` creates `array`. Now `array(1) + array()` can correctly create `array(1)` instead of `array("1")`. ### How was this patch tested? Tested manually Closes #27521 from HyukjinKwon/SPARK-29462. Lead-authored-by: HyukjinKwon Co-authored-by: Aman Omer Signed-off-by: HyukjinKwon (cherry picked from commit 0045be766b949dff23ed72bd559568f17f645ffe) Signed-off-by: HyukjinKwon --- docs/sql-migration-guide.md | 2 ++ .../expressions/complexTypeCreator.scala | 11 ++++++++++- .../org/apache/spark/sql/internal/SQLConf.scala | 9 +++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 +++++++++++++---- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 26eb5838892b4..f98fab5b4c56b 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -215,6 +215,8 @@ license: | For example `SELECT timestamp 'tomorrow';`. - Since Spark 3.0, the `size` function returns `NULL` for the `NULL` input. In Spark version 2.4 and earlier, this function gives `-1` for the same input. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.sizeOfNull` to `true`. + + - Since Spark 3.0, when the `array` function is called without any parameters, it returns an empty array of `NullType`. In Spark version 2.4 and earlier, it returns an empty array of string type. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.arrayDefaultToStringType.enabled` to `true`. - Since Spark 3.0, the interval literal syntax does not allow multiple from-to units anymore. For example, `SELECT INTERVAL '1-1' YEAR TO MONTH '2-2' YEAR TO MONTH'` throws parser exception. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 9ce87a4922c01..7335e305bfe55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -44,10 +45,18 @@ case class CreateArray(children: Seq[Expression]) extends Expression { TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } + private val defaultElementType: DataType = { + if (SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING)) { + StringType + } else { + NullType + } + } + override def dataType: ArrayType = { ArrayType( TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) - .getOrElse(StringType), + .getOrElse(defaultElementType), containsNull = children.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e38fe7606c4ee..b79b767dbb22b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2007,6 +2007,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_ARRAY_DEFAULT_TO_STRING = + buildConf("spark.sql.legacy.arrayDefaultToStringType.enabled") + .internal() + .doc("When set to true, it returns an empty array of string type when the `array` " + + "function is called without any parameters. Otherwise, it returns an empty " + + "array of `NullType`") + .booleanConf + .createWithDefault(false) + val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL = buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7fce03658fc16..9e9d8c3e9a7c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3499,12 +3499,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } - test("SPARK-21281 use string types by default if array and map have no argument") { + test("SPARK-21281 use string types by default if map have no argument") { val ds = spark.range(1) var expectedSchema = new StructType() - .add("x", ArrayType(StringType, containsNull = false), nullable = false) - assert(ds.select(array().as("x")).schema == expectedSchema) - expectedSchema = new StructType() .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) assert(ds.select(map().as("x")).schema == expectedSchema) } @@ -3577,6 +3574,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }.getMessage assert(nonFoldableError.contains("The 'escape' parameter must be a string literal")) } + + test("SPARK-29462: Empty array of NullType for array function with no arguments") { + Seq((true, StringType), (false, NullType)).foreach { + case (arrayDefaultToString, expectedType) => + withSQLConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING.key -> arrayDefaultToString.toString) { + val schema = spark.range(1).select(array()).schema + assert(schema.nonEmpty && schema.head.dataType.isInstanceOf[ArrayType]) + val actualType = schema.head.dataType.asInstanceOf[ArrayType].elementType + assert(actualType === expectedType) + } + } + } } object DataFrameFunctionsSuite { From 37edbab7803c35eecd664f72447418e79638024e Mon Sep 17 00:00:00 2001 From: root1 Date: Tue, 11 Feb 2020 20:42:02 +0800 Subject: [PATCH 0049/1280] [SPARK-27545][SQL][DOC] Update the Documentation for CACHE TABLE and UNCACHE TABLE ### What changes were proposed in this pull request? Document updated for `CACHE TABLE` & `UNCACHE TABLE` ### Why are the changes needed? Cache table creates a temp view while caching data using `CACHE TABLE name AS query`. `UNCACHE TABLE` does not remove this temp view. These things were not mentioned in the existing doc for `CACHE TABLE` & `UNCACHE TABLE`. ### Does this PR introduce any user-facing change? Document updated for `CACHE TABLE` & `UNCACHE TABLE` command. ### How was this patch tested? Manually Closes #27090 from iRakson/SPARK-27545. Lead-authored-by: root1 Co-authored-by: iRakson Signed-off-by: Wenchen Fan (cherry picked from commit b20754d9ee033091e2ef4d5bfa2576f946c9df50) Signed-off-by: Wenchen Fan --- docs/sql-ref-syntax-aux-cache-cache-table.md | 3 ++- docs/sql-ref-syntax-aux-cache-uncache-table.md | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/sql-ref-syntax-aux-cache-cache-table.md b/docs/sql-ref-syntax-aux-cache-cache-table.md index ed6ef973466dd..20ade1961ab0b 100644 --- a/docs/sql-ref-syntax-aux-cache-cache-table.md +++ b/docs/sql-ref-syntax-aux-cache-cache-table.md @@ -20,7 +20,8 @@ license: | --- ### Description -`CACHE TABLE` statement caches contents of a table or output of a query with the given storage level. This reduces scanning of the original files in future queries. +`CACHE TABLE` statement caches contents of a table or output of a query with the given storage level. If a query is cached, then a temp view will be created for this query. +This reduces scanning of the original files in future queries. ### Syntax {% highlight sql %} diff --git a/docs/sql-ref-syntax-aux-cache-uncache-table.md b/docs/sql-ref-syntax-aux-cache-uncache-table.md index e0581d0d213df..69e21c258a333 100644 --- a/docs/sql-ref-syntax-aux-cache-uncache-table.md +++ b/docs/sql-ref-syntax-aux-cache-uncache-table.md @@ -21,11 +21,13 @@ license: | ### Description `UNCACHE TABLE` removes the entries and associated data from the in-memory and/or on-disk cache for a given table or view. The -underlying entries should already have been brought to cache by previous `CACHE TABLE` operation. `UNCACHE TABLE` on a non-existent table throws Exception if `IF EXISTS` is not specified. +underlying entries should already have been brought to cache by previous `CACHE TABLE` operation. `UNCACHE TABLE` on a non-existent table throws an exception if `IF EXISTS` is not specified. + ### Syntax {% highlight sql %} UNCACHE TABLE [ IF EXISTS ] table_identifier {% endhighlight %} + ### Parameters

table_identifier
@@ -37,10 +39,12 @@ UNCACHE TABLE [ IF EXISTS ] table_identifier
+ ### Examples {% highlight sql %} UNCACHE TABLE t1; {% endhighlight %} + ### Related Statements * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) * [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) From e139bf3c1b541b235b9a8f2750a5245ccc902f8a Mon Sep 17 00:00:00 2001 From: fuwhu Date: Tue, 11 Feb 2020 22:16:44 +0800 Subject: [PATCH 0050/1280] [MINOR][DOC] Add class document for PruneFileSourcePartitions and PruneHiveTablePartitions ### What changes were proposed in this pull request? Add class document for PruneFileSourcePartitions and PruneHiveTablePartitions. ### Why are the changes needed? To describe these two classes. ### Does this PR introduce any user-facing change? no ### How was this patch tested? no Closes #27535 from fuwhu/SPARK-15616-FOLLOW-UP. Authored-by: fuwhu Signed-off-by: Wenchen Fan (cherry picked from commit f1d0dce4848a53831268c80bf7e1e0f47a1f7612) Signed-off-by: Wenchen Fan --- .../datasources/PruneFileSourcePartitions.scala | 13 +++++++++++++ .../hive/execution/PruneHiveTablePartitions.scala | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 1ea19c187e51a..a7129fb14d1a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -26,6 +26,19 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan} import org.apache.spark.sql.types.StructType +/** + * Prune the partitions of file source based table using partition filters. Currently, this rule + * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]] + * with [[FileScan]]. + * + * For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding + * statistics will be updated. And the partition filters will be kept in the filters of returned + * logical plan. + * + * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to + * its underlying [[FileScan]]. And the partition filters will be removed in the filters of + * returned logical plan. + */ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { private def getPartitionKeyFiltersAndDataFilters( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index a0349f627d107..da6e4c52cf3a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -30,6 +30,14 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf /** + * Prune hive table partitions using partition filters on [[HiveTableRelation]]. The pruned + * partitions will be kept in [[HiveTableRelation.prunedPartitions]], and the statistics of + * the hive table relation will be updated based on pruned partitions. + * + * This rule is executed in optimization phase, so the statistics can be updated before physical + * planning, which is useful for some spark strategy, eg. + * [[org.apache.spark.sql.execution.SparkStrategies.JoinSelection]]. + * * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. */ private[sql] class PruneHiveTablePartitions(session: SparkSession) From 79c99d0ee02de79df9700123a13d34afaaf21602 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Feb 2020 00:12:45 +0800 Subject: [PATCH 0051/1280] [SPARK-30783] Exclude hive-service-rpc ### What changes were proposed in this pull request? Exclude hive-service-rpc from build. ### Why are the changes needed? hive-service-rpc 2.3.6 and spark sql's thrift server module have duplicate classes. Leaving hive-service-rpc 2.3.6 in the class path means that spark can pick up classes defined in hive instead of its thrift server module, which can cause hard to debug runtime errors due to class loading order and compilation errors for applications depend on spark. If you compare hive-service-rpc 2.3.6's jar (https://search.maven.org/remotecontent?filepath=org/apache/hive/hive-service-rpc/2.3.6/hive-service-rpc-2.3.6.jar) and spark thrift server's jar (e.g. https://repository.apache.org/content/groups/snapshots/org/apache/spark/spark-hive-thriftserver_2.12/3.0.0-SNAPSHOT/spark-hive-thriftserver_2.12-3.0.0-20200207.021914-364.jar), you will see that all of classes provided by hive-service-rpc-2.3.6.jar are covered by spark thrift server's jar. https://issues.apache.org/jira/browse/SPARK-30783 has output of jar tf for both jars. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing tests. Closes #27533 from yhuai/SPARK-30783. Authored-by: Yin Huai Signed-off-by: Wenchen Fan (cherry picked from commit ea626b6acf0de0ff3b0678372f30ba6f84ae2b09) Signed-off-by: Wenchen Fan --- dev/deps/spark-deps-hadoop-2.7-hive-2.3 | 1 - dev/deps/spark-deps-hadoop-3.2-hive-2.3 | 1 - pom.xml | 20 ++++++++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index 1b57250c1fb54..4f4d8b1d4a62a 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -87,7 +87,6 @@ hive-jdbc/2.3.6//hive-jdbc-2.3.6.jar hive-llap-common/2.3.6//hive-llap-common-2.3.6.jar hive-metastore/2.3.6//hive-metastore-2.3.6.jar hive-serde/2.3.6//hive-serde-2.3.6.jar -hive-service-rpc/2.3.6//hive-service-rpc-2.3.6.jar hive-shims-0.23/2.3.6//hive-shims-0.23-2.3.6.jar hive-shims-common/2.3.6//hive-shims-common-2.3.6.jar hive-shims-scheduler/2.3.6//hive-shims-scheduler-2.3.6.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index ffd2364a51317..18e4246d63996 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -86,7 +86,6 @@ hive-jdbc/2.3.6//hive-jdbc-2.3.6.jar hive-llap-common/2.3.6//hive-llap-common-2.3.6.jar hive-metastore/2.3.6//hive-metastore-2.3.6.jar hive-serde/2.3.6//hive-serde-2.3.6.jar -hive-service-rpc/2.3.6//hive-service-rpc-2.3.6.jar hive-shims-0.23/2.3.6//hive-shims-0.23-2.3.6.jar hive-shims-common/2.3.6//hive-shims-common-2.3.6.jar hive-shims-scheduler/2.3.6//hive-shims-scheduler-2.3.6.jar diff --git a/pom.xml b/pom.xml index a8d6ac932bac2..925fa28a291a4 100644 --- a/pom.xml +++ b/pom.xml @@ -1452,6 +1452,11 @@ ${hive.group} hive-service + + + ${hive.group} + hive-service-rpc + ${hive.group} hive-shims @@ -1508,6 +1513,11 @@ ${hive.group} hive-service + + + ${hive.group} + hive-service-rpc + ${hive.group} hive-shims @@ -1761,6 +1771,11 @@ ${hive.group} hive-service + + + ${hive.group} + hive-service-rpc + ${hive.group} hive-shims @@ -1911,6 +1926,11 @@ groovy-all + + + ${hive.group} + hive-service-rpc + org.apache.parquet From a53969617e8770683e675ae6d39e8f8bfd787073 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 11 Feb 2020 09:55:02 -0800 Subject: [PATCH 0052/1280] [SPARK-29462][SQL][DOCS] Add some more context and details in 'spark.sql.defaultUrlStreamHandlerFactory.enabled' documentation ### What changes were proposed in this pull request? This PR adds some more information and context to `spark.sql.defaultUrlStreamHandlerFactory.enabled`. ### Why are the changes needed? It is a bit difficult to understand the documentation of `spark.sql.defaultUrlStreamHandlerFactory.enabled`. ### Does this PR introduce any user-facing change? Nope, internal doc only fix. ### How was this patch tested? Nope. I only tested linter. Closes #27541 from HyukjinKwon/SPARK-29462-followup. Authored-by: HyukjinKwon Signed-off-by: Dongjoon Hyun (cherry picked from commit 99bd59fe29a87bb70485db536b0ae676e7a9d42e) Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/internal/StaticSQLConf.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 6bc752260a893..563e51ed597b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -172,7 +172,13 @@ object StaticSQLConf { val DEFAULT_URL_STREAM_HANDLER_FACTORY_ENABLED = buildStaticConf("spark.sql.defaultUrlStreamHandlerFactory.enabled") - .doc("When true, set FsUrlStreamHandlerFactory to support ADD JAR against HDFS locations") + .doc( + "When true, register Hadoop's FsUrlStreamHandlerFactory to support " + + "ADD JAR against HDFS locations. " + + "It should be disabled when a different stream protocol handler should be registered " + + "to support a particular protocol type, or if Hadoop's FsUrlStreamHandlerFactory " + + "conflicts with other protocol types such as `http` or `https`. See also SPARK-25694 " + + "and HADOOP-14598.") .internal() .booleanConf .createWithDefault(true) From 5199d2f9dcf044f759318457ce3c0a56e00d9537 Mon Sep 17 00:00:00 2001 From: herman Date: Wed, 12 Feb 2020 10:48:29 +0900 Subject: [PATCH 0053/1280] [SPARK-30780][SQL] Empty LocalTableScan should use RDD without partitions ### What changes were proposed in this pull request? This is a small follow-up for https://github.com/apache/spark/pull/27400. This PR makes an empty `LocalTableScanExec` return an `RDD` without partitions. ### Why are the changes needed? It is a bit unexpected that the RDD contains partitions if there is not work to do. It also can save a bit of work when this is used in a more complex plan. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added test to `SparkPlanSuite`. Closes #27530 from hvanhovell/SPARK-30780. Authored-by: herman Signed-off-by: HyukjinKwon (cherry picked from commit b25359cca3190f6a34dce3c3e49c4d2a80e88bdc) Signed-off-by: HyukjinKwon --- .../spark/sql/execution/LocalTableScanExec.scala | 12 ++++++++---- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../apache/spark/sql/execution/SparkPlanSuite.scala | 4 ++++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 1b5115f2e29a3..b452213cd6cc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -45,10 +45,14 @@ case class LocalTableScanExec( } } - private lazy val numParallelism: Int = math.min(math.max(unsafeRows.length, 1), - sqlContext.sparkContext.defaultParallelism) - - private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows, numParallelism) + @transient private lazy val rdd: RDD[InternalRow] = { + if (rows.isEmpty) { + sqlContext.sparkContext.emptyRDD + } else { + val numSlices = math.min(unsafeRows.length, sqlContext.sparkContext.defaultParallelism) + sqlContext.sparkContext.parallelize(unsafeRows, numSlices) + } + } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d2d58a83ded5d..694e576fcded4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -330,7 +330,7 @@ class DataFrameSuite extends QueryTest testData.select("key").coalesce(1).select("key"), testData.select("key").collect().toSeq) - assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0) } test("convert $\"attribute name\" into unresolved attribute") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index e3bc414516c04..56fff1107ae39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -84,4 +84,8 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-30780 empty LocalTableScan should use RDD without partitions") { + assert(LocalTableScanExec(Nil, Nil).execute().getNumPartitions == 0) + } } From ed6193ad68a55a85f51f3ebda08f53cbcf023a24 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 12 Feb 2020 10:49:46 +0900 Subject: [PATCH 0054/1280] [SPARK-30722][PYTHON][DOCS] Update documentation for Pandas UDF with Python type hints ### What changes were proposed in this pull request? This PR targets to document the Pandas UDF redesign with type hints introduced at SPARK-28264. Mostly self-describing; however, there are few things to note for reviewers. 1. This PR replace the existing documentation of pandas UDFs to the newer redesign to promote the Python type hints. I added some words that Spark 3.0 still keeps the compatibility though. 2. This PR proposes to name non-pandas UDFs as "Pandas Function API" 3. SCALAR_ITER become two separate sections to reduce confusion: - `Iterator[pd.Series]` -> `Iterator[pd.Series]` - `Iterator[Tuple[pd.Series, ...]]` -> `Iterator[pd.Series]` 4. I removed some examples that look overkill to me. 5. I also removed some information in the doc, that seems duplicating or too much. ### Why are the changes needed? To document new redesign in pandas UDF. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests should cover. Closes #27466 from HyukjinKwon/SPARK-30722. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon (cherry picked from commit aa6a60530e63ab3bb8b117f8738973d1b26a2cc7) Signed-off-by: HyukjinKwon --- dev/sparktestsupport/modules.py | 1 - docs/sql-pyspark-pandas-with-arrow.md | 233 +++++++---- examples/src/main/python/sql/arrow.py | 258 ++++++------ python/pyspark/sql/pandas/functions.py | 538 +++++++++++-------------- python/pyspark/sql/pandas/group_ops.py | 99 ++++- python/pyspark/sql/pandas/map_ops.py | 6 +- python/pyspark/sql/udf.py | 16 +- 7 files changed, 609 insertions(+), 542 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 40f2ca288d694..391e4bbe1b1f0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -364,7 +364,6 @@ def __hash__(self): "pyspark.sql.avro.functions", "pyspark.sql.pandas.conversion", "pyspark.sql.pandas.map_ops", - "pyspark.sql.pandas.functions", "pyspark.sql.pandas.group_ops", "pyspark.sql.pandas.types", "pyspark.sql.pandas.serializers", diff --git a/docs/sql-pyspark-pandas-with-arrow.md b/docs/sql-pyspark-pandas-with-arrow.md index 7eb8a74547f70..92a515746b607 100644 --- a/docs/sql-pyspark-pandas-with-arrow.md +++ b/docs/sql-pyspark-pandas-with-arrow.md @@ -35,7 +35,7 @@ working with Arrow-enabled data. If you install PySpark using pip, then PyArrow can be brought in as an extra dependency of the SQL module with the command `pip install pyspark[sql]`. Otherwise, you must ensure that PyArrow -is installed and available on all cluster nodes. The current supported version is 0.12.1. +is installed and available on all cluster nodes. The current supported version is 0.15.1+. You can install using pip or conda from the conda-forge channel. See PyArrow [installation](https://arrow.apache.org/docs/python/install.html) for details. @@ -65,132 +65,216 @@ Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) -Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and -Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator -or to wrap the function, no additional configuration is required. Currently, there are two types of -Pandas UDF: Scalar and Grouped Map. +Pandas UDFs are user defined functions that are executed by Spark using +Arrow to transfer data and Pandas to work with the data, which allows vectorized operations. A Pandas +UDF is defined using the `pandas_udf` as a decorator or to wrap the function, and no additional +configuration is required. A Pandas UDF behaves as a regular PySpark function API in general. -### Scalar +Before Spark 3.0, Pandas UDFs used to be defined with `PandasUDFType`. From Spark 3.0 +with Python 3.6+, you can also use [Python type hints](https://www.python.org/dev/peps/pep-0484). +Using Python type hints are preferred and using `PandasUDFType` will be deprecated in +the future release. -Scalar Pandas UDFs are used for vectorizing scalar operations. They can be used with functions such -as `select` and `withColumn`. The Python function should take `pandas.Series` as inputs and return -a `pandas.Series` of the same length. Internally, Spark will execute a Pandas UDF by splitting -columns into batches and calling the function for each batch as a subset of the data, then -concatenating the results together. +Note that the type hint should use `pandas.Series` in all cases but there is one variant +that `pandas.DataFrame` should be used for its input or output type hint instead when the input +or output column is of `StructType`. The following example shows a Pandas UDF which takes long +column, string column and struct column, and outputs a struct column. It requires the function to +specify the type hints of `pandas.Series` and `pandas.DataFrame` as below: -The following example shows how to create a scalar Pandas UDF that computes the product of 2 columns. +

+

+
+{% include_example ser_to_frame_pandas_udf python/sql/arrow.py %} +
+
+

+ +In the following sections, it describes the cominations of the supported type hints. For simplicity, +`pandas.DataFrame` variant is omitted. + +### Series to Series + +The type hint can be expressed as `pandas.Series`, ... -> `pandas.Series`. + +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF where the given +function takes one or more `pandas.Series` and outputs one `pandas.Series`. The output of the function should +always be of the same length as the input. Internally, PySpark will execute a Pandas UDF by splitting +columns into batches and calling the function for each batch as a subset of the data, then concatenating +the results together. + +The following example shows how to create this Pandas UDF that computes the product of 2 columns.
-{% include_example scalar_pandas_udf python/sql/arrow.py %} +{% include_example ser_to_ser_pandas_udf python/sql/arrow.py %}
-### Scalar Iterator +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) + +### Iterator of Series to Iterator of Series + +The type hint can be expressed as `Iterator[pandas.Series]` -> `Iterator[pandas.Series]`. + +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF where the given +function takes an iterator of `pandas.Series` and outputs an iterator of `pandas.Series`. The output of each +series from the function should always be of the same length as the input. In this case, the created +Pandas UDF requires one input column when the Pandas UDF is called. To use multiple input columns, +a different type hint is required. See Iterator of Multiple Series to Iterator of Series. + +It is useful when the UDF execution requires initializing some states although internally it works +identically as Series to Series case. The pseudocode below illustrates the example. + +{% highlight python %} +@pandas_udf("long") +def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + # Do some expensive initialization with a state + state = very_expensive_initialization() + for x in iterator: + # Use that state for whole iterator. + yield calculate_with_state(x, state) -Scalar iterator (`SCALAR_ITER`) Pandas UDF is the same as scalar Pandas UDF above except that the -underlying Python function takes an iterator of batches as input instead of a single batch and, -instead of returning a single output batch, it yields output batches or returns an iterator of -output batches. -It is useful when the UDF execution requires initializing some states, e.g., loading an machine -learning model file to apply inference to every input batch. +df.select(calculate("value")).show() +{% endhighlight %} -The following example shows how to create scalar iterator Pandas UDFs: +The following example shows how to create this Pandas UDF:
-{% include_example scalar_iter_pandas_udf python/sql/arrow.py %} +{% include_example iter_ser_to_iter_ser_pandas_udf python/sql/arrow.py %}
-### Grouped Map -Grouped map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. -Split-apply-combine consists of three steps: -* Split the data into groups by using `DataFrame.groupBy`. -* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The - input data contains all the rows and columns for each group. -* Combine the results into a new `DataFrame`. +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) -To use `groupBy().apply()`, the user needs to define the following: -* A Python function that defines the computation for each group. -* A `StructType` object or a string that defines the schema of the output `DataFrame`. +### Iterator of Multiple Series to Iterator of Series -The column labels of the returned `pandas.DataFrame` must either match the field names in the -defined output schema if specified as strings, or match the field data types by position if not -strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) -on how to label columns when constructing a `pandas.DataFrame`. +The type hint can be expressed as `Iterator[Tuple[pandas.Series, ...]]` -> `Iterator[pandas.Series]`. -Note that all data for a group will be loaded into memory before the function is applied. This can -lead to out of memory exceptions, especially if the group sizes are skewed. The configuration for -[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user -to ensure that the grouped data will fit into the available memory. +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF where the +given function takes an iterator of a tuple of multiple `pandas.Series` and outputs an iterator of `pandas.Series`. +In this case, the created pandas UDF requires multiple input columns as many as the series in the tuple +when the Pandas UDF is called. It works identically as Iterator of Series to Iterator of Series case except the parameter difference. -The following example shows how to use `groupby().apply()` to subtract the mean from each value in the group. +The following example shows how to create this Pandas UDF:
-{% include_example grouped_map_pandas_udf python/sql/arrow.py %} +{% include_example iter_sers_to_iter_ser_pandas_udf python/sql/arrow.py %}
-For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and -[`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) + +### Series to Scalar -### Grouped Aggregate +The type hint can be expressed as `pandas.Series`, ... -> `Any`. -Grouped aggregate Pandas UDFs are similar to Spark aggregate functions. Grouped aggregate Pandas UDFs are used with `groupBy().agg()` and -[`pyspark.sql.Window`](api/python/pyspark.sql.html#pyspark.sql.Window). It defines an aggregation from one or more `pandas.Series` -to a scalar value, where each `pandas.Series` represents a column within the group or window. +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF similar +to PySpark's aggregate functions. The given function takes `pandas.Series` and returns a scalar value. +The return type should be a primitive data type, and the returned scalar can be either a python +primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. +`Any` should ideally be a specific scalar type accordingly. -Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded into memory. Also, -only unbounded window is supported with Grouped aggregate Pandas UDFs currently. +This UDF can be also used with `groupBy().agg()` and [`pyspark.sql.Window`](api/python/pyspark.sql.html#pyspark.sql.Window). +It defines an aggregation from one or more `pandas.Series` to a scalar value, where each `pandas.Series` +represents a column within the group or window. -The following example shows how to use this type of UDF to compute mean with groupBy and window operations: +Note that this type of UDF does not support partial aggregation and all data for a group or window +will be loaded into memory. Also, only unbounded window is supported with Grouped aggregate Pandas +UDFs currently. The following example shows how to use this type of UDF to compute mean with a group-by +and window operations:
-{% include_example grouped_agg_pandas_udf python/sql/arrow.py %} +{% include_example ser_to_scalar_pandas_udf python/sql/arrow.py %}
For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) -### Map Iterator +## Pandas Function APIs + +Pandas function APIs can directly apply a Python native function against the whole DataFrame by +using Pandas instances. Internally it works similarly with Pandas UDFs by Spark using Arrow to transfer +data and Pandas to work with the data, which allows vectorized operations. A Pandas function API behaves +as a regular API under PySpark `DataFrame` in general. + +From Spark 3.0, Grouped map pandas UDF is now categorized as a separate Pandas Function API, +`DataFrame.groupby().applyInPandas()`. It is still possible to use it with `PandasUDFType` +and `DataFrame.groupby().apply()` as it was; however, it is preferred to use +`DataFrame.groupby().applyInPandas()` directly. Using `PandasUDFType` will be deprecated +in the future. + +### Grouped Map + +Grouped map operations with Pandas instances are supported by `DataFrame.groupby().applyInPandas()` +which requires a Python function that takes a `pandas.DataFrame` and return another `pandas.DataFrame`. +It maps each group to each `pandas.DataFrame` in the Python function. + +This API implements the "split-apply-combine" pattern which consists of three steps: +* Split the data into groups by using `DataFrame.groupBy`. +* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The + input data contains all the rows and columns for each group. +* Combine the results into a new PySpark `DataFrame`. -Map iterator Pandas UDFs are used to transform data with an iterator of batches. Map iterator -Pandas UDFs can be used with -[`pyspark.sql.DataFrame.mapInPandas`](api/python/pyspark.sql.html#pyspark.sql.DataFrame.mapInPandas). -It defines a map function that transforms an iterator of `pandas.DataFrame` to another. +To use `groupBy().applyInPandas()`, the user needs to define the following: +* A Python function that defines the computation for each group. +* A `StructType` object or a string that defines the schema of the output PySpark `DataFrame`. + +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. -It can return the output of arbitrary length in contrast to the scalar Pandas UDF. It maps an iterator of `pandas.DataFrame`s, -that represents the current `DataFrame`, using the map iterator UDF and returns the result as a `DataFrame`. +Note that all data for a group will be loaded into memory before the function is applied. This can +lead to out of memory exceptions, especially if the group sizes are skewed. The configuration for +[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user +to ensure that the grouped data will fit into the available memory. -The following example shows how to create map iterator Pandas UDFs: +The following example shows how to use `groupby().applyInPandas()` to subtract the mean from each value +in the group.
-{% include_example map_iter_pandas_udf python/sql/arrow.py %} +{% include_example grouped_apply_in_pandas python/sql/arrow.py %}
-For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and -[`pyspark.sql.DataFrame.mapsInPandas`](api/python/pyspark.sql.html#pyspark.sql.DataFrame.mapInPandas). +For detailed usage, please see [`pyspark.sql.GroupedData.applyInPandas`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.applyInPandas). +### Map + +Map operations with Pandas instances are supported by `DataFrame.mapInPandas()` which maps an iterator +of `pandas.DataFrame`s to another iterator of `pandas.DataFrame`s that represents the current +PySpark `DataFrame` and returns the result as a PySpark `DataFrame`. The functions takes and outputs +an iterator of `pandas.DataFrame`. It can return the output of arbitrary length in contrast to some +Pandas UDFs although internally it works similarly with Series to Series Pandas UDF. + +The following example shows how to use `mapInPandas()`: + +
+
+{% include_example map_in_pandas python/sql/arrow.py %} +
+
-### Cogrouped Map +For detailed usage, please see [`pyspark.sql.DataFrame.mapsInPandas`](api/python/pyspark.sql.html#pyspark.sql.DataFrame.mapInPandas). -Cogrouped map Pandas UDFs allow two DataFrames to be cogrouped by a common key and then a python function applied to -each cogroup. They are used with `groupBy().cogroup().apply()` which consists of the following steps: +### Co-grouped Map +Co-grouped map operations with Pandas instances are supported by `DataFrame.groupby().cogroup().applyInPandas()` which +allows two PySpark `DataFrame`s to be cogrouped by a common key and then a Python function applied to each +cogroup. It consists of the following steps: * Shuffle the data such that the groups of each dataframe which share a key are cogrouped together. -* Apply a function to each cogroup. The input of the function is two `pandas.DataFrame` (with an optional Tuple -representing the key). The output of the function is a `pandas.DataFrame`. -* Combine the pandas.DataFrames from all groups into a new `DataFrame`. +* Apply a function to each cogroup. The input of the function is two `pandas.DataFrame` (with an optional tuple +representing the key). The output of the function is a `pandas.DataFrame`. +* Combine the `pandas.DataFrame`s from all groups into a new PySpark `DataFrame`. -To use `groupBy().cogroup().apply()`, the user needs to define the following: +To use `groupBy().cogroup().applyInPandas()`, the user needs to define the following: * A Python function that defines the computation for each cogroup. -* A `StructType` object or a string that defines the schema of the output `DataFrame`. +* A `StructType` object or a string that defines the schema of the output PySpark `DataFrame`. The column labels of the returned `pandas.DataFrame` must either match the field names in the defined output schema if specified as strings, or match the field data types by position if not @@ -201,16 +285,15 @@ Note that all data for a cogroup will be loaded into memory before the function memory exceptions, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied and it is up to the user to ensure that the cogrouped data will fit into the available memory. -The following example shows how to use `groupby().cogroup().apply()` to perform an asof join between two datasets. +The following example shows how to use `groupby().cogroup().applyInPandas()` to perform an asof join between two datasets.
-{% include_example cogrouped_map_pandas_udf python/sql/arrow.py %} +{% include_example cogrouped_apply_in_pandas python/sql/arrow.py %}
-For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and -[`pyspark.sql.CoGroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.CoGroupedData.apply). +For detailed usage, please see [`pyspark.sql.PandasCogroupedOps.applyInPandas()`](api/python/pyspark.sql.html#pyspark.sql.PandasCogroupedOps.applyInPandas). ## Usage Notes diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 1c983172d36ef..b7d8467172fab 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -23,12 +23,19 @@ from __future__ import print_function +import sys + from pyspark.sql import SparkSession from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version require_minimum_pandas_version() require_minimum_pyarrow_version() +if sys.version_info < (3, 6): + raise Exception( + "Running this example file requires Python 3.6+; however, " + "your Python version was:\n %s" % sys.version) + def dataframe_with_arrow_example(spark): # $example on:dataframe_with_arrow$ @@ -50,15 +57,45 @@ def dataframe_with_arrow_example(spark): print("Pandas DataFrame result statistics:\n%s\n" % str(result_pdf.describe())) -def scalar_pandas_udf_example(spark): - # $example on:scalar_pandas_udf$ +def ser_to_frame_pandas_udf_example(spark): + # $example on:ser_to_frame_pandas_udf$ + import pandas as pd + + from pyspark.sql.functions import pandas_udf + + @pandas_udf("col1 string, col2 long") + def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame: + s3['col2'] = s1 + s2.str.len() + return s3 + + # Create a Spark DataFrame that has three columns including a sturct column. + df = spark.createDataFrame( + [[1, "a string", ("a nested string",)]], + "long_col long, string_col string, struct_col struct") + + df.printSchema() + # root + # |-- long_column: long (nullable = true) + # |-- string_column: string (nullable = true) + # |-- struct_column: struct (nullable = true) + # | |-- col1: string (nullable = true) + + df.select(func("long_col", "string_col", "struct_col")).printSchema() + # |-- func(long_col, string_col, struct_col): struct (nullable = true) + # | |-- col1: string (nullable = true) + # | |-- col2: long (nullable = true) + # $example off:ser_to_frame_pandas_udf$$ + + +def ser_to_ser_pandas_udf_example(spark): + # $example on:ser_to_ser_pandas_udf$ import pandas as pd from pyspark.sql.functions import col, pandas_udf from pyspark.sql.types import LongType # Declare the function and create the UDF - def multiply_func(a, b): + def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series: return a * b multiply = pandas_udf(multiply_func, returnType=LongType()) @@ -83,26 +120,27 @@ def multiply_func(a, b): # | 4| # | 9| # +-------------------+ - # $example off:scalar_pandas_udf$ + # $example off:ser_to_ser_pandas_udf$ -def scalar_iter_pandas_udf_example(spark): - # $example on:scalar_iter_pandas_udf$ +def iter_ser_to_iter_ser_pandas_udf_example(spark): + # $example on:iter_ser_to_iter_ser_pandas_udf$ + from typing import Iterator + import pandas as pd - from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType + from pyspark.sql.functions import pandas_udf pdf = pd.DataFrame([1, 2, 3], columns=["x"]) df = spark.createDataFrame(pdf) - # When the UDF is called with a single column that is not StructType, - # the input to the underlying function is an iterator of pd.Series. - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def plus_one(batch_iter): - for x in batch_iter: + # Declare the function and create the UDF + @pandas_udf("long") + def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + for x in iterator: yield x + 1 - df.select(plus_one(col("x"))).show() + df.select(plus_one("x")).show() # +-----------+ # |plus_one(x)| # +-----------+ @@ -110,15 +148,28 @@ def plus_one(batch_iter): # | 3| # | 4| # +-----------+ + # $example off:iter_ser_to_iter_ser_pandas_udf$ + + +def iter_sers_to_iter_ser_pandas_udf_example(spark): + # $example on:iter_sers_to_iter_ser_pandas_udf$ + from typing import Iterator, Tuple + + import pandas as pd - # When the UDF is called with more than one columns, - # the input to the underlying function is an iterator of pd.Series tuple. - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def multiply_two_cols(batch_iter): - for a, b in batch_iter: + from pyspark.sql.functions import pandas_udf + + pdf = pd.DataFrame([1, 2, 3], columns=["x"]) + df = spark.createDataFrame(pdf) + + # Declare the function and create the UDF + @pandas_udf("long") + def multiply_two_cols( + iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]: + for a, b in iterator: yield a * b - df.select(multiply_two_cols(col("x"), col("x"))).show() + df.select(multiply_two_cols("x", "x")).show() # +-----------------------+ # |multiply_two_cols(x, x)| # +-----------------------+ @@ -126,92 +177,32 @@ def multiply_two_cols(batch_iter): # | 4| # | 9| # +-----------------------+ + # $example off:iter_sers_to_iter_ser_pandas_udf$ - # When the UDF is called with a single column that is StructType, - # the input to the underlying function is an iterator of pd.DataFrame. - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def multiply_two_nested_cols(pdf_iter): - for pdf in pdf_iter: - yield pdf["a"] * pdf["b"] - - df.select( - multiply_two_nested_cols( - struct(col("x").alias("a"), col("x").alias("b")) - ).alias("y") - ).show() - # +---+ - # | y| - # +---+ - # | 1| - # | 4| - # | 9| - # +---+ - - # In the UDF, you can initialize some states before processing batches. - # Wrap your code with try/finally or use context managers to ensure - # the release of resources at the end. - y_bc = spark.sparkContext.broadcast(1) - - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def plus_y(batch_iter): - y = y_bc.value # initialize states - try: - for x in batch_iter: - yield x + y - finally: - pass # release resources here, if any - - df.select(plus_y(col("x"))).show() - # +---------+ - # |plus_y(x)| - # +---------+ - # | 2| - # | 3| - # | 4| - # +---------+ - # $example off:scalar_iter_pandas_udf$ - - -def grouped_map_pandas_udf_example(spark): - # $example on:grouped_map_pandas_udf$ - from pyspark.sql.functions import pandas_udf, PandasUDFType - - df = spark.createDataFrame( - [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ("id", "v")) - - @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) - def subtract_mean(pdf): - # pdf is a pandas.DataFrame - v = pdf.v - return pdf.assign(v=v - v.mean()) - - df.groupby("id").apply(subtract_mean).show() - # +---+----+ - # | id| v| - # +---+----+ - # | 1|-0.5| - # | 1| 0.5| - # | 2|-3.0| - # | 2|-1.0| - # | 2| 4.0| - # +---+----+ - # $example off:grouped_map_pandas_udf$ +def ser_to_scalar_pandas_udf_example(spark): + # $example on:ser_to_scalar_pandas_udf$ + import pandas as pd -def grouped_agg_pandas_udf_example(spark): - # $example on:grouped_agg_pandas_udf$ - from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql.functions import pandas_udf from pyspark.sql import Window df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) - @pandas_udf("double", PandasUDFType.GROUPED_AGG) - def mean_udf(v): + # Declare the function and create the UDF + @pandas_udf("double") + def mean_udf(v: pd.Series) -> float: return v.mean() + df.select(mean_udf(df['v'])).show() + # +-----------+ + # |mean_udf(v)| + # +-----------+ + # | 4.2| + # +-----------+ + df.groupby("id").agg(mean_udf(df['v'])).show() # +---+-----------+ # | id|mean_udf(v)| @@ -233,37 +224,54 @@ def mean_udf(v): # | 2| 5.0| 6.0| # | 2|10.0| 6.0| # +---+----+------+ - # $example off:grouped_agg_pandas_udf$ + # $example off:ser_to_scalar_pandas_udf$ -def map_iter_pandas_udf_example(spark): - # $example on:map_iter_pandas_udf$ - import pandas as pd +def grouped_apply_in_pandas_example(spark): + # $example on:grouped_apply_in_pandas$ + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) - from pyspark.sql.functions import pandas_udf, PandasUDFType + def subtract_mean(pdf): + # pdf is a pandas.DataFrame + v = pdf.v + return pdf.assign(v=v - v.mean()) + + df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show() + # +---+----+ + # | id| v| + # +---+----+ + # | 1|-0.5| + # | 1| 0.5| + # | 2|-3.0| + # | 2|-1.0| + # | 2| 4.0| + # +---+----+ + # $example off:grouped_apply_in_pandas$ + +def map_in_pandas_example(spark): + # $example on:map_in_pandas$ df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) - @pandas_udf(df.schema, PandasUDFType.MAP_ITER) - def filter_func(batch_iter): - for pdf in batch_iter: + def filter_func(iterator): + for pdf in iterator: yield pdf[pdf.id == 1] - df.mapInPandas(filter_func).show() + df.mapInPandas(filter_func, schema=df.schema).show() # +---+---+ # | id|age| # +---+---+ # | 1| 21| # +---+---+ - # $example off:map_iter_pandas_udf$ + # $example off:map_in_pandas$ -def cogrouped_map_pandas_udf_example(spark): - # $example on:cogrouped_map_pandas_udf$ +def cogrouped_apply_in_pandas_example(spark): + # $example on:cogrouped_apply_in_pandas$ import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - df1 = spark.createDataFrame( [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ("time", "id", "v1")) @@ -272,11 +280,11 @@ def cogrouped_map_pandas_udf_example(spark): [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2")) - @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) def asof_join(l, r): return pd.merge_asof(l, r, on="time", by="id") - df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() + df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + asof_join, schema="time int, id int, v1 double, v2 string").show() # +--------+---+---+---+ # | time| id| v1| v2| # +--------+---+---+---+ @@ -285,7 +293,7 @@ def asof_join(l, r): # |20000101| 2|2.0| y| # |20000102| 2|4.0| y| # +--------+---+---+---+ - # $example off:cogrouped_map_pandas_udf$ + # $example off:cogrouped_apply_in_pandas$ if __name__ == "__main__": @@ -296,17 +304,21 @@ def asof_join(l, r): print("Running Pandas to/from conversion example") dataframe_with_arrow_example(spark) - print("Running pandas_udf scalar example") - scalar_pandas_udf_example(spark) - print("Running pandas_udf scalar iterator example") - scalar_iter_pandas_udf_example(spark) - print("Running pandas_udf grouped map example") - grouped_map_pandas_udf_example(spark) - print("Running pandas_udf grouped agg example") - grouped_agg_pandas_udf_example(spark) - print("Running pandas_udf map iterator example") - map_iter_pandas_udf_example(spark) - print("Running pandas_udf cogrouped map example") - cogrouped_map_pandas_udf_example(spark) + print("Running pandas_udf example: Series to Frame") + ser_to_frame_pandas_udf_example(spark) + print("Running pandas_udf example: Series to Series") + ser_to_ser_pandas_udf_example(spark) + print("Running pandas_udf example: Iterator of Series to Iterator of Seires") + iter_ser_to_iter_ser_pandas_udf_example(spark) + print("Running pandas_udf example: Iterator of Multiple Series to Iterator of Series") + iter_sers_to_iter_ser_pandas_udf_example(spark) + print("Running pandas_udf example: Series to Scalar") + ser_to_scalar_pandas_udf_example(spark) + print("Running pandas function example: Grouped Map") + grouped_apply_in_pandas_example(spark) + print("Running pandas function example: Map") + map_in_pandas_example(spark) + print("Running pandas function example: Co-grouped Map") + cogrouped_apply_in_pandas_example(spark) spark.stop() diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 30602789a33a9..31aa321bf5826 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -43,303 +43,228 @@ class PandasUDFType(object): @since(2.3) def pandas_udf(f=None, returnType=None, functionType=None): """ - Creates a vectorized user defined function (UDF). + Creates a pandas user defined function (a.k.a. vectorized user defined function). + + Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer + data and Pandas to work with the data, which allows vectorized operations. A Pandas UDF + is defined using the `pandas_udf` as a decorator or to wrap the function, and no + additional configuration is required. A Pandas UDF behaves as a regular PySpark function + API in general. :param f: user-defined function. A python function if used as a standalone function :param returnType: the return type of the user-defined function. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. - Default: SCALAR. - - .. seealso:: :meth:`pyspark.sql.DataFrame.mapInPandas` - .. seealso:: :meth:`pyspark.sql.GroupedData.applyInPandas` - .. seealso:: :meth:`pyspark.sql.PandasCogroupedOps.applyInPandas` - - The function type of the UDF can be one of the following: - - 1. SCALAR - - A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. - If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. - - :class:`MapType`, nested :class:`StructType` are currently not supported as output types. - - Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP - >>> @pandas_udf(StringType()) # doctest: +SKIP - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], - ... ("id", "name", "age")) # doctest: +SKIP - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ - >>> @pandas_udf("first string, last string") # doctest: +SKIP - ... def split_expand(n): - ... return n.str.split(expand=True) - >>> df.select(split_expand("name")).show() # doctest: +SKIP - +------------------+ - |split_expand(name)| - +------------------+ - | [John, Doe]| - +------------------+ - - .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input - column, but is the length of an internal batch used for each call to the function. - Therefore, this can be used, for example, to ensure the length of each returned - `pandas.Series`, and can not be used as the column length. - - 2. SCALAR_ITER - - A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the - wrapped Python function takes an iterator of batches as input instead of a single batch and, - instead of returning a single output batch, it yields output batches or explicitly returns an - generator or an iterator of output batches. - It is useful when the UDF execution requires initializing some state, e.g., loading a machine - learning model file to apply inference to every input batch. - - .. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all - batches from one partition, although it is currently implemented this way. - Your code shall not rely on this behavior because it might change in the future for - further optimization, e.g., one invocation processes multiple partitions. - - Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. - - >>> import pandas as pd # doctest: +SKIP - >>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType - >>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP - >>> df = spark.createDataFrame(pdf) # doctest: +SKIP - - When the UDF is called with a single column that is not `StructType`, the input to the - underlying function is an iterator of `pd.Series`. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def plus_one(batch_iter): - ... for x in batch_iter: - ... yield x + 1 - ... - >>> df.select(plus_one(col("x"))).show() # doctest: +SKIP - +-----------+ - |plus_one(x)| - +-----------+ - | 2| - | 3| - | 4| - +-----------+ - - When the UDF is called with more than one columns, the input to the underlying function is an - iterator of `pd.Series` tuple. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def multiply_two_cols(batch_iter): - ... for a, b in batch_iter: - ... yield a * b - ... - >>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP - +-----------------------+ - |multiply_two_cols(x, x)| - +-----------------------+ - | 1| - | 4| - | 9| - +-----------------------+ - - When the UDF is called with a single column that is `StructType`, the input to the underlying - function is an iterator of `pd.DataFrame`. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def multiply_two_nested_cols(pdf_iter): - ... for pdf in pdf_iter: - ... yield pdf["a"] * pdf["b"] - ... - >>> df.select( - ... multiply_two_nested_cols( - ... struct(col("x").alias("a"), col("x").alias("b")) - ... ).alias("y") - ... ).show() # doctest: +SKIP - +---+ - | y| - +---+ - | 1| - | 4| - | 9| - +---+ - - In the UDF, you can initialize some states before processing batches, wrap your code with - `try ... finally ...` or use context managers to ensure the release of resources at the end - or in case of early termination. - - >>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def plus_y(batch_iter): - ... y = y_bc.value # initialize some state - ... try: - ... for x in batch_iter: - ... yield x + y - ... finally: - ... pass # release resources here, if any - ... - >>> df.select(plus_y(col("x"))).show() # doctest: +SKIP - +---------+ - |plus_y(x)| - +---------+ - | 2| - | 3| - | 4| - +---------+ - - 3. GROUPED_MAP - - A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` - The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match - the field names in the defined returnType schema if specified as strings, or match the - field data types by position if not strings, e.g. integer indices. - The length of the returned `pandas.DataFrame` can be arbitrary. - - Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP - ... def normalize(pdf): - ... v = pdf.v - ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP - +---+-------------------+ - | id| v| - +---+-------------------+ - | 1|-0.7071067811865475| - | 1| 0.7071067811865475| - | 2|-0.8320502943378437| - | 2|-0.2773500981126146| - | 2| 1.1094003924504583| - +---+-------------------+ - - Alternatively, the user can define a function that takes two arguments. - In this case, the grouping key(s) will be passed as the first argument and the data will - be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy - data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in - as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. - This is useful when the user does not want to hardcode grouping key(s) in the function. - - >>> import pandas as pd # doctest: +SKIP - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP - ... def mean_udf(key, pdf): - ... # key is a tuple of one numpy.int64, which is the value - ... # of 'id' for the current group - ... return pd.DataFrame([key + (pdf.v.mean(),)]) - >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP - +---+---+ - | id| v| - +---+---+ - | 1|1.5| - | 2|6.0| - +---+---+ - >>> @pandas_udf( - ... "id long, `ceil(v / 2)` long, v double", - ... PandasUDFType.GROUPED_MAP) # doctest: +SKIP - >>> def sum_udf(key, pdf): - ... # key is a tuple of two numpy.int64s, which is the values - ... # of 'id' and 'ceil(df.v / 2)' for the current group - ... return pd.DataFrame([key + (pdf.v.sum(),)]) - >>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP - +---+-----------+----+ - | id|ceil(v / 2)| v| - +---+-----------+----+ - | 2| 5|10.0| - | 1| 1| 3.0| - | 2| 3| 5.0| - | 2| 2| 3.0| - +---+-----------+----+ - - .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is - recommended to explicitly index the columns by name to ensure the positions are correct, - or alternatively use an `OrderedDict`. - For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or - `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. - - .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - - 4. GROUPED_AGG - - A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar - The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. - The returned scalar can be either a python primitive type, e.g., `int` or `float` - or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - - :class:`MapType` and :class:`StructType` are currently not supported as output types. - - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and - :class:`pyspark.sql.Window` - - This example shows using grouped aggregated UDFs with groupby: - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def mean_udf(v): - ... return v.mean() - >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP - +---+-----------+ - | id|mean_udf(v)| - +---+-----------+ - | 1| 1.5| - | 2| 6.0| - +---+-----------+ - - This example shows using grouped aggregated UDFs as window functions. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> from pyspark.sql import Window - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def mean_udf(v): - ... return v.mean() - >>> w = (Window.partitionBy('id') - ... .orderBy('v') - ... .rowsBetween(-1, 0)) - >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP - +---+----+------+ - | id| v|mean_v| - +---+----+------+ - | 1| 1.0| 1.0| - | 1| 2.0| 1.5| - | 2| 3.0| 3.0| - | 2| 5.0| 4.0| - | 2|10.0| 7.5| - +---+----+------+ - - .. note:: For performance reasons, the input series to window functions are not copied. + Default: SCALAR. + + .. note:: This parameter exists for compatibility. Using Python type hints is encouraged. + + In order to use this API, customarily the below are imported: + + >>> import pandas as pd + >>> from pyspark.sql.functions import pandas_udf + + From Spark 3.0 with Python 3.6+, `Python type hints `_ + detect the function types as below: + + >>> @pandas_udf(IntegerType()) + ... def slen(s: pd.Series) -> pd.Series: + ... return s.str.len() + + Prior to Spark 3.0, the pandas UDF used `functionType` to decide the execution type as below: + + >>> from pyspark.sql.functions import PandasUDFType + >>> from pyspark.sql.types import IntegerType + >>> @pandas_udf(IntegerType(), PandasUDFType.SCALAR) + ... def slen(s): + ... return s.str.len() + + It is preferred to specify type hints for the pandas UDF instead of specifying pandas UDF + type via `functionType` which will be deprecated in the future releases. + + Note that the type hint should use `pandas.Series` in all cases but there is one variant + that `pandas.DataFrame` should be used for its input or output type hint instead when the input + or output column is of :class:`pyspark.sql.types.StructType`. The following example shows + a Pandas UDF which takes long column, string column and struct column, and outputs a struct + column. It requires the function to specify the type hints of `pandas.Series` and + `pandas.DataFrame` as below: + + >>> @pandas_udf("col1 string, col2 long") + >>> def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame: + ... s3['col2'] = s1 + s2.str.len() + ... return s3 + ... + >>> # Create a Spark DataFrame that has three columns including a sturct column. + ... df = spark.createDataFrame( + ... [[1, "a string", ("a nested string",)]], + ... "long_col long, string_col string, struct_col struct") + >>> df.printSchema() + root + |-- long_column: long (nullable = true) + |-- string_column: string (nullable = true) + |-- struct_column: struct (nullable = true) + | |-- col1: string (nullable = true) + >>> df.select(func("long_col", "string_col", "struct_col")).printSchema() + |-- func(long_col, string_col, struct_col): struct (nullable = true) + | |-- col1: string (nullable = true) + | |-- col2: long (nullable = true) + + In the following sections, it describes the cominations of the supported type hints. For + simplicity, `pandas.DataFrame` variant is omitted. + + * Series to Series + `pandas.Series`, ... -> `pandas.Series` + + The function takes one or more `pandas.Series` and outputs one `pandas.Series`. + The output of the function should always be of the same length as the input. + + >>> @pandas_udf("string") + ... def to_upper(s: pd.Series) -> pd.Series: + ... return s.str.upper() + ... + >>> df = spark.createDataFrame([("John Doe",)], ("name",)) + >>> df.select(to_upper("name")).show() + +--------------+ + |to_upper(name)| + +--------------+ + | JOHN DOE| + +--------------+ + + >>> @pandas_udf("first string, last string") + ... def split_expand(s: pd.Series) -> pd.DataFrame: + ... return s.str.split(expand=True) + ... + >>> df = spark.createDataFrame([("John Doe",)], ("name",)) + >>> df.select(split_expand("name")).show() + +------------------+ + |split_expand(name)| + +------------------+ + | [John, Doe]| + +------------------+ + + .. note:: The length of the input is not that of the whole input column, but is the + length of an internal batch used for each call to the function. + + * Iterator of Series to Iterator of Series + `Iterator[pandas.Series]` -> `Iterator[pandas.Series]` + + The function takes an iterator of `pandas.Series` and outputs an iterator of + `pandas.Series`. In this case, the created pandas UDF instance requires one input + column when this is called as a PySpark column. The output of each series from + the function should always be of the same length as the input. + + It is useful when the UDF execution + requires initializing some states although internally it works identically as + Series to Series case. The pseudocode below illustrates the example. + + .. highlight:: python + .. code-block:: python + + @pandas_udf("long") + def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + # Do some expensive initialization with a state + state = very_expensive_initialization() + for x in iterator: + # Use that state for whole iterator. + yield calculate_with_state(x, state) + + df.select(calculate("value")).show() + + >>> from typing import Iterator + >>> @pandas_udf("long") + ... def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + ... for s in iterator: + ... yield s + 1 + ... + >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"])) + >>> df.select(plus_one(df.v)).show() + +-----------+ + |plus_one(v)| + +-----------+ + | 2| + | 3| + | 4| + +-----------+ + + .. note:: The length of each series is the length of a batch internally used. + + * Iterator of Multiple Series to Iterator of Series + `Iterator[Tuple[pandas.Series, ...]]` -> `Iterator[pandas.Series]` + + The function takes an iterator of a tuple of multiple `pandas.Series` and outputs an + iterator of `pandas.Series`. In this case, the created pandas UDF instance requires + input columns as many as the series when this is called as a PySpark column. + It works identically as Iterator of Series to Iterator of Series case except + the parameter difference. The output of each series from the function should always + be of the same length as the input. + + >>> from typing import Iterator, Tuple + >>> from pyspark.sql.functions import struct, col + >>> @pandas_udf("long") + ... def multiply(iterator: Iterator[Tuple[pd.Series, pd.DataFrame]]) -> Iterator[pd.Series]: + ... for s1, df in iterator: + ... yield s1 * df.v + ... + >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"])) + >>> df.withColumn('output', multiply(col("v"), struct(col("v")))).show() + +---+------+ + | v|output| + +---+------+ + | 1| 1| + | 2| 4| + | 3| 9| + +---+------+ + + .. note:: The length of each series is the length of a batch internally used. + + * Series to Scalar + `pandas.Series`, ... -> `Any` + + The function takes `pandas.Series` and returns a scalar value. The `returnType` + should be a primitive data type, and the returned scalar can be either a python primitive + type, e.g., int or float or a numpy data type, e.g., numpy.int64 or numpy.float64. + `Any` should ideally be a specific scalar type accordingly. + + >>> @pandas_udf("double") + ... def mean_udf(v: pd.Series) -> float: + ... return v.mean() + ... + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) + >>> df.groupby("id").agg(mean_udf(df['v'])).show() + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + This UDF can also be used as window functions as below: + + >>> from pyspark.sql import Window + >>> @pandas_udf("double") + ... def mean_udf(v: pd.Series) -> float: + ... return v.mean() + ... + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) + >>> w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0) + >>> df.withColumn('mean_v', mean_udf("v").over(w)).show() + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.0| + | 1| 2.0| 1.5| + | 2| 3.0| 3.0| + | 2| 5.0| 4.0| + | 2|10.0| 7.5| + +---+----+------+ + + .. note:: For performance reasons, the input series to window functions are not copied. Therefore, mutating the input series is not allowed and will cause incorrect results. For the same reason, users should also not rely on the index of the input series. - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions @@ -348,10 +273,21 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined functions do not take keyword arguments on the calling side. .. note:: The data type of returned `pandas.Series` from the user-defined functions should be - matched with defined returnType (see :meth:`types.to_arrow_type` and + matched with defined `returnType` (see :meth:`types.to_arrow_type` and :meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do conversion on returned data. The conversion is not guaranteed to be correct and results should be checked for accuracy by users. + + .. note:: Currently, + :class:`pyspark.sql.types.MapType`, + :class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and + nested :class:`pyspark.sql.types.StructType` + are currently not supported as output types. + + .. seealso:: :meth:`pyspark.sql.DataFrame.mapInPandas` + .. seealso:: :meth:`pyspark.sql.GroupedData.applyInPandas` + .. seealso:: :meth:`pyspark.sql.PandasCogroupedOps.applyInPandas` + .. seealso:: :meth:`pyspark.sql.UDFRegistration.register` """ # The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that @@ -480,25 +416,3 @@ def _create_pandas_udf(f, returnType, evalType): "or three arguments (key, left, right).") return _create_udf(f, returnType, evalType) - - -def _test(): - import doctest - from pyspark.sql import SparkSession - import pyspark.sql.pandas.functions - globs = pyspark.sql.pandas.functions.__dict__.copy() - spark = SparkSession.builder\ - .master("local[4]")\ - .appName("sql.pandas.functions tests")\ - .getOrCreate() - globs['spark'] = spark - (failure_count, test_count) = doctest.testmod( - pyspark.sql.pandas.functions, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - spark.stop() - if failure_count: - sys.exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 3152271ba9df8..b93f0516cadb1 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -88,29 +88,27 @@ def applyInPandas(self, func, schema): to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. - The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the - returnType of the pandas udf. - - .. note:: This function requires a full shuffle. All the data of a group will be loaded - into memory, so the user should be aware of the potential OOM risk if data is skewed - and certain groups are too large to fit in memory. + The `schema` should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. :param func: a Python native function that takes a `pandas.DataFrame`, and outputs a `pandas.DataFrame`. :param schema: the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - .. note:: Experimental - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import pandas_udf, ceil >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) + ... ("id", "v")) # doctest: +SKIP >>> def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby("id").applyInPandas(normalize, schema="id long, v double").show() - ... # doctest: +SKIP + >>> df.groupby("id").applyInPandas( + ... normalize, schema="id long, v double").show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -121,8 +119,56 @@ def applyInPandas(self, func, schema): | 2| 1.1094003924504583| +---+-------------------+ - .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + Alternatively, the user can pass a function that takes two arguments. + In this case, the grouping key(s) will be passed as the first argument and the data will + be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key(s) in the function. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> def mean_func(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').applyInPandas( + ... mean_func, schema="id long, v double").show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + >>> def sum_func(key, pdf): + ... # key is a tuple of two numpy.int64s, which is the values + ... # of 'id' and 'ceil(df.v / 2)' for the current group + ... return pd.DataFrame([key + (pdf.v.sum(),)]) + >>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas( + ... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() # doctest: +SKIP + +---+-----------+----+ + | id|ceil(v / 2)| v| + +---+-----------+----+ + | 2| 5|10.0| + | 1| 1| 3.0| + | 2| 3| 5.0| + | 2| 2| 3.0| + +---+-----------+----+ + + .. note:: This function requires a full shuffle. All the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + .. note:: Experimental + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ from pyspark.sql import GroupedData from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -176,14 +222,11 @@ def applyInPandas(self, func, schema): `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. - The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the - returnType of the pandas udf. - - .. note:: This function requires a full shuffle. All the data of a cogroup will be loaded - into memory, so the user should be aware of the potential OOM risk if data is skewed - and certain groups are too large to fit in memory. - - .. note:: Experimental + The `schema` should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. :param func: a Python native function that takes two `pandas.DataFrame`\\s, and outputs a `pandas.DataFrame`, or that takes one tuple (grouping keys) and two @@ -191,7 +234,7 @@ def applyInPandas(self, func, schema): :param schema: the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql.functions import pandas_udf >>> df1 = spark.createDataFrame( ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ... ("time", "id", "v1")) @@ -232,6 +275,18 @@ def applyInPandas(self, func, schema): |20000102| 1|3.0| x| +--------+---+---+---+ + .. note:: This function requires a full shuffle. All the data of a cogroup will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + + .. note:: Experimental + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index 75cacd797f9dd..9835e88c6ac21 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -45,10 +45,10 @@ def mapInPandas(self, func, schema): :param schema: the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql.functions import pandas_udf >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) - >>> def filter_func(batch_iter): - ... for pdf in batch_iter: + >>> def filter_func(iterator): + ... for pdf in iterator: ... yield pdf[pdf.id == 1] >>> df.mapInPandas(filter_func, df.schema).show() # doctest: +SKIP +---+---+ diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 433c5fc845c59..10546ecacc57f 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -297,17 +297,18 @@ def register(self, name, f, returnType=None): >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP [Row(random_udf()=82)] - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import pandas_udf + >>> @pandas_udf("integer") # doctest: +SKIP + ... def add_one(s: pd.Series) -> pd.Series: + ... return s + 1 ... >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - >>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def sum_udf(v): + >>> @pandas_udf("integer") # doctest: +SKIP + ... def sum_udf(v: pd.Series) -> int: ... return v.sum() ... >>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP @@ -414,6 +415,9 @@ def _test(): .appName("sql.udf tests")\ .getOrCreate() globs['spark'] = spark + # Hack to skip the unit tests in register. These are currently being tested in proper tests. + # We should reenable this test once we completely drop Python 2. + del pyspark.sql.udf.UDFRegistration.register (failure_count, test_count) = doctest.testmod( pyspark.sql.udf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) From 6e1b6cc5c55c4d945f59da68d248cc3ef82569d3 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Mon, 10 Feb 2020 10:56:43 -0800 Subject: [PATCH 0055/1280] Revert "[SPARK-30245][SQL] Add cache for Like and RLike when pattern is not static" ### What changes were proposed in this pull request? This reverts commit 8ce7962931680c204e84dd75783b1c943ea9c525. There's variable name conflicts with https://github.com/apache/spark/commit/8aebc80e0e67bcb1aa300b8c8b1a209159237632#diff-39298b470865a4cbc67398a4ea11e767. This can be cleanly ported back to branch-3.0. ### Why are the changes needed? Performance investigation were not made enough and it's not clear if it really beneficial or now. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Jenkins tests. Closes #27514 from HyukjinKwon/revert-cache-PR. Authored-by: HyukjinKwon Signed-off-by: Xiao Li --- .../expressions/regexpExpressions.scala | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index c9ddc70bf5bc6..f84c476ea5807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -177,6 +177,8 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """) } } else { + val patternStr = ctx.freshName("patternStr") + val compiledPattern = ctx.freshName("compiledPattern") // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. @@ -185,17 +187,11 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) } else { escapeChar } - val patternStr = ctx.freshName("patternStr") - val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern") - val lastPatternStr = ctx.addMutableState(classOf[String].getName, "lastPatternStr") - nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => { s""" String $patternStr = $eval2.toString(); - if (!$patternStr.equals($lastPatternStr)) { - $compiledPattern = $patternClass.compile($escapeFunc($patternStr, '$newEscapeChar')); - $lastPatternStr = $patternStr; - } + $patternClass $compiledPattern = $patternClass.compile( + $escapeFunc($patternStr, '$newEscapeChar')); ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches(); """ }) @@ -278,16 +274,11 @@ case class RLike(left: Expression, right: Expression) } } else { val rightStr = ctx.freshName("rightStr") - val pattern = ctx.addMutableState(patternClass, "pattern") - val lastRightStr = ctx.addMutableState(classOf[String].getName, "lastRightStr") - + val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = $eval2.toString(); - if (!$rightStr.equals($lastRightStr)) { - $pattern = $patternClass.compile($rightStr); - $lastRightStr = $rightStr; - } + $patternClass $pattern = $patternClass.compile($rightStr); ${ev.value} = $pattern.matcher($eval1.toString()).find(0); """ }) From 9c739358487acf3cd2d5171cebdd053c098aeb8c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 11 Feb 2020 10:15:34 -0800 Subject: [PATCH 0056/1280] Revert "[SPARK-30625][SQL] Support `escape` as third parameter of the `like` function In the PR, I propose to revert the commit 8aebc80e0e67bcb1aa300b8c8b1a209159237632. See the concerns https://github.com/apache/spark/pull/27355#issuecomment-584344438 No By existing test suites. Closes #27531 from MaxGekk/revert-like-3-args. Authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../expressions/regexpExpressions.scala | 85 ++++++------------- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 15 ---- 4 files changed, 31 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 40998080bc4e3..b4a8bafe22dfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -99,7 +99,7 @@ package object dsl { } def like(other: Expression, escapeChar: Char = '\\'): Expression = - Like(expr, other, Literal(escapeChar.toString)) + Like(expr, other, escapeChar) def rlike(other: Expression): Expression = RLike(expr, other) def contains(other: Expression): Expression = Contains(expr, other) def startsWith(other: Expression): Expression = StartsWith(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f84c476ea5807..32a653dba8fd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -22,7 +22,6 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.text.StringEscapeUtils -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} @@ -30,19 +29,17 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends Expression +abstract class StringRegexExpression extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - def str: Expression - def pattern: Expression - def escape(v: String): String def matches(regex: Pattern, str: String): Boolean override def dataType: DataType = BooleanType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal - private lazy val cache: Pattern = pattern match { + private lazy val cache: Pattern = right match { case Literal(value: String, StringType) => compile(value) case _ => null } @@ -54,9 +51,10 @@ trait StringRegexExpression extends Expression Pattern.compile(escape(str)) } - def nullSafeMatch(input1: Any, input2: Any): Any = { - val s = input2.asInstanceOf[UTF8String].toString - val regex = if (cache == null) compile(s) else cache + protected def pattern(str: String) = if (cache == null) compile(str) else cache + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { @@ -64,7 +62,7 @@ trait StringRegexExpression extends Expression } } - override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" } // scalastyle:off line.contains.tab @@ -109,65 +107,46 @@ trait StringRegexExpression extends Expression true > SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/'; true - > SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_'); - true """, note = """ Use RLIKE to match with standard regular expressions. """, since = "1.0.0") // scalastyle:on line.contains.tab -case class Like(str: Expression, pattern: Expression, escape: Expression) - extends TernaryExpression with StringRegexExpression { - - def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\")) - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = Seq(str, pattern, escape) +case class Like(left: Expression, right: Expression, escapeChar: Char) + extends StringRegexExpression { - private lazy val escapeChar: Char = if (escape.foldable) { - escape.eval() match { - case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0) - case s => throw new AnalysisException( - s"The 'escape' parameter must be a string literal of one char but it is $s.") - } - } else { - throw new AnalysisException("The 'escape' parameter must be a string literal.") - } + def this(left: Expression, right: Expression) = this(left, right, '\\') override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = escapeChar match { - case '\\' => s"$str LIKE $pattern" - case c => s"$str LIKE $pattern ESCAPE '$c'" - } - - protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - nullSafeMatch(input1, input2) + case '\\' => s"$left LIKE $right" + case c => s"$left LIKE $right ESCAPE '$c'" } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - if (pattern.foldable) { - val patternVal = pattern.eval() - if (patternVal != null) { + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { val regexStr = - StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString())) - val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern", + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + val pattern = ctx.addMutableState(patternClass, "patternLike", v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = str.genCode(ctx) + val eval = left.genCode(ctx) ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches(); + ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } """) } else { @@ -177,8 +156,8 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """) } } else { - val patternStr = ctx.freshName("patternStr") - val compiledPattern = ctx.freshName("compiledPattern") + val pattern = ctx.freshName("pattern") + val rightStr = ctx.freshName("rightStr") // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. @@ -187,12 +166,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) } else { escapeChar } - nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $patternStr = $eval2.toString(); - $patternClass $compiledPattern = $patternClass.compile( - $escapeFunc($patternStr, '$newEscapeChar')); - ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches(); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile( + $escapeFunc($rightStr, '$newEscapeChar')); + ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) } @@ -231,20 +210,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """, since = "1.0.0") // scalastyle:on line.contains.tab -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def str: Expression = left - override def pattern: Expression = right +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6fc65e14868e0..62e568587fcc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging throw new ParseException("Invalid escape string." + "Escape string must contains only one character.", ctx) } - str + str.charAt(0) }.getOrElse('\\') - invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar))) + invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) case SqlBaseParser.RLIKE => invertIfNotDefined(RLike(e, expression(ctx.pattern))) case SqlBaseParser.NULL if ctx.NOT != null => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9e9d8c3e9a7c5..6012678341ccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3560,21 +3560,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(1))) } - test("the like function with the escape parameter") { - val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape") - checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true)) - - val longEscapeError = intercept[AnalysisException] { - df.selectExpr("like(str, pattern, '@%')").collect() - }.getMessage - assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char")) - - val nonFoldableError = intercept[AnalysisException] { - df.selectExpr("like(str, pattern, escape)").collect() - }.getMessage - assert(nonFoldableError.contains("The 'escape' parameter must be a string literal")) - } - test("SPARK-29462: Empty array of NullType for array function with no arguments") { Seq((true, StringType), (false, NullType)).foreach { case (arrayDefaultToString, expectedType) => From 0608361e5b14dcb08d631701894bd1cea7e39fdd Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 12 Feb 2020 15:19:16 +0900 Subject: [PATCH 0057/1280] [SPARK-30795][SQL] Spark SQL codegen's code() interpolator should treat escapes like Scala's StringContext.s() ### What changes were proposed in this pull request? This PR proposes to make the `code` string interpolator treat escapes the same way as Scala's builtin `StringContext.s()` string interpolator. This will remove the need for an ugly workaround in `Like` expression's codegen. ### Why are the changes needed? The `code()` string interpolator in Spark SQL's code generator should treat escapes like Scala's builtin `StringContext.s()` interpolator, i.e. it should treat escapes in the code parts, and should not treat escapes in the input arguments. For example, ```scala val arg = "This is an argument." val str = s"This is string part 1. $arg This is string part 2." val code = code"This is string part 1. $arg This is string part 2." assert(code.toString == str) ``` We should expect the `code()` interpolator to produce the same result as the `StringContext.s()` interpolator, where only escapes in the string parts should be treated, while the args should be kept verbatim. But in the current implementation, due to the eager folding of code parts and literal input args, the escape treatment is incorrectly done on both code parts and literal args. That causes a problem when an arg contains escape sequences and wants to preserve that in the final produced code string. For example, in `Like` expression's codegen, there's an ugly workaround for this bug: ```scala // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') { s"""\\\\\\$escapeChar""" } else { escapeChar } ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added a new unit test case in `CodeBlockSuite`. Closes #27544 from rednaxelafx/fix-code-string-interpolator. Authored-by: Kris Mok Signed-off-by: HyukjinKwon --- .../sql/catalyst/expressions/codegen/javaCode.scala | 13 +++++++++---- .../catalyst/expressions/regexpExpressions.scala | 13 ++++--------- .../expressions/codegen/CodeBlockSuite.scala | 12 ++++++++++++ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index d9393b9df6bbd..dff258902a0b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -223,6 +223,11 @@ object Block { implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { + /** + * A string interpolator that retains references to the `JavaCode` inputs, and behaves like + * the Scala builtin StringContext.s() interpolator otherwise, i.e. it will treat escapes in + * the code parts, and will not treat escapes in the input arguments. + */ def code(args: Any*): Block = { sc.checkLengths(args) if (sc.parts.length == 0) { @@ -250,7 +255,7 @@ object Block { val inputs = args.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) - buf.append(strings.next) + buf.append(StringContext.treatEscapes(strings.next)) while (strings.hasNext) { val input = inputs.next input match { @@ -262,7 +267,7 @@ object Block { case _ => buf.append(input) } - buf.append(strings.next) + buf.append(StringContext.treatEscapes(strings.next)) } codeParts += buf.toString @@ -286,10 +291,10 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends val strings = codeParts.iterator val inputs = blockInputs.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) - buf.append(StringContext.treatEscapes(strings.next)) + buf.append(strings.next) while (strings.hasNext) { buf.append(inputs.next) - buf.append(StringContext.treatEscapes(strings.next)) + buf.append(strings.next) } buf.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 32a653dba8fd4..ac620b10cfd2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -158,19 +158,14 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) } else { val pattern = ctx.freshName("pattern") val rightStr = ctx.freshName("rightStr") - // We need double escape to avoid org.codehaus.commons.compiler.CompileException. - // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. - // '\"' will cause exception 'Line break in literal not allowed'. - val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') { - s"""\\\\\\$escapeChar""" - } else { - escapeChar - } + // We need to escape the escapeChar to make sure the generated code is valid. + // Otherwise we'll hit org.codehaus.commons.compiler.CompileException. + val escapedEscapeChar = StringEscapeUtils.escapeJava(escapeChar.toString) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = $eval2.toString(); $patternClass $pattern = $patternClass.compile( - $escapeFunc($rightStr, '$newEscapeChar')); + $escapeFunc($rightStr, '$escapedEscapeChar')); ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index 55569b6f2933e..67e3bc69543e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -37,6 +37,18 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) } + test("Code parts should be treated for escapes, but string inputs shouldn't be") { + val strlit = raw"\\" + val code = code"""String s = "foo\\bar" + "$strlit";""" + + val builtin = s"""String s = "foo\\bar" + "$strlit";""" + + val expected = raw"""String s = "foo\bar" + "\\";""" + + assert(builtin == expected) + assert(code.asInstanceOf[CodeBlock].toString == expected) + } + test("Block.stripMargin") { val isNull = JavaCode.isNullVariable("expr1_isNull") val value = JavaCode.variable("expr1", IntegerType) From 7c5d7d78ddc403a3e3701b2e8dc1f4b2885e1a84 Mon Sep 17 00:00:00 2001 From: turbofei Date: Wed, 12 Feb 2020 20:21:52 +0900 Subject: [PATCH 0058/1280] [SPARK-29542][FOLLOW-UP] Keep the description of spark.sql.files.* in tuning guide be consistent with that in SQLConf ### What changes were proposed in this pull request? This pr is a follow up of https://github.com/apache/spark/pull/26200. In this PR, I modify the description of spark.sql.files.* in sql-performance-tuning.md to keep consistent with that in SQLConf. ### Why are the changes needed? To keep consistent with the description in SQLConf. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existed UT. Closes #27545 from turboFei/SPARK-29542-follow-up. Authored-by: turbofei Signed-off-by: HyukjinKwon (cherry picked from commit 8b1839728acaa5e61f542a7332505289726d3162) Signed-off-by: HyukjinKwon --- docs/sql-performance-tuning.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index e289854c7acc7..5a86c0cc31e12 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -67,6 +67,7 @@ that these options will be deprecated in future release as more optimizations ar 134217728 (128 MB) The maximum number of bytes to pack into a single partition when reading files. + This configuration is effective only when using file-based sources such as Parquet, JSON and ORC. @@ -76,7 +77,8 @@ that these options will be deprecated in future release as more optimizations ar The estimated cost to open a file, measured by the number of bytes could be scanned in the same time. This is used when putting multiple files into a partition. It is better to over-estimated, then the partitions with small files will be faster than partitions with bigger files (which is - scheduled first). + scheduled first). This configuration is effective only when using file-based sources such as Parquet, + JSON and ORC. From 2a059e65bae93ddb61f7154d81da3fa0c2dcb669 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 12 Feb 2020 20:12:38 +0800 Subject: [PATCH 0059/1280] [SPARK-30788][SQL] Support `SimpleDateFormat` and `FastDateFormat` as legacy date/timestamp formatters ### What changes were proposed in this pull request? In the PR, I propose to add legacy date/timestamp formatters based on `SimpleDateFormat` and `FastDateFormat`: - `LegacyFastTimestampFormatter` - uses `FastDateFormat` and supports parsing/formatting in microsecond precision. The code was borrowed from Spark 2.4, see https://github.com/apache/spark/pull/26507 & https://github.com/apache/spark/pull/26582 - `LegacySimpleTimestampFormatter` uses `SimpleDateFormat`, and support the `lenient` mode. When the `lenient` parameter is set to `false`, the parser become much stronger in checking its input. ### Why are the changes needed? Spark 2.4.x uses the following parsers for parsing/formatting date/timestamp strings: - `DateTimeFormat` in CSV/JSON datasource - `SimpleDateFormat` - is used in JDBC datasource, in partitions parsing. - `SimpleDateFormat` in strong mode (`lenient = false`), see https://github.com/apache/spark/blob/branch-2.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala#L124. It is used by the `date_format`, `from_unixtime`, `unix_timestamp` and `to_unix_timestamp` functions. The PR aims to make Spark 3.0 compatible with Spark 2.4.x in all those cases when `spark.sql.legacy.timeParser.enabled` is set to `true`. ### Does this PR introduce any user-facing change? This shouldn't change behavior with default settings. If `spark.sql.legacy.timeParser.enabled` is set to `true`, users should observe behavior of Spark 2.4. ### How was this patch tested? - Modified tests in `DateExpressionsSuite` to check the legacy parser - `SimpleDateFormat`. - Added `CSVLegacyTimeParserSuite` and `JsonLegacyTimeParserSuite` to run `CSVSuite` and `JsonSuite` with the legacy parser - `FastDateFormat`. Closes #27524 from MaxGekk/timestamp-formatter-legacy-fallback. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan (cherry picked from commit c1986204e59f1e8cc4b611d5a578cb248cb74c28) Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/CSVInferSchema.scala | 4 +- .../spark/sql/catalyst/csv/CSVOptions.scala | 4 +- .../sql/catalyst/csv/UnivocityGenerator.scala | 7 +- .../sql/catalyst/csv/UnivocityParser.scala | 7 +- .../expressions/datetimeExpressions.scala | 52 ++- .../spark/sql/catalyst/json/JSONOptions.scala | 4 +- .../sql/catalyst/json/JacksonGenerator.scala | 7 +- .../sql/catalyst/json/JacksonParser.scala | 7 +- .../sql/catalyst/json/JsonInferSchema.scala | 4 +- .../sql/catalyst/util/DateFormatter.scala | 66 ++- .../catalyst/util/TimestampFormatter.scala | 132 +++++- .../org/apache/spark/sql/types/Decimal.scala | 2 +- .../expressions/DateExpressionsSuite.scala | 390 ++++++++++-------- .../org/apache/spark/sql/functions.scala | 7 +- .../resources/test-data/bad_after_good.csv | 2 +- .../resources/test-data/value-malformed.csv | 2 +- .../apache/spark/sql/DateFunctionsSuite.scala | 346 +++++++++------- .../execution/datasources/csv/CSVSuite.scala | 23 +- .../datasources/json/JsonSuite.scala | 7 + 19 files changed, 654 insertions(+), 419 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 03cc3cbdf790a..c6a03183ab45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -24,6 +24,7 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.types._ @@ -32,7 +33,8 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private val timestampParser = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val decimalParser = if (options.locale == Locale.US) { // Special handling the default locale for backward compatibility diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 5e40d74e54f11..8892037e03a7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -146,10 +146,10 @@ class CSVOptions( // A language tag in IETF BCP 47 format val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) - val dateFormat: String = parameters.getOrElse("dateFormat", "uuuu-MM-dd") + val dateFormat: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) val timestampFormat: String = - parameters.getOrElse("timestampFormat", "uuuu-MM-dd'T'HH:mm:ss.SSSXXX") + parameters.getOrElse("timestampFormat", s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX") val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 05cb91d10868e..00e3d49787db1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -23,6 +23,7 @@ import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ class UnivocityGenerator( @@ -44,11 +45,13 @@ class UnivocityGenerator( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private def makeConverter(dataType: DataType): ValueConverter = dataType match { case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 5510953804025..cd69c21a01976 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -86,11 +87,13 @@ class UnivocityParser( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val csvFilters = new CSVFilters(filters, requiredSchema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index aa2bd5a1273e0..1f4c8c041f8bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -30,9 +30,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -622,13 +623,15 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti @transient private lazy val formatter: Option[TimestampFormatter] = { if (right.foldable) { - Option(right.eval()).map(format => TimestampFormatter(format.toString, zoneId)) + Option(right.eval()).map { format => + TimestampFormatter(format.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) + } } else None } override protected def nullSafeEval(timestamp: Any, format: Any): Any = { val tf = if (formatter.isEmpty) { - TimestampFormatter(format.toString, zoneId) + TimestampFormatter(format.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) } else { formatter.get } @@ -643,10 +646,14 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti }) }.getOrElse { val tf = TimestampFormatter.getClass.getName.stripSuffix("$") + val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString($tf$$.MODULE$$.apply($format.toString(), $zid) - .format($timestamp))""" + s"""|UTF8String.fromString($tf$$.MODULE$$.apply( + | $format.toString(), + | $zid, + | $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT()) + |.format($timestamp))""".stripMargin }) } } @@ -688,7 +695,7 @@ case class ToUnixTimestamp( copy(timeZoneId = Option(timeZoneId)) def this(time: Expression) = { - this(time, Literal("uuuu-MM-dd HH:mm:ss")) + this(time, Literal(TimestampFormatter.defaultPattern)) } override def prettyName: String = "to_unix_timestamp" @@ -732,7 +739,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression, timeZoneId: Op copy(timeZoneId = Option(timeZoneId)) def this(time: Expression) = { - this(time, Literal("uuuu-MM-dd HH:mm:ss")) + this(time, Literal(TimestampFormatter.defaultPattern)) } def this() = { @@ -758,7 +765,7 @@ abstract class ToTimestamp private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: TimestampFormatter = try { - TimestampFormatter(constFormat.toString, zoneId) + TimestampFormatter(constFormat.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) } catch { case NonFatal(_) => null } @@ -791,8 +798,8 @@ abstract class ToTimestamp } else { val formatString = f.asInstanceOf[UTF8String].toString try { - TimestampFormatter(formatString, zoneId).parse( - t.asInstanceOf[UTF8String].toString) / downScaleFactor + TimestampFormatter(formatString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) + .parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor } catch { case NonFatal(_) => null } @@ -831,13 +838,16 @@ abstract class ToTimestamp } case StringType => val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) - val locale = ctx.addReferenceObj("locale", Locale.US) val tf = TimestampFormatter.getClass.getName.stripSuffix("$") + val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $zid, $locale) - .parse($string.toString()) / $downScaleFactor; + ${ev.value} = $tf$$.MODULE$$.apply( + $format.toString(), + $zid, + $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT()) + .parse($string.toString()) / $downScaleFactor; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } catch (java.text.ParseException e) { @@ -908,7 +918,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def prettyName: String = "from_unixtime" def this(unix: Expression) = { - this(unix, Literal("uuuu-MM-dd HH:mm:ss")) + this(unix, Literal(TimestampFormatter.defaultPattern)) } override def dataType: DataType = StringType @@ -922,7 +932,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: TimestampFormatter = try { - TimestampFormatter(constFormat.toString, zoneId) + TimestampFormatter(constFormat.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) } catch { case NonFatal(_) => null } @@ -948,8 +958,9 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ null } else { try { - UTF8String.fromString(TimestampFormatter(f.toString, zoneId) - .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) + UTF8String.fromString( + TimestampFormatter(f.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) + .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { case NonFatal(_) => null } @@ -980,13 +991,14 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } } else { val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) - val locale = ctx.addReferenceObj("locale", Locale.US) val tf = TimestampFormatter.getClass.getName.stripSuffix("$") + val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" try { - ${ev.value} = UTF8String.fromString($tf$$.MODULE$$.apply($f.toString(), $zid, $locale). - format($seconds * 1000000L)); + ${ev.value} = UTF8String.fromString( + $tf$$.MODULE$$.apply($f.toString(), $zid, $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT()) + .format($seconds * 1000000L)); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index cdf4b4689e821..45c4edff47070 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -88,10 +88,10 @@ private[sql] class JSONOptions( val zoneId: ZoneId = DateTimeUtils.getZoneId( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) - val dateFormat: String = parameters.getOrElse("dateFormat", "uuuu-MM-dd") + val dateFormat: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) val timestampFormat: String = - parameters.getOrElse("timestampFormat", "uuuu-MM-dd'T'HH:mm:ss.SSSXXX") + parameters.getOrElse("timestampFormat", s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX") val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9c63593ea1752..141360ff02117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -24,6 +24,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ /** @@ -80,11 +81,13 @@ private[sql] class JacksonGenerator( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 76efa574a99ff..1e408cdb126b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -58,11 +59,13 @@ class JacksonParser( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) /** * Create a converter which converts the JSON documents held by the `JsonParser` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index f030955ee6e7f..82dd6d0da2632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -40,7 +41,8 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) /** * Infer the type of a collection of json records in three stages: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 28189b65dee9a..2cf82d1cfa177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.util +import java.text.SimpleDateFormat import java.time.{LocalDate, ZoneId} -import java.util.Locale +import java.util.{Date, Locale} import org.apache.commons.lang3.time.FastDateFormat @@ -51,41 +52,76 @@ class Iso8601DateFormatter( } } -class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { - @transient - private lazy val format = FastDateFormat.getInstance(pattern, locale) +trait LegacyDateFormatter extends DateFormatter { + def parseToDate(s: String): Date + def formatDate(d: Date): String override def parse(s: String): Int = { - val milliseconds = format.parse(s).getTime + val milliseconds = parseToDate(s).getTime DateTimeUtils.millisToDays(milliseconds) } override def format(days: Int): String = { val date = DateTimeUtils.toJavaDate(days) - format.format(date) + formatDate(date) } } +class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { + @transient + private lazy val fdf = FastDateFormat.getInstance(pattern, locale) + override def parseToDate(s: String): Date = fdf.parse(s) + override def formatDate(d: Date): String = fdf.format(d) +} + +class LegacySimpleDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { + @transient + private lazy val sdf = new SimpleDateFormat(pattern, locale) + override def parseToDate(s: String): Date = sdf.parse(s) + override def formatDate(d: Date): String = sdf.format(d) +} + object DateFormatter { + import LegacyDateFormats._ + val defaultLocale: Locale = Locale.US - def apply(format: String, zoneId: ZoneId, locale: Locale): DateFormatter = { + def defaultPattern(): String = { + if (SQLConf.get.legacyTimeParserEnabled) "yyyy-MM-dd" else "uuuu-MM-dd" + } + + private def getFormatter( + format: Option[String], + zoneId: ZoneId, + locale: Locale = defaultLocale, + legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT): DateFormatter = { + + val pattern = format.getOrElse(defaultPattern) if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyDateFormatter(format, locale) + legacyFormat match { + case FAST_DATE_FORMAT => + new LegacyFastDateFormatter(pattern, locale) + case SIMPLE_DATE_FORMAT | LENIENT_SIMPLE_DATE_FORMAT => + new LegacySimpleDateFormatter(pattern, locale) + } } else { - new Iso8601DateFormatter(format, zoneId, locale) + new Iso8601DateFormatter(pattern, zoneId, locale) } } + def apply( + format: String, + zoneId: ZoneId, + locale: Locale, + legacyFormat: LegacyDateFormat): DateFormatter = { + getFormatter(Some(format), zoneId, locale, legacyFormat) + } + def apply(format: String, zoneId: ZoneId): DateFormatter = { - apply(format, zoneId, defaultLocale) + getFormatter(Some(format), zoneId) } def apply(zoneId: ZoneId): DateFormatter = { - if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyDateFormatter("yyyy-MM-dd", defaultLocale) - } else { - new Iso8601DateFormatter("uuuu-MM-dd", zoneId, defaultLocale) - } + getFormatter(None, zoneId) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index fe1a4fe710c20..4893a7ec91cbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -17,19 +17,20 @@ package org.apache.spark.sql.catalyst.util -import java.text.ParseException +import java.text.{ParseException, ParsePosition, SimpleDateFormat} import java.time._ import java.time.format.DateTimeParseException import java.time.temporal.ChronoField.MICRO_OF_SECOND import java.time.temporal.TemporalQueries -import java.util.{Locale, TimeZone} +import java.util.{Calendar, GregorianCalendar, Locale, TimeZone} import java.util.concurrent.TimeUnit.SECONDS import org.apache.commons.lang3.time.FastDateFormat -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS -import org.apache.spark.sql.catalyst.util.DateTimeUtils.convertSpecialTimestamp +import org.apache.spark.sql.catalyst.util.DateTimeConstants._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{ convertSpecialTimestamp, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.Decimal sealed trait TimestampFormatter extends Serializable { /** @@ -90,44 +91,139 @@ class FractionTimestampFormatter(zoneId: ZoneId) override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter } -class LegacyTimestampFormatter( +/** + * The custom sub-class of `GregorianCalendar` is needed to get access to + * protected `fields` immediately after parsing. We cannot use + * the `get()` method because it performs normalization of the fraction + * part. Accordingly, the `MILLISECOND` field doesn't contain original value. + * + * Also this class allows to set raw value to the `MILLISECOND` field + * directly before formatting. + */ +class MicrosCalendar(tz: TimeZone, digitsInFraction: Int) + extends GregorianCalendar(tz, Locale.US) { + // Converts parsed `MILLISECOND` field to seconds fraction in microsecond precision. + // For example if the fraction pattern is `SSSS` then `digitsInFraction` = 4, and + // if the `MILLISECOND` field was parsed to `1234`. + def getMicros(): SQLTimestamp = { + // Append 6 zeros to the field: 1234 -> 1234000000 + val d = fields(Calendar.MILLISECOND) * MICROS_PER_SECOND + // Take the first 6 digits from `d`: 1234000000 -> 123400 + // The rest contains exactly `digitsInFraction`: `0000` = 10 ^ digitsInFraction + // So, the result is `(1234 * 1000000) / (10 ^ digitsInFraction) + d / Decimal.POW_10(digitsInFraction) + } + + // Converts the seconds fraction in microsecond precision to a value + // that can be correctly formatted according to the specified fraction pattern. + // The method performs operations opposite to `getMicros()`. + def setMicros(micros: Long): Unit = { + val d = micros * Decimal.POW_10(digitsInFraction) + fields(Calendar.MILLISECOND) = (d / MICROS_PER_SECOND).toInt + } +} + +class LegacyFastTimestampFormatter( pattern: String, zoneId: ZoneId, locale: Locale) extends TimestampFormatter { - @transient private lazy val format = + @transient private lazy val fastDateFormat = FastDateFormat.getInstance(pattern, TimeZone.getTimeZone(zoneId), locale) + @transient private lazy val cal = new MicrosCalendar( + fastDateFormat.getTimeZone, + fastDateFormat.getPattern.count(_ == 'S')) + + def parse(s: String): SQLTimestamp = { + cal.clear() // Clear the calendar because it can be re-used many times + if (!fastDateFormat.parse(s, new ParsePosition(0), cal)) { + throw new IllegalArgumentException(s"'$s' is an invalid timestamp") + } + val micros = cal.getMicros() + cal.set(Calendar.MILLISECOND, 0) + cal.getTimeInMillis * MICROS_PER_MILLIS + micros + } + + def format(timestamp: SQLTimestamp): String = { + cal.setTimeInMillis(Math.floorDiv(timestamp, MICROS_PER_SECOND) * MILLIS_PER_SECOND) + cal.setMicros(Math.floorMod(timestamp, MICROS_PER_SECOND)) + fastDateFormat.format(cal) + } +} - protected def toMillis(s: String): Long = format.parse(s).getTime +class LegacySimpleTimestampFormatter( + pattern: String, + zoneId: ZoneId, + locale: Locale, + lenient: Boolean = true) extends TimestampFormatter { + @transient private lazy val sdf = { + val formatter = new SimpleDateFormat(pattern, locale) + formatter.setTimeZone(TimeZone.getTimeZone(zoneId)) + formatter.setLenient(lenient) + formatter + } - override def parse(s: String): Long = toMillis(s) * MICROS_PER_MILLIS + override def parse(s: String): Long = { + sdf.parse(s).getTime * MICROS_PER_MILLIS + } override def format(us: Long): String = { - format.format(DateTimeUtils.toJavaTimestamp(us)) + val timestamp = DateTimeUtils.toJavaTimestamp(us) + sdf.format(timestamp) } } +object LegacyDateFormats extends Enumeration { + type LegacyDateFormat = Value + val FAST_DATE_FORMAT, SIMPLE_DATE_FORMAT, LENIENT_SIMPLE_DATE_FORMAT = Value +} + object TimestampFormatter { + import LegacyDateFormats._ + val defaultLocale: Locale = Locale.US - def apply(format: String, zoneId: ZoneId, locale: Locale): TimestampFormatter = { + def defaultPattern(): String = s"${DateFormatter.defaultPattern()} HH:mm:ss" + + private def getFormatter( + format: Option[String], + zoneId: ZoneId, + locale: Locale = defaultLocale, + legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT): TimestampFormatter = { + + val pattern = format.getOrElse(defaultPattern) if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyTimestampFormatter(format, zoneId, locale) + legacyFormat match { + case FAST_DATE_FORMAT => + new LegacyFastTimestampFormatter(pattern, zoneId, locale) + case SIMPLE_DATE_FORMAT => + new LegacySimpleTimestampFormatter(pattern, zoneId, locale, lenient = false) + case LENIENT_SIMPLE_DATE_FORMAT => + new LegacySimpleTimestampFormatter(pattern, zoneId, locale, lenient = true) + } } else { - new Iso8601TimestampFormatter(format, zoneId, locale) + new Iso8601TimestampFormatter(pattern, zoneId, locale) } } + def apply( + format: String, + zoneId: ZoneId, + locale: Locale, + legacyFormat: LegacyDateFormat): TimestampFormatter = { + getFormatter(Some(format), zoneId, locale, legacyFormat) + } + + def apply(format: String, zoneId: ZoneId, legacyFormat: LegacyDateFormat): TimestampFormatter = { + getFormatter(Some(format), zoneId, defaultLocale, legacyFormat) + } + def apply(format: String, zoneId: ZoneId): TimestampFormatter = { - apply(format, zoneId, defaultLocale) + getFormatter(Some(format), zoneId) } def apply(zoneId: ZoneId): TimestampFormatter = { - if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyTimestampFormatter("yyyy-MM-dd HH:mm:ss", zoneId, defaultLocale) - } else { - new Iso8601TimestampFormatter("uuuu-MM-dd HH:mm:ss", zoneId, defaultLocale) - } + getFormatter(None, zoneId) } def getFractionFormatter(zoneId: ZoneId): TimestampFormatter = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9ce64b09f7870..f32e48e1cc128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -541,7 +541,7 @@ object Decimal { /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 - private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) private val BIG_DEC_ZERO = BigDecimal(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 274d0beebd300..f04149ab7eb29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, Timesta import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -241,41 +242,45 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("DateFormat") { - checkEvaluation( - DateFormatClass(Literal.create(null, TimestampType), Literal("y"), gmtId), - null) - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), - Literal.create(null, StringType), gmtId), null) - - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), - Literal("y"), gmtId), "2015") - checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), gmtId), "2013") - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), - Literal("H"), gmtId), "0") - checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), gmtId), "13") - - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), - Literal("y"), pstId), "2015") - checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), pstId), "2013") - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), - Literal("H"), pstId), "0") - checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), pstId), "5") - - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), - Literal("y"), jstId), "2015") - checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), jstId), "2013") - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), - Literal("H"), jstId), "0") - checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), jstId), "22") - - // SPARK-28072 The codegen path should work - checkEvaluation( - expression = DateFormatClass( - BoundReference(ordinal = 0, dataType = TimestampType, nullable = true), - BoundReference(ordinal = 1, dataType = StringType, nullable = true), - jstId), - expected = "22", - inputRow = InternalRow(DateTimeUtils.fromJavaTimestamp(ts), UTF8String.fromString("H"))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkEvaluation( + DateFormatClass(Literal.create(null, TimestampType), Literal("y"), gmtId), + null) + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal.create(null, StringType), gmtId), null) + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal("y"), gmtId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), gmtId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal("H"), gmtId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), gmtId), "13") + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), + Literal("y"), pstId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), pstId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), + Literal("H"), pstId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), pstId), "5") + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), + Literal("y"), jstId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), jstId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), + Literal("H"), jstId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), jstId), "22") + + // SPARK-28072 The codegen path should work + checkEvaluation( + expression = DateFormatClass( + BoundReference(ordinal = 0, dataType = TimestampType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + jstId), + expected = "22", + inputRow = InternalRow(DateTimeUtils.fromJavaTimestamp(ts), UTF8String.fromString("H"))) + } + } } test("Hour") { @@ -705,162 +710,189 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("from_unixtime") { - val fmt1 = "yyyy-MM-dd HH:mm:ss" - val sdf1 = new SimpleDateFormat(fmt1, Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { - val timeZoneId = Option(tz.getID) - sdf1.setTimeZone(tz) - sdf2.setTimeZone(tz) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val fmt1 = "yyyy-MM-dd HH:mm:ss" + val sdf1 = new SimpleDateFormat(fmt1, Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) - checkEvaluation( - FromUnixTime(Literal(0L), Literal(fmt1), timeZoneId), - sdf1.format(new Timestamp(0))) - checkEvaluation(FromUnixTime( - Literal(1000L), Literal(fmt1), timeZoneId), - sdf1.format(new Timestamp(1000000))) - checkEvaluation( - FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId), - sdf2.format(new Timestamp(-1000000))) - checkEvaluation( - FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType), timeZoneId), - null) - checkEvaluation( - FromUnixTime(Literal.create(null, LongType), Literal(fmt1), timeZoneId), - null) - checkEvaluation( - FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId), - null) - checkEvaluation( - FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal(fmt1), timeZoneId), + sdf1.format(new Timestamp(0))) + checkEvaluation(FromUnixTime( + Literal(1000L), Literal(fmt1), timeZoneId), + sdf1.format(new Timestamp(1000000))) + checkEvaluation( + FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId), + sdf2.format(new Timestamp(-1000000))) + checkEvaluation( + FromUnixTime( + Literal.create(null, LongType), + Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal(fmt1), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null) - // SPARK-28072 The codegen path for non-literal input should also work - checkEvaluation( - expression = FromUnixTime( - BoundReference(ordinal = 0, dataType = LongType, nullable = true), - BoundReference(ordinal = 1, dataType = StringType, nullable = true), - timeZoneId), - expected = UTF8String.fromString(sdf1.format(new Timestamp(0))), - inputRow = InternalRow(0L, UTF8String.fromString(fmt1))) + // SPARK-28072 The codegen path for non-literal input should also work + checkEvaluation( + expression = FromUnixTime( + BoundReference(ordinal = 0, dataType = LongType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + timeZoneId), + expected = UTF8String.fromString(sdf1.format(new Timestamp(0))), + inputRow = InternalRow(0L, UTF8String.fromString(fmt1))) + } + } } } test("unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - sdf3.setTimeZone(TimeZoneGMT) - - withDefaultTimeZone(TimeZoneGMT) { - for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { - val timeZoneId = Option(tz.getID) - sdf1.setTimeZone(tz) - sdf2.setTimeZone(tz) - - val date1 = Date.valueOf("2015-07-24") - checkEvaluation(UnixTimestamp( - Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L) - checkEvaluation(UnixTimestamp( - Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - 1000L) - checkEvaluation( - UnixTimestamp( - Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - 1000L) - checkEvaluation( - UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), - -1000L) - checkEvaluation(UnixTimestamp( - Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) - val t1 = UnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - val t2 = UnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - assert(t2 - t1 <= 1) - checkEvaluation( - UnixTimestamp( - Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), - null) - checkEvaluation( - UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - null) - checkEvaluation( - UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) + sdf3.setTimeZone(TimeZoneGMT) + + withDefaultTimeZone(TimeZoneGMT) { + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) + + val date1 = Date.valueOf("2015-07-24") + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(0))), + Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L) + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), + Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + 1000L) + checkEvaluation( + UnixTimestamp( + Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + 1000L) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), + Literal(fmt2), timeZoneId), + -1000L) + checkEvaluation(UnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), + MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + val t1 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + UnixTimestamp( + Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + UnixTimestamp( + Literal.create(null, DateType), + Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + null) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + } + } } } } test("to_unix_timestamp") { - val fmt1 = "yyyy-MM-dd HH:mm:ss" - val sdf1 = new SimpleDateFormat(fmt1, Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - sdf3.setTimeZone(TimeZoneGMT) - - withDefaultTimeZone(TimeZoneGMT) { - for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { - val timeZoneId = Option(tz.getID) - sdf1.setTimeZone(tz) - sdf2.setTimeZone(tz) - - val date1 = Date.valueOf("2015-07-24") - checkEvaluation(ToUnixTimestamp( - Literal(sdf1.format(new Timestamp(0))), Literal(fmt1), timeZoneId), 0L) - checkEvaluation(ToUnixTimestamp( - Literal(sdf1.format(new Timestamp(1000000))), Literal(fmt1), timeZoneId), - 1000L) - checkEvaluation(ToUnixTimestamp( - Literal(new Timestamp(1000000)), Literal(fmt1)), - 1000L) - checkEvaluation( - ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), - -1000L) - checkEvaluation(ToUnixTimestamp( - Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) - val t1 = ToUnixTimestamp( - CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] - val t2 = ToUnixTimestamp( - CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] - assert(t2 - t1 <= 1) - checkEvaluation(ToUnixTimestamp( - Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null) - checkEvaluation( - ToUnixTimestamp( - Literal.create(null, DateType), Literal(fmt1), timeZoneId), - null) - checkEvaluation(ToUnixTimestamp( - Literal(date1), Literal.create(null, StringType), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val fmt1 = "yyyy-MM-dd HH:mm:ss" + val sdf1 = new SimpleDateFormat(fmt1, Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) + sdf3.setTimeZone(TimeZoneGMT) + + withDefaultTimeZone(TimeZoneGMT) { + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) + + val date1 = Date.valueOf("2015-07-24") + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(0))), Literal(fmt1), timeZoneId), 0L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal(fmt1), timeZoneId), + 1000L) + checkEvaluation(ToUnixTimestamp( + Literal(new Timestamp(1000000)), Literal(fmt1)), + 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + ToUnixTimestamp( + Literal(sdf2.format(new Timestamp(-1000000))), + Literal(fmt2), timeZoneId), + -1000L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), + MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + val t1 = ToUnixTimestamp( + CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] + val t2 = ToUnixTimestamp( + CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation(ToUnixTimestamp( + Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null) + checkEvaluation( + ToUnixTimestamp( + Literal.create(null, DateType), Literal(fmt1), timeZoneId), + null) + checkEvaluation(ToUnixTimestamp( + Literal(date1), Literal.create(null, StringType), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + ToUnixTimestamp( + Literal("2015-07-24"), + Literal("not a valid format"), timeZoneId), null) - // SPARK-28072 The codegen path for non-literal input should also work - checkEvaluation( - expression = ToUnixTimestamp( - BoundReference(ordinal = 0, dataType = StringType, nullable = true), - BoundReference(ordinal = 1, dataType = StringType, nullable = true), - timeZoneId), - expected = 0L, - inputRow = InternalRow( - UTF8String.fromString(sdf1.format(new Timestamp(0))), UTF8String.fromString(fmt1))) + // SPARK-28072 The codegen path for non-literal input should also work + checkEvaluation( + expression = ToUnixTimestamp( + BoundReference(ordinal = 0, dataType = StringType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + timeZoneId), + expected = 0L, + inputRow = InternalRow( + UTF8String.fromString(sdf1.format(new Timestamp(0))), UTF8String.fromString(fmt1))) + } + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d125581857e0b..2d5504ac00ffa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} +import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf @@ -2881,7 +2882,7 @@ object functions { * @since 1.5.0 */ def from_unixtime(ut: Column): Column = withExpr { - FromUnixTime(ut.expr, Literal("uuuu-MM-dd HH:mm:ss")) + FromUnixTime(ut.expr, Literal(TimestampFormatter.defaultPattern)) } /** @@ -2913,7 +2914,7 @@ object functions { * @since 1.5.0 */ def unix_timestamp(): Column = withExpr { - UnixTimestamp(CurrentTimestamp(), Literal("uuuu-MM-dd HH:mm:ss")) + UnixTimestamp(CurrentTimestamp(), Literal(TimestampFormatter.defaultPattern)) } /** @@ -2927,7 +2928,7 @@ object functions { * @since 1.5.0 */ def unix_timestamp(s: Column): Column = withExpr { - UnixTimestamp(s.expr, Literal("uuuu-MM-dd HH:mm:ss")) + UnixTimestamp(s.expr, Literal(TimestampFormatter.defaultPattern)) } /** diff --git a/sql/core/src/test/resources/test-data/bad_after_good.csv b/sql/core/src/test/resources/test-data/bad_after_good.csv index 4621a7d23714d..1a7c2651a11a7 100644 --- a/sql/core/src/test/resources/test-data/bad_after_good.csv +++ b/sql/core/src/test/resources/test-data/bad_after_good.csv @@ -1,2 +1,2 @@ "good record",1999-08-01 -"bad record",1999-088-01 +"bad record",1999-088_01 diff --git a/sql/core/src/test/resources/test-data/value-malformed.csv b/sql/core/src/test/resources/test-data/value-malformed.csv index 8945ed73d2e83..6e6f08fca6df8 100644 --- a/sql/core/src/test/resources/test-data/value-malformed.csv +++ b/sql/core/src/test/resources/test-data/value-malformed.csv @@ -1,2 +1,2 @@ -0,2013-111-11 12:13:14 +0,2013-111_11 12:13:14 1,1983-08-04 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index bb8cdf3cb6de1..41d53c959ef99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -96,15 +96,19 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { } test("date format") { - val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") - checkAnswer( - df.select(date_format($"a", "y"), date_format($"b", "y"), date_format($"c", "y")), - Row("2015", "2015", "2013")) + checkAnswer( + df.select(date_format($"a", "y"), date_format($"b", "y"), date_format($"c", "y")), + Row("2015", "2015", "2013")) - checkAnswer( - df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), - Row("2015", "2015", "2013")) + checkAnswer( + df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), + Row("2015", "2015", "2013")) + } + } } test("year") { @@ -525,170 +529,194 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - val fmt3 = "yy-MM-dd HH-mm-ss" - val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") - checkAnswer( - df.select(from_unixtime(col("a"))), - Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) - checkAnswer( - df.select(from_unixtime(col("a"), fmt2)), - Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) - checkAnswer( - df.select(from_unixtime(col("a"), fmt3)), - Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr("from_unixtime(a)"), - Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr(s"from_unixtime(a, '$fmt2')"), - Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr(s"from_unixtime(a, '$fmt3')"), - Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + } } private def secs(millis: Long): Long = TimeUnit.MILLISECONDS.toSeconds(millis) test("unix_timestamp") { - val date1 = Date.valueOf("2015-07-24") - val date2 = Date.valueOf("2015-07-25") - val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") - val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") - val s1 = "2015/07/24 10:00:00.5" - val s2 = "2015/07/25 02:02:02.6" - val ss1 = "2015-07-24 10:00:00" - val ss2 = "2015-07-25 02:02:02" - val fmt = "yyyy/MM/dd HH:mm:ss.S" - val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") - checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( - Row(secs(date1.getTime)), Row(secs(date2.getTime)))) - checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( - Row(secs(date1.getTime)), Row(secs(date2.getTime)))) - checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - - val x1 = "2015-07-24 10:00:00" - val x2 = "2015-25-07 02:02:02" - val x3 = "2015-07-24 25:02:02" - val x4 = "2015-24-07 26:02:02" - val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") - val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") - - val df1 = Seq(x1, x2, x3, x4).toDF("x") - checkAnswer(df1.select(unix_timestamp(col("x"))), Seq( - Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) - checkAnswer(df1.selectExpr("unix_timestamp(x)"), Seq( - Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) - checkAnswer(df1.select(unix_timestamp(col("x"), "yyyy-dd-MM HH:mm:ss")), Seq( - Row(null), Row(secs(ts2.getTime)), Row(null), Row(null))) - checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( - Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) - - // invalid format - checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd aa:HH:ss')"), Seq( - Row(null), Row(null), Row(null), Row(null))) - - // february - val y1 = "2016-02-29" - val y2 = "2017-02-29" - val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") - val df2 = Seq(y1, y2).toDF("y") - checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( - Row(secs(ts5.getTime)), Row(null))) - - val now = sql("select unix_timestamp()").collect().head.getLong(0) - checkAnswer( - sql(s"select cast ($now as timestamp)"), - Row(new java.util.Date(TimeUnit.SECONDS.toMillis(now)))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(secs(date1.getTime)), Row(secs(date2.getTime)))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(secs(date1.getTime)), Row(secs(date2.getTime)))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + + val x1 = "2015-07-24 10:00:00" + val x2 = "2015-25-07 02:02:02" + val x3 = "2015-07-24 25:02:02" + val x4 = "2015-24-07 26:02:02" + val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") + val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") + + val df1 = Seq(x1, x2, x3, x4).toDF("x") + checkAnswer(df1.select(unix_timestamp(col("x"))), Seq( + Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) + checkAnswer(df1.selectExpr("unix_timestamp(x)"), Seq( + Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) + checkAnswer(df1.select(unix_timestamp(col("x"), "yyyy-dd-MM HH:mm:ss")), Seq( + Row(null), Row(secs(ts2.getTime)), Row(null), Row(null))) + checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( + Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) + + // invalid format + checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd aa:HH:ss')"), Seq( + Row(null), Row(null), Row(null), Row(null))) + + // february + val y1 = "2016-02-29" + val y2 = "2017-02-29" + val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") + val df2 = Seq(y1, y2).toDF("y") + checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( + Row(secs(ts5.getTime)), Row(null))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer( + sql(s"select cast ($now as timestamp)"), + Row(new java.util.Date(TimeUnit.SECONDS.toMillis(now)))) + } + } } test("to_unix_timestamp") { - val date1 = Date.valueOf("2015-07-24") - val date2 = Date.valueOf("2015-07-25") - val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") - val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") - val s1 = "2015/07/24 10:00:00.5" - val s2 = "2015/07/25 02:02:02.6" - val ss1 = "2015-07-24 10:00:00" - val ss2 = "2015-07-25 02:02:02" - val fmt = "yyyy/MM/dd HH:mm:ss.S" - val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") - checkAnswer(df.selectExpr("to_unix_timestamp(ts)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr("to_unix_timestamp(ss)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr(s"to_unix_timestamp(d, '$fmt')"), Seq( - Row(secs(date1.getTime)), Row(secs(date2.getTime)))) - checkAnswer(df.selectExpr(s"to_unix_timestamp(s, '$fmt')"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - - val x1 = "2015-07-24 10:00:00" - val x2 = "2015-25-07 02:02:02" - val x3 = "2015-07-24 25:02:02" - val x4 = "2015-24-07 26:02:02" - val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") - val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") - - val df1 = Seq(x1, x2, x3, x4).toDF("x") - checkAnswer(df1.selectExpr("to_unix_timestamp(x)"), Seq( - Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) - checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( - Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) - - // february - val y1 = "2016-02-29" - val y2 = "2017-02-29" - val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") - val df2 = Seq(y1, y2).toDF("y") - checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( - Row(secs(ts5.getTime)), Row(null))) - - // invalid format - checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd bb:HH:ss')"), Seq( - Row(null), Row(null), Row(null), Row(null))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.selectExpr("to_unix_timestamp(ts)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr("to_unix_timestamp(ss)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(d, '$fmt')"), Seq( + Row(secs(date1.getTime)), Row(secs(date2.getTime)))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(s, '$fmt')"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + + val x1 = "2015-07-24 10:00:00" + val x2 = "2015-25-07 02:02:02" + val x3 = "2015-07-24 25:02:02" + val x4 = "2015-24-07 26:02:02" + val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") + val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") + + val df1 = Seq(x1, x2, x3, x4).toDF("x") + checkAnswer(df1.selectExpr("to_unix_timestamp(x)"), Seq( + Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) + checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( + Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) + + // february + val y1 = "2016-02-29" + val y2 = "2017-02-29" + val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") + val df2 = Seq(y1, y2).toDF("y") + checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( + Row(secs(ts5.getTime)), Row(null))) + + // invalid format + checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd bb:HH:ss')"), Seq( + Row(null), Row(null), Row(null), Row(null))) + } + } } test("to_timestamp") { - val date1 = Date.valueOf("2015-07-24") - val date2 = Date.valueOf("2015-07-25") - val ts_date1 = Timestamp.valueOf("2015-07-24 00:00:00") - val ts_date2 = Timestamp.valueOf("2015-07-25 00:00:00") - val ts1 = Timestamp.valueOf("2015-07-24 10:00:00") - val ts2 = Timestamp.valueOf("2015-07-25 02:02:02") - val s1 = "2015/07/24 10:00:00.5" - val s2 = "2015/07/25 02:02:02.6" - val ts1m = Timestamp.valueOf("2015-07-24 10:00:00.5") - val ts2m = Timestamp.valueOf("2015-07-25 02:02:02.6") - val ss1 = "2015-07-24 10:00:00" - val ss2 = "2015-07-25 02:02:02" - val fmt = "yyyy/MM/dd HH:mm:ss.S" - val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") - - checkAnswer(df.select(to_timestamp(col("ss"))), - df.select(unix_timestamp(col("ss")).cast("timestamp"))) - checkAnswer(df.select(to_timestamp(col("ss"))), Seq( - Row(ts1), Row(ts2))) - checkAnswer(df.select(to_timestamp(col("s"), fmt)), Seq( - Row(ts1m), Row(ts2m))) - checkAnswer(df.select(to_timestamp(col("ts"), fmt)), Seq( - Row(ts1), Row(ts2))) - checkAnswer(df.select(to_timestamp(col("d"), "yyyy-MM-dd")), Seq( - Row(ts_date1), Row(ts_date2))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts_date1 = Timestamp.valueOf("2015-07-24 00:00:00") + val ts_date2 = Timestamp.valueOf("2015-07-25 00:00:00") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ts1m = Timestamp.valueOf("2015-07-24 10:00:00.5") + val ts2m = Timestamp.valueOf("2015-07-25 02:02:02.6") + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + + checkAnswer(df.select(to_timestamp(col("ss"))), + df.select(unix_timestamp(col("ss")).cast("timestamp"))) + checkAnswer(df.select(to_timestamp(col("ss"))), Seq( + Row(ts1), Row(ts2))) + if (legacyParser) { + // In Spark 2.4 and earlier, to_timestamp() parses in seconds precision and cuts off + // the fractional part of seconds. The behavior was changed by SPARK-27438. + val legacyFmt = "yyyy/MM/dd HH:mm:ss" + checkAnswer(df.select(to_timestamp(col("s"), legacyFmt)), Seq( + Row(ts1), Row(ts2))) + } else { + checkAnswer(df.select(to_timestamp(col("s"), fmt)), Seq( + Row(ts1m), Row(ts2m))) + } + checkAnswer(df.select(to_timestamp(col("ts"), fmt)), Seq( + Row(ts1), Row(ts2))) + checkAnswer(df.select(to_timestamp(col("d"), "yyyy-MM-dd")), Seq( + Row(ts_date1), Row(ts_date2))) + } + } } test("datediff") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 97dfbbdb7fd2f..b1105b4a63bba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1182,7 +1182,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, - Row(0, null, "0,2013-111-11 12:13:14") :: + Row(0, null, "0,2013-111_11 12:13:14") :: Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: Nil) @@ -1199,7 +1199,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, - Row(0, "0,2013-111-11 12:13:14", null) :: + Row(0, "0,2013-111_11 12:13:14", null) :: Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: Nil) @@ -1435,7 +1435,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa assert(df.filter($"_corrupt_record".isNull).count() == 1) checkAnswer( df.select(columnNameOfCorruptRecord), - Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil + Row("0,2013-111_11 12:13:14") :: Row(null) :: Nil ) } @@ -2093,7 +2093,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa Seq("csv", "").foreach { reader => withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> reader) { withTempPath { path => - val df = Seq(("0", "2013-111-11")).toDF("a", "b") + val df = Seq(("0", "2013-111_11")).toDF("a", "b") df.write .option("header", "true") .csv(path.getAbsolutePath) @@ -2109,7 +2109,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schemaWithCorrField) .csv(path.getAbsoluteFile.toString) - checkAnswer(readDF, Row(0, null, "0,2013-111-11") :: Nil) + checkAnswer(readDF, Row(0, null, "0,2013-111_11") :: Nil) } } } @@ -2216,7 +2216,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa val readback = spark.read .option("mode", mode) .option("header", true) - .option("timestampFormat", "uuuu-MM-dd HH:mm:ss") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") .option("multiLine", multiLine) .schema("c0 string, c1 integer, c2 timestamp") .csv(path.getAbsolutePath) @@ -2235,7 +2235,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa } test("filters push down - malformed input in PERMISSIVE mode") { - val invalidTs = "2019-123-14 20:35:30" + val invalidTs = "2019-123_14 20:35:30" val invalidRow = s"0,$invalidTs,999" val validTs = "2019-12-14 20:35:30" Seq(true, false).foreach { filterPushdown => @@ -2252,7 +2252,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", "c3") .option("header", true) - .option("timestampFormat", "uuuu-MM-dd HH:mm:ss") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") .schema("c0 integer, c1 timestamp, c2 integer, c3 string") .csv(path.getAbsolutePath) .where(condition) @@ -2309,3 +2309,10 @@ class CSVv2Suite extends CSVSuite { .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") } + +class CSVLegacyTimeParserSuite extends CSVSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.LEGACY_TIME_PARSER_ENABLED, true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b20da2266b0f3..7abe818a29d9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2572,3 +2572,10 @@ class JsonV2Suite extends JsonSuite { .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") } + +class JsonLegacyTimeParserSuite extends JsonSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.LEGACY_TIME_PARSER_ENABLED, true) +} From 258bfcfe4a87fe1d6a0bc27afb97e6b223e420e8 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Thu, 13 Feb 2020 02:00:23 +0800 Subject: [PATCH 0060/1280] [SPARK-30651][SQL] Add detailed information for Aggregate operators in EXPLAIN FORMATTED ### What changes were proposed in this pull request? Currently `EXPLAIN FORMATTED` only report input attributes of HashAggregate/ObjectHashAggregate/SortAggregate, while `EXPLAIN EXTENDED` provides more information of Keys, Functions, etc. This PR enhanced `EXPLAIN FORMATTED` to sync with original explain behavior. ### Why are the changes needed? The newly added `EXPLAIN FORMATTED` got less information comparing to the original `EXPLAIN EXTENDED` ### Does this PR introduce any user-facing change? Yes, taking HashAggregate explain result as example. **SQL** ``` EXPLAIN FORMATTED SELECT COUNT(val) + SUM(key) as TOTAL, COUNT(key) FILTER (WHERE val > 1) FROM explain_temp1; ``` **EXPLAIN EXTENDED** ``` == Physical Plan == *(2) HashAggregate(keys=[], functions=[count(val#6), sum(cast(key#5 as bigint)), count(key#5)], output=[TOTAL#62L, count(key) FILTER (WHERE (val > 1))#71L]) +- Exchange SinglePartition, true, [id=#89] +- HashAggregate(keys=[], functions=[partial_count(val#6), partial_sum(cast(key#5 as bigint)), partial_count(key#5) FILTER (WHERE (val#6 > 1))], output=[count#75L, sum#76L, count#77L]) +- *(1) ColumnarToRow +- FileScan parquet default.explain_temp1[key#5,val#6] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/XXX/spark-dev/spark/spark-warehouse/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` **EXPLAIN FORMATTED - BEFORE** ``` == Physical Plan == * HashAggregate (5) +- Exchange (4) +- HashAggregate (3) +- * ColumnarToRow (2) +- Scan parquet default.explain_temp1 (1) ... ... (5) HashAggregate [codegen id : 2] Input: [count#91L, sum#92L, count#93L] ... ... ``` **EXPLAIN FORMATTED - AFTER** ``` == Physical Plan == * HashAggregate (5) +- Exchange (4) +- HashAggregate (3) +- * ColumnarToRow (2) +- Scan parquet default.explain_temp1 (1) ... ... (5) HashAggregate [codegen id : 2] Input: [count#91L, sum#92L, count#93L] Keys: [] Functions: [count(val#6), sum(cast(key#5 as bigint)), count(key#5)] Results: [(count(val#6)#84L + sum(cast(key#5 as bigint))#85L) AS TOTAL#78L, count(key#5)#86L AS count(key) FILTER (WHERE (val > 1))#87L] Output: [TOTAL#78L, count(key) FILTER (WHERE (val > 1))#87L] ... ... ``` ### How was this patch tested? Three tests added in explain.sql for HashAggregate/ObjectHashAggregate/SortAggregate. Closes #27368 from Eric5553/ExplainFormattedAgg. Authored-by: Eric Wu <492960551@qq.com> Signed-off-by: Wenchen Fan (cherry picked from commit 5919bd3b8d3ef3c3e957d8e3e245e00383b979bf) Signed-off-by: Wenchen Fan --- .../aggregate/BaseAggregateExec.scala | 48 ++++ .../aggregate/HashAggregateExec.scala | 2 +- .../aggregate/ObjectHashAggregateExec.scala | 2 +- .../aggregate/SortAggregateExec.scala | 4 +- .../resources/sql-tests/inputs/explain.sql | 22 +- .../sql-tests/results/explain.sql.out | 232 +++++++++++++++++- 6 files changed, 300 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala new file mode 100644 index 0000000000000..0eaa0f53fdacd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} + +/** + * Holds common logic for aggregate operators + */ +trait BaseAggregateExec extends UnaryExecNode { + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def resultExpressions: Seq[NamedExpression] + + override def verboseStringWithOperatorId(): String = { + val inputString = child.output.mkString("[", ", ", "]") + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = aggregateExpressions.mkString("[", ", ", "]") + val aggregateAttributeString = aggregateAttributes.mkString("[", ", ", "]") + val resultString = resultExpressions.mkString("[", ", ", "]") + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |Input: $inputString + |Keys: $keyString + |Functions: $functionString + |Aggregate Attributes: $aggregateAttributeString + |Results: $resultString + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f73e214a6b41f..7a26fd7a8541a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,7 +53,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { + extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 4376f6b6edd57..3fb58eb2cc8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index b6e684e62ea5c..77ed469016fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index d5253e3daddb0..497b61c6134a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -5,6 +5,7 @@ CREATE table explain_temp1 (key int, val int) USING PARQUET; CREATE table explain_temp2 (key int, val int) USING PARQUET; CREATE table explain_temp3 (key int, val int) USING PARQUET; +CREATE table explain_temp4 (key int, val string) USING PARQUET; SET spark.sql.codegen.wholeStage = true; @@ -61,7 +62,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0); @@ -93,6 +94,25 @@ EXPLAIN FORMATTED CREATE VIEW explain_view AS SELECT key, val FROM explain_temp1; +-- HashAggregate +EXPLAIN FORMATTED + SELECT + COUNT(val) + SUM(key) as TOTAL, + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1; + +-- ObjectHashAggregate +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key; + +-- SortAggregate +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key; + -- cleanup DROP TABLE explain_temp1; DROP TABLE explain_temp2; diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 756c14f28a657..bc28d7f87bf00 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 22 -- !query @@ -26,6 +26,14 @@ struct<> +-- !query +CREATE table explain_temp4 (key int, val string) USING PARQUET +-- !query schema +struct<> +-- !query output + + + -- !query SET spark.sql.codegen.wholeStage = true -- !query schema @@ -76,12 +84,20 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (8) Exchange Input: [key#x, max(val)#x] @@ -132,12 +148,20 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x] (8) Filter [codegen id : 2] Input : [key#x, max(val)#x, max(val#x)#x] @@ -211,12 +235,20 @@ Input : [key#x, val#x] (10) HashAggregate [codegen id : 3] Input: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] +Aggregate Attributes: [] +Results: [key#x, val#x] (11) Exchange Input: [key#x, val#x] (12) HashAggregate [codegen id : 4] Input: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] +Aggregate Attributes: [] +Results: [key#x, val#x] -- !query @@ -413,12 +445,20 @@ Input : [key#x, val#x] (9) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (10) Exchange Input: [max#x] (11) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (18) @@ -450,12 +490,20 @@ Input : [key#x, val#x] (16) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (17) Exchange Input: [max#x] (18) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] -- !query @@ -466,7 +514,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0) -- !query schema @@ -489,7 +537,7 @@ Input: [key#x, val#x] (3) Filter [codegen id : 1] Input : [key#x, val#x] -Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (key#x = Subquery scalar-subquery#x, [id=#x])) +Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (cast(key#x as double) = Subquery scalar-subquery#x, [id=#x])) ===== Subqueries ===== @@ -523,12 +571,20 @@ Input : [key#x, val#x] (8) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (9) Exchange Input: [max#x] (10) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (17) @@ -560,12 +616,20 @@ Input : [key#x, val#x] (15) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_avg(cast(key#x as bigint))] +Aggregate Attributes: [sum#x, count#xL] +Results: [sum#x, count#xL] (16) Exchange -Input: [max#x] +Input: [sum#x, count#xL] (17) HashAggregate [codegen id : 2] -Input: [max#x] +Input: [sum#x, count#xL] +Keys: [] +Functions: [avg(cast(key#x as bigint))] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] -- !query @@ -615,12 +679,20 @@ Input: [key#x] (6) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_avg(cast(key#x as bigint))] +Aggregate Attributes: [sum#x, count#xL] +Results: [sum#x, count#xL] (7) Exchange Input: [sum#x, count#xL] (8) HashAggregate [codegen id : 2] Input: [sum#x, count#xL] +Keys: [] +Functions: [avg(cast(key#x as bigint))] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -740,18 +812,30 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 4] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (8) ReusedExchange [Reuses operator id: 6] Output : ArrayBuffer(key#x, max#x) (9) HashAggregate [codegen id : 3] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (10) BroadcastExchange Input: [key#x, max(val)#x] @@ -786,6 +870,144 @@ Output: [] (4) Project +-- !query +EXPLAIN FORMATTED + SELECT + COUNT(val) + SUM(key) as TOTAL, + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1 +-- !query schema +struct +-- !query output +== Physical Plan == +* HashAggregate (5) ++- Exchange (4) + +- HashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp1 (1) + + +(1) Scan parquet default.explain_temp1 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp1] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) HashAggregate +Input: [key#x, val#x] +Keys: [] +Functions: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Aggregate Attributes: [count#xL, sum#xL, count#xL] +Results: [count#xL, sum#xL, count#xL] + +(4) Exchange +Input: [count#xL, sum#xL, count#xL] + +(5) HashAggregate [codegen id : 2] +Input: [count#xL, sum#xL, count#xL] +Keys: [] +Functions: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] +Aggregate Attributes: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] +Results: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] + + +-- !query +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +ObjectHashAggregate (5) ++- Exchange (4) + +- ObjectHashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) ObjectHashAggregate +Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_collect_set(val#x, 0, 0)] +Aggregate Attributes: [buf#x] +Results: [key#x, buf#x] + +(4) Exchange +Input: [key#x, buf#x] + +(5) ObjectHashAggregate +Input: [key#x, buf#x] +Keys: [key#x] +Functions: [collect_set(val#x, 0, 0)] +Aggregate Attributes: [collect_set(val#x, 0, 0)#x] +Results: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS sort_array(collect_set(val), true)[0]#x] + + +-- !query +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +SortAggregate (7) ++- * Sort (6) + +- Exchange (5) + +- SortAggregate (4) + +- * Sort (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) Sort [codegen id : 1] +Input: [key#x, val#x] + +(4) SortAggregate +Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_min(val#x)] +Aggregate Attributes: [min#x] +Results: [key#x, min#x] + +(5) Exchange +Input: [key#x, min#x] + +(6) Sort [codegen id : 2] +Input: [key#x, min#x] + +(7) SortAggregate +Input: [key#x, min#x] +Keys: [key#x] +Functions: [min(val#x)] +Aggregate Attributes: [min(val#x)#x] +Results: [key#x, min(val#x)#x AS min(val)#x] + + -- !query DROP TABLE explain_temp1 -- !query schema From a5bf41fc7cbf6c9c3c613e87f37a1cbed64fa32f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 13 Feb 2020 02:31:48 +0800 Subject: [PATCH 0061/1280] [SPARK-30760][SQL] Port `millisToDays` and `daysToMillis` on Java 8 time API ### What changes were proposed in this pull request? In the PR, I propose to rewrite the `millisToDays` and `daysToMillis` of `DateTimeUtils` using Java 8 time API. I removed `getOffsetFromLocalMillis` from `DateTimeUtils` because it is a private methods, and is not used anymore in Spark SQL. ### Why are the changes needed? New implementation is based on Proleptic Gregorian calendar which has been already used by other date-time functions. This changes make `millisToDays` and `daysToMillis` consistent to rest Spark SQL API related to date & time operations. ### Does this PR introduce any user-facing change? Yes, this might effect behavior for old dates before 1582 year. ### How was this patch tested? By existing test suites `DateTimeUtilsSuite`, `DateFunctionsSuite`, DateExpressionsSuite`, `SQLQuerySuite` and `HiveResultSuite`. Closes #27494 from MaxGekk/millis-2-days-java8-api. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan (cherry picked from commit aa0d13683cdf9f38f04cc0e73dc8cf63eed29bf4) Signed-off-by: Wenchen Fan --- .../expressions/datetimeExpressions.scala | 8 +-- .../sql/catalyst/util/DateTimeUtils.scala | 58 +++++-------------- .../catalyst/csv/UnivocityParserSuite.scala | 3 +- .../expressions/DateExpressionsSuite.scala | 19 +++--- .../catalyst/util/DateTimeUtilsSuite.scala | 34 ++++++----- .../spark/sql/execution/HiveResult.scala | 5 ++ .../sql-tests/results/postgreSQL/date.sql.out | 12 ++-- .../apache/spark/sql/SQLQueryTestSuite.scala | 1 + 8 files changed, 62 insertions(+), 78 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1f4c8c041f8bf..cf91489d8e6b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -135,7 +135,7 @@ case class CurrentBatchTimestamp( def toLiteral: Literal = dataType match { case _: TimestampType => Literal(DateTimeUtils.fromJavaTimestamp(new Timestamp(timestampMs)), TimestampType) - case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs, timeZone), DateType) + case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs, zoneId), DateType) } } @@ -1332,14 +1332,14 @@ case class MonthsBetween( override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { DateTimeUtils.monthsBetween( - t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], zoneId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (d1, d2, roundOff) => { - s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $zid)""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 8eb560944d4cb..01d36f19fc06f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -67,24 +67,22 @@ object DateTimeUtils { // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisUtc: Long): SQLDate = { - millisToDays(millisUtc, defaultTimeZone()) + millisToDays(millisUtc, defaultTimeZone().toZoneId) } - def millisToDays(millisUtc: Long, timeZone: TimeZone): SQLDate = { - // SPARK-6785: use Math.floorDiv so negative number of days (dates before 1970) - // will correctly work as input for function toJavaDate(Int) - val millisLocal = millisUtc + timeZone.getOffset(millisUtc) - Math.floorDiv(millisLocal, MILLIS_PER_DAY).toInt + def millisToDays(millisUtc: Long, zoneId: ZoneId): SQLDate = { + val instant = microsToInstant(Math.multiplyExact(millisUtc, MICROS_PER_MILLIS)) + localDateToDays(LocalDateTime.ofInstant(instant, zoneId).toLocalDate) } // reverse of millisToDays def daysToMillis(days: SQLDate): Long = { - daysToMillis(days, defaultTimeZone()) + daysToMillis(days, defaultTimeZone().toZoneId) } - def daysToMillis(days: SQLDate, timeZone: TimeZone): Long = { - val millisLocal = days.toLong * MILLIS_PER_DAY - millisLocal - getOffsetFromLocalMillis(millisLocal, timeZone) + def daysToMillis(days: SQLDate, zoneId: ZoneId): Long = { + val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant + instantToMicros(instant) / MICROS_PER_MILLIS } // Converts Timestamp to string according to Hive TimestampWritable convention. @@ -587,11 +585,11 @@ object DateTimeUtils { time1: SQLTimestamp, time2: SQLTimestamp, roundOff: Boolean, - timeZone: TimeZone): Double = { + zoneId: ZoneId): Double = { val millis1 = MICROSECONDS.toMillis(time1) val millis2 = MICROSECONDS.toMillis(time2) - val date1 = millisToDays(millis1, timeZone) - val date2 = millisToDays(millis2, timeZone) + val date1 = millisToDays(millis1, zoneId) + val date2 = millisToDays(millis2, zoneId) val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1) val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2) @@ -605,8 +603,8 @@ object DateTimeUtils { } // using milliseconds can cause precision loss with more than 8 digits // we follow Hive's implementation which uses seconds - val secondsInDay1 = MILLISECONDS.toSeconds(millis1 - daysToMillis(date1, timeZone)) - val secondsInDay2 = MILLISECONDS.toSeconds(millis2 - daysToMillis(date2, timeZone)) + val secondsInDay1 = MILLISECONDS.toSeconds(millis1 - daysToMillis(date1, zoneId)) + val secondsInDay2 = MILLISECONDS.toSeconds(millis2 - daysToMillis(date2, zoneId)) val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 val secondsInMonth = DAYS.toSeconds(31) val diff = monthDiff + secondsDiff / secondsInMonth.toDouble @@ -735,8 +733,8 @@ object DateTimeUtils { millis += offset millis - millis % MILLIS_PER_DAY - offset case _ => // Try to truncate date levels - val dDays = millisToDays(millis, timeZone) - daysToMillis(truncDate(dDays, level), timeZone) + val dDays = millisToDays(millis, timeZone.toZoneId) + daysToMillis(truncDate(dDays, level), timeZone.toZoneId) } truncated * MICROS_PER_MILLIS } @@ -768,32 +766,6 @@ object DateTimeUtils { } } - /** - * Lookup the offset for given millis seconds since 1970-01-01 00:00:00 in given timezone. - * TODO: Improve handling of normalization differences. - * TODO: Replace with JSR-310 or similar system - see SPARK-16788 - */ - private[sql] def getOffsetFromLocalMillis(millisLocal: Long, tz: TimeZone): Long = { - var guess = tz.getRawOffset - // the actual offset should be calculated based on milliseconds in UTC - val offset = tz.getOffset(millisLocal - guess) - if (offset != guess) { - guess = tz.getOffset(millisLocal - offset) - if (guess != offset) { - // fallback to do the reverse lookup using java.time.LocalDateTime - // this should only happen near the start or end of DST - val localDate = LocalDate.ofEpochDay(MILLISECONDS.toDays(millisLocal)) - val localTime = LocalTime.ofNanoOfDay(MILLISECONDS.toNanos( - Math.floorMod(millisLocal, MILLIS_PER_DAY))) - val localDateTime = LocalDateTime.of(localDate, localTime) - val millisEpoch = localDateTime.atZone(tz.toZoneId).toInstant.toEpochMilli - - guess = (millisLocal - millisEpoch).toInt - } - } - guess - } - /** * Convert the timestamp `ts` from one timezone to another. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 77a2ca7e4a828..536c76f042d23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.time.ZoneOffset import java.util.{Locale, TimeZone} import org.apache.commons.lang3.time.FastDateFormat @@ -137,7 +138,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { val expectedDate = format.parse(customDate).getTime val castedDate = parser.makeConverter("_1", DateType, nullable = true) .apply(customDate) - assert(castedDate == DateTimeUtils.millisToDays(expectedDate, TimeZone.getTimeZone("GMT"))) + assert(castedDate == DateTimeUtils.millisToDays(expectedDate, ZoneOffset.UTC)) val timestamp = "2015-01-01 00:00:00" timestampsOptions = new CSVOptions(Map( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f04149ab7eb29..39b859af47ca9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -56,9 +56,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val ts = new Timestamp(toMillis(time)) test("datetime function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis(), TimeZoneGMT) + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis(), ZoneOffset.UTC) val cd = CurrentDate(gmtId).eval(EmptyRow).asInstanceOf[Int] - val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis(), TimeZoneGMT) + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis(), ZoneOffset.UTC) assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) val cdjst = CurrentDate(jstId).eval(EmptyRow).asInstanceOf[Int] @@ -499,7 +499,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Valid range of DateType is [0001-01-01, 9999-12-31] val maxMonthInterval = 10000 * 12 checkEvaluation( - AddMonths(Literal(Date.valueOf("0001-01-01")), Literal(maxMonthInterval)), 2933261) + AddMonths(Literal(LocalDate.parse("0001-01-01")), Literal(maxMonthInterval)), + LocalDate.of(10001, 1, 1).toEpochDay.toInt) checkEvaluation( AddMonths(Literal(Date.valueOf("9999-12-31")), Literal(-1 * maxMonthInterval)), -719529) // Test evaluation results between Interpreted mode and Codegen mode @@ -788,7 +789,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), @@ -796,7 +797,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(UnixTimestamp( Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz.toZoneId))) val t1 = UnixTimestamp( CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] val t2 = UnixTimestamp( @@ -814,7 +815,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) } @@ -852,7 +853,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( ToUnixTimestamp( Literal(sdf2.format(new Timestamp(-1000000))), @@ -861,7 +862,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ToUnixTimestamp( Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz.toZoneId))) val t1 = ToUnixTimestamp( CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] val t2 = ToUnixTimestamp( @@ -876,7 +877,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ToUnixTimestamp( Literal(date1), Literal.create(null, StringType), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( ToUnixTimestamp( Literal("2015-07-24"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index cabcd3007d1c0..cd0594c775a47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -86,9 +86,13 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { } test("SPARK-6785: java date conversion before and after epoch") { + def format(d: Date): String = { + TimestampFormatter("uuuu-MM-dd", defaultTimeZone().toZoneId) + .format(d.getTime * MICROS_PER_MILLIS) + } def checkFromToJavaDate(d1: Date): Unit = { val d2 = toJavaDate(fromJavaDate(d1)) - assert(d2.toString === d1.toString) + assert(format(d2) === format(d1)) } val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) @@ -413,22 +417,22 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { test("monthsBetween") { val date1 = date(1997, 2, 28, 10, 30, 0) var date2 = date(1996, 10, 30) - assert(monthsBetween(date1, date2, true, TimeZoneUTC) === 3.94959677) - assert(monthsBetween(date1, date2, false, TimeZoneUTC) === 3.9495967741935485) + assert(monthsBetween(date1, date2, true, ZoneOffset.UTC) === 3.94959677) + assert(monthsBetween(date1, date2, false, ZoneOffset.UTC) === 3.9495967741935485) Seq(true, false).foreach { roundOff => date2 = date(2000, 2, 28) - assert(monthsBetween(date1, date2, roundOff, TimeZoneUTC) === -36) + assert(monthsBetween(date1, date2, roundOff, ZoneOffset.UTC) === -36) date2 = date(2000, 2, 29) - assert(monthsBetween(date1, date2, roundOff, TimeZoneUTC) === -36) + assert(monthsBetween(date1, date2, roundOff, ZoneOffset.UTC) === -36) date2 = date(1996, 3, 31) - assert(monthsBetween(date1, date2, roundOff, TimeZoneUTC) === 11) + assert(monthsBetween(date1, date2, roundOff, ZoneOffset.UTC) === 11) } val date3 = date(2000, 2, 28, 16, tz = TimeZonePST) val date4 = date(1997, 2, 28, 16, tz = TimeZonePST) - assert(monthsBetween(date3, date4, true, TimeZonePST) === 36.0) - assert(monthsBetween(date3, date4, true, TimeZoneGMT) === 35.90322581) - assert(monthsBetween(date3, date4, false, TimeZoneGMT) === 35.903225806451616) + assert(monthsBetween(date3, date4, true, TimeZonePST.toZoneId) === 36.0) + assert(monthsBetween(date3, date4, true, ZoneOffset.UTC) === 35.90322581) + assert(monthsBetween(date3, date4, false, ZoneOffset.UTC) === 35.903225806451616) } test("from UTC timestamp") { @@ -571,15 +575,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { test("daysToMillis and millisToDays") { val input = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, 16, tz = TimeZonePST)) - assert(millisToDays(input, TimeZonePST) === 16800) - assert(millisToDays(input, TimeZoneGMT) === 16801) - assert(millisToDays(-1 * MILLIS_PER_DAY + 1, TimeZoneGMT) == -1) + assert(millisToDays(input, TimeZonePST.toZoneId) === 16800) + assert(millisToDays(input, ZoneOffset.UTC) === 16801) + assert(millisToDays(-1 * MILLIS_PER_DAY + 1, ZoneOffset.UTC) == -1) var expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, tz = TimeZonePST)) - assert(daysToMillis(16800, TimeZonePST) === expected) + assert(daysToMillis(16800, TimeZonePST.toZoneId) === expected) expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, tz = TimeZoneGMT)) - assert(daysToMillis(16800, TimeZoneGMT) === expected) + assert(daysToMillis(16800, ZoneOffset.UTC) === expected) // There are some days are skipped entirely in some timezone, skip them here. val skipped_days = Map[String, Set[Int]]( @@ -594,7 +598,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { val skipped = skipped_days.getOrElse(tz.getID, Set.empty) (-20000 to 20000).foreach { d => if (!skipped.contains(d)) { - assert(millisToDays(daysToMillis(d, tz), tz) === d, + assert(millisToDays(daysToMillis(d, tz.toZoneId), tz.toZoneId) === d, s"Round trip of ${d} did not work in tz ${tz}") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index bbe47a63f4d61..5a2f16d8e1526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} @@ -67,8 +68,12 @@ object HiveResult { case (null, _) => if (nested) "null" else "NULL" case (b, BooleanType) => b.toString case (d: Date, DateType) => dateFormatter.format(DateTimeUtils.fromJavaDate(d)) + case (ld: LocalDate, DateType) => + dateFormatter.format(DateTimeUtils.localDateToDays(ld)) case (t: Timestamp, TimestampType) => timestampFormatter.format(DateTimeUtils.fromJavaTimestamp(t)) + case (i: Instant, TimestampType) => + timestampFormatter.format(DateTimeUtils.instantToMicros(i)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => decimal.toPlainString case (n, _: NumericType) => n.toString diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out index fd5dc42632176..ed27317121623 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out @@ -800,7 +800,7 @@ SELECT DATE_TRUNC('MILLENNIUM', TIMESTAMP '1970-03-20 04:30:00.00000') -- !query schema struct -- !query output -1001-01-01 00:07:02 +1001-01-01 00:00:00 -- !query @@ -808,7 +808,7 @@ SELECT DATE_TRUNC('MILLENNIUM', DATE '1970-03-20') -- !query schema struct -- !query output -1001-01-01 00:07:02 +1001-01-01 00:00:00 -- !query @@ -840,7 +840,7 @@ SELECT DATE_TRUNC('CENTURY', DATE '0002-02-04') -- !query schema struct -- !query output -0001-01-01 00:07:02 +0001-01-01 00:00:00 -- !query @@ -848,7 +848,7 @@ SELECT DATE_TRUNC('CENTURY', TO_DATE('0055-08-10 BC', 'yyyy-MM-dd G')) -- !query schema struct -- !query output --0099-01-01 00:07:02 +-0099-01-01 00:00:00 -- !query @@ -864,7 +864,7 @@ SELECT DATE_TRUNC('DECADE', DATE '0004-12-25') -- !query schema struct -- !query output -0000-01-01 00:07:02 +0000-01-01 00:00:00 -- !query @@ -872,7 +872,7 @@ SELECT DATE_TRUNC('DECADE', TO_DATE('0002-12-31 BC', 'yyyy-MM-dd G')) -- !query schema struct -- !query output --0010-01-01 00:07:02 +-0010-01-01 00:00:00 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 2e5a9e0b4d45d..6b9e5bbd3c961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -337,6 +337,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true) case _ => } + localSparkSession.conf.set(SQLConf.DATETIME_JAVA8API_ENABLED.key, true) if (configSet.nonEmpty) { // Execute the list of set operation in order to add the desired configs From 82981737f762760c07ae82464b15dc866a2b64e5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Feb 2020 14:27:18 -0800 Subject: [PATCH 0062/1280] [SPARK-30797][SQL] Set tradition user/group/other permission to ACL entries when setting up ACLs in truncate table ### What changes were proposed in this pull request? This is a follow-up to the PR #26956. In #26956, the patch proposed to preserve path permission when truncating table. When setting up original ACLs, we need to set user/group/other permission as ACL entries too, otherwise if the path doesn't have default user/group/other ACL entries, ACL API will complain an error `Invalid ACL: the user, group and other entries are required.`. In short this change makes sure: 1. Permissions for user/group/other are always kept into ACLs to work with ACL API. 2. Other custom ACLs are still kept after TRUNCATE TABLE (#26956 did this). ### Why are the changes needed? Without this fix, `TRUNCATE TABLE` will get an error when setting up ACLs if there is no default default user/group/other ACL entries. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Update unit test. Manual test on dev Spark cluster. Set ACLs for a table path without default user/group/other ACL entries: ``` hdfs dfs -setfacl --set 'user:liangchi:rwx,user::rwx,group::r--,other::r--' /user/hive/warehouse/test.db/test_truncate_table hdfs dfs -getfacl /user/hive/warehouse/test.db/test_truncate_table # file: /user/hive/warehouse/test.db/test_truncate_table # owner: liangchi # group: supergroup user::rwx user:liangchi:rwx group::r-- mask::rwx other::r-- ``` Then run `sql("truncate table test.test_truncate_table")`, it works by normally truncating the table and preserve ACLs. Closes #27548 from viirya/fix-truncate-table-permission. Lead-authored-by: Liang-Chi Hsieh Co-authored-by: Liang-Chi Hsieh Signed-off-by: Dongjoon Hyun (cherry picked from commit 5b76367a9d0aaca53ce96ab7e555a596567e8335) Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/command/tables.scala | 32 +++++++++++++++++-- .../sql/execution/command/DDLSuite.scala | 21 +++++++++++- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 90dbdf5515d4d..61500b773cd7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.command import java.net.{URI, URISyntaxException} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileContext, FsConstants, Path} -import org.apache.hadoop.fs.permission.{AclEntry, FsPermission} +import org.apache.hadoop.fs.permission.{AclEntry, AclEntryScope, AclEntryType, FsAction, FsPermission} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -538,12 +539,27 @@ case class TruncateTableCommand( } } optAcls.foreach { acls => + val aclEntries = acls.asScala.filter(_.getName != null).asJava + + // If the path doesn't have default ACLs, `setAcl` API will throw an error + // as it expects user/group/other permissions must be in ACL entries. + // So we need to add tradition user/group/other permission + // in the form of ACL. + optPermission.map { permission => + aclEntries.add(newAclEntry(AclEntryScope.ACCESS, + AclEntryType.USER, permission.getUserAction())) + aclEntries.add(newAclEntry(AclEntryScope.ACCESS, + AclEntryType.GROUP, permission.getGroupAction())) + aclEntries.add(newAclEntry(AclEntryScope.ACCESS, + AclEntryType.OTHER, permission.getOtherAction())) + } + try { - fs.setAcl(path, acls) + fs.setAcl(path, aclEntries) } catch { case NonFatal(e) => throw new SecurityException( - s"Failed to set original ACL $acls back to " + + s"Failed to set original ACL $aclEntries back to " + s"the created path: $path. Exception: ${e.getMessage}") } } @@ -574,6 +590,16 @@ case class TruncateTableCommand( } Seq.empty[Row] } + + private def newAclEntry( + scope: AclEntryScope, + aclType: AclEntryType, + permission: FsAction): AclEntry = { + new AclEntry.Builder() + .setScope(scope) + .setType(aclType) + .setPermission(permission).build() + } } abstract class DescribeCommandBase extends RunnableCommand { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 31e00781ae6b4..dbf4b09403423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2042,6 +2042,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // Set ACL to table path. val customAcl = new java.util.ArrayList[AclEntry]() customAcl.add(new AclEntry.Builder() + .setName("test") .setType(AclEntryType.USER) .setScope(AclEntryScope.ACCESS) .setPermission(FsAction.READ).build()) @@ -2061,8 +2062,26 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { if (ignore) { assert(aclEntries.size() == 0) } else { - assert(aclEntries.size() == 1) + assert(aclEntries.size() == 4) assert(aclEntries.get(0) == customAcl.get(0)) + + // Setting ACLs will also set user/group/other permissions + // as ACL entries. + val user = new AclEntry.Builder() + .setType(AclEntryType.USER) + .setScope(AclEntryScope.ACCESS) + .setPermission(FsAction.ALL).build() + val group = new AclEntry.Builder() + .setType(AclEntryType.GROUP) + .setScope(AclEntryScope.ACCESS) + .setPermission(FsAction.ALL).build() + val other = new AclEntry.Builder() + .setType(AclEntryType.OTHER) + .setScope(AclEntryScope.ACCESS) + .setPermission(FsAction.ALL).build() + assert(aclEntries.get(1) == user) + assert(aclEntries.get(2) == group) + assert(aclEntries.get(3) == other) } } } From 8ab6ae3ede96adb093347470a5cbbf17fe8c04e9 Mon Sep 17 00:00:00 2001 From: iRakson Date: Thu, 13 Feb 2020 12:23:40 +0800 Subject: [PATCH 0063/1280] [SPARK-30790] The dataType of map() should be map ### What changes were proposed in this pull request? `spark.sql("select map()")` returns {}. After these changes it will return map ### Why are the changes needed? After changes introduced due to #27521, it is important to maintain consistency while using map(). ### Does this PR introduce any user-facing change? Yes. Now map() will give map instead of {}. ### How was this patch tested? UT added. Migration guide updated as well Closes #27542 from iRakson/SPARK-30790. Authored-by: iRakson Signed-off-by: Wenchen Fan (cherry picked from commit 926e3a1efe9e142804fcbf52146b22700640ae1b) Signed-off-by: Wenchen Fan --- docs/sql-migration-guide.md | 2 +- .../expressions/complexTypeCreator.scala | 14 ++++++++--- .../catalyst/util/ArrayBasedMapBuilder.scala | 5 ++-- .../apache/spark/sql/internal/SQLConf.scala | 10 ++++---- .../spark/sql/DataFrameFunctionsSuite.scala | 25 +++++++++++++------ 5 files changed, 36 insertions(+), 20 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index f98fab5b4c56b..46b741687363f 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -216,7 +216,7 @@ license: | - Since Spark 3.0, the `size` function returns `NULL` for the `NULL` input. In Spark version 2.4 and earlier, this function gives `-1` for the same input. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.sizeOfNull` to `true`. - - Since Spark 3.0, when the `array` function is called without any parameters, it returns an empty array of `NullType`. In Spark version 2.4 and earlier, it returns an empty array of string type. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.arrayDefaultToStringType.enabled` to `true`. + - Since Spark 3.0, when the `array`/`map` function is called without any parameters, it returns an empty collection with `NullType` as element type. In Spark version 2.4 and earlier, it returns an empty collection with `StringType` as element type. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.createEmptyCollectionUsingStringType` to `true`. - Since Spark 3.0, the interval literal syntax does not allow multiple from-to units anymore. For example, `SELECT INTERVAL '1-1' YEAR TO MONTH '2-2' YEAR TO MONTH'` throws parser exception. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 7335e305bfe55..4bd85d304ded2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -46,7 +46,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } private val defaultElementType: DataType = { - if (SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING)) { + if (SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) { StringType } else { NullType @@ -145,6 +145,14 @@ case class CreateMap(children: Seq[Expression]) extends Expression { lazy val keys = children.indices.filter(_ % 2 == 0).map(children) lazy val values = children.indices.filter(_ % 2 != 0).map(children) + private val defaultElementType: DataType = { + if (SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) { + StringType + } else { + NullType + } + } + override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { @@ -167,9 +175,9 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override lazy val dataType: MapType = { MapType( keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) - .getOrElse(StringType), + .getOrElse(defaultElementType), valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) - .getOrElse(StringType), + .getOrElse(defaultElementType), valueContainsNull = values.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index 98934368205ec..37d65309e2b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -29,12 +29,11 @@ import org.apache.spark.unsafe.array.ByteArrayMethods */ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Serializable { assert(!keyType.existsRecursively(_.isInstanceOf[MapType]), "key of map cannot be/contain map") - assert(keyType != NullType, "map key cannot be null type.") private lazy val keyToIndex = keyType match { // Binary type data is `byte[]`, which can't use `==` to check equality. - case _: AtomicType | _: CalendarIntervalType if !keyType.isInstanceOf[BinaryType] => - new java.util.HashMap[Any, Int]() + case _: AtomicType | _: CalendarIntervalType | _: NullType + if !keyType.isInstanceOf[BinaryType] => new java.util.HashMap[Any, Int]() case _ => // for complex types, use interpreted ordering to be able to compare unsafe data with safe // data, e.g. UnsafeRow vs GenericInternalRow. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b79b767dbb22b..442711db93f0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2007,12 +2007,12 @@ object SQLConf { .booleanConf .createWithDefault(false) - val LEGACY_ARRAY_DEFAULT_TO_STRING = - buildConf("spark.sql.legacy.arrayDefaultToStringType.enabled") + val LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE = + buildConf("spark.sql.legacy.createEmptyCollectionUsingStringType") .internal() - .doc("When set to true, it returns an empty array of string type when the `array` " + - "function is called without any parameters. Otherwise, it returns an empty " + - "array of `NullType`") + .doc("When set to true, Spark returns an empty collection with `StringType` as element " + + "type if the `array`/`map` function is called without any parameters. Otherwise, Spark " + + "returns an empty collection with `NullType` as element type.") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6012678341ccc..f7531ea446015 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3499,13 +3499,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } - test("SPARK-21281 use string types by default if map have no argument") { - val ds = spark.range(1) - var expectedSchema = new StructType() - .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) - assert(ds.select(map().as("x")).schema == expectedSchema) - } - test("SPARK-21281 fails if functions have no argument") { val df = Seq(1).toDF("a") @@ -3563,7 +3556,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-29462: Empty array of NullType for array function with no arguments") { Seq((true, StringType), (false, NullType)).foreach { case (arrayDefaultToString, expectedType) => - withSQLConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING.key -> arrayDefaultToString.toString) { + withSQLConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE.key -> + arrayDefaultToString.toString) { val schema = spark.range(1).select(array()).schema assert(schema.nonEmpty && schema.head.dataType.isInstanceOf[ArrayType]) val actualType = schema.head.dataType.asInstanceOf[ArrayType].elementType @@ -3571,6 +3565,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-30790: Empty map with NullType as key/value type for map function with no argument") { + Seq((true, StringType), (false, NullType)).foreach { + case (mapDefaultToString, expectedType) => + withSQLConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE.key -> + mapDefaultToString.toString) { + val schema = spark.range(1).select(map()).schema + assert(schema.nonEmpty && schema.head.dataType.isInstanceOf[MapType]) + val actualKeyType = schema.head.dataType.asInstanceOf[MapType].keyType + val actualValueType = schema.head.dataType.asInstanceOf[MapType].valueType + assert(actualKeyType === expectedType) + assert(actualValueType === expectedType) + } + } + } } object DataFrameFunctionsSuite { From a2c46334d44355f022e0498b8e2de71d7c91a533 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Feb 2020 16:55:45 -0800 Subject: [PATCH 0064/1280] [SPARK-30743][K8S][TESTS] Use JRE instead of JDK in K8S test docker image ### What changes were proposed in this pull request? This PR aims to replace JDK to JRE in K8S integration test docker images. ### Why are the changes needed? This will save some resources and make it sure that we only need JRE at runtime and testing. - https://lists.apache.org/thread.html/3145150b711d7806a86bcd3ab43e18bcd0e4892ab5f11600689ba087%40%3Cdev.spark.apache.org%3E ### Does this PR introduce any user-facing change? No. This is a dev-only test environment. ### How was this patch tested? Pass the Jenkins K8s Integration Test. - https://github.com/apache/spark/pull/27469#issuecomment-582681125 Closes #27469 from dongjoon-hyun/SPARK-30743. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit 9d907bc84df2f6f7e1abdb810b761a65ac6ce064) Signed-off-by: Dongjoon Hyun --- .../kubernetes/docker/src/main/dockerfiles/spark/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index b6eeff1cd18a9..a1fc63789bc61 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -15,7 +15,7 @@ # limitations under the License. # -FROM openjdk:8-jdk-slim +FROM openjdk:8-jre-slim ARG spark_uid=185 From 59a13c9b7bc3b3aa5b5bc30a60344f849c0f8012 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 13 Feb 2020 19:32:38 +0800 Subject: [PATCH 0065/1280] [SPARK-30528][SQL] Turn off DPP subquery duplication by default ### What changes were proposed in this pull request? This PR adds a config for Dynamic Partition Pruning subquery duplication and turns it off by default due to its potential performance regression. When planning a DPP filter, it seeks to reuse the broadcast exchange relation if the corresponding join is a BHJ with the filter relation being on the build side, otherwise it will either opt out or plan the filter as an un-reusable subquery duplication based on the cost estimate. However, the cost estimate is not accurate and only takes into account the table scan overhead, thus adding an un-reusable subquery duplication DPP filter can sometimes cause perf regression. This PR turns off the subquery duplication DPP filter by: 1. adding a config `spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcastOnly` and setting it `true` by default. 2. removing the existing meaningless config `spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcast` since we always want to reuse broadcast results if possible. ### Why are the changes needed? This is to fix a potential performance regression caused by DPP. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Updated DynamicPartitionPruningSuite to test the new configuration. Closes #27551 from maryannxue/spark-30528. Authored-by: maryannxue Signed-off-by: Wenchen Fan (cherry picked from commit 453d5261b22ebcdd5886e65ab9d0d9857051e76a) Signed-off-by: Wenchen Fan --- .../apache/spark/sql/internal/SQLConf.scala | 12 +- .../sql/dynamicpruning/PartitionPruning.scala | 4 +- .../PlanDynamicPruningFilters.scala | 5 +- .../sql/DynamicPartitionPruningSuite.scala | 183 +++++++----------- .../org/apache/spark/sql/ExplainSuite.scala | 3 +- 5 files changed, 82 insertions(+), 125 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 442711db93f0e..19c94e23e046e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -259,11 +259,11 @@ object SQLConf { .doubleConf .createWithDefault(0.5) - val DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST = - buildConf("spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcast") + val DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY = + buildConf("spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcastOnly") .internal() - .doc("When true, dynamic partition pruning will seek to reuse the broadcast results from " + - "a broadcast hash join operation.") + .doc("When true, dynamic partition pruning will only apply when the broadcast exchange of " + + "a broadcast hash join operation can be reused as the dynamic pruning filter.") .booleanConf .createWithDefault(true) @@ -2303,8 +2303,8 @@ class SQLConf extends Serializable with Logging { def dynamicPartitionPruningFallbackFilterRatio: Double = getConf(DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO) - def dynamicPartitionPruningReuseBroadcast: Boolean = - getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST) + def dynamicPartitionPruningReuseBroadcastOnly: Boolean = + getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY) def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala index 48ba8618f272e..28f8f49d2ce44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala @@ -86,7 +86,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { filteringPlan: LogicalPlan, joinKeys: Seq[Expression], hasBenefit: Boolean): LogicalPlan = { - val reuseEnabled = SQLConf.get.dynamicPartitionPruningReuseBroadcast + val reuseEnabled = SQLConf.get.exchangeReuseEnabled val index = joinKeys.indexOf(filteringKey) if (hasBenefit || reuseEnabled) { // insert a DynamicPruning wrapper to identify the subquery during query planning @@ -96,7 +96,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { filteringPlan, joinKeys, index, - !hasBenefit), + !hasBenefit || SQLConf.get.dynamicPartitionPruningReuseBroadcastOnly), pruningPlan) } else { // abort dynamic partition pruning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala index 1398dc049dd99..be00f728aa3ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala @@ -36,9 +36,6 @@ import org.apache.spark.sql.internal.SQLConf case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[SparkPlan] with PredicateHelper { - private def reuseBroadcast: Boolean = - SQLConf.get.dynamicPartitionPruningReuseBroadcast && SQLConf.get.exchangeReuseEnabled - /** * Identify the shape in which keys of a given plan are broadcasted. */ @@ -59,7 +56,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) sparkSession, sparkSession.sessionState.planner, buildPlan) // Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is // the first to be applied (apart from `InsertAdaptiveSparkPlan`). - val canReuseExchange = reuseBroadcast && buildKeys.nonEmpty && + val canReuseExchange = SQLConf.get.exchangeReuseEnabled && buildKeys.nonEmpty && plan.find { case BroadcastHashJoinExec(_, _, _, BuildLeft, _, left, _) => left.sameResult(sparkPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index e1f9bcc4e008d..f7b51d6f4c8ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -239,7 +239,8 @@ class DynamicPartitionPruningSuite */ test("simple inner join triggers DPP with mock-up tables") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("df1", "df2") { spark.range(1000) .select(col("id"), col("id").as("k")) @@ -271,7 +272,8 @@ class DynamicPartitionPruningSuite */ test("self-join on a partitioned table should not trigger DPP") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("fact") { sql( s""" @@ -302,7 +304,8 @@ class DynamicPartitionPruningSuite */ test("static scan metrics") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("fact", "dim") { spark.range(10) .map { x => Tuple3(x, x + 1, 0) } @@ -370,7 +373,8 @@ class DynamicPartitionPruningSuite test("DPP should not be rewritten as an existential join") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "1.5", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( s""" |SELECT * FROM product p WHERE p.store_id NOT IN @@ -395,7 +399,7 @@ class DynamicPartitionPruningSuite */ test("DPP triggers only for certain types of query") { withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false") { Given("dynamic partition pruning disabled") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { val df = sql( @@ -433,7 +437,8 @@ class DynamicPartitionPruningSuite } Given("left-semi join with partition column on the left side") - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT * FROM fact_sk f @@ -457,7 +462,8 @@ class DynamicPartitionPruningSuite } Given("right outer join with partition column on the left side") - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT * FROM fact_sk f RIGHT OUTER JOIN dim_store s @@ -474,7 +480,8 @@ class DynamicPartitionPruningSuite */ test("filtering ratio policy fallback") { withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { Given("no stats and selective predicate") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") { @@ -543,7 +550,8 @@ class DynamicPartitionPruningSuite */ test("filtering ratio policy with stats when the broadcast pruning is disabled") { withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { Given("disabling the use of stats in the DPP heuristic") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false") { @@ -613,10 +621,7 @@ class DynamicPartitionPruningSuite test("partition pruning in broadcast hash joins with non-deterministic probe part") { Given("alias with simple join condition, and non-deterministic query") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -630,10 +635,7 @@ class DynamicPartitionPruningSuite } Given("alias over multiple sub-queries with simple join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -651,10 +653,7 @@ class DynamicPartitionPruningSuite test("partition pruning in broadcast hash joins with aliases") { Given("alias with simple join condition, using attribute names only") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -674,10 +673,7 @@ class DynamicPartitionPruningSuite } Given("alias with expr as join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -697,10 +693,7 @@ class DynamicPartitionPruningSuite } Given("alias over multiple sub-queries with simple join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -722,10 +715,7 @@ class DynamicPartitionPruningSuite } Given("alias over multiple sub-queries with simple join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid_d as pid, f.sid_d as sid FROM @@ -754,10 +744,8 @@ class DynamicPartitionPruningSuite test("partition pruning in broadcast hash joins") { Given("disable broadcast pruning and disable subquery duplication") withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f @@ -777,9 +765,10 @@ class DynamicPartitionPruningSuite Given("disable reuse broadcast results and enable subquery duplication") withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0.5") { + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0.5", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f @@ -798,52 +787,47 @@ class DynamicPartitionPruningSuite } Given("enable reuse broadcast results and disable query duplication") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { - val df = sql( - """ - |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f - |JOIN dim_stats s - |ON f.store_id = s.store_id WHERE s.country = 'DE' - """.stripMargin) + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) - checkPartitionPruningPredicate(df, false, true) + checkPartitionPruningPredicate(df, false, true) - checkAnswer(df, - Row(1030, 2, 10, 3) :: - Row(1040, 2, 50, 3) :: - Row(1050, 2, 50, 3) :: - Row(1060, 2, 50, 3) :: Nil - ) + checkAnswer(df, + Row(1030, 2, 10, 3) :: + Row(1040, 2, 50, 3) :: + Row(1050, 2, 50, 3) :: + Row(1060, 2, 50, 3) :: Nil + ) } Given("disable broadcast hash join and disable query duplication") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { - val df = sql( - """ - |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f - |JOIN dim_stats s - |ON f.store_id = s.store_id WHERE s.country = 'DE' - """.stripMargin) + withSQLConf( + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) - checkPartitionPruningPredicate(df, false, false) + checkPartitionPruningPredicate(df, false, false) - checkAnswer(df, - Row(1030, 2, 10, 3) :: - Row(1040, 2, 50, 3) :: - Row(1050, 2, 50, 3) :: - Row(1060, 2, 50, 3) :: Nil - ) + checkAnswer(df, + Row(1030, 2, 10, 3) :: + Row(1040, 2, 50, 3) :: + Row(1050, 2, 50, 3) :: + Row(1060, 2, 50, 3) :: Nil + ) } Given("disable broadcast hash join and enable query duplication") - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") { val df = sql( @@ -865,9 +849,7 @@ class DynamicPartitionPruningSuite } test("broadcast a single key in a HashedRelation") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -925,9 +907,7 @@ class DynamicPartitionPruningSuite } test("broadcast multiple keys in a LongHashedRelation") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -962,9 +942,7 @@ class DynamicPartitionPruningSuite } test("broadcast multiple keys in an UnsafeHashedRelation") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -999,9 +977,7 @@ class DynamicPartitionPruningSuite } test("different broadcast subqueries with identical children") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -1073,7 +1049,7 @@ class DynamicPartitionPruningSuite } test("avoid reordering broadcast join keys to match input hash partitioning") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTable("large", "dimTwo", "dimThree") { spark.range(100).select( @@ -1123,9 +1099,7 @@ class DynamicPartitionPruningSuite * duplicated partitioning keys, also used to uniquely identify the dynamic pruning filters. */ test("dynamic partition pruning ambiguity issue across nested joins") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("store", "date", "item") { spark.range(500) .select((($"id" + 30) % 50).as("ss_item_sk"), @@ -1163,9 +1137,7 @@ class DynamicPartitionPruningSuite } test("cleanup any DPP filter that isn't pushed down due to expression id clashes") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(1000).select($"id".as("A"), $"id".as("AA")) .write.partitionBy("A").format(tableFormat).mode("overwrite").saveAsTable("fact") @@ -1186,10 +1158,7 @@ class DynamicPartitionPruningSuite } test("cleanup any DPP filter that isn't pushed down due to non-determinism") { - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -1204,9 +1173,7 @@ class DynamicPartitionPruningSuite } test("join key with multiple references on the filtering plan") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0", + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { // when enable AQE, the reusedExchange is inserted when executed. withTable("fact", "dim") { @@ -1240,9 +1207,7 @@ class DynamicPartitionPruningSuite } test("Make sure dynamic pruning works on uncorrelated queries") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT d.store_id, @@ -1266,10 +1231,7 @@ class DynamicPartitionPruningSuite test("Plan broadcast pruning only when the broadcast can be reused") { Given("dynamic pruning filter on the build side") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.store_id, f.product_id, f.units_sold FROM fact_np f @@ -1288,10 +1250,7 @@ class DynamicPartitionPruningSuite } Given("dynamic pruning filter on the probe side") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT /*+ BROADCAST(f)*/ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index d9f4d6d5132ae..b591705274110 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -239,7 +239,8 @@ class ExplainSuite extends QueryTest with SharedSparkSession { test("explain formatted - check presence of subquery in case of DPP") { withTable("df1", "df2") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("df1", "df2") { spark.range(1000).select(col("id"), col("id").as("k")) .write From f041aaaf55fb1e907e0e5b0876927ef328664664 Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 13 Feb 2020 22:06:24 +0800 Subject: [PATCH 0066/1280] [SPARK-30758][SQL][TESTS] Improve bracketed comments tests ### What changes were proposed in this pull request? Although Spark SQL support bracketed comments, but `SQLQueryTestSuite` can't treat bracketed comments well and lead to generated golden files can't display bracketed comments well. This PR will improve the treatment of bracketed comments and add three test case in `PlanParserSuite`. Spark SQL can't support nested bracketed comments and https://github.com/apache/spark/pull/27495 used to support it. ### Why are the changes needed? Golden files can't display well. ### Does this PR introduce any user-facing change? No ### How was this patch tested? New UT. Closes #27481 from beliefer/ansi-brancket-comments. Authored-by: beliefer Signed-off-by: Wenchen Fan (cherry picked from commit 04604b9899cc43a9726d671061ff305912fdb85f) Signed-off-by: Wenchen Fan --- .../sql-tests/inputs/postgreSQL/comments.sql | 6 +- .../results/postgreSQL/comments.sql.out | 137 ++++-------------- .../apache/spark/sql/SQLQueryTestSuite.scala | 51 ++++++- 3 files changed, 78 insertions(+), 116 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql index 6725ce45e72a5..1a454179ef79f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql @@ -11,17 +11,19 @@ SELECT /* embedded single line */ 'embedded' AS `second`; SELECT /* both embedded and trailing single line */ 'both' AS third; -- trailing single line SELECT 'before multi-line' AS fourth; +--QUERY-DELIMITER-START -- [SPARK-28880] ANSI SQL: Bracketed comments /* This is an example of SQL which should not execute: * select 'multi-line'; */ SELECT 'after multi-line' AS fifth; +--QUERY-DELIMITER-END -- [SPARK-28880] ANSI SQL: Bracketed comments -- -- Nested comments -- - +--QUERY-DELIMITER-START /* SELECT 'trailing' as x1; -- inside block comment */ @@ -44,5 +46,5 @@ Hoo boy. Still two deep... Now just one deep... */ 'deeply nested example' AS sixth; - +--QUERY-DELIMITER-END /* and this is the end of the file */ diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out index 4ea49013a62d1..637c5561bd940 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 7 -- !query @@ -36,129 +36,32 @@ before multi-line -- !query /* This is an example of SQL which should not execute: - * select 'multi-line' --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* This is an example of SQL which should not execute: -^^^ - * select 'multi-line' - - --- !query -*/ + * select 'multi-line'; + */ SELECT 'after multi-line' AS fifth -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -extraneous input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -*/ -^^^ -SELECT 'after multi-line' AS fifth +after multi-line -- !query /* -SELECT 'trailing' as x1 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* -^^^ -SELECT 'trailing' as x1 - - --- !query +SELECT 'trailing' as x1; -- inside block comment */ /* This block comment surrounds a query which itself has a block comment... -SELECT /* embedded single line */ 'embedded' AS x2 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -*/ -^^^ - -/* This block comment surrounds a query which itself has a block comment... -SELECT /* embedded single line */ 'embedded' AS x2 - - --- !query +SELECT /* embedded single line */ 'embedded' AS x2; */ SELECT -- continued after the following block comments... /* Deeply nested comment. This includes a single apostrophe to make sure we aren't decoding this part as a string. -SELECT 'deep nest' AS n1 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -extraneous input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -*/ -^^^ - -SELECT -- continued after the following block comments... -/* Deeply nested comment. - This includes a single apostrophe to make sure we aren't decoding this part as a string. -SELECT 'deep nest' AS n1 - - --- !query +SELECT 'deep nest' AS n1; /* Second level of nesting... -SELECT 'deeper nest' as n2 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* Second level of nesting... -^^^ -SELECT 'deeper nest' as n2 - - --- !query +SELECT 'deeper nest' as n2; /* Third level of nesting... -SELECT 'deepest nest' as n3 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* Third level of nesting... -^^^ -SELECT 'deepest nest' as n3 - - --- !query +SELECT 'deepest nest' as n3; */ Hoo boy. Still two deep... */ @@ -170,11 +73,27 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) +mismatched input ''embedded'' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 6, pos 34) == SQL == +/* +SELECT 'trailing' as x1; -- inside block comment +*/ + +/* This block comment surrounds a query which itself has a block comment... +SELECT /* embedded single line */ 'embedded' AS x2; +----------------------------------^^^ +*/ + +SELECT -- continued after the following block comments... +/* Deeply nested comment. + This includes a single apostrophe to make sure we aren't decoding this part as a string. +SELECT 'deep nest' AS n1; +/* Second level of nesting... +SELECT 'deeper nest' as n2; +/* Third level of nesting... +SELECT 'deepest nest' as n3; */ -^^^ Hoo boy. Still two deep... */ Now just one deep... diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 6b9e5bbd3c961..da4727f6a98cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql import java.io.File import java.util.{Locale, TimeZone} +import java.util.regex.Pattern +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal import org.apache.spark.{SparkConf, SparkException} @@ -62,7 +64,12 @@ import org.apache.spark.tags.ExtendedSQLTest * }}} * * The format for input files is simple: - * 1. A list of SQL queries separated by semicolon. + * 1. A list of SQL queries separated by semicolons by default. If the semicolon cannot effectively + * separate the SQL queries in the test file(e.g. bracketed comments), please use + * --QUERY-DELIMITER-START and --QUERY-DELIMITER-END. Lines starting with + * --QUERY-DELIMITER-START and --QUERY-DELIMITER-END represent the beginning and end of a query, + * respectively. Code that is not surrounded by lines that begin with --QUERY-DELIMITER-START + * and --QUERY-DELIMITER-END is still separated by semicolons. * 2. Lines starting with -- are treated as comments and ignored. * 3. Lines starting with --SET are used to specify the configs when running this testing file. You * can set multiple configs in one --SET, using comma to separate them. Or you can use multiple @@ -246,9 +253,15 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { /** Run a test case. */ protected def runTest(testCase: TestCase): Unit = { + def splitWithSemicolon(seq: Seq[String]) = { + seq.mkString("\n").split("(?<=[^\\\\]);") + } val input = fileToString(new File(testCase.inputFile)) - val (comments, code) = input.split("\n").partition(_.trim.startsWith("--")) + val (comments, code) = input.split("\n").partition { line => + val newLine = line.trim + newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER") + } // If `--IMPORT` found, load code from another test case file, then insert them // into the head in this test. @@ -261,10 +274,38 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { } }.flatten + val allCode = importedCode ++ code + val tempQueries = if (allCode.exists(_.trim.startsWith("--QUERY-DELIMITER"))) { + // Although the loop is heavy, only used for bracketed comments test. + val querys = new ArrayBuffer[String] + val otherCodes = new ArrayBuffer[String] + var tempStr = "" + var start = false + for (c <- allCode) { + if (c.trim.startsWith("--QUERY-DELIMITER-START")) { + start = true + querys ++= splitWithSemicolon(otherCodes.toSeq) + otherCodes.clear() + } else if (c.trim.startsWith("--QUERY-DELIMITER-END")) { + start = false + querys += s"\n${tempStr.stripSuffix(";")}" + tempStr = "" + } else if (start) { + tempStr += s"\n$c" + } else { + otherCodes += c + } + } + if (otherCodes.nonEmpty) { + querys ++= splitWithSemicolon(otherCodes.toSeq) + } + querys.toSeq + } else { + splitWithSemicolon(allCode).toSeq + } + // List of SQL queries to run - // note: this is not a robust way to split queries using semicolon, but works for now. - val queries = (importedCode ++ code).mkString("\n").split("(?<=[^\\\\]);") - .map(_.trim).filter(_ != "").toSeq + val queries = tempQueries.map(_.trim).filter(_ != "").toSeq // Fix misplacement when comment is at the end of the query. .map(_.split("\n").filterNot(_.startsWith("--")).mkString("\n")).map(_.trim).filter(_ != "") From 074712e329b347f769f8c009949c7845e95b3212 Mon Sep 17 00:00:00 2001 From: Liang Zhang Date: Thu, 13 Feb 2020 23:55:13 +0800 Subject: [PATCH 0067/1280] [SPARK-30762] Add dtype=float32 support to vector_to_array UDF ### What changes were proposed in this pull request? In this PR, we add a parameter in the python function vector_to_array(col) that allows converting to a column of arrays of Float (32bits) in scala, which would be mapped to a numpy array of dtype=float32. ### Why are the changes needed? In the downstream ML training, using float32 instead of float64 (default) would allow a larger batch size, i.e., allow more data to fit in the memory. ### Does this PR introduce any user-facing change? Yes. Old: `vector_to_array()` only take one param ``` df.select(vector_to_array("colA"), ...) ``` New: `vector_to_array()` can take an additional optional param: `dtype` = "float32" (or "float64") ``` df.select(vector_to_array("colA", "float32"), ...) ``` ### How was this patch tested? Unit test in scala. doctest in python. Closes #27522 from liangz1/udf-float32. Authored-by: Liang Zhang Signed-off-by: WeichenXu (cherry picked from commit 82d0aa37ae521231d8067e473c6ea79a118a115a) Signed-off-by: WeichenXu --- .../scala/org/apache/spark/ml/functions.scala | 34 ++++++++++++++++--- .../org/apache/spark/ml/FunctionsSuite.scala | 33 +++++++++++++++--- python/pyspark/ml/functions.py | 27 ++++++++++++--- 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 1faf562c4d896..0f03231079866 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{SparseVector, Vector} import org.apache.spark.mllib.linalg.{Vector => OldVector} import org.apache.spark.sql.Column import org.apache.spark.sql.functions.udf @@ -27,7 +27,6 @@ import org.apache.spark.sql.functions.udf @Since("3.0.0") object functions { // scalastyle:on - private val vectorToArrayUdf = udf { vec: Any => vec match { case v: Vector => v.toArray @@ -39,10 +38,37 @@ object functions { } }.asNonNullable() + private val vectorToArrayFloatUdf = udf { vec: Any => + vec match { + case v: SparseVector => + val data = new Array[Float](v.size) + v.foreachActive { (index, value) => data(index) = value.toFloat } + data + case v: Vector => v.toArray.map(_.toFloat) + case v: OldVector => v.toArray.map(_.toFloat) + case v => throw new IllegalArgumentException( + "function vector_to_array requires a non-null input argument and input type must be " + + "`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " + + s"but got ${ if (v == null) "null" else v.getClass.getName }.") + } + }.asNonNullable() + /** * Converts a column of MLlib sparse/dense vectors into a column of dense arrays. - * + * @param v: the column of MLlib sparse/dense vectors + * @param dtype: the desired underlying data type in the returned array + * @return an array<float> if dtype is float32, or array<double> if dtype is float64 * @since 3.0.0 */ - def vector_to_array(v: Column): Column = vectorToArrayUdf(v) + def vector_to_array(v: Column, dtype: String = "float64"): Column = { + if (dtype == "float64") { + vectorToArrayUdf(v) + } else if (dtype == "float32") { + vectorToArrayFloatUdf(v) + } else { + throw new IllegalArgumentException( + s"Unsupported dtype: $dtype. Valid values: float64, float32." + ) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala index 2f5062c689fc7..3dd9a7d8ec85d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala @@ -34,9 +34,8 @@ class FunctionsSuite extends MLTest { (Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0)))) ).toDF("vec", "oldVec") - val result = df.select(vector_to_array('vec), vector_to_array('oldVec)) - .as[(Seq[Double], Seq[Double])] - .collect().toSeq + val result = df.select(vector_to_array('vec), vector_to_array('oldVec)) + .as[(Seq[Double], Seq[Double])].collect().toSeq val expected = Seq( (Seq(1.0, 2.0, 3.0), Seq(10.0, 20.0, 30.0)), @@ -50,7 +49,6 @@ class FunctionsSuite extends MLTest { (null, null, 0) ).toDF("vec", "oldVec", "label") - for ((colName, valType) <- Seq( ("vec", "null"), ("oldVec", "null"), ("label", "java.lang.Integer"))) { val thrown1 = intercept[SparkException] { @@ -61,5 +59,32 @@ class FunctionsSuite extends MLTest { "`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " + s"but got ${valType}")) } + + val df3 = Seq( + (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)), + (Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0)))) + ).toDF("vec", "oldVec") + val dfArrayFloat = df3.select( + vector_to_array('vec, dtype = "float32"), vector_to_array('oldVec, dtype = "float32")) + + // Check values are correct + val result3 = dfArrayFloat.as[(Seq[Float], Seq[Float])].collect().toSeq + + val expected3 = Seq( + (Seq(1.0, 2.0, 3.0), Seq(10.0, 20.0, 30.0)), + (Seq(2.0, 0.0, 3.0), Seq(20.0, 0.0, 30.0)) + ) + assert(result3 === expected3) + + // Check data types are correct + assert(dfArrayFloat.schema.simpleString === + "struct,UDF(oldVec):array>") + + val thrown2 = intercept[IllegalArgumentException] { + df3.select( + vector_to_array('vec, dtype = "float16"), vector_to_array('oldVec, dtype = "float16")) + } + assert(thrown2.getMessage.contains( + s"Unsupported dtype: float16. Valid values: float64, float32.")) } } diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py index 2b4d8ddcd00a8..ec164f34bc4db 100644 --- a/python/pyspark/ml/functions.py +++ b/python/pyspark/ml/functions.py @@ -19,10 +19,15 @@ from pyspark.sql.column import Column, _to_java_column -@since(3.0) -def vector_to_array(col): +@since("3.0.0") +def vector_to_array(col, dtype="float64"): """ Converts a column of MLlib sparse/dense vectors into a column of dense arrays. + :param col: A string of the column name or a Column + :param dtype: The data type of the output array. Valid values: "float64" or "float32". + :return: The converted column of dense arrays. + + .. versionadded:: 3.0.0 >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.functions import vector_to_array @@ -32,14 +37,26 @@ def vector_to_array(col): ... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]), ... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))], ... ["vec", "oldVec"]) - >>> df.select(vector_to_array("vec").alias("vec"), - ... vector_to_array("oldVec").alias("oldVec")).collect() + >>> df1 = df.select(vector_to_array("vec").alias("vec"), + ... vector_to_array("oldVec").alias("oldVec")) + >>> df1.collect() + [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]), + Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])] + >>> df2 = df.select(vector_to_array("vec", "float32").alias("vec"), + ... vector_to_array("oldVec", "float32").alias("oldVec")) + >>> df2.collect() [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]), Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])] + >>> df1.schema.fields + [StructField(vec,ArrayType(DoubleType,false),false), + StructField(oldVec,ArrayType(DoubleType,false),false)] + >>> df2.schema.fields + [StructField(vec,ArrayType(FloatType,false),false), + StructField(oldVec,ArrayType(FloatType,false),false)] """ sc = SparkContext._active_spark_context return Column( - sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col))) + sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype)) def _test(): From 82258aa4794d15bfe9cfd2b4bced790ef1d35e45 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 13 Feb 2020 10:53:55 -0800 Subject: [PATCH 0068/1280] [SPARK-30703][SQL][DOCS] Add a document for the ANSI mode ### What changes were proposed in this pull request? This pr intends to add a document for the ANSI mode; Screen Shot 2020-02-13 at 8 08 52 Screen Shot 2020-02-13 at 8 09 13 Screen Shot 2020-02-13 at 8 09 26 Screen Shot 2020-02-13 at 8 09 38 ### Why are the changes needed? For better document coverage and usability. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? N/A Closes #27489 from maropu/SPARK-30703. Authored-by: Takeshi Yamamuro Signed-off-by: Gengliang Wang (cherry picked from commit 3c4044ea77fe3b1268b52744cd4f1ae61f17a9a8) Signed-off-by: Gengliang Wang --- docs/_data/menu-sql.yaml | 11 +- ...keywords.md => sql-ref-ansi-compliance.md} | 125 +++++++++++++++++- docs/sql-ref-arithmetic-ops.md | 22 --- 3 files changed, 132 insertions(+), 26 deletions(-) rename docs/{sql-keywords.md => sql-ref-ansi-compliance.md} (82%) delete mode 100644 docs/sql-ref-arithmetic-ops.md diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 241ec399d7bd5..1e343f630f88e 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -80,6 +80,15 @@ url: sql-ref-null-semantics.html - text: NaN Semantics url: sql-ref-nan-semantics.html + - text: ANSI Compliance + url: sql-ref-ansi-compliance.html + subitems: + - text: Arithmetic Operations + url: sql-ref-ansi-compliance.html#arithmetic-operations + - text: Type Conversion + url: sql-ref-ansi-compliance.html#type-conversion + - text: SQL Keywords + url: sql-ref-ansi-compliance.html#sql-keywords - text: SQL Syntax url: sql-ref-syntax.html subitems: @@ -214,5 +223,3 @@ url: sql-ref-syntax-aux-resource-mgmt-list-file.html - text: LIST JAR url: sql-ref-syntax-aux-resource-mgmt-list-jar.html - - text: Arithmetic operations - url: sql-ref-arithmetic-ops.html diff --git a/docs/sql-keywords.md b/docs/sql-ref-ansi-compliance.md similarity index 82% rename from docs/sql-keywords.md rename to docs/sql-ref-ansi-compliance.md index 9e4a3c54100c6..d02383518b967 100644 --- a/docs/sql-keywords.md +++ b/docs/sql-ref-ansi-compliance.md @@ -1,7 +1,7 @@ --- layout: global -title: Spark SQL Keywords -displayTitle: Spark SQL Keywords +title: ANSI Compliance +displayTitle: ANSI Compliance license: | Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with @@ -19,6 +19,127 @@ license: | limitations under the License. --- +Spark SQL has two options to comply with the SQL standard: `spark.sql.ansi.enabled` and `spark.sql.storeAssignmentPolicy` (See a table below for details). +When `spark.sql.ansi.enabled` is set to `true`, Spark SQL follows the standard in basic behaviours (e.g., arithmetic operations, type conversion, and SQL parsing). +Moreover, Spark SQL has an independent option to control implicit casting behaviours when inserting rows in a table. +The casting behaviours are defined as store assignment rules in the standard. +When `spark.sql.storeAssignmentPolicy` is set to `ANSI`, Spark SQL complies with the ANSI store assignment rules. + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.ansi.enabledfalse + When true, Spark tries to conform to the ANSI SQL specification: + 1. Spark will throw a runtime exception if an overflow occurs in any operation on integral/decimal field. + 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in the SQL parser. +
spark.sql.storeAssignmentPolicyANSI + When inserting a value into a column with different data type, Spark will perform type coercion. + Currently, we support 3 policies for the type coercion rules: ANSI, legacy and strict. With ANSI policy, + Spark performs the type coercion as per ANSI SQL. In practice, the behavior is mostly the same as PostgreSQL. + It disallows certain unreasonable type conversions such as converting string to int or double to boolean. + With legacy policy, Spark allows the type coercion as long as it is a valid Cast, which is very loose. + e.g. converting string to int or double to boolean is allowed. + It is also the only behavior in Spark 2.x and it is compatible with Hive. + With strict policy, Spark doesn't allow any possible precision loss or data truncation in type coercion, + e.g. converting double to int or decimal to double is not allowed. +
+ +The following subsections present behaviour changes in arithmetic operations, type conversions, and SQL parsing when the ANSI mode enabled. + +### Arithmetic Operations + +In Spark SQL, arithmetic operations performed on numeric types (with the exception of decimal) are not checked for overflows by default. +This means that in case an operation causes overflows, the result is the same that the same operation returns in a Java/Scala program (e.g., if the sum of 2 integers is higher than the maximum value representable, the result is a negative number). +On the other hand, Spark SQL returns null for decimal overflows. +When `spark.sql.ansi.enabled` is set to `true` and an overflow occurs in numeric and interval arithmetic operations, it throws an arithmetic exception at runtime. + +{% highlight sql %} +-- `spark.sql.ansi.enabled=true` +SELECT 2147483647 + 1; + + java.lang.ArithmeticException: integer overflow + +-- `spark.sql.ansi.enabled=false` +SELECT 2147483647 + 1; + + +----------------+ + |(2147483647 + 1)| + +----------------+ + | -2147483648| + +----------------+ + +{% endhighlight %} + +### Type Conversion + +Spark SQL has three kinds of type conversions: explicit casting, type coercion, and store assignment casting. +When `spark.sql.ansi.enabled` is set to `true`, explicit casting by `CAST` syntax throws a runtime exception for illegal cast patterns defined in the standard, e.g. casts from a string to an integer. +On the other hand, `INSERT INTO` syntax throws an analysis exception when the ANSI mode enabled via `spark.sql.storeAssignmentPolicy=ANSI`. + +Currently, the ANSI mode affects explicit casting and assignment casting only. +In future releases, the behaviour of type coercion might change along with the other two type conversion rules. + +{% highlight sql %} +-- Examples of explicit casting + +-- `spark.sql.ansi.enabled=true` +SELECT CAST('a' AS INT); + + java.lang.NumberFormatException: invalid input syntax for type numeric: a + +SELECT CAST(2147483648L AS INT); + + java.lang.ArithmeticException: Casting 2147483648 to int causes overflow + +-- `spark.sql.ansi.enabled=false` (This is a default behaviour) +SELECT CAST('a' AS INT); + + +--------------+ + |CAST(a AS INT)| + +--------------+ + | null| + +--------------+ + +SELECT CAST(2147483648L AS INT); + + +-----------------------+ + |CAST(2147483648 AS INT)| + +-----------------------+ + | -2147483648| + +-----------------------+ + +-- Examples of store assignment rules +CREATE TABLE t (v INT); + +-- `spark.sql.storeAssignmentPolicy=ANSI` +INSERT INTO t VALUES ('1'); + + org.apache.spark.sql.AnalysisException: Cannot write incompatible data to table '`default`.`t`': + - Cannot safely cast 'v': StringType to IntegerType; + +-- `spark.sql.storeAssignmentPolicy=LEGACY` (This is a legacy behaviour until Spark 2.x) +INSERT INTO t VALUES ('1'); +SELECT * FROM t; + + +---+ + | v| + +---+ + | 1| + +---+ + +{% endhighlight %} + +### SQL Keywords + When `spark.sql.ansi.enabled` is true, Spark SQL will use the ANSI mode parser. In this mode, Spark SQL has two kinds of keywords: * Reserved keywords: Keywords that are reserved and can't be used as identifiers for table, view, column, function, alias, etc. diff --git a/docs/sql-ref-arithmetic-ops.md b/docs/sql-ref-arithmetic-ops.md deleted file mode 100644 index 7bc8ffe31c990..0000000000000 --- a/docs/sql-ref-arithmetic-ops.md +++ /dev/null @@ -1,22 +0,0 @@ ---- -layout: global -title: Arithmetic Operations -displayTitle: Arithmetic Operations -license: | - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. ---- - -Operations performed on numeric types (with the exception of decimal) are not checked for overflow. This means that in case an operation causes an overflow, the result is the same that the same operation returns in a Java/Scala program (eg. if the sum of 2 integers is higher than the maximum value representable, the result is a negative number). From 78bd4b34ca7e0834b8b3878cd74b3f59b46b4f90 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Feb 2020 20:09:24 +0100 Subject: [PATCH 0069/1280] [SPARK-30751][SQL] Combine the skewed readers into one in AQE skew join optimizations ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/26434 This PR use one special shuffle reader for skew join, so that we only have one join after optimization. In order to do that, this PR 1. add a very general `CustomShuffledRowRDD` which support all kind of partition arrangement. 2. move the logic of coalescing shuffle partitions to a util function, and call it during skew join optimization, to totally decouple with the `ReduceNumShufflePartitions` rule. It's too complicated to interfere skew join with `ReduceNumShufflePartitions`, as you need to consider the size of split partitions which don't respect target size already. ### Why are the changes needed? The current skew join optimization has a serious performance issue: the size of the query plan depends on the number and size of skewed partitions. ### Does this PR introduce any user-facing change? no ### How was this patch tested? existing tests test UI manually: ![image](https://user-images.githubusercontent.com/3182036/74357390-cfb30480-4dfa-11ea-83f6-825d1b9379ca.png) explain output ``` AdaptiveSparkPlan(isFinalPlan=true) +- OverwriteByExpression org.apache.spark.sql.execution.datasources.noop.NoopTable$403a2ed5, [AlwaysTrue()], org.apache.spark.sql.util.CaseInsensitiveStringMap1f +- *(5) SortMergeJoin(skew=true) [key1#2L], [key2#6L], Inner :- *(3) Sort [key1#2L ASC NULLS FIRST], false, 0 : +- SkewJoinShuffleReader 2 skewed partitions with size(max=5 KB, min=5 KB, avg=5 KB) : +- ShuffleQueryStage 0 : +- Exchange hashpartitioning(key1#2L, 200), true, [id=#53] : +- *(1) Project [(id#0L % 2) AS key1#2L] : +- *(1) Filter isnotnull((id#0L % 2)) : +- *(1) Range (0, 100000, step=1, splits=6) +- *(4) Sort [key2#6L ASC NULLS FIRST], false, 0 +- SkewJoinShuffleReader 2 skewed partitions with size(max=5 KB, min=5 KB, avg=5 KB) +- ShuffleQueryStage 1 +- Exchange hashpartitioning(key2#6L, 200), true, [id=#64] +- *(2) Project [((id#4L % 2) + 1) AS key2#6L] +- *(2) Filter isnotnull(((id#4L % 2) + 1)) +- *(2) Range (0, 100000, step=1, splits=6) ``` Closes #27493 from cloud-fan/aqe. Authored-by: Wenchen Fan Signed-off-by: herman (cherry picked from commit a4ceea6868002b88161517b14b94a2006be8af1b) Signed-off-by: herman --- .../spark/sql/execution/ShuffledRowRDD.scala | 23 +- .../adaptive/CustomShuffledRowRDD.scala | 113 +++++++ .../adaptive/OptimizeLocalShuffleReader.scala | 2 +- .../adaptive/OptimizeSkewedJoin.scala | 276 +++++++++++------- .../adaptive/ReduceNumShufflePartitions.scala | 157 ++-------- .../adaptive/ShufflePartitionsCoalescer.scala | 112 +++++++ .../adaptive/SkewedShuffledRowRDD.scala | 78 ----- .../exchange/ShuffleExchangeExec.scala | 21 +- .../execution/joins/SortMergeJoinExec.scala | 13 +- .../ReduceNumShufflePartitionsSuite.scala | 210 +------------ .../ShufflePartitionsCoalescerSuite.scala | 220 ++++++++++++++ .../adaptive/AdaptiveQueryExecSuite.scala | 219 +++++--------- 12 files changed, 741 insertions(+), 703 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index efa493923ccc1..4c19f95796d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -116,7 +116,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A class ShuffledRowRDD( var dependency: ShuffleDependency[Int, InternalRow, InternalRow], metrics: Map[String, SQLMetric], - specifiedPartitionIndices: Option[Array[(Int, Int)]] = None) + specifiedPartitionStartIndices: Option[Array[Int]] = None) extends RDD[InternalRow](dependency.rdd.context, Nil) { if (SQLConf.get.fetchShuffleBlocksInBatchEnabled) { @@ -126,8 +126,8 @@ class ShuffledRowRDD( private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions - private[this] val partitionStartIndices: Array[Int] = specifiedPartitionIndices match { - case Some(indices) => indices.map(_._1) + private[this] val partitionStartIndices: Array[Int] = specifiedPartitionStartIndices match { + case Some(indices) => indices case None => // When specifiedPartitionStartIndices is not defined, every post-shuffle partition // corresponds to a pre-shuffle partition. @@ -142,15 +142,16 @@ class ShuffledRowRDD( override val partitioner: Option[Partitioner] = Some(part) override def getPartitions: Array[Partition] = { - specifiedPartitionIndices match { - case Some(indices) => - Array.tabulate[Partition](indices.length) { i => - new ShuffledRowRDDPartition(i, indices(i)._1, indices(i)._2) - } - case None => - Array.tabulate[Partition](numPreShufflePartitions) { i => - new ShuffledRowRDDPartition(i, i, i + 1) + assert(partitionStartIndices.length == part.numPartitions) + Array.tabulate[Partition](partitionStartIndices.length) { i => + val startIndex = partitionStartIndices(i) + val endIndex = + if (i < partitionStartIndices.length - 1) { + partitionStartIndices(i + 1) + } else { + numPreShufflePartitions } + new ShuffledRowRDDPartition(i, startIndex, endIndex) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala new file mode 100644 index 0000000000000..5aba57443d632 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} + +sealed trait ShufflePartitionSpec + +// A partition that reads data of one reducer. +case class SinglePartitionSpec(reducerIndex: Int) extends ShufflePartitionSpec + +// A partition that reads data of multiple reducers, from `startReducerIndex` (inclusive) to +// `endReducerIndex` (exclusive). +case class CoalescedPartitionSpec( + startReducerIndex: Int, endReducerIndex: Int) extends ShufflePartitionSpec + +// A partition that reads partial data of one reducer, from `startMapIndex` (inclusive) to +// `endMapIndex` (exclusive). +case class PartialPartitionSpec( + reducerIndex: Int, startMapIndex: Int, endMapIndex: Int) extends ShufflePartitionSpec + +private final case class CustomShufflePartition( + index: Int, spec: ShufflePartitionSpec) extends Partition + +// TODO: merge this with `ShuffledRowRDD`, and replace `LocalShuffledRowRDD` with this RDD. +class CustomShuffledRowRDD( + var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + metrics: Map[String, SQLMetric], + partitionSpecs: Array[ShufflePartitionSpec]) + extends RDD[InternalRow](dependency.rdd.context, Nil) { + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override def clearDependencies() { + super.clearDependencies() + dependency = null + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](partitionSpecs.length) { i => + CustomShufflePartition(i, partitionSpecs(i)) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + partition.asInstanceOf[CustomShufflePartition].spec match { + case SinglePartitionSpec(reducerIndex) => + tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) + + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + startReducerIndex.until(endReducerIndex).flatMap { reducerIndex => + tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) + } + + case PartialPartitionSpec(_, startMapIndex, endMapIndex) => + tracker.getMapLocation(dependency, startMapIndex, endMapIndex) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + val reader = split.asInstanceOf[CustomShufflePartition].spec match { + case SinglePartitionSpec(reducerIndex) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + + case PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex) => + SparkEnv.get.shuffleManager.getReaderForRange( + dependency.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + } + reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index a8d8f358ab660..e95441e28aafe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -71,7 +71,7 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { plan match { case c @ CoalescedShuffleReaderExec(s: ShuffleQueryStageExec, _) => LocalShuffleReaderExec( - s, getPartitionStartIndices(s, Some(c.partitionIndices.length))) + s, getPartitionStartIndices(s, Some(c.partitionStartIndices.length))) case s: ShuffleQueryStageExec => LocalShuffleReaderExec(s, getPartitionStartIndices(s, None)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 74b7fbd317fc8..a716497c274b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.commons.io.FileUtils + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -44,11 +46,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { * partition size * spark.sql.adaptive.skewedPartitionFactor and also larger than * spark.sql.adaptive.skewedPartitionSizeThreshold. */ - private def isSkewed( - stats: MapOutputStatistics, - partitionId: Int, - medianSize: Long): Boolean = { - val size = stats.bytesByPartitionId(partitionId) + private def isSkewed(size: Long, medianSize: Long): Boolean = { size > medianSize * conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR) && size > conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD) } @@ -108,12 +106,12 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { stage.resultOption.get.asInstanceOf[MapOutputStatistics] } - private def supportSplitOnLeftPartition(joinType: JoinType) = { + private def canSplitLeftSide(joinType: JoinType) = { joinType == Inner || joinType == Cross || joinType == LeftSemi || joinType == LeftAnti || joinType == LeftOuter } - private def supportSplitOnRightPartition(joinType: JoinType) = { + private def canSplitRightSide(joinType: JoinType) = { joinType == Inner || joinType == Cross || joinType == RightOuter } @@ -130,17 +128,18 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { * 1. Check whether the shuffle partition is skewed based on the median size * and the skewed partition threshold in origin smj. * 2. Assuming partition0 is skewed in left side, and it has 5 mappers (Map0, Map1...Map4). - * And we will split the 5 Mappers into 3 mapper ranges [(Map0, Map1), (Map2, Map3), (Map4)] + * And we may split the 5 Mappers into 3 mapper ranges [(Map0, Map1), (Map2, Map3), (Map4)] * based on the map size and the max split number. - * 3. Create the 3 smjs with separately reading the above mapper ranges and then join with - * the Partition0 in right side. - * 4. Finally union the above 3 split smjs and the origin smj. + * 3. Wrap the join left child with a special shuffle reader that reads each mapper range with one + * task, so total 3 tasks. + * 4. Wrap the join right child with a special shuffle reader that reads partition0 3 times by + * 3 tasks separately. */ def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { - case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, + case smj @ SortMergeJoinExec(_, _, joinType, _, s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _), s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _), _) - if (supportedJoinTypes.contains(joinType)) => + if supportedJoinTypes.contains(joinType) => val leftStats = getStatistics(left) val rightStats = getStatistics(right) val numPartitions = leftStats.bytesByPartitionId.length @@ -155,61 +154,134 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { |Right side partition size: |${getSizeInfo(rightMedSize, rightStats.bytesByPartitionId.max)} """.stripMargin) + val canSplitLeft = canSplitLeftSide(joinType) + val canSplitRight = canSplitRightSide(joinType) + + val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + // This is used to delay the creation of non-skew partitions so that we can potentially + // coalesce them like `ReduceNumShufflePartitions` does. + val nonSkewPartitionIndices = mutable.ArrayBuffer.empty[Int] + val leftSkewDesc = new SkewDesc + val rightSkewDesc = new SkewDesc + for (partitionIndex <- 0 until numPartitions) { + val leftSize = leftStats.bytesByPartitionId(partitionIndex) + val isLeftSkew = isSkewed(leftSize, leftMedSize) && canSplitLeft + val rightSize = rightStats.bytesByPartitionId(partitionIndex) + val isRightSkew = isSkewed(rightSize, rightMedSize) && canSplitRight + if (isLeftSkew || isRightSkew) { + if (nonSkewPartitionIndices.nonEmpty) { + // As soon as we see a skew, we'll "flush" out unhandled non-skew partitions. + createNonSkewPartitions(leftStats, rightStats, nonSkewPartitionIndices).foreach { p => + leftSidePartitions += p + rightSidePartitions += p + } + nonSkewPartitionIndices.clear() + } - val skewedPartitions = mutable.HashSet[Int]() - val subJoins = mutable.ArrayBuffer[SparkPlan]() - for (partitionId <- 0 until numPartitions) { - val isLeftSkew = isSkewed(leftStats, partitionId, leftMedSize) - val isRightSkew = isSkewed(rightStats, partitionId, rightMedSize) - val leftMapIdStartIndices = if (isLeftSkew && supportSplitOnLeftPartition(joinType)) { - getMapStartIndices(left, partitionId) - } else { - Array(0) - } - val rightMapIdStartIndices = if (isRightSkew && supportSplitOnRightPartition(joinType)) { - getMapStartIndices(right, partitionId) - } else { - Array(0) - } + val leftParts = if (isLeftSkew) { + leftSkewDesc.addPartitionSize(leftSize) + createSkewPartitions( + partitionIndex, + getMapStartIndices(left, partitionIndex), + getNumMappers(left)) + } else { + Seq(SinglePartitionSpec(partitionIndex)) + } - if (leftMapIdStartIndices.length > 1 || rightMapIdStartIndices.length > 1) { - skewedPartitions += partitionId - for (i <- 0 until leftMapIdStartIndices.length; - j <- 0 until rightMapIdStartIndices.length) { - val leftEndMapId = if (i == leftMapIdStartIndices.length - 1) { - getNumMappers(left) - } else { - leftMapIdStartIndices(i + 1) - } - val rightEndMapId = if (j == rightMapIdStartIndices.length - 1) { - getNumMappers(right) - } else { - rightMapIdStartIndices(j + 1) + val rightParts = if (isRightSkew) { + rightSkewDesc.addPartitionSize(rightSize) + createSkewPartitions( + partitionIndex, + getMapStartIndices(right, partitionIndex), + getNumMappers(right)) + } else { + Seq(SinglePartitionSpec(partitionIndex)) + } + + for { + leftSidePartition <- leftParts + rightSidePartition <- rightParts + } { + leftSidePartitions += leftSidePartition + rightSidePartitions += rightSidePartition + } + } else { + // Add to `nonSkewPartitionIndices` first, and add real partitions later, in case we can + // coalesce the non-skew partitions. + nonSkewPartitionIndices += partitionIndex + // If this is the last partition, add real partition immediately. + if (partitionIndex == numPartitions - 1) { + createNonSkewPartitions(leftStats, rightStats, nonSkewPartitionIndices).foreach { p => + leftSidePartitions += p + rightSidePartitions += p } - // TODO: we may can optimize the sort merge join to broad cast join after - // obtaining the raw data size of per partition, - val leftSkewedReader = SkewedPartitionReaderExec( - left, partitionId, leftMapIdStartIndices(i), leftEndMapId) - val rightSkewedReader = SkewedPartitionReaderExec(right, partitionId, - rightMapIdStartIndices(j), rightEndMapId) - subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, - s1.copy(child = leftSkewedReader), s2.copy(child = rightSkewedReader), true) + nonSkewPartitionIndices.clear() } } } - logDebug(s"number of skewed partitions is ${skewedPartitions.size}") - if (skewedPartitions.nonEmpty) { - val optimizedSmj = smj.copy( - left = s1.copy(child = PartialShuffleReaderExec(left, skewedPartitions.toSet)), - right = s2.copy(child = PartialShuffleReaderExec(right, skewedPartitions.toSet)), - isPartial = true) - subJoins += optimizedSmj - UnionExec(subJoins) + + logDebug("number of skewed partitions: " + + s"left ${leftSkewDesc.numPartitions}, right ${rightSkewDesc.numPartitions}") + if (leftSkewDesc.numPartitions > 0 || rightSkewDesc.numPartitions > 0) { + val newLeft = SkewJoinShuffleReaderExec( + left, leftSidePartitions.toArray, leftSkewDesc.toString) + val newRight = SkewJoinShuffleReaderExec( + right, rightSidePartitions.toArray, rightSkewDesc.toString) + smj.copy( + left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) } else { smj } } + private def createNonSkewPartitions( + leftStats: MapOutputStatistics, + rightStats: MapOutputStatistics, + nonSkewPartitionIndices: Seq[Int]): Seq[ShufflePartitionSpec] = { + assert(nonSkewPartitionIndices.nonEmpty) + if (nonSkewPartitionIndices.length == 1) { + Seq(SinglePartitionSpec(nonSkewPartitionIndices.head)) + } else { + val startIndices = ShufflePartitionsCoalescer.coalescePartitions( + Array(leftStats, rightStats), + firstPartitionIndex = nonSkewPartitionIndices.head, + // `lastPartitionIndex` is exclusive. + lastPartitionIndex = nonSkewPartitionIndices.last + 1, + advisoryTargetSize = conf.targetPostShuffleInputSize) + startIndices.indices.map { i => + val startIndex = startIndices(i) + val endIndex = if (i == startIndices.length - 1) { + // `endIndex` is exclusive. + nonSkewPartitionIndices.last + 1 + } else { + startIndices(i + 1) + } + // Do not create `CoalescedPartitionSpec` if only need to read a singe partition. + if (startIndex + 1 == endIndex) { + SinglePartitionSpec(startIndex) + } else { + CoalescedPartitionSpec(startIndex, endIndex) + } + } + } + } + + private def createSkewPartitions( + reducerIndex: Int, + mapStartIndices: Array[Int], + numMappers: Int): Seq[PartialPartitionSpec] = { + mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { + numMappers + } else { + mapStartIndices(i + 1) + } + PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex) + } + } + override def apply(plan: SparkPlan): SparkPlan = { if (!conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED)) { return plan @@ -248,79 +320,69 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { } } -/** - * A wrapper of shuffle query stage, which submits one reduce task to read a single - * shuffle partition 'partitionIndex' produced by the mappers in range [startMapIndex, endMapIndex). - * This is used to increase the parallelism when reading skewed partitions. - * - * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange - * node during canonicalization. - * @param partitionIndex The pre shuffle partition index. - * @param startMapIndex The start map index. - * @param endMapIndex The end map index. - */ -case class SkewedPartitionReaderExec( - child: QueryStageExec, - partitionIndex: Int, - startMapIndex: Int, - endMapIndex: Int) extends LeafExecNode { +private class SkewDesc { + private[this] var numSkewedPartitions: Int = 0 + private[this] var totalSize: Long = 0 + private[this] var maxSize: Long = 0 + private[this] var minSize: Long = 0 - override def output: Seq[Attribute] = child.output + def numPartitions: Int = numSkewedPartitions - override def outputPartitioning: Partitioning = { - UnknownPartitioning(1) + def addPartitionSize(size: Long): Unit = { + if (numSkewedPartitions == 0) { + maxSize = size + minSize = size + } + numSkewedPartitions += 1 + totalSize += size + if (size > maxSize) maxSize = size + if (size < minSize) minSize = size } - private var cachedSkewedShuffleRDD: SkewedShuffledRowRDD = null - override def doExecute(): RDD[InternalRow] = { - if (cachedSkewedShuffleRDD == null) { - cachedSkewedShuffleRDD = child match { - case stage: ShuffleQueryStageExec => - stage.shuffle.createSkewedShuffleRDD(partitionIndex, startMapIndex, endMapIndex) - case _ => - throw new IllegalStateException("operating on canonicalization plan") - } + override def toString: String = { + if (numSkewedPartitions == 0) { + "no skewed partition" + } else { + val maxSizeStr = FileUtils.byteCountToDisplaySize(maxSize) + val minSizeStr = FileUtils.byteCountToDisplaySize(minSize) + val avgSizeStr = FileUtils.byteCountToDisplaySize(totalSize / numSkewedPartitions) + s"$numSkewedPartitions skewed partitions with " + + s"size(max=$maxSizeStr, min=$minSizeStr, avg=$avgSizeStr)" } - cachedSkewedShuffleRDD } } /** - * A wrapper of shuffle query stage, which skips some partitions when reading the shuffle blocks. + * A wrapper of shuffle query stage, which follows the given partition arrangement. * * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange node during * canonicalization. - * @param excludedPartitions The partitions to skip when reading. + * @param partitionSpecs The partition specs that defines the arrangement. + * @param skewDesc The description of the skewed partitions. */ -case class PartialShuffleReaderExec( - child: QueryStageExec, - excludedPartitions: Set[Int]) extends UnaryExecNode { +case class SkewJoinShuffleReaderExec( + child: SparkPlan, + partitionSpecs: Array[ShufflePartitionSpec], + skewDesc: String) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = { - UnknownPartitioning(1) + UnknownPartitioning(partitionSpecs.length) } - private def shuffleExchange(): ShuffleExchangeExec = child match { - case stage: ShuffleQueryStageExec => stage.shuffle - case _ => - throw new IllegalStateException("operating on canonicalization plan") - } - - private def getPartitionIndexRanges(): Array[(Int, Int)] = { - val length = shuffleExchange().shuffleDependency.partitioner.numPartitions - (0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 1)).toArray - } + override def stringArgs: Iterator[Any] = Iterator(skewDesc) private var cachedShuffleRDD: RDD[InternalRow] = null - override def doExecute(): RDD[InternalRow] = { + override protected def doExecute(): RDD[InternalRow] = { if (cachedShuffleRDD == null) { - cachedShuffleRDD = if (excludedPartitions.isEmpty) { - child.execute() - } else { - shuffleExchange().createShuffledRDD(Some(getPartitionIndexRanges())) + cachedShuffleRDD = child match { + case stage: ShuffleQueryStageExec => + new CustomShuffledRowRDD( + stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs) + case _ => + throw new IllegalStateException("operating on canonicalization plan") } } cachedShuffleRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index 2c50b638b4d45..5bbcb14e008d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.adaptive -import scala.collection.mutable.{ArrayBuffer, HashSet} - import org.apache.spark.MapOutputStatistics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -29,24 +27,8 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf /** - * A rule to adjust the post shuffle partitions based on the map output statistics. - * - * The strategy used to determine the number of post-shuffle partitions is described as follows. - * To determine the number of post-shuffle partitions, we have a target input size for a - * post-shuffle partition. Once we have size statistics of all pre-shuffle partitions, we will do - * a pass of those statistics and pack pre-shuffle partitions with continuous indices to a single - * post-shuffle partition until adding another pre-shuffle partition would cause the size of a - * post-shuffle partition to be greater than the target size. - * - * For example, we have two stages with the following pre-shuffle partition size statistics: - * stage 1: [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB] - * stage 2: [10 MiB, 10 MiB, 70 MiB, 5 MiB, 5 MiB] - * assuming the target input size is 128 MiB, we will have four post-shuffle partitions, - * which are: - * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MiB) - * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MiB) - * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MiB) - * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MiB) + * A rule to reduce the post shuffle partitions based on the map output statistics, which can + * avoid many small reduce tasks that hurt performance. */ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { @@ -54,28 +36,21 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { if (!conf.reducePostShufflePartitionsEnabled) { return plan } - // 'SkewedShufflePartitionReader' is added by us, so it's safe to ignore it when changing - // number of reducers. - val leafNodes = plan.collectLeaves().filter(!_.isInstanceOf[SkewedPartitionReaderExec]) - if (!leafNodes.forall(_.isInstanceOf[QueryStageExec])) { + if (!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec])) { // If not all leaf nodes are query stages, it's not safe to reduce the number of // shuffle partitions, because we may break the assumption that all children of a spark plan // have same number of output partitions. return plan } - def collectShuffles(plan: SparkPlan): Seq[SparkPlan] = plan match { + def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match { case _: LocalShuffleReaderExec => Nil - case p: PartialShuffleReaderExec => Seq(p) + case _: SkewJoinShuffleReaderExec => Nil case stage: ShuffleQueryStageExec => Seq(stage) - case _ => plan.children.flatMap(collectShuffles) + case _ => plan.children.flatMap(collectShuffleStages) } - val shuffles = collectShuffles(plan) - val shuffleStages = shuffles.map { - case PartialShuffleReaderExec(s: ShuffleQueryStageExec, _) => s - case s: ShuffleQueryStageExec => s - } + val shuffleStages = collectShuffleStages(plan) // ShuffleExchanges introduced by repartition do not support changing the number of partitions. // We change the number of partitions in the stage only if all the ShuffleExchanges support it. if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) { @@ -94,110 +69,27 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { // partition) and a result of a SortMergeJoin (multiple partitions). val distinctNumPreShufflePartitions = validMetrics.map(stats => stats.bytesByPartitionId.length).distinct - val distinctExcludedPartitions = shuffles.map { - case PartialShuffleReaderExec(_, excludedPartitions) => excludedPartitions - case _: ShuffleQueryStageExec => Set.empty[Int] - }.distinct - if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1 - && distinctExcludedPartitions.length == 1) { - val excludedPartitions = distinctExcludedPartitions.head - val partitionIndices = estimatePartitionStartAndEndIndices( - validMetrics.toArray, excludedPartitions) + if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { + val partitionStartIndices = ShufflePartitionsCoalescer.coalescePartitions( + validMetrics.toArray, + firstPartitionIndex = 0, + lastPartitionIndex = distinctNumPreShufflePartitions.head, + advisoryTargetSize = conf.targetPostShuffleInputSize, + minNumPartitions = conf.minNumPostShufflePartitions) // This transformation adds new nodes, so we must use `transformUp` here. - // Even for shuffle exchange whose input RDD has 0 partition, we should still update its - // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same - // number of output partitions. - val visitedStages = HashSet.empty[Int] - plan.transformDown { - // Replace `PartialShuffleReaderExec` with `CoalescedShuffleReaderExec`, which keeps the - // "excludedPartition" requirement and also merges some partitions. - case PartialShuffleReaderExec(stage: ShuffleQueryStageExec, _) => - visitedStages.add(stage.id) - CoalescedShuffleReaderExec(stage, partitionIndices) - - // We are doing `transformDown`, so the `ShuffleQueryStageExec` may already be optimized - // and wrapped by `CoalescedShuffleReaderExec`. - case stage: ShuffleQueryStageExec if !visitedStages.contains(stage.id) => - visitedStages.add(stage.id) - CoalescedShuffleReaderExec(stage, partitionIndices) + val stageIds = shuffleStages.map(_.id).toSet + plan.transformUp { + // even for shuffle exchange whose input RDD has 0 partition, we should still update its + // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same + // number of output partitions. + case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) => + CoalescedShuffleReaderExec(stage, partitionStartIndices) } } else { plan } } } - - /** - * Estimates partition start and end indices for post-shuffle partitions based on - * mapOutputStatistics provided by all pre-shuffle stages and skip the omittedPartitions - * already handled in skewed partition optimization. - */ - // visible for testing. - private[sql] def estimatePartitionStartAndEndIndices( - mapOutputStatistics: Array[MapOutputStatistics], - excludedPartitions: Set[Int] = Set.empty): Array[(Int, Int)] = { - val minNumPostShufflePartitions = conf.minNumPostShufflePartitions - excludedPartitions.size - val advisoryTargetPostShuffleInputSize = conf.targetPostShuffleInputSize - // If minNumPostShufflePartitions is defined, it is possible that we need to use a - // value less than advisoryTargetPostShuffleInputSize as the target input size of - // a post shuffle task. - val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum - // The max at here is to make sure that when we have an empty table, we - // only have a single post-shuffle partition. - // There is no particular reason that we pick 16. We just need a number to - // prevent maxPostShuffleInputSize from being set to 0. - val maxPostShuffleInputSize = math.max( - math.ceil(totalPostShuffleInputSize / minNumPostShufflePartitions.toDouble).toLong, 16) - val targetPostShuffleInputSize = - math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) - - logInfo( - s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + - s"targetPostShuffleInputSize $targetPostShuffleInputSize.") - - // Make sure we do get the same number of pre-shuffle partitions for those stages. - val distinctNumPreShufflePartitions = - mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct - // The reason that we are expecting a single value of the number of pre-shuffle partitions - // is that when we add Exchanges, we set the number of pre-shuffle partitions - // (i.e. map output partitions) using a static setting, which is the value of - // spark.sql.shuffle.partitions. Even if two input RDDs are having different - // number of partitions, they will have the same number of pre-shuffle partitions - // (i.e. map output partitions). - assert( - distinctNumPreShufflePartitions.length == 1, - "There should be only one distinct value of the number pre-shuffle partitions " + - "among registered Exchange operator.") - - val partitionStartIndices = ArrayBuffer[Int]() - val partitionEndIndices = ArrayBuffer[Int]() - val numPartitions = distinctNumPreShufflePartitions.head - val includedPartitions = (0 until numPartitions).filter(!excludedPartitions.contains(_)) - val firstStartIndex = includedPartitions(0) - partitionStartIndices += firstStartIndex - var postShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId(firstStartIndex)).sum - var i = firstStartIndex - includedPartitions.drop(1).foreach { nextPartitionIndex => - val nextShuffleInputSize = - mapOutputStatistics.map(_.bytesByPartitionId(nextPartitionIndex)).sum - // If nextPartitionIndices is skewed and omitted, or including - // the nextShuffleInputSize would exceed the target partition size, - // then start a new partition. - if (nextPartitionIndex != i + 1 || - (postShuffleInputSize + nextShuffleInputSize > targetPostShuffleInputSize)) { - partitionEndIndices += i + 1 - partitionStartIndices += nextPartitionIndex - // reset postShuffleInputSize. - postShuffleInputSize = nextShuffleInputSize - i = nextPartitionIndex - } else { - postShuffleInputSize += nextShuffleInputSize - i += 1 - } - } - partitionEndIndices += i + 1 - partitionStartIndices.zip(partitionEndIndices).toArray - } } /** @@ -206,15 +98,16 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { * * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange node during * canonicalization. + * @param partitionStartIndices The start partition indices for the coalesced partitions. */ case class CoalescedShuffleReaderExec( child: SparkPlan, - partitionIndices: Array[(Int, Int)]) extends UnaryExecNode { + partitionStartIndices: Array[Int]) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = { - UnknownPartitioning(partitionIndices.length) + UnknownPartitioning(partitionStartIndices.length) } private var cachedShuffleRDD: ShuffledRowRDD = null @@ -223,7 +116,7 @@ case class CoalescedShuffleReaderExec( if (cachedShuffleRDD == null) { cachedShuffleRDD = child match { case stage: ShuffleQueryStageExec => - stage.shuffle.createShuffledRDD(Some(partitionIndices)) + stage.shuffle.createShuffledRDD(Some(partitionStartIndices)) case _ => throw new IllegalStateException("operating on canonicalization plan") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala new file mode 100644 index 0000000000000..18f0585524aa2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.MapOutputStatistics +import org.apache.spark.internal.Logging + +object ShufflePartitionsCoalescer extends Logging { + + /** + * Coalesce the same range of partitions (`firstPartitionIndex`` to `lastPartitionIndex`, the + * start is inclusive and the end is exclusive) from multiple shuffles. This method assumes that + * all the shuffles have the same number of partitions, and the partitions of same index will be + * read together by one task. + * + * The strategy used to determine the number of coalesced partitions is described as follows. + * To determine the number of coalesced partitions, we have a target size for a coalesced + * partition. Once we have size statistics of all shuffle partitions, we will do + * a pass of those statistics and pack shuffle partitions with continuous indices to a single + * coalesced partition until adding another shuffle partition would cause the size of a + * coalesced partition to be greater than the target size. + * + * For example, we have two shuffles with the following partition size statistics: + * - shuffle 1 (5 partitions): [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB] + * - shuffle 2 (5 partitions): [10 MiB, 10 MiB, 70 MiB, 5 MiB, 5 MiB] + * Assuming the target size is 128 MiB, we will have 4 coalesced partitions, which are: + * - coalesced partition 0: shuffle partition 0 (size 110 MiB) + * - coalesced partition 1: shuffle partition 1 (size 30 MiB) + * - coalesced partition 2: shuffle partition 2 (size 170 MiB) + * - coalesced partition 3: shuffle partition 3 and 4 (size 50 MiB) + * + * @return An array of partition indices which represents the coalesced partitions. For example, + * [0, 2, 3] means 3 coalesced partitions: [0, 2), [2, 3), [3, lastPartitionIndex] + */ + def coalescePartitions( + mapOutputStatistics: Array[MapOutputStatistics], + firstPartitionIndex: Int, + lastPartitionIndex: Int, + advisoryTargetSize: Long, + minNumPartitions: Int = 1): Array[Int] = { + // If `minNumPartitions` is very large, it is possible that we need to use a value less than + // `advisoryTargetSize` as the target size of a coalesced task. + val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum + // The max at here is to make sure that when we have an empty table, we only have a single + // coalesced partition. + // There is no particular reason that we pick 16. We just need a number to prevent + // `maxTargetSize` from being set to 0. + val maxTargetSize = math.max( + math.ceil(totalPostShuffleInputSize / minNumPartitions.toDouble).toLong, 16) + val targetSize = math.min(maxTargetSize, advisoryTargetSize) + + logInfo(s"advisory target size: $advisoryTargetSize, actual target size $targetSize.") + + // Make sure these shuffles have the same number of partitions. + val distinctNumShufflePartitions = + mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of shuffle partitions + // is that when we add Exchanges, we set the number of shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // `spark.sql.shuffle.partitions`. Even if two input RDDs are having different + // number of partitions, they will have the same number of shuffle partitions + // (i.e. map output partitions). + assert( + distinctNumShufflePartitions.length == 1, + "There should be only one distinct value of the number of shuffle partitions " + + "among registered Exchange operators.") + + val splitPoints = ArrayBuffer[Int]() + splitPoints += firstPartitionIndex + var coalescedSize = 0L + var i = firstPartitionIndex + while (i < lastPartitionIndex) { + // We calculate the total size of i-th shuffle partitions from all shuffles. + var totalSizeOfCurrentPartition = 0L + var j = 0 + while (j < mapOutputStatistics.length) { + totalSizeOfCurrentPartition += mapOutputStatistics(j).bytesByPartitionId(i) + j += 1 + } + + // If including the `totalSizeOfCurrentPartition` would exceed the target size, then start a + // new coalesced partition. + if (i > firstPartitionIndex && coalescedSize + totalSizeOfCurrentPartition > targetSize) { + splitPoints += i + // reset postShuffleInputSize. + coalescedSize = totalSizeOfCurrentPartition + } else { + coalescedSize += totalSizeOfCurrentPartition + } + i += 1 + } + + splitPoints.toArray + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala deleted file mode 100644 index 52f793b24aa17..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} - -/** - * The [[Partition]] used by [[SkewedShuffledRowRDD]]. - */ -class SkewedShuffledRowRDDPartition(override val index: Int) extends Partition - -/** - * This is a specialized version of [[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used - * in Spark SQL adaptive execution to solve data skew issues. This RDD includes rearranged - * partitions from mappers. - * - * This RDD takes a [[ShuffleDependency]] (`dependency`), a partitionIndex - * and the range of startMapIndex to endMapIndex. - */ -class SkewedShuffledRowRDD( - var dependency: ShuffleDependency[Int, InternalRow, InternalRow], - partitionIndex: Int, - startMapIndex: Int, - endMapIndex: Int, - metrics: Map[String, SQLMetric]) - extends RDD[InternalRow](dependency.rdd.context, Nil) { - - override def getDependencies: Seq[Dependency[_]] = List(dependency) - - override def getPartitions: Array[Partition] = { - Array(new SkewedShuffledRowRDDPartition(0)) - } - - override def getPreferredLocations(partition: Partition): Seq[String] = { - val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - tracker.getMapLocation(dependency, startMapIndex, endMapIndex) - } - - override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() - // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, - // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) - - val reader = SparkEnv.get.shuffleManager.getReaderForRange( - dependency.shuffleHandle, - startMapIndex, - endMapIndex, - partitionIndex, - partitionIndex + 1, - context, - sqlMetricsReporter) - reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) - } - - override def clearDependencies() { - super.clearDependencies() - dependency = null - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index ffcd6c7783354..4b08da043b83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,11 +30,11 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Divide, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{LocalShuffledRowRDD, SkewedShuffledRowRDD} +import org.apache.spark.sql.execution.adaptive.LocalShuffledRowRDD import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -49,11 +49,9 @@ case class ShuffleExchangeExec( child: SparkPlan, canChangeNumPartitions: Boolean = true) extends Exchange { - // NOTE: coordinator can be null after serialization/deserialization, - // e.g. it can be null on the Executor side private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) - private lazy val readMetrics = + private[sql] lazy val readMetrics = SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") @@ -90,9 +88,8 @@ case class ShuffleExchangeExec( writeMetrics) } - def createShuffledRDD( - partitionRanges: Option[Array[(Int, Int)]]): ShuffledRowRDD = { - new ShuffledRowRDD(shuffleDependency, readMetrics, partitionRanges) + def createShuffledRDD(partitionStartIndices: Option[Array[Int]]): ShuffledRowRDD = { + new ShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndices) } def createLocalShuffleRDD( @@ -100,14 +97,6 @@ case class ShuffleExchangeExec( new LocalShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndicesPerMapper) } - def createSkewedShuffleRDD( - partitionIndex: Int, - startMapIndex: Int, - endMapIndex: Int): SkewedShuffledRowRDD = { - new SkewedShuffledRowRDD(shuffleDependency, - partitionIndex, startMapIndex, endMapIndex, readMetrics) - } - /** * Caches the created ShuffleRowRDD so we can reuse that. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 6384aed6a78e0..62eea611556ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{PartialShuffleReaderExec, SkewedPartitionReaderExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -42,11 +41,17 @@ case class SortMergeJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isPartial: Boolean = false) extends BinaryExecNode with CodegenSupport { + isSkewJoin: Boolean = false) extends BinaryExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def nodeName: String = { + if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName + } + + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override def simpleStringWithNodeId(): String = { val opId = ExplainUtils.getOpId(this) s"$nodeName $joinType ($opId)".trim @@ -98,7 +103,9 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = { - if (isPartial) { + if (isSkewJoin) { + // We re-arrange the shuffle partitions to deal with skew join, and the new children + // partitioning doesn't satisfy `HashClusteredDistribution`. UnspecifiedDistribution :: UnspecifiedDistribution :: Nil } else { HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 04b4d4f29f850..5565a0dd01840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql._ import org.apache.spark.sql.execution.adaptive._ -import org.apache.spark.sql.execution.adaptive.{CoalescedShuffleReaderExec, ReduceNumShufflePartitions} +import org.apache.spark.sql.execution.adaptive.CoalescedShuffleReaderExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -52,212 +52,6 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA } } - private def checkEstimation( - rule: ReduceNumShufflePartitions, - bytesByPartitionIdArray: Array[Array[Long]], - expectedPartitionStartIndices: Array[Int]): Unit = { - val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { - case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) - } - val estimatedPartitionStartIndices = - rule.estimatePartitionStartAndEndIndices(mapOutputStatistics).map(_._1) - assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) - } - - private def createReduceNumShufflePartitionsRule( - advisoryTargetPostShuffleInputSize: Long, - minNumPostShufflePartitions: Int = 1): ReduceNumShufflePartitions = { - val conf = new SQLConf().copy( - SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE -> advisoryTargetPostShuffleInputSize, - SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS -> minNumPostShufflePartitions) - ReduceNumShufflePartitions(conf) - } - - test("test estimatePartitionStartIndices - 1 Exchange") { - val rule = createReduceNumShufflePartitionsRule(100L) - - { - // All bytes per partition are 0. - val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // Some bytes per partition are 0 and total size is less than the target size. - // 1 post-shuffle partition is needed. - val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // 2 post-shuffle partitions are needed. - val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0) - val expectedPartitionStartIndices = Array[Int](0, 3) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // There are a few large pre-shuffle partitions. - val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // All pre-shuffle partitions are larger than the targeted size. - val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // The last pre-shuffle partition is in a single post-shuffle partition. - val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110) - val expectedPartitionStartIndices = Array[Int](0, 4) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - } - - test("test estimatePartitionStartIndices - 2 Exchanges") { - val rule = createReduceNumShufflePartitionsRule(100L) - - { - // If there are multiple values of the number of pre-shuffle partitions, - // we should see an assertion error. - val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) - val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) - val mapOutputStatistics = - Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) - intercept[AssertionError](rule.estimatePartitionStartAndEndIndices( - mapOutputStatistics)) - } - - { - // All bytes per partition are 0. - val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) - val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // Some bytes per partition are 0. - // 1 post-shuffle partition is needed. - val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // 2 post-shuffle partition are needed. - val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 2, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // 4 post-shuffle partition are needed. - val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // 2 post-shuffle partition are needed. - val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // There are a few large pre-shuffle partitions. - val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // All pairs of pre-shuffle partitions are larger than the targeted size. - val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - } - - test("test estimatePartitionStartIndices and enforce minimal number of reducers") { - val rule = createReduceNumShufflePartitionsRule(100L, 2) - - { - // The minimal number of post-shuffle partitions is not enforced because - // the size of data is 0. - val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) - val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // The minimal number of post-shuffle partitions is enforced. - val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20) - val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5) - val expectedPartitionStartIndices = Array[Int](0, 3) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // The number of post-shuffle partitions is determined by the coordinator. - val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20) - val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) - val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - } - - /////////////////////////////////////////////////////////////////////////// - // Query tests - /////////////////////////////////////////////////////////////////////////// - val numInputPartitions: Int = 10 def withSparkSession( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala new file mode 100644 index 0000000000000..fcfde83b2ffd5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{MapOutputStatistics, SparkFunSuite} +import org.apache.spark.sql.execution.adaptive.ShufflePartitionsCoalescer + +class ShufflePartitionsCoalescerSuite extends SparkFunSuite { + + private def checkEstimation( + bytesByPartitionIdArray: Array[Array[Long]], + expectedPartitionStartIndices: Array[Int], + targetSize: Long, + minNumPartitions: Int = 1): Unit = { + val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { + case (bytesByPartitionId, index) => + new MapOutputStatistics(index, bytesByPartitionId) + } + val estimatedPartitionStartIndices = ShufflePartitionsCoalescer.coalescePartitions( + mapOutputStatistics, + 0, + bytesByPartitionIdArray.head.length, + targetSize, + minNumPartitions) + assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) + } + + test("1 shuffle") { + val targetSize = 100 + + { + // All bytes per partition are 0. + val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // Some bytes per partition are 0 and total size is less than the target size. + // 1 coalesced partition is expected. + val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // 2 coalesced partitions are expected. + val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // There are a few large shuffle partitions. + val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // All shuffle partitions are larger than the targeted size. + val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // The last shuffle partition is in a single coalesced partition. + val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110) + val expectedPartitionStartIndices = Array[Int](0, 4) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + } + + test("2 shuffles") { + val targetSize = 100 + + { + // If there are multiple values of the number of shuffle partitions, + // we should see an assertion error. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) + intercept[AssertionError] { + checkEstimation(Array(bytesByPartitionId1, bytesByPartitionId2), Array.empty, targetSize) + } + } + + { + // All bytes per partition are 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // Some bytes per partition are 0. + // 1 coalesced partition is expected. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // 2 coalesced partition are expected. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // 4 coalesced partition are expected. + val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // 2 coalesced partition are needed. + val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // There are a few large shuffle partitions. + val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // All pairs of shuffle partitions are larger than the targeted size. + val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + } + + test("enforce minimal number of coalesced partitions") { + val targetSize = 100 + val minNumPartitions = 2 + + { + // The minimal number of coalesced partitions is not enforced because + // the size of data is 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize, minNumPartitions) + } + + { + // The minimal number of coalesced partitions is enforced. + val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20) + val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize, minNumPartitions) + } + + { + // The number of coalesced partitions is determined by the algorithm. + val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20) + val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) + val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize, minNumPartitions) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a2071903bea7e..4edb35ea30fde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildRight, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.internal.SQLConf @@ -594,160 +594,84 @@ class AdaptiveQueryExecSuite .range(0, 1000, 1, 10) .selectExpr("id % 1 as key2", "id as value2") .createOrReplaceTempView("skewData2") - val (innerPlan, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( "SELECT key1 FROM skewData1 join skewData2 ON key1 = key2 group by key1") - val innerSmj = findTopLevelSortMergeJoin(innerPlan) - assert(innerSmj.size == 1) // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" optimization - val innerSmjAfter = findTopLevelSortMergeJoin(innerAdaptivePlan) - assert(innerSmjAfter.size == 1) + val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) + assert(innerSmj.size == 1 && !innerSmj.head.isSkewJoin) } } } + // TODO: we need a way to customize data distribution after shuffle, to improve test coverage + // of this case. test("SPARK-29544: adaptive skew join with different join types") { - Seq("false", "true").foreach { reducePostShufflePartitionsEnabled => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> "100", - SQLConf.REDUCE_POST_SHUFFLE_PARTITIONS_ENABLED.key -> reducePostShufflePartitionsEnabled, - SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") { - withTempView("skewData1", "skewData2") { - spark - .range(0, 1000, 1, 10) - .selectExpr("id % 2 as key1", "id as value1") - .createOrReplaceTempView("skewData1") - spark - .range(0, 1000, 1, 10) - .selectExpr("id % 1 as key2", "id as value2") - .createOrReplaceTempView("skewData2") - // skewed inner join optimization - val (innerPlan, innerAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val innerSmj = findTopLevelSortMergeJoin(innerPlan) - assert(innerSmj.size == 1) - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // the partition 0 in both left and right side are all skewed. - // And divide into 5 splits both in left and right (the max splits number). - // So there are 5 x 5 smjs for partition 0. - // Partition 4 in left side is skewed and is divided into 5 splits. - // The right side of partition 4 is not skewed. - // So there are 5 smjs for partition 4. - // So total (25 + 5 + 1) smjs. - // Union - // +- SortMergeJoin - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // . - // . - // . - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - - val innerSmjAfter = findTopLevelSortMergeJoin(innerAdaptivePlan) - assert(innerSmjAfter.size == 31) - - // skewed left outer join optimization - val (leftPlan, leftAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") - val leftSmj = findTopLevelSortMergeJoin(leftPlan) - assert(leftSmj.size == 1) - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // The partition 0 in both left and right are all skewed. - // The partition 4 in left side is skewed. - // But for left outer join, we don't split the right partition even skewed. - // So the partition 0 in left side is divided into 5 splits(the max split number). - // the partition 4 in left side is divided into 5 splits(the max split number). - // So total (5 + 5 + 1) smjs. - // Union - // +- SortMergeJoin - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // . - // . - // . - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - - val leftSmjAfter = findTopLevelSortMergeJoin(leftAdaptivePlan) - assert(leftSmjAfter.size == 11) - - // skewed right outer join optimization - val (rightPlan, rightAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") - val rightSmj = findTopLevelSortMergeJoin(rightPlan) - assert(rightSmj.size == 1) - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // The partition 0 in both left and right side are all skewed. - // And the partition 4 in left side is skewed. - // But for right outer join, we don't split the left partition even skewed. - // And divide right side into 5 splits(the max split number) - // So total 6 smjs. - // Union - // +- SortMergeJoin - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // . - // . - // . - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - - val rightSmjAfter = findTopLevelSortMergeJoin(rightAdaptivePlan) - assert(rightSmjAfter.size == 6) + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> "100", + SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 2 as key1", "id as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + + def checkSkewJoin(joins: Seq[SortMergeJoinExec], expectedNumPartitions: Int): Unit = { + assert(joins.size == 1 && joins.head.isSkewJoin) + assert(joins.head.left.collect { + case r: SkewJoinShuffleReaderExec => r + }.head.partitionSpecs.length == expectedNumPartitions) + assert(joins.head.right.collect { + case r: SkewJoinShuffleReaderExec => r + }.head.partitionSpecs.length == expectedNumPartitions) } + + // skewed inner join optimization + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 join skewData2 ON key1 = key2") + // left stats: [3496, 0, 0, 0, 4014] + // right stats:[6292, 0, 0, 0, 0] + // Partition 0: both left and right sides are skewed, and divide into 5 splits, so + // 5 x 5 sub-partitions. + // Partition 1, 2, 3: not skewed, and coalesced into 1 partition. + // Partition 4: only left side is skewed, and divide into 5 splits, so + // 5 sub-partitions. + // So total (25 + 1 + 5) partitions. + val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) + checkSkewJoin(innerSmj, 25 + 1 + 5) + + // skewed left outer join optimization + val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") + // left stats: [3496, 0, 0, 0, 4014] + // right stats:[6292, 0, 0, 0, 0] + // Partition 0: both left and right sides are skewed, but left join can't split right side, + // so only left side is divided into 5 splits, and thus 5 sub-partitions. + // Partition 1, 2, 3: not skewed, and coalesced into 1 partition. + // Partition 4: only left side is skewed, and divide into 5 splits, so + // 5 sub-partitions. + // So total (5 + 1 + 5) partitions. + val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan) + checkSkewJoin(leftSmj, 5 + 1 + 5) + + // skewed right outer join optimization + val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") + // left stats: [3496, 0, 0, 0, 4014] + // right stats:[6292, 0, 0, 0, 0] + // Partition 0: both left and right sides are skewed, but right join can't split left side, + // so only right side is divided into 5 splits, and thus 5 sub-partitions. + // Partition 1, 2, 3: not skewed, and coalesced into 1 partition. + // Partition 4: only left side is skewed, but right join can't split left side, so just + // 1 partition. + // So total (5 + 1 + 1) partitions. + val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan) + checkSkewJoin(rightSmj, 5 + 1 + 1) } } } @@ -805,3 +729,4 @@ class AdaptiveQueryExecSuite s" enabled but is not supported for"))) } } + From 5063cd93c8d76a6d9650d1243bf5e1cea8da1d94 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Feb 2020 11:17:27 -0800 Subject: [PATCH 0070/1280] [SPARK-30807][K8S][TESTS] Support Java 11 in K8S integration tests ### What changes were proposed in this pull request? This PR aims to support JDK11 test in K8S integration tests. - This is an update in testing framework instead of individual tests. - This will enable JDK11 runtime test when you didn't installed JDK11 on your local system. ### Why are the changes needed? Apache Spark 3.0.0 adds JDK11 support, but K8s integration tests use JDK8 until now. ### Does this PR introduce any user-facing change? No. This is a dev-only test-related PR. ### How was this patch tested? This is irrelevant to Jenkins UT, but Jenkins K8S IT (JDK8) should pass. - https://github.com/apache/spark/pull/27559#issuecomment-585903489 (JDK8 Passed) And, manually do the following for JDK11 test. ``` $ NO_MANUAL=1 ./dev/make-distribution.sh --r --pip --tgz -Phadoop-3.2 -Pkubernetes $ resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh --java-image-tag 11-jre-slim --spark-tgz $PWD/spark-*.tgz ``` ``` $ docker run -it --rm kubespark/spark:1318DD8A-2B15-4A00-BC69-D0E90CED235B /usr/local/openjdk-11/bin/java --version | tail -n1 OpenJDK 64-Bit Server VM 18.9 (build 11.0.6+10, mixed mode) ``` Closes #27559 from dongjoon-hyun/SPARK-30807. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit 859699135cb63b57f5d844e762070065cedb4408) Signed-off-by: Dongjoon Hyun --- .../docker/src/main/dockerfiles/spark/Dockerfile | 3 ++- .../kubernetes/integration-tests/README.md | 15 +++++++++++++-- .../dev/dev-run-integration-tests.sh | 10 ++++++++++ .../kubernetes/integration-tests/pom.xml | 4 ++++ .../scripts/setup-integration-test-env.sh | 14 +++++++++++--- 5 files changed, 40 insertions(+), 6 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index a1fc63789bc61..6ed37fc637b31 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +ARG java_image_tag=8-jre-slim -FROM openjdk:8-jre-slim +FROM openjdk:${java_image_tag} ARG spark_uid=185 diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index d7ad35a175a61..18b91916208d6 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -6,13 +6,17 @@ title: Spark on Kubernetes Integration Tests # Running the Kubernetes Integration Tests Note that the integration test framework is currently being heavily revised and -is subject to change. Note that currently the integration tests only run with Java 8. +is subject to change. The simplest way to run the integration tests is to install and run Minikube, then run the following from this directory: ./dev/dev-run-integration-tests.sh +To run tests with Java 11 instead of Java 8, use `--java-image-tag` to specify the base image. + + ./dev/dev-run-integration-tests.sh --java-image-tag 11-jre-slim + The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should run with a minimum of 4 CPUs and 6G of memory: @@ -183,7 +187,14 @@ to the wrapper scripts and using the wrapper scripts will simply set these appro A specific image tag to use, when set assumes images with those tags are already built and available in the specified image repository. When set to N/A (the default) fresh images will be built. - N/A + N/A + + + spark.kubernetes.test.javaImageTag + + A specific OpenJDK base image tag to use, when set uses it instead of 8-jre-slim. + + 8-jre-slim spark.kubernetes.test.imageTagFile diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 1f0a8035cea7b..76d6e1c1e8499 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -23,6 +23,7 @@ DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" SPARK_TGZ="N/A" IMAGE_TAG="N/A" +JAVA_IMAGE_TAG= BASE_IMAGE_NAME= JVM_IMAGE_NAME= PYTHON_IMAGE_NAME= @@ -52,6 +53,10 @@ while (( "$#" )); do IMAGE_TAG="$2" shift ;; + --java-image-tag) + JAVA_IMAGE_TAG="$2" + shift + ;; --deploy-mode) DEPLOY_MODE="$2" shift @@ -120,6 +125,11 @@ properties=( -Dtest.include.tags=$INCLUDE_TAGS ) +if [ -n "$JAVA_IMAGE_TAG" ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.javaImageTag=$JAVA_IMAGE_TAG ) +fi + if [ -n $NAMESPACE ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 8e1043f77db6d..369dfd491826c 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -39,6 +39,7 @@ ${project.build.directory}/spark-dist-unpacked N/A + 8-jre-slim ${project.build.directory}/imageTag.txt minikube docker.io/kubespark @@ -109,6 +110,9 @@ --image-tag ${spark.kubernetes.test.imageTag} + --java-image-tag + ${spark.kubernetes.test.javaImageTag} + --image-tag-output-file ${spark.kubernetes.test.imageTagFile} diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index 9e04b963fc40e..ab906604fce06 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -23,6 +23,7 @@ IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" IMAGE_TAG="N/A" +JAVA_IMAGE_TAG="8-jre-slim" SPARK_TGZ="N/A" # Parse arguments @@ -40,6 +41,10 @@ while (( "$#" )); do IMAGE_TAG="$2" shift ;; + --java-image-tag) + JAVA_IMAGE_TAG="$2" + shift + ;; --image-tag-output-file) IMAGE_TAG_OUTPUT_FILE="$2" shift @@ -82,6 +87,9 @@ then IMAGE_TAG=$(uuidgen); cd $SPARK_INPUT_DIR + # OpenJDK base-image tag (e.g. 8-jre-slim, 11-jre-slim) + JAVA_IMAGE_TAG_BUILD_ARG="-b java_image_tag=$JAVA_IMAGE_TAG" + # Build PySpark image LANGUAGE_BINDING_BUILD_ARGS="-p $DOCKER_FILE_BASE_PATH/bindings/python/Dockerfile" @@ -95,7 +103,7 @@ then case $DEPLOY_MODE in cloud) # Build images - $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $JAVA_IMAGE_TAG_BUILD_ARG $LANGUAGE_BINDING_BUILD_ARGS build # Push images appropriately if [[ $IMAGE_REPO == gcr.io* ]] ; @@ -109,13 +117,13 @@ then docker-for-desktop) # Only need to build as this will place it in our local Docker repo which is all # we need for Docker for Desktop to work so no need to also push - $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $JAVA_IMAGE_TAG_BUILD_ARG $LANGUAGE_BINDING_BUILD_ARGS build ;; minikube) # Only need to build and if we do this with the -m option for minikube we will # build the images directly using the minikube Docker daemon so no need to push - $SPARK_INPUT_DIR/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $JAVA_IMAGE_TAG_BUILD_ARG $LANGUAGE_BINDING_BUILD_ARGS build ;; *) echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1 From 72720ae144921efc405ef29c52f467c92ad0a3f3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Feb 2020 11:42:00 -0800 Subject: [PATCH 0071/1280] [SPARK-30816][K8S][TESTS] Fix dev-run-integration-tests.sh to ignore empty params ### What changes were proposed in this pull request? This PR aims to fix `dev-run-integration-tests.sh` to ignore empty params correctly. ### Why are the changes needed? The following script runs `mvn` integration test like the following. ``` $ resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh ... build/mvn integration-test -f /Users/dongjoon/APACHE/spark/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-2.12 -Pkubernetes -Pkubernetes-integration-tests -Djava.version=8 -Dspark.kubernetes.test.sparkTgz=N/A -Dspark.kubernetes.test.imageTag=N/A -Dspark.kubernetes.test.imageRepo=docker.io/kubespark -Dspark.kubernetes.test.deployMode=minikube -Dtest.include.tags=k8s -Dspark.kubernetes.test.namespace= -Dspark.kubernetes.test.serviceAccountName= -Dspark.kubernetes.test.kubeConfigContext= -Dspark.kubernetes.test.master= -Dtest.exclude.tags= -Dspark.kubernetes.test.jvmImage=spark -Dspark.kubernetes.test.pythonImage=spark-py -Dspark.kubernetes.test.rImage=spark-r ``` After this PR, the empty parameters like the followings will be skipped like the original design. ``` -Dspark.kubernetes.test.namespace= -Dspark.kubernetes.test.serviceAccountName= -Dspark.kubernetes.test.kubeConfigContext= -Dspark.kubernetes.test.master= -Dtest.exclude.tags= ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Pass the Jenkins K8S integration test. Closes #27566 from dongjoon-hyun/SPARK-30816. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit 74cd46eb691be5dc1cb1c496eeeaa2614945bd98) Signed-off-by: Dongjoon Hyun --- .../integration-tests/dev/dev-run-integration-tests.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 76d6e1c1e8499..607bb243458a6 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -130,27 +130,27 @@ then properties=( ${properties[@]} -Dspark.kubernetes.test.javaImageTag=$JAVA_IMAGE_TAG ) fi -if [ -n $NAMESPACE ]; +if [ -n "$NAMESPACE" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) fi -if [ -n $SERVICE_ACCOUNT ]; +if [ -n "$SERVICE_ACCOUNT" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT ) fi -if [ -n $CONTEXT ]; +if [ -n "$CONTEXT" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.kubeConfigContext=$CONTEXT ) fi -if [ -n $SPARK_MASTER ]; +if [ -n "$SPARK_MASTER" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) fi -if [ -n $EXCLUDE_TAGS ]; +if [ -n "$EXCLUDE_TAGS" ]; then properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) fi From 4db64ed37c601eb62aa3939d13f4f0e15bc1e4a9 Mon Sep 17 00:00:00 2001 From: Ali Afroozeh Date: Thu, 13 Feb 2020 23:58:55 +0100 Subject: [PATCH 0072/1280] [SPARK-30798][SQL] Scope Session.active in QueryExecution ### What changes were proposed in this pull request? This PR scopes `SparkSession.active` to prevent problems with processing queries with possibly different spark sessions (and different configs). A new method, `withActive` is introduced on `SparkSession` that restores the previous spark session after the block of code is executed. ### Why are the changes needed? `SparkSession.active` is a thread local variable that points to the current thread's spark session. It is important to note that the `SQLConf.get` method depends on `SparkSession.active`. In the current implementation it is possible that `SparkSession.active` points to a different session which causes various problems. Most of these problems arise because part of the query processing is done using the configurations of a different session. For example, when creating a data frame using a new session, i.e., `session.sql("...")`, part of the data frame is constructed using the currently active spark session, which can be a different session from the one used later for processing the query. ### Does this PR introduce any user-facing change? The `withActive` method is introduced on `SparkSession`. ### How was this patch tested? Unit tests (to be added) Closes #27387 from dbaliafroozeh/UseWithActiveSessionInQueryExecution. Authored-by: Ali Afroozeh Signed-off-by: herman (cherry picked from commit e2d3983de78f5c80fac066b7ee8bedd0987110dd) Signed-off-by: herman --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../apache/spark/sql/DataFrameWriterV2.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 36 ++++++++++--------- .../spark/sql/KeyValueGroupedDataset.scala | 5 +-- .../org/apache/spark/sql/SparkSession.scala | 30 +++++++++++----- .../spark/sql/execution/QueryExecution.scala | 16 +++++---- .../spark/sql/execution/SQLExecution.scala | 4 +-- .../streaming/MicroBatchExecution.scala | 4 +-- .../continuous/ContinuousExecution.scala | 2 +- .../spark/sql/internal/CatalogImpl.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 10 ++++++ .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../ui/SQLAppStatusListenerSuite.scala | 2 +- .../SparkExecuteStatementOperation.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 2 +- .../hive/execution/HiveComparisonTest.scala | 3 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 17 files changed, 74 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4557219abeb15..fff1f4b636dea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -896,7 +896,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = { val qe = session.sessionState.executePlan(command) // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd) + SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd) } private def lookupV2Provider(): Option[TableProvider] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index f5dd7613d4103..cf6bde5a2bcb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -226,7 +226,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) private def runCommand(name: String)(command: LogicalPlan): Unit = { val qe = sparkSession.sessionState.executePlan(command) // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd) + SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd) } private def internalReplace(orCreate: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a1c33f92d17b4..42f35354e864f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -82,18 +82,19 @@ private[sql] object Dataset { dataset } - def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { - val qe = sparkSession.sessionState.executePlan(logicalPlan) - qe.assertAnalyzed() - new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) + def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = + sparkSession.withActive { + val qe = sparkSession.sessionState.executePlan(logicalPlan) + qe.assertAnalyzed() + new Dataset[Row](qe, RowEncoder(qe.analyzed.schema)) } /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker) - : DataFrame = { + : DataFrame = sparkSession.withActive { val qe = new QueryExecution(sparkSession, logicalPlan, tracker) qe.assertAnalyzed() - new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) + new Dataset[Row](qe, RowEncoder(qe.analyzed.schema)) } } @@ -185,13 +186,12 @@ private[sql] object Dataset { */ @Stable class Dataset[T] private[sql]( - @transient private val _sparkSession: SparkSession, @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @DeveloperApi @Unstable @transient val encoder: Encoder[T]) extends Serializable { @transient lazy val sparkSession: SparkSession = { - if (_sparkSession == null) { + if (queryExecution == null || queryExecution.sparkSession == null) { throw new SparkException( "Dataset transformations and actions can only be invoked by the driver, not inside of" + " other Dataset transformations; for example, dataset1.map(x => dataset2.values.count()" + @@ -199,7 +199,7 @@ class Dataset[T] private[sql]( "performed inside of the dataset1.map transformation. For more information," + " see SPARK-28702.") } - _sparkSession + queryExecution.sparkSession } // A globally unique id of this Dataset. @@ -211,7 +211,7 @@ class Dataset[T] private[sql]( // you wrap it with `withNewExecutionId` if this actions doesn't call other action. def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { - this(sparkSession, sparkSession.sessionState.executePlan(logicalPlan), encoder) + this(sparkSession.sessionState.executePlan(logicalPlan), encoder) } def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { @@ -445,7 +445,7 @@ class Dataset[T] private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) + def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema)) /** * Returns a new Dataset where each record has been mapped on to the specified type. The @@ -503,7 +503,9 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def schema: StructType = queryExecution.analyzed.schema + def schema: StructType = sparkSession.withActive { + queryExecution.analyzed.schema + } /** * Prints the schema to the console in a nice tree format. @@ -539,7 +541,7 @@ class Dataset[T] private[sql]( * @group basic * @since 3.0.0 */ - def explain(mode: String): Unit = { + def explain(mode: String): Unit = sparkSession.withActive { // Because temporary views are resolved during analysis when we create a Dataset, and // `ExplainCommand` analyzes input query plan and resolves temporary views again. Using // `ExplainCommand` here will probably output different query plans, compared to the results @@ -1502,7 +1504,7 @@ class Dataset[T] private[sql]( val namedColumns = columns.map(_.withInputType(exprEnc, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) - new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) + new Dataset(execution, ExpressionEncoder.tuple(encoders)) } /** @@ -3472,7 +3474,7 @@ class Dataset[T] private[sql]( * an execution. */ private def withNewExecutionId[U](body: => U): U = { - SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) + SQLExecution.withNewExecutionId(queryExecution)(body) } /** @@ -3481,7 +3483,7 @@ class Dataset[T] private[sql]( * reset. */ private def withNewRDDExecutionId[U](body: => U): U = { - SQLExecution.withNewExecutionId(sparkSession, rddQueryExecution) { + SQLExecution.withNewExecutionId(rddQueryExecution) { rddQueryExecution.executedPlan.resetMetrics() body } @@ -3492,7 +3494,7 @@ class Dataset[T] private[sql]( * user-registered callback functions. */ private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { - SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) { + SQLExecution.withNewExecutionId(qe, Some(name)) { qe.executedPlan.resetMetrics() action(qe.executedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 89cc9735e4f6a..76ee297dfca79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -449,10 +449,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) - new Dataset( - sparkSession, - execution, - ExpressionEncoder.tuple(kExprEnc +: encoders)) + new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index abefb348cafc7..1fb97fb4b4cf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -293,8 +293,7 @@ class SparkSession private( * * @since 2.0.0 */ - def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SparkSession.setActiveSession(this) + def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = withActive { val encoder = Encoders.product[A] Dataset.ofRows(self, ExternalRDD(rdd, self)(encoder)) } @@ -304,8 +303,7 @@ class SparkSession private( * * @since 2.0.0 */ - def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SparkSession.setActiveSession(this) + def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = withActive { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -343,7 +341,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val encoder = RowEncoder(schema) @@ -373,7 +371,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive { Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } @@ -385,7 +383,7 @@ class SparkSession private( * * @since 2.0.0 */ - def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = withActive { val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -414,7 +412,7 @@ class SparkSession private( * SELECT * queries will return the columns in an undefined order. * @since 1.6.0 */ - def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { + def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = withActive { val attrSeq = getSchema(beanClass) val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq) Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) @@ -599,7 +597,7 @@ class SparkSession private( * * @since 2.0.0 */ - def sql(sqlText: String): DataFrame = { + def sql(sqlText: String): DataFrame = withActive { val tracker = new QueryPlanningTracker val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { sessionState.sqlParser.parsePlan(sqlText) @@ -751,6 +749,20 @@ class SparkSession private( } } + /** + * Execute a block of code with the this session set as the active session, and restore the + * previous session on completion. + */ + private[sql] def withActive[T](block: => T): T = { + // Use the active session thread local directly to make sure we get the session that is actually + // set and not the default session. This to prevent that we promote the default session to the + // active session once we are done. + val old = SparkSession.activeThreadSession.get() + SparkSession.setActiveSession(this) + try block finally { + SparkSession.setActiveSession(old) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 38ef66682c413..53b6b5d82c021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -63,13 +63,12 @@ class QueryExecution( } } - lazy val analyzed: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.ANALYSIS) { - SparkSession.setActiveSession(sparkSession) + lazy val analyzed: LogicalPlan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } - lazy val withCachedData: LogicalPlan = { + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, @@ -77,20 +76,20 @@ class QueryExecution( sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone()) } - lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { + lazy val optimizedPlan: LogicalPlan = executePhase(QueryPlanningTracker.OPTIMIZATION) { // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) } - lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { + lazy val sparkPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) { // Clone the logical plan here, in case the planner rules change the states of the logical plan. QueryExecution.createSparkPlan(sparkSession, planner, optimizedPlan.clone()) } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { + lazy val executedPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) { // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. QueryExecution.prepareForExecution(preparations, sparkPlan.clone()) @@ -116,6 +115,10 @@ class QueryExecution( QueryExecution.preparations(sparkSession) } + private def executePhase[T](phase: String)(block: => T): T = sparkSession.withActive { + tracker.measurePhase(phase)(block) + } + def simpleString: String = simpleString(false) def simpleString(formatted: Boolean): String = withRedaction { @@ -305,7 +308,6 @@ object QueryExecution { sparkSession: SparkSession, planner: SparkPlanner, plan: LogicalPlan): SparkPlan = { - SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. planner.plan(ReturnAnswer(plan)).next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 9f177819f6ea7..59c503e372535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -62,9 +62,9 @@ object SQLExecution { * we can connect them with an execution. */ def withNewExecutionId[T]( - sparkSession: SparkSession, queryExecution: QueryExecution, - name: Option[String] = None)(body: => T): T = { + name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive { + val sparkSession = queryExecution.sparkSession val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) val executionId = SQLExecution.nextExecutionId diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 83bc347e23ed4..45a2ce16183a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -563,11 +563,11 @@ class MicroBatchExecution( } val nextBatch = - new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + new Dataset(lastExecution, RowEncoder(lastExecution.analyzed.schema)) val batchSinkProgress: Option[StreamWriterCommitProgress] = reportTimeTaken("addBatch") { - SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { + SQLExecution.withNewExecutionId(lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) case _: SupportsWrite => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a9b724a73a18e..a109c2171f3d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -252,7 +252,7 @@ class ContinuousExecution( updateStatusMessage("Running") reportTimeTaken("runContinuous") { - SQLExecution.withNewExecutionId(sparkSessionForQuery, lastExecution) { + SQLExecution.withNewExecutionId(lastExecution) { lastExecution.executedPlan.execute() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 3740b56cb9cbb..d3ef03e9b3b74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -520,7 +520,7 @@ private[sql] object CatalogImpl { val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(enc.schema.toAttributes, encoded) val queryExecution = sparkSession.sessionState.executePlan(plan) - new Dataset[T](sparkSession, queryExecution, enc) + new Dataset[T](queryExecution, enc) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 233d67898f909..b0bd612e88d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1899,6 +1899,16 @@ class DatasetSuite extends QueryTest val e = intercept[AnalysisException](spark.range(1).tail(-1)) e.getMessage.contains("tail expression must be equal to or greater than 0") } + + test("SparkSession.active should be the same instance after dataset operations") { + val active = SparkSession.getActiveSession.get + val clone = active.cloneSession() + val ds = new Dataset(clone, spark.range(10).queryExecution.logical, Encoders.INT) + + ds.queryExecution.analyzed + + assert(active eq SparkSession.getActiveSession.get) + } } object AssertExecutionId { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index da4727f6a98cb..83285911b3948 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -511,7 +511,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { val df = session.sql(sql) val schema = df.schema.catalogString // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) { + val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) { hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 55b551d0af078..9f4a335309b63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -506,7 +506,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils override lazy val executedPlan = physicalPlan } - SQLExecution.withNewExecutionId(spark, dummyQueryExecution) { + SQLExecution.withNewExecutionId(dummyQueryExecution) { physicalPlan.execute().collect() } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 76d07848f79a9..cf0e5ebf3a2b1 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -295,7 +295,7 @@ private[hive] class SparkExecuteStatementOperation( resultList.get.iterator } } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + dataTypes = result.schema.fields.map(_.dataType) } catch { // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 362ac362e9718..12fba0eae6dce 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont try { context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) - hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { + hiveResponse = SQLExecution.withNewExecutionId(execution) { hiveResultString(execution.executedPlan) } tableSchema = getResultSetSchema(execution) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 28e1db961f611..8b1f4c92755b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -346,8 +346,7 @@ abstract class HiveComparisonTest val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) def getResult(): Seq[String] = { - SQLExecution.withNewExecutionId( - query.sparkSession, query)(hiveResultString(query.executedPlan)) + SQLExecution.withNewExecutionId(query)(hiveResultString(query.executedPlan)) } try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index cc4592a5caf68..222244a04f5f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -501,7 +501,7 @@ private[hive] class TestHiveSparkSession( // has already set the execution id. if (sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) == null) { // We don't actually have a `QueryExecution` here, use a fake one instead. - SQLExecution.withNewExecutionId(this, new QueryExecution(this, OneRowRelation())) { + SQLExecution.withNewExecutionId(new QueryExecution(this, OneRowRelation())) { createCmds.foreach(_()) } } else { From 6001866cea1216da421c5acd71d6fc74228222ac Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Thu, 13 Feb 2020 16:15:00 -0800 Subject: [PATCH 0073/1280] [SPARK-30667][CORE] Add allGather method to BarrierTaskContext ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27395 from sarthfrey/master. Lead-authored-by: sarthfrey-db Co-authored-by: sarthfrey Signed-off-by: Xiangrui Meng (cherry picked from commit 57254c9719f9af9ad985596ed7fbbaafa4052002) Signed-off-by: Xiangrui Meng --- .../org/apache/spark/BarrierCoordinator.scala | 113 +++++++++++-- .../org/apache/spark/BarrierTaskContext.scala | 153 ++++++++++++------ .../spark/api/python/PythonRunner.scala | 51 ++++-- .../scheduler/BarrierTaskContextSuite.scala | 74 +++++++++ python/pyspark/taskcontext.py | 49 +++++- python/pyspark/tests/test_taskcontext.py | 20 +++ 6 files changed, 381 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 4e417679ca663..042a2664a0e27 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,12 +17,17 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator( // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 - // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() - // call. + // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call + private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] + + // The blocking requestMethod called by tasks to sync up for this stage attempt + private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER + // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + def handleRequest( + requester: RpcCallContext, + request: RequestToSync + ): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch + val requestMethod = request.requestMethod + val partitionId = request.partitionId + val allGatherMessage = request match { + case ag: AllGatherRequestToSync => ag.allGatherMessage + case _ => "" + } + + if (requesters.size == 0) { + requestMethodToSync = requestMethod + } + + if (requestMethodToSync != requestMethod) { + requesters.foreach( + _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + + s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + + s"the current synchronized requestMethod `$requestMethodToSync`" + )) + ) + cleanupBarrierStage(barrierId) + } // Require the number of tasks is correctly set from the BarrierTaskContext. require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + @@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester + allGatherMessages(partitionId) = allGatherMessage logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (maybeFinishAllRequesters(requesters, numTasks)) { @@ -162,6 +196,7 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() + allGatherMessages.clear() cancelTimerTask() } } @@ -173,7 +208,13 @@ private[spark] class BarrierCoordinator( requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { if (requesters.size == numTasks) { - requesters.foreach(_.reply(())) + requestMethodToSync match { + case RequestMethod.BARRIER => + requesters.foreach(_.reply("")) + case RequestMethod.ALL_GATHER => + val json: String = compact(render(allGatherMessages)) + requesters.foreach(_.reply(json)) + } true } else { false @@ -186,6 +227,7 @@ private[spark] class BarrierCoordinator( // messages come from current stage attempt shall fail. barrierEpoch = -1 requesters.clear() + allGatherMessages.clear() cancelTimerTask() } } @@ -199,11 +241,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + case request: RequestToSync => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(stageId, stageAttemptId) + val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -216,6 +258,16 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable +private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { + def numTasks: Int + def stageId: Int + def stageAttemptId: Int + def taskAttemptId: Long + def barrierEpoch: Int + def partitionId: Int + def requestMethod: RequestMethod.Value +} + /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is * identified by stageId + stageAttemptId + barrierEpoch. @@ -224,11 +276,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator */ -private[spark] case class RequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int) extends BarrierCoordinatorMessage +private[spark] case class BarrierRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value +) extends RequestToSync + +/** + * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator + * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER + */ +private[spark] case class AllGatherRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String +) extends RequestToSync + +private[spark] object RequestMethod extends Enumeration { + val BARRIER, ALL_GATHER = Value +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 3d369802f3023..2263538a11676 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,11 +17,19 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.json4s.DefaultFormats +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - /** - * :: Experimental :: - * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to - * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same - * stage have reached this routine. - * - * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all - * possible code branches. Otherwise, you may get the job hanging or a SparkException after - * timeout. Some examples of '''misuses''' are listed below: - * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it - * shall lead to timeout of the function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * if (context.partitionId() == 0) { - * // Do nothing. - * } else { - * context.barrier() - * } - * iter - * } - * }}} - * - * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the - * second function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * try { - * // Do something that might throw an Exception. - * doSomething() - * context.barrier() - * } catch { - * case e: Exception => logWarning("...", e) - * } - * context.barrier() - * iter - * } - * }}} - */ - @Experimental - @Since("2.4.0") - def barrier(): Unit = { + private def getRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptNumber: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String + ): RequestToSync = { + requestMethod match { + case RequestMethod.BARRIER => + BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod) + case RequestMethod.ALL_GATHER => + AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod, allGatherMessage) + } + } + + private def runBarrier( + requestMethod: RequestMethod.Value, + allGatherMessage: String = "" + ): String = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") logTrace("Current callSite: " + Utils.getCallSite()) @@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] ( // Log the update of global sync every 60 seconds. timer.schedule(timerTask, 60000, 60000) + var json: String = "" + try { - val abortableRpcFuture = barrierCoordinator.askAbortable[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch), + val abortableRpcFuture = barrierCoordinator.askAbortable[String]( + message = getRequestToSync(numTasks, stageId, stageAttemptNumber, + taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) @@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] ( while (!abortableRpcFuture.toFuture.isCompleted) { // wait RPC future for at most 1 second try { - ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) } catch { case _: TimeoutException | _: InterruptedException => // If `TimeoutException` thrown, waiting RPC future reach 1 second. @@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] ( timerTask.cancel() timer.purge() } + json + } + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of '''misuses''' are listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { + runBarrier(RequestMethod.BARRIER) + () + } + + /** + * :: Experimental :: + * Blocks until all tasks in the same stage have reached this routine. Each task passes in + * a message and returns with a list of all the messages passed in by each of those tasks. + * + * CAUTION! The allGather method requires the same precautions as the barrier method + * + * The message is type String rather than Array[Byte] because it is more convenient for + * the user at the cost of worse performance. + */ + @Experimental + @Since("3.0.0") + def allGather(message: String): ArrayBuffer[String] = { + val json = runBarrier(RequestMethod.ALL_GATHER, message) + val jsonArray = parse(json) + implicit val formats = DefaultFormats + ArrayBuffer(jsonArray.extract[Array[String]]: _*) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 658e0d593a167..fa8bf0fc06358 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,8 +24,13 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -238,13 +243,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock.setSoTimeout(10000) authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) - input.readInt() match { + val requestMethod = input.readInt() + // The BarrierTaskContext function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + requestMethod match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - // The barrier() function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - barrierAndServe(sock) - + barrierAndServe(requestMethod, sock) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val length = input.readInt() + val message = new Array[Byte](length) + input.readFully(message) + barrierAndServe(requestMethod, sock, new String(message, UTF_8)) case _ => val out = new DataOutputStream(new BufferedOutputStream( sock.getOutputStream)) @@ -395,15 +405,31 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } /** - * Gateway to call BarrierTaskContext.barrier(). + * Gateway to call BarrierTaskContext methods. */ - def barrierAndServe(sock: Socket): Unit = { - require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") - + def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { + require( + serverSocket.isDefined, + "No available ServerSocket to redirect the BarrierTaskContext method call." + ) val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - context.asInstanceOf[BarrierTaskContext].barrier() - writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + var result: String = "" + requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].barrier() + result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( + message + ) + result = compact(render(JArray( + messages.map( + (message) => JString(message) + ).toList + ))) + } + writeUTF(result, out) } catch { case e: SparkException => writeUTF(e.getMessage, out) @@ -638,6 +664,7 @@ private[spark] object SpecialLengths { private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_FUNCTION = 1 + val ALL_GATHER_FUNCTION = 2 val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index fc8ac38479932..ed38b7f0ecac1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.File +import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ @@ -52,6 +53,79 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(times.max - times.min <= 1000) } + test("share messages with allGather() call") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message = context.partitionId().toString + val messages = context.allGather(message) + messages.toList.iterator + } + // Take a sorted list of all the partitionId messages + val messages = rdd2.collect().head + // All the task partitionIds are shared + for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString) + } + + test("throw exception if we attempt to synchronize with different blocking calls") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + val partitionId = context.partitionId + if (partitionId == 0) { + context.barrier() + } else { + context.allGather(partitionId.toString) + } + Seq(null).iterator + } + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("does not match the current synchronized requestMethod")) + } + + test("successively sync with allGather and barrier") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message = context.partitionId().toString + val messages = context.allGather(message) + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index d648f63338514..90bd2345ac525 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,9 +16,10 @@ # from __future__ import print_function +import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, UTF8Deserializer +from pyspark.serializers import write_int, write_with_length, UTF8Deserializer class TaskContext(object): @@ -107,18 +108,28 @@ def resources(self): BARRIER_FUNCTION = 1 +ALL_GATHER_FUNCTION = 2 -def _load_from_socket(port, auth_secret): +def _load_from_socket(port, auth_secret, function, all_gather_message=None): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) - # The barrier() call may block forever, so no timeout + + # The call may block forever, so no timeout sock.settimeout(None) - # Make a barrier() function call. - write_int(BARRIER_FUNCTION, sockfile) + + if function == BARRIER_FUNCTION: + # Make a barrier() function call. + write_int(function, sockfile) + elif function == ALL_GATHER_FUNCTION: + # Make a all_gather() function call. + write_int(function, sockfile) + write_with_length(all_gather_message.encode("utf-8"), sockfile) + else: + raise ValueError("Unrecognized function type") sockfile.flush() # Collect result. @@ -199,7 +210,33 @@ def barrier(self): raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - _load_from_socket(self._port, self._secret) + _load_from_socket(self._port, self._secret, BARRIER_FUNCTION) + + def allGather(self, message=""): + """ + .. note:: Experimental + + This function blocks until all tasks in the same stage have reached this routine. + Each task passes in a message and returns with a list of all the messages passed in + by each of those tasks. + + .. warning:: In a barrier stage, each task much have the same number of `allGather()` + calls, in all possible code branches. + Otherwise, you may get the job hanging or a SparkException after timeout. + """ + if not isinstance(message, str): + raise ValueError("Argument `message` must be of type `str`") + elif self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + gathered_items = _load_from_socket( + self._port, + self._secret, + ALL_GATHER_FUNCTION, + message, + ) + return [e for e in json.loads(gathered_items)] def getTaskInfos(self): """ diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 6095a384679af..f5dbd068387c2 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -134,6 +134,26 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) + def test_all_gather(self): + """ + Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks + within a stage and passes messages properly. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + out = tc.allGather(str(context.partitionId())) + pids = [int(e) for e in out] + return [pids] + + pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] + self.assertTrue(pids == [0, 1, 2, 3]) + def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the From eb37aa5595badd79becf4d3d332404cbcdb1b12d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 13 Feb 2020 17:48:19 -0800 Subject: [PATCH 0074/1280] Revert "[SPARK-30667][CORE] Add allGather method to BarrierTaskContext" This reverts commit 6001866cea1216da421c5acd71d6fc74228222ac. --- .../org/apache/spark/BarrierCoordinator.scala | 113 ++----------- .../org/apache/spark/BarrierTaskContext.scala | 153 ++++++------------ .../spark/api/python/PythonRunner.scala | 51 ++---- .../scheduler/BarrierTaskContextSuite.scala | 74 --------- python/pyspark/taskcontext.py | 49 +----- python/pyspark/tests/test_taskcontext.py | 20 --- 6 files changed, 79 insertions(+), 381 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 042a2664a0e27..4e417679ca663 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,17 +17,12 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -104,15 +99,10 @@ private[spark] class BarrierCoordinator( // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 - // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) - // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call - private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] - - // The blocking requestMethod called by tasks to sync up for this stage attempt - private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER - // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -140,32 +130,9 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest( - requester: RpcCallContext, - request: RequestToSync - ): Unit = synchronized { + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch - val requestMethod = request.requestMethod - val partitionId = request.partitionId - val allGatherMessage = request match { - case ag: AllGatherRequestToSync => ag.allGatherMessage - case _ => "" - } - - if (requesters.size == 0) { - requestMethodToSync = requestMethod - } - - if (requestMethodToSync != requestMethod) { - requesters.foreach( - _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + - s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + - s"the current synchronized requestMethod `$requestMethodToSync`" - )) - ) - cleanupBarrierStage(barrierId) - } // Require the number of tasks is correctly set from the BarrierTaskContext. require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + @@ -186,7 +153,6 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester - allGatherMessages(partitionId) = allGatherMessage logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (maybeFinishAllRequesters(requesters, numTasks)) { @@ -196,7 +162,6 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() - allGatherMessages.clear() cancelTimerTask() } } @@ -208,13 +173,7 @@ private[spark] class BarrierCoordinator( requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { if (requesters.size == numTasks) { - requestMethodToSync match { - case RequestMethod.BARRIER => - requesters.foreach(_.reply("")) - case RequestMethod.ALL_GATHER => - val json: String = compact(render(allGatherMessages)) - requesters.foreach(_.reply(json)) - } + requesters.foreach(_.reply(())) true } else { false @@ -227,7 +186,6 @@ private[spark] class BarrierCoordinator( // messages come from current stage attempt shall fail. barrierEpoch = -1 requesters.clear() - allGatherMessages.clear() cancelTimerTask() } } @@ -241,11 +199,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request: RequestToSync => + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) + val barrierId = ContextBarrierId(stageId, stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -258,16 +216,6 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable -private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { - def numTasks: Int - def stageId: Int - def stageAttemptId: Int - def taskAttemptId: Long - def barrierEpoch: Int - def partitionId: Int - def requestMethod: RequestMethod.Value -} - /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is * identified by stageId + stageAttemptId + barrierEpoch. @@ -276,44 +224,11 @@ private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls - * @param partitionId ID of the current partition the task is assigned to - * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. */ -private[spark] case class BarrierRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value -) extends RequestToSync - -/** - * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is - * identified by stageId + stageAttemptId + barrierEpoch. - * - * @param numTasks The number of global sync requests the BarrierCoordinator shall receive - * @param stageId ID of current stage - * @param stageAttemptId ID of current stage attempt - * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls - * @param partitionId ID of the current partition the task is assigned to - * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator - * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER - */ -private[spark] case class AllGatherRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String -) extends RequestToSync - -private[spark] object RequestMethod extends Enumeration { - val BARRIER, ALL_GATHER = Value -} +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 2263538a11676..3d369802f3023 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,19 +17,11 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.json4s.DefaultFormats -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.parse import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -67,31 +59,49 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - private def getRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptNumber: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String - ): RequestToSync = { - requestMethod match { - case RequestMethod.BARRIER => - BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod) - case RequestMethod.ALL_GATHER => - AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod, allGatherMessage) - } - } - - private def runBarrier( - requestMethod: RequestMethod.Value, - allGatherMessage: String = "" - ): String = { - + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of '''misuses''' are listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") logTrace("Current callSite: " + Utils.getCallSite()) @@ -108,12 +118,10 @@ class BarrierTaskContext private[spark] ( // Log the update of global sync every 60 seconds. timer.schedule(timerTask, 60000, 60000) - var json: String = "" - try { - val abortableRpcFuture = barrierCoordinator.askAbortable[String]( - message = getRequestToSync(numTasks, stageId, stageAttemptNumber, - taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), + val abortableRpcFuture = barrierCoordinator.askAbortable[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) @@ -125,7 +133,7 @@ class BarrierTaskContext private[spark] ( while (!abortableRpcFuture.toFuture.isCompleted) { // wait RPC future for at most 1 second try { - json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) } catch { case _: TimeoutException | _: InterruptedException => // If `TimeoutException` thrown, waiting RPC future reach 1 second. @@ -155,73 +163,6 @@ class BarrierTaskContext private[spark] ( timerTask.cancel() timer.purge() } - json - } - - /** - * :: Experimental :: - * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to - * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same - * stage have reached this routine. - * - * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all - * possible code branches. Otherwise, you may get the job hanging or a SparkException after - * timeout. Some examples of '''misuses''' are listed below: - * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it - * shall lead to timeout of the function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * if (context.partitionId() == 0) { - * // Do nothing. - * } else { - * context.barrier() - * } - * iter - * } - * }}} - * - * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the - * second function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * try { - * // Do something that might throw an Exception. - * doSomething() - * context.barrier() - * } catch { - * case e: Exception => logWarning("...", e) - * } - * context.barrier() - * iter - * } - * }}} - */ - @Experimental - @Since("2.4.0") - def barrier(): Unit = { - runBarrier(RequestMethod.BARRIER) - () - } - - /** - * :: Experimental :: - * Blocks until all tasks in the same stage have reached this routine. Each task passes in - * a message and returns with a list of all the messages passed in by each of those tasks. - * - * CAUTION! The allGather method requires the same precautions as the barrier method - * - * The message is type String rather than Array[Byte] because it is more convenient for - * the user at the cost of worse performance. - */ - @Experimental - @Since("3.0.0") - def allGather(message: String): ArrayBuffer[String] = { - val json = runBarrier(RequestMethod.ALL_GATHER, message) - val jsonArray = parse(json) - implicit val formats = DefaultFormats - ArrayBuffer(jsonArray.extract[Array[String]]: _*) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index fa8bf0fc06358..658e0d593a167 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,13 +24,8 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -243,18 +238,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock.setSoTimeout(10000) authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) - val requestMethod = input.readInt() - // The BarrierTaskContext function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - requestMethod match { + input.readInt() match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - barrierAndServe(requestMethod, sock) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val length = input.readInt() - val message = new Array[Byte](length) - input.readFully(message) - barrierAndServe(requestMethod, sock, new String(message, UTF_8)) + // The barrier() function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + barrierAndServe(sock) + case _ => val out = new DataOutputStream(new BufferedOutputStream( sock.getOutputStream)) @@ -405,31 +395,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } /** - * Gateway to call BarrierTaskContext methods. + * Gateway to call BarrierTaskContext.barrier(). */ - def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { - require( - serverSocket.isDefined, - "No available ServerSocket to redirect the BarrierTaskContext method call." - ) + def barrierAndServe(sock: Socket): Unit = { + require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - var result: String = "" - requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].barrier() - result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( - message - ) - result = compact(render(JArray( - messages.map( - (message) => JString(message) - ).toList - ))) - } - writeUTF(result, out) + context.asInstanceOf[BarrierTaskContext].barrier() + writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) } catch { case e: SparkException => writeUTF(e.getMessage, out) @@ -664,7 +638,6 @@ private[spark] object SpecialLengths { private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_FUNCTION = 1 - val ALL_GATHER_FUNCTION = 2 val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index ed38b7f0ecac1..fc8ac38479932 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.File -import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ @@ -53,79 +52,6 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(times.max - times.min <= 1000) } - test("share messages with allGather() call") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - // Pass partitionId message in - val message = context.partitionId().toString - val messages = context.allGather(message) - messages.toList.iterator - } - // Take a sorted list of all the partitionId messages - val messages = rdd2.collect().head - // All the task partitionIds are shared - for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString) - } - - test("throw exception if we attempt to synchronize with different blocking calls") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - val partitionId = context.partitionId - if (partitionId == 0) { - context.barrier() - } else { - context.allGather(partitionId.toString) - } - Seq(null).iterator - } - val error = intercept[SparkException] { - rdd2.collect() - }.getMessage - assert(error.contains("does not match the current synchronized requestMethod")) - } - - test("successively sync with allGather and barrier") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - context.barrier() - val time1 = System.currentTimeMillis() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - // Pass partitionId message in - val message = context.partitionId().toString - val messages = context.allGather(message) - val time2 = System.currentTimeMillis() - Seq((time1, time2)).iterator - } - val times = rdd2.collect() - // All the tasks shall finish the first round of global sync within a short time slot. - val times1 = times.map(_._1) - assert(times1.max - times1.min <= 1000) - - // All the tasks shall finish the second round of global sync within a short time slot. - val times2 = times.map(_._2) - assert(times2.max - times2.min <= 1000) - } - test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 90bd2345ac525..d648f63338514 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,10 +16,9 @@ # from __future__ import print_function -import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, write_with_length, UTF8Deserializer +from pyspark.serializers import write_int, UTF8Deserializer class TaskContext(object): @@ -108,28 +107,18 @@ def resources(self): BARRIER_FUNCTION = 1 -ALL_GATHER_FUNCTION = 2 -def _load_from_socket(port, auth_secret, function, all_gather_message=None): +def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) - - # The call may block forever, so no timeout + # The barrier() call may block forever, so no timeout sock.settimeout(None) - - if function == BARRIER_FUNCTION: - # Make a barrier() function call. - write_int(function, sockfile) - elif function == ALL_GATHER_FUNCTION: - # Make a all_gather() function call. - write_int(function, sockfile) - write_with_length(all_gather_message.encode("utf-8"), sockfile) - else: - raise ValueError("Unrecognized function type") + # Make a barrier() function call. + write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() # Collect result. @@ -210,33 +199,7 @@ def barrier(self): raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - _load_from_socket(self._port, self._secret, BARRIER_FUNCTION) - - def allGather(self, message=""): - """ - .. note:: Experimental - - This function blocks until all tasks in the same stage have reached this routine. - Each task passes in a message and returns with a list of all the messages passed in - by each of those tasks. - - .. warning:: In a barrier stage, each task much have the same number of `allGather()` - calls, in all possible code branches. - Otherwise, you may get the job hanging or a SparkException after timeout. - """ - if not isinstance(message, str): - raise ValueError("Argument `message` must be of type `str`") - elif self._port is None or self._secret is None: - raise Exception("Not supported to call barrier() before initialize " + - "BarrierTaskContext.") - else: - gathered_items = _load_from_socket( - self._port, - self._secret, - ALL_GATHER_FUNCTION, - message, - ) - return [e for e in json.loads(gathered_items)] + _load_from_socket(self._port, self._secret) def getTaskInfos(self): """ diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index f5dbd068387c2..6095a384679af 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -134,26 +134,6 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) - def test_all_gather(self): - """ - Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks - within a stage and passes messages properly. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - def context_barrier(x): - tc = BarrierTaskContext.get() - time.sleep(random.randint(1, 10)) - out = tc.allGather(str(context.partitionId())) - pids = [int(e) for e in out] - return [pids] - - pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] - self.assertTrue(pids == [0, 1, 2, 3]) - def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the From 35539cad17fd2b425ba8f7a7e298e9805541aa73 Mon Sep 17 00:00:00 2001 From: David Toneian Date: Fri, 14 Feb 2020 11:00:35 +0900 Subject: [PATCH 0075/1280] [PYSPARK][DOCS][MINOR] Changed `:func:` to `:attr:` Sphinx roles, fixed links in documentation of `Data{Frame,Stream}{Reader,Writer}` This commit is published into the public domain. ### What changes were proposed in this pull request? This PR fixes the documentation of `DataFrameReader`, `DataFrameWriter`, `DataStreamReader`, and `DataStreamWriter`, where attributes of other classes were misrepresented as functions. Additionally, creation of hyperlinks across modules was fixed in these instances. ### Why are the changes needed? The old state produced documentation that suggested invalid usage of PySpark objects (accessing attributes as though they were callable.) ### Does this PR introduce any user-facing change? No, except for improved documentation. ### How was this patch tested? No test added; documentation build runs through. Closes #27553 from DavidToneian/docfix-DataFrameReader-DataFrameWriter-DataStreamReader-DataStreamWriter. Authored-by: David Toneian Signed-off-by: HyukjinKwon (cherry picked from commit 25db8c71a2100c167b8c2d7a6c540ebc61db9b73) Signed-off-by: HyukjinKwon --- python/pyspark/sql/readwriter.py | 4 ++-- python/pyspark/sql/streaming.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3d3280dbd9943..69660395ad823 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -48,7 +48,7 @@ def _set_opts(self, schema=None, **options): class DataFrameReader(OptionUtils): """ Interface used to load a :class:`DataFrame` from external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`spark.read` + (e.g. file systems, key-value stores, etc). Use :attr:`SparkSession.read` to access this. .. versionadded:: 1.4 @@ -616,7 +616,7 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar class DataFrameWriter(OptionUtils): """ Interface used to write a :class:`DataFrame` to external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write` + (e.g. file systems, key-value stores, etc). Use :attr:`DataFrame.write` to access this. .. versionadded:: 1.4 diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index f17a52f6b3dc8..5fced8aca9bdf 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -276,9 +276,9 @@ def resetTerminated(self): class DataStreamReader(OptionUtils): """ - Interface used to load a streaming :class:`DataFrame` from external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` - to access this. + Interface used to load a streaming :class:`DataFrame ` from external + storage systems (e.g. file systems, key-value stores, etc). + Use :attr:`SparkSession.readStream ` to access this. .. note:: Evolving. @@ -750,8 +750,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non class DataStreamWriter(object): """ - Interface used to write a streaming :class:`DataFrame` to external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.writeStream` + Interface used to write a streaming :class:`DataFrame ` to external + storage systems (e.g. file systems, key-value stores, etc). + Use :attr:`DataFrame.writeStream ` to access this. .. note:: Evolving. From febe28517df715c7f3c3e6efdc45dc647f811411 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Fri, 14 Feb 2020 11:20:55 +0800 Subject: [PATCH 0076/1280] [SPARK-30801][SQL] Subqueries should not be AQE-ed if main query is not ### What changes were proposed in this pull request? This PR makes sure AQE is either enabled or disabled for the entire query, including the main query and all subqueries. Currently there are unsupported queries by AQE, e.g., queries that contain DPP filters. We need to make sure that if the main query is unsupported, none of the sub-queries should apply AQE, otherwise it can lead to performance regressions due to missed opportunity of sub-query reuse. ### Why are the changes needed? To get rid of potential perf regressions when AQE is turned on. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Updated DynamicPartitionPruningSuite: 1. Removed the existing workaround `withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")` 2. Added `DynamicPartitionPruningSuiteAEOn` and `DynamicPartitionPruningSuiteAEOff` to enable testing this suite with AQE on and off options 3. Added a check in `checkPartitionPruningPredicate` to verify that the subqueries are always in sync with the main query in terms of whether AQE is applied. Closes #27554 from maryannxue/spark-30801. Authored-by: maryannxue Signed-off-by: Wenchen Fan (cherry picked from commit 0aed77a0155b404e39bc5dbc0579e29e4c7bf887) Signed-off-by: Wenchen Fan --- .../spark/sql/execution/QueryExecution.scala | 19 ++++++++++--- .../sql/DynamicPartitionPruningSuite.scala | 27 ++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 53b6b5d82c021..9109c05e75853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -274,13 +274,25 @@ object QueryExecution { * are correct, insert whole stage code gen, and try to reduce the work done by reusing exchanges * and subqueries. */ - private[execution] def preparations(sparkSession: SparkSession): Seq[Rule[SparkPlan]] = + private[execution] def preparations(sparkSession: SparkSession): Seq[Rule[SparkPlan]] = { + + val sparkSessionWithAdaptiveExecutionOff = + if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { + val session = sparkSession.cloneSession() + session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false) + session + } else { + sparkSession + } + Seq( // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op // as the original plan is hidden behind `AdaptiveSparkPlanExec`. InsertAdaptiveSparkPlan(AdaptiveExecutionContext(sparkSession)), - PlanDynamicPruningFilters(sparkSession), - PlanSubqueries(sparkSession), + // If the following rules apply, it means the main query is not AQE-ed, so we make sure the + // subqueries are not AQE-ed either. + PlanDynamicPruningFilters(sparkSessionWithAdaptiveExecutionOff), + PlanSubqueries(sparkSessionWithAdaptiveExecutionOff), EnsureRequirements(sparkSession.sessionState.conf), ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf, sparkSession.sessionState.columnarRules), @@ -288,6 +300,7 @@ object QueryExecution { ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf) ) + } /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index f7b51d6f4c8ef..baa9f5ecafc68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.GivenWhenThen import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.SharedSparkSession /** * Test suite for the filtering ratio policy used to trigger dynamic partition pruning (DPP). */ -class DynamicPartitionPruningSuite +abstract class DynamicPartitionPruningSuiteBase extends QueryTest with SharedSparkSession with GivenWhenThen @@ -43,9 +43,14 @@ class DynamicPartitionPruningSuite import testImplicits._ + val adaptiveExecutionOn: Boolean + override def beforeAll(): Unit = { super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, adaptiveExecutionOn) + spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY, true) + val factData = Seq[(Int, Int, Int, Int)]( (1000, 1, 1, 10), (1010, 2, 1, 10), @@ -153,6 +158,8 @@ class DynamicPartitionPruningSuite sql("DROP TABLE IF EXISTS fact_stats") sql("DROP TABLE IF EXISTS dim_stats") } finally { + spark.sessionState.conf.unsetConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED) + spark.sessionState.conf.unsetConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) super.afterAll() } } @@ -195,6 +202,11 @@ class DynamicPartitionPruningSuite fail(s"Invalid child node found in\n$s") } } + + val isMainQueryAdaptive = plan.isInstanceOf[AdaptiveSparkPlanExec] + subqueriesAll(plan).filterNot(subqueryBroadcast.contains).foreach { s => + assert(s.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined == isMainQueryAdaptive) + } } /** @@ -1173,8 +1185,7 @@ class DynamicPartitionPruningSuite } test("join key with multiple references on the filtering plan") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { // when enable AQE, the reusedExchange is inserted when executed. withTable("fact", "dim") { spark.range(100).select( @@ -1270,3 +1281,11 @@ class DynamicPartitionPruningSuite } } } + +class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase { + override val adaptiveExecutionOn: Boolean = false +} + +class DynamicPartitionPruningSuiteAEOn extends DynamicPartitionPruningSuiteBase { + override val adaptiveExecutionOn: Boolean = true +} From 1a29f9fd0c2e304f93fcd9bce3cd038ee278c937 Mon Sep 17 00:00:00 2001 From: David Toneian Date: Fri, 14 Feb 2020 13:49:11 +0900 Subject: [PATCH 0077/1280] [SPARK-30823][PYTHON][DOCS] Set `%PYTHONPATH%` when building PySpark documentation on Windows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit is published into the public domain. ### What changes were proposed in this pull request? In analogy to `python/docs/Makefile`, which has > export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.8.1-src.zip) on line 10, this PR adds > set PYTHONPATH=..;..\lib\py4j-0.10.8.1-src.zip to `make2.bat`. Since there is no `realpath` in default installations of Windows, I left the relative paths unresolved. Per the instructions on how to build docs, `make.bat` is supposed to be run from `python/docs` as the working directory, so this should probably not cause issues (`%BUILDDIR%` is a relative path as well.) ### Why are the changes needed? When building the PySpark documentation on Windows, by changing directory to `python/docs` and running `make.bat` (which runs `make2.bat`), the majority of the documentation may not be built if pyspark is not in the default `%PYTHONPATH%`. Sphinx then reports that `pyspark` (and possibly dependencies) cannot be imported. If `pyspark` is in the default `%PYTHONPATH%`, I suppose it is that version of `pyspark` – as opposed to the version found above the `python/docs` directory – that is considered when building the documentation, which may result in documentation that does not correspond to the development version one is trying to build. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manual tests on my Windows 10 machine. Additional tests with other environments very welcome! Closes #27569 from DavidToneian/SPARK-30823. Authored-by: David Toneian Signed-off-by: HyukjinKwon (cherry picked from commit b2134ee73cfad4d78aaf8f0a2898011ac0308e48) Signed-off-by: HyukjinKwon --- python/docs/make2.bat | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/docs/make2.bat b/python/docs/make2.bat index 05d22eb5cdd23..742df373166da 100644 --- a/python/docs/make2.bat +++ b/python/docs/make2.bat @@ -2,6 +2,8 @@ REM Command file for Sphinx documentation +set PYTHONPATH=..;..\lib\py4j-0.10.8.1-src.zip + if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build From 0dcc4df0ca5dba8ae09388b95969080ca28cbe16 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Fri, 14 Feb 2020 16:52:28 +0800 Subject: [PATCH 0078/1280] [SPARK-25990][SQL] ScriptTransformation should handle different data types correctly ### What changes were proposed in this pull request? We should convert Spark InternalRows to hive data via `HiveInspectors.wrapperFor`. ### Why are the changes needed? We may hit below exception without this change: ``` [info] org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0 (TID 1, 192.168.1.6, executor driver): java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to org.apache.hadoop.hive.common.type.HiveDecimal [info] at org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector.getPrimitiveJavaObject(JavaHiveDecimalObjectInspector.java:55) [info] at org.apache.hadoop.hive.serde2.lazy.LazyUtils.writePrimitiveUTF8(LazyUtils.java:321) [info] at org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe.serialize(LazySimpleSerDe.java:292) [info] at org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe.serializeField(LazySimpleSerDe.java:247) [info] at org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe.doSerialize(LazySimpleSerDe.java:231) [info] at org.apache.hadoop.hive.serde2.AbstractEncodingAwareSerDe.serialize(AbstractEncodingAwareSerDe.java:55) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.$anonfun$run$2(ScriptTransformationExec.scala:300) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.$anonfun$run$2$adapted(ScriptTransformationExec.scala:281) [info] at scala.collection.Iterator.foreach(Iterator.scala:941) [info] at scala.collection.Iterator.foreach$(Iterator.scala:941) [info] at scala.collection.AbstractIterator.foreach(Iterator.scala:1429) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.$anonfun$run$1(ScriptTransformationExec.scala:281) [info] at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) [info] at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1932) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.run(ScriptTransformationExec.scala:270) ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added new test. But please note that this test returns different result between Hive1.2 and Hive2.3 due to `HiveDecimal` or `SerDe` difference(don't know the root cause yet). Closes #27556 from Ngone51/script_transform. Lead-authored-by: yi.wu Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan (cherry picked from commit 99b8136a86030411e6bcbd312f40eb2a901ab0f0) Signed-off-by: Wenchen Fan --- .../execution/ScriptTransformationExec.scala | 32 ++++++++----- sql/hive/src/test/resources/test_script.py | 21 +++++++++ .../execution/ScriptTransformationSuite.scala | 46 ++++++++++++++++++- 3 files changed, 85 insertions(+), 14 deletions(-) create mode 100644 sql/hive/src/test/resources/test_script.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index e12f663304e7a..40f7b4e8db7c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -94,9 +94,8 @@ case class ScriptTransformationExec( // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( - inputIterator, + inputIterator.map(outputProjection), input.map(_.dataType), - outputProjection, inputSerde, inputSoi, ioschema, @@ -249,16 +248,15 @@ case class ScriptTransformationExec( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], - outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, - @Nullable inputSoi: ObjectInspector, + @Nullable inputSoi: StructObjectInspector, ioschema: HiveScriptIOSchema, outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, taskContext: TaskContext, conf: Configuration - ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { + ) extends Thread("Thread-ScriptTransformation-Feed") with HiveInspectors with Logging { setDaemon(true) @@ -278,8 +276,8 @@ private class ScriptTransformationWriterThread( var threwException: Boolean = true val len = inputSchema.length try { - iter.map(outputProjection).foreach { row => - if (inputSerde == null) { + if (inputSerde == null) { + iter.foreach { row => val data = if (len == 0) { ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { @@ -295,10 +293,21 @@ private class ScriptTransformationWriterThread( sb.toString() } outputStream.write(data.getBytes(StandardCharsets.UTF_8)) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) + } + } else { + // Convert Spark InternalRows to hive data via `HiveInspectors.wrapperFor`. + val hiveData = new Array[Any](inputSchema.length) + val fieldOIs = inputSoi.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.zip(inputSchema).map { case (f, dt) => wrapperFor(f, dt) } + + iter.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + hiveData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, inputSchema(i))) + i += 1 + } + val writable = inputSerde.serialize(hiveData, inputSoi) if (scriptInputWriter != null) { scriptInputWriter.write(writable) } else { @@ -374,14 +383,13 @@ case class HiveScriptIOSchema ( val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { inputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) val fieldObjectInspectors = columnTypes.map(toInspector) val objectInspector = ObjectInspectorFactory .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) - .asInstanceOf[ObjectInspector] (serde, objectInspector) } } diff --git a/sql/hive/src/test/resources/test_script.py b/sql/hive/src/test/resources/test_script.py new file mode 100644 index 0000000000000..82ef7b38f0c1b --- /dev/null +++ b/sql/hive/src/test/resources/test_script.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +for line in sys.stdin: + (a, b, c, d, e) = line.split('\t') + sys.stdout.write('\t'.join([a, b, c, d, e])) + sys.stdout.flush() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 80a50c18bcb93..7d01fc53a4099 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.sql.Timestamp + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.Assertions._ import org.scalatest.BeforeAndAfterEach @@ -24,15 +26,18 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton with - BeforeAndAfterEach { +class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton + with BeforeAndAfterEach { import spark.implicits._ private val noSerdeIOSchema = HiveScriptIOSchema( @@ -186,6 +191,43 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton wit rowsDf.select("name").collect()) assert(uncaughtExceptionHandler.exception.isEmpty) } + + test("SPARK-25990: TRANSFORM should handle different data types correctly") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + val query = sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + |USING 'python $scriptFilePath' AS (a, b, c, d, e) + |FROM v + """.stripMargin) + + // In Hive1.2, it does not do well on Decimal conversion. For example, in this case, + // it converts a decimal value's type from Decimal(38, 18) to Decimal(1, 0). So we need + // do extra cast here for Hive1.2. But in Hive2.3, it still keeps the original Decimal type. + val decimalToString: Column => Column = if (HiveUtils.isHive23) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } + checkAnswer(query, identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string")).collect()) + } + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { From 79ce79234f02092e22fdd79e859d83f5a174ef95 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 14 Feb 2020 18:20:18 +0800 Subject: [PATCH 0079/1280] [SPARK-30810][SQL] Parses and convert a CSV Dataset having different column from 'value' in csv(dataset) API ### What changes were proposed in this pull request? This PR fixes `DataFrameReader.csv(dataset: Dataset[String])` API to take a `Dataset[String]` originated from a column name different from `value`. This is a long-standing bug started from the very first place. `CSVUtils.filterCommentAndEmpty` assumed the `Dataset[String]` to be originated with `value` column. This PR changes to use the first column name in the schema. ### Why are the changes needed? For `DataFrameReader.csv(dataset: Dataset[String])` to support any `Dataset[String]` as the signature indicates. ### Does this PR introduce any user-facing change? Yes, ```scala val ds = spark.range(2).selectExpr("concat('a,b,', id) AS text").as[String] spark.read.option("header", true).option("inferSchema", true).csv(ds).show() ``` Before: ``` org.apache.spark.sql.AnalysisException: cannot resolve '`value`' given input columns: [text];; 'Filter (length(trim('value, None)) > 0) +- Project [concat(a,b,, cast(id#0L as string)) AS text#2] +- Range (0, 2, step=1, splits=Some(2)) ``` After: ``` +---+---+---+ | a| b| 0| +---+---+---+ | a| b| 1| +---+---+---+ ``` ### How was this patch tested? Unittest was added. Closes #27561 from HyukjinKwon/SPARK-30810. Authored-by: HyukjinKwon Signed-off-by: Wenchen Fan (cherry picked from commit 2a270a731a3b1da9a0fb036d648dd522e5c4d5ad) Signed-off-by: Wenchen Fan --- .../spark/sql/execution/datasources/csv/CSVUtils.scala | 7 ++++--- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 21fabac472f4b..d8b52c503ad34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -33,11 +33,12 @@ object CSVUtils { // with the one below, `filterCommentAndEmpty` but execution path is different. One of them // might have to be removed in the near future if possible. import lines.sqlContext.implicits._ - val nonEmptyLines = lines.filter(length(trim($"value")) > 0) + val aliased = lines.toDF("value") + val nonEmptyLines = aliased.filter(length(trim($"value")) > 0) if (options.isCommentSet) { - nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)) + nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).as[String] } else { - nonEmptyLines + nonEmptyLines.as[String] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index b1105b4a63bba..0be0e1e3da3dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2294,6 +2294,13 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa } } } + + test("SPARK-30810: parses and convert a CSV Dataset having different column from 'value'") { + val ds = spark.range(2).selectExpr("concat('a,b,', id) AS `a.text`").as[String] + val csv = spark.read.option("header", true).option("inferSchema", true).csv(ds) + assert(csv.schema.fieldNames === Seq("a", "b", "0")) + checkAnswer(csv, Row("a", "b", 1)) + } } class CSVv1Suite extends CSVSuite { From 0a8d7a35e24acbd7af57fe5169691afb8ccd8675 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 14 Feb 2020 22:16:57 +0800 Subject: [PATCH 0080/1280] [SPARK-30766][SQL] Fix the timestamp truncation to the `HOUR` and `DAY` levels ### What changes were proposed in this pull request? In the PR, I propose to use Java 8 time API in timestamp truncations to the levels of `HOUR` and `DAY`. The problem is in the usage of `timeZone.getOffset(millis)` in days/hours truncations where the combined calendar (Julian + Gregorian) is used underneath. ### Why are the changes needed? The change fix wrong truncations. For example, the following truncation to hours should print `0010-01-01 01:00:00` but it outputs wrong timestamp: ```scala Seq("0010-01-01 01:02:03.123456").toDF() .select($"value".cast("timestamp").as("ts")) .select(date_trunc("HOUR", $"ts").cast("string")) .show(false) +------------------------------------+ |CAST(date_trunc(HOUR, ts) AS STRING)| +------------------------------------+ |0010-01-01 01:30:17 | +------------------------------------+ ``` ### Does this PR introduce any user-facing change? Yes. After the changes, the result of the example above is: ```scala +------------------------------------+ |CAST(date_trunc(HOUR, ts) AS STRING)| +------------------------------------+ |0010-01-01 01:00:00 | +------------------------------------+ ``` ### How was this patch tested? - Added new test to `DateFunctionsSuite` - By `DateExpressionsSuite` and `DateTimeUtilsSuite` Closes #27512 from MaxGekk/fix-trunc-old-timestamp. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan (cherry picked from commit 7137a6d065edeaab97bf5bf49ffaca3d060a14fe) Signed-off-by: Wenchen Fan --- .../expressions/datetimeExpressions.scala | 6 +-- .../sql/catalyst/util/DateTimeUtils.scala | 44 ++++++++++--------- .../catalyst/util/DateTimeUtilsSuite.scala | 39 ++++++++-------- .../apache/spark/sql/DateFunctionsSuite.scala | 13 ++++++ 4 files changed, 59 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index cf91489d8e6b7..adf7251256041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1690,15 +1690,15 @@ case class TruncTimestamp( override def eval(input: InternalRow): Any = { evalHelper(input, minLevel = MIN_LEVEL_OF_TIMESTAMP_TRUNC) { (t: Any, level: Int) => - DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone) + DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, zoneId) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) codeGenHelper(ctx, ev, minLevel = MIN_LEVEL_OF_TIMESTAMP_TRUNC, true) { (date: String, fmt: String) => - s"truncTimestamp($date, $fmt, $tz);" + s"truncTimestamp($date, $fmt, $zid);" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 01d36f19fc06f..ce0c138791f30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -711,32 +711,34 @@ object DateTimeUtils { } } + private def truncToUnit(t: SQLTimestamp, zoneId: ZoneId, unit: ChronoUnit): SQLTimestamp = { + val truncated = microsToInstant(t).atZone(zoneId).truncatedTo(unit) + instantToMicros(truncated.toInstant) + } + /** * Returns the trunc date time from original date time and trunc level. * Trunc level should be generated using `parseTruncLevel()`, should be between 0 and 12. */ - def truncTimestamp(t: SQLTimestamp, level: Int, timeZone: TimeZone): SQLTimestamp = { - if (level == TRUNC_TO_MICROSECOND) return t - var millis = MICROSECONDS.toMillis(t) - val truncated = level match { - case TRUNC_TO_MILLISECOND => millis - case TRUNC_TO_SECOND => - millis - millis % MILLIS_PER_SECOND - case TRUNC_TO_MINUTE => - millis - millis % MILLIS_PER_MINUTE - case TRUNC_TO_HOUR => - val offset = timeZone.getOffset(millis) - millis += offset - millis - millis % MILLIS_PER_HOUR - offset - case TRUNC_TO_DAY => - val offset = timeZone.getOffset(millis) - millis += offset - millis - millis % MILLIS_PER_DAY - offset - case _ => // Try to truncate date levels - val dDays = millisToDays(millis, timeZone.toZoneId) - daysToMillis(truncDate(dDays, level), timeZone.toZoneId) + def truncTimestamp(t: SQLTimestamp, level: Int, zoneId: ZoneId): SQLTimestamp = { + level match { + case TRUNC_TO_MICROSECOND => t + case TRUNC_TO_HOUR => truncToUnit(t, zoneId, ChronoUnit.HOURS) + case TRUNC_TO_DAY => truncToUnit(t, zoneId, ChronoUnit.DAYS) + case _ => + val millis = MICROSECONDS.toMillis(t) + val truncated = level match { + case TRUNC_TO_MILLISECOND => millis + case TRUNC_TO_SECOND => + millis - millis % MILLIS_PER_SECOND + case TRUNC_TO_MINUTE => + millis - millis % MILLIS_PER_MINUTE + case _ => // Try to truncate date levels + val dDays = millisToDays(millis, zoneId) + daysToMillis(truncDate(dDays, level), zoneId) + } + truncated * MICROS_PER_MILLIS } - truncated * MICROS_PER_MILLIS } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index cd0594c775a47..ff4d8a2457922 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -499,9 +499,9 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { level: Int, expected: String, inputTS: SQLTimestamp, - timezone: TimeZone = DateTimeUtils.defaultTimeZone()): Unit = { + zoneId: ZoneId = defaultZoneId): Unit = { val truncated = - DateTimeUtils.truncTimestamp(inputTS, level, timezone) + DateTimeUtils.truncTimestamp(inputTS, level, zoneId) val expectedTS = toTimestamp(expected, defaultZoneId) assert(truncated === expectedTS.get) } @@ -539,6 +539,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { for (tz <- ALL_TIMEZONES) { withDefaultTimeZone(tz) { + val zid = tz.toZoneId val inputTS = DateTimeUtils.stringToTimestamp( UTF8String.fromString("2015-03-05T09:32:05.359"), defaultZoneId) val inputTS1 = DateTimeUtils.stringToTimestamp( @@ -552,23 +553,23 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { val inputTS5 = DateTimeUtils.stringToTimestamp( UTF8String.fromString("1999-03-29T01:02:03.456789"), defaultZoneId) - testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_DECADE, "1990-01-01", inputTS5.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_CENTURY, "1901-01-01", inputTS5.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MILLENNIUM, "2001-01-01", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_DECADE, "1990-01-01", inputTS5.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_CENTURY, "1901-01-01", inputTS5.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MILLENNIUM, "2001-01-01", inputTS.get, zid) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 41d53c959ef99..ba45b9f9b62df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -856,4 +856,17 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { TimeZone.setDefault(defaultTz) } } + + test("SPARK-30766: date_trunc of old timestamps to hours and days") { + def checkTrunc(level: String, expected: String): Unit = { + val df = Seq("0010-01-01 01:02:03.123456") + .toDF() + .select($"value".cast("timestamp").as("ts")) + .select(date_trunc(level, $"ts").cast("string")) + checkAnswer(df, Row(expected)) + } + + checkTrunc("HOUR", "0010-01-01 01:00:00") + checkTrunc("DAY", "0010-01-01 00:00:00") + } } From 1385fc02ce7d28e6570971e1687e74d245a5533f Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 14 Feb 2020 10:18:08 -0800 Subject: [PATCH 0081/1280] [SPARK-29748][DOCS][FOLLOW-UP] Add a note that the legacy environment variable to set in both executor and driver ### What changes were proposed in this pull request? This PR address the comment at https://github.com/apache/spark/pull/26496#discussion_r379194091 and improves the migration guide to explicitly note that the legacy environment variable to set in both executor and driver. ### Why are the changes needed? To clarify this env should be set both in driver and executors. ### Does this PR introduce any user-facing change? Nope. ### How was this patch tested? I checked it via md editor. Closes #27573 from HyukjinKwon/SPARK-29748. Authored-by: HyukjinKwon Signed-off-by: Shixiong Zhu (cherry picked from commit b343757b1bd5d0344b82f36aa4d65ed34f840606) Signed-off-by: Shixiong Zhu --- docs/pyspark-migration-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pyspark-migration-guide.md b/docs/pyspark-migration-guide.md index 8ea4fec75edf8..f7f20389aa694 100644 --- a/docs/pyspark-migration-guide.md +++ b/docs/pyspark-migration-guide.md @@ -87,7 +87,7 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide. - Since Spark 3.0, `Column.getItem` is fixed such that it does not call `Column.apply`. Consequently, if `Column` is used as an argument to `getItem`, the indexing operator should be used. For example, `map_col.getItem(col('id'))` should be replaced with `map_col[col('id')]`. - - As of Spark 3.0 `Row` field names are no longer sorted alphabetically when constructing with named arguments for Python versions 3.6 and above, and the order of fields will match that as entered. To enable sorted fields by default, as in Spark 2.4, set the environment variable `PYSPARK_ROW_FIELD_SORTING_ENABLED` to "true". For Python versions less than 3.6, the field names will be sorted alphabetically as the only option. + - As of Spark 3.0 `Row` field names are no longer sorted alphabetically when constructing with named arguments for Python versions 3.6 and above, and the order of fields will match that as entered. To enable sorted fields by default, as in Spark 2.4, set the environment variable `PYSPARK_ROW_FIELD_SORTING_ENABLED` to "true" for both executors and driver - this environment variable must be consistent on all executors and driver; otherwise, it may cause failures or incorrect answers. For Python versions less than 3.6, the field names will be sorted alphabetically as the only option. ## Upgrading from PySpark 2.3 to 2.4 From 2824fec9fa57444b7c64edb8226cf75bb87a2e5d Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 14 Feb 2020 21:46:01 +0000 Subject: [PATCH 0082/1280] [SPARK-30289][SQL] Partitioned by Nested Column for `InMemoryTable` ### What changes were proposed in this pull request? 1. `InMemoryTable` was flatting the nested columns, and then the flatten columns was used to look up the indices which is not correct. This PR implements partitioned by nested column for `InMemoryTable`. ### Why are the changes needed? This PR implements partitioned by nested column for `InMemoryTable`, so we can test this features in DSv2 ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing unit tests and new tests. Closes #26929 from dbtsai/addTests. Authored-by: DB Tsai Signed-off-by: DB Tsai (cherry picked from commit d0f961476031b62bda0d4d41f7248295d651ea92) Signed-off-by: DB Tsai --- .../spark/sql/connector/InMemoryTable.scala | 35 ++++++-- .../spark/sql/DataFrameWriterV2Suite.scala | 86 ++++++++++++++++++- 2 files changed, 114 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index c9e4e0aad5704..0187ae31e2d1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -26,7 +26,7 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} @@ -59,10 +59,30 @@ class InMemoryTable( def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) - private val partIndexes = partFieldNames.map(schema.fieldIndex) + private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => + schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + case Some(_) => ref.fieldNames() + case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") + } + } - private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_)) + private def getKey(row: InternalRow): Seq[Any] = { + def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = { + val index = schema.fieldIndex(fieldNames(0)) + val value = row.toSeq(schema).apply(index) + if (fieldNames.length > 1) { + (value, schema(index).dataType) match { + case (row: InternalRow, nestedSchema: StructType) => + extractor(fieldNames.drop(1), nestedSchema, row) + case (_, dataType) => + throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") + } + } else { + value + } + } + partCols.map(fieldNames => extractor(fieldNames, schema, row)) + } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => @@ -146,8 +166,10 @@ class InMemoryTable( } private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + val deleteKeys = InMemoryTable.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), filters) dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } @@ -161,7 +183,8 @@ class InMemoryTable( } override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { - dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index d49dc58e93ddb..cd157086a8b8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -17,20 +17,24 @@ package org.apache.spark.sql +import java.sql.Timestamp + import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.TimestampType import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -550,4 +554,84 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(replaced.partitioning.isEmpty) assert(replaced.properties === defaultOwnership.asJava) } + + test("SPARK-30289 Create: partitioned by nested column") { + val schema = new StructType().add("ts", new StructType() + .add("created", TimestampType) + .add("modified", TimestampType) + .add("timezone", StringType)) + + val data = Seq( + Row(Row(Timestamp.valueOf("2019-06-01 10:00:00"), Timestamp.valueOf("2019-09-02 07:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2019-08-26 18:00:00"), Timestamp.valueOf("2019-09-26 18:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2018-11-23 18:00:00"), Timestamp.valueOf("2018-12-22 18:00:00"), + "America/New_York"))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) + + df.writeTo("testcat.table_name") + .partitionedBy($"ts.timezone") + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + .asInstanceOf[InMemoryTable] + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(IdentityTransform(FieldReference(Array("ts", "timezone"))))) + checkAnswer(spark.table(table.name), data) + assert(table.dataMap.toArray.length == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).rows.size == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/New_York"))).rows.size == 1) + + // TODO: `DataSourceV2Strategy` can not translate nested fields into source filter yet + // so the following sql will fail. + // sql("DELETE FROM testcat.table_name WHERE ts.timezone = \"America/Los_Angeles\"") + } + + test("SPARK-30289 Create: partitioned by multiple transforms on nested columns") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy( + years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"), + years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified") + ) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq( + YearsTransform(FieldReference(Array("ts", "created"))), + MonthsTransform(FieldReference(Array("ts", "created"))), + DaysTransform(FieldReference(Array("ts", "created"))), + HoursTransform(FieldReference(Array("ts", "created"))), + YearsTransform(FieldReference(Array("ts", "modified"))), + MonthsTransform(FieldReference(Array("ts", "modified"))), + DaysTransform(FieldReference(Array("ts", "modified"))), + HoursTransform(FieldReference(Array("ts", "modified"))))) + } + + test("SPARK-30289 Create: partitioned by bucket(4, ts.timezone)") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(bucket(4, $"ts.timezone")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(BucketTransform(LiteralValue(4, IntegerType), + Seq(FieldReference(Seq("ts", "timezone")))))) + } } From f7b38fe05e222ed0eed3aca89e90a5690416193a Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 15 Feb 2020 19:49:58 +0800 Subject: [PATCH 0083/1280] [SPARK-30826][SQL] Respect reference case in `StringStartsWith` pushed down to parquet ### What changes were proposed in this pull request? In the PR, I propose to convert the attribute name of `StringStartsWith` pushed down to the Parquet datasource to column reference via the `nameToParquetField` map. Similar conversions are performed for other source filters pushed down to parquet. ### Why are the changes needed? This fixes the bug described in [SPARK-30826](https://issues.apache.org/jira/browse/SPARK-30826). The query from an external table: ```sql CREATE TABLE t1 (col STRING) USING parquet OPTIONS (path '$path') ``` created on top of written parquet files by `Seq("42").toDF("COL").write.parquet(path)` returns wrong empty result: ```scala spark.sql("SELECT * FROM t1 WHERE col LIKE '4%'").show +---+ |col| +---+ +---+ ``` ### Does this PR introduce any user-facing change? Yes. After the changes the result is correct for the example above: ```scala spark.sql("SELECT * FROM t1 WHERE col LIKE '4%'").show +---+ |col| +---+ | 42| +---+ ``` ### How was this patch tested? Added a test to `ParquetFilterSuite` Closes #27574 from MaxGekk/parquet-StringStartsWith-case-sens. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan (cherry picked from commit 8b73b92aadd685b29ef3d9b33366f5e1fd3dae99) Signed-off-by: Wenchen Fan --- .../datasources/parquet/ParquetFilters.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index b9b86adb438e6..948a120e0d6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -591,7 +591,7 @@ class ParquetFilters( case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => - FilterApi.userDefined(binaryColumn(name), + FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldName), new UserDefinedPredicate[Binary] with Serializable { private val strToBinary = Binary.fromReusedByteArray(v.getBytes) private val size = strToBinary.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 286bb1e920266..4e0c1c2dbe601 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1390,6 +1390,27 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } } + + test("SPARK-30826: case insensitivity of StringStartsWith attribute") { + import testImplicits._ + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTable("t1") { + withTempPath { dir => + val path = dir.toURI.toString + Seq("42").toDF("COL").write.parquet(path) + spark.sql( + s""" + |CREATE TABLE t1 (col STRING) + |USING parquet + |OPTIONS (path '$path') + """.stripMargin) + checkAnswer( + spark.sql("SELECT * FROM t1 WHERE col LIKE '4%'"), + Row("42")) + } + } + } + } } class ParquetV1FilterSuite extends ParquetFilterSuite { From 8ed8baa74a6471d929fcc367bff282a87cead7a1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 16 Feb 2020 09:53:12 -0600 Subject: [PATCH 0084/1280] [SPARK-30691][SQL][DOC][FOLLOW-UP] Make link names exactly the same as the side bar names ### What changes were proposed in this pull request? Make link names exactly the same as the side bar names ### Why are the changes needed? Make doc look better ### Does this PR introduce any user-facing change? before: ![image](https://user-images.githubusercontent.com/13592258/74578603-ad300100-4f4a-11ea-8430-11fccf31eab4.png) after: ![image](https://user-images.githubusercontent.com/13592258/74578670-eff1d900-4f4a-11ea-97d8-5908c0e50e95.png) ### How was this patch tested? Manually build and check the docs Closes #27591 from huaxingao/spark-doc-followup. Authored-by: Huaxin Gao Signed-off-by: Sean Owen (cherry picked from commit 0a03e7e679771da8556fae72b35edf21ae71ac44) Signed-off-by: Sean Owen --- docs/_data/menu-sql.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 1e343f630f88e..38a5cf61245a6 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -157,12 +157,12 @@ - text: Auxiliary Statements url: sql-ref-syntax-aux.html subitems: - - text: Analyze statement + - text: ANALYZE url: sql-ref-syntax-aux-analyze.html subitems: - text: ANALYZE TABLE url: sql-ref-syntax-aux-analyze-table.html - - text: Caching statements + - text: CACHE url: sql-ref-syntax-aux-cache.html subitems: - text: CACHE TABLE @@ -175,7 +175,7 @@ url: sql-ref-syntax-aux-refresh-table.html - text: REFRESH url: sql-ref-syntax-aux-cache-refresh.md - - text: Describe Commands + - text: DESCRIBE url: sql-ref-syntax-aux-describe.html subitems: - text: DESCRIBE DATABASE @@ -186,7 +186,7 @@ url: sql-ref-syntax-aux-describe-function.html - text: DESCRIBE QUERY url: sql-ref-syntax-aux-describe-query.html - - text: Show commands + - text: SHOW url: sql-ref-syntax-aux-show.html subitems: - text: SHOW COLUMNS @@ -205,14 +205,14 @@ url: sql-ref-syntax-aux-show-partitions.html - text: SHOW CREATE TABLE url: sql-ref-syntax-aux-show-create-table.html - - text: Configuration Management Commands + - text: CONFIGURATION MANAGEMENT url: sql-ref-syntax-aux-conf-mgmt.html subitems: - text: SET url: sql-ref-syntax-aux-conf-mgmt-set.html - text: RESET url: sql-ref-syntax-aux-conf-mgmt-reset.html - - text: Resource Management Commands + - text: RESOURCE MANAGEMENT url: sql-ref-syntax-aux-resource-mgmt.html subitems: - text: ADD FILE From 3c9231e20b70ffc6e3ea61944ae5000f774c613a Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 16 Feb 2020 09:55:03 -0600 Subject: [PATCH 0085/1280] [SPARK-30803][DOCS] Fix the home page link for Scala API document ### What changes were proposed in this pull request? Change the link to the Scala API document. ``` $ git grep "#org.apache.spark.package" docs/_layouts/global.html:
  • Scala
  • docs/index.md:* [Spark Scala API (Scaladoc)](api/scala/index.html#org.apache.spark.package) docs/rdd-programming-guide.md:[Scala](api/scala/#org.apache.spark.package), [Java](api/java/), [Python](api/python/) and [R](api/R/). ``` ### Why are the changes needed? The home page link for Scala API document is incorrect after upgrade to 3.0 ### Does this PR introduce any user-facing change? Document UI change only. ### How was this patch tested? Local test, attach screenshots below: Before: ![image](https://user-images.githubusercontent.com/4833765/74335713-c2385300-4dd7-11ea-95d8-f5a3639d2578.png) After: ![image](https://user-images.githubusercontent.com/4833765/74335727-cbc1bb00-4dd7-11ea-89d9-4dcc1310e679.png) Closes #27549 from xuanyuanking/scala-doc. Authored-by: Yuanjian Li Signed-off-by: Sean Owen (cherry picked from commit 01cc852982cd065e08f9a652c14a0514f49fb662) Signed-off-by: Sean Owen --- docs/_layouts/global.html | 2 +- docs/configuration.md | 8 +- docs/graphx-programming-guide.md | 68 +++++++------- docs/index.md | 2 +- docs/ml-advanced.md | 10 +- docs/ml-classification-regression.md | 40 ++++---- docs/ml-clustering.md | 10 +- docs/ml-collaborative-filtering.md | 2 +- docs/ml-datasource.md | 4 +- docs/ml-features.md | 92 +++++++++---------- docs/ml-frequent-pattern-mining.md | 4 +- docs/ml-migration-guide.md | 36 ++++---- docs/ml-pipeline.md | 10 +- docs/ml-statistics.md | 8 +- docs/ml-tuning.md | 18 ++-- docs/mllib-clustering.md | 26 +++--- docs/mllib-collaborative-filtering.md | 4 +- docs/mllib-data-types.md | 48 +++++----- docs/mllib-decision-tree.md | 10 +- docs/mllib-dimensionality-reduction.md | 6 +- docs/mllib-ensembles.md | 10 +- docs/mllib-evaluation-metrics.md | 8 +- docs/mllib-feature-extraction.md | 34 +++---- docs/mllib-frequent-pattern-mining.md | 14 +-- docs/mllib-isotonic-regression.md | 2 +- docs/mllib-linear-methods.md | 22 ++--- docs/mllib-naive-bayes.md | 8 +- docs/mllib-optimization.md | 14 +-- docs/mllib-pmml-model-export.md | 2 +- docs/mllib-statistics.md | 28 +++--- docs/quick-start.md | 2 +- docs/rdd-programming-guide.md | 28 +++--- docs/sql-data-sources-generic-options.md | 2 +- docs/sql-data-sources-jdbc.md | 2 +- docs/sql-data-sources-json.md | 2 +- docs/sql-getting-started.md | 16 ++-- docs/sql-migration-guide.md | 4 +- docs/sql-programming-guide.md | 2 +- docs/sql-ref-syntax-aux-analyze-table.md | 2 +- docs/sql-ref-syntax-aux-cache-refresh.md | 2 +- docs/sql-ref-syntax-aux-refresh-table.md | 2 +- docs/sql-ref-syntax-aux-resource-mgmt.md | 2 +- docs/sql-ref-syntax-aux-show-tables.md | 2 +- docs/sql-ref-syntax-aux-show.md | 2 +- docs/sql-ref-syntax-ddl-drop-database.md | 2 +- docs/sql-ref-syntax-ddl-drop-function.md | 2 +- ...tax-dml-insert-overwrite-directory-hive.md | 2 +- ...f-syntax-dml-insert-overwrite-directory.md | 2 +- docs/sql-ref-syntax-dml.md | 2 +- docs/sql-ref-syntax-qry-select-clusterby.md | 2 +- ...sql-ref-syntax-qry-select-distribute-by.md | 2 +- docs/sql-ref-syntax-qry-select-sortby.md | 2 +- docs/sql-ref-syntax-qry-select.md | 2 +- docs/streaming-custom-receivers.md | 2 +- docs/streaming-kafka-integration.md | 2 +- docs/streaming-kinesis-integration.md | 2 +- docs/streaming-programming-guide.md | 42 ++++----- .../structured-streaming-programming-guide.md | 22 ++--- docs/tuning.md | 2 +- 59 files changed, 355 insertions(+), 355 deletions(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index d5fb18bfb06c0..d05ac6bbe129d 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -82,7 +82,7 @@