From 6da8ade5f46cac69820ef0f6987806ffa78873f1 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 19 Nov 2020 12:42:33 -0800 Subject: [PATCH 01/38] [SPARK-33045][SQL][FOLLOWUP] Fix build failure with Scala 2.13 ### What changes were proposed in this pull request? Explicitly convert `scala.collection.mutable.Buffer` to `Seq`. In Scala 2.13 `Seq` is an alias of `scala.collection.immutable.Seq` instead of `scala.collection.Seq`. ### Why are the changes needed? Without the change build with Scala 2.13 fails with the following: ``` [error] /home/runner/work/spark/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala:1417:41: type mismatch; [error] found : scala.collection.mutable.Buffer[org.apache.spark.unsafe.types.UTF8String] [error] required: Seq[org.apache.spark.unsafe.types.UTF8String] [error] case null => LikeAll(e, patterns) [error] ^ [error] /home/runner/work/spark/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala:1418:41: type mismatch; [error] found : scala.collection.mutable.Buffer[org.apache.spark.unsafe.types.UTF8String] [error] required: Seq[org.apache.spark.unsafe.types.UTF8String] [error] case _ => NotLikeAll(e, patterns) [error] ^ ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A Closes #30431 from sunchao/SPARK-33045-followup. Authored-by: Chao Sun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 79857a63a69b5..23de8ab09dd0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1414,8 +1414,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg // So we use LikeAll or NotLikeAll instead. val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) ctx.NOT match { - case null => LikeAll(e, patterns) - case _ => NotLikeAll(e, patterns) + case null => LikeAll(e, patterns.toSeq) + case _ => NotLikeAll(e, patterns.toSeq) } } else { getLikeQuantifierExprs(ctx.expression).reduceLeft(And) From 883a213a8f721d19855f7a5696084533da2002f7 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 19 Nov 2020 13:36:45 -0800 Subject: [PATCH 02/38] [MINOR] Structured Streaming statistics page indent fix ### What changes were proposed in this pull request? Structured Streaming statistics page code contains an indentation issue. This PR fixes it. ### Why are the changes needed? Indent fix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit tests. Closes #30434 from gaborgsomogyi/STAT-INDENT-FIX. Authored-by: Gabor Somogyi Signed-off-by: Dongjoon Hyun --- .../ui/StreamingQueryStatisticsPage.scala | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index 77078046dda7c..7d38acfceee81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -209,33 +209,33 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) {graphUIDataForNumberTotalRows.generateTimelineHtml(jsCollector)} {graphUIDataForNumberTotalRows.generateHistogramHtml(jsCollector)} - - -
-
Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
-
- - {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} - {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} - - - -
-
Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
-
- - {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} - {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} - - - -
-
Aggregated Number Of State Rows Dropped By Watermark {SparkUIUtils.tooltip("Aggregated number of state rows dropped by watermark.", "right")}
-
- - {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} - {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} - + + +
+
Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
+
+ + {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} + {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} + + + +
+
Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
+
+ + {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} + {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} + + + +
+
Aggregated Number Of State Rows Dropped By Watermark {SparkUIUtils.tooltip("Aggregated number of state rows dropped by watermark.", "right")}
+
+ + {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} + {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} + // scalastyle:on } else { new NodeBuffer() From 02d410a18c966944c7a46e5bc3006dadf3d579b6 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 20 Nov 2020 13:14:20 +0900 Subject: [PATCH 03/38] [MINOR][DOCS] Document 'without' value for HADOOP_VERSION in pip installation ### What changes were proposed in this pull request? I believe it's self-descriptive. ### Why are the changes needed? To document supported features. ### Does this PR introduce _any_ user-facing change? Yes, the docs are updated. It's master only. ### How was this patch tested? Manually built the docs via `cd python/docs` and `make clean html`: ![Screen Shot 2020-11-20 at 10 59 07 AM](https://user-images.githubusercontent.com/6477701/99748225-7ad9b280-2b1f-11eb-86fd-165012b1bb7c.png) Closes #30436 from HyukjinKwon/minor-doc-fix. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- python/docs/source/getting_started/install.rst | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 4039698d39958..9c9ff7fa7844b 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -48,7 +48,7 @@ If you want to install extra dependencies for a specific componenet, you can ins pip install pyspark[sql] -For PySpark with a different Hadoop version, you can install it by using ``HADOOP_VERSION`` environment variables as below: +For PySpark with/without a specific Hadoop version, you can install it by using ``HADOOP_VERSION`` environment variables as below: .. code-block:: bash @@ -68,8 +68,13 @@ It is recommended to use ``-v`` option in ``pip`` to track the installation and HADOOP_VERSION=2.7 pip install pyspark -v -Supported versions of Hadoop are ``HADOOP_VERSION=2.7`` and ``HADOOP_VERSION=3.2`` (default). -Note that this installation of PySpark with a different version of Hadoop is experimental. It can change or be removed between minor releases. +Supported values in ``HADOOP_VERSION`` are: + +- ``without``: Spark pre-built with user-provided Apache Hadoop +- ``2.7``: Spark pre-built for Apache Hadoop 2.7 +- ``3.2``: Spark pre-built for Apache Hadoop 3.2 and later (default) + +Note that this installation way of PySpark with/without a specific Hadoop version is experimental. It can change or be removed between minor releases. Using Conda From 8218b488035049434271dc9e3bd5af45ffadf0fd Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Fri, 20 Nov 2020 06:00:30 -0600 Subject: [PATCH 04/38] [SPARK-32919][SHUFFLE][TEST-MAVEN][TEST-HADOOP2.7] Driver side changes for coordinating push based shuffle by selecting external shuffle services for merging partitions ### What changes were proposed in this pull request? Driver side changes for coordinating push based shuffle by selecting external shuffle services for merging partitions. This PR includes changes related to `ShuffleMapStage` preparation which is selection of merger locations and initializing them as part of `ShuffleDependency`. Currently this code is not used as some of the changes would come subsequently as part of https://issues.apache.org/jira/browse/SPARK-32917 (shuffle blocks push as part of `ShuffleMapTask`), https://issues.apache.org/jira/browse/SPARK-32918 (support for finalize API) and https://issues.apache.org/jira/browse/SPARK-32920 (finalization of push/merge phase). This is why the tests here are also partial, once these above mentioned changes are raised as PR we will have enough tests for DAGScheduler piece of code as well. ### Why are the changes needed? Added a new API in `SchedulerBackend` to get merger locations for push based shuffle. This is currently implemented for Yarn and other cluster managers can have separate implementations which is why a new API is introduced. ### Does this PR introduce _any_ user-facing change? Yes, user facing config to enable push based shuffle is introduced ### How was this patch tested? Added unit tests partially and some of the changes in DAGScheduler depends on future changes, DAGScheduler tests will be added along with those changes. Lead-authored-by: Venkata krishnan Sowrirajan vsowrirajanlinkedin.com Co-authored-by: Min Shen mshenlinkedin.com Closes #30164 from venkata91/upstream-SPARK-32919. Lead-authored-by: Venkata krishnan Sowrirajan Co-authored-by: Min Shen Signed-off-by: Mridul Muralidharan gmail.com> --- .../scala/org/apache/spark/Dependency.scala | 15 +++++ .../spark/internal/config/package.scala | 47 ++++++++++++++ .../apache/spark/scheduler/DAGScheduler.scala | 40 ++++++++++++ .../spark/scheduler/SchedulerBackend.scala | 13 ++++ .../apache/spark/storage/BlockManagerId.scala | 2 + .../spark/storage/BlockManagerMaster.scala | 20 ++++++ .../storage/BlockManagerMasterEndpoint.scala | 65 +++++++++++++++++++ .../spark/storage/BlockManagerMessages.scala | 6 ++ .../scala/org/apache/spark/util/Utils.scala | 8 +++ .../spark/storage/BlockManagerSuite.scala | 49 +++++++++++++- .../org/apache/spark/util/UtilsSuite.scala | 12 ++++ .../cluster/YarnSchedulerBackend.scala | 50 ++++++++++++-- 12 files changed, 320 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ba8e4d69ba755..d21b9d9833e9e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} +import org.apache.spark.storage.BlockManagerId /** * :: DeveloperApi :: @@ -95,6 +96,20 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( shuffleId, this) + /** + * Stores the location of the list of chosen external shuffle services for handling the + * shuffle merge requests from mappers in this shuffle map stage. + */ + private[spark] var mergerLocs: Seq[BlockManagerId] = Nil + + def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = { + if (mergerLocs != null) { + this.mergerLocs = mergerLocs + } + } + + def getMergerLocs: Seq[BlockManagerId] = mergerLocs + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4bc49514fc5ad..b38d0e5c617b9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1945,4 +1945,51 @@ package object config { .version("3.0.1") .booleanConf .createWithDefault(false) + + private[spark] val PUSH_BASED_SHUFFLE_ENABLED = + ConfigBuilder("spark.shuffle.push.enabled") + .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " + + "conjunction with the server side flag spark.shuffle.server.mergedShuffleFileManagerImpl " + + "which needs to be set with the appropriate " + + "org.apache.spark.network.shuffle.MergedShuffleFileManager implementation for push-based " + + "shuffle to be enabled") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS = + ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations") + .doc("Maximum number of shuffle push merger locations cached for push based shuffle. " + + "Currently, shuffle push merger locations are nothing but external shuffle services " + + "which are responsible for handling pushed blocks and merging them and serving " + + "merged blocks for later shuffle fetch.") + .version("3.1.0") + .intConf + .createWithDefault(500) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO = + ConfigBuilder("spark.shuffle.push.mergersMinThresholdRatio") + .doc("The minimum number of shuffle merger locations required to enable push based " + + "shuffle for a stage. This is specified as a ratio of the number of partitions in " + + "the child stage. For example, a reduce stage which has 100 partitions and uses the " + + "default value 0.05 requires at least 5 unique merger locations to enable push based " + + "shuffle. Merger locations are currently defined as external shuffle services.") + .version("3.1.0") + .doubleConf + .createWithDefault(0.05) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD = + ConfigBuilder("spark.shuffle.push.mergersMinStaticThreshold") + .doc(s"The static threshold for number of shuffle push merger locations should be " + + "available in order to enable push based shuffle for a stage. Note this config " + + s"works in conjunction with ${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key}. " + + "Maximum of spark.shuffle.push.mergersMinStaticThreshold and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} ratio number of mergers needed to " + + "enable push based shuffle for a stage. For eg: with 1000 partitions for the child " + + "stage with spark.shuffle.push.mergersMinStaticThreshold as 5 and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " + + "at least 50 mergers to enable push based shuffle for that stage.") + .version("3.1.0") + .doubleConf + .createWithDefault(5) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 13b766e654832..6fb0fb93f253b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -249,6 +249,8 @@ private[spark] class DAGScheduler( private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf) + /** * Called by the TaskSetManager to report task's starting. */ @@ -1252,6 +1254,33 @@ private[spark] class DAGScheduler( execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores)) } + /** + * If push based shuffle is enabled, set the shuffle services to be used for the given + * shuffle map stage for block push/merge. + * + * Even with dynamic resource allocation kicking in and significantly reducing the number + * of available active executors, we would still be able to get sufficient shuffle service + * locations for block push/merge by getting the historical locations of past executors. + */ + private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { + // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize + // TODO changes we cannot disable shuffle merge for the retry/reuse cases + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + + logDebug("List of shuffle push merger locations " + + s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + } else { + logInfo("No available merger locations." + + s" Push-based shuffle disabled for $stage (${stage.name})") + } + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") @@ -1281,6 +1310,12 @@ private[spark] class DAGScheduler( stage match { case s: ShuffleMapStage => outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + // Only generate merger location for a given shuffle dependency once. This way, even if + // this stage gets retried, it would still be merging blocks using the same set of + // shuffle services. + if (pushBasedShuffleEnabled) { + prepareShuffleServicesForShuffleMapStage(s) + } case s: ResultStage => outputCommitCoordinator.stageStart( stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) @@ -2027,6 +2062,11 @@ private[spark] class DAGScheduler( if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) { executorFailureEpoch(execId) = currentEpoch logInfo(s"Executor lost: $execId (epoch $currentEpoch)") + if (pushBasedShuffleEnabled) { + // Remove fetchFailed host in the shuffle push merger list for push based shuffle + hostToUnregisterOutputs.foreach( + host => blockManagerMaster.removeShufflePushMergerLocation(host)) + } blockManagerMaster.removeExecutor(execId) clearCacheLocs() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index a566d0a04387c..b2acdb3e12a6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.resource.ResourceProfile +import org.apache.spark.storage.BlockManagerId /** * A backend interface for scheduling systems that allows plugging in different ones under @@ -92,4 +93,16 @@ private[spark] trait SchedulerBackend { */ def maxNumConcurrentTasks(rp: ResourceProfile): Int + /** + * Get the list of host locations for push based shuffle + * + * Currently push based shuffle is disabled for both stage retry and stage reuse cases + * (for eg: in the case where few partitions are lost due to failure). Hence this method + * should be invoked only once for a ShuffleDependency. + * @return List of external shuffle services locations + */ + def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = Nil + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 49e32d04d450a..c6a4457d8f910 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -145,4 +145,6 @@ private[spark] object BlockManagerId { def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { blockManagerIdCache.get(id) } + + private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f544d47b8e13c..fe1a5aef9499c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -125,6 +125,26 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + /** + * Get a list of unique shuffle service locations where an executor is successfully + * registered in the past for block push/merge with push based shuffle. + */ + def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + driverEndpoint.askSync[Seq[BlockManagerId]]( + GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + } + + /** + * Remove the host from the candidate list of shuffle push mergers. This can be + * triggered if there is a FetchFailedException on the host + * @param host + */ + def removeShufflePushMergerLocation(host: String): Unit = { + driverEndpoint.askSync[Seq[BlockManagerId]](RemoveShufflePushMergerLocation(host)) + } + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index a7532a9870fae..4d565511704d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -74,6 +74,14 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + // Mapping from host name to shuffle (mergers) services where the current app + // registered an executor in the past. Older hosts are removed when the + // maxRetainedMergerLocations size is reached in favor of newer locations. + private val shuffleMergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]() + + // Maximum number of merger locations to cache + private val maxRetainedMergerLocations = conf.get(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS) + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100) private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) @@ -92,6 +100,8 @@ class BlockManagerMasterEndpoint( val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) + logInfo("BlockManagerMasterEndpoint up") // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED) // && conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)` @@ -139,6 +149,12 @@ class BlockManagerMasterEndpoint( case GetBlockStatus(blockId, askStorageEndpoints) => context.reply(blockStatus(blockId, askStorageEndpoints)) + case GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter) => + context.reply(getShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + + case RemoveShufflePushMergerLocation(host) => + context.reply(removeShufflePushMergerLocation(host)) + case IsExecutorAlive(executorId) => context.reply(blockManagerIdByExecutor.contains(executorId)) @@ -360,6 +376,17 @@ class BlockManagerMasterEndpoint( } + private def addMergerLocation(blockManagerId: BlockManagerId): Unit = { + if (!blockManagerId.isDriver && !shuffleMergerLocations.contains(blockManagerId.host)) { + val shuffleServerId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, + blockManagerId.host, externalShuffleServicePort) + if (shuffleMergerLocations.size >= maxRetainedMergerLocations) { + shuffleMergerLocations -= shuffleMergerLocations.head._1 + } + shuffleMergerLocations(shuffleServerId.host) = shuffleServerId + } + } + private def removeExecutor(execId: String): Unit = { logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) @@ -526,6 +553,10 @@ class BlockManagerMasterEndpoint( blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint, externalShuffleServiceBlockStatus) + + if (pushBasedShuffleEnabled) { + addMergerLocation(id) + } } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) @@ -657,6 +688,40 @@ class BlockManagerMasterEndpoint( } } + private def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + val blockManagerHosts = blockManagerIdByExecutor.values.map(_.host).toSet + val filteredBlockManagerHosts = blockManagerHosts.filterNot(hostsToFilter.contains(_)) + val filteredMergersWithExecutors = filteredBlockManagerHosts.map( + BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, _, externalShuffleServicePort)) + // Enough mergers are available as part of active executors list + if (filteredMergersWithExecutors.size >= numMergersNeeded) { + filteredMergersWithExecutors.toSeq + } else { + // Delta mergers added from inactive mergers list to the active mergers list + val filteredMergersWithExecutorsHosts = filteredMergersWithExecutors.map(_.host) + val filteredMergersWithoutExecutors = shuffleMergerLocations.values + .filterNot(x => hostsToFilter.contains(x.host)) + .filterNot(x => filteredMergersWithExecutorsHosts.contains(x.host)) + val randomFilteredMergersLocations = + if (filteredMergersWithoutExecutors.size > + numMergersNeeded - filteredMergersWithExecutors.size) { + Utils.randomize(filteredMergersWithoutExecutors) + .take(numMergersNeeded - filteredMergersWithExecutors.size) + } else { + filteredMergersWithoutExecutors + } + filteredMergersWithExecutors.toSeq ++ randomFilteredMergersLocations + } + } + + private def removeShufflePushMergerLocation(host: String): Unit = { + if (shuffleMergerLocations.contains(host)) { + shuffleMergerLocations.remove(host) + } + } + /** * Returns an [[RpcEndpointRef]] of the [[BlockManagerReplicaEndpoint]] for sending RPC messages. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index bbc076cea9ba8..afe416a55ed0d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -141,4 +141,10 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster + + case class GetShufflePushMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String]) + extends ToBlockManagerMaster + + case class RemoveShufflePushMergerLocation(host: String) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b743ab6507117..6ccf65b737c1a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2541,6 +2541,14 @@ private[spark] object Utils extends Logging { master == "local" || master.startsWith("local[") } + /** + * Push based shuffle can only be enabled when external shuffle service is enabled. + */ + def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = { + conf.get(PUSH_BASED_SHUFFLE_ENABLED) && + (conf.get(IS_TESTING).getOrElse(false) || conf.get(SHUFFLE_SERVICE_ENABLED)) + } + /** * Return whether dynamic allocation is enabled in the given conf. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 55280fc578310..144489c5f7922 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -100,6 +100,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L) .set(Network.RPC_ASK_TIMEOUT, "5s") + .set(PUSH_BASED_SHUFFLE_ENABLED, true) } private def makeSortShuffleManager(): SortShuffleManager = { @@ -1974,6 +1975,48 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("SPARK-32919: Shuffle push merger locations should be bounded with in" + + " spark.shuffle.push.retainedMergerLocations") { + assert(master.getShufflePushMergerLocations(10, Set.empty).isEmpty) + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(10, Set.empty).size == 4) + assert(master.getShufflePushMergerLocations(10, Set.empty).map(_.host).sorted === + Seq("hostC", "hostD", "hostA", "hostB").sorted) + assert(master.getShufflePushMergerLocations(10, Set("hostB")).size == 3) + } + + test("SPARK-32919: Prefer active executor locations for shuffle push mergers") { + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(5, Set.empty).size == 4) + + master.removeExecutor("execA") + master.removeExecutor("execE") + + assert(master.getShufflePushMergerLocations(3, Set.empty).size == 3) + assert(master.getShufflePushMergerLocations(3, Set.empty).map(_.host).sorted === + Seq("hostC", "hostB", "hostD").sorted) + assert(master.getShufflePushMergerLocations(4, Set.empty).map(_.host).sorted === + Seq("hostB", "hostA", "hostC", "hostD").sorted) + } + test("SPARK-33387 Support ordered shuffle block migration") { val blocks: Seq[ShuffleBlockInfo] = Seq( ShuffleBlockInfo(1, 0L), @@ -1995,7 +2038,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(sortedBlocks.sameElements(decomManager.shufflesToMigrate.asScala.map(_._1))) } - class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { + class MockBlockTransferService( + val maxFailures: Int, + override val hostName: String = "MockBlockTransferServiceHost") extends BlockTransferService { var numCalls = 0 var tempFileManager: DownloadFileManager = null @@ -2013,8 +2058,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def close(): Unit = {} - override def hostName: String = { "MockBlockTransferServiceHost" } - override def port: Int = { 63332 } override def uploadBlock( diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 20624c743bc22..8fb408041ca9d 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.SparkListener import org.apache.spark.util.io.ChunkedByteBufferInputStream @@ -1432,6 +1433,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { }.getMessage assert(message.contains(expected)) } + + test("isPushBasedShuffleEnabled when both PUSH_BASED_SHUFFLE_ENABLED" + + " and SHUFFLE_SERVICE_ENABLED are true") { + val conf = new SparkConf() + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, false) + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(SHUFFLE_SERVICE_ENABLED, true) + assert(Utils.isPushBasedShuffleEnabled(conf) === true) + } } private class SimpleExtension diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index b42bdb9816600..22002bb32004d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.EnumSet -import java.util.concurrent.atomic.{AtomicBoolean} +import java.util.concurrent.atomic.AtomicBoolean import javax.servlet.DispatcherType import scala.concurrent.{ExecutionContext, Future} @@ -29,14 +29,14 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config.UI._ import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{RpcUtils, ThreadUtils} +import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Abstract Yarn scheduler backend that contains common logic @@ -80,6 +80,18 @@ private[spark] abstract class YarnSchedulerBackend( /** Attempt ID. This is unset for client-mode schedulers */ private var attemptId: Option[ApplicationAttemptId] = None + private val blockManagerMaster: BlockManagerMaster = sc.env.blockManager.master + + private val minMergersThresholdRatio = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO) + + private val minMergersStaticThreshold = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD) + + private val maxNumExecutors = conf.get(config.DYN_ALLOCATION_MAX_EXECUTORS) + + private val numExecutors = conf.get(config.EXECUTOR_INSTANCES).getOrElse(0) + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -161,6 +173,36 @@ private[spark] abstract class YarnSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + override def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = { + // TODO (SPARK-33481) This is a naive way of calculating numMergersDesired for a stage, + // TODO we can use better heuristics to calculate numMergersDesired for a stage. + val maxExecutors = if (Utils.isDynamicAllocationEnabled(sc.getConf)) { + maxNumExecutors + } else { + numExecutors + } + val tasksPerExecutor = sc.resourceProfileManager + .resourceProfileFromId(resourceProfileId).maxTasksPerExecutor(sc.conf) + val numMergersDesired = math.min( + math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxExecutors) + val minMergersNeeded = math.max(minMergersStaticThreshold, + math.floor(numMergersDesired * minMergersThresholdRatio).toInt) + + // Request for numMergersDesired shuffle mergers to BlockManagerMasterEndpoint + // and if it's less than minMergersNeeded, we disable push based shuffle. + val mergerLocations = blockManagerMaster + .getShufflePushMergerLocations(numMergersDesired, scheduler.excludedNodes()) + if (mergerLocations.size < numMergersDesired && mergerLocations.size < minMergersNeeded) { + Seq.empty[BlockManagerId] + } else { + logDebug(s"The number of shuffle mergers desired ${numMergersDesired}" + + s" and available locations are ${mergerLocations.length}") + mergerLocations + } + } + /** * Add filters to the SparkUI. */ From 2289389821a23e5b5badabfb4e62c427de2554a5 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 20 Nov 2020 21:27:41 +0900 Subject: [PATCH 05/38] [SPARK-33441][BUILD][FOLLOWUP] Make unused-imports check for SBT specific ### What changes were proposed in this pull request? Move "unused-imports" check config to `SparkBuild.scala` and make it SBT specific. ### Why are the changes needed? Make unused-imports check for SBT specific. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass the Jenkins or GitHub Action Closes #30441 from LuciferYang/SPARK-33441-FOLLOWUP. Authored-by: yangjie01 Signed-off-by: HyukjinKwon --- pom.xml | 5 +---- project/SparkBuild.scala | 3 +++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index 3ae2e7420e154..85cf5a00b0b24 100644 --- a/pom.xml +++ b/pom.xml @@ -164,7 +164,6 @@ 3.2.2 2.12.10 2.12 - -Ywarn-unused-import 2.0.0 --test @@ -2538,7 +2537,6 @@ -deprecation -feature -explaintypes - ${scalac.arg.unused-imports} -target:jvm-1.8 @@ -3262,13 +3260,12 @@ - + scala-2.13 2.13.3 2.13 - -Wconf:cat=unused-imports:e diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 55c87fcb3aaa2..05413b7091ad9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -221,6 +221,7 @@ object SparkBuild extends PomBuild { Seq( "-Xfatal-warnings", "-deprecation", + "-Ywarn-unused-import", "-P:silencer:globalFilters=.*deprecated.*" //regex to catch deprecation warnings and supress them ) } else { @@ -230,6 +231,8 @@ object SparkBuild extends PomBuild { // see `scalac -Wconf:help` for details "-Wconf:cat=deprecation:wv,any:e", // 2.13-specific warning hits to be muted (as narrowly as possible) and addressed separately + // TODO(SPARK-33499): Enable this option when Scala 2.12 is no longer supported. + // "-Wunused:imports", "-Wconf:cat=lint-multiarg-infix:wv", "-Wconf:cat=other-nullary-override:wv", "-Wconf:cat=other-match-analysis&site=org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupFunction.catalogFunction:wv", From 870d4095336f29f5bef77b9232d6cb9d025987dd Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 20 Nov 2020 12:53:45 +0000 Subject: [PATCH 06/38] [SPARK-32512][SQL][TESTS][FOLLOWUP] Remove duplicate tests for ALTER TABLE .. PARTITIONS from DataSourceV2SQLSuite ### What changes were proposed in this pull request? Remove tests from `DataSourceV2SQLSuite` that were copied to `AlterTablePartitionV2SQLSuite` by https://github.com/apache/spark/pull/29339. ### Why are the changes needed? - To reduce tests execution time - To improve test maintenance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By running the modified tests: ``` $ build/sbt "test:testOnly *DataSourceV2SQLSuite" $ build/sbt "test:testOnly *AlterTablePartitionV2SQLSuite" ``` Closes #30444 from MaxGekk/dedup-tests-AlterTablePartitionV2SQLSuite. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../sql/connector/DataSourceV2SQLSuite.scala | 53 ------------------- 1 file changed, 53 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index ddafa1bb5070a..0057415ff6e1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkException import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connector.catalog._ @@ -43,7 +42,6 @@ class DataSourceV2SQLSuite with AlterTableTests with DatasourceV2SQLBase { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ private val v2Source = classOf[FakeV2Provider].getName override protected val v2Format = v2Source @@ -1980,57 +1978,6 @@ class DataSourceV2SQLSuite } } - test("ALTER TABLE RECOVER PARTITIONS") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - val e = intercept[AnalysisException] { - sql(s"ALTER TABLE $t RECOVER PARTITIONS") - } - assert(e.message.contains("ALTER TABLE RECOVER PARTITIONS is only supported with v1 tables")) - } - } - - test("ALTER TABLE ADD PARTITION") { - val t = "testpart.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'") - - val partTable = catalog("testpart").asTableCatalog - .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable] - assert(partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) - - val partMetadata = partTable.loadPartitionMetadata(InternalRow.fromSeq(Seq(1))) - assert(partMetadata.containsKey("location")) - assert(partMetadata.get("location") == "loc") - } - } - - test("ALTER TABLE RENAME PARTITION") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - val e = intercept[AnalysisException] { - sql(s"ALTER TABLE $t PARTITION (id=1) RENAME TO PARTITION (id=2)") - } - assert(e.message.contains("ALTER TABLE RENAME PARTITION is only supported with v1 tables")) - } - } - - test("ALTER TABLE DROP PARTITION") { - val t = "testpart.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'") - spark.sql(s"ALTER TABLE $t DROP PARTITION (id=1)") - - val partTable = - catalog("testpart").asTableCatalog.loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) - assert(!partTable.asPartitionable.partitionExists(InternalRow.fromSeq(Seq(1)))) - } - } - test("ALTER TABLE SerDe properties") { val t = "testcat.ns1.ns2.tbl" withTable(t) { From cbc8be24c896ed25be63ef9a111ff015af4fabec Mon Sep 17 00:00:00 2001 From: liucht Date: Fri, 20 Nov 2020 22:19:35 +0900 Subject: [PATCH 07/38] [SPARK-33422][DOC] Fix the correct display of left menu item ### What changes were proposed in this pull request? Limit the height of the menu area on the left to display vertical scroll bar ### Why are the changes needed? The bottom menu item cannot be displayed when the left menu tree is long ### Does this PR introduce any user-facing change? Yes, if the menu item shows more, you'll see it by pulling down the vertical scroll bar before: ![image](https://user-images.githubusercontent.com/28332082/98805115-16995d80-2452-11eb-933a-3b72c14bea78.png) after: ![image](https://user-images.githubusercontent.com/28332082/98805418-7e4fa880-2452-11eb-9a9b-8d265078297c.png) ### How was this patch tested? NA Closes #30335 from liucht-inspur/master. Authored-by: liucht Signed-off-by: HyukjinKwon --- docs/css/main.css | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/css/main.css b/docs/css/main.css index 8168a46f9a437..8b279a157c2b6 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -162,6 +162,7 @@ body .container-wrapper { margin-right: auto; border-radius: 15px; position: relative; + min-height: 100vh; } .title { @@ -264,6 +265,7 @@ a:hover code { max-width: 914px; line-height: 1.6; /* Inspired by Github's wiki style */ padding-left: 30px; + min-height: 100vh; } .dropdown-menu { @@ -325,6 +327,7 @@ a.anchorjs-link:hover { text-decoration: none; } border-bottom-width: 0px; margin-top: 0px; width: 210px; + height: 80%; float: left; position: fixed; overflow-y: scroll; From 3384bda453d0e728be311ce458e00d70d2484973 Mon Sep 17 00:00:00 2001 From: ulysses Date: Fri, 20 Nov 2020 13:23:08 +0000 Subject: [PATCH 08/38] [SPARK-33468][SQL] ParseUrl in ANSI mode should fail if input string is not a valid url ### What changes were proposed in this pull request? With `ParseUrl`, instead of return null we throw exception if input string is not a vaild url. ### Why are the changes needed? For ANSI mode. ### Does this PR introduce _any_ user-facing change? Yes, user will get exception if `set spark.sql.ansi.enabled=true`. ### How was this patch tested? Add test. Closes #30399 from ulysses-you/SPARK-33468. Lead-authored-by: ulysses Co-authored-by: ulysses-you Signed-off-by: Wenchen Fan --- docs/sql-ref-ansi-compliance.md | 1 + .../catalyst/expressions/stringExpressions.scala | 7 +++++-- .../expressions/StringExpressionsSuite.scala | 14 ++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index fd7208615a09f..870ed0aa0daaa 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -135,6 +135,7 @@ The behavior of some SQL functions can be different under ANSI mode (`spark.sql. - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. - `element_at`: This function throws `NoSuchElementException` if key does not exist in map. - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `parse_url`: This function throws `IllegalArgumentException` if an input string is not a valid url. ### SQL Operators diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 16e22940495f1..9f92181b34df1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1357,8 +1357,9 @@ object ParseUrl { 1 """, since = "2.0.0") -case class ParseUrl(children: Seq[Expression]) +case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression with ExpectsInputTypes with CodegenFallback { + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) @@ -1404,7 +1405,9 @@ case class ParseUrl(children: Seq[Expression]) try { new URI(url.toString) } catch { - case e: URISyntaxException => null + case e: URISyntaxException if failOnError => + throw new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e) + case _: URISyntaxException => null } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index a1b6cec24f23f..730574a4b9846 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -943,6 +943,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil) } + test("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val msg = intercept[IllegalArgumentException] { + evaluateWithoutCodegen( + ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST"))) + }.getMessage + assert(msg.contains("Find an invaild url string")) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation( + ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST")), null) + } + } + test("Sentences") { val nullString = Literal.create(null, StringType) checkEvaluation(Sentences(nullString, nullString, nullString), null) From 47326ac1c6a296a84af76d832061741740ae9f12 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 20 Nov 2020 08:40:14 -0800 Subject: [PATCH 09/38] [SPARK-28704][SQL][TEST] Add back Skiped HiveExternalCatalogVersionsSuite in HiveSparkSubmitSuite at JDK9+ ### What changes were proposed in this pull request? We skip test HiveExternalCatalogVersionsSuite when testing with JAVA_9 or later because our previous version does not support JAVA_9 or later. We now add it back since we have a version supports JAVA_9 or later. ### Why are the changes needed? To recover test coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Check CI logs. Closes #30428 from AngersZhuuuu/SPARK-28704. Authored-by: angerszhu Signed-off-by: Dongjoon Hyun --- .../HiveExternalCatalogVersionsSuite.scala | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 38a8c492d77a7..4cafd3e8ca626 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -52,7 +52,6 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { import HiveExternalCatalogVersionsSuite._ - private val isTestAtLeastJava9 = SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9) private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `spark.test.cache-dir` to a static value like `/tmp/test-spark`, to @@ -149,7 +148,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) } - private def prepare(): Unit = { + override def beforeAll(): Unit = { + super.beforeAll() + val tempPyFile = File.createTempFile("test", ".py") // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, @@ -199,7 +200,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=2.3.7", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--conf", s"spark.sql.test.version.index=$index", @@ -211,23 +212,14 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { tempPyFile.delete() } - override def beforeAll(): Unit = { - super.beforeAll() - if (!isTestAtLeastJava9) { - prepare() - } - } - test("backward compatibility") { - // TODO SPARK-28704 Test backward compatibility on JDK9+ once we have a version supports JDK9+ - assume(!isTestAtLeastJava9) val args = Seq( "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"), "--name", "HiveExternalCatalog backward compatibility test", "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=2.3.7", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", @@ -252,7 +244,9 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // do not throw exception during object initialization. case NonFatal(_) => Seq("3.0.1", "2.4.7") // A temporary fallback to use a specific version } - versions.filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) + versions + .filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) + .filter(v => v.startsWith("3") || !SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) } protected var spark: SparkSession = _ From 116b7b72a1980a0768413329f28591f772822827 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 20 Nov 2020 11:35:34 -0600 Subject: [PATCH 10/38] [SPARK-33466][ML][PYTHON] Imputer support mode(most_frequent) strategy ### What changes were proposed in this pull request? impl a new strategy `mode`: replace missing using the most frequent value along each column. ### Why are the changes needed? it is highly scalable, and had been a function in [sklearn.impute.SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html#sklearn.impute.SimpleImputer) for a long time. ### Does this PR introduce _any_ user-facing change? Yes, a new strategy is added ### How was this patch tested? updated testsuites Closes #30397 from zhengruifeng/imputer_max_freq. Lead-authored-by: Ruifeng Zheng Co-authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../org/apache/spark/ml/feature/Imputer.scala | 49 ++-- .../spark/ml/feature/ImputerSuite.scala | 211 ++++++++++-------- python/pyspark/ml/feature.py | 5 +- 3 files changed, 144 insertions(+), 121 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index ad1010da5c104..03ebe0299f63f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp * The imputation strategy. Currently only "mean" and "median" are supported. * If "mean", then replace missing values using the mean value of the feature. * If "median", then replace missing values using the approximate median value of the feature. + * If "mode", then replace missing using the most frequent value of the feature. * Default: mean * * @group param */ final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + - s"If ${Imputer.median}, then replace missing values using the median value of the feature.", - ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) + s"If ${Imputer.median}, then replace missing values using the median value of the feature. " + + s"If ${Imputer.mode}, then replace missing values using the most frequent value of " + + s"the feature.", ParamValidators.inArray[String](Imputer.supportedStrategies)) /** @group getParam */ def getStrategy: String = $(strategy) @@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp * For example, if the input column is IntegerType (1, 2, 4, null), * the output will be IntegerType (1, 2, 4, 2) after mean imputation. * - * Note that the mean/median value is computed after filtering out missing values. + * Note that the mean/median/mode value is computed after filtering out missing values. * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. */ @@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) def setOutputCols(value: Array[String]): this.type = set(outputCols, value) /** - * Imputation strategy. Available options are ["mean", "median"]. + * Imputation strategy. Available options are ["mean", "median", "mode"]. * @group setParam */ @Since("2.2.0") @@ -151,39 +153,42 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) val spark = dataset.sparkSession val (inputColumns, _) = getInOutCols() - val cols = inputColumns.map { inputCol => when(col(inputCol).equalTo($(missingValue)), null) .when(col(inputCol).isNaN, null) .otherwise(col(inputCol)) - .cast("double") + .cast(DoubleType) .as(inputCol) } + val numCols = cols.length val results = $(strategy) match { case Imputer.mean => // Function avg will ignore null automatically. // For a column only containing null, avg will return null. val row = dataset.select(cols.map(avg): _*).head() - Array.range(0, inputColumns.length).map { i => - if (row.isNullAt(i)) { - Double.NaN - } else { - row.getDouble(i) - } - } + Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i)) case Imputer.median => // Function approxQuantile will ignore null automatically. // For a column only containing null, approxQuantile will return an empty array. dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError)) - .map { array => - if (array.isEmpty) { - Double.NaN - } else { - array.head - } - } + .map(_.headOption.getOrElse(Double.NaN)) + + case Imputer.mode => + import spark.implicits._ + // If there is more than one mode, choose the smallest one to keep in line + // with sklearn.impute.SimpleImputer (using scipy.stats.mode). + val modes = dataset.select(cols: _*).flatMap { row => + // Ignore null. + Iterator.range(0, numCols) + .flatMap(i => if (row.isNullAt(i)) None else Some((i, row.getDouble(i)))) + }.toDF("index", "value") + .groupBy("index", "value").agg(negate(count(lit(0))).as("negative_count")) + .groupBy("index").agg(min(struct("negative_count", "value")).as("mode")) + .select("index", "mode.value") + .as[(Int, Double)].collect().toMap + Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN)) } val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1) @@ -212,6 +217,10 @@ object Imputer extends DefaultParamsReadable[Imputer] { /** strategy names that Imputer currently supports. */ private[feature] val mean = "mean" private[feature] val median = "median" + private[feature] val mode = "mode" + + /* Set of strategies that Imputer supports */ + private[feature] val supportedStrategies = Array(mean, median, mode) @Since("2.2.0") override def load(path: String): Imputer = super.load(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index dfee2b4029c8b..30887f55638f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -28,13 +28,14 @@ import org.apache.spark.sql.types._ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0), - (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0), - (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0), - (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0) - )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1", - "expected_mean_value2", "expected_median_value2") + val df = spark.createDataFrame(Seq( + (0, 1.0, 4.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0), + (1, 11.0, 12.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0), + (2, 3.0, Double.NaN, 3.0, 3.0, 3.0, 10.0, 12.0, 4.0), + (3, Double.NaN, 14.0, 5.0, 3.0, 1.0, 14.0, 14.0, 14.0) + )).toDF("id", "value1", "value2", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1", + "expected_mean_value2", "expected_median_value2", "expected_mode_value2") val imputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("out1", "out2")) @@ -42,23 +43,25 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer for Double with default missing Value NaN") { - val df1 = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 11.0, 11.0, 11.0), - (2, 3.0, 3.0, 3.0), - (3, Double.NaN, 5.0, 3.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df1 = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 11.0, 11.0, 11.0, 11.0), + (2, 3.0, 3.0, 3.0, 3.0), + (3, Double.NaN, 5.0, 3.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer1 = new Imputer() .setInputCol("value") .setOutputCol("out") ImputerSuite.iterateStrategyTest(false, imputer1, df1) - val df2 = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 12.0, 12.0, 12.0), - (2, Double.NaN, 10.0, 12.0), - (3, 14.0, 14.0, 14.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df2 = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 12.0, 12.0, 12.0, 12.0), + (2, Double.NaN, 10.0, 12.0, 4.0), + (3, 14.0, 14.0, 14.0, 14.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer2 = new Imputer() .setInputCol("value") .setOutputCol("out") @@ -66,12 +69,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 3.0, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 1.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) ImputerSuite.iterateStrategyTest(true, imputer, df) @@ -79,64 +83,69 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Single Column: Imputer should handle NaNs when computing surrogate value," + " if missingValue is not NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 3.0, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 1.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setMissingValue(-1.0) ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer for Float with missing Value -1.0") { - val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0F, 1.0F), - (1, 3.0F, 3.0F, 3.0F), - (2, 10.0F, 10.0F, 10.0F), - (3, 10.0F, 10.0F, 10.0F), - (4, -1.0F, 6.0F, 3.0F) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0F, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F, 10.0F) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1) ImputerSuite.iterateStrategyTest(true, imputer, df) } test("Single Column: Imputer for Float with missing Value -1.0") { - val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0F, 1.0F), - (1, 3.0F, 3.0F, 3.0F), - (2, 10.0F, 10.0F, 10.0F), - (3, 10.0F, 10.0F, 10.0F), - (4, -1.0F, 6.0F, 3.0F) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0F, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F, 10.0F) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setMissingValue(-1) ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer should impute null as well as 'missingValue'") { - val rawDf = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 10.0, 10.0, 10.0), - (2, 10.0, 10.0, 10.0), - (3, Double.NaN, 8.0, 10.0), - (4, -1.0, 8.0, 10.0) - )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val rawDf = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0, 10.0), + (4, -1.0, 8.0, 10.0, 10.0) + )).toDF("id", "rawValue", + "expected_mean_value", "expected_median_value", "expected_mode_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) ImputerSuite.iterateStrategyTest(true, imputer, df) } test("Single Column: Imputer should impute null as well as 'missingValue'") { - val rawDf = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 10.0, 10.0, 10.0), - (2, 10.0, 10.0, 10.0), - (3, Double.NaN, 8.0, 10.0), - (4, -1.0, 8.0, 10.0) - )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val rawDf = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0, 10.0), + (4, -1.0, 8.0, 10.0, 10.0) + )).toDF("id", "rawValue", + "expected_mean_value", "expected_median_value", "expected_mode_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") ImputerSuite.iterateStrategyTest(false, imputer, df) @@ -187,7 +196,7 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer throws exception when surrogate cannot be computed") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, Double.NaN, 1.0, 1.0), (1, Double.NaN, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN) @@ -205,12 +214,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer throws exception when surrogate cannot be computed") { - val df = spark.createDataFrame( Seq( - (0, Double.NaN, 1.0, 1.0), - (1, Double.NaN, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") - Seq("mean", "median").foreach { strategy => + val df = spark.createDataFrame(Seq( + (0, Double.NaN, 1.0, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") + Seq("mean", "median", "mode").foreach { strategy => val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setStrategy(strategy) withClue("Imputer should fail all the values are invalid") { @@ -223,12 +233,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer input & output column validation") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, 1.0, 1.0, 1.0), (1, Double.NaN, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN) )).toDF("id", "value1", "value2", "value3") - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => withClue("Imputer should fail if inputCols and outputCols are different length") { val e: IllegalArgumentException = intercept[IllegalArgumentException] { val imputer = new Imputer().setStrategy(strategy) @@ -306,13 +316,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer for IntegerType with default missing value null") { - - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (null, 5, 3) - )).toDF("value1", "expected_mean_value1", "expected_median_value1") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (null, 5, 3, 1) + )).toDF("value1", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -327,12 +337,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column Imputer for IntegerType with default missing value null") { - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (null, 5, 3) - )).toDF("value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (null, 5, 3, 1) + )).toDF("value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer() .setInputCol("value") @@ -347,13 +358,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer for IntegerType with missing value -1") { - - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (-1, 5, 3) - )).toDF("value1", "expected_mean_value1", "expected_median_value1") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (-1, 5, 3, 1) + )).toDF("value1", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -369,12 +380,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer for IntegerType with missing value -1") { - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (-1, 5, 3) - )).toDF("value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (-1, 5, 3, 1) + )).toDF("value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer() .setInputCol("value") @@ -402,13 +414,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Compare single/multiple column(s) Imputer in pipeline") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, 1.0, 4.0), (1, 11.0, 12.0), (2, 3.0, Double.NaN), (3, Double.NaN, 14.0) )).toDF("id", "value1", "value2") - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => val multiColsImputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("result1", "result2")) @@ -450,11 +462,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { object ImputerSuite { /** - * Imputation strategy. Available options are ["mean", "median"]. - * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" + * Imputation strategy. Available options are ["mean", "median", "mode"]. + * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median", + * "expected_mode". */ def iterateStrategyTest(isMultiCol: Boolean, imputer: Imputer, df: DataFrame): Unit = { - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) val resultDF = model.transform(df) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4d898bd5fffa8..82b9a6db1eb92 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1507,7 +1507,8 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has strategy = Param(Params._dummy(), "strategy", "strategy for imputation. If mean, then replace missing values using the mean " "value of the feature. If median, then replace missing values using the " - "median value of the feature.", + "median value of the feature. If mode, then replace missing using the most " + "frequent value of the feature.", typeConverter=TypeConverters.toString) missingValue = Param(Params._dummy(), "missingValue", @@ -1541,7 +1542,7 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable): numeric type. Currently Imputer does not support categorical features and possibly creates incorrect values for a categorical feature. - Note that the mean/median value is computed after filtering out missing values. + Note that the mean/median/mode value is computed after filtering out missing values. All Null values in the input columns are treated as missing, and so are also imputed. For computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a relative error of `0.001`. From a1a3d5cb02e380156eab320bf6cf512c01b11284 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 20 Nov 2020 10:14:37 -0800 Subject: [PATCH 11/38] [MINOR][TESTS][DOCS] Use fully-qualified class name in docker integration test ### What changes were proposed in this pull request? change ``` ./build/sbt -Pdocker-integration-tests "testOnly *xxxIntegrationSuite" ``` to ``` ./build/sbt -Pdocker-integration-tests "testOnly org.apache.spark.sql.jdbc.xxxIntegrationSuite" ``` ### Why are the changes needed? We only want to start v1 ```xxxIntegrationSuite```, not the newly added```v2.xxxIntegrationSuite```. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually checked Closes #30448 from huaxingao/dockertest. Authored-by: Huaxin Gao Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala | 3 ++- .../apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala | 3 ++- .../org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala | 3 ++- .../org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 4b9acd0d39f3f..d086c8cdcc589 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.4.0): * {{{ * DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.4.0 - * ./build/sbt -Pdocker-integration-tests "testOnly *DB2IntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.DB2IntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index f1ffc8f0f3dc7..939a07238934b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., 2019-GA-ubuntu-16.04): * {{{ * MSSQLSERVER_DOCKER_IMAGE_NAME=2019-GA-ubuntu-16.04 - * ./build/sbt -Pdocker-integration-tests "testOnly *MsSqlServerIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 6f96ab33d0fee..68f0dbc057c1f 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., mysql:5.7.31): * {{{ * MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.31 - * ./build/sbt -Pdocker-integration-tests "testOnly *MySQLIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.MySQLIntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index fa13100b5fdc8..0347c98bba2c4 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., postgres:13.0): * {{{ * POSTGRES_DOCKER_IMAGE_NAME=postgres:13.0 - * ./build/sbt -Pdocker-integration-tests "testOnly *PostgresIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} */ @DockerTest From 247977893473f810ffbcda31ee2710e445120e42 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Fri, 20 Nov 2020 14:59:56 -0800 Subject: [PATCH 12/38] [SPARK-33492][SQL] DSv2: Append/Overwrite/ReplaceTable should invalidate cache ### What changes were proposed in this pull request? This adds changes in the following places: - logic to also refresh caches referencing the target table in v2 `AppendDataExec`, `OverwriteByExpressionExec`, `OverwritePartitionsDynamicExec`, as well as their v1 fallbacks `AppendDataExecV1` and `OverwriteByExpressionExecV1`. - logic to invalidate caches referencing the target table in v2 `ReplaceTableAsSelectExec` and its atomic version `AtomicReplaceTableAsSelectExec`. These are only supported in v2 at the moment though. In addition to the above, in order to test the v1 write fallback behavior, I extended `InMemoryTableWithV1Fallback` to also support batch reads. ### Why are the changes needed? Currently in DataSource v2 we don't refresh or invalidate caches referencing the target table when the table content is changed by operations such as append, overwrite, or replace table. This is different from DataSource v1, and could potentially cause data correctness issue if the staled caches are queried later. ### Does this PR introduce _any_ user-facing change? Yes. Now When a data source v2 is cached (either directly or indirectly), all the relevant caches will be refreshed or invalidated if the table is replaced. ### How was this patch tested? Added unit tests for the new code path. Closes #30429 from sunchao/SPARK-33492. Authored-by: Chao Sun Signed-off-by: Dongjoon Hyun --- .../datasources/v2/DataSourceV2Strategy.scala | 13 +-- .../datasources/v2/V1FallbackWriters.scala | 21 +++-- .../v2/WriteToDataSourceV2Exec.scala | 38 ++++++++- .../sql/connector/DataSourceV2SQLSuite.scala | 78 ++++++++++++++++++ .../sql/connector/V1WriteFallbackSuite.scala | 79 ++++++++++++++++++- 5 files changed, 212 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 21abfc2816ee4..e5c29312b80e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -147,6 +147,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( + session, staging, ident, parts, @@ -157,6 +158,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat orCreate = orCreate) :: Nil case _ => ReplaceTableAsSelectExec( + session, catalog, ident, parts, @@ -170,9 +172,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil + AppendDataExecV1(v1, writeOptions.asOptions, query, r) :: Nil case v2 => - AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil + AppendDataExec(session, v2, r, writeOptions.asOptions, planLater(query)) :: Nil } case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => @@ -184,14 +186,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil + OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query, r) :: Nil case v2 => - OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil + OverwriteByExpressionExec(session, v2, r, filters, + writeOptions.asOptions, planLater(query)) :: Nil } case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => OverwritePartitionsDynamicExec( - r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil + session, r.table.asWritable, r, writeOptions.asOptions, planLater(query)) :: Nil case DeleteFromTable(relation, condition) => relation match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 560da39314b36..af7721588edeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -37,10 +37,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class AppendDataExecV1( table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { override protected def run(): Seq[InternalRow] = { - writeWithV1(newWriteBuilder().buildForV1Write()) + writeWithV1(newWriteBuilder().buildForV1Write(), Some(v2Relation)) } } @@ -59,7 +60,8 @@ case class OverwriteByExpressionExecV1( table: SupportsWrite, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { private def isTruncate(filters: Array[Filter]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] @@ -68,10 +70,10 @@ case class OverwriteByExpressionExecV1( override protected def run(): Seq[InternalRow] = { newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => - writeWithV1(builder.truncate().asV1Builder.buildForV1Write()) + writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), Some(v2Relation)) case builder: SupportsOverwrite => - writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write()) + writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(), Some(v2Relation)) case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") @@ -112,9 +114,14 @@ sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { trait SupportsV1Write extends SparkPlan { def plan: LogicalPlan - protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = { + protected def writeWithV1( + relation: InsertableRelation, + v2Relation: Option[DataSourceV2Relation] = None): Seq[InternalRow] = { + val session = sqlContext.sparkSession // The `plan` is already optimized, we should not analyze and optimize it again. - relation.insert(AlreadyOptimized.dataFrame(sqlContext.sparkSession, plan), overwrite = false) + relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false) + v2Relation.foreach(r => session.sharedState.cacheManager.recacheByPlan(session, r)) + Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 1421a9315c3a8..1648134d0a1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -127,6 +128,7 @@ case class AtomicCreateTableAsSelectExec( * ReplaceTableAsSelectStagingExec. */ case class ReplaceTableAsSelectExec( + session: SparkSession, catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -146,6 +148,8 @@ case class ReplaceTableAsSelectExec( // 2. Writing to the new table fails, // 3. The table returned by catalog.createTable doesn't support writing. if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -169,6 +173,7 @@ case class ReplaceTableAsSelectExec( * is left untouched. */ case class AtomicReplaceTableAsSelectExec( + session: SparkSession, catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -180,6 +185,10 @@ case class AtomicReplaceTableAsSelectExec( override protected def run(): Seq[InternalRow] = { val schema = query.schema.asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) + } val staged = if (orCreate) { catalog.stageCreateOrReplace( ident, schema, partitioning.toArray, properties.asJava) @@ -204,12 +213,16 @@ case class AtomicReplaceTableAsSelectExec( * Rows in the output data set are appended. */ case class AppendDataExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - writeWithV2(newWriteBuilder().buildForBatch()) + val writtenRows = writeWithV2(newWriteBuilder().buildForBatch()) + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -224,7 +237,9 @@ case class AppendDataExec( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -234,7 +249,7 @@ case class OverwriteByExpressionExec( } override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => writeWithV2(builder.truncate().buildForBatch()) @@ -244,9 +259,12 @@ case class OverwriteByExpressionExec( case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } + /** * Physical plan node for dynamic partition overwrite into a v2 table. * @@ -257,18 +275,22 @@ case class OverwriteByExpressionExec( * are not modified. */ case class OverwritePartitionsDynamicExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsDynamicOverwrite => writeWithV2(builder.overwriteDynamicPartitions().buildForBatch()) case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -370,6 +392,15 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { Nil } + + protected def uncacheTable( + session: SparkSession, + catalog: TableCatalog, + table: Table, + ident: Identifier): Unit = { + val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true) + } } object DataWritingSparkTask extends Logging { @@ -484,3 +515,4 @@ private[v2] case class DataWritingSparkTaskResult( * Sink progress information collected after commit. */ private[sql] case class StreamWriterCommitProgress(numOutputRows: Long) + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 0057415ff6e1d..0e7aec8d80e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -780,6 +780,84 @@ class DataSourceV2SQLSuite } } + test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") { + Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"REPLACE TABLE $t USING foo AS SELECT id FROM source") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty) + } + } + } + } + + test("SPARK-33492: AppendData should refresh cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a")).toDF("i", "j").write.saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b")).toDF("i", "j").write.mode(SaveMode.Append).saveAsTable(t) + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Row(2, "b") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Row(2) :: Nil) + } + } + } + + test("SPARK-33492: OverwriteByExpression should refresh cache") { + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"INSERT OVERWRITE TABLE $t VALUES (1, 'a')") + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + } + } + } + + test("SPARK-33492: OverwritePartitionsDynamic should refresh cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a", 1)).toDF("i", "j", "k").write.partitionBy("k") saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b", 1)).toDF("i", "j", "k").writeTo(t).overwritePartitions() + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(2, "b", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(2) :: Nil) + } + } + } + test("Relation: basic") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 4b52a4cbf4116..cba7dd35fb3bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -24,14 +24,17 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources._ @@ -145,6 +148,52 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before SparkSession.setDefaultSession(spark) } } + + test("SPARK-33492: append fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").append() + checkAnswer(session.read.table("test"), Row(1, "x") :: Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } + + test("SPARK-33492: overwrite fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").overwrite(lit(true)) + checkAnswer(session.read.table("test"), Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } } class V1WriteFallbackSessionCatalogSuite @@ -177,6 +226,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) + tables.put(Identifier.of(Array("default"), name), t) t } } @@ -272,7 +322,7 @@ class InMemoryTableWithV1Fallback( override val partitioning: Array[Transform], override val properties: util.Map[String, String]) extends Table - with SupportsWrite { + with SupportsWrite with SupportsRead { partitioning.foreach { t => if (!t.isInstanceOf[IdentityTransform]) { @@ -281,6 +331,7 @@ class InMemoryTableWithV1Fallback( } override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.TRUNCATE).asJava @@ -338,6 +389,30 @@ class InMemoryTableWithV1Fallback( } } } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new V1ReadFallbackScanBuilder(schema) + + private class V1ReadFallbackScanBuilder(schema: StructType) extends ScanBuilder { + override def build(): Scan = new V1ReadFallbackScan(schema) + } + + private class V1ReadFallbackScan(schema: StructType) extends V1Scan { + override def readSchema(): StructType = schema + override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = + new V1TableScan(context, schema).asInstanceOf[T] + } + + private class V1TableScan( + context: SQLContext, + requiredSchema: StructType) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = context + override def schema: StructType = requiredSchema + override def buildScan(): RDD[Row] = { + val data = InMemoryV1Provider.getTableData(context.sparkSession, name).collect() + context.sparkContext.makeRDD(data) + } + } } /** A rule that fails if a query plan is analyzed twice. */ From de0f50abf407ec972c6a80ae80853a66b24468f4 Mon Sep 17 00:00:00 2001 From: anchovYu Date: Sat, 21 Nov 2020 08:33:39 +0900 Subject: [PATCH 13/38] [SPARK-32670][SQL] Group exception messages in Catalyst Analyzer in one file ### What changes were proposed in this pull request? Group all messages of `AnalysisExcpetions` created and thrown directly in org.apache.spark.sql.catalyst.analysis.Analyzer in one file. * Create a new object: `org.apache.spark.sql.CatalystErrors` with many exception-creating functions. * When the `Analyzer` wants to create and throw a new `AnalysisException`, call functions of `CatalystErrors` ### Why are the changes needed? This is the sample PR that groups exception messages together in several files. It will largely help with standardization of error messages and its maintenance. ### Does this PR introduce _any_ user-facing change? No. Error messages remain unchanged. ### How was this patch tested? No new tests - pass all original tests to make sure it doesn't break any existing behavior. ### Naming of exception functions All function names ended with `Error`. * For specific errors like `groupingIDMismatch` and `groupingColInvalid`, directly use them as name, just like `groupingIDMismatchError` and `groupingColInvalidError`. * For generic errors like `dataTypeMismatch`, * if confident with the context, prefix and condition can be added, like `pivotValDataTypeMismatchError` * if not sure about the context, add a `For` suffix of the specific component that this exception is related to, like `dataTypeMismatchForDeserializerError` Closes #29497 from anchovYu/32670. Lead-authored-by: anchovYu Co-authored-by: anchovYu Signed-off-by: HyukjinKwon --- .../spark/sql/QueryCompilationErrors.scala | 164 ++++++++++++++++++ .../sql/catalyst/analysis/Analyzer.scala | 84 +++------ 2 files changed, 192 insertions(+), 56 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala new file mode 100644 index 0000000000000..c680502cb328f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, GroupingID} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{AbstractDataType, DataType, StructType} + +/** + * Object for grouping all error messages of the query compilation. + * Currently it includes all AnalysisExcpetions created and thrown directly in + * org.apache.spark.sql.catalyst.analysis.Analyzer. + */ +object QueryCompilationErrors { + def groupingIDMismatchError(groupingID: GroupingID, groupByExprs: Seq[Expression]): Throwable = { + new AnalysisException( + s"Columns of grouping_id (${groupingID.groupByExprs.mkString(",")}) " + + s"does not match grouping columns (${groupByExprs.mkString(",")})") + } + + def groupingColInvalidError(groupingCol: Expression, groupByExprs: Seq[Expression]): Throwable = { + new AnalysisException( + s"Column of grouping ($groupingCol) can't be found " + + s"in grouping columns ${groupByExprs.mkString(",")}") + } + + def groupingSizeTooLargeError(sizeLimit: Int): Throwable = { + new AnalysisException( + s"Grouping sets size cannot be greater than $sizeLimit") + } + + def unorderablePivotColError(pivotCol: Expression): Throwable = { + new AnalysisException( + s"Invalid pivot column '$pivotCol'. Pivot columns must be comparable." + ) + } + + def nonLiteralPivotValError(pivotVal: Expression): Throwable = { + new AnalysisException( + s"Literal expressions required for pivot values, found '$pivotVal'") + } + + def pivotValDataTypeMismatchError(pivotVal: Expression, pivotCol: Expression): Throwable = { + new AnalysisException( + s"Invalid pivot value '$pivotVal': " + + s"value data type ${pivotVal.dataType.simpleString} does not match " + + s"pivot column data type ${pivotCol.dataType.catalogString}") + } + + def unsupportedIfNotExistsError(tableName: String): Throwable = { + new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: $tableName") + } + + def nonPartitionColError(partitionName: String): Throwable = { + new AnalysisException( + s"PARTITION clause cannot contain a non-partition column name: $partitionName") + } + + def addStaticValToUnknownColError(staticName: String): Throwable = { + new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + } + + def unknownStaticPartitionColError(name: String): Throwable = { + new AnalysisException(s"Unknown static partition column: $name") + } + + def nestedGeneratorError(trimmedNestedGenerator: Expression): Throwable = { + new AnalysisException( + "Generators are not supported when it's nested in " + + "expressions, but got: " + toPrettySQL(trimmedNestedGenerator)) + } + + def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = { + new AnalysisException( + s"Only one generator allowed per $clause clause but found " + + generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + } + + def generatorOutsideSelectError(plan: LogicalPlan): Throwable = { + new AnalysisException( + "Generators are not supported outside the SELECT clause, but " + + "got: " + plan.simpleString(SQLConf.get.maxToStringFields)) + } + + def legacyStoreAssignmentPolicyError(): Throwable = { + val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key + new AnalysisException( + "LEGACY store assignment policy is disallowed in Spark data source V2. " + + s"Please set the configuration $configKey to other values.") + } + + def unresolvedUsingColForJoinError( + colName: String, plan: LogicalPlan, side: String): Throwable = { + new AnalysisException( + s"USING column `$colName` cannot be resolved on the $side " + + s"side of the join. The $side-side columns: [${plan.output.map(_.name).mkString(", ")}]") + } + + def dataTypeMismatchForDeserializerError( + dataType: DataType, desiredType: String): Throwable = { + val quantifier = if (desiredType.equals("array")) "an" else "a" + new AnalysisException( + s"need $quantifier $desiredType field but got " + dataType.catalogString) + } + + def fieldNumberMismatchForDeserializerError( + schema: StructType, maxOrdinal: Int): Throwable = { + new AnalysisException( + s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.") + } + + def upCastFailureError( + fromStr: String, from: Expression, to: DataType, walkedTypePath: Seq[String]): Throwable = { + new AnalysisException( + s"Cannot up cast $fromStr from " + + s"${from.dataType.catalogString} to ${to.catalogString}.\n" + + s"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + def unsupportedAbstractDataTypeForUpCastError(gotType: AbstractDataType): Throwable = { + new AnalysisException( + s"UpCast only support DecimalType as AbstractDataType yet, but got: $gotType") + } + + def outerScopeFailureForNewInstanceError(className: String): Throwable = { + new AnalysisException( + s"Unable to generate an encoder for inner class `$className` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + + def referenceColNotFoundForAlterTableChangesError( + after: TableChange.After, parentName: String): Throwable = { + new AnalysisException( + s"Couldn't find the reference column for $after at $parentName") + } + +} + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8d95d8cf49d45..53c0ff687c6d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy} @@ -448,9 +449,7 @@ class Analyzer(override val catalogManager: CatalogManager) e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { Alias(gid, toPrettySQL(e))() } else { - throw new AnalysisException( - s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + - s"grouping columns (${groupByExprs.mkString(",")})") + throw QueryCompilationErrors.groupingIDMismatchError(e, groupByExprs) } case e @ Grouping(col: Expression) => val idx = groupByExprs.indexWhere(_.semanticEquals(col)) @@ -458,8 +457,7 @@ class Analyzer(override val catalogManager: CatalogManager) Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1L)), ByteType), toPrettySQL(e))() } else { - throw new AnalysisException(s"Column of grouping ($col) can't be found " + - s"in grouping columns ${groupByExprs.mkString(",")}") + throw QueryCompilationErrors.groupingColInvalidError(col, groupByExprs) } } } @@ -575,8 +573,7 @@ class Analyzer(override val catalogManager: CatalogManager) val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs) if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) { - throw new AnalysisException( - s"Grouping sets size cannot be greater than ${GroupingID.dataType.defaultSize * 8}") + throw QueryCompilationErrors.groupingSizeTooLargeError(GroupingID.dataType.defaultSize * 8) } // Expand works by setting grouping expressions to null as determined by the @@ -712,8 +709,7 @@ class Analyzer(override val catalogManager: CatalogManager) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => if (!RowOrdering.isOrderable(pivotColumn.dataType)) { - throw new AnalysisException( - s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) } // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) @@ -724,13 +720,10 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => value.foldable } if (!foldable) { - throw new AnalysisException( - s"Literal expressions required for pivot values, found '$value'") + throw QueryCompilationErrors.nonLiteralPivotValError(value) } if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { - throw new AnalysisException(s"Invalid pivot value '$value': " + - s"value data type ${value.dataType.simpleString} does not match " + - s"pivot column data type ${pivotColumn.dataType.catalogString}") + throw QueryCompilationErrors.pivotValDataTypeMismatchError(value, pivotColumn) } Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) } @@ -1167,8 +1160,7 @@ class Analyzer(override val catalogManager: CatalogManager) case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved => // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { - throw new AnalysisException( - s"Cannot write, IF NOT EXISTS is not supported for table: ${r.table.name}") + throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name) } val partCols = partitionColumnNames(r.table) @@ -1205,8 +1197,7 @@ class Analyzer(override val catalogManager: CatalogManager) partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { case Some(_) => case None => - throw new AnalysisException( - s"PARTITION clause cannot contain a non-partition column name: $partitionName") + throw QueryCompilationErrors.nonPartitionColError(partitionName) } } } @@ -1228,8 +1219,7 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(attr) => attr.name -> staticName case _ => - throw new AnalysisException( - s"Cannot add static value for unknown column: $staticName") + throw QueryCompilationErrors.addStaticValToUnknownColError(staticName) }).toMap val queryColumns = query.output.iterator @@ -1271,7 +1261,7 @@ class Analyzer(override val catalogManager: CatalogManager) // an UnresolvedAttribute. EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) case None => - throw new AnalysisException(s"Unknown static partition column: $name") + throw QueryCompilationErrors.unknownStaticPartitionColError(name) } }.reduce(And) } @@ -2483,23 +2473,19 @@ class Analyzer(override val catalogManager: CatalogManager) def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get - throw new AnalysisException("Generators are not supported when it's nested in " + - "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Project(projectList, _) if projectList.count(hasGenerator) > 1 => val generators = projectList.filter(hasGenerator).map(trimAlias) - throw new AnalysisException("Only one generator allowed per select clause but found " + - generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "select") case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) => val nestedGenerator = aggList.find(hasNestedGenerator).get - throw new AnalysisException("Generators are not supported when it's nested in " + - "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 => val generators = aggList.filter(hasGenerator).map(trimAlias) - throw new AnalysisException("Only one generator allowed per aggregate clause but found " + - generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate") case agg @ Aggregate(groupList, aggList, child) if aggList.forall { case AliasedGenerator(_, _, _) => true @@ -2582,8 +2568,7 @@ class Analyzer(override val catalogManager: CatalogManager) case g: Generate => g case p if p.expressions.exists(hasGenerator) => - throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + - "got: " + p.simpleString(SQLConf.get.maxToStringFields)) + throw QueryCompilationErrors.generatorOutsideSelectError(p) } } @@ -3122,10 +3107,7 @@ class Analyzer(override val catalogManager: CatalogManager) private def validateStoreAssignmentPolicy(): Unit = { // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { - val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key - throw new AnalysisException(s""" - |"LEGACY" store assignment policy is disallowed in Spark data source V2. - |Please set the configuration $configKey to other values.""".stripMargin) + throw QueryCompilationErrors.legacyStoreAssignmentPolicyError() } } @@ -3138,14 +3120,12 @@ class Analyzer(override val catalogManager: CatalogManager) hint: JoinHint) = { val leftKeys = joinNames.map { keyName => left.output.find(attr => resolver(attr.name, keyName)).getOrElse { - throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + - s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]") + throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, left, "left") } } val rightKeys = joinNames.map { keyName => right.output.find(attr => resolver(attr.name, keyName)).getOrElse { - throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " + - s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]") + throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, right, "right") } } val joinPairs = leftKeys.zip(rightKeys) @@ -3208,7 +3188,8 @@ class Analyzer(override val catalogManager: CatalogManager) ExtractValue(child, fieldName, resolver) } case other => - throw new AnalysisException("need an array field but got " + other.catalogString) + throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, + "array") } case u: UnresolvedCatalystToExternalMap if u.child.resolved => u.child.dataType match { @@ -3218,7 +3199,7 @@ class Analyzer(override val catalogManager: CatalogManager) ExtractValue(child, fieldName, resolver) } case other => - throw new AnalysisException("need a map field but got " + other.catalogString) + throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, "map") } } validateNestedTupleFields(result) @@ -3227,8 +3208,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def fail(schema: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" + - ", but failed as the number of fields does not line up.") + throw QueryCompilationErrors.fieldNumberMismatchForDeserializerError(schema, maxOrdinal) } /** @@ -3287,10 +3267,7 @@ class Analyzer(override val catalogManager: CatalogManager) case n: NewInstance if n.childrenResolved && !n.resolved => val outer = OuterScopes.getOuterScope(n.cls) if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + - "access to the scope that this class was defined in.\n" + - "Try moving this class out of its parent class.") + throw QueryCompilationErrors.outerScopeFailureForNewInstanceError(n.cls.getName) } n.copy(outerPointer = Some(outer)) } @@ -3306,11 +3283,7 @@ class Analyzer(override val catalogManager: CatalogManager) case l: LambdaVariable => "array element" case e => e.sql } - throw new AnalysisException(s"Cannot up cast $fromStr from " + - s"${from.dataType.catalogString} to ${to.catalogString}.\n" + - "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + - "You can either add an explicit cast to the input data or choose a higher precision " + - "type of the field in the target object") + throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath) } def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { @@ -3321,8 +3294,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UpCast(child, _, _) if !child.resolved => u case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => - throw new AnalysisException( - s"UpCast only support DecimalType as AbstractDataType yet, but got: $target") + throw QueryCompilationErrors.unsupportedAbstractDataTypeForUpCastError(target) case UpCast(child, target, walkedTypePath) if target == DecimalType && child.dataType.isInstanceOf[DecimalType] => @@ -3501,8 +3473,8 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(colName) => ColumnPosition.after(colName) case None => - throw new AnalysisException("Couldn't find the reference column for " + - s"$after at $parentName") + throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError(after, + parentName) } case other => other } From 67c6ed90682455dbc866e43709fd9081dfc15ad9 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Sat, 21 Nov 2020 10:27:00 +0900 Subject: [PATCH 14/38] [SPARK-33223][SS][FOLLOWUP] Clarify the meaning of "number of rows dropped by watermark" in SS UI page ### What changes were proposed in this pull request? This PR fixes the representation to clarify the meaning of "number of rows dropped by watermark" in SS UI page. ### Why are the changes needed? `Aggregated Number Of State Rows Dropped By Watermark` says that the dropped rows are from the state, whereas they're not. We say "evicted from the state" for the case, which is "normal" to emit outputs and reduce memory usage of the state. The metric actually represents the number of "input" rows dropped by watermark, and the meaning of "input" is relative to the "stateful operator". That's a bit confusing as we normally think "input" as "input from source" whereas it's not. ### Does this PR introduce _any_ user-facing change? Yes, UI element & tooltip change. ### How was this patch tested? Only text change in UI, so we know how thing will be changed intuitively. Closes #30439 from HeartSaVioR/SPARK-33223-FOLLOWUP. Authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Jungtaek Lim (HeartSaVioR) --- .../streaming/ui/StreamingQueryStatisticsPage.scala | 10 +++++----- .../spark/sql/streaming/ui/UISeleniumSuite.scala | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index 7d38acfceee81..f48672afb41f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -189,8 +189,8 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) val graphUIDataForNumRowsDroppedByWatermark = new GraphUIData( - "aggregated-num-state-rows-dropped-by-watermark-timeline", - "aggregated-num-state-rows-dropped-by-watermark-histogram", + "aggregated-num-rows-dropped-by-watermark-timeline", + "aggregated-num-rows-dropped-by-watermark-histogram", numRowsDroppedByWatermarkData, minBatchTime, maxBatchTime, @@ -230,11 +230,11 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab)
-
Aggregated Number Of State Rows Dropped By Watermark {SparkUIUtils.tooltip("Aggregated number of state rows dropped by watermark.", "right")}
+
Aggregated Number Of Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all input rows being dropped in stateful operators by watermark. 'Inputs' are relative to operators.", "right")}
- {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} - {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} + {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} + {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} // scalastyle:on } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala index 1a8b28001b8d1..307479db33949 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala @@ -139,7 +139,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B summaryText should contain ("Aggregated Number Of Total State Rows (?)") summaryText should contain ("Aggregated Number Of Updated State Rows (?)") summaryText should contain ("Aggregated State Memory Used In Bytes (?)") - summaryText should contain ("Aggregated Number Of State Rows Dropped By Watermark (?)") + summaryText should contain ("Aggregated Number Of Rows Dropped By Watermark (?)") } } finally { spark.streams.active.foreach(_.stop()) From 530c0a8e28973c57a5d0deec6b15fc29500b6c00 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 20 Nov 2020 18:41:25 -0800 Subject: [PATCH 15/38] [SPARK-33505][SQL][TESTS] Fix adding new partitions by INSERT INTO `InMemoryPartitionTable` ### What changes were proposed in this pull request? 1. Add a hook method to `addPartitionKey()` of `InMemoryTable` which is called per every row. 2. Override `addPartitionKey()` in `InMemoryPartitionTable`, and add partition key every time when new row is inserted to the table. ### Why are the changes needed? To be able to write unified tests for datasources V1 and V2. Currently, INSERT INTO a V1 table creates partitions but the same doesn't work for the custom catalog `InMemoryPartitionTableCatalog` used in DSv2 tests. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By running the affected test suite `DataSourceV2SQLSuite`. Closes #30449 from MaxGekk/insert-into-InMemoryPartitionTable. Authored-by: Max Gekk Signed-off-by: Dongjoon Hyun --- .../connector/InMemoryPartitionTable.scala | 4 ++++ .../spark/sql/connector/InMemoryTable.scala | 3 +++ .../sql/connector/DataSourceV2SQLSuite.scala | 21 +++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala index 1c96bdf3afa20..23987e909aa70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala @@ -92,4 +92,8 @@ class InMemoryPartitionTable( override def partitionExists(ident: InternalRow): Boolean = memoryTablePartitions.containsKey(ident) + + override protected def addPartitionKey(key: Seq[Any]): Unit = { + memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 3b47271a114e2..c93053abc550a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -160,12 +160,15 @@ class InMemoryTable( } } + protected def addPartitionKey(key: Seq[Any]): Unit = {} + def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => val key = getKey(row) dataMap += dataMap.get(key) .map(key -> _.withRow(row)) .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row)) + addPartitionKey(key) }) this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 0e7aec8d80e01..90df4ee08bfc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkException import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connector.catalog._ @@ -35,6 +36,7 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataSourceV2SQLSuite @@ -2538,6 +2540,25 @@ class DataSourceV2SQLSuite } } + test("SPARK-33505: insert into partitioned table") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + sql(s""" + |CREATE TABLE $t (id bigint, city string, data string) + |USING foo + |PARTITIONED BY (id, city)""".stripMargin) + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable] + val expectedPartitionIdent = InternalRow.fromSeq(Seq(1, UTF8String.fromString("NY"))) + assert(!partTable.partitionExists(expectedPartitionIdent)) + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'abc'") + assert(partTable.partitionExists(expectedPartitionIdent)) + // Insert into the existing partition must not fail + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'def'") + assert(partTable.partitionExists(expectedPartitionIdent)) + } + } + private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") From b623c03456be12169de7d3823f191ae6774e33ce Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Fri, 20 Nov 2020 18:45:17 -0800 Subject: [PATCH 16/38] [SPARK-32381][CORE][FOLLOWUP][TEST-HADOOP2.7] Don't remove SerializableFileStatus and SerializableBlockLocation for Hadoop 2.7 ### What changes were proposed in this pull request? Revert the change in #29959 and don't remove `SerializableFileStatus` and `SerializableBlockLocation`. ### Why are the changes needed? In Hadoop 2.7 `FileStatus` and `BlockLocation` are not serializable, so we still need the two wrapper classes. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A Closes #30447 from sunchao/SPARK-32381-followup. Authored-by: Chao Sun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/util/HadoopFSUtils.scala | 61 ++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala index a3a528cddee37..4af48d5b9125c 100644 --- a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala @@ -136,12 +136,53 @@ private[spark] object HadoopFSUtils extends Logging { parallelismMax = 0) (path, leafFiles) }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) }.collect() } finally { sc.setJobDescription(previousJobDescription) } - statusMap.toSeq + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } } // scalastyle:off argcount @@ -291,4 +332,22 @@ private[spark] object HadoopFSUtils extends Logging { resolvedLeafStatuses } // scalastyle:on argcount + + /** A serializable variant of HDFS's BlockLocation. This is required by Hadoop 2.7. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. This is required by Hadoop 2.7. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) } From cf7490112ab81cce4a483c2a94368ce3d9d986df Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 20 Nov 2020 19:01:58 -0800 Subject: [PATCH 17/38] Revert "[SPARK-28704][SQL][TEST] Add back Skiped HiveExternalCatalogVersionsSuite in HiveSparkSubmitSuite at JDK9+" This reverts commit 47326ac1c6a296a84af76d832061741740ae9f12. --- .../HiveExternalCatalogVersionsSuite.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 4cafd3e8ca626..38a8c492d77a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -52,6 +52,7 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { import HiveExternalCatalogVersionsSuite._ + private val isTestAtLeastJava9 = SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9) private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `spark.test.cache-dir` to a static value like `/tmp/test-spark`, to @@ -148,9 +149,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) } - override def beforeAll(): Unit = { - super.beforeAll() - + private def prepare(): Unit = { val tempPyFile = File.createTempFile("test", ".py") // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, @@ -200,7 +199,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=2.3.7", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--conf", s"spark.sql.test.version.index=$index", @@ -212,14 +211,23 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { tempPyFile.delete() } + override def beforeAll(): Unit = { + super.beforeAll() + if (!isTestAtLeastJava9) { + prepare() + } + } + test("backward compatibility") { + // TODO SPARK-28704 Test backward compatibility on JDK9+ once we have a version supports JDK9+ + assume(!isTestAtLeastJava9) val args = Seq( "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"), "--name", "HiveExternalCatalog backward compatibility test", "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=2.3.7", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", @@ -244,9 +252,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // do not throw exception during object initialization. case NonFatal(_) => Seq("3.0.1", "2.4.7") // A temporary fallback to use a specific version } - versions - .filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) - .filter(v => v.startsWith("3") || !SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) + versions.filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) } protected var spark: SparkSession = _ From 517b810dfa5076c3d0155d1e134dc93317ec3ec0 Mon Sep 17 00:00:00 2001 From: Gustavo Martin Morcuende Date: Sat, 21 Nov 2020 08:39:16 -0800 Subject: [PATCH 18/38] [SPARK-33463][SQL] Keep Job Id during incremental collect in Spark Thrift Server ### What changes were proposed in this pull request? When enabling **spark.sql.thriftServer.incrementalCollect** Job Ids get lost and tracing queries in Spark Thrift Server ends up being too complicated. ### Why are the changes needed? Because it will make easier tracing Spark Thrift Server queries. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? The current tests are enough. No need of more tests. Closes #30390 from gumartinm/master. Authored-by: Gustavo Martin Morcuende Signed-off-by: Dongjoon Hyun --- .../SparkExecuteStatementOperation.scala | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 2e9975bcabc3f..f7a4be9591818 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -63,6 +63,10 @@ private[hive] class SparkExecuteStatementOperation( } } + private val substitutorStatement = SQLConf.withExistingConf(sqlContext.conf) { + new VariableSubstitution().substitute(statement) + } + private var result: DataFrame = _ // We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST. @@ -126,6 +130,17 @@ private[hive] class SparkExecuteStatementOperation( } def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = withLocalProperties { + try { + sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement) + getNextRowSetInternal(order, maxRowsL) + } finally { + sqlContext.sparkContext.clearJobGroup() + } + } + + private def getNextRowSetInternal( + order: FetchOrientation, + maxRowsL: Long): RowSet = withLocalProperties { log.info(s"Received getNextRowSet request order=${order} and maxRowsL=${maxRowsL} " + s"with ${statementId}") validateDefaultFetchOrientation(order) @@ -306,9 +321,6 @@ private[hive] class SparkExecuteStatementOperation( parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader) } - val substitutorStatement = SQLConf.withExistingConf(sqlContext.conf) { - new VariableSubstitution().substitute(statement) - } sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement) result = sqlContext.sql(statement) logDebug(result.queryExecution.toString()) From d7f4b2ad50aa7acdb0392bb400fc0c87491c6e45 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sun, 22 Nov 2020 10:29:15 -0800 Subject: [PATCH 19/38] [SPARK-28704][SQL][TEST] Add back Skiped HiveExternalCatalogVersionsSuite in HiveSparkSubmitSuite at JDK9+ ### What changes were proposed in this pull request? We skip test HiveExternalCatalogVersionsSuite when testing with JAVA_9 or later because our previous version does not support JAVA_9 or later. We now add it back since we have a version supports JAVA_9 or later. ### Why are the changes needed? To recover test coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Check CI logs. Closes #30451 from AngersZhuuuu/SPARK-28704. Authored-by: angerszhu Signed-off-by: Dongjoon Hyun --- .../HiveExternalCatalogVersionsSuite.scala | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 38a8c492d77a7..cf070f4611f3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -52,7 +52,6 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { import HiveExternalCatalogVersionsSuite._ - private val isTestAtLeastJava9 = SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9) private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `spark.test.cache-dir` to a static value like `/tmp/test-spark`, to @@ -60,6 +59,11 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val sparkTestingDir = Option(System.getProperty(SPARK_TEST_CACHE_DIR_SYSTEM_PROPERTY)) .map(new File(_)).getOrElse(Utils.createTempDir(namePrefix = "test-spark")) private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val hiveVersion = if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) { + "2.3.7" + } else { + "1.2.1" + } override def afterAll(): Unit = { try { @@ -149,7 +153,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) } - private def prepare(): Unit = { + override def beforeAll(): Unit = { + super.beforeAll() + val tempPyFile = File.createTempFile("test", ".py") // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, @@ -199,7 +205,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=$hiveVersion", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--conf", s"spark.sql.test.version.index=$index", @@ -211,23 +217,14 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { tempPyFile.delete() } - override def beforeAll(): Unit = { - super.beforeAll() - if (!isTestAtLeastJava9) { - prepare() - } - } - test("backward compatibility") { - // TODO SPARK-28704 Test backward compatibility on JDK9+ once we have a version supports JDK9+ - assume(!isTestAtLeastJava9) val args = Seq( "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"), "--name", "HiveExternalCatalog backward compatibility test", "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=$hiveVersion", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", @@ -252,7 +249,9 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // do not throw exception during object initialization. case NonFatal(_) => Seq("3.0.1", "2.4.7") // A temporary fallback to use a specific version } - versions.filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) + versions + .filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) + .filter(v => v.startsWith("3") || !SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) } protected var spark: SparkSession = _ From d338af3101a4c986b5e979e8fdc63b8551e12d29 Mon Sep 17 00:00:00 2001 From: CC Highman Date: Mon, 23 Nov 2020 08:30:41 +0900 Subject: [PATCH 20/38] [SPARK-31962][SQL] Provide modifiedAfter and modifiedBefore options when filtering from a batch-based file data source ### What changes were proposed in this pull request? Two new options, _modifiiedBefore_ and _modifiedAfter_, is provided expecting a value in 'YYYY-MM-DDTHH:mm:ss' format. _PartioningAwareFileIndex_ considers these options during the process of checking for files, just before considering applied _PathFilters_ such as `pathGlobFilter.` In order to filter file results, a new PathFilter class was derived for this purpose. General house-keeping around classes extending PathFilter was performed for neatness. It became apparent support was needed to handle multiple potential path filters. Logic was introduced for this purpose and the associated tests written. ### Why are the changes needed? When loading files from a data source, there can often times be thousands of file within a respective file path. In many cases I've seen, we want to start loading from a folder path and ideally be able to begin loading files having modification dates past a certain point. This would mean out of thousands of potential files, only the ones with modification dates greater than the specified timestamp would be considered. This saves a ton of time automatically and reduces significant complexity managing this in code. ### Does this PR introduce _any_ user-facing change? This PR introduces an option that can be used with batch-based Spark file data sources. A documentation update was made to reflect an example and usage of the new data source option. **Example Usages** _Load all CSV files modified after date:_ `spark.read.format("csv").option("modifiedAfter","2020-06-15T05:00:00").load()` _Load all CSV files modified before date:_ `spark.read.format("csv").option("modifiedBefore","2020-06-15T05:00:00").load()` _Load all CSV files modified between two dates:_ `spark.read.format("csv").option("modifiedAfter","2019-01-15T05:00:00").option("modifiedBefore","2020-06-15T05:00:00").load() ` ### How was this patch tested? A handful of unit tests were added to support the positive, negative, and edge case code paths. It's also live in a handful of our Databricks dev environments. (quoted from cchighman) Closes #30411 from HeartSaVioR/SPARK-31962. Lead-authored-by: CC Highman Co-authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Jungtaek Lim (HeartSaVioR) --- docs/sql-data-sources-generic-options.md | 37 +++ .../sql/JavaSQLDataSourceExample.java | 16 + examples/src/main/python/sql/datasource.py | 20 ++ examples/src/main/r/RSparkSQLExample.R | 8 + .../examples/sql/SQLDataSourceExample.scala | 21 ++ python/pyspark/sql/readwriter.py | 81 ++++- .../apache/spark/sql/DataFrameReader.scala | 30 ++ .../PartitioningAwareFileIndex.scala | 13 +- .../execution/datasources/pathFilters.scala | 161 +++++++++ .../streaming/FileStreamOptions.scala | 11 + .../spark/sql/FileBasedDataSourceSuite.scala | 32 -- .../datasources/PathFilterStrategySuite.scala | 54 +++ .../datasources/PathFilterSuite.scala | 307 ++++++++++++++++++ .../sql/streaming/FileStreamSourceSuite.scala | 44 ++- 14 files changed, 787 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala diff --git a/docs/sql-data-sources-generic-options.md b/docs/sql-data-sources-generic-options.md index 6bcf48235bced..2e4fc879a435f 100644 --- a/docs/sql-data-sources-generic-options.md +++ b/docs/sql-data-sources-generic-options.md @@ -119,3 +119,40 @@ To load all files recursively, you can use: {% include_example recursive_file_lookup r/RSparkSQLExample.R %} + +### Modification Time Path Filters + +`modifiedBefore` and `modifiedAfter` are options that can be +applied together or separately in order to achieve greater +granularity over which files may load during a Spark batch query. +(Note that Structured Streaming file sources don't support these options.) + +* `modifiedBefore`: an optional timestamp to only include files with +modification times occurring before the specified time. The provided timestamp +must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) +* `modifiedAfter`: an optional timestamp to only include files with +modification times occurring after the specified time. The provided timestamp +must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + +When a timezone option is not provided, the timestamps will be interpreted according +to the Spark session timezone (`spark.sql.session.timeZone`). + +To load files with paths matching a given modified time range, you can use: + +
+
+{% include_example load_with_modified_time_filter scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example load_with_modified_time_filter java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example load_with_modified_time_filter python/sql/datasource.py %} +
+ +
+{% include_example load_with_modified_time_filter r/RSparkSQLExample.R %} +
+
\ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 2295225387a33..46e740d78bffb 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -147,6 +147,22 @@ private static void runGenericFileSourceOptionsExample(SparkSession spark) { // |file1.parquet| // +-------------+ // $example off:load_with_path_glob_filter$ + // $example on:load_with_modified_time_filter$ + Dataset beforeFilterDF = spark.read().format("parquet") + // Only load files modified before 7/1/2020 at 05:30 + .option("modifiedBefore", "2020-07-01T05:30:00") + // Only load files modified after 6/1/2020 at 05:30 + .option("modifiedAfter", "2020-06-01T05:30:00") + // Interpret both times above relative to CST timezone + .option("timeZone", "CST") + .load("examples/src/main/resources/dir1"); + beforeFilterDF.show(); + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // +-------------+ + // $example off:load_with_modified_time_filter$ } private static void runBasicDataSourceExample(SparkSession spark) { diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index eecd8c2d84788..8c146ba0c9455 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -67,6 +67,26 @@ def generic_file_source_options_example(spark): # +-------------+ # $example off:load_with_path_glob_filter$ + # $example on:load_with_modified_time_filter$ + # Only load files modified before 07/1/2050 @ 08:30:00 + df = spark.read.load("examples/src/main/resources/dir1", + format="parquet", modifiedBefore="2050-07-01T08:30:00") + df.show() + # +-------------+ + # | file| + # +-------------+ + # |file1.parquet| + # +-------------+ + # Only load files modified after 06/01/2050 @ 08:30:00 + df = spark.read.load("examples/src/main/resources/dir1", + format="parquet", modifiedAfter="2050-06-01T08:30:00") + df.show() + # +-------------+ + # | file| + # +-------------+ + # +-------------+ + # $example off:load_with_modified_time_filter$ + def basic_datasource_example(spark): # $example on:generic_load_save_functions$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 8685cfb5c05f2..86ad5334248bc 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -144,6 +144,14 @@ df <- read.df("examples/src/main/resources/dir1", "parquet", pathGlobFilter = "* # 1 file1.parquet # $example off:load_with_path_glob_filter$ +# $example on:load_with_modified_time_filter$ +beforeDF <- read.df("examples/src/main/resources/dir1", "parquet", modifiedBefore= "2020-07-01T05:30:00") +# file +# 1 file1.parquet +afterDF <- read.df("examples/src/main/resources/dir1", "parquet", modifiedAfter = "2020-06-01T05:30:00") +# file +# $example off:load_with_modified_time_filter$ + # $example on:manual_save_options_orc$ df <- read.df("examples/src/main/resources/users.orc", "orc") write.orc(df, "users_with_options.orc", orc.bloom.filter.columns = "favorite_color", orc.dictionary.key.threshold = 1.0, orc.column.encoding.direct = "name") diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 2c7abfcd335d1..90c0eeb5ba888 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -81,6 +81,27 @@ object SQLDataSourceExample { // |file1.parquet| // +-------------+ // $example off:load_with_path_glob_filter$ + // $example on:load_with_modified_time_filter$ + val beforeFilterDF = spark.read.format("parquet") + // Files modified before 07/01/2020 at 05:30 are allowed + .option("modifiedBefore", "2020-07-01T05:30:00") + .load("examples/src/main/resources/dir1"); + beforeFilterDF.show(); + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // +-------------+ + val afterFilterDF = spark.read.format("parquet") + // Files modified after 06/01/2020 at 05:30 are allowed + .option("modifiedAfter", "2020-06-01T05:30:00") + .load("examples/src/main/resources/dir1"); + afterFilterDF.show(); + // +-------------+ + // | file| + // +-------------+ + // +-------------+ + // $example off:load_with_modified_time_filter$ } private def runBasicDataSourceExample(spark: SparkSession): Unit = { diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2ed991c87f506..bb31e6a3e09f8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -125,6 +125,12 @@ def option(self, key, value): * ``pathGlobFilter``: an optional glob pattern to only include files with paths matching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior of partition discovery. + * ``modifiedBefore``: an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + * ``modifiedAfter``: an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -149,6 +155,12 @@ def options(self, **options): * ``pathGlobFilter``: an optional glob pattern to only include files with paths matching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior of partition discovery. + * ``modifiedBefore``: an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + * ``modifiedAfter``: an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -203,7 +215,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, dropFieldIfAllNull=None, encoding=None, locale=None, pathGlobFilter=None, - recursiveFileLookup=None, allowNonNumericNumbers=None): + recursiveFileLookup=None, allowNonNumericNumbers=None, + modifiedBefore=None, modifiedAfter=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -322,6 +335,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, ``+Infinity`` and ``Infinity``. * ``-INF``: for negative infinity, alias ``-Infinity``. * ``NaN``: for other not-a-numbers, like result of division by zero. + modifiedBefore : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- @@ -344,6 +364,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, locale=locale, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup, + modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter, allowNonNumericNumbers=allowNonNumericNumbers) if isinstance(path, str): path = [path] @@ -410,6 +431,15 @@ def parquet(self, *paths, **options): disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore (batch only) : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter (batch only) : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') @@ -418,13 +448,18 @@ def parquet(self, *paths, **options): """ mergeSchema = options.get('mergeSchema', None) pathGlobFilter = options.get('pathGlobFilter', None) + modifiedBefore = options.get('modifiedBefore', None) + modifiedAfter = options.get('modifiedAfter', None) recursiveFileLookup = options.get('recursiveFileLookup', None) self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, - recursiveFileLookup=recursiveFileLookup) + recursiveFileLookup=recursiveFileLookup, modifiedBefore=modifiedBefore, + modifiedAfter=modifiedAfter) + return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths))) def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, - recursiveFileLookup=None): + recursiveFileLookup=None, modifiedBefore=None, + modifiedAfter=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there @@ -453,6 +488,15 @@ def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, recursively scan a directory for files. Using this option disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore (batch only) : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter (batch only) : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.text('python/test_support/sql/text-test.txt') @@ -464,7 +508,9 @@ def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, """ self._set_opts( wholetext=wholetext, lineSep=lineSep, pathGlobFilter=pathGlobFilter, - recursiveFileLookup=recursiveFileLookup) + recursiveFileLookup=recursiveFileLookup, modifiedBefore=modifiedBefore, + modifiedAfter=modifiedAfter) + if isinstance(paths, str): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -476,7 +522,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None, - pathGlobFilter=None, recursiveFileLookup=None): + pathGlobFilter=None, recursiveFileLookup=None, modifiedBefore=None, modifiedAfter=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -631,6 +677,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non recursively scan a directory for files. Using this option disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore (batch only) : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter (batch only) : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.csv('python/test_support/sql/ages.csv') @@ -652,7 +707,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep, - pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) + pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup, + modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter) if isinstance(path, str): path = [path] if type(path) == list: @@ -679,7 +735,8 @@ def func(iterator): else: raise TypeError("path can be only string, list or RDD") - def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None): + def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None, + modifiedBefore=None, modifiedAfter=None): """Loads ORC files, returning the result as a :class:`DataFrame`. .. versionadded:: 1.5.0 @@ -701,6 +758,15 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.orc('python/test_support/sql/orc_partitioned') @@ -708,6 +774,7 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N [('a', 'bigint'), ('b', 'int'), ('c', 'int')] """ self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, + modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter, recursiveFileLookup=recursiveFileLookup) if isinstance(path, str): path = [path] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 276d5d29bfa2c..b26bc6441b6cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -493,6 +493,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • *
  • `allowNonNumericNumbers` (default `true`): allows JSON parser to recognize set of @@ -750,6 +756,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * @@ -781,6 +793,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * @@ -814,6 +832,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * @@ -880,6 +904,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index fed9614347f6a..5b0d0606da093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -57,13 +57,10 @@ abstract class PartitioningAwareFileIndex( protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] private val caseInsensitiveMap = CaseInsensitiveMap(parameters) + private val pathFilters = PathFilterFactory.create(caseInsensitiveMap) - protected lazy val pathGlobFilter: Option[GlobFilter] = - caseInsensitiveMap.get("pathGlobFilter").map(new GlobFilter(_)) - - protected def matchGlobPattern(file: FileStatus): Boolean = { - pathGlobFilter.forall(_.accept(file.getPath)) - } + protected def matchPathPattern(file: FileStatus): Boolean = + pathFilters.forall(_.accept(file)) protected lazy val recursiveFileLookup: Boolean = { caseInsensitiveMap.getOrElse("recursiveFileLookup", "false").toBoolean @@ -86,7 +83,7 @@ abstract class PartitioningAwareFileIndex( val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => // Directory has children files in it, return them - existingDir.filter(f => matchGlobPattern(f) && isNonEmptyFile(f)) + existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f)) case None => // Directory does not exist, or has no children files @@ -135,7 +132,7 @@ abstract class PartitioningAwareFileIndex( } else { leafFiles.values.toSeq } - files.filter(matchGlobPattern) + files.filter(matchPathPattern) } protected def inferPartitioning(): PartitionSpec = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala new file mode 100644 index 0000000000000..c8f23988f93c6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Locale, TimeZone} + +import org.apache.hadoop.fs.{FileStatus, GlobFilter} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.unsafe.types.UTF8String + +trait PathFilterStrategy extends Serializable { + def accept(fileStatus: FileStatus): Boolean +} + +trait StrategyBuilder { + def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] +} + +class PathGlobFilter(filePatten: String) extends PathFilterStrategy { + + private val globFilter = new GlobFilter(filePatten) + + override def accept(fileStatus: FileStatus): Boolean = + globFilter.accept(fileStatus.getPath) +} + +object PathGlobFilter extends StrategyBuilder { + val PARAM_NAME = "pathglobfilter" + + override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { + parameters.get(PARAM_NAME).map(new PathGlobFilter(_)) + } +} + +/** + * Provide modifiedAfter and modifiedBefore options when + * filtering from a batch-based file data source. + * + * Example Usages + * Load all CSV files modified after date: + * {{{ + * spark.read.format("csv").option("modifiedAfter","2020-06-15T05:00:00").load() + * }}} + * + * Load all CSV files modified before date: + * {{{ + * spark.read.format("csv").option("modifiedBefore","2020-06-15T05:00:00").load() + * }}} + * + * Load all CSV files modified between two dates: + * {{{ + * spark.read.format("csv").option("modifiedAfter","2019-01-15T05:00:00") + * .option("modifiedBefore","2020-06-15T05:00:00").load() + * }}} + */ +abstract class ModifiedDateFilter extends PathFilterStrategy { + + def timeZoneId: String + + protected def localTime(micros: Long): Long = + DateTimeUtils.fromUTCTime(micros, timeZoneId) +} + +object ModifiedDateFilter { + + def getTimeZoneId(options: CaseInsensitiveMap[String]): String = { + options.getOrElse( + DateTimeUtils.TIMEZONE_OPTION.toLowerCase(Locale.ROOT), + SQLConf.get.sessionLocalTimeZone) + } + + def toThreshold(timeString: String, timeZoneId: String, strategy: String): Long = { + val timeZone: TimeZone = DateTimeUtils.getTimeZone(timeZoneId) + val ts = UTF8String.fromString(timeString) + DateTimeUtils.stringToTimestamp(ts, timeZone.toZoneId).getOrElse { + throw new AnalysisException( + s"The timestamp provided for the '$strategy' option is invalid. The expected format " + + s"is 'YYYY-MM-DDTHH:mm:ss', but the provided timestamp: $timeString") + } + } +} + +/** + * Filter used to determine whether file was modified before the provided timestamp. + */ +class ModifiedBeforeFilter(thresholdTime: Long, val timeZoneId: String) + extends ModifiedDateFilter { + + override def accept(fileStatus: FileStatus): Boolean = + // We standardize on microseconds wherever possible + // getModificationTime returns in milliseconds + thresholdTime - localTime(DateTimeUtils.millisToMicros(fileStatus.getModificationTime)) > 0 +} + +object ModifiedBeforeFilter extends StrategyBuilder { + import ModifiedDateFilter._ + + val PARAM_NAME = "modifiedbefore" + + override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { + parameters.get(PARAM_NAME).map { value => + val timeZoneId = getTimeZoneId(parameters) + val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + new ModifiedBeforeFilter(thresholdTime, timeZoneId) + } + } +} + +/** + * Filter used to determine whether file was modified after the provided timestamp. + */ +class ModifiedAfterFilter(thresholdTime: Long, val timeZoneId: String) + extends ModifiedDateFilter { + + override def accept(fileStatus: FileStatus): Boolean = + // getModificationTime returns in milliseconds + // We standardize on microseconds wherever possible + localTime(DateTimeUtils.millisToMicros(fileStatus.getModificationTime)) - thresholdTime > 0 +} + +object ModifiedAfterFilter extends StrategyBuilder { + import ModifiedDateFilter._ + + val PARAM_NAME = "modifiedafter" + + override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { + parameters.get(PARAM_NAME).map { value => + val timeZoneId = getTimeZoneId(parameters) + val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + new ModifiedAfterFilter(thresholdTime, timeZoneId) + } + } +} + +object PathFilterFactory { + + private val strategies = + Seq(PathGlobFilter, ModifiedBeforeFilter, ModifiedAfterFilter) + + def create(parameters: CaseInsensitiveMap[String]): Seq[PathFilterStrategy] = { + strategies.flatMap { _.create(parameters) } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 712ed1585bc8a..6f43542fd6595 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -23,6 +23,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.{ModifiedAfterFilter, ModifiedBeforeFilter} import org.apache.spark.util.Utils /** @@ -32,6 +33,16 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + checkDisallowedOptions(parameters) + + private def checkDisallowedOptions(options: Map[String, String]): Unit = { + Seq(ModifiedBeforeFilter.PARAM_NAME, ModifiedAfterFilter.PARAM_NAME).foreach { param => + if (parameters.contains(param)) { + throw new IllegalArgumentException(s"option '$param' is not allowed in file stream sources") + } + } + } + val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str => Try(str.toInt).toOption.filter(_ > 0).getOrElse { throw new IllegalArgumentException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index b27c1145181bd..876f62803dc7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -577,38 +577,6 @@ class FileBasedDataSourceSuite extends QueryTest } } - test("Option pathGlobFilter: filter files correctly") { - withTempPath { path => - val dataDir = path.getCanonicalPath - Seq("foo").toDS().write.text(dataDir) - Seq("bar").toDS().write.mode("append").orc(dataDir) - val df = spark.read.option("pathGlobFilter", "*.txt").text(dataDir) - checkAnswer(df, Row("foo")) - - // Both glob pattern in option and path should be effective to filter files. - val df2 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*.orc") - checkAnswer(df2, Seq.empty) - - val df3 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*xt") - checkAnswer(df3, Row("foo")) - } - } - - test("Option pathGlobFilter: simple extension filtering should contains partition info") { - withTempPath { path => - val input = Seq(("foo", 1), ("oof", 2)).toDF("a", "b") - input.write.partitionBy("b").text(path.getCanonicalPath) - Seq("bar").toDS().write.mode("append").orc(path.getCanonicalPath + "/b=1") - - // If we use glob pattern in the path, the partition column won't be shown in the result. - val df = spark.read.text(path.getCanonicalPath + "/*/*.txt") - checkAnswer(df, input.select("a")) - - val df2 = spark.read.option("pathGlobFilter", "*.txt").text(path.getCanonicalPath) - checkAnswer(df2, input) - } - } - test("Option recursiveFileLookup: recursive loading correctly") { val expectedFileList = mutable.ListBuffer[String]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala new file mode 100644 index 0000000000000..b965a78c9eec0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.test.SharedSparkSession + +class PathFilterStrategySuite extends QueryTest with SharedSparkSession { + + test("SPARK-31962: PathFilterStrategies - modifiedAfter option") { + val options = + CaseInsensitiveMap[String](Map("modifiedAfter" -> "2010-10-01T01:01:00")) + val strategy = PathFilterFactory.create(options) + assert(strategy.head.isInstanceOf[ModifiedAfterFilter]) + assert(strategy.size == 1) + } + + test("SPARK-31962: PathFilterStrategies - modifiedBefore option") { + val options = + CaseInsensitiveMap[String](Map("modifiedBefore" -> "2020-10-01T01:01:00")) + val strategy = PathFilterFactory.create(options) + assert(strategy.head.isInstanceOf[ModifiedBeforeFilter]) + assert(strategy.size == 1) + } + + test("SPARK-31962: PathFilterStrategies - pathGlobFilter option") { + val options = CaseInsensitiveMap[String](Map("pathGlobFilter" -> "*.txt")) + val strategy = PathFilterFactory.create(options) + assert(strategy.head.isInstanceOf[PathGlobFilter]) + assert(strategy.size == 1) + } + + test("SPARK-31962: PathFilterStrategies - no options") { + val options = CaseInsensitiveMap[String](Map.empty) + val strategy = PathFilterFactory.create(options) + assert(strategy.isEmpty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala new file mode 100644 index 0000000000000..1af2adfd8640c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.time.{LocalDateTime, ZoneId, ZoneOffset} +import java.time.format.DateTimeFormatter + +import scala.util.Random + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.util.{stringToFile, DateTimeUtils} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class PathFilterSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("SPARK-31962: modifiedBefore specified" + + " and sharing same timestamp with file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formatTime(curTime))) + } + } + + test("SPARK-31962: modifiedAfter specified" + + " and sharing same timestamp with file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + executeTest(dir, Seq(curTime), 0, modifiedAfter = Some(formatTime(curTime))) + } + } + + test("SPARK-31962: modifiedBefore and modifiedAfter option" + + " share same timestamp with file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime), + modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore and modifiedAfter option" + + " share same timestamp with earlier file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val fileTime = curTime.minusDays(3) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(fileTime), 0, modifiedBefore = Some(formattedTime), + modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore and modifiedAfter option" + + " share same timestamp with later file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime), + modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: when modifiedAfter specified with a past date") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val pastTime = curTime.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(curTime), 1, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: when modifiedBefore specified with a future date") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val futureTime = curTime.plusYears(1) + val formattedTime = formatTime(futureTime) + executeTest(dir, Seq(curTime), 1, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: with modifiedBefore option provided using a past date") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val pastTime = curTime.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedAfter specified with a past date, multiple files, one valid") { + withTempDir { dir => + val fileTime1 = LocalDateTime.now(ZoneOffset.UTC) + val fileTime2 = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val pastTime = fileTime1.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(fileTime1, fileTime2), 1, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedAfter specified with a past date, multiple files, both valid") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val pastTime = curTime.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(curTime, curTime), 2, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedAfter specified with a past date, multiple files, none valid") { + withTempDir { dir => + val fileTime = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val pastTime = LocalDateTime.now(ZoneOffset.UTC).minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(fileTime, fileTime), 0, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore specified with a future date, multiple files, both valid") { + withTempDir { dir => + val fileTime = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val futureTime = LocalDateTime.now(ZoneOffset.UTC).plusYears(1) + val formattedTime = formatTime(futureTime) + executeTest(dir, Seq(fileTime, fileTime), 2, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore specified with a future date, multiple files, one valid") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val fileTime1 = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val fileTime2 = curTime.plusDays(3) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(fileTime1, fileTime2), 1, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore specified with a future date, multiple files, none valid") { + withTempDir { dir => + val fileTime = LocalDateTime.now(ZoneOffset.UTC).minusDays(1) + val formattedTime = formatTime(fileTime) + executeTest(dir, Seq(fileTime, fileTime), 0, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore/modifiedAfter is specified with an invalid date") { + executeTestWithBadOption( + Map("modifiedBefore" -> "2024-05+1 01:00:00"), + Seq("The timestamp provided", "modifiedbefore", "2024-05+1 01:00:00")) + + executeTestWithBadOption( + Map("modifiedAfter" -> "2024-05+1 01:00:00"), + Seq("The timestamp provided", "modifiedafter", "2024-05+1 01:00:00")) + } + + test("SPARK-31962: modifiedBefore/modifiedAfter - empty option") { + executeTestWithBadOption( + Map("modifiedBefore" -> ""), + Seq("The timestamp provided", "modifiedbefore")) + + executeTestWithBadOption( + Map("modifiedAfter" -> ""), + Seq("The timestamp provided", "modifiedafter")) + } + + test("SPARK-31962: modifiedBefore/modifiedAfter filter takes into account local timezone " + + "when specified as an option.") { + Seq("modifiedbefore", "modifiedafter").foreach { filterName => + // CET = UTC + 1 hour, HST = UTC - 10 hours + Seq("CET", "HST").foreach { tzId => + testModifiedDateFilterWithTimezone(tzId, filterName) + } + } + } + + test("Option pathGlobFilter: filter files correctly") { + withTempPath { path => + val dataDir = path.getCanonicalPath + Seq("foo").toDS().write.text(dataDir) + Seq("bar").toDS().write.mode("append").orc(dataDir) + val df = spark.read.option("pathGlobFilter", "*.txt").text(dataDir) + checkAnswer(df, Row("foo")) + + // Both glob pattern in option and path should be effective to filter files. + val df2 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*.orc") + checkAnswer(df2, Seq.empty) + + val df3 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*xt") + checkAnswer(df3, Row("foo")) + } + } + + test("Option pathGlobFilter: simple extension filtering should contains partition info") { + withTempPath { path => + val input = Seq(("foo", 1), ("oof", 2)).toDF("a", "b") + input.write.partitionBy("b").text(path.getCanonicalPath) + Seq("bar").toDS().write.mode("append").orc(path.getCanonicalPath + "/b=1") + + // If we use glob pattern in the path, the partition column won't be shown in the result. + val df = spark.read.text(path.getCanonicalPath + "/*/*.txt") + checkAnswer(df, input.select("a")) + + val df2 = spark.read.option("pathGlobFilter", "*.txt").text(path.getCanonicalPath) + checkAnswer(df2, input) + } + } + + private def executeTest( + dir: File, + fileDates: Seq[LocalDateTime], + expectedCount: Long, + modifiedBefore: Option[String] = None, + modifiedAfter: Option[String] = None): Unit = { + fileDates.foreach { fileDate => + val file = createSingleFile(dir) + setFileTime(fileDate, file) + } + + val schema = StructType(Seq(StructField("a", StringType))) + + var dfReader = spark.read.format("csv").option("timeZone", "UTC").schema(schema) + modifiedBefore.foreach { opt => dfReader = dfReader.option("modifiedBefore", opt) } + modifiedAfter.foreach { opt => dfReader = dfReader.option("modifiedAfter", opt) } + + if (expectedCount > 0) { + // without pathGlobFilter + val df1 = dfReader.load(dir.getCanonicalPath) + assert(df1.count() === expectedCount) + + // pathGlobFilter matched + val df2 = dfReader.option("pathGlobFilter", "*.csv").load(dir.getCanonicalPath) + assert(df2.count() === expectedCount) + + // pathGlobFilter mismatched + val df3 = dfReader.option("pathGlobFilter", "*.txt").load(dir.getCanonicalPath) + assert(df3.count() === 0) + } else { + val df = dfReader.load(dir.getCanonicalPath) + assert(df.count() === 0) + } + } + + private def executeTestWithBadOption( + options: Map[String, String], + expectedMsgParts: Seq[String]): Unit = { + withTempDir { dir => + createSingleFile(dir) + val exc = intercept[AnalysisException] { + var dfReader = spark.read.format("csv") + options.foreach { case (key, value) => + dfReader = dfReader.option(key, value) + } + dfReader.load(dir.getCanonicalPath) + } + expectedMsgParts.foreach { msg => assert(exc.getMessage.contains(msg)) } + } + } + + private def testModifiedDateFilterWithTimezone( + timezoneId: String, + filterParamName: String): Unit = { + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val zoneId: ZoneId = DateTimeUtils.getTimeZone(timezoneId).toZoneId + val strategyTimeInMicros = + ModifiedDateFilter.toThreshold( + curTime.toString, + timezoneId, + filterParamName) + val strategyTimeInSeconds = strategyTimeInMicros / 1000 / 1000 + + val curTimeAsSeconds = curTime.atZone(zoneId).toEpochSecond + withClue(s"timezone: $timezoneId / param: $filterParamName,") { + assert(strategyTimeInSeconds === curTimeAsSeconds) + } + } + + private def createSingleFile(dir: File): File = { + val file = new File(dir, "temp" + Random.nextInt(1000000) + ".csv") + stringToFile(file, "text") + } + + private def setFileTime(time: LocalDateTime, file: File): Boolean = { + val sameTime = time.toEpochSecond(ZoneOffset.UTC) + file.setLastModified(sameTime * 1000) + } + + private def formatTime(time: LocalDateTime): String = { + time.format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index cf9664a9764be..718095003b096 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.streaming import java.io.File import java.net.URI +import java.time.{LocalDateTime, ZoneOffset} +import java.time.format.DateTimeFormatter import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable @@ -40,7 +42,6 @@ import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ import org.apache.spark.sql.types.{StructType, _} import org.apache.spark.util.Utils @@ -2054,6 +2055,47 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-31962: file stream source shouldn't allow modifiedBefore/modifiedAfter") { + def formatTime(time: LocalDateTime): String = { + time.format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss")) + } + + def assertOptionIsNotSupported(options: Map[String, String], path: String): Unit = { + val schema = StructType(Seq(StructField("a", StringType))) + var dsReader = spark.readStream + .format("csv") + .option("timeZone", "UTC") + .schema(schema) + + options.foreach { case (k, v) => dsReader = dsReader.option(k, v) } + + val df = dsReader.load(path) + + testStream(df)( + ExpectFailure[IllegalArgumentException]( + t => assert(t.getMessage.contains("is not allowed in file stream source")), + isFatalError = false) + ) + } + + withTempDir { dir => + // "modifiedBefore" + val futureTime = LocalDateTime.now(ZoneOffset.UTC).plusYears(1) + val formattedFutureTime = formatTime(futureTime) + assertOptionIsNotSupported(Map("modifiedBefore" -> formattedFutureTime), dir.getCanonicalPath) + + // "modifiedAfter" + val prevTime = LocalDateTime.now(ZoneOffset.UTC).minusYears(1) + val formattedPrevTime = formatTime(prevTime) + assertOptionIsNotSupported(Map("modifiedAfter" -> formattedPrevTime), dir.getCanonicalPath) + + // both + assertOptionIsNotSupported( + Map("modifiedBefore" -> formattedFutureTime, "modifiedAfter" -> formattedPrevTime), + dir.getCanonicalPath) + } + } + private def createFile(content: String, src: File, tmp: File): File = { val tempFile = Utils.tempFileWith(new File(tmp, "text")) val finalFile = new File(src, tempFile.getName) From 6d625ccd5b5a76a149e2070df31984610629a295 Mon Sep 17 00:00:00 2001 From: ulysses Date: Sun, 22 Nov 2020 15:36:44 -0800 Subject: [PATCH 21/38] [SPARK-33469][SQL] Add current_timezone function ### What changes were proposed in this pull request? Add a `CurrentTimeZone` function and replace the value at `Optimizer` side. ### Why are the changes needed? Let user get current timezone easily. Then user can call ``` SELECT current_timezone() ``` Presto: https://prestodb.io/docs/current/functions/datetime.html SQL Server: https://docs.microsoft.com/en-us/sql/t-sql/functions/current-timezone-transact-sql?view=sql-server-ver15 ### Does this PR introduce _any_ user-facing change? Yes, a new function. ### How was this patch tested? Add test. Closes #30400 from ulysses-you/SPARK-33469. Lead-authored-by: ulysses Co-authored-by: ulysses-you Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/datetimeExpressions.scala | 15 +++++++++++++++ .../sql/catalyst/optimizer/finishAnalysis.scala | 3 +++ .../optimizer/ComputeCurrentTimeSuite.scala | 16 +++++++++++++++- .../sql-functions/sql-expression-schema.md | 3 ++- .../org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ .../sql/expressions/ExpressionInfoSuite.scala | 1 + 7 files changed, 45 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 508239077a70e..6fb9bed9625d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -391,6 +391,7 @@ object FunctionRegistry { expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[CurrentTimeZone]("current_timezone"), expression[DateDiff]("datediff"), expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 97aacb3f7530c..9953b780ceace 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -73,6 +73,21 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { } } +@ExpressionDescription( + usage = "_FUNC_() - Returns the current session local timezone.", + examples = """ + Examples: + > SELECT _FUNC_(); + Asia/Shanghai + """, + group = "datetime_funcs", + since = "3.1.0") +case class CurrentTimeZone() extends LeafExpression with Unevaluable { + override def nullable: Boolean = false + override def dataType: DataType = StringType + override def prettyName: String = "current_timezone" +} + /** * Returns the current date at the start of query evaluation. * There is no code generation since this expression should get constant folded by the optimizer. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 9aa7e3201ab1b..1f2389176d1e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -75,6 +76,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { val timeExpr = CurrentTimestamp() val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long] val currentTime = Literal.create(timestamp, timeExpr.dataType) + val timezone = Literal.create(SQLConf.get.sessionLocalTimeZone, StringType) plan transformAllExpressions { case currentDate @ CurrentDate(Some(timeZoneId)) => @@ -84,6 +86,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { DateType) }) case CurrentTimestamp() | Now() => currentTime + case CurrentTimeZone() => timezone } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index db0399d2a73ee..82d6757407b51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.catalyst.optimizer import java.time.ZoneId import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.unsafe.types.UTF8String class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -67,4 +69,16 @@ class ComputeCurrentTimeSuite extends PlanTest { assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } + + test("SPARK-33469: Add current_timezone function") { + val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val lits = new scala.collection.mutable.ArrayBuffer[String] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[UTF8String].toString + e + } + assert(lits.size == 1) + assert(lits.head == SQLConf.get.sessionLocalTimeZone) + } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index da83df4994d8d..0a54dff3a1cea 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ ## Summary - - Number of queries: 341 + - Number of queries: 342 - Number of expressions that missing example: 13 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window ## Schema of Built-in Functions @@ -86,6 +86,7 @@ | org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct | +| org.apache.spark.sql.catalyst.expressions.CurrentTimeZone | current_timezone | SELECT current_timezone() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentTimestamp | current_timestamp | SELECT current_timestamp() | struct | | org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6a1378837ea9b..953a58760cd5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1947,6 +1947,14 @@ class DatasetSuite extends QueryTest df.where($"zoo".contains(Array('a', 'b'))), Seq(Row("abc"))) } + + test("SPARK-33469: Add current_timezone function") { + val df = Seq(1).toDF("c") + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Shanghai") { + val timezone = df.selectExpr("current_timezone()").collect().head.getString(0) + assert(timezone == "Asia/Shanghai") + } + } } object AssertExecutionId { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 9f62ff8301ebc..6085c1f2cccb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -149,6 +149,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { "org.apache.spark.sql.catalyst.expressions.UnixTimestamp", "org.apache.spark.sql.catalyst.expressions.CurrentDate", "org.apache.spark.sql.catalyst.expressions.CurrentTimestamp", + "org.apache.spark.sql.catalyst.expressions.CurrentTimeZone", "org.apache.spark.sql.catalyst.expressions.Now", // Random output without a seed "org.apache.spark.sql.catalyst.expressions.Rand", From df4a1c2256b71c9a1bd2006819135f56c99a2f21 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 22 Nov 2020 16:40:54 -0800 Subject: [PATCH 22/38] [SPARK-33512][BUILD] Upgrade test libraries ### What changes were proposed in this pull request? This PR aims to update the test libraries. - ScalaTest: 3.2.0 -> 3.2.3 - JUnit: 4.12 -> 4.13.1 - Mockito: 3.1.0 -> 3.4.6 - JMock: 2.8.4 -> 2.12.0 - maven-surefire-plugin: 3.0.0-M3 -> 3.0.0-M5 - scala-maven-plugin: 4.3.0 -> 4.4.0 ### Why are the changes needed? This will make the test frameworks up-to-date for Apache Spark 3.1.0. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. Closes #30456 from dongjoon-hyun/SPARK-33512. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- pom.xml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pom.xml b/pom.xml index 85cf5a00b0b24..0ab5a8c5b3efa 100644 --- a/pom.xml +++ b/pom.xml @@ -931,7 +931,7 @@ org.scalatest scalatest_${scala.binary.version} - 3.2.0 + 3.2.3 test @@ -955,14 +955,14 @@ org.mockito mockito-core - 3.1.0 + 3.4.6 test org.jmock jmock-junit4 test - 2.8.4 + 2.12.0 org.scalacheck @@ -973,7 +973,7 @@ junit junit - 4.12 + 4.13.1 test @@ -2498,7 +2498,7 @@ net.alchim31.maven scala-maven-plugin - 4.3.0 + 4.4.0 eclipse-add-source @@ -2573,7 +2573,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M3 + 3.0.0-M5 From a45923852342ce3f9454743a71740b09e6efe859 Mon Sep 17 00:00:00 2001 From: William Hyun Date: Mon, 23 Nov 2020 10:38:40 +0900 Subject: [PATCH 23/38] [MINOR][INFRA] Suppress warning in check-license ### What changes were proposed in this pull request? This PR aims to suppress the warning `File exists` in check-license ### Why are the changes needed? **BEFORE** ``` % dev/check-license Attempting to fetch rat RAT checks passed. % dev/check-license mkdir: target: File exists RAT checks passed. ``` **AFTER** ``` % dev/check-license Attempting to fetch rat RAT checks passed. % dev/check-license RAT checks passed. ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually do dev/check-license twice. Closes #30460 from williamhyun/checklicense. Authored-by: William Hyun Signed-off-by: HyukjinKwon --- dev/check-license | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/check-license b/dev/check-license index 0cc17ffe55c67..bd255954d6db4 100755 --- a/dev/check-license +++ b/dev/check-license @@ -67,7 +67,7 @@ mkdir -p "$FWDIR"/lib exit 1 } -mkdir target +mkdir -p target $java_cmd -jar "$rat_jar" -E "$FWDIR"/dev/.rat-excludes -d "$FWDIR" > target/rat-results.txt if [ $? -ne 0 ]; then From aa78c05edc9cb910cca9fb14f7670559fe00c62d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Nov 2020 10:42:28 +0900 Subject: [PATCH 24/38] [SPARK-33427][SQL][FOLLOWUP] Put key and value into IdentityHashMap sequantially ### What changes were proposed in this pull request? This follow-up fixes an issue when inserting key/value pairs into `IdentityHashMap` in `SubExprEvaluationRuntime`. ### Why are the changes needed? The last commits to #30341 follows review comment to use `IdentityHashMap`. Because we leverage `IdentityHashMap` to compare keys in reference, we should not convert expression pairs to Scala map before inserting. Scala map compares keys by equality so we will loss keys with different references. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Run benchmark to verify. Closes #30459 from viirya/SPARK-33427-map. Authored-by: Liang-Chi Hsieh Signed-off-by: HyukjinKwon --- .../SubExprEvaluationRuntime.scala | 9 +++++--- .../SubExprEvaluationRuntimeSuite.scala | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala index 3189d81289903..ff9c4cf3147d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.IdentityHashMap -import scala.collection.JavaConverters._ - import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} @@ -98,7 +96,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) { val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this) proxyExpressionCurrentId += 1 - proxyMap.putAll(e.map(_ -> proxy).toMap.asJava) + // We leverage `IdentityHashMap` so we compare expression keys by reference here. + // So for example if there are one group of common exprs like Seq(common expr 1, + // common expr2, ..., common expr n), we will insert into `proxyMap` some key/value + // pairs like Map(common expr 1 -> proxy(common expr 1), ..., + // common expr n -> proxy(common expr 1)). + e.map(proxyMap.put(_, proxy)) } // Only adding proxy if we find subexpressions. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala index 64b619ca7766b..f8dca266a62d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala @@ -95,4 +95,26 @@ class SubExprEvaluationRuntimeSuite extends SparkFunSuite { }) assert(proxys.isEmpty) } + + test("SubExprEvaluationRuntime should wrap semantically equal exprs") { + val runtime = new SubExprEvaluationRuntime(1) + + val one = Literal(1) + val two = Literal(2) + def mul: (Literal, Literal) => Expression = + (left: Literal, right: Literal) => Multiply(left, right) + + val mul2_1 = Multiply(mul(one, two), mul(one, two)) + val mul2_2 = Multiply(mul(one, two), mul(one, two)) + + val sqrt = Sqrt(mul2_1) + val sum = Add(mul2_2, sqrt) + val proxyExpressions = runtime.proxyExpressions(Seq(sum)) + val proxys = proxyExpressions.flatMap(_.collect { + case p: ExpressionProxy => p + }) + // ( (one * two) * (one * two) ) + assert(proxys.size == 2) + assert(proxys.forall(_.child.semanticEquals(mul2_1))) + } } From 0bb911d979955ac59adc39818667b616eb539103 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 23 Nov 2020 15:19:34 +0900 Subject: [PATCH 25/38] [SPARK-33143][PYTHON] Add configurable timeout to python server and client ### What changes were proposed in this pull request? Spark creates local server to serialize several type of data for python. The python code tries to connect to the server, immediately after it's created but there are several system calls in between (this may change in each Spark version): * getaddrinfo * socket * settimeout * connect Under some circumstances in heavy user environments these calls can be super slow (more than 15 seconds). These issues must be analyzed one-by-one but since these are system calls the underlying OS and/or DNS servers must be debugged and fixed. This is not trivial task and at the same time data processing must work somehow. In this PR I'm only intended to add a configuration possibility to increase the mentioned timeouts in order to be able to provide temporary workaround. The rootcause analysis is ongoing but I think this can vary in each case. Because the server part doesn't contain huge amount of log entries to with one can measure time, I've added some. ### Why are the changes needed? Provide workaround when localhost python server connection timeout appears. ### Does this PR introduce _any_ user-facing change? Yes, new configuration added. ### How was this patch tested? Existing unit tests + manual test. ``` #Compile Spark echo "spark.io.encryption.enabled true" >> conf/spark-defaults.conf echo "spark.python.authenticate.socketTimeout 10" >> conf/spark-defaults.conf $ ./bin/pyspark Python 3.8.5 (default, Jul 21 2020, 10:48:26) [Clang 11.0.3 (clang-1103.0.32.62)] on darwin Type "help", "copyright", "credits" or "license" for more information. 20/11/20 10:17:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 20/11/20 10:17:03 WARN SparkEnv: I/O encryption enabled without RPC encryption: keys will be visible on the wire. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 3.1.0-SNAPSHOT /_/ Using Python version 3.8.5 (default, Jul 21 2020 10:48:26) Spark context Web UI available at http://192.168.0.189:4040 Spark context available as 'sc' (master = local[*], app id = local-1605863824276). SparkSession available as 'spark'. >>> sc.setLogLevel("TRACE") >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() 20/11/20 10:17:09 TRACE PythonParallelizeServer: Creating listening socket 20/11/20 10:17:09 TRACE PythonParallelizeServer: Setting timeout to 10 sec 20/11/20 10:17:09 TRACE PythonParallelizeServer: Waiting for connection on port 59726 20/11/20 10:17:09 TRACE PythonParallelizeServer: Connection accepted from address /127.0.0.1:59727 20/11/20 10:17:09 TRACE PythonParallelizeServer: Client authenticated 20/11/20 10:17:09 TRACE PythonParallelizeServer: Closing server ... 20/11/20 10:17:10 TRACE SocketFuncServer: Creating listening socket 20/11/20 10:17:10 TRACE SocketFuncServer: Setting timeout to 10 sec 20/11/20 10:17:10 TRACE SocketFuncServer: Waiting for connection on port 59735 20/11/20 10:17:10 TRACE SocketFuncServer: Connection accepted from address /127.0.0.1:59736 20/11/20 10:17:10 TRACE SocketFuncServer: Client authenticated 20/11/20 10:17:10 TRACE SocketFuncServer: Closing server [[0], [2], [3], [4], [6]] >>> ``` Closes #30389 from gaborgsomogyi/SPARK-33143. Lead-authored-by: Gabor Somogyi Co-authored-by: Hyukjin Kwon Co-authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- .../apache/spark/api/python/PythonRunner.scala | 2 ++ .../apache/spark/api/python/PythonUtils.scala | 4 ++++ .../apache/spark/internal/config/Python.scala | 6 ++++++ .../spark/security/SocketAuthHelper.scala | 2 +- .../spark/security/SocketAuthServer.scala | 17 +++++++++++++---- python/pyspark/context.py | 2 ++ python/pyspark/java_gateway.py | 2 +- 7 files changed, 29 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 136da80d48dee..f49cb3c2b8836 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -80,6 +80,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) protected val simplifiedTraceback: Boolean = false @@ -139,6 +140,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (workerMemoryMb.isDefined) { envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) } + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 527d0d6d3a48d..33849f6fcb65f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -85,4 +85,8 @@ private[spark] object PythonUtils { def getBroadcastThreshold(sc: JavaSparkContext): Long = { sc.conf.get(org.apache.spark.internal.config.BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) } + + def getPythonAuthSocketTimeout(sc: JavaSparkContext): Long = { + sc.conf.get(org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala index 188d884319644..348a33e129d65 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -50,4 +50,10 @@ private[spark] object Python { .version("2.4.0") .bytesConf(ByteUnit.MiB) .createOptional + + val PYTHON_AUTH_SOCKET_TIMEOUT = ConfigBuilder("spark.python.authenticate.socketTimeout") + .internal() + .version("3.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("15s") } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index dbcb376905338..f800553c5388b 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils * * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. */ -private[spark] class SocketAuthHelper(conf: SparkConf) { +private[spark] class SocketAuthHelper(val conf: SparkConf) { val secret = Utils.createSecret(conf) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index 548fd1b07ddc5..35990b5a59281 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -25,6 +25,8 @@ import scala.concurrent.duration.Duration import scala.util.Try import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.{ThreadUtils, Utils} @@ -34,11 +36,11 @@ import org.apache.spark.util.{ThreadUtils, Utils} * handling one batch of data, with authentication and error handling. * * The socket server can only accept one connection, or close if no connection - * in 15 seconds. + * in configurable amount of seconds (default 15). */ private[spark] abstract class SocketAuthServer[T]( authHelper: SocketAuthHelper, - threadName: String) { + threadName: String) extends Logging { def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) def this(threadName: String) = this(SparkEnv.get, threadName) @@ -46,19 +48,26 @@ private[spark] abstract class SocketAuthServer[T]( private val promise = Promise[T]() private def startServer(): (Int, String) = { + logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) + // Close the socket if no connection in the configured seconds + val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt + logTrace(s"Setting timeout to $timeout sec") + serverSocket.setSoTimeout(timeout * 1000) new Thread(threadName) { setDaemon(true) override def run(): Unit = { var sock: Socket = null try { + logTrace(s"Waiting for connection on port ${serverSocket.getLocalPort}") sock = serverSocket.accept() + logTrace(s"Connection accepted from address ${sock.getRemoteSocketAddress}") authHelper.authClient(sock) + logTrace("Client authenticated") promise.complete(Try(handleConnection(sock))) } finally { + logTrace("Closing server") JavaUtils.closeQuietly(serverSocket) JavaUtils.closeQuietly(sock) } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 9c9e3f4b3c881..1bd5961e0525a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,6 +222,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # data via a socket. # scala's mangled names w/ $ in them require special treatment. self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc) + os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = \ + str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index eafa5d90f9ff8..fe2e326dff8be 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) - sock.settimeout(15) + sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15))) sock.connect(sa) sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) _do_server_auth(sockfile, auth_secret) From 84e70362dbf2bbebc7f1a1b734b99952d7e95e4d Mon Sep 17 00:00:00 2001 From: William Hyun Date: Sun, 22 Nov 2020 22:56:59 -0800 Subject: [PATCH 26/38] [SPARK-33510][BUILD] Update SBT to 1.4.4 ### What changes were proposed in this pull request? This PR aims to update SBT from 1.4.2 to 1.4.4. ### Why are the changes needed? This will bring the latest bug fixes. - https://github.com/sbt/sbt/releases/tag/v1.4.3 - https://github.com/sbt/sbt/releases/tag/v1.4.4 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. Closes #30453 from williamhyun/sbt143. Authored-by: William Hyun Signed-off-by: Dongjoon Hyun --- dev/mima | 4 ++-- project/build.properties | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/mima b/dev/mima index f324c5c00a45c..d214bb96e09a3 100755 --- a/dev/mima +++ b/dev/mima @@ -25,8 +25,8 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" SPARK_PROFILES=${1:-"-Pmesos -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"} -TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" -OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" +TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | grep jar | tail -n1)" +OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | grep jar | tail -n1)" rm -f .generated-mima* diff --git a/project/build.properties b/project/build.properties index 5ec1d700fd2a8..c92de941c10be 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=1.4.2 +sbt.version=1.4.4 From c891e025b8ed34392fbc81e988b75bdbdb268c11 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 23 Nov 2020 17:43:58 +0900 Subject: [PATCH 27/38] Revert "[SPARK-32481][CORE][SQL] Support truncate table to move data to trash" ### What changes were proposed in this pull request? This reverts commit 065f17386d1851d732b4c1badf1ce2e14d0de338, which is not part of any released version. That is, this is an unreleased feature ### Why are the changes needed? I like the concept of Trash, but I think this PR might just resolve a very specific issue by introducing a mechanism without a proper design doc. This could make the usage more complex. I think we need to consider the big picture. Trash directory is an important concept. If we decide to introduce it, we should consider all the code paths of Spark SQL that could delete the data, instead of Truncate only. We also need to consider what is the current behavior if the underlying file system does not provide the API `Trash.moveToAppropriateTrash`. Is the exception good? How about the performance when users are using the object store instead of HDFS? Will it impact the GDPR compliance? In sum, I think we should not merge the PR https://github.com/apache/spark/pull/29552 without the design doc and implementation plan. That is why I reverted it before the code freeze of Spark 3.1 ### Does this PR introduce _any_ user-facing change? Reverted the original commit ### How was this patch tested? The existing tests. Closes #30463 from gatorsmile/revertSpark-32481. Authored-by: Xiao Li Signed-off-by: HyukjinKwon --- .../scala/org/apache/spark/util/Utils.scala | 25 +----- .../apache/spark/sql/internal/SQLConf.scala | 14 ---- .../spark/sql/execution/command/tables.scala | 4 +- .../sql/execution/command/DDLSuite.scala | 78 ------------------- 4 files changed, 2 insertions(+), 119 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6ccf65b737c1a..71a310a4279ad 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -50,7 +50,7 @@ import com.google.common.net.InetAddresses import org.apache.commons.codec.binary.Hex import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FileUtil, Path, Trash} +import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -269,29 +269,6 @@ private[spark] object Utils extends Logging { file.setExecutable(true, true) } - /** - * Move data to trash if 'spark.sql.truncate.trash.enabled' is true, else - * delete the data permanently. If move data to trash failed fallback to hard deletion. - */ - def moveToTrashOrDelete( - fs: FileSystem, - partitionPath: Path, - isTrashEnabled: Boolean, - hadoopConf: Configuration): Boolean = { - if (isTrashEnabled) { - logDebug(s"Try to move data ${partitionPath.toString} to trash") - val isSuccess = Trash.moveToAppropriateTrash(fs, partitionPath, hadoopConf) - if (!isSuccess) { - logWarning(s"Failed to move data ${partitionPath.toString} to trash. " + - "Fallback to hard deletion") - return fs.delete(partitionPath, true) - } - isSuccess - } else { - fs.delete(partitionPath, true) - } - } - /** * Create a directory given the abstract pathname * @return true, if the directory is successfully created; otherwise, return false. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fcf222c8fdab0..ef974dc176e51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2913,18 +2913,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val TRUNCATE_TRASH_ENABLED = - buildConf("spark.sql.truncate.trash.enabled") - .doc("This configuration decides when truncating table, whether data files will be moved " + - "to trash directory or deleted permanently. The trash retention time is controlled by " + - "'fs.trash.interval', and in default, the server side configuration value takes " + - "precedence over the client-side one. Note that if 'fs.trash.interval' is non-positive, " + - "this will be a no-op and log a warning message. If the data fails to be moved to " + - "trash, Spark will turn to delete it permanently.") - .version("3.1.0") - .booleanConf - .createWithDefault(false) - val DISABLED_JDBC_CONN_PROVIDER_LIST = buildConf("spark.sql.sources.disabledJdbcConnProviderList") .internal() @@ -3577,8 +3565,6 @@ class SQLConf extends Serializable with Logging { def legacyPathOptionBehavior: Boolean = getConf(SQLConf.LEGACY_PATH_OPTION_BEHAVIOR) - def truncateTrashEnabled: Boolean = getConf(SQLConf.TRUNCATE_TRASH_ENABLED) - def disabledJdbcConnectionProviders: String = getConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 206f952fed0ca..847052cd4fcde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -48,7 +48,6 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils -import org.apache.spark.util.Utils /** * A command to create a table with the same definition of the given existing table. @@ -490,7 +489,6 @@ case class TruncateTableCommand( } val hadoopConf = spark.sessionState.newHadoopConf() val ignorePermissionAcl = SQLConf.get.truncateTableIgnorePermissionAcl - val isTrashEnabled = SQLConf.get.truncateTrashEnabled locations.foreach { location => if (location.isDefined) { val path = new Path(location.get) @@ -515,7 +513,7 @@ case class TruncateTableCommand( } } - Utils.moveToTrashOrDelete(fs, path, isTrashEnabled, hadoopConf) + fs.delete(path, true) // We should keep original permission/acl of the path. // For owner/group, only super-user can set it, for example on HDFS. Because diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 9d0147048dbb8..43a33860d262e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -3104,84 +3104,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(spark.sessionState.catalog.isRegisteredFunction(rand)) } } - - test("SPARK-32481 Move data to trash on truncate table if enabled") { - val trashIntervalKey = "fs.trash.interval" - withTable("tab1") { - withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "true") { - sql("CREATE TABLE tab1 (col INT) USING parquet") - sql("INSERT INTO tab1 SELECT 1") - // scalastyle:off hadoopconfiguration - val hadoopConf = spark.sparkContext.hadoopConfiguration - // scalastyle:on hadoopconfiguration - val originalValue = hadoopConf.get(trashIntervalKey, "0") - val tablePath = new Path(spark.sessionState.catalog - .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get) - - val fs = tablePath.getFileSystem(hadoopConf) - val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current") - val trashPath = Path.mergePaths(trashCurrent, tablePath) - assume( - fs.mkdirs(trashPath) && fs.delete(trashPath, false), - "Trash directory could not be created, skipping.") - assert(!fs.exists(trashPath)) - try { - hadoopConf.set(trashIntervalKey, "5") - sql("TRUNCATE TABLE tab1") - } finally { - hadoopConf.set(trashIntervalKey, originalValue) - } - assert(fs.exists(trashPath)) - fs.delete(trashPath, true) - } - } - } - - test("SPARK-32481 delete data permanently on truncate table if trash interval is non-positive") { - val trashIntervalKey = "fs.trash.interval" - withTable("tab1") { - withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "true") { - sql("CREATE TABLE tab1 (col INT) USING parquet") - sql("INSERT INTO tab1 SELECT 1") - // scalastyle:off hadoopconfiguration - val hadoopConf = spark.sparkContext.hadoopConfiguration - // scalastyle:on hadoopconfiguration - val originalValue = hadoopConf.get(trashIntervalKey, "0") - val tablePath = new Path(spark.sessionState.catalog - .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get) - - val fs = tablePath.getFileSystem(hadoopConf) - val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current") - val trashPath = Path.mergePaths(trashCurrent, tablePath) - assert(!fs.exists(trashPath)) - try { - hadoopConf.set(trashIntervalKey, "0") - sql("TRUNCATE TABLE tab1") - } finally { - hadoopConf.set(trashIntervalKey, originalValue) - } - assert(!fs.exists(trashPath)) - } - } - } - - test("SPARK-32481 Do not move data to trash on truncate table if disabled") { - withTable("tab1") { - withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "false") { - sql("CREATE TABLE tab1 (col INT) USING parquet") - sql("INSERT INTO tab1 SELECT 1") - val hadoopConf = spark.sessionState.newHadoopConf() - val tablePath = new Path(spark.sessionState.catalog - .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get) - - val fs = tablePath.getFileSystem(hadoopConf) - val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current") - val trashPath = Path.mergePaths(trashCurrent, tablePath) - sql("TRUNCATE TABLE tab1") - assert(!fs.exists(trashPath)) - } - } - } } object FakeLocalFsFileSystem { From 60f3a730e4e67c3b67d6e45fb18f589ad66b07e6 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 23 Nov 2020 08:54:00 +0000 Subject: [PATCH 28/38] [SPARK-33515][SQL] Improve exception messages while handling UnresolvedTable ### What changes were proposed in this pull request? This PR proposes to improve the exception messages while `UnresolvedTable` is handled based on this suggestion: https://github.com/apache/spark/pull/30321#discussion_r521127001. Currently, when an identifier is resolved to a view when a table is expected, the following exception message is displayed (e.g., for `COMMENT ON TABLE`): ``` v is a temp view not table. ``` After this PR, the message will be: ``` v is a temp view. 'COMMENT ON TABLE' expects a table. ``` Also, if an identifier is not resolved, the following exception message is currently used: ``` Table not found: t ``` After this PR, the message will be: ``` Table not found for 'COMMENT ON TABLE': t ``` ### Why are the changes needed? To improve the exception message. ### Does this PR introduce _any_ user-facing change? Yes, the exception message will be changed as described above. ### How was this patch tested? Updated existing tests. Closes #30461 from imback82/unresolved_table_message. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/analysis/Analyzer.scala | 10 +++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../catalyst/analysis/v2ResolutionPlans.scala | 4 +++- .../spark/sql/catalyst/parser/AstBuilder.scala | 12 ++++++++---- .../sql/catalyst/parser/DDLParserSuite.scala | 18 +++++++++--------- .../sql/connector/DataSourceV2SQLSuite.scala | 3 ++- .../spark/sql/execution/SQLViewSuite.scala | 8 ++++---- .../sql/hive/execution/HiveDDLSuite.scala | 4 ++-- 8 files changed, 34 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 53c0ff687c6d2..837686420375a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -861,9 +861,9 @@ class Analyzer(override val catalogManager: CatalogManager) }.getOrElse(write) case _ => write } - case u @ UnresolvedTable(ident) => + case u @ UnresolvedTable(ident, cmd) => lookupTempView(ident).foreach { _ => - u.failAnalysis(s"${ident.quoted} is a temp view not table.") + u.failAnalysis(s"${ident.quoted} is a temp view. '$cmd' expects a table") } u case u @ UnresolvedTableOrView(ident, allowTempView) => @@ -950,7 +950,7 @@ class Analyzer(override val catalogManager: CatalogManager) SubqueryAlias(catalog.get.name +: ident.namespace :+ ident.name, relation) }.getOrElse(u) - case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) => + case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident), _) => CatalogV2Util.loadTable(catalog, ident) .map(ResolvedTable(catalog.asTableCatalog, ident, _)) .getOrElse(u) @@ -1077,11 +1077,11 @@ class Analyzer(override val catalogManager: CatalogManager) lookupRelation(u.multipartIdentifier, u.options, u.isStreaming) .map(resolveViews).getOrElse(u) - case u @ UnresolvedTable(identifier) => + case u @ UnresolvedTable(identifier, cmd) => lookupTableOrView(identifier).map { case v: ResolvedView => val viewStr = if (v.isTemp) "temp view" else "view" - u.failAnalysis(s"${v.identifier.quoted} is a $viewStr not table.") + u.failAnalysis(s"${v.identifier.quoted} is a $viewStr. '$cmd' expects a table.'") case table => table }.getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 452ba80b23441..9998035d65c3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -98,7 +98,7 @@ trait CheckAnalysis extends PredicateHelper { u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}") case u: UnresolvedTable => - u.failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}") + u.failAnalysis(s"Table not found for '${u.commandName}': ${u.multipartIdentifier.quoted}") case u: UnresolvedTableOrView => u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index 98bd84fb94bd6..0e883a88f2691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -37,7 +37,9 @@ case class UnresolvedNamespace(multipartIdentifier: Seq[String]) extends LeafNod * Holds the name of a table that has yet to be looked up in a catalog. It will be resolved to * [[ResolvedTable]] during analysis. */ -case class UnresolvedTable(multipartIdentifier: Seq[String]) extends LeafNode { +case class UnresolvedTable( + multipartIdentifier: Seq[String], + commandName: String) extends LeafNode { override lazy val resolved: Boolean = false override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 23de8ab09dd0a..ea4baafbacede 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3303,7 +3303,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitLoadData(ctx: LoadDataContext): LogicalPlan = withOrigin(ctx) { LoadData( - child = UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)), + child = UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier), "LOAD DATA"), path = string(ctx.path), isLocal = ctx.LOCAL != null, isOverwrite = ctx.OVERWRITE != null, @@ -3449,7 +3449,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg UnresolvedPartitionSpec(spec, location) } AlterTableAddPartition( - UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)), + UnresolvedTable( + visitMultipartIdentifier(ctx.multipartIdentifier), + "ALTER TABLE ... ADD PARTITION ..."), specsAndLocs.toSeq, ctx.EXISTS != null) } @@ -3491,7 +3493,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val partSpecs = ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec) .map(spec => UnresolvedPartitionSpec(spec)) AlterTableDropPartition( - UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)), + UnresolvedTable( + visitMultipartIdentifier(ctx.multipartIdentifier), + "ALTER TABLE ... DROP PARTITION ..."), partSpecs.toSeq, ifExists = ctx.EXISTS != null, purge = ctx.PURGE != null, @@ -3720,6 +3724,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case _ => string(ctx.STRING) } val nameParts = visitMultipartIdentifier(ctx.multipartIdentifier) - CommentOnTable(UnresolvedTable(nameParts), comment) + CommentOnTable(UnresolvedTable(nameParts, "COMMENT ON TABLE"), comment) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index f93c0dcf59f4c..bd28484b23f46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1555,15 +1555,15 @@ class DDLParserSuite extends AnalysisTest { test("LOAD DATA INTO table") { comparePlans( parsePlan("LOAD DATA INPATH 'filepath' INTO TABLE a.b.c"), - LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", false, false, None)) + LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", false, false, None)) comparePlans( parsePlan("LOAD DATA LOCAL INPATH 'filepath' INTO TABLE a.b.c"), - LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", true, false, None)) + LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, false, None)) comparePlans( parsePlan("LOAD DATA LOCAL INPATH 'filepath' OVERWRITE INTO TABLE a.b.c"), - LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", true, true, None)) + LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, true, None)) comparePlans( parsePlan( @@ -1572,7 +1572,7 @@ class DDLParserSuite extends AnalysisTest { |PARTITION(ds='2017-06-10') """.stripMargin), LoadData( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, true, @@ -1674,13 +1674,13 @@ class DDLParserSuite extends AnalysisTest { val parsed2 = parsePlan(sql2) val expected1 = AlterTableAddPartition( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ..."), Seq( UnresolvedPartitionSpec(Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), UnresolvedPartitionSpec(Map("dt" -> "2009-09-09", "country" -> "uk"), None)), ifNotExists = true) val expected2 = AlterTableAddPartition( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ..."), Seq(UnresolvedPartitionSpec(Map("dt" -> "2008-08-08"), Some("loc"))), ifNotExists = false) @@ -1747,7 +1747,7 @@ class DDLParserSuite extends AnalysisTest { assertUnsupported(sql2_view) val expected1_table = AlterTableDropPartition( - UnresolvedTable(Seq("table_name")), + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP PARTITION ..."), Seq( UnresolvedPartitionSpec(Map("dt" -> "2008-08-08", "country" -> "us")), UnresolvedPartitionSpec(Map("dt" -> "2009-09-09", "country" -> "uk"))), @@ -1763,7 +1763,7 @@ class DDLParserSuite extends AnalysisTest { val sql3_table = "ALTER TABLE a.b.c DROP IF EXISTS PARTITION (ds='2017-06-10')" val expected3_table = AlterTableDropPartition( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... DROP PARTITION ..."), Seq(UnresolvedPartitionSpec(Map("ds" -> "2017-06-10"))), ifExists = true, purge = false, @@ -2174,7 +2174,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("COMMENT ON TABLE a.b.c IS 'xYz'"), - CommentOnTable(UnresolvedTable(Seq("a", "b", "c")), "xYz")) + CommentOnTable(UnresolvedTable(Seq("a", "b", "c"), "COMMENT ON TABLE"), "xYz")) } // TODO: ignored by SPARK-31707, restore the test after create table syntax unification diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 90df4ee08bfc0..da53936239de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2414,7 +2414,8 @@ class DataSourceV2SQLSuite withTempView("v") { sql("create global temp view v as select 1") val e = intercept[AnalysisException](sql("COMMENT ON TABLE global_temp.v IS NULL")) - assert(e.getMessage.contains("global_temp.v is a temp view not table.")) + assert(e.getMessage.contains( + "global_temp.v is a temp view. 'COMMENT ON TABLE' expects a table")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 792f920ee0217..504cc57dc12d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -147,10 +147,10 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { s"'$viewName' is a view not a table") assertAnalysisError( s"ALTER TABLE $viewName ADD IF NOT EXISTS PARTITION (a='4', b='8')", - s"$viewName is a temp view not table") + s"$viewName is a temp view. 'ALTER TABLE ... ADD PARTITION ...' expects a table") assertAnalysisError( s"ALTER TABLE $viewName DROP PARTITION (a='4', b='8')", - s"$viewName is a temp view not table") + s"$viewName is a temp view. 'ALTER TABLE ... DROP PARTITION ...' expects a table") // For the following v2 ALERT TABLE statements, unsupported operations are checked first // before resolving the relations. @@ -175,7 +175,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { val e2 = intercept[AnalysisException] { sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""") }.getMessage - assert(e2.contains(s"$viewName is a temp view not table")) + assert(e2.contains(s"$viewName is a temp view. 'LOAD DATA' expects a table")) assertNoSuchTable(s"TRUNCATE TABLE $viewName") val e3 = intercept[AnalysisException] { sql(s"SHOW CREATE TABLE $viewName") @@ -214,7 +214,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { e = intercept[AnalysisException] { sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""") }.getMessage - assert(e.contains("default.testView is a view not table")) + assert(e.contains("default.testView is a view. 'LOAD DATA' expects a table")) e = intercept[AnalysisException] { sql(s"TRUNCATE TABLE $viewName") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 1f15bd685b239..56b871644453b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -904,10 +904,10 @@ class HiveDDLSuite assertAnalysisError( s"ALTER TABLE $oldViewName ADD IF NOT EXISTS PARTITION (a='4', b='8')", - s"$oldViewName is a view not table") + s"$oldViewName is a view. 'ALTER TABLE ... ADD PARTITION ...' expects a table.") assertAnalysisError( s"ALTER TABLE $oldViewName DROP IF EXISTS PARTITION (a='2')", - s"$oldViewName is a view not table") + s"$oldViewName is a view. 'ALTER TABLE ... DROP PARTITION ...' expects a table.") assert(catalog.tableExists(TableIdentifier(tabName))) assert(catalog.tableExists(TableIdentifier(oldViewName))) From 23e9920b3910e4f05269853429c7f18888cdc7b5 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Mon, 23 Nov 2020 09:00:41 +0000 Subject: [PATCH 29/38] [SPARK-33511][SQL] Respect case sensitivity while resolving V2 partition specs ### What changes were proposed in this pull request? 1. Pre-process partition specs in `ResolvePartitionSpec`, and convert partition names according to the partition schema and the SQL config `spark.sql.caseSensitive`. In the PR, I propose to invoke `normalizePartitionSpec` for that. The function is used in DSv1 commands, so, the behavior will be similar to DSv1. 2. Move `normalizePartitionSpec()` from `sql/core/.../datasources/PartitioningUtils` to `sql/catalyst/.../util/PartitioningUtils` to use it in Catalyst's rule `ResolvePartitionSpec` ### Why are the changes needed? DSv1 commands like `ALTER TABLE .. ADD PARTITION` and `ALTER TABLE .. DROP PARTITION` respect the SQL config `spark.sql.caseSensitive` while resolving partition specs. For example: ```sql spark-sql> CREATE TABLE tbl1 (id bigint, data string) USING parquet PARTITIONED BY (id); spark-sql> ALTER TABLE tbl1 ADD PARTITION (ID=1); spark-sql> SHOW PARTITIONS tbl1; id=1 ``` The same command fails on V2 Table catalog with error: ``` AnalysisException: Partition key ID not exists ``` ### Does this PR introduce _any_ user-facing change? Yes. After the changes, partition spec resolution works as for DSv1 (without the exception showed above). ### How was this patch tested? By running `AlterTablePartitionV2SQLSuite`. Closes #30454 from MaxGekk/partition-spec-case-sensitivity. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../analysis/ResolvePartitionSpec.scala | 27 +++++++---- .../spark/sql/util/PartitioningUtils.scala | 47 +++++++++++++++++++ .../command/AnalyzePartitionCommand.scala | 2 +- .../spark/sql/execution/command/ddl.scala | 3 +- .../spark/sql/execution/command/tables.scala | 3 +- .../datasources/PartitioningUtils.scala | 26 +--------- .../sql/execution/datasources/rules.scala | 3 +- .../AlterTablePartitionV2SQLSuite.scala | 26 ++++++++++ 8 files changed, 98 insertions(+), 39 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 5e19a32968992..531d40f431dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, Alte import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec /** * Resolve [[UnresolvedPartitionSpec]] to [[ResolvedPartitionSpec]] in partition related commands. @@ -33,32 +34,38 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case r @ AlterTableAddPartition( ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _) => - r.copy(parts = resolvePartitionSpecs(partSpecs, table.partitionSchema())) + r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema())) case r @ AlterTableDropPartition( ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _, _, _) => - r.copy(parts = resolvePartitionSpecs(partSpecs, table.partitionSchema())) + r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema())) } private def resolvePartitionSpecs( - partSpecs: Seq[PartitionSpec], partSchema: StructType): Seq[ResolvedPartitionSpec] = + tableName: String, + partSpecs: Seq[PartitionSpec], + partSchema: StructType): Seq[ResolvedPartitionSpec] = partSpecs.map { case unresolvedPartSpec: UnresolvedPartitionSpec => ResolvedPartitionSpec( - convertToPartIdent(unresolvedPartSpec.spec, partSchema), unresolvedPartSpec.location) + convertToPartIdent(tableName, unresolvedPartSpec.spec, partSchema), + unresolvedPartSpec.location) case resolvedPartitionSpec: ResolvedPartitionSpec => resolvedPartitionSpec } private def convertToPartIdent( - partSpec: TablePartitionSpec, partSchema: StructType): InternalRow = { - val conflictKeys = partSpec.keys.toSeq.diff(partSchema.map(_.name)) - if (conflictKeys.nonEmpty) { - throw new AnalysisException(s"Partition key ${conflictKeys.mkString(",")} not exists") - } + tableName: String, + partitionSpec: TablePartitionSpec, + partSchema: StructType): InternalRow = { + val normalizedSpec = normalizePartitionSpec( + partitionSpec, + partSchema.map(_.name), + tableName, + conf.resolver) val partValues = partSchema.map { part => - val partValue = partSpec.get(part.name).orNull + val partValue = normalizedSpec.get(part.name).orNull if (partValue == null) { null } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala new file mode 100644 index 0000000000000..586aa6c59164f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver + +object PartitioningUtils { + /** + * Normalize the column names in partition specification, w.r.t. the real partition column names + * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a + * partition column named `month`, and it's case insensitive, we will normalize `monTh` to + * `month`. + */ + def normalizePartitionSpec[T]( + partitionSpec: Map[String, T], + partColNames: Seq[String], + tblName: String, + resolver: Resolver): Map[String, T] = { + val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) => + val normalizedKey = partColNames.find(resolver(_, key)).getOrElse { + throw new AnalysisException(s"$key is not a valid partition column in table $tblName.") + } + normalizedKey -> value + } + + SchemaUtils.checkColumnNameDuplication( + normalizedPartSpec.map(_._1), "in the partition schema", resolver) + + normalizedPartSpec.toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index fc62dce5002b1..0b265bfb63e3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, Unresol import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.util.PartitioningUtils /** * Analyzes a given set of partitions to generate per-partition statistics, which will be used in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index d550fe270c753..27ad62026c9b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -39,11 +39,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog} import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitioningUtils import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} // Note: The definition of these commands are based on the ones described in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 847052cd4fcde..bd238948aab02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier, CaseInsensitiveMap} -import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitioningUtils import org.apache.spark.sql.util.SchemaUtils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 796c23c7337d8..ea437d200eaab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter} @@ -357,30 +357,6 @@ object PartitioningUtils { getPathFragment(spec, StructType.fromAttributes(partitionColumns)) } - /** - * Normalize the column names in partition specification, w.r.t. the real partition column names - * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a - * partition column named `month`, and it's case insensitive, we will normalize `monTh` to - * `month`. - */ - def normalizePartitionSpec[T]( - partitionSpec: Map[String, T], - partColNames: Seq[String], - tblName: String, - resolver: Resolver): Map[String, T] = { - val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) => - val normalizedKey = partColNames.find(resolver(_, key)).getOrElse { - throw new AnalysisException(s"$key is not a valid partition column in table $tblName.") - } - normalizedKey -> value - } - - SchemaUtils.checkColumnNameDuplication( - normalizedPartSpec.map(_._1), "in the partition schema", resolver) - - normalizedPartSpec.toMap - } - /** * Resolves possible type conflicts between partitions by up-casting "lower" types using * [[findWiderTypeForPartitionColumn]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3a2a642b870f8..9e65b0ce13693 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec import org.apache.spark.sql.util.SchemaUtils /** @@ -386,7 +387,7 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { partColNames: Seq[String], catalogTable: Option[CatalogTable]): InsertIntoStatement = { - val normalizedPartSpec = PartitioningUtils.normalizePartitionSpec( + val normalizedPartSpec = normalizePartitionSpec( insert.partitionSpec, partColNames, tblName, conf.resolver) val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala index 107d0ea47249d..e05c2c09ace2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionsException, PartitionsAlreadyExistException} import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits +import org.apache.spark.sql.internal.SQLConf class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { @@ -159,4 +160,29 @@ class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { assert(partTable.asPartitionable.listPartitionIdentifiers(InternalRow.empty).isEmpty) } } + + test("case sensitivity in resolving partition specs") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg = intercept[AnalysisException] { + spark.sql(s"ALTER TABLE $t ADD PARTITION (ID=1) LOCATION 'loc1'") + }.getMessage + assert(errMsg.contains(s"ID is not a valid partition column in table $t")) + } + + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) + .asPartitionable + assert(!partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + spark.sql(s"ALTER TABLE $t ADD PARTITION (ID=1) LOCATION 'loc1'") + assert(partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) + spark.sql(s"ALTER TABLE $t DROP PARTITION (Id=1)") + assert(!partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) + } + } + } } From f83fcb12543049672a54ef5b582d58817e2ee5d3 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 23 Nov 2020 14:54:44 +0000 Subject: [PATCH 30/38] [SPARK-33278][SQL][FOLLOWUP] Improve OptimizeWindowFunctions to avoid transfer first to nth_value ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/30178 provided `OptimizeWindowFunctions` used to transfer `first` to `nth_value`. If the window frame is `UNBOUNDED PRECEDING AND CURRENT ROW` or `UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING`, `nth_value` has better performance than `first`. But the `OptimizeWindowFunctions` need to exclude other window frame. ### Why are the changes needed? Improve `OptimizeWindowFunctions` to avoid transfer `first` to `nth_value` if the specified window frame isn't `UNBOUNDED PRECEDING AND CURRENT ROW` or `UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING`. ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? Jenkins test. Closes #30419 from beliefer/SPARK-33278_followup. Lead-authored-by: gengjiaan Co-authored-by: beliefer Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 9 +++-- .../OptimizeWindowFunctionsSuite.scala | 33 +++++++++++++++++-- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c4b9936fa4c4f..9eee7c2b914a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -811,9 +811,12 @@ object CollapseRepartition extends Rule[LogicalPlan] { */ object OptimizeWindowFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), spec) - if spec.orderSpec.nonEmpty && - spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame].frameType == RowFrame => + case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), + WindowSpecDefinition(_, orderSpec, frameSpecification: SpecifiedWindowFrame)) + if orderSpec.nonEmpty && frameSpecification.frameType == RowFrame && + frameSpecification.lower == UnboundedPreceding && + (frameSpecification.upper == UnboundedFollowing || + frameSpecification.upper == CurrentRow) => we.copy(windowFunction = NthValue(first.child, Literal(1), first.ignoreNulls)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala index 389aaeafe655f..cf850bbe21ce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala @@ -36,7 +36,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { val b = testRelation.output(1) val c = testRelation.output(2) - test("replace first(col) by nth_value(col, 1)") { + test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND CURRENT ROW") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), @@ -52,7 +52,34 @@ class OptimizeWindowFunctionsSuite extends PlanTest { assert(optimized == correctAnswer) } - test("can't replace first(col) by nth_value(col, 1) if the window frame type is range") { + test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") { + val inputPlan = testRelation.select( + WindowExpression( + First(a, false).toAggregateExpression(), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)))) + val correctAnswer = testRelation.select( + WindowExpression( + NthValue(a, Literal(1), false), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)))) + + val optimized = Optimize.execute(inputPlan) + assert(optimized == correctAnswer) + } + + test("can't replace first by nth_value if frame is not suitable") { + val inputPlan = testRelation.select( + WindowExpression( + First(a, false).toAggregateExpression(), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, Literal(1), CurrentRow)))) + + val optimized = Optimize.execute(inputPlan) + assert(optimized == inputPlan) + } + + test("can't replace first by nth_value if the window frame type is range") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), @@ -63,7 +90,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { assert(optimized == inputPlan) } - test("can't replace first(col) by nth_value(col, 1) if the window frame isn't ordered") { + test("can't replace first by nth_value if the window frame isn't ordered") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), From 1bd897cbc4fe30eb8b7740c7232aae87081e8e33 Mon Sep 17 00:00:00 2001 From: Ye Zhou Date: Mon, 23 Nov 2020 15:16:20 -0600 Subject: [PATCH 31/38] [SPARK-32918][SHUFFLE] RPC implementation to support control plane coordination for push-based shuffle ### What changes were proposed in this pull request? This is one of the patches for SPIP SPARK-30602 which is needed for push-based shuffle. Summary of changes: This PR introduces a new RPC to be called within Driver. When the expected shuffle push wait time reaches, Driver will call this RPC to facilitate coordination of shuffle map/reduce stages and notify external shuffle services to finalize shuffle block merge for a given shuffle. Shuffle services also respond back the metadata about a merged shuffle partition back to the caller. ### Why are the changes needed? Refer to the SPIP in SPARK-30602. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? This code snippets won't be called by any existing code and will be tested after the coordinated driver changes gets merged in SPARK-32920. Lead-authored-by: Min Shen mshenlinkedin.com Closes #30163 from zhouyejoe/SPARK-32918. Lead-authored-by: Ye Zhou Co-authored-by: Min Shen Signed-off-by: Mridul Muralidharan gmail.com> --- .../network/shuffle/BlockStoreClient.java | 22 ++++++++++ .../shuffle/ExternalBlockStoreClient.java | 29 +++++++++++++ .../shuffle/MergeFinalizerListener.java | 43 +++++++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergeFinalizerListener.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 37befcd4b67fa..a6bdc13e93234 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -147,6 +147,8 @@ public void onFailure(Throwable t) { * @param blockIds block ids to be pushed * @param buffers buffers to be pushed * @param listener the listener to receive block push status. + * + * @since 3.1.0 */ public void pushBlocks( String host, @@ -156,4 +158,24 @@ public void pushBlocks( BlockFetchingListener listener) { throw new UnsupportedOperationException(); } + + /** + * Invoked by Spark driver to notify external shuffle services to finalize the shuffle merge + * for a given shuffle. This allows the driver to start the shuffle reducer stage after properly + * finishing the shuffle merge process associated with the shuffle mapper stage. + * + * @param host host of shuffle server + * @param port port of shuffle server. + * @param shuffleId shuffle ID of the shuffle to be finalized + * @param listener the listener to receive MergeStatuses + * + * @since 3.1.0 + */ + public void finalizeShuffleMerge( + String host, + int port, + int shuffleId, + MergeFinalizerListener listener) { + throw new UnsupportedOperationException(); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index eca35ed290467..56c06e640acda 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -158,6 +158,35 @@ public void pushBlocks( } } + @Override + public void finalizeShuffleMerge( + String host, + int port, + int shuffleId, + MergeFinalizerListener listener) { + checkInit(); + try { + TransportClient client = clientFactory.createClient(host, port); + ByteBuffer finalizeShuffleMerge = new FinalizeShuffleMerge(appId, shuffleId).toByteBuffer(); + client.sendRpc(finalizeShuffleMerge, new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + listener.onShuffleMergeSuccess( + (MergeStatuses) BlockTransferMessage.Decoder.fromByteBuffer(response)); + } + + @Override + public void onFailure(Throwable e) { + listener.onShuffleMergeFailure(e); + } + }); + } catch (Exception e) { + logger.error("Exception while sending finalizeShuffleMerge request to {}:{}", + host, port, e); + listener.onShuffleMergeFailure(e); + } + } + @Override public MetricSet shuffleMetrics() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergeFinalizerListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergeFinalizerListener.java new file mode 100644 index 0000000000000..08e13eea9f40d --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergeFinalizerListener.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.EventListener; + +import org.apache.spark.network.shuffle.protocol.MergeStatuses; + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when driver receives the response for the + * finalize shuffle merge request sent to remote shuffle service. + * + * @since 3.1.0 + */ +public interface MergeFinalizerListener extends EventListener { + /** + * Called once upon successful response on finalize shuffle merge on a remote shuffle service. + * The returned {@link MergeStatuses} is passed to the listener for further processing + */ + void onShuffleMergeSuccess(MergeStatuses statuses); + + /** + * Called once upon failure response on finalize shuffle merge on a remote shuffle service. + */ + void onShuffleMergeFailure(Throwable e); +} From 05921814e2349e1acecb14a365e6d47ffb0d68e8 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 24 Nov 2020 09:27:44 +0900 Subject: [PATCH 32/38] [SPARK-33479][DOC][FOLLOWUP] DocSearch: Support filtering search results by version ### What changes were proposed in this pull request? In the discussion https://github.com/apache/spark/pull/30292#issuecomment-725613417, we planned to apply a new API key for each Spark release. However, it turns that DocSearch supports crawling multiple URLs from one website and filtering by fact key: https://docsearch.algolia.com/docs/config-file/#using-regular-expressions Thanks to the help from shortcuts, our Spark doc supports multiple version now: https://github.com/algolia/docsearch-configs/pull/2868 This PR is to add the fact key in the search script and update the instruction in the comment. ### Why are the changes needed? To support filtering Spark documentation search results by the current document version. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual test Closes #30469 from gengliangwang/apiKeyFollowUp. Authored-by: Gengliang Wang Signed-off-by: Takeshi Yamamuro --- docs/_config.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/_config.yml b/docs/_config.yml index cd341063a1f92..026b3dd804690 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -26,15 +26,20 @@ SCALA_VERSION: "2.12.10" MESOS_VERSION: 1.0.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark -# Before a new release, we should apply a new `apiKey` for the new Spark documentation -# on https://docsearch.algolia.com/. Otherwise, after release, the search results are always based -# on the latest documentation(https://spark.apache.org/docs/latest/) even when visiting the -# documentation of previous releases. +# Before a new release, we should: +# 1. update the `version` array for the new Spark documentation +# on https://github.com/algolia/docsearch-configs/blob/master/configs/apache_spark.json. +# 2. update the value of `facetFilters.version` in `algoliaOptions` on the new release branch. +# Otherwise, after release, the search results are always based on the latest documentation +# (https://spark.apache.org/docs/latest/) even when visiting the documentation of previous releases. DOCSEARCH_SCRIPT: | docsearch({ apiKey: 'b18ca3732c502995563043aa17bc6ecb', indexName: 'apache_spark', inputSelector: '#docsearch-input', enhancedSearchInput: true, + algoliaOptions: { + 'facetFilters': ["version:latest"] + }, debug: false // Set debug to true if you want to inspect the dropdown }); From 3ce4ab545bfc28db7df2c559726b887b0c8c33b7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 23 Nov 2020 16:28:43 -0800 Subject: [PATCH 33/38] [SPARK-33513][BUILD] Upgrade to Scala 2.13.4 to improve exhaustivity ### What changes were proposed in this pull request? This PR aims the followings. 1. Upgrade from Scala 2.13.3 to 2.13.4 for Apache Spark 3.1 2. Fix exhaustivity issues in both Scala 2.12/2.13 (Scala 2.13.4 requires this for compilation.) 3. Enforce the improved exhaustive check by using the existing Scala 2.13 GitHub Action compilation job. ### Why are the changes needed? Scala 2.13.4 is a maintenance release for 2.13 line and improves JDK 15 support. - https://github.com/scala/scala/releases/tag/v2.13.4 Also, it improves exhaustivity check. - https://github.com/scala/scala/pull/9140 (Check exhaustivity of pattern matches with "if" guards and custom extractors) - https://github.com/scala/scala/pull/9147 (Check all bindings exhaustively, e.g. tuples components) ### Does this PR introduce _any_ user-facing change? Yep. Although it's a maintenance version change, it's a Scala version change. ### How was this patch tested? Pass the CIs and do the manual testing. - Scala 2.12 CI jobs(GitHub Action/Jenkins UT/Jenkins K8s IT) to check the validity of code change. - Scala 2.13 Compilation job to check the compilation Closes #30455 from dongjoon-hyun/SCALA_3.13. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/storage/StorageUtils.scala | 2 +- .../main/scala/org/apache/spark/util/JsonProtocol.scala | 8 ++++---- .../src/main/scala/org/apache/spark/ml/linalg/BLAS.scala | 2 ++ .../org/apache/spark/ml/feature/RFormulaParser.scala | 6 +++++- .../org/apache/spark/ml/feature/StandardScaler.scala | 2 ++ .../org/apache/spark/ml/linalg/JsonMatrixConverter.scala | 2 ++ .../org/apache/spark/ml/linalg/JsonVectorConverter.scala | 2 ++ .../main/scala/org/apache/spark/ml/linalg/VectorUDT.scala | 2 ++ .../spark/ml/optim/aggregator/HingeAggregator.scala | 3 +++ .../spark/ml/optim/aggregator/LogisticAggregator.scala | 3 +++ .../scala/org/apache/spark/ml/util/Instrumentation.scala | 2 ++ .../org/apache/spark/mllib/feature/StandardScaler.scala | 2 ++ .../main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 2 ++ .../scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 ++ .../spark/mllib/linalg/distributed/IndexedRowMatrix.scala | 4 ++++ .../apache/spark/mllib/linalg/distributed/RowMatrix.scala | 2 ++ pom.xml | 2 +- .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 2 +- .../mesos/MesosFineGrainedSchedulerBackendSuite.scala | 2 +- .../spark/sql/catalyst/expressions/jsonExpressions.scala | 2 +- .../apache/spark/sql/catalyst/expressions/literals.scala | 4 +++- .../spark/sql/catalyst/expressions/objects/objects.scala | 2 +- .../apache/spark/sql/catalyst/json/JsonInferSchema.scala | 3 +++ .../sql/catalyst/optimizer/StarSchemaDetection.scala | 6 +++--- .../apache/spark/sql/catalyst/optimizer/expressions.scala | 1 + .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 2 ++ .../catalyst/plans/logical/basicLogicalOperators.scala | 2 +- .../apache/spark/sql/catalyst/util/GenericArrayData.scala | 2 +- .../spark/sql/catalyst/planning/ScanOperationSuite.scala | 5 +++++ .../sql/catalyst/util/ArrayDataIndexedSeqSuite.scala | 2 +- .../org/apache/spark/sql/execution/SparkSqlParser.scala | 6 +++--- .../spark/sql/execution/aggregate/BaseAggregateExec.scala | 2 +- .../spark/sql/execution/window/WindowExecBase.scala | 6 ++++++ .../scala/org/apache/spark/sql/hive/HiveInspectors.scala | 1 + .../spark/streaming/util/FileBasedWriteAheadLog.scala | 2 +- 35 files changed, 77 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 147731a0fb547..c607fb28b2f56 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -169,7 +169,7 @@ private[spark] class StorageStatus( .getOrElse((0L, 0L)) case _ if !level.useOffHeap => (_nonRddStorageInfo.onHeapUsage, _nonRddStorageInfo.diskUsage) - case _ if level.useOffHeap => + case _ => (_nonRddStorageInfo.offHeapUsage, _nonRddStorageInfo.diskUsage) } val newMem = math.max(oldMem + changeInMem, 0L) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 13f7cb453346f..103965e4860a3 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -757,7 +757,7 @@ private[spark] object JsonProtocol { def taskResourceRequestMapFromJson(json: JValue): Map[String, TaskResourceRequest] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, v) => + jsonFields.collect { case JField(k, v) => val req = taskResourceRequestFromJson(v) (k, req) }.toMap @@ -765,7 +765,7 @@ private[spark] object JsonProtocol { def executorResourceRequestMapFromJson(json: JValue): Map[String, ExecutorResourceRequest] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, v) => + jsonFields.collect { case JField(k, v) => val req = executorResourceRequestFromJson(v) (k, req) }.toMap @@ -1229,7 +1229,7 @@ private[spark] object JsonProtocol { def resourcesMapFromJson(json: JValue): Map[String, ResourceInformation] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, v) => + jsonFields.collect { case JField(k, v) => val resourceInfo = ResourceInformation.parseJson(v) (k, resourceInfo) }.toMap @@ -1241,7 +1241,7 @@ private[spark] object JsonProtocol { def mapFromJson(json: JValue): Map[String, String] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, JString(v)) => (k, v) }.toMap + jsonFields.collect { case JField(k, JString(v)) => (k, v) }.toMap } def propertiesFromJson(json: JValue): Properties = { diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 368f177cda828..b6c1b011f004c 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -302,6 +302,8 @@ private[spark] object BLAS extends Serializable { j += 1 prevCol = col } + case _ => + throw new IllegalArgumentException(s"spr doesn't support vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index dbbfd8f329431..c5b28c95eb7c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -286,6 +286,7 @@ private[ml] object RFormulaParser extends RegexParsers { private val pow: Parser[Term] = term ~ "^" ~ "^[1-9]\\d*".r ^^ { case base ~ "^" ~ degree => power(base, degree.toInt) + case t => throw new IllegalArgumentException(s"Invalid term: $t") } | term private val interaction: Parser[Term] = pow * (":" ^^^ { interact _ }) @@ -298,7 +299,10 @@ private[ml] object RFormulaParser extends RegexParsers { private val expr = (sum | term) private val formula: Parser[ParsedRFormula] = - (label ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.asTerms.terms) } + (label ~ "~" ~ expr) ^^ { + case r ~ "~" ~ t => ParsedRFormula(r, t.asTerms.terms) + case t => throw new IllegalArgumentException(s"Invalid term: $t") + } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { case Success(result, _) => result diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 7434b1adb2ff2..92dee46ad0055 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -314,6 +314,8 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { case SparseVector(size, indices, values) => val newValues = transformSparseWithScale(scale, indices, values.clone()) Vectors.sparse(size, indices, newValues) + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } case (false, false) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala index 0bee643412b3f..8f03a29eb991a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala @@ -74,6 +74,8 @@ private[ml] object JsonMatrixConverter { ("values" -> values.toSeq) ~ ("isTransposed" -> isTransposed) compact(render(jValue)) + case _ => + throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.") } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala index 781e69f8d63db..1b949d75eeaa0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala @@ -57,6 +57,8 @@ private[ml] object JsonVectorConverter { case DenseVector(values) => val jValue = ("type" -> 1) ~ ("values" -> values.toSeq) compact(render(jValue)) + case _ => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index 37f173bc20469..35bbaf5aa1ded 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -45,6 +45,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { row.setNullAt(2) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala index 3d72512563154..0fe1ed231aa83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala @@ -200,6 +200,9 @@ private[ml] class BlockHingeAggregator( case sm: SparseMatrix if !fitIntercept => val gradSumVec = new DenseVector(gradientSumArray) BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec) + + case m => + throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.") } if (fitIntercept) gradientSumArray(numFeatures) += vec.values.sum diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala index 2496c789f8da6..5a516940b9788 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala @@ -504,6 +504,9 @@ private[ml] class BlockLogisticAggregator( case sm: SparseMatrix if !fitIntercept => val gradSumVec = new DenseVector(gradientSumArray) BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec) + + case m => + throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.") } if (fitIntercept) gradientSumArray(numFeatures) += vec.values.sum diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index d4b39e11fd1d7..2215c2b071584 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -192,6 +192,8 @@ private[spark] object Instrumentation { case Failure(NonFatal(e)) => instr.logFailure(e) throw e + case Failure(e) => + throw e case Success(result) => instr.logSuccess() result diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 8f9d6d07a4c36..12a5a0f2b2189 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -167,6 +167,8 @@ class StandardScalerModel @Since("1.3.0") ( val newValues = NewStandardScalerModel .transformSparseWithScale(localScale, indices, values.clone()) Vectors.sparse(size, indices, newValues) + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } case _ => vector diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index da486010cfa9e..bd60364326e28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -285,6 +285,8 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 prevCol = col } + case _ => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 2fe415f14032f..9ed9dd0c88c9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -289,6 +289,8 @@ class VectorUDT extends UserDefinedType[Vector] { row.setNullAt(2) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index ad79230c7513c..da5d1650694d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -145,6 +145,8 @@ class IndexedRowMatrix @Since("1.0.0") ( .map { case (values, blockColumn) => ((blockRow.toInt, blockColumn), (rowInBlock.toInt, values.zipWithIndex)) } + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } }.groupByKey(GridPartitioner(numRowBlocks, numColBlocks, rows.getNumPartitions)).map { case ((blockRow, blockColumn), itr) => @@ -187,6 +189,8 @@ class IndexedRowMatrix @Since("1.0.0") ( Iterator.tabulate(indices.length)(i => MatrixEntry(rowIndex, indices(i), values(i))) case DenseVector(values) => Iterator.tabulate(values.length)(i => MatrixEntry(rowIndex, i, values(i))) + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } new CoordinateMatrix(entries, numRows(), numCols()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 07b9d91c1f59b..c618b71ddc5a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -748,6 +748,8 @@ class RowMatrix @Since("1.0.0") ( } buf }.flatten + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } }.reduceByKey(_ + _).map { case ((i, j), sim) => diff --git a/pom.xml b/pom.xml index 0ab5a8c5b3efa..e5b1f30edd3be 100644 --- a/pom.xml +++ b/pom.xml @@ -3264,7 +3264,7 @@ scala-2.13 - 2.13.3 + 2.13.4 2.13 diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index b5a360167679e..4620bdb005094 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -313,7 +313,6 @@ trait MesosSchedulerUtils extends Logging { // offer has the required attribute and subsumes the required values for that attribute case (name, requiredValues) => offerAttributes.get(name) match { - case None => false case Some(_) if requiredValues.isEmpty => true // empty value matches presence case Some(scalarValue: Value.Scalar) => // check if provided values is less than equal to the offered values @@ -332,6 +331,7 @@ trait MesosSchedulerUtils extends Logging { // check if the specified value is equal, if multiple values are specified // we succeed if any of them match. requiredValues.contains(textValue.getValue) + case _ => false } } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 67ecf3242f52d..6a6514569cf90 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -178,7 +178,7 @@ class MesosFineGrainedSchedulerBackendSuite val (execInfo, _) = backend.createExecutorInfo( Arrays.asList(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) - assert(execInfo.getContainer.getDocker.getForcePullImage.equals(true)) + assert(execInfo.getContainer.getDocker.getForcePullImage) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) assert(portmaps.get(0).getContainerPort.equals(8080)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 39d9eb5a36964..a363615d3afe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -94,7 +94,7 @@ private[this] object JsonPathParser extends RegexParsers { case Success(result, _) => Some(result) - case NoSuccess(msg, next) => + case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 1e69814673082..810cecff379d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -322,7 +322,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { case (a: Array[Byte], b: Array[Byte]) => util.Arrays.equals(a, b) case (a: ArrayBasedMapData, b: ArrayBasedMapData) => a.keyArray == b.keyArray && a.valueArray == b.valueArray - case (a, b) => a != null && a.equals(b) + case (a: Double, b: Double) if a.isNaN && b.isNaN => true + case (a: Float, b: Float) if a.isNaN && b.isNaN => true + case (a, b) => a != null && a == b } case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9701420e65870..9303df75af503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -981,7 +981,7 @@ case class MapObjects private( (genValue: String) => s"$builder.add($genValue);", s"$builder;" ) - case None => + case _ => // array ( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index de396a4c63458..a39f06628b9ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -190,6 +190,9 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { } case VALUE_TRUE | VALUE_FALSE => BooleanType + + case _ => + throw new SparkException("Malformed JSON") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index b65fc7f7e2bde..bf3fced0ae0fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -197,9 +197,9 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { } else { false } - case None => false + case _ => false } - case None => false + case _ => false } case _ => false } @@ -239,7 +239,7 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { case Some(col) if t.outputSet.contains(col) => val stats = t.stats stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) - case None => false + case _ => false } case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 55a45f4410b34..d1eb3b07d3d5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -685,6 +685,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { case LeftOuter => newJoin.right.output case RightOuter => newJoin.left.output case FullOuter => newJoin.left.output ++ newJoin.right.output + case _ => Nil }) val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { case (attr, _) => missDerivedAttrsSet.contains(attr) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ea4baafbacede..50580b8e335ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -967,6 +967,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) case Some(c) if c.booleanExpression != null => (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw new ParseException(s"Unimplemented joinCriteria: $c", ctx) case None if join.NATURAL != null => if (baseJoinType == Cross) { throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f96e07863fa69..c7108ea8ac74b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -362,7 +362,7 @@ case class Join( left.constraints case RightOuter => right.constraints - case FullOuter => + case _ => ExpressionSet() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 81f412c14304d..e46d730afb4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -120,7 +120,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } - case _ => if (!o1.equals(o2)) { + case _ => if (o1.getClass != o2.getClass || o1 != o2) { return false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala index 7790f467a890b..1290f770349e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala @@ -39,6 +39,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colB) assert(projects(1) === aliasR) assert(filters.size === 1) + case _ => assert(false) } } @@ -50,6 +51,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colA) assert(projects(1) === colB) assert(filters.size === 1) + case _ => assert(false) } } @@ -65,6 +67,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects.size === 2) assert(projects(0) === colA) assert(projects(1) === aliasId) + case _ => assert(false) } } @@ -81,6 +84,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colA) assert(projects(1) === aliasR) assert(filters.size === 1) + case _ => assert(false) } } @@ -93,6 +97,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colA) assert(projects(1) === aliasR) assert(filters.size === 1) + case _ => assert(false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index 1e430351b5137..9c3aaea0f7772 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -45,7 +45,7 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { if (e != null) { elementDt match { // For Nan, etc. - case FloatType | DoubleType => assert(seq(i).equals(e)) + case FloatType | DoubleType => assert(seq(i) == e) case _ => assert(seq(i) === e) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 85476bcd21e19..01522257c072d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -868,12 +868,12 @@ class SparkSqlAstBuilder extends AstBuilder { // assert if directory is local when LOCAL keyword is mentioned val scheme = Option(storage.locationUri.get.getScheme) scheme match { - case None => + case Some(pathScheme) if (!pathScheme.equals("file")) => + throw new ParseException("LOCAL is supported only with file: scheme", ctx) + case _ => // force scheme to be file rather than fs.default.name val loc = Some(UriBuilder.fromUri(CatalogUtils.stringToURI(path)).scheme("file").build()) storage = storage.copy(locationUri = loc) - case Some(pathScheme) if (!pathScheme.equals("file")) => - throw new ParseException("LOCAL is supported only with file: scheme", ctx) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index efba51706cf98..c676609bc37e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -91,7 +91,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case Some(exprs) => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index c6b98d48d7dde..9832e5cd74ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -71,6 +71,9 @@ trait WindowExecBase extends UnaryExecNode { case (RowFrame, IntegerLiteral(offset)) => RowBoundOrdering(offset) + case (RowFrame, _) => + sys.error(s"Unhandled bound in windows expressions: $bound") + case (RangeFrame, CurrentRow) => val ordering = RowOrdering.create(orderSpec, child.output) RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) @@ -249,6 +252,9 @@ trait WindowExecBase extends UnaryExecNode { createBoundOrdering(frameType, lower, timeZone), createBoundOrdering(frameType, upper, timeZone)) } + + case _ => + sys.error(s"Unsupported factory: $key") } // Keep track of the number of expressions. This is a side-effect in a map... diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 8ab6e28366753..9213173bbc9ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -1039,6 +1039,7 @@ private[hive] trait HiveInspectors { private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case dt => throw new AnalysisException(s"${dt.catalogString} is not supported.") } def toTypeInfo: TypeInfo = dt match { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 2e5000159bcb7..d1f9dfb791355 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -293,7 +293,7 @@ private[streaming] object FileBasedWriteAheadLog { val startTime = startTimeStr.toLong val stopTime = stopTimeStr.toLong Some(LogInfo(startTime, stopTime, file.toString)) - case None => + case None | Some(_) => None } }.sortBy { _.startTime } From 8380e00419281cd1b1fc5706d23d5231356a3379 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 23 Nov 2020 19:35:58 -0800 Subject: [PATCH 34/38] [SPARK-33524][SQL][TESTS] Change `InMemoryTable` not to use Tuple.hashCode for `BucketTransform` ### What changes were proposed in this pull request? This PR aims to change `InMemoryTable` not to use `Tuple.hashCode` for `BucketTransform`. ### Why are the changes needed? SPARK-32168 made `InMemoryTable` to handle `BucketTransform` as a hash of `Tuple` which is dependents on Scala versions. - https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala#L159 **Scala 2.12.10** ```scala $ bin/scala Welcome to Scala 2.12.10 (OpenJDK 64-Bit Server VM, Java 1.8.0_272). Type in expressions for evaluation. Or try :help. scala> (1, 1).hashCode res0: Int = -2074071657 ``` **Scala 2.13.3** ```scala Welcome to Scala 2.13.3 (OpenJDK 64-Bit Server VM, Java 1.8.0_272). Type in expressions for evaluation. Or try :help. scala> (1, 1).hashCode val res0: Int = -1669302457 ``` ### Does this PR introduce _any_ user-facing change? Yes. This is a correctness issue. ### How was this patch tested? Pass the UT with both Scala 2.12/2.13. Closes #30477 from dongjoon-hyun/SPARK-33524. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/sql/connector/InMemoryTable.scala | 4 +++- .../org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index c93053abc550a..ffff00b54f1b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -156,7 +156,9 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case BucketTransform(numBuckets, ref) => - (extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets + val (value, dataType) = extractor(ref.fieldNames, schema, row) + val valueHashCode = if (value == null) 0 else value.hashCode + ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index da53936239de8..dc4abf3eb19cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2511,7 +2511,7 @@ class DataSourceV2SQLSuite checkAnswer( spark.sql(s"SELECT id, data, _partition FROM $t1"), - Seq(Row(1, "a", "3/1"), Row(2, "b", "2/2"), Row(3, "c", "2/3"))) + Seq(Row(1, "a", "3/1"), Row(2, "b", "0/2"), Row(3, "c", "1/3"))) } } @@ -2524,7 +2524,7 @@ class DataSourceV2SQLSuite checkAnswer( spark.sql(s"SELECT index, data, _partition FROM $t1"), - Seq(Row(3, "c", "2/3"), Row(2, "b", "2/2"), Row(1, "a", "3/1"))) + Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) } } From f35e28fea5605de4b28630eb643a821ecd7c8523 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Nov 2020 13:30:06 +0900 Subject: [PATCH 35/38] [SPARK-33523][SQL][TEST] Add predicate related benchmark to SubExprEliminationBenchmark ### What changes were proposed in this pull request? This patch adds predicate related benchmark to `SubExprEliminationBenchmark`. ### Why are the changes needed? We should have a benchmark for subexpression elimination of predicate. ### Does this PR introduce _any_ user-facing change? No, dev only. ### How was this patch tested? Run benchmark locally. Closes #30476 from viirya/SPARK-33523. Authored-by: Liang-Chi Hsieh Signed-off-by: HyukjinKwon --- ...ExprEliminationBenchmark-jdk11-results.txt | 22 +++- .../SubExprEliminationBenchmark-results.txt | 22 +++- .../SubExprEliminationBenchmark.scala | 106 ++++++++++-------- 3 files changed, 90 insertions(+), 60 deletions(-) diff --git a/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt b/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt index 3d2b2e5c8edba..1eb7b534d2194 100644 --- a/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt @@ -5,11 +5,21 @@ Benchmark for performance of subexpression elimination Preparing data for benchmarking ... OpenJDK 64-Bit Server VM 11.0.9+11 on Mac OS X 10.15.6 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz -from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -subexpressionElimination off, codegen on 25932 26908 916 0.0 259320042.3 1.0X -subexpressionElimination off, codegen off 26085 26159 65 0.0 260848905.0 1.0X -subexpressionElimination on, codegen on 2860 2939 72 0.0 28603312.9 9.1X -subexpressionElimination on, codegen off 2517 2617 93 0.0 25165157.7 10.3X +from_json as subExpr in Project: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subExprElimination false, codegen: true 26447 27127 605 0.0 264467933.4 1.0X +subExprElimination false, codegen: false 25673 26035 546 0.0 256732419.1 1.0X +subExprElimination true, codegen: true 1384 1448 102 0.0 13842910.3 19.1X +subExprElimination true, codegen: false 1244 1347 123 0.0 12442389.3 21.3X + +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 11.0.9+11 on Mac OS X 10.15.6 +Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz +from_json as subExpr in Filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subexpressionElimination off, codegen on 34631 35449 833 0.0 346309884.0 1.0X +subexpressionElimination off, codegen on 34480 34851 353 0.0 344798490.4 1.0X +subexpressionElimination off, codegen on 16618 16811 291 0.0 166176642.6 2.1X +subexpressionElimination off, codegen on 34316 34667 310 0.0 343157094.7 1.0X diff --git a/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt b/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt index ca2a9c6497500..801f519ca76a1 100644 --- a/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt +++ b/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt @@ -5,11 +5,21 @@ Benchmark for performance of subexpression elimination Preparing data for benchmarking ... OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz -from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -subexpressionElimination off, codegen on 26503 27622 1937 0.0 265033362.4 1.0X -subexpressionElimination off, codegen off 24920 25376 430 0.0 249196978.2 1.1X -subexpressionElimination on, codegen on 2421 2466 39 0.0 24213606.1 10.9X -subexpressionElimination on, codegen off 2360 2435 87 0.0 23604320.7 11.2X +from_json as subExpr in Project: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subExprElimination false, codegen: true 22767 23240 424 0.0 227665316.7 1.0X +subExprElimination false, codegen: false 22869 23351 465 0.0 228693464.1 1.0X +subExprElimination true, codegen: true 1328 1340 10 0.0 13280056.2 17.1X +subExprElimination true, codegen: false 1248 1276 31 0.0 12476135.1 18.2X + +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6 +Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz +from_json as subExpr in Filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subexpressionElimination off, codegen on 37691 38846 1004 0.0 376913767.9 1.0X +subexpressionElimination off, codegen on 37852 39124 1103 0.0 378517745.5 1.0X +subexpressionElimination off, codegen on 22900 23085 202 0.0 229000242.5 1.6X +subexpressionElimination off, codegen on 38298 38598 374 0.0 382978731.3 1.0X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala index 34b4a70d05a25..e26acbcb3cd21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Or} import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -39,7 +41,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { import spark.implicits._ def withFromJson(rowsNum: Int, numIters: Int): Unit = { - val benchmark = new Benchmark("from_json as subExpr", rowsNum, output = output) + val benchmark = new Benchmark("from_json as subExpr in Project", rowsNum, output = output) withTempPath { path => prepareDataInfo(benchmark) @@ -50,57 +52,65 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { from_json('value, schema).getField(s"col$idx") } - // We only benchmark subexpression performance under codegen/non-codegen, so disabling - // json optimization. - benchmark.addCase("subexpressionElimination off, codegen on", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "false", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", - SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() + Seq( + ("false", "true", "CODEGEN_ONLY"), + ("false", "false", "NO_CODEGEN"), + ("true", "true", "CODEGEN_ONLY"), + ("true", "false", "NO_CODEGEN") + ).foreach { case (subExprEliminationEnabled, codegenEnabled, codegenFactory) => + // We only benchmark subexpression performance under codegen/non-codegen, so disabling + // json optimization. + val caseName = s"subExprElimination $subExprEliminationEnabled, codegen: $codegenEnabled" + benchmark.addCase(caseName, numIters) { _ => + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactory, + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { + val df = spark.read + .text(path.getAbsolutePath) + .select(cols: _*) + df.write.mode("overwrite").format("noop").save() + } } } - benchmark.addCase("subexpressionElimination off, codegen off", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "false", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", - SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() - } - } + benchmark.run() + } + } - benchmark.addCase("subexpressionElimination on, codegen on", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "true", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", - SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() - } - } + def withFilter(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark("from_json as subExpr in Filter", rowsNum, output = output) - benchmark.addCase("subexpressionElimination on, codegen off", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "true", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", - SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() + withTempPath { path => + prepareDataInfo(benchmark) + val numCols = 1000 + val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) + + val predicate = (0 until numCols).map { idx => + (from_json('value, schema).getField(s"col$idx") >= Literal(100000)).expr + }.asInstanceOf[Seq[Expression]].reduce(Or) + + Seq( + ("false", "true", "CODEGEN_ONLY"), + ("false", "false", "NO_CODEGEN"), + ("true", "true", "CODEGEN_ONLY"), + ("true", "false", "NO_CODEGEN") + ).foreach { case (subExprEliminationEnabled, codegenEnabled, codegenFactory) => + // We only benchmark subexpression performance under codegen/non-codegen, so disabling + // json optimization. + val caseName = s"subExprElimination $subExprEliminationEnabled, codegen: $codegenEnabled" + benchmark.addCase("subexpressionElimination off, codegen on", numIters) { _ => + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactory, + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { + val df = spark.read + .text(path.getAbsolutePath) + .where(Column(predicate)) + df.write.mode("overwrite").format("noop").save() + } } } @@ -108,11 +118,11 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { } } - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val numIters = 3 runBenchmark("Benchmark for performance of subexpression elimination") { withFromJson(100, numIters) + withFilter(100, numIters) } } } From a6555ee59626bbc4ef860c4ff9fcefae0d45b45e Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Tue, 24 Nov 2020 08:04:21 +0000 Subject: [PATCH 36/38] [SPARK-33521][SQL] Universal type conversion in resolving V2 partition specs ### What changes were proposed in this pull request? In the PR, I propose to changes the resolver of partition specs used in V2 `ALTER TABLE .. ADD/DROP PARTITION` (at the moment), and re-use `CAST` in conversion partition values to desired types according to the partition schema. ### Why are the changes needed? Currently, the resolver of V2 partition specs supports just a few types: https://github.com/apache/spark/blob/23e9920b3910e4f05269853429c7f18888cdc7b5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala#L72, and fails on other types like date/timestamp. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? By running `AlterTablePartitionV2SQLSuite` Closes #30474 from MaxGekk/dsv2-partition-value-types. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../analysis/ResolvePartitionSpec.scala | 29 +--------- .../AlterTablePartitionV2SQLSuite.scala | 58 +++++++++++++++++++ 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 531d40f431dee..6d061fce06919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, AlterTableDropPartition, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement @@ -65,31 +65,8 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { conf.resolver) val partValues = partSchema.map { part => - val partValue = normalizedSpec.get(part.name).orNull - if (partValue == null) { - null - } else { - // TODO: Support other datatypes, such as DateType - part.dataType match { - case _: ByteType => - partValue.toByte - case _: ShortType => - partValue.toShort - case _: IntegerType => - partValue.toInt - case _: LongType => - partValue.toLong - case _: FloatType => - partValue.toFloat - case _: DoubleType => - partValue.toDouble - case _: StringType => - partValue - case _ => - throw new AnalysisException( - s"Type ${part.dataType.typeName} is not supported for partition.") - } - } + val raw = normalizedSpec.get(part.name).orNull + Cast(Literal.create(raw, StringType), part.dataType, Some(conf.sessionLocalTimeZone)).eval() } InternalRow.fromSeq(partValues) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala index e05c2c09ace2a..4cacd5ec2b49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.connector +import java.time.{LocalDate, LocalDateTime} + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionsException, PartitionsAlreadyExistException} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.unsafe.types.UTF8String class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { @@ -185,4 +189,58 @@ class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { } } } + + test("SPARK-33521: universal type conversions of partition values") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + sql(s""" + |CREATE TABLE $t ( + | part0 tinyint, + | part1 smallint, + | part2 int, + | part3 bigint, + | part4 float, + | part5 double, + | part6 string, + | part7 boolean, + | part8 date, + | part9 timestamp + |) USING foo + |PARTITIONED BY (part0, part1, part2, part3, part4, part5, part6, part7, part8, part9) + |""".stripMargin) + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) + .asPartitionable + val expectedPartition = InternalRow.fromSeq(Seq[Any]( + -1, // tinyint + 0, // smallint + 1, // int + 2, // bigint + 3.14F, // float + 3.14D, // double + UTF8String.fromString("abc"), // string + true, // boolean + LocalDate.parse("2020-11-23").toEpochDay, + DateTimeUtils.instantToMicros( + LocalDateTime.parse("2020-11-23T22:13:10.123456").atZone(DateTimeTestUtils.LA).toInstant) + )) + assert(!partTable.partitionExists(expectedPartition)) + val partSpec = """ + | part0 = -1, + | part1 = 0, + | part2 = 1, + | part3 = 2, + | part4 = 3.14, + | part5 = 3.14, + | part6 = 'abc', + | part7 = true, + | part8 = '2020-11-23', + | part9 = '2020-11-23T22:13:10.123456' + |""".stripMargin + sql(s"ALTER TABLE $t ADD PARTITION ($partSpec) LOCATION 'loc1'") + assert(partTable.partitionExists(expectedPartition)) + sql(s" ALTER TABLE $t DROP PARTITION ($partSpec)") + assert(!partTable.partitionExists(expectedPartition)) + } + } } From fdd6c73b3cfac5af30c789c7f70b92367a79f7e7 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 24 Nov 2020 11:06:39 +0000 Subject: [PATCH 37/38] [SPARK-33514][SQL] Migrate TRUNCATE TABLE command to use UnresolvedTable to resolve the identifier ### What changes were proposed in this pull request? This PR proposes to migrate `TRUNCATE TABLE` to use `UnresolvedTable` to resolve the table identifier. This allows consistent resolution rules (temp view first, etc.) to be applied for both v1/v2 commands. More info about the consistent resolution rule proposal can be found in [JIRA](https://issues.apache.org/jira/browse/SPARK-29900) or [proposal doc](https://docs.google.com/document/d/1hvLjGA8y_W_hhilpngXVub1Ebv8RsMap986nENCFnrg/edit?usp=sharing). Note that `TRUNCATE TABLE` works only with v1 tables, and not supported for v2 tables. ### Why are the changes needed? The changes allow consistent resolution behavior when resolving the table identifier. For example, the following is the current behavior: ```scala sql("CREATE TEMPORARY VIEW t AS SELECT 1") sql("CREATE DATABASE db") sql("CREATE TABLE t using csv AS SELECT 1") sql("USE db") sql("TRUNCATE TABLE t") // Succeeds ``` With this PR, `TRUNCATE TABLE` above fails with the following: ``` org.apache.spark.sql.AnalysisException: t is a temp view not table.; line 1 pos 0 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTempViews$$anonfun$apply$7.$anonfun$applyOrElse$42(Analyzer.scala:866) ``` , which is expected since temporary view is resolved first and `TRUNCATE TABLE` doesn't support a temporary view. ### Does this PR introduce _any_ user-facing change? After this PR, `TRUNCATE TABLE` is resolved to a temp view `t` instead of table `db.t` in the above scenario. ### How was this patch tested? Updated existing tests. Closes #30457 from imback82/truncate_table. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/AstBuilder.scala | 6 +++--- .../sql/catalyst/plans/logical/v2Commands.scala | 9 +++++++++ .../spark/sql/catalyst/parser/DDLParserSuite.scala | 6 ++++-- .../catalyst/analysis/ResolveSessionCatalog.scala | 5 ++--- .../datasources/v2/DataSourceV2Strategy.scala | 3 +++ .../spark/sql/connector/DataSourceV2SQLSuite.scala | 4 ++-- .../apache/spark/sql/execution/SQLViewSuite.scala | 13 ++++++++----- .../spark/sql/execution/command/DDLSuite.scala | 10 +++++++--- 8 files changed, 38 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 50580b8e335ff..a4298abd211b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3356,7 +3356,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create a [[TruncateTableStatement]] command. + * Create a [[TruncateTable]] command. * * For example: * {{{ @@ -3364,8 +3364,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { - TruncateTableStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), + TruncateTable( + UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier), "TRUNCATE TABLE"), Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 5bda2b5b8db01..a65b9fc59bd55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -670,3 +670,12 @@ case class LoadData( case class ShowCreateTable(child: LogicalPlan, asSerde: Boolean = false) extends Command { override def children: Seq[LogicalPlan] = child :: Nil } + +/** + * The logical plan of the TRUNCATE TABLE command. + */ +case class TruncateTable( + child: LogicalPlan, + partitionSpec: Option[TablePartitionSpec]) extends Command { + override def children: Seq[LogicalPlan] = child :: Nil +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index bd28484b23f46..997c642276bfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1621,11 +1621,13 @@ class DDLParserSuite extends AnalysisTest { test("TRUNCATE table") { comparePlans( parsePlan("TRUNCATE TABLE a.b.c"), - TruncateTableStatement(Seq("a", "b", "c"), None)) + TruncateTable(UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE"), None)) comparePlans( parsePlan("TRUNCATE TABLE a.b.c PARTITION(ds='2017-06-10')"), - TruncateTableStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10")))) + TruncateTable( + UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE"), + Some(Map("ds" -> "2017-06-10")))) } test("REFRESH TABLE") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 303ae47f06b84..726099991a897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -456,10 +456,9 @@ class ResolveSessionCatalog( val name = parseTempViewOrV1Table(tbl, "UNCACHE TABLE") UncacheTableCommand(name.asTableIdentifier, ifExists) - case TruncateTableStatement(tbl, partitionSpec) => - val v1TableName = parseV1Table(tbl, "TRUNCATE TABLE") + case TruncateTable(ResolvedV1TableIdentifier(ident), partitionSpec) => TruncateTableCommand( - v1TableName.asTableIdentifier, + ident.asTableIdentifier, partitionSpec) case ShowPartitionsStatement(tbl, partitionSpec) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index e5c29312b80e7..30d976524bfa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -302,6 +302,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case ShowCreateTable(_: ResolvedTable, _) => throw new AnalysisException("SHOW CREATE TABLE is not supported for v2 tables.") + case TruncateTable(_: ResolvedTable, _) => + throw new AnalysisException("TRUNCATE TABLE is not supported for v2 tables.") + case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index dc4abf3eb19cf..9a3fa0c5bd3f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1986,8 +1986,8 @@ class DataSourceV2SQLSuite |PARTITIONED BY (id) """.stripMargin) - testV1Command("TRUNCATE TABLE", t) - testV1Command("TRUNCATE TABLE", s"$t PARTITION(id='1')") + testNotSupportedV2Command("TRUNCATE TABLE", t) + testNotSupportedV2Command("TRUNCATE TABLE", s"$t PARTITION(id='1')") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 504cc57dc12d3..edeebde7db726 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -176,15 +176,18 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""") }.getMessage assert(e2.contains(s"$viewName is a temp view. 'LOAD DATA' expects a table")) - assertNoSuchTable(s"TRUNCATE TABLE $viewName") val e3 = intercept[AnalysisException] { - sql(s"SHOW CREATE TABLE $viewName") + sql(s"TRUNCATE TABLE $viewName") }.getMessage - assert(e3.contains(s"$viewName is a temp view not table or permanent view")) + assert(e3.contains(s"$viewName is a temp view. 'TRUNCATE TABLE' expects a table")) val e4 = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + sql(s"SHOW CREATE TABLE $viewName") }.getMessage assert(e4.contains(s"$viewName is a temp view not table or permanent view")) + val e5 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + }.getMessage + assert(e5.contains(s"$viewName is a temp view not table or permanent view")) assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") } } @@ -219,7 +222,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { e = intercept[AnalysisException] { sql(s"TRUNCATE TABLE $viewName") }.getMessage - assert(e.contains(s"Operation not allowed: TRUNCATE TABLE on views: `default`.`testview`")) + assert(e.contains("default.testView is a view. 'TRUNCATE TABLE' expects a table")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 43a33860d262e..07201f9f85b5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2169,11 +2169,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { (1 to 10).map { i => (i, i) }.toDF("a", "b").createTempView("my_temp_tab") sql(s"CREATE TABLE my_ext_tab using parquet LOCATION '${tempDir.toURI}'") sql(s"CREATE VIEW my_view AS SELECT 1") - intercept[NoSuchTableException] { + val e1 = intercept[AnalysisException] { sql("TRUNCATE TABLE my_temp_tab") - } + }.getMessage + assert(e1.contains("my_temp_tab is a temp view. 'TRUNCATE TABLE' expects a table")) assertUnsupported("TRUNCATE TABLE my_ext_tab") - assertUnsupported("TRUNCATE TABLE my_view") + val e2 = intercept[AnalysisException] { + sql("TRUNCATE TABLE my_view") + }.getMessage + assert(e2.contains("default.my_view is a view. 'TRUNCATE TABLE' expects a table")) } } } From 048a9821c788b6796d52d1e2a0cd174377ebd0f0 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 24 Nov 2020 09:50:10 -0800 Subject: [PATCH 38/38] [SPARK-33535][INFRA][TESTS] Export LANG to en_US.UTF-8 in run-tests-jenkins script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? It seems that Jenkins tests tasks in many pr have test failed. The failed cases include: - `org.apache.spark.sql.hive.thriftserver.SparkThriftServerProtocolVersionsSuite.HIVE_CLI_SERVICE_PROTOCOL_V1 get binary type` - `org.apache.spark.sql.hive.thriftserver.SparkThriftServerProtocolVersionsSuite.HIVE_CLI_SERVICE_PROTOCOL_V2 get binary type` - `org.apache.spark.sql.hive.thriftserver.SparkThriftServerProtocolVersionsSuite.HIVE_CLI_SERVICE_PROTOCOL_V3 get binary type` - `org.apache.spark.sql.hive.thriftserver.SparkThriftServerProtocolVersionsSuite.HIVE_CLI_SERVICE_PROTOCOL_V4 get binary type` - `org.apache.spark.sql.hive.thriftserver.SparkThriftServerProtocolVersionsSuite.HIVE_CLI_SERVICE_PROTOCOL_V5 get binary type` The error message as follows: ``` Error Messageorg.scalatest.exceptions.TestFailedException: "[?](" did not equal "[�]("Stacktracesbt.ForkMain$ForkError: org.scalatest.exceptions.TestFailedException: "[?](" did not equal "[�](" at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:472) at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:471) at org.scalatest.Assertions$.newAssertionFailedException(Assertions.scala:1231) at org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:1295) at org.apache.spark.sql.hive.thriftserver.SparkThriftServerProtocolVersionsSuite.$anonfun$new$26(SparkThriftServerProtocolVersionsSuite.scala:302) ``` But they can pass the GitHub Action, maybe it's related to the `LANG` of the Jenkins build machine, this pr add `export LANG="en_US.UTF-8"` in `run-test-jenkins` script. ### Why are the changes needed? Ensure LANG in Jenkins test process is `en_US.UTF-8` to pass `HIVE_CLI_SERVICE_PROTOCOL_VX` related tests ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Jenkins tests pass Closes #30487 from LuciferYang/SPARK-33535. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- dev/run-tests-jenkins | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index c3adc696a5122..c155d4ea3f076 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -26,6 +26,7 @@ FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" export PATH=/home/anaconda/envs/py36/bin:$PATH +export LANG="en_US.UTF-8" PYTHON_VERSION_CHECK=$(python3 -c 'import sys; print(sys.version_info < (3, 6, 0))') if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then