diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ba8e4d69ba755..d21b9d9833e9e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} +import org.apache.spark.storage.BlockManagerId /** * :: DeveloperApi :: @@ -95,6 +96,20 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( shuffleId, this) + /** + * Stores the location of the list of chosen external shuffle services for handling the + * shuffle merge requests from mappers in this shuffle map stage. + */ + private[spark] var mergerLocs: Seq[BlockManagerId] = Nil + + def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = { + if (mergerLocs != null) { + this.mergerLocs = mergerLocs + } + } + + def getMergerLocs: Seq[BlockManagerId] = mergerLocs + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 136da80d48dee..f49cb3c2b8836 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -80,6 +80,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) protected val simplifiedTraceback: Boolean = false @@ -139,6 +140,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (workerMemoryMb.isDefined) { envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) } + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 527d0d6d3a48d..33849f6fcb65f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -85,4 +85,8 @@ private[spark] object PythonUtils { def getBroadcastThreshold(sc: JavaSparkContext): Long = { sc.conf.get(org.apache.spark.internal.config.BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) } + + def getPythonAuthSocketTimeout(sc: JavaSparkContext): Long = { + sc.conf.get(org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala index 188d884319644..348a33e129d65 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -50,4 +50,10 @@ private[spark] object Python { .version("2.4.0") .bytesConf(ByteUnit.MiB) .createOptional + + val PYTHON_AUTH_SOCKET_TIMEOUT = ConfigBuilder("spark.python.authenticate.socketTimeout") + .internal() + .version("3.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("15s") } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4bc49514fc5ad..b38d0e5c617b9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1945,4 +1945,51 @@ package object config { .version("3.0.1") .booleanConf .createWithDefault(false) + + private[spark] val PUSH_BASED_SHUFFLE_ENABLED = + ConfigBuilder("spark.shuffle.push.enabled") + .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " + + "conjunction with the server side flag spark.shuffle.server.mergedShuffleFileManagerImpl " + + "which needs to be set with the appropriate " + + "org.apache.spark.network.shuffle.MergedShuffleFileManager implementation for push-based " + + "shuffle to be enabled") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS = + ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations") + .doc("Maximum number of shuffle push merger locations cached for push based shuffle. " + + "Currently, shuffle push merger locations are nothing but external shuffle services " + + "which are responsible for handling pushed blocks and merging them and serving " + + "merged blocks for later shuffle fetch.") + .version("3.1.0") + .intConf + .createWithDefault(500) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO = + ConfigBuilder("spark.shuffle.push.mergersMinThresholdRatio") + .doc("The minimum number of shuffle merger locations required to enable push based " + + "shuffle for a stage. This is specified as a ratio of the number of partitions in " + + "the child stage. For example, a reduce stage which has 100 partitions and uses the " + + "default value 0.05 requires at least 5 unique merger locations to enable push based " + + "shuffle. Merger locations are currently defined as external shuffle services.") + .version("3.1.0") + .doubleConf + .createWithDefault(0.05) + + private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD = + ConfigBuilder("spark.shuffle.push.mergersMinStaticThreshold") + .doc(s"The static threshold for number of shuffle push merger locations should be " + + "available in order to enable push based shuffle for a stage. Note this config " + + s"works in conjunction with ${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key}. " + + "Maximum of spark.shuffle.push.mergersMinStaticThreshold and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} ratio number of mergers needed to " + + "enable push based shuffle for a stage. For eg: with 1000 partitions for the child " + + "stage with spark.shuffle.push.mergersMinStaticThreshold as 5 and " + + s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " + + "at least 50 mergers to enable push based shuffle for that stage.") + .version("3.1.0") + .doubleConf + .createWithDefault(5) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 13b766e654832..6fb0fb93f253b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -249,6 +249,8 @@ private[spark] class DAGScheduler( private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf) + /** * Called by the TaskSetManager to report task's starting. */ @@ -1252,6 +1254,33 @@ private[spark] class DAGScheduler( execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores)) } + /** + * If push based shuffle is enabled, set the shuffle services to be used for the given + * shuffle map stage for block push/merge. + * + * Even with dynamic resource allocation kicking in and significantly reducing the number + * of available active executors, we would still be able to get sufficient shuffle service + * locations for block push/merge by getting the historical locations of past executors. + */ + private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { + // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize + // TODO changes we cannot disable shuffle merge for the retry/reuse cases + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + + logDebug("List of shuffle push merger locations " + + s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + } else { + logInfo("No available merger locations." + + s" Push-based shuffle disabled for $stage (${stage.name})") + } + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") @@ -1281,6 +1310,12 @@ private[spark] class DAGScheduler( stage match { case s: ShuffleMapStage => outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + // Only generate merger location for a given shuffle dependency once. This way, even if + // this stage gets retried, it would still be merging blocks using the same set of + // shuffle services. + if (pushBasedShuffleEnabled) { + prepareShuffleServicesForShuffleMapStage(s) + } case s: ResultStage => outputCommitCoordinator.stageStart( stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) @@ -2027,6 +2062,11 @@ private[spark] class DAGScheduler( if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) { executorFailureEpoch(execId) = currentEpoch logInfo(s"Executor lost: $execId (epoch $currentEpoch)") + if (pushBasedShuffleEnabled) { + // Remove fetchFailed host in the shuffle push merger list for push based shuffle + hostToUnregisterOutputs.foreach( + host => blockManagerMaster.removeShufflePushMergerLocation(host)) + } blockManagerMaster.removeExecutor(execId) clearCacheLocs() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index a566d0a04387c..b2acdb3e12a6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.resource.ResourceProfile +import org.apache.spark.storage.BlockManagerId /** * A backend interface for scheduling systems that allows plugging in different ones under @@ -92,4 +93,16 @@ private[spark] trait SchedulerBackend { */ def maxNumConcurrentTasks(rp: ResourceProfile): Int + /** + * Get the list of host locations for push based shuffle + * + * Currently push based shuffle is disabled for both stage retry and stage reuse cases + * (for eg: in the case where few partitions are lost due to failure). Hence this method + * should be invoked only once for a ShuffleDependency. + * @return List of external shuffle services locations + */ + def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = Nil + } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index dbcb376905338..f800553c5388b 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils * * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. */ -private[spark] class SocketAuthHelper(conf: SparkConf) { +private[spark] class SocketAuthHelper(val conf: SparkConf) { val secret = Utils.createSecret(conf) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index 548fd1b07ddc5..35990b5a59281 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -25,6 +25,8 @@ import scala.concurrent.duration.Duration import scala.util.Try import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.{ThreadUtils, Utils} @@ -34,11 +36,11 @@ import org.apache.spark.util.{ThreadUtils, Utils} * handling one batch of data, with authentication and error handling. * * The socket server can only accept one connection, or close if no connection - * in 15 seconds. + * in configurable amount of seconds (default 15). */ private[spark] abstract class SocketAuthServer[T]( authHelper: SocketAuthHelper, - threadName: String) { + threadName: String) extends Logging { def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) def this(threadName: String) = this(SparkEnv.get, threadName) @@ -46,19 +48,26 @@ private[spark] abstract class SocketAuthServer[T]( private val promise = Promise[T]() private def startServer(): (Int, String) = { + logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) + // Close the socket if no connection in the configured seconds + val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt + logTrace(s"Setting timeout to $timeout sec") + serverSocket.setSoTimeout(timeout * 1000) new Thread(threadName) { setDaemon(true) override def run(): Unit = { var sock: Socket = null try { + logTrace(s"Waiting for connection on port ${serverSocket.getLocalPort}") sock = serverSocket.accept() + logTrace(s"Connection accepted from address ${sock.getRemoteSocketAddress}") authHelper.authClient(sock) + logTrace("Client authenticated") promise.complete(Try(handleConnection(sock))) } finally { + logTrace("Closing server") JavaUtils.closeQuietly(serverSocket) JavaUtils.closeQuietly(sock) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 49e32d04d450a..c6a4457d8f910 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -145,4 +145,6 @@ private[spark] object BlockManagerId { def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { blockManagerIdCache.get(id) } + + private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f544d47b8e13c..fe1a5aef9499c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -125,6 +125,26 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + /** + * Get a list of unique shuffle service locations where an executor is successfully + * registered in the past for block push/merge with push based shuffle. + */ + def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + driverEndpoint.askSync[Seq[BlockManagerId]]( + GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + } + + /** + * Remove the host from the candidate list of shuffle push mergers. This can be + * triggered if there is a FetchFailedException on the host + * @param host + */ + def removeShufflePushMergerLocation(host: String): Unit = { + driverEndpoint.askSync[Seq[BlockManagerId]](RemoveShufflePushMergerLocation(host)) + } + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index a7532a9870fae..4d565511704d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -74,6 +74,14 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + // Mapping from host name to shuffle (mergers) services where the current app + // registered an executor in the past. Older hosts are removed when the + // maxRetainedMergerLocations size is reached in favor of newer locations. + private val shuffleMergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]() + + // Maximum number of merger locations to cache + private val maxRetainedMergerLocations = conf.get(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS) + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100) private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) @@ -92,6 +100,8 @@ class BlockManagerMasterEndpoint( val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) + logInfo("BlockManagerMasterEndpoint up") // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED) // && conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)` @@ -139,6 +149,12 @@ class BlockManagerMasterEndpoint( case GetBlockStatus(blockId, askStorageEndpoints) => context.reply(blockStatus(blockId, askStorageEndpoints)) + case GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter) => + context.reply(getShufflePushMergerLocations(numMergersNeeded, hostsToFilter)) + + case RemoveShufflePushMergerLocation(host) => + context.reply(removeShufflePushMergerLocation(host)) + case IsExecutorAlive(executorId) => context.reply(blockManagerIdByExecutor.contains(executorId)) @@ -360,6 +376,17 @@ class BlockManagerMasterEndpoint( } + private def addMergerLocation(blockManagerId: BlockManagerId): Unit = { + if (!blockManagerId.isDriver && !shuffleMergerLocations.contains(blockManagerId.host)) { + val shuffleServerId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, + blockManagerId.host, externalShuffleServicePort) + if (shuffleMergerLocations.size >= maxRetainedMergerLocations) { + shuffleMergerLocations -= shuffleMergerLocations.head._1 + } + shuffleMergerLocations(shuffleServerId.host) = shuffleServerId + } + } + private def removeExecutor(execId: String): Unit = { logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) @@ -526,6 +553,10 @@ class BlockManagerMasterEndpoint( blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint, externalShuffleServiceBlockStatus) + + if (pushBasedShuffleEnabled) { + addMergerLocation(id) + } } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) @@ -657,6 +688,40 @@ class BlockManagerMasterEndpoint( } } + private def getShufflePushMergerLocations( + numMergersNeeded: Int, + hostsToFilter: Set[String]): Seq[BlockManagerId] = { + val blockManagerHosts = blockManagerIdByExecutor.values.map(_.host).toSet + val filteredBlockManagerHosts = blockManagerHosts.filterNot(hostsToFilter.contains(_)) + val filteredMergersWithExecutors = filteredBlockManagerHosts.map( + BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, _, externalShuffleServicePort)) + // Enough mergers are available as part of active executors list + if (filteredMergersWithExecutors.size >= numMergersNeeded) { + filteredMergersWithExecutors.toSeq + } else { + // Delta mergers added from inactive mergers list to the active mergers list + val filteredMergersWithExecutorsHosts = filteredMergersWithExecutors.map(_.host) + val filteredMergersWithoutExecutors = shuffleMergerLocations.values + .filterNot(x => hostsToFilter.contains(x.host)) + .filterNot(x => filteredMergersWithExecutorsHosts.contains(x.host)) + val randomFilteredMergersLocations = + if (filteredMergersWithoutExecutors.size > + numMergersNeeded - filteredMergersWithExecutors.size) { + Utils.randomize(filteredMergersWithoutExecutors) + .take(numMergersNeeded - filteredMergersWithExecutors.size) + } else { + filteredMergersWithoutExecutors + } + filteredMergersWithExecutors.toSeq ++ randomFilteredMergersLocations + } + } + + private def removeShufflePushMergerLocation(host: String): Unit = { + if (shuffleMergerLocations.contains(host)) { + shuffleMergerLocations.remove(host) + } + } + /** * Returns an [[RpcEndpointRef]] of the [[BlockManagerReplicaEndpoint]] for sending RPC messages. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index bbc076cea9ba8..afe416a55ed0d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -141,4 +141,10 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster + + case class GetShufflePushMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String]) + extends ToBlockManagerMaster + + case class RemoveShufflePushMergerLocation(host: String) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala index a3a528cddee37..4af48d5b9125c 100644 --- a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala @@ -136,12 +136,53 @@ private[spark] object HadoopFSUtils extends Logging { parallelismMax = 0) (path, leafFiles) }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) }.collect() } finally { sc.setJobDescription(previousJobDescription) } - statusMap.toSeq + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } } // scalastyle:off argcount @@ -291,4 +332,22 @@ private[spark] object HadoopFSUtils extends Logging { resolvedLeafStatuses } // scalastyle:on argcount + + /** A serializable variant of HDFS's BlockLocation. This is required by Hadoop 2.7. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. This is required by Hadoop 2.7. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b743ab6507117..71a310a4279ad 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -50,7 +50,7 @@ import com.google.common.net.InetAddresses import org.apache.commons.codec.binary.Hex import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FileUtil, Path, Trash} +import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -269,29 +269,6 @@ private[spark] object Utils extends Logging { file.setExecutable(true, true) } - /** - * Move data to trash if 'spark.sql.truncate.trash.enabled' is true, else - * delete the data permanently. If move data to trash failed fallback to hard deletion. - */ - def moveToTrashOrDelete( - fs: FileSystem, - partitionPath: Path, - isTrashEnabled: Boolean, - hadoopConf: Configuration): Boolean = { - if (isTrashEnabled) { - logDebug(s"Try to move data ${partitionPath.toString} to trash") - val isSuccess = Trash.moveToAppropriateTrash(fs, partitionPath, hadoopConf) - if (!isSuccess) { - logWarning(s"Failed to move data ${partitionPath.toString} to trash. " + - "Fallback to hard deletion") - return fs.delete(partitionPath, true) - } - isSuccess - } else { - fs.delete(partitionPath, true) - } - } - /** * Create a directory given the abstract pathname * @return true, if the directory is successfully created; otherwise, return false. @@ -2541,6 +2518,14 @@ private[spark] object Utils extends Logging { master == "local" || master.startsWith("local[") } + /** + * Push based shuffle can only be enabled when external shuffle service is enabled. + */ + def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = { + conf.get(PUSH_BASED_SHUFFLE_ENABLED) && + (conf.get(IS_TESTING).getOrElse(false) || conf.get(SHUFFLE_SERVICE_ENABLED)) + } + /** * Return whether dynamic allocation is enabled in the given conf. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 55280fc578310..144489c5f7922 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -100,6 +100,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L) .set(Network.RPC_ASK_TIMEOUT, "5s") + .set(PUSH_BASED_SHUFFLE_ENABLED, true) } private def makeSortShuffleManager(): SortShuffleManager = { @@ -1974,6 +1975,48 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("SPARK-32919: Shuffle push merger locations should be bounded with in" + + " spark.shuffle.push.retainedMergerLocations") { + assert(master.getShufflePushMergerLocations(10, Set.empty).isEmpty) + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(10, Set.empty).size == 4) + assert(master.getShufflePushMergerLocations(10, Set.empty).map(_.host).sorted === + Seq("hostC", "hostD", "hostA", "hostB").sorted) + assert(master.getShufflePushMergerLocations(10, Set("hostB")).size == 3) + } + + test("SPARK-32919: Prefer active executor locations for shuffle push mergers") { + makeBlockManager(100, "execA", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + makeBlockManager(100, "execB", + transferService = Some(new MockBlockTransferService(10, "hostB"))) + makeBlockManager(100, "execC", + transferService = Some(new MockBlockTransferService(10, "hostC"))) + makeBlockManager(100, "execD", + transferService = Some(new MockBlockTransferService(10, "hostD"))) + makeBlockManager(100, "execE", + transferService = Some(new MockBlockTransferService(10, "hostA"))) + assert(master.getShufflePushMergerLocations(5, Set.empty).size == 4) + + master.removeExecutor("execA") + master.removeExecutor("execE") + + assert(master.getShufflePushMergerLocations(3, Set.empty).size == 3) + assert(master.getShufflePushMergerLocations(3, Set.empty).map(_.host).sorted === + Seq("hostC", "hostB", "hostD").sorted) + assert(master.getShufflePushMergerLocations(4, Set.empty).map(_.host).sorted === + Seq("hostB", "hostA", "hostC", "hostD").sorted) + } + test("SPARK-33387 Support ordered shuffle block migration") { val blocks: Seq[ShuffleBlockInfo] = Seq( ShuffleBlockInfo(1, 0L), @@ -1995,7 +2038,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(sortedBlocks.sameElements(decomManager.shufflesToMigrate.asScala.map(_._1))) } - class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { + class MockBlockTransferService( + val maxFailures: Int, + override val hostName: String = "MockBlockTransferServiceHost") extends BlockTransferService { var numCalls = 0 var tempFileManager: DownloadFileManager = null @@ -2013,8 +2058,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def close(): Unit = {} - override def hostName: String = { "MockBlockTransferServiceHost" } - override def port: Int = { 63332 } override def uploadBlock( diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 20624c743bc22..8fb408041ca9d 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.SparkListener import org.apache.spark.util.io.ChunkedByteBufferInputStream @@ -1432,6 +1433,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { }.getMessage assert(message.contains(expected)) } + + test("isPushBasedShuffleEnabled when both PUSH_BASED_SHUFFLE_ENABLED" + + " and SHUFFLE_SERVICE_ENABLED are true") { + val conf = new SparkConf() + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, false) + assert(Utils.isPushBasedShuffleEnabled(conf) === false) + conf.set(SHUFFLE_SERVICE_ENABLED, true) + assert(Utils.isPushBasedShuffleEnabled(conf) === true) + } } private class SimpleExtension diff --git a/dev/check-license b/dev/check-license index 0cc17ffe55c67..bd255954d6db4 100755 --- a/dev/check-license +++ b/dev/check-license @@ -67,7 +67,7 @@ mkdir -p "$FWDIR"/lib exit 1 } -mkdir target +mkdir -p target $java_cmd -jar "$rat_jar" -E "$FWDIR"/dev/.rat-excludes -d "$FWDIR" > target/rat-results.txt if [ $? -ne 0 ]; then diff --git a/dev/mima b/dev/mima index f324c5c00a45c..d214bb96e09a3 100755 --- a/dev/mima +++ b/dev/mima @@ -25,8 +25,8 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" SPARK_PROFILES=${1:-"-Pmesos -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"} -TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" -OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" +TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | grep jar | tail -n1)" +OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | grep jar | tail -n1)" rm -f .generated-mima* diff --git a/docs/css/main.css b/docs/css/main.css index 8168a46f9a437..8b279a157c2b6 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -162,6 +162,7 @@ body .container-wrapper { margin-right: auto; border-radius: 15px; position: relative; + min-height: 100vh; } .title { @@ -264,6 +265,7 @@ a:hover code { max-width: 914px; line-height: 1.6; /* Inspired by Github's wiki style */ padding-left: 30px; + min-height: 100vh; } .dropdown-menu { @@ -325,6 +327,7 @@ a.anchorjs-link:hover { text-decoration: none; } border-bottom-width: 0px; margin-top: 0px; width: 210px; + height: 80%; float: left; position: fixed; overflow-y: scroll; diff --git a/docs/sql-data-sources-generic-options.md b/docs/sql-data-sources-generic-options.md index 6bcf48235bced..2e4fc879a435f 100644 --- a/docs/sql-data-sources-generic-options.md +++ b/docs/sql-data-sources-generic-options.md @@ -119,3 +119,40 @@ To load all files recursively, you can use: {% include_example recursive_file_lookup r/RSparkSQLExample.R %} + +### Modification Time Path Filters + +`modifiedBefore` and `modifiedAfter` are options that can be +applied together or separately in order to achieve greater +granularity over which files may load during a Spark batch query. +(Note that Structured Streaming file sources don't support these options.) + +* `modifiedBefore`: an optional timestamp to only include files with +modification times occurring before the specified time. The provided timestamp +must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) +* `modifiedAfter`: an optional timestamp to only include files with +modification times occurring after the specified time. The provided timestamp +must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + +When a timezone option is not provided, the timestamps will be interpreted according +to the Spark session timezone (`spark.sql.session.timeZone`). + +To load files with paths matching a given modified time range, you can use: + +
+
+{% include_example load_with_modified_time_filter scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example load_with_modified_time_filter java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example load_with_modified_time_filter python/sql/datasource.py %} +
+ +
+{% include_example load_with_modified_time_filter r/RSparkSQLExample.R %} +
+
\ No newline at end of file diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index fd7208615a09f..870ed0aa0daaa 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -135,6 +135,7 @@ The behavior of some SQL functions can be different under ANSI mode (`spark.sql. - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. - `element_at`: This function throws `NoSuchElementException` if key does not exist in map. - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `parse_url`: This function throws `IllegalArgumentException` if an input string is not a valid url. ### SQL Operators diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 2295225387a33..46e740d78bffb 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -147,6 +147,22 @@ private static void runGenericFileSourceOptionsExample(SparkSession spark) { // |file1.parquet| // +-------------+ // $example off:load_with_path_glob_filter$ + // $example on:load_with_modified_time_filter$ + Dataset beforeFilterDF = spark.read().format("parquet") + // Only load files modified before 7/1/2020 at 05:30 + .option("modifiedBefore", "2020-07-01T05:30:00") + // Only load files modified after 6/1/2020 at 05:30 + .option("modifiedAfter", "2020-06-01T05:30:00") + // Interpret both times above relative to CST timezone + .option("timeZone", "CST") + .load("examples/src/main/resources/dir1"); + beforeFilterDF.show(); + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // +-------------+ + // $example off:load_with_modified_time_filter$ } private static void runBasicDataSourceExample(SparkSession spark) { diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index eecd8c2d84788..8c146ba0c9455 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -67,6 +67,26 @@ def generic_file_source_options_example(spark): # +-------------+ # $example off:load_with_path_glob_filter$ + # $example on:load_with_modified_time_filter$ + # Only load files modified before 07/1/2050 @ 08:30:00 + df = spark.read.load("examples/src/main/resources/dir1", + format="parquet", modifiedBefore="2050-07-01T08:30:00") + df.show() + # +-------------+ + # | file| + # +-------------+ + # |file1.parquet| + # +-------------+ + # Only load files modified after 06/01/2050 @ 08:30:00 + df = spark.read.load("examples/src/main/resources/dir1", + format="parquet", modifiedAfter="2050-06-01T08:30:00") + df.show() + # +-------------+ + # | file| + # +-------------+ + # +-------------+ + # $example off:load_with_modified_time_filter$ + def basic_datasource_example(spark): # $example on:generic_load_save_functions$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 8685cfb5c05f2..86ad5334248bc 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -144,6 +144,14 @@ df <- read.df("examples/src/main/resources/dir1", "parquet", pathGlobFilter = "* # 1 file1.parquet # $example off:load_with_path_glob_filter$ +# $example on:load_with_modified_time_filter$ +beforeDF <- read.df("examples/src/main/resources/dir1", "parquet", modifiedBefore= "2020-07-01T05:30:00") +# file +# 1 file1.parquet +afterDF <- read.df("examples/src/main/resources/dir1", "parquet", modifiedAfter = "2020-06-01T05:30:00") +# file +# $example off:load_with_modified_time_filter$ + # $example on:manual_save_options_orc$ df <- read.df("examples/src/main/resources/users.orc", "orc") write.orc(df, "users_with_options.orc", orc.bloom.filter.columns = "favorite_color", orc.dictionary.key.threshold = 1.0, orc.column.encoding.direct = "name") diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 2c7abfcd335d1..90c0eeb5ba888 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -81,6 +81,27 @@ object SQLDataSourceExample { // |file1.parquet| // +-------------+ // $example off:load_with_path_glob_filter$ + // $example on:load_with_modified_time_filter$ + val beforeFilterDF = spark.read.format("parquet") + // Files modified before 07/01/2020 at 05:30 are allowed + .option("modifiedBefore", "2020-07-01T05:30:00") + .load("examples/src/main/resources/dir1"); + beforeFilterDF.show(); + // +-------------+ + // | file| + // +-------------+ + // |file1.parquet| + // +-------------+ + val afterFilterDF = spark.read.format("parquet") + // Files modified after 06/01/2020 at 05:30 are allowed + .option("modifiedAfter", "2020-06-01T05:30:00") + .load("examples/src/main/resources/dir1"); + afterFilterDF.show(); + // +-------------+ + // | file| + // +-------------+ + // +-------------+ + // $example off:load_with_modified_time_filter$ } private def runBasicDataSourceExample(spark: SparkSession): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 4b9acd0d39f3f..d086c8cdcc589 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.4.0): * {{{ * DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.4.0 - * ./build/sbt -Pdocker-integration-tests "testOnly *DB2IntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.DB2IntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index f1ffc8f0f3dc7..939a07238934b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., 2019-GA-ubuntu-16.04): * {{{ * MSSQLSERVER_DOCKER_IMAGE_NAME=2019-GA-ubuntu-16.04 - * ./build/sbt -Pdocker-integration-tests "testOnly *MsSqlServerIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 6f96ab33d0fee..68f0dbc057c1f 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., mysql:5.7.31): * {{{ * MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.31 - * ./build/sbt -Pdocker-integration-tests "testOnly *MySQLIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.MySQLIntegrationSuite" * }}} */ @DockerTest diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index fa13100b5fdc8..0347c98bba2c4 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.tags.DockerTest * To run this test suite for a specific version (e.g., postgres:13.0): * {{{ * POSTGRES_DOCKER_IMAGE_NAME=postgres:13.0 - * ./build/sbt -Pdocker-integration-tests "testOnly *PostgresIntegrationSuite" + * ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} */ @DockerTest diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index ad1010da5c104..03ebe0299f63f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp * The imputation strategy. Currently only "mean" and "median" are supported. * If "mean", then replace missing values using the mean value of the feature. * If "median", then replace missing values using the approximate median value of the feature. + * If "mode", then replace missing using the most frequent value of the feature. * Default: mean * * @group param */ final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + - s"If ${Imputer.median}, then replace missing values using the median value of the feature.", - ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) + s"If ${Imputer.median}, then replace missing values using the median value of the feature. " + + s"If ${Imputer.mode}, then replace missing values using the most frequent value of " + + s"the feature.", ParamValidators.inArray[String](Imputer.supportedStrategies)) /** @group getParam */ def getStrategy: String = $(strategy) @@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp * For example, if the input column is IntegerType (1, 2, 4, null), * the output will be IntegerType (1, 2, 4, 2) after mean imputation. * - * Note that the mean/median value is computed after filtering out missing values. + * Note that the mean/median/mode value is computed after filtering out missing values. * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. */ @@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) def setOutputCols(value: Array[String]): this.type = set(outputCols, value) /** - * Imputation strategy. Available options are ["mean", "median"]. + * Imputation strategy. Available options are ["mean", "median", "mode"]. * @group setParam */ @Since("2.2.0") @@ -151,39 +153,42 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) val spark = dataset.sparkSession val (inputColumns, _) = getInOutCols() - val cols = inputColumns.map { inputCol => when(col(inputCol).equalTo($(missingValue)), null) .when(col(inputCol).isNaN, null) .otherwise(col(inputCol)) - .cast("double") + .cast(DoubleType) .as(inputCol) } + val numCols = cols.length val results = $(strategy) match { case Imputer.mean => // Function avg will ignore null automatically. // For a column only containing null, avg will return null. val row = dataset.select(cols.map(avg): _*).head() - Array.range(0, inputColumns.length).map { i => - if (row.isNullAt(i)) { - Double.NaN - } else { - row.getDouble(i) - } - } + Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i)) case Imputer.median => // Function approxQuantile will ignore null automatically. // For a column only containing null, approxQuantile will return an empty array. dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError)) - .map { array => - if (array.isEmpty) { - Double.NaN - } else { - array.head - } - } + .map(_.headOption.getOrElse(Double.NaN)) + + case Imputer.mode => + import spark.implicits._ + // If there is more than one mode, choose the smallest one to keep in line + // with sklearn.impute.SimpleImputer (using scipy.stats.mode). + val modes = dataset.select(cols: _*).flatMap { row => + // Ignore null. + Iterator.range(0, numCols) + .flatMap(i => if (row.isNullAt(i)) None else Some((i, row.getDouble(i)))) + }.toDF("index", "value") + .groupBy("index", "value").agg(negate(count(lit(0))).as("negative_count")) + .groupBy("index").agg(min(struct("negative_count", "value")).as("mode")) + .select("index", "mode.value") + .as[(Int, Double)].collect().toMap + Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN)) } val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1) @@ -212,6 +217,10 @@ object Imputer extends DefaultParamsReadable[Imputer] { /** strategy names that Imputer currently supports. */ private[feature] val mean = "mean" private[feature] val median = "median" + private[feature] val mode = "mode" + + /* Set of strategies that Imputer supports */ + private[feature] val supportedStrategies = Array(mean, median, mode) @Since("2.2.0") override def load(path: String): Imputer = super.load(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index dfee2b4029c8b..30887f55638f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -28,13 +28,14 @@ import org.apache.spark.sql.types._ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0), - (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0), - (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0), - (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0) - )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1", - "expected_mean_value2", "expected_median_value2") + val df = spark.createDataFrame(Seq( + (0, 1.0, 4.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0), + (1, 11.0, 12.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0), + (2, 3.0, Double.NaN, 3.0, 3.0, 3.0, 10.0, 12.0, 4.0), + (3, Double.NaN, 14.0, 5.0, 3.0, 1.0, 14.0, 14.0, 14.0) + )).toDF("id", "value1", "value2", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1", + "expected_mean_value2", "expected_median_value2", "expected_mode_value2") val imputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("out1", "out2")) @@ -42,23 +43,25 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer for Double with default missing Value NaN") { - val df1 = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 11.0, 11.0, 11.0), - (2, 3.0, 3.0, 3.0), - (3, Double.NaN, 5.0, 3.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df1 = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 11.0, 11.0, 11.0, 11.0), + (2, 3.0, 3.0, 3.0, 3.0), + (3, Double.NaN, 5.0, 3.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer1 = new Imputer() .setInputCol("value") .setOutputCol("out") ImputerSuite.iterateStrategyTest(false, imputer1, df1) - val df2 = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 12.0, 12.0, 12.0), - (2, Double.NaN, 10.0, 12.0), - (3, 14.0, 14.0, 14.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df2 = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 12.0, 12.0, 12.0, 12.0), + (2, Double.NaN, 10.0, 12.0, 4.0), + (3, 14.0, 14.0, 14.0, 14.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer2 = new Imputer() .setInputCol("value") .setOutputCol("out") @@ -66,12 +69,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 3.0, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 1.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) ImputerSuite.iterateStrategyTest(true, imputer, df) @@ -79,64 +83,69 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Single Column: Imputer should handle NaNs when computing surrogate value," + " if missingValue is not NaN") { - val df = spark.createDataFrame( Seq( - (0, 1.0, 1.0, 1.0), - (1, 3.0, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 1.0) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0, 1.0) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setMissingValue(-1.0) ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer for Float with missing Value -1.0") { - val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0F, 1.0F), - (1, 3.0F, 3.0F, 3.0F), - (2, 10.0F, 10.0F, 10.0F), - (3, 10.0F, 10.0F, 10.0F), - (4, -1.0F, 6.0F, 3.0F) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0F, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F, 10.0F) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1) ImputerSuite.iterateStrategyTest(true, imputer, df) } test("Single Column: Imputer for Float with missing Value -1.0") { - val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0F, 1.0F), - (1, 3.0F, 3.0F, 3.0F), - (2, 10.0F, 10.0F, 10.0F), - (3, 10.0F, 10.0F, 10.0F), - (4, -1.0F, 6.0F, 3.0F) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq( + (0, 1.0F, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F, 10.0F) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setMissingValue(-1) ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer should impute null as well as 'missingValue'") { - val rawDf = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 10.0, 10.0, 10.0), - (2, 10.0, 10.0, 10.0), - (3, Double.NaN, 8.0, 10.0), - (4, -1.0, 8.0, 10.0) - )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val rawDf = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0, 10.0), + (4, -1.0, 8.0, 10.0, 10.0) + )).toDF("id", "rawValue", + "expected_mean_value", "expected_median_value", "expected_mode_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) ImputerSuite.iterateStrategyTest(true, imputer, df) } test("Single Column: Imputer should impute null as well as 'missingValue'") { - val rawDf = spark.createDataFrame( Seq( - (0, 4.0, 4.0, 4.0), - (1, 10.0, 10.0, 10.0), - (2, 10.0, 10.0, 10.0), - (3, Double.NaN, 8.0, 10.0), - (4, -1.0, 8.0, 10.0) - )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val rawDf = spark.createDataFrame(Seq( + (0, 4.0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0, 10.0), + (4, -1.0, 8.0, 10.0, 10.0) + )).toDF("id", "rawValue", + "expected_mean_value", "expected_median_value", "expected_mode_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCol("value").setOutputCol("out") ImputerSuite.iterateStrategyTest(false, imputer, df) @@ -187,7 +196,7 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer throws exception when surrogate cannot be computed") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, Double.NaN, 1.0, 1.0), (1, Double.NaN, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN) @@ -205,12 +214,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer throws exception when surrogate cannot be computed") { - val df = spark.createDataFrame( Seq( - (0, Double.NaN, 1.0, 1.0), - (1, Double.NaN, 3.0, 3.0), - (2, Double.NaN, Double.NaN, Double.NaN) - )).toDF("id", "value", "expected_mean_value", "expected_median_value") - Seq("mean", "median").foreach { strategy => + val df = spark.createDataFrame(Seq( + (0, Double.NaN, 1.0, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", + "expected_mean_value", "expected_median_value", "expected_mode_value") + Seq("mean", "median", "mode").foreach { strategy => val imputer = new Imputer().setInputCol("value").setOutputCol("out") .setStrategy(strategy) withClue("Imputer should fail all the values are invalid") { @@ -223,12 +233,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer input & output column validation") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, 1.0, 1.0, 1.0), (1, Double.NaN, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN) )).toDF("id", "value1", "value2", "value3") - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => withClue("Imputer should fail if inputCols and outputCols are different length") { val e: IllegalArgumentException = intercept[IllegalArgumentException] { val imputer = new Imputer().setStrategy(strategy) @@ -306,13 +316,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer for IntegerType with default missing value null") { - - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (null, 5, 3) - )).toDF("value1", "expected_mean_value1", "expected_median_value1") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (null, 5, 3, 1) + )).toDF("value1", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -327,12 +337,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column Imputer for IntegerType with default missing value null") { - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (null, 5, 3) - )).toDF("value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (null, 5, 3, 1) + )).toDF("value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer() .setInputCol("value") @@ -347,13 +358,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer for IntegerType with missing value -1") { - - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (-1, 5, 3) - )).toDF("value1", "expected_mean_value1", "expected_median_value1") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (-1, 5, 3, 1) + )).toDF("value1", + "expected_mean_value1", "expected_median_value1", "expected_mode_value1") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -369,12 +380,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: Imputer for IntegerType with missing value -1") { - val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( - (1, 1, 1), - (11, 11, 11), - (3, 3, 3), - (-1, 5, 3) - )).toDF("value", "expected_mean_value", "expected_median_value") + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)]( + (1, 1, 1, 1), + (11, 11, 11, 11), + (3, 3, 3, 3), + (-1, 5, 3, 1) + )).toDF("value", + "expected_mean_value", "expected_median_value", "expected_mode_value") val imputer = new Imputer() .setInputCol("value") @@ -402,13 +414,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Compare single/multiple column(s) Imputer in pipeline") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, 1.0, 4.0), (1, 11.0, 12.0), (2, 3.0, Double.NaN), (3, Double.NaN, 14.0) )).toDF("id", "value1", "value2") - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => val multiColsImputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("result1", "result2")) @@ -450,11 +462,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { object ImputerSuite { /** - * Imputation strategy. Available options are ["mean", "median"]. - * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" + * Imputation strategy. Available options are ["mean", "median", "mode"]. + * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median", + * "expected_mode". */ def iterateStrategyTest(isMultiCol: Boolean, imputer: Imputer, df: DataFrame): Unit = { - Seq("mean", "median").foreach { strategy => + Seq("mean", "median", "mode").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) val resultDF = model.transform(df) diff --git a/pom.xml b/pom.xml index 3ae2e7420e154..0ab5a8c5b3efa 100644 --- a/pom.xml +++ b/pom.xml @@ -164,7 +164,6 @@ 3.2.2 2.12.10 2.12 - -Ywarn-unused-import 2.0.0 --test @@ -932,7 +931,7 @@ org.scalatest scalatest_${scala.binary.version} - 3.2.0 + 3.2.3 test @@ -956,14 +955,14 @@ org.mockito mockito-core - 3.1.0 + 3.4.6 test org.jmock jmock-junit4 test - 2.8.4 + 2.12.0 org.scalacheck @@ -974,7 +973,7 @@ junit junit - 4.12 + 4.13.1 test @@ -2499,7 +2498,7 @@ net.alchim31.maven scala-maven-plugin - 4.3.0 + 4.4.0 eclipse-add-source @@ -2538,7 +2537,6 @@ -deprecation -feature -explaintypes - ${scalac.arg.unused-imports} -target:jvm-1.8 @@ -2575,7 +2573,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M3 + 3.0.0-M5 @@ -3262,13 +3260,12 @@ - + scala-2.13 2.13.3 2.13 - -Wconf:cat=unused-imports:e diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 55c87fcb3aaa2..05413b7091ad9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -221,6 +221,7 @@ object SparkBuild extends PomBuild { Seq( "-Xfatal-warnings", "-deprecation", + "-Ywarn-unused-import", "-P:silencer:globalFilters=.*deprecated.*" //regex to catch deprecation warnings and supress them ) } else { @@ -230,6 +231,8 @@ object SparkBuild extends PomBuild { // see `scalac -Wconf:help` for details "-Wconf:cat=deprecation:wv,any:e", // 2.13-specific warning hits to be muted (as narrowly as possible) and addressed separately + // TODO(SPARK-33499): Enable this option when Scala 2.12 is no longer supported. + // "-Wunused:imports", "-Wconf:cat=lint-multiarg-infix:wv", "-Wconf:cat=other-nullary-override:wv", "-Wconf:cat=other-match-analysis&site=org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupFunction.catalogFunction:wv", diff --git a/project/build.properties b/project/build.properties index 5ec1d700fd2a8..c92de941c10be 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=1.4.2 +sbt.version=1.4.4 diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 4039698d39958..9c9ff7fa7844b 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -48,7 +48,7 @@ If you want to install extra dependencies for a specific componenet, you can ins pip install pyspark[sql] -For PySpark with a different Hadoop version, you can install it by using ``HADOOP_VERSION`` environment variables as below: +For PySpark with/without a specific Hadoop version, you can install it by using ``HADOOP_VERSION`` environment variables as below: .. code-block:: bash @@ -68,8 +68,13 @@ It is recommended to use ``-v`` option in ``pip`` to track the installation and HADOOP_VERSION=2.7 pip install pyspark -v -Supported versions of Hadoop are ``HADOOP_VERSION=2.7`` and ``HADOOP_VERSION=3.2`` (default). -Note that this installation of PySpark with a different version of Hadoop is experimental. It can change or be removed between minor releases. +Supported values in ``HADOOP_VERSION`` are: + +- ``without``: Spark pre-built with user-provided Apache Hadoop +- ``2.7``: Spark pre-built for Apache Hadoop 2.7 +- ``3.2``: Spark pre-built for Apache Hadoop 3.2 and later (default) + +Note that this installation way of PySpark with/without a specific Hadoop version is experimental. It can change or be removed between minor releases. Using Conda diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 9c9e3f4b3c881..1bd5961e0525a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,6 +222,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # data via a socket. # scala's mangled names w/ $ in them require special treatment. self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc) + os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = \ + str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index eafa5d90f9ff8..fe2e326dff8be 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) - sock.settimeout(15) + sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15))) sock.connect(sa) sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) _do_server_auth(sockfile, auth_secret) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4d898bd5fffa8..82b9a6db1eb92 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1507,7 +1507,8 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has strategy = Param(Params._dummy(), "strategy", "strategy for imputation. If mean, then replace missing values using the mean " "value of the feature. If median, then replace missing values using the " - "median value of the feature.", + "median value of the feature. If mode, then replace missing using the most " + "frequent value of the feature.", typeConverter=TypeConverters.toString) missingValue = Param(Params._dummy(), "missingValue", @@ -1541,7 +1542,7 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable): numeric type. Currently Imputer does not support categorical features and possibly creates incorrect values for a categorical feature. - Note that the mean/median value is computed after filtering out missing values. + Note that the mean/median/mode value is computed after filtering out missing values. All Null values in the input columns are treated as missing, and so are also imputed. For computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a relative error of `0.001`. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2ed991c87f506..bb31e6a3e09f8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -125,6 +125,12 @@ def option(self, key, value): * ``pathGlobFilter``: an optional glob pattern to only include files with paths matching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior of partition discovery. + * ``modifiedBefore``: an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + * ``modifiedAfter``: an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -149,6 +155,12 @@ def options(self, **options): * ``pathGlobFilter``: an optional glob pattern to only include files with paths matching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior of partition discovery. + * ``modifiedBefore``: an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + * ``modifiedAfter``: an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -203,7 +215,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, dropFieldIfAllNull=None, encoding=None, locale=None, pathGlobFilter=None, - recursiveFileLookup=None, allowNonNumericNumbers=None): + recursiveFileLookup=None, allowNonNumericNumbers=None, + modifiedBefore=None, modifiedAfter=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -322,6 +335,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, ``+Infinity`` and ``Infinity``. * ``-INF``: for negative infinity, alias ``-Infinity``. * ``NaN``: for other not-a-numbers, like result of division by zero. + modifiedBefore : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- @@ -344,6 +364,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, locale=locale, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup, + modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter, allowNonNumericNumbers=allowNonNumericNumbers) if isinstance(path, str): path = [path] @@ -410,6 +431,15 @@ def parquet(self, *paths, **options): disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore (batch only) : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter (batch only) : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') @@ -418,13 +448,18 @@ def parquet(self, *paths, **options): """ mergeSchema = options.get('mergeSchema', None) pathGlobFilter = options.get('pathGlobFilter', None) + modifiedBefore = options.get('modifiedBefore', None) + modifiedAfter = options.get('modifiedAfter', None) recursiveFileLookup = options.get('recursiveFileLookup', None) self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, - recursiveFileLookup=recursiveFileLookup) + recursiveFileLookup=recursiveFileLookup, modifiedBefore=modifiedBefore, + modifiedAfter=modifiedAfter) + return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths))) def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, - recursiveFileLookup=None): + recursiveFileLookup=None, modifiedBefore=None, + modifiedAfter=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there @@ -453,6 +488,15 @@ def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, recursively scan a directory for files. Using this option disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore (batch only) : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter (batch only) : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.text('python/test_support/sql/text-test.txt') @@ -464,7 +508,9 @@ def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, """ self._set_opts( wholetext=wholetext, lineSep=lineSep, pathGlobFilter=pathGlobFilter, - recursiveFileLookup=recursiveFileLookup) + recursiveFileLookup=recursiveFileLookup, modifiedBefore=modifiedBefore, + modifiedAfter=modifiedAfter) + if isinstance(paths, str): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -476,7 +522,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None, - pathGlobFilter=None, recursiveFileLookup=None): + pathGlobFilter=None, recursiveFileLookup=None, modifiedBefore=None, modifiedAfter=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -631,6 +677,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non recursively scan a directory for files. Using this option disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore (batch only) : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter (batch only) : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.csv('python/test_support/sql/ages.csv') @@ -652,7 +707,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep, - pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) + pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup, + modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter) if isinstance(path, str): path = [path] if type(path) == list: @@ -679,7 +735,8 @@ def func(iterator): else: raise TypeError("path can be only string, list or RDD") - def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None): + def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None, + modifiedBefore=None, modifiedAfter=None): """Loads ORC files, returning the result as a :class:`DataFrame`. .. versionadded:: 1.5.0 @@ -701,6 +758,15 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N disables `partition discovery `_. # noqa + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedBefore : an optional timestamp to only include files with + modification times occurring before the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + modifiedAfter : an optional timestamp to only include files with + modification times occurring after the specified time. The provided timestamp + must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00) + Examples -------- >>> df = spark.read.orc('python/test_support/sql/orc_partitioned') @@ -708,6 +774,7 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N [('a', 'bigint'), ('b', 'int'), ('c', 'int')] """ self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, + modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter, recursiveFileLookup=recursiveFileLookup) if isinstance(path, str): path = [path] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index b42bdb9816600..22002bb32004d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.EnumSet -import java.util.concurrent.atomic.{AtomicBoolean} +import java.util.concurrent.atomic.AtomicBoolean import javax.servlet.DispatcherType import scala.concurrent.{ExecutionContext, Future} @@ -29,14 +29,14 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config.UI._ import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{RpcUtils, ThreadUtils} +import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Abstract Yarn scheduler backend that contains common logic @@ -80,6 +80,18 @@ private[spark] abstract class YarnSchedulerBackend( /** Attempt ID. This is unset for client-mode schedulers */ private var attemptId: Option[ApplicationAttemptId] = None + private val blockManagerMaster: BlockManagerMaster = sc.env.blockManager.master + + private val minMergersThresholdRatio = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO) + + private val minMergersStaticThreshold = + conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD) + + private val maxNumExecutors = conf.get(config.DYN_ALLOCATION_MAX_EXECUTORS) + + private val numExecutors = conf.get(config.EXECUTOR_INSTANCES).getOrElse(0) + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -161,6 +173,36 @@ private[spark] abstract class YarnSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + override def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = { + // TODO (SPARK-33481) This is a naive way of calculating numMergersDesired for a stage, + // TODO we can use better heuristics to calculate numMergersDesired for a stage. + val maxExecutors = if (Utils.isDynamicAllocationEnabled(sc.getConf)) { + maxNumExecutors + } else { + numExecutors + } + val tasksPerExecutor = sc.resourceProfileManager + .resourceProfileFromId(resourceProfileId).maxTasksPerExecutor(sc.conf) + val numMergersDesired = math.min( + math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxExecutors) + val minMergersNeeded = math.max(minMergersStaticThreshold, + math.floor(numMergersDesired * minMergersThresholdRatio).toInt) + + // Request for numMergersDesired shuffle mergers to BlockManagerMasterEndpoint + // and if it's less than minMergersNeeded, we disable push based shuffle. + val mergerLocations = blockManagerMaster + .getShufflePushMergerLocations(numMergersDesired, scheduler.excludedNodes()) + if (mergerLocations.size < numMergersDesired && mergerLocations.size < minMergersNeeded) { + Seq.empty[BlockManagerId] + } else { + logDebug(s"The number of shuffle mergers desired ${numMergersDesired}" + + s" and available locations are ${mergerLocations.length}") + mergerLocations + } + } + /** * Add filters to the SparkUI. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala new file mode 100644 index 0000000000000..c680502cb328f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, GroupingID} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{AbstractDataType, DataType, StructType} + +/** + * Object for grouping all error messages of the query compilation. + * Currently it includes all AnalysisExcpetions created and thrown directly in + * org.apache.spark.sql.catalyst.analysis.Analyzer. + */ +object QueryCompilationErrors { + def groupingIDMismatchError(groupingID: GroupingID, groupByExprs: Seq[Expression]): Throwable = { + new AnalysisException( + s"Columns of grouping_id (${groupingID.groupByExprs.mkString(",")}) " + + s"does not match grouping columns (${groupByExprs.mkString(",")})") + } + + def groupingColInvalidError(groupingCol: Expression, groupByExprs: Seq[Expression]): Throwable = { + new AnalysisException( + s"Column of grouping ($groupingCol) can't be found " + + s"in grouping columns ${groupByExprs.mkString(",")}") + } + + def groupingSizeTooLargeError(sizeLimit: Int): Throwable = { + new AnalysisException( + s"Grouping sets size cannot be greater than $sizeLimit") + } + + def unorderablePivotColError(pivotCol: Expression): Throwable = { + new AnalysisException( + s"Invalid pivot column '$pivotCol'. Pivot columns must be comparable." + ) + } + + def nonLiteralPivotValError(pivotVal: Expression): Throwable = { + new AnalysisException( + s"Literal expressions required for pivot values, found '$pivotVal'") + } + + def pivotValDataTypeMismatchError(pivotVal: Expression, pivotCol: Expression): Throwable = { + new AnalysisException( + s"Invalid pivot value '$pivotVal': " + + s"value data type ${pivotVal.dataType.simpleString} does not match " + + s"pivot column data type ${pivotCol.dataType.catalogString}") + } + + def unsupportedIfNotExistsError(tableName: String): Throwable = { + new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: $tableName") + } + + def nonPartitionColError(partitionName: String): Throwable = { + new AnalysisException( + s"PARTITION clause cannot contain a non-partition column name: $partitionName") + } + + def addStaticValToUnknownColError(staticName: String): Throwable = { + new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + } + + def unknownStaticPartitionColError(name: String): Throwable = { + new AnalysisException(s"Unknown static partition column: $name") + } + + def nestedGeneratorError(trimmedNestedGenerator: Expression): Throwable = { + new AnalysisException( + "Generators are not supported when it's nested in " + + "expressions, but got: " + toPrettySQL(trimmedNestedGenerator)) + } + + def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = { + new AnalysisException( + s"Only one generator allowed per $clause clause but found " + + generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + } + + def generatorOutsideSelectError(plan: LogicalPlan): Throwable = { + new AnalysisException( + "Generators are not supported outside the SELECT clause, but " + + "got: " + plan.simpleString(SQLConf.get.maxToStringFields)) + } + + def legacyStoreAssignmentPolicyError(): Throwable = { + val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key + new AnalysisException( + "LEGACY store assignment policy is disallowed in Spark data source V2. " + + s"Please set the configuration $configKey to other values.") + } + + def unresolvedUsingColForJoinError( + colName: String, plan: LogicalPlan, side: String): Throwable = { + new AnalysisException( + s"USING column `$colName` cannot be resolved on the $side " + + s"side of the join. The $side-side columns: [${plan.output.map(_.name).mkString(", ")}]") + } + + def dataTypeMismatchForDeserializerError( + dataType: DataType, desiredType: String): Throwable = { + val quantifier = if (desiredType.equals("array")) "an" else "a" + new AnalysisException( + s"need $quantifier $desiredType field but got " + dataType.catalogString) + } + + def fieldNumberMismatchForDeserializerError( + schema: StructType, maxOrdinal: Int): Throwable = { + new AnalysisException( + s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.") + } + + def upCastFailureError( + fromStr: String, from: Expression, to: DataType, walkedTypePath: Seq[String]): Throwable = { + new AnalysisException( + s"Cannot up cast $fromStr from " + + s"${from.dataType.catalogString} to ${to.catalogString}.\n" + + s"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + def unsupportedAbstractDataTypeForUpCastError(gotType: AbstractDataType): Throwable = { + new AnalysisException( + s"UpCast only support DecimalType as AbstractDataType yet, but got: $gotType") + } + + def outerScopeFailureForNewInstanceError(className: String): Throwable = { + new AnalysisException( + s"Unable to generate an encoder for inner class `$className` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + + def referenceColNotFoundForAlterTableChangesError( + after: TableChange.After, parentName: String): Throwable = { + new AnalysisException( + s"Couldn't find the reference column for $after at $parentName") + } + +} + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8d95d8cf49d45..837686420375a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy} @@ -448,9 +449,7 @@ class Analyzer(override val catalogManager: CatalogManager) e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { Alias(gid, toPrettySQL(e))() } else { - throw new AnalysisException( - s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + - s"grouping columns (${groupByExprs.mkString(",")})") + throw QueryCompilationErrors.groupingIDMismatchError(e, groupByExprs) } case e @ Grouping(col: Expression) => val idx = groupByExprs.indexWhere(_.semanticEquals(col)) @@ -458,8 +457,7 @@ class Analyzer(override val catalogManager: CatalogManager) Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1L)), ByteType), toPrettySQL(e))() } else { - throw new AnalysisException(s"Column of grouping ($col) can't be found " + - s"in grouping columns ${groupByExprs.mkString(",")}") + throw QueryCompilationErrors.groupingColInvalidError(col, groupByExprs) } } } @@ -575,8 +573,7 @@ class Analyzer(override val catalogManager: CatalogManager) val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs) if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) { - throw new AnalysisException( - s"Grouping sets size cannot be greater than ${GroupingID.dataType.defaultSize * 8}") + throw QueryCompilationErrors.groupingSizeTooLargeError(GroupingID.dataType.defaultSize * 8) } // Expand works by setting grouping expressions to null as determined by the @@ -712,8 +709,7 @@ class Analyzer(override val catalogManager: CatalogManager) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => if (!RowOrdering.isOrderable(pivotColumn.dataType)) { - throw new AnalysisException( - s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) } // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) @@ -724,13 +720,10 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => value.foldable } if (!foldable) { - throw new AnalysisException( - s"Literal expressions required for pivot values, found '$value'") + throw QueryCompilationErrors.nonLiteralPivotValError(value) } if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { - throw new AnalysisException(s"Invalid pivot value '$value': " + - s"value data type ${value.dataType.simpleString} does not match " + - s"pivot column data type ${pivotColumn.dataType.catalogString}") + throw QueryCompilationErrors.pivotValDataTypeMismatchError(value, pivotColumn) } Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) } @@ -868,9 +861,9 @@ class Analyzer(override val catalogManager: CatalogManager) }.getOrElse(write) case _ => write } - case u @ UnresolvedTable(ident) => + case u @ UnresolvedTable(ident, cmd) => lookupTempView(ident).foreach { _ => - u.failAnalysis(s"${ident.quoted} is a temp view not table.") + u.failAnalysis(s"${ident.quoted} is a temp view. '$cmd' expects a table") } u case u @ UnresolvedTableOrView(ident, allowTempView) => @@ -957,7 +950,7 @@ class Analyzer(override val catalogManager: CatalogManager) SubqueryAlias(catalog.get.name +: ident.namespace :+ ident.name, relation) }.getOrElse(u) - case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) => + case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident), _) => CatalogV2Util.loadTable(catalog, ident) .map(ResolvedTable(catalog.asTableCatalog, ident, _)) .getOrElse(u) @@ -1084,11 +1077,11 @@ class Analyzer(override val catalogManager: CatalogManager) lookupRelation(u.multipartIdentifier, u.options, u.isStreaming) .map(resolveViews).getOrElse(u) - case u @ UnresolvedTable(identifier) => + case u @ UnresolvedTable(identifier, cmd) => lookupTableOrView(identifier).map { case v: ResolvedView => val viewStr = if (v.isTemp) "temp view" else "view" - u.failAnalysis(s"${v.identifier.quoted} is a $viewStr not table.") + u.failAnalysis(s"${v.identifier.quoted} is a $viewStr. '$cmd' expects a table.'") case table => table }.getOrElse(u) @@ -1167,8 +1160,7 @@ class Analyzer(override val catalogManager: CatalogManager) case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved => // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { - throw new AnalysisException( - s"Cannot write, IF NOT EXISTS is not supported for table: ${r.table.name}") + throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name) } val partCols = partitionColumnNames(r.table) @@ -1205,8 +1197,7 @@ class Analyzer(override val catalogManager: CatalogManager) partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { case Some(_) => case None => - throw new AnalysisException( - s"PARTITION clause cannot contain a non-partition column name: $partitionName") + throw QueryCompilationErrors.nonPartitionColError(partitionName) } } } @@ -1228,8 +1219,7 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(attr) => attr.name -> staticName case _ => - throw new AnalysisException( - s"Cannot add static value for unknown column: $staticName") + throw QueryCompilationErrors.addStaticValToUnknownColError(staticName) }).toMap val queryColumns = query.output.iterator @@ -1271,7 +1261,7 @@ class Analyzer(override val catalogManager: CatalogManager) // an UnresolvedAttribute. EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) case None => - throw new AnalysisException(s"Unknown static partition column: $name") + throw QueryCompilationErrors.unknownStaticPartitionColError(name) } }.reduce(And) } @@ -2483,23 +2473,19 @@ class Analyzer(override val catalogManager: CatalogManager) def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get - throw new AnalysisException("Generators are not supported when it's nested in " + - "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Project(projectList, _) if projectList.count(hasGenerator) > 1 => val generators = projectList.filter(hasGenerator).map(trimAlias) - throw new AnalysisException("Only one generator allowed per select clause but found " + - generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "select") case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) => val nestedGenerator = aggList.find(hasNestedGenerator).get - throw new AnalysisException("Generators are not supported when it's nested in " + - "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 => val generators = aggList.filter(hasGenerator).map(trimAlias) - throw new AnalysisException("Only one generator allowed per aggregate clause but found " + - generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) + throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate") case agg @ Aggregate(groupList, aggList, child) if aggList.forall { case AliasedGenerator(_, _, _) => true @@ -2582,8 +2568,7 @@ class Analyzer(override val catalogManager: CatalogManager) case g: Generate => g case p if p.expressions.exists(hasGenerator) => - throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + - "got: " + p.simpleString(SQLConf.get.maxToStringFields)) + throw QueryCompilationErrors.generatorOutsideSelectError(p) } } @@ -3122,10 +3107,7 @@ class Analyzer(override val catalogManager: CatalogManager) private def validateStoreAssignmentPolicy(): Unit = { // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { - val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key - throw new AnalysisException(s""" - |"LEGACY" store assignment policy is disallowed in Spark data source V2. - |Please set the configuration $configKey to other values.""".stripMargin) + throw QueryCompilationErrors.legacyStoreAssignmentPolicyError() } } @@ -3138,14 +3120,12 @@ class Analyzer(override val catalogManager: CatalogManager) hint: JoinHint) = { val leftKeys = joinNames.map { keyName => left.output.find(attr => resolver(attr.name, keyName)).getOrElse { - throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + - s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]") + throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, left, "left") } } val rightKeys = joinNames.map { keyName => right.output.find(attr => resolver(attr.name, keyName)).getOrElse { - throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " + - s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]") + throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, right, "right") } } val joinPairs = leftKeys.zip(rightKeys) @@ -3208,7 +3188,8 @@ class Analyzer(override val catalogManager: CatalogManager) ExtractValue(child, fieldName, resolver) } case other => - throw new AnalysisException("need an array field but got " + other.catalogString) + throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, + "array") } case u: UnresolvedCatalystToExternalMap if u.child.resolved => u.child.dataType match { @@ -3218,7 +3199,7 @@ class Analyzer(override val catalogManager: CatalogManager) ExtractValue(child, fieldName, resolver) } case other => - throw new AnalysisException("need a map field but got " + other.catalogString) + throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, "map") } } validateNestedTupleFields(result) @@ -3227,8 +3208,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def fail(schema: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" + - ", but failed as the number of fields does not line up.") + throw QueryCompilationErrors.fieldNumberMismatchForDeserializerError(schema, maxOrdinal) } /** @@ -3287,10 +3267,7 @@ class Analyzer(override val catalogManager: CatalogManager) case n: NewInstance if n.childrenResolved && !n.resolved => val outer = OuterScopes.getOuterScope(n.cls) if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + - "access to the scope that this class was defined in.\n" + - "Try moving this class out of its parent class.") + throw QueryCompilationErrors.outerScopeFailureForNewInstanceError(n.cls.getName) } n.copy(outerPointer = Some(outer)) } @@ -3306,11 +3283,7 @@ class Analyzer(override val catalogManager: CatalogManager) case l: LambdaVariable => "array element" case e => e.sql } - throw new AnalysisException(s"Cannot up cast $fromStr from " + - s"${from.dataType.catalogString} to ${to.catalogString}.\n" + - "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + - "You can either add an explicit cast to the input data or choose a higher precision " + - "type of the field in the target object") + throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath) } def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { @@ -3321,8 +3294,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UpCast(child, _, _) if !child.resolved => u case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => - throw new AnalysisException( - s"UpCast only support DecimalType as AbstractDataType yet, but got: $target") + throw QueryCompilationErrors.unsupportedAbstractDataTypeForUpCastError(target) case UpCast(child, target, walkedTypePath) if target == DecimalType && child.dataType.isInstanceOf[DecimalType] => @@ -3501,8 +3473,8 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(colName) => ColumnPosition.after(colName) case None => - throw new AnalysisException("Couldn't find the reference column for " + - s"$after at $parentName") + throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError(after, + parentName) } case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 452ba80b23441..9998035d65c3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -98,7 +98,7 @@ trait CheckAnalysis extends PredicateHelper { u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}") case u: UnresolvedTable => - u.failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}") + u.failAnalysis(s"Table not found for '${u.commandName}': ${u.multipartIdentifier.quoted}") case u: UnresolvedTableOrView => u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 508239077a70e..6fb9bed9625d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -391,6 +391,7 @@ object FunctionRegistry { expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[CurrentTimeZone]("current_timezone"), expression[DateDiff]("datediff"), expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 5e19a32968992..531d40f431dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, Alte import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec /** * Resolve [[UnresolvedPartitionSpec]] to [[ResolvedPartitionSpec]] in partition related commands. @@ -33,32 +34,38 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case r @ AlterTableAddPartition( ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _) => - r.copy(parts = resolvePartitionSpecs(partSpecs, table.partitionSchema())) + r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema())) case r @ AlterTableDropPartition( ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _, _, _) => - r.copy(parts = resolvePartitionSpecs(partSpecs, table.partitionSchema())) + r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema())) } private def resolvePartitionSpecs( - partSpecs: Seq[PartitionSpec], partSchema: StructType): Seq[ResolvedPartitionSpec] = + tableName: String, + partSpecs: Seq[PartitionSpec], + partSchema: StructType): Seq[ResolvedPartitionSpec] = partSpecs.map { case unresolvedPartSpec: UnresolvedPartitionSpec => ResolvedPartitionSpec( - convertToPartIdent(unresolvedPartSpec.spec, partSchema), unresolvedPartSpec.location) + convertToPartIdent(tableName, unresolvedPartSpec.spec, partSchema), + unresolvedPartSpec.location) case resolvedPartitionSpec: ResolvedPartitionSpec => resolvedPartitionSpec } private def convertToPartIdent( - partSpec: TablePartitionSpec, partSchema: StructType): InternalRow = { - val conflictKeys = partSpec.keys.toSeq.diff(partSchema.map(_.name)) - if (conflictKeys.nonEmpty) { - throw new AnalysisException(s"Partition key ${conflictKeys.mkString(",")} not exists") - } + tableName: String, + partitionSpec: TablePartitionSpec, + partSchema: StructType): InternalRow = { + val normalizedSpec = normalizePartitionSpec( + partitionSpec, + partSchema.map(_.name), + tableName, + conf.resolver) val partValues = partSchema.map { part => - val partValue = partSpec.get(part.name).orNull + val partValue = normalizedSpec.get(part.name).orNull if (partValue == null) { null } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index 98bd84fb94bd6..0e883a88f2691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -37,7 +37,9 @@ case class UnresolvedNamespace(multipartIdentifier: Seq[String]) extends LeafNod * Holds the name of a table that has yet to be looked up in a catalog. It will be resolved to * [[ResolvedTable]] during analysis. */ -case class UnresolvedTable(multipartIdentifier: Seq[String]) extends LeafNode { +case class UnresolvedTable( + multipartIdentifier: Seq[String], + commandName: String) extends LeafNode { override lazy val resolved: Boolean = false override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala index 3189d81289903..ff9c4cf3147d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.IdentityHashMap -import scala.collection.JavaConverters._ - import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} @@ -98,7 +96,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) { val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this) proxyExpressionCurrentId += 1 - proxyMap.putAll(e.map(_ -> proxy).toMap.asJava) + // We leverage `IdentityHashMap` so we compare expression keys by reference here. + // So for example if there are one group of common exprs like Seq(common expr 1, + // common expr2, ..., common expr n), we will insert into `proxyMap` some key/value + // pairs like Map(common expr 1 -> proxy(common expr 1), ..., + // common expr n -> proxy(common expr 1)). + e.map(proxyMap.put(_, proxy)) } // Only adding proxy if we find subexpressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 97aacb3f7530c..9953b780ceace 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -73,6 +73,21 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { } } +@ExpressionDescription( + usage = "_FUNC_() - Returns the current session local timezone.", + examples = """ + Examples: + > SELECT _FUNC_(); + Asia/Shanghai + """, + group = "datetime_funcs", + since = "3.1.0") +case class CurrentTimeZone() extends LeafExpression with Unevaluable { + override def nullable: Boolean = false + override def dataType: DataType = StringType + override def prettyName: String = "current_timezone" +} + /** * Returns the current date at the start of query evaluation. * There is no code generation since this expression should get constant folded by the optimizer. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 16e22940495f1..9f92181b34df1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1357,8 +1357,9 @@ object ParseUrl { 1 """, since = "2.0.0") -case class ParseUrl(children: Seq[Expression]) +case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression with ExpectsInputTypes with CodegenFallback { + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) @@ -1404,7 +1405,9 @@ case class ParseUrl(children: Seq[Expression]) try { new URI(url.toString) } catch { - case e: URISyntaxException => null + case e: URISyntaxException if failOnError => + throw new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e) + case _: URISyntaxException => null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c4b9936fa4c4f..9eee7c2b914a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -811,9 +811,12 @@ object CollapseRepartition extends Rule[LogicalPlan] { */ object OptimizeWindowFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), spec) - if spec.orderSpec.nonEmpty && - spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame].frameType == RowFrame => + case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), + WindowSpecDefinition(_, orderSpec, frameSpecification: SpecifiedWindowFrame)) + if orderSpec.nonEmpty && frameSpecification.frameType == RowFrame && + frameSpecification.lower == UnboundedPreceding && + (frameSpecification.upper == UnboundedFollowing || + frameSpecification.upper == CurrentRow) => we.copy(windowFunction = NthValue(first.child, Literal(1), first.ignoreNulls)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 9aa7e3201ab1b..1f2389176d1e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -75,6 +76,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { val timeExpr = CurrentTimestamp() val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long] val currentTime = Literal.create(timestamp, timeExpr.dataType) + val timezone = Literal.create(SQLConf.get.sessionLocalTimeZone, StringType) plan transformAllExpressions { case currentDate @ CurrentDate(Some(timeZoneId)) => @@ -84,6 +86,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { DateType) }) case CurrentTimestamp() | Now() => currentTime + case CurrentTimeZone() => timezone } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 79857a63a69b5..ea4baafbacede 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1414,8 +1414,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg // So we use LikeAll or NotLikeAll instead. val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) ctx.NOT match { - case null => LikeAll(e, patterns) - case _ => NotLikeAll(e, patterns) + case null => LikeAll(e, patterns.toSeq) + case _ => NotLikeAll(e, patterns.toSeq) } } else { getLikeQuantifierExprs(ctx.expression).reduceLeft(And) @@ -3303,7 +3303,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitLoadData(ctx: LoadDataContext): LogicalPlan = withOrigin(ctx) { LoadData( - child = UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)), + child = UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier), "LOAD DATA"), path = string(ctx.path), isLocal = ctx.LOCAL != null, isOverwrite = ctx.OVERWRITE != null, @@ -3449,7 +3449,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg UnresolvedPartitionSpec(spec, location) } AlterTableAddPartition( - UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)), + UnresolvedTable( + visitMultipartIdentifier(ctx.multipartIdentifier), + "ALTER TABLE ... ADD PARTITION ..."), specsAndLocs.toSeq, ctx.EXISTS != null) } @@ -3491,7 +3493,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val partSpecs = ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec) .map(spec => UnresolvedPartitionSpec(spec)) AlterTableDropPartition( - UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)), + UnresolvedTable( + visitMultipartIdentifier(ctx.multipartIdentifier), + "ALTER TABLE ... DROP PARTITION ..."), partSpecs.toSeq, ifExists = ctx.EXISTS != null, purge = ctx.PURGE != null, @@ -3720,6 +3724,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case _ => string(ctx.STRING) } val nameParts = visitMultipartIdentifier(ctx.multipartIdentifier) - CommentOnTable(UnresolvedTable(nameParts), comment) + CommentOnTable(UnresolvedTable(nameParts, "COMMENT ON TABLE"), comment) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fcf222c8fdab0..ef974dc176e51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2913,18 +2913,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val TRUNCATE_TRASH_ENABLED = - buildConf("spark.sql.truncate.trash.enabled") - .doc("This configuration decides when truncating table, whether data files will be moved " + - "to trash directory or deleted permanently. The trash retention time is controlled by " + - "'fs.trash.interval', and in default, the server side configuration value takes " + - "precedence over the client-side one. Note that if 'fs.trash.interval' is non-positive, " + - "this will be a no-op and log a warning message. If the data fails to be moved to " + - "trash, Spark will turn to delete it permanently.") - .version("3.1.0") - .booleanConf - .createWithDefault(false) - val DISABLED_JDBC_CONN_PROVIDER_LIST = buildConf("spark.sql.sources.disabledJdbcConnProviderList") .internal() @@ -3577,8 +3565,6 @@ class SQLConf extends Serializable with Logging { def legacyPathOptionBehavior: Boolean = getConf(SQLConf.LEGACY_PATH_OPTION_BEHAVIOR) - def truncateTrashEnabled: Boolean = getConf(SQLConf.TRUNCATE_TRASH_ENABLED) - def disabledJdbcConnectionProviders: String = getConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala new file mode 100644 index 0000000000000..586aa6c59164f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver + +object PartitioningUtils { + /** + * Normalize the column names in partition specification, w.r.t. the real partition column names + * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a + * partition column named `month`, and it's case insensitive, we will normalize `monTh` to + * `month`. + */ + def normalizePartitionSpec[T]( + partitionSpec: Map[String, T], + partColNames: Seq[String], + tblName: String, + resolver: Resolver): Map[String, T] = { + val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) => + val normalizedKey = partColNames.find(resolver(_, key)).getOrElse { + throw new AnalysisException(s"$key is not a valid partition column in table $tblName.") + } + normalizedKey -> value + } + + SchemaUtils.checkColumnNameDuplication( + normalizedPartSpec.map(_._1), "in the partition schema", resolver) + + normalizedPartSpec.toMap + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index a1b6cec24f23f..730574a4b9846 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -943,6 +943,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil) } + test("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val msg = intercept[IllegalArgumentException] { + evaluateWithoutCodegen( + ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST"))) + }.getMessage + assert(msg.contains("Find an invaild url string")) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation( + ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST")), null) + } + } + test("Sentences") { val nullString = Literal.create(null, StringType) checkEvaluation(Sentences(nullString, nullString, nullString), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala index 64b619ca7766b..f8dca266a62d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala @@ -95,4 +95,26 @@ class SubExprEvaluationRuntimeSuite extends SparkFunSuite { }) assert(proxys.isEmpty) } + + test("SubExprEvaluationRuntime should wrap semantically equal exprs") { + val runtime = new SubExprEvaluationRuntime(1) + + val one = Literal(1) + val two = Literal(2) + def mul: (Literal, Literal) => Expression = + (left: Literal, right: Literal) => Multiply(left, right) + + val mul2_1 = Multiply(mul(one, two), mul(one, two)) + val mul2_2 = Multiply(mul(one, two), mul(one, two)) + + val sqrt = Sqrt(mul2_1) + val sum = Add(mul2_2, sqrt) + val proxyExpressions = runtime.proxyExpressions(Seq(sum)) + val proxys = proxyExpressions.flatMap(_.collect { + case p: ExpressionProxy => p + }) + // ( (one * two) * (one * two) ) + assert(proxys.size == 2) + assert(proxys.forall(_.child.semanticEquals(mul2_1))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index db0399d2a73ee..82d6757407b51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.catalyst.optimizer import java.time.ZoneId import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.unsafe.types.UTF8String class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -67,4 +69,16 @@ class ComputeCurrentTimeSuite extends PlanTest { assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } + + test("SPARK-33469: Add current_timezone function") { + val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val lits = new scala.collection.mutable.ArrayBuffer[String] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[UTF8String].toString + e + } + assert(lits.size == 1) + assert(lits.head == SQLConf.get.sessionLocalTimeZone) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala index 389aaeafe655f..cf850bbe21ce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala @@ -36,7 +36,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { val b = testRelation.output(1) val c = testRelation.output(2) - test("replace first(col) by nth_value(col, 1)") { + test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND CURRENT ROW") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), @@ -52,7 +52,34 @@ class OptimizeWindowFunctionsSuite extends PlanTest { assert(optimized == correctAnswer) } - test("can't replace first(col) by nth_value(col, 1) if the window frame type is range") { + test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") { + val inputPlan = testRelation.select( + WindowExpression( + First(a, false).toAggregateExpression(), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)))) + val correctAnswer = testRelation.select( + WindowExpression( + NthValue(a, Literal(1), false), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)))) + + val optimized = Optimize.execute(inputPlan) + assert(optimized == correctAnswer) + } + + test("can't replace first by nth_value if frame is not suitable") { + val inputPlan = testRelation.select( + WindowExpression( + First(a, false).toAggregateExpression(), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, Literal(1), CurrentRow)))) + + val optimized = Optimize.execute(inputPlan) + assert(optimized == inputPlan) + } + + test("can't replace first by nth_value if the window frame type is range") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), @@ -63,7 +90,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { assert(optimized == inputPlan) } - test("can't replace first(col) by nth_value(col, 1) if the window frame isn't ordered") { + test("can't replace first by nth_value if the window frame isn't ordered") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index f93c0dcf59f4c..bd28484b23f46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1555,15 +1555,15 @@ class DDLParserSuite extends AnalysisTest { test("LOAD DATA INTO table") { comparePlans( parsePlan("LOAD DATA INPATH 'filepath' INTO TABLE a.b.c"), - LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", false, false, None)) + LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", false, false, None)) comparePlans( parsePlan("LOAD DATA LOCAL INPATH 'filepath' INTO TABLE a.b.c"), - LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", true, false, None)) + LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, false, None)) comparePlans( parsePlan("LOAD DATA LOCAL INPATH 'filepath' OVERWRITE INTO TABLE a.b.c"), - LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", true, true, None)) + LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, true, None)) comparePlans( parsePlan( @@ -1572,7 +1572,7 @@ class DDLParserSuite extends AnalysisTest { |PARTITION(ds='2017-06-10') """.stripMargin), LoadData( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, true, @@ -1674,13 +1674,13 @@ class DDLParserSuite extends AnalysisTest { val parsed2 = parsePlan(sql2) val expected1 = AlterTableAddPartition( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ..."), Seq( UnresolvedPartitionSpec(Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), UnresolvedPartitionSpec(Map("dt" -> "2009-09-09", "country" -> "uk"), None)), ifNotExists = true) val expected2 = AlterTableAddPartition( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ..."), Seq(UnresolvedPartitionSpec(Map("dt" -> "2008-08-08"), Some("loc"))), ifNotExists = false) @@ -1747,7 +1747,7 @@ class DDLParserSuite extends AnalysisTest { assertUnsupported(sql2_view) val expected1_table = AlterTableDropPartition( - UnresolvedTable(Seq("table_name")), + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP PARTITION ..."), Seq( UnresolvedPartitionSpec(Map("dt" -> "2008-08-08", "country" -> "us")), UnresolvedPartitionSpec(Map("dt" -> "2009-09-09", "country" -> "uk"))), @@ -1763,7 +1763,7 @@ class DDLParserSuite extends AnalysisTest { val sql3_table = "ALTER TABLE a.b.c DROP IF EXISTS PARTITION (ds='2017-06-10')" val expected3_table = AlterTableDropPartition( - UnresolvedTable(Seq("a", "b", "c")), + UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... DROP PARTITION ..."), Seq(UnresolvedPartitionSpec(Map("ds" -> "2017-06-10"))), ifExists = true, purge = false, @@ -2174,7 +2174,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan("COMMENT ON TABLE a.b.c IS 'xYz'"), - CommentOnTable(UnresolvedTable(Seq("a", "b", "c")), "xYz")) + CommentOnTable(UnresolvedTable(Seq("a", "b", "c"), "COMMENT ON TABLE"), "xYz")) } // TODO: ignored by SPARK-31707, restore the test after create table syntax unification diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala index 1c96bdf3afa20..23987e909aa70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala @@ -92,4 +92,8 @@ class InMemoryPartitionTable( override def partitionExists(ident: InternalRow): Boolean = memoryTablePartitions.containsKey(ident) + + override protected def addPartitionKey(key: Seq[Any]): Unit = { + memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 3b47271a114e2..c93053abc550a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -160,12 +160,15 @@ class InMemoryTable( } } + protected def addPartitionKey(key: Seq[Any]): Unit = {} + def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => val key = getKey(row) dataMap += dataMap.get(key) .map(key -> _.withRow(row)) .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row)) + addPartitionKey(key) }) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 276d5d29bfa2c..b26bc6441b6cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -493,6 +493,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • *
  • `allowNonNumericNumbers` (default `true`): allows JSON parser to recognize set of @@ -750,6 +756,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * @@ -781,6 +793,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * @@ -814,6 +832,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * @@ -880,6 +904,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `pathGlobFilter`: an optional glob pattern to only include files with paths matching * the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. * It does not change the behavior of partition discovery.
  • + *
  • `modifiedBefore` (batch only): an optional timestamp to only include files with + * modification times occurring before the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • + *
  • `modifiedAfter` (batch only): an optional timestamp to only include files with + * modification times occurring after the specified Time. The provided timestamp + * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
  • *
  • `recursiveFileLookup`: recursively scan a directory for files. Using this option * disables partition discovery
  • * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index fc62dce5002b1..0b265bfb63e3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, Unresol import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.util.PartitioningUtils /** * Analyzes a given set of partitions to generate per-partition statistics, which will be used in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index d550fe270c753..27ad62026c9b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -39,11 +39,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog} import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitioningUtils import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} // Note: The definition of these commands are based on the ones described in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 206f952fed0ca..bd238948aab02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier, CaseInsensitiveMap} -import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -47,8 +47,8 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitioningUtils import org.apache.spark.sql.util.SchemaUtils -import org.apache.spark.util.Utils /** * A command to create a table with the same definition of the given existing table. @@ -490,7 +490,6 @@ case class TruncateTableCommand( } val hadoopConf = spark.sessionState.newHadoopConf() val ignorePermissionAcl = SQLConf.get.truncateTableIgnorePermissionAcl - val isTrashEnabled = SQLConf.get.truncateTrashEnabled locations.foreach { location => if (location.isDefined) { val path = new Path(location.get) @@ -515,7 +514,7 @@ case class TruncateTableCommand( } } - Utils.moveToTrashOrDelete(fs, path, isTrashEnabled, hadoopConf) + fs.delete(path, true) // We should keep original permission/acl of the path. // For owner/group, only super-user can set it, for example on HDFS. Because diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index fed9614347f6a..5b0d0606da093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -57,13 +57,10 @@ abstract class PartitioningAwareFileIndex( protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] private val caseInsensitiveMap = CaseInsensitiveMap(parameters) + private val pathFilters = PathFilterFactory.create(caseInsensitiveMap) - protected lazy val pathGlobFilter: Option[GlobFilter] = - caseInsensitiveMap.get("pathGlobFilter").map(new GlobFilter(_)) - - protected def matchGlobPattern(file: FileStatus): Boolean = { - pathGlobFilter.forall(_.accept(file.getPath)) - } + protected def matchPathPattern(file: FileStatus): Boolean = + pathFilters.forall(_.accept(file)) protected lazy val recursiveFileLookup: Boolean = { caseInsensitiveMap.getOrElse("recursiveFileLookup", "false").toBoolean @@ -86,7 +83,7 @@ abstract class PartitioningAwareFileIndex( val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => // Directory has children files in it, return them - existingDir.filter(f => matchGlobPattern(f) && isNonEmptyFile(f)) + existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f)) case None => // Directory does not exist, or has no children files @@ -135,7 +132,7 @@ abstract class PartitioningAwareFileIndex( } else { leafFiles.values.toSeq } - files.filter(matchGlobPattern) + files.filter(matchPathPattern) } protected def inferPartitioning(): PartitionSpec = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 796c23c7337d8..ea437d200eaab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter} @@ -357,30 +357,6 @@ object PartitioningUtils { getPathFragment(spec, StructType.fromAttributes(partitionColumns)) } - /** - * Normalize the column names in partition specification, w.r.t. the real partition column names - * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a - * partition column named `month`, and it's case insensitive, we will normalize `monTh` to - * `month`. - */ - def normalizePartitionSpec[T]( - partitionSpec: Map[String, T], - partColNames: Seq[String], - tblName: String, - resolver: Resolver): Map[String, T] = { - val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) => - val normalizedKey = partColNames.find(resolver(_, key)).getOrElse { - throw new AnalysisException(s"$key is not a valid partition column in table $tblName.") - } - normalizedKey -> value - } - - SchemaUtils.checkColumnNameDuplication( - normalizedPartSpec.map(_._1), "in the partition schema", resolver) - - normalizedPartSpec.toMap - } - /** * Resolves possible type conflicts between partitions by up-casting "lower" types using * [[findWiderTypeForPartitionColumn]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala new file mode 100644 index 0000000000000..c8f23988f93c6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Locale, TimeZone} + +import org.apache.hadoop.fs.{FileStatus, GlobFilter} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.unsafe.types.UTF8String + +trait PathFilterStrategy extends Serializable { + def accept(fileStatus: FileStatus): Boolean +} + +trait StrategyBuilder { + def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] +} + +class PathGlobFilter(filePatten: String) extends PathFilterStrategy { + + private val globFilter = new GlobFilter(filePatten) + + override def accept(fileStatus: FileStatus): Boolean = + globFilter.accept(fileStatus.getPath) +} + +object PathGlobFilter extends StrategyBuilder { + val PARAM_NAME = "pathglobfilter" + + override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { + parameters.get(PARAM_NAME).map(new PathGlobFilter(_)) + } +} + +/** + * Provide modifiedAfter and modifiedBefore options when + * filtering from a batch-based file data source. + * + * Example Usages + * Load all CSV files modified after date: + * {{{ + * spark.read.format("csv").option("modifiedAfter","2020-06-15T05:00:00").load() + * }}} + * + * Load all CSV files modified before date: + * {{{ + * spark.read.format("csv").option("modifiedBefore","2020-06-15T05:00:00").load() + * }}} + * + * Load all CSV files modified between two dates: + * {{{ + * spark.read.format("csv").option("modifiedAfter","2019-01-15T05:00:00") + * .option("modifiedBefore","2020-06-15T05:00:00").load() + * }}} + */ +abstract class ModifiedDateFilter extends PathFilterStrategy { + + def timeZoneId: String + + protected def localTime(micros: Long): Long = + DateTimeUtils.fromUTCTime(micros, timeZoneId) +} + +object ModifiedDateFilter { + + def getTimeZoneId(options: CaseInsensitiveMap[String]): String = { + options.getOrElse( + DateTimeUtils.TIMEZONE_OPTION.toLowerCase(Locale.ROOT), + SQLConf.get.sessionLocalTimeZone) + } + + def toThreshold(timeString: String, timeZoneId: String, strategy: String): Long = { + val timeZone: TimeZone = DateTimeUtils.getTimeZone(timeZoneId) + val ts = UTF8String.fromString(timeString) + DateTimeUtils.stringToTimestamp(ts, timeZone.toZoneId).getOrElse { + throw new AnalysisException( + s"The timestamp provided for the '$strategy' option is invalid. The expected format " + + s"is 'YYYY-MM-DDTHH:mm:ss', but the provided timestamp: $timeString") + } + } +} + +/** + * Filter used to determine whether file was modified before the provided timestamp. + */ +class ModifiedBeforeFilter(thresholdTime: Long, val timeZoneId: String) + extends ModifiedDateFilter { + + override def accept(fileStatus: FileStatus): Boolean = + // We standardize on microseconds wherever possible + // getModificationTime returns in milliseconds + thresholdTime - localTime(DateTimeUtils.millisToMicros(fileStatus.getModificationTime)) > 0 +} + +object ModifiedBeforeFilter extends StrategyBuilder { + import ModifiedDateFilter._ + + val PARAM_NAME = "modifiedbefore" + + override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { + parameters.get(PARAM_NAME).map { value => + val timeZoneId = getTimeZoneId(parameters) + val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + new ModifiedBeforeFilter(thresholdTime, timeZoneId) + } + } +} + +/** + * Filter used to determine whether file was modified after the provided timestamp. + */ +class ModifiedAfterFilter(thresholdTime: Long, val timeZoneId: String) + extends ModifiedDateFilter { + + override def accept(fileStatus: FileStatus): Boolean = + // getModificationTime returns in milliseconds + // We standardize on microseconds wherever possible + localTime(DateTimeUtils.millisToMicros(fileStatus.getModificationTime)) - thresholdTime > 0 +} + +object ModifiedAfterFilter extends StrategyBuilder { + import ModifiedDateFilter._ + + val PARAM_NAME = "modifiedafter" + + override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { + parameters.get(PARAM_NAME).map { value => + val timeZoneId = getTimeZoneId(parameters) + val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + new ModifiedAfterFilter(thresholdTime, timeZoneId) + } + } +} + +object PathFilterFactory { + + private val strategies = + Seq(PathGlobFilter, ModifiedBeforeFilter, ModifiedAfterFilter) + + def create(parameters: CaseInsensitiveMap[String]): Seq[PathFilterStrategy] = { + strategies.flatMap { _.create(parameters) } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3a2a642b870f8..9e65b0ce13693 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec import org.apache.spark.sql.util.SchemaUtils /** @@ -386,7 +387,7 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { partColNames: Seq[String], catalogTable: Option[CatalogTable]): InsertIntoStatement = { - val normalizedPartSpec = PartitioningUtils.normalizePartitionSpec( + val normalizedPartSpec = normalizePartitionSpec( insert.partitionSpec, partColNames, tblName, conf.resolver) val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 21abfc2816ee4..e5c29312b80e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -147,6 +147,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( + session, staging, ident, parts, @@ -157,6 +158,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat orCreate = orCreate) :: Nil case _ => ReplaceTableAsSelectExec( + session, catalog, ident, parts, @@ -170,9 +172,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil + AppendDataExecV1(v1, writeOptions.asOptions, query, r) :: Nil case v2 => - AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil + AppendDataExec(session, v2, r, writeOptions.asOptions, planLater(query)) :: Nil } case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => @@ -184,14 +186,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil + OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query, r) :: Nil case v2 => - OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil + OverwriteByExpressionExec(session, v2, r, filters, + writeOptions.asOptions, planLater(query)) :: Nil } case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => OverwritePartitionsDynamicExec( - r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil + session, r.table.asWritable, r, writeOptions.asOptions, planLater(query)) :: Nil case DeleteFromTable(relation, condition) => relation match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 560da39314b36..af7721588edeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -37,10 +37,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class AppendDataExecV1( table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { override protected def run(): Seq[InternalRow] = { - writeWithV1(newWriteBuilder().buildForV1Write()) + writeWithV1(newWriteBuilder().buildForV1Write(), Some(v2Relation)) } } @@ -59,7 +60,8 @@ case class OverwriteByExpressionExecV1( table: SupportsWrite, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { private def isTruncate(filters: Array[Filter]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] @@ -68,10 +70,10 @@ case class OverwriteByExpressionExecV1( override protected def run(): Seq[InternalRow] = { newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => - writeWithV1(builder.truncate().asV1Builder.buildForV1Write()) + writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), Some(v2Relation)) case builder: SupportsOverwrite => - writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write()) + writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(), Some(v2Relation)) case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") @@ -112,9 +114,14 @@ sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { trait SupportsV1Write extends SparkPlan { def plan: LogicalPlan - protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = { + protected def writeWithV1( + relation: InsertableRelation, + v2Relation: Option[DataSourceV2Relation] = None): Seq[InternalRow] = { + val session = sqlContext.sparkSession // The `plan` is already optimized, we should not analyze and optimize it again. - relation.insert(AlreadyOptimized.dataFrame(sqlContext.sparkSession, plan), overwrite = false) + relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false) + v2Relation.foreach(r => session.sharedState.cacheManager.recacheByPlan(session, r)) + Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 1421a9315c3a8..1648134d0a1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -127,6 +128,7 @@ case class AtomicCreateTableAsSelectExec( * ReplaceTableAsSelectStagingExec. */ case class ReplaceTableAsSelectExec( + session: SparkSession, catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -146,6 +148,8 @@ case class ReplaceTableAsSelectExec( // 2. Writing to the new table fails, // 3. The table returned by catalog.createTable doesn't support writing. if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -169,6 +173,7 @@ case class ReplaceTableAsSelectExec( * is left untouched. */ case class AtomicReplaceTableAsSelectExec( + session: SparkSession, catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -180,6 +185,10 @@ case class AtomicReplaceTableAsSelectExec( override protected def run(): Seq[InternalRow] = { val schema = query.schema.asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) + } val staged = if (orCreate) { catalog.stageCreateOrReplace( ident, schema, partitioning.toArray, properties.asJava) @@ -204,12 +213,16 @@ case class AtomicReplaceTableAsSelectExec( * Rows in the output data set are appended. */ case class AppendDataExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - writeWithV2(newWriteBuilder().buildForBatch()) + val writtenRows = writeWithV2(newWriteBuilder().buildForBatch()) + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -224,7 +237,9 @@ case class AppendDataExec( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -234,7 +249,7 @@ case class OverwriteByExpressionExec( } override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => writeWithV2(builder.truncate().buildForBatch()) @@ -244,9 +259,12 @@ case class OverwriteByExpressionExec( case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } + /** * Physical plan node for dynamic partition overwrite into a v2 table. * @@ -257,18 +275,22 @@ case class OverwriteByExpressionExec( * are not modified. */ case class OverwritePartitionsDynamicExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsDynamicOverwrite => writeWithV2(builder.overwriteDynamicPartitions().buildForBatch()) case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -370,6 +392,15 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { Nil } + + protected def uncacheTable( + session: SparkSession, + catalog: TableCatalog, + table: Table, + ident: Identifier): Unit = { + val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true) + } } object DataWritingSparkTask extends Logging { @@ -484,3 +515,4 @@ private[v2] case class DataWritingSparkTaskResult( * Sink progress information collected after commit. */ private[sql] case class StreamWriterCommitProgress(numOutputRows: Long) + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 712ed1585bc8a..6f43542fd6595 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -23,6 +23,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.{ModifiedAfterFilter, ModifiedBeforeFilter} import org.apache.spark.util.Utils /** @@ -32,6 +33,16 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + checkDisallowedOptions(parameters) + + private def checkDisallowedOptions(options: Map[String, String]): Unit = { + Seq(ModifiedBeforeFilter.PARAM_NAME, ModifiedAfterFilter.PARAM_NAME).foreach { param => + if (parameters.contains(param)) { + throw new IllegalArgumentException(s"option '$param' is not allowed in file stream sources") + } + } + } + val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str => Try(str.toInt).toOption.filter(_ > 0).getOrElse { throw new IllegalArgumentException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index 77078046dda7c..f48672afb41f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -189,8 +189,8 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) val graphUIDataForNumRowsDroppedByWatermark = new GraphUIData( - "aggregated-num-state-rows-dropped-by-watermark-timeline", - "aggregated-num-state-rows-dropped-by-watermark-histogram", + "aggregated-num-rows-dropped-by-watermark-timeline", + "aggregated-num-rows-dropped-by-watermark-histogram", numRowsDroppedByWatermarkData, minBatchTime, maxBatchTime, @@ -209,33 +209,33 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) {graphUIDataForNumberTotalRows.generateTimelineHtml(jsCollector)} {graphUIDataForNumberTotalRows.generateHistogramHtml(jsCollector)} - - -
    -
    Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
    -
    - - {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} - {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} - - - -
    -
    Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
    -
    - - {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} - {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} - - - -
    -
    Aggregated Number Of State Rows Dropped By Watermark {SparkUIUtils.tooltip("Aggregated number of state rows dropped by watermark.", "right")}
    -
    - - {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} - {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} - + + +
    +
    Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
    +
    + + {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} + {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} + + + +
    +
    Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
    +
    + + {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} + {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} + + + +
    +
    Aggregated Number Of Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all input rows being dropped in stateful operators by watermark. 'Inputs' are relative to operators.", "right")}
    +
    + + {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} + {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} + // scalastyle:on } else { new NodeBuffer() diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index da83df4994d8d..0a54dff3a1cea 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ ## Summary - - Number of queries: 341 + - Number of queries: 342 - Number of expressions that missing example: 13 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window ## Schema of Built-in Functions @@ -86,6 +86,7 @@ | org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct | +| org.apache.spark.sql.catalyst.expressions.CurrentTimeZone | current_timezone | SELECT current_timezone() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentTimestamp | current_timestamp | SELECT current_timestamp() | struct | | org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6a1378837ea9b..953a58760cd5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1947,6 +1947,14 @@ class DatasetSuite extends QueryTest df.where($"zoo".contains(Array('a', 'b'))), Seq(Row("abc"))) } + + test("SPARK-33469: Add current_timezone function") { + val df = Seq(1).toDF("c") + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Shanghai") { + val timezone = df.selectExpr("current_timezone()").collect().head.getString(0) + assert(timezone == "Asia/Shanghai") + } + } } object AssertExecutionId { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index b27c1145181bd..876f62803dc7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -577,38 +577,6 @@ class FileBasedDataSourceSuite extends QueryTest } } - test("Option pathGlobFilter: filter files correctly") { - withTempPath { path => - val dataDir = path.getCanonicalPath - Seq("foo").toDS().write.text(dataDir) - Seq("bar").toDS().write.mode("append").orc(dataDir) - val df = spark.read.option("pathGlobFilter", "*.txt").text(dataDir) - checkAnswer(df, Row("foo")) - - // Both glob pattern in option and path should be effective to filter files. - val df2 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*.orc") - checkAnswer(df2, Seq.empty) - - val df3 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*xt") - checkAnswer(df3, Row("foo")) - } - } - - test("Option pathGlobFilter: simple extension filtering should contains partition info") { - withTempPath { path => - val input = Seq(("foo", 1), ("oof", 2)).toDF("a", "b") - input.write.partitionBy("b").text(path.getCanonicalPath) - Seq("bar").toDS().write.mode("append").orc(path.getCanonicalPath + "/b=1") - - // If we use glob pattern in the path, the partition column won't be shown in the result. - val df = spark.read.text(path.getCanonicalPath + "/*/*.txt") - checkAnswer(df, input.select("a")) - - val df2 = spark.read.option("pathGlobFilter", "*.txt").text(path.getCanonicalPath) - checkAnswer(df2, input) - } - } - test("Option recursiveFileLookup: recursive loading correctly") { val expectedFileList = mutable.ListBuffer[String]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala index 107d0ea47249d..e05c2c09ace2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionsException, PartitionsAlreadyExistException} import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits +import org.apache.spark.sql.internal.SQLConf class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { @@ -159,4 +160,29 @@ class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { assert(partTable.asPartitionable.listPartitionIdentifiers(InternalRow.empty).isEmpty) } } + + test("case sensitivity in resolving partition specs") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg = intercept[AnalysisException] { + spark.sql(s"ALTER TABLE $t ADD PARTITION (ID=1) LOCATION 'loc1'") + }.getMessage + assert(errMsg.contains(s"ID is not a valid partition column in table $t")) + } + + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) + .asPartitionable + assert(!partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + spark.sql(s"ALTER TABLE $t ADD PARTITION (ID=1) LOCATION 'loc1'") + assert(partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) + spark.sql(s"ALTER TABLE $t DROP PARTITION (Id=1)") + assert(!partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index ddafa1bb5070a..da53936239de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataSourceV2SQLSuite @@ -43,7 +44,6 @@ class DataSourceV2SQLSuite with AlterTableTests with DatasourceV2SQLBase { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ private val v2Source = classOf[FakeV2Provider].getName override protected val v2Format = v2Source @@ -782,6 +782,84 @@ class DataSourceV2SQLSuite } } + test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") { + Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"REPLACE TABLE $t USING foo AS SELECT id FROM source") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty) + } + } + } + } + + test("SPARK-33492: AppendData should refresh cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a")).toDF("i", "j").write.saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b")).toDF("i", "j").write.mode(SaveMode.Append).saveAsTable(t) + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Row(2, "b") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Row(2) :: Nil) + } + } + } + + test("SPARK-33492: OverwriteByExpression should refresh cache") { + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"INSERT OVERWRITE TABLE $t VALUES (1, 'a')") + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + } + } + } + + test("SPARK-33492: OverwritePartitionsDynamic should refresh cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a", 1)).toDF("i", "j", "k").write.partitionBy("k") saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b", 1)).toDF("i", "j", "k").writeTo(t).overwritePartitions() + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(2, "b", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(2) :: Nil) + } + } + } + test("Relation: basic") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { @@ -1980,57 +2058,6 @@ class DataSourceV2SQLSuite } } - test("ALTER TABLE RECOVER PARTITIONS") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - val e = intercept[AnalysisException] { - sql(s"ALTER TABLE $t RECOVER PARTITIONS") - } - assert(e.message.contains("ALTER TABLE RECOVER PARTITIONS is only supported with v1 tables")) - } - } - - test("ALTER TABLE ADD PARTITION") { - val t = "testpart.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'") - - val partTable = catalog("testpart").asTableCatalog - .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable] - assert(partTable.partitionExists(InternalRow.fromSeq(Seq(1)))) - - val partMetadata = partTable.loadPartitionMetadata(InternalRow.fromSeq(Seq(1))) - assert(partMetadata.containsKey("location")) - assert(partMetadata.get("location") == "loc") - } - } - - test("ALTER TABLE RENAME PARTITION") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - val e = intercept[AnalysisException] { - sql(s"ALTER TABLE $t PARTITION (id=1) RENAME TO PARTITION (id=2)") - } - assert(e.message.contains("ALTER TABLE RENAME PARTITION is only supported with v1 tables")) - } - } - - test("ALTER TABLE DROP PARTITION") { - val t = "testpart.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)") - spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'") - spark.sql(s"ALTER TABLE $t DROP PARTITION (id=1)") - - val partTable = - catalog("testpart").asTableCatalog.loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) - assert(!partTable.asPartitionable.partitionExists(InternalRow.fromSeq(Seq(1)))) - } - } - test("ALTER TABLE SerDe properties") { val t = "testcat.ns1.ns2.tbl" withTable(t) { @@ -2387,7 +2414,8 @@ class DataSourceV2SQLSuite withTempView("v") { sql("create global temp view v as select 1") val e = intercept[AnalysisException](sql("COMMENT ON TABLE global_temp.v IS NULL")) - assert(e.getMessage.contains("global_temp.v is a temp view not table.")) + assert(e.getMessage.contains( + "global_temp.v is a temp view. 'COMMENT ON TABLE' expects a table")) } } @@ -2513,6 +2541,25 @@ class DataSourceV2SQLSuite } } + test("SPARK-33505: insert into partitioned table") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + sql(s""" + |CREATE TABLE $t (id bigint, city string, data string) + |USING foo + |PARTITIONED BY (id, city)""".stripMargin) + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable] + val expectedPartitionIdent = InternalRow.fromSeq(Seq(1, UTF8String.fromString("NY"))) + assert(!partTable.partitionExists(expectedPartitionIdent)) + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'abc'") + assert(partTable.partitionExists(expectedPartitionIdent)) + // Insert into the existing partition must not fail + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'def'") + assert(partTable.partitionExists(expectedPartitionIdent)) + } + } + private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 4b52a4cbf4116..cba7dd35fb3bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -24,14 +24,17 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources._ @@ -145,6 +148,52 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before SparkSession.setDefaultSession(spark) } } + + test("SPARK-33492: append fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").append() + checkAnswer(session.read.table("test"), Row(1, "x") :: Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } + + test("SPARK-33492: overwrite fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").overwrite(lit(true)) + checkAnswer(session.read.table("test"), Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } } class V1WriteFallbackSessionCatalogSuite @@ -177,6 +226,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) + tables.put(Identifier.of(Array("default"), name), t) t } } @@ -272,7 +322,7 @@ class InMemoryTableWithV1Fallback( override val partitioning: Array[Transform], override val properties: util.Map[String, String]) extends Table - with SupportsWrite { + with SupportsWrite with SupportsRead { partitioning.foreach { t => if (!t.isInstanceOf[IdentityTransform]) { @@ -281,6 +331,7 @@ class InMemoryTableWithV1Fallback( } override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.TRUNCATE).asJava @@ -338,6 +389,30 @@ class InMemoryTableWithV1Fallback( } } } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new V1ReadFallbackScanBuilder(schema) + + private class V1ReadFallbackScanBuilder(schema: StructType) extends ScanBuilder { + override def build(): Scan = new V1ReadFallbackScan(schema) + } + + private class V1ReadFallbackScan(schema: StructType) extends V1Scan { + override def readSchema(): StructType = schema + override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = + new V1TableScan(context, schema).asInstanceOf[T] + } + + private class V1TableScan( + context: SQLContext, + requiredSchema: StructType) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = context + override def schema: StructType = requiredSchema + override def buildScan(): RDD[Row] = { + val data = InMemoryV1Provider.getTableData(context.sparkSession, name).collect() + context.sparkContext.makeRDD(data) + } + } } /** A rule that fails if a query plan is analyzed twice. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 792f920ee0217..504cc57dc12d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -147,10 +147,10 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { s"'$viewName' is a view not a table") assertAnalysisError( s"ALTER TABLE $viewName ADD IF NOT EXISTS PARTITION (a='4', b='8')", - s"$viewName is a temp view not table") + s"$viewName is a temp view. 'ALTER TABLE ... ADD PARTITION ...' expects a table") assertAnalysisError( s"ALTER TABLE $viewName DROP PARTITION (a='4', b='8')", - s"$viewName is a temp view not table") + s"$viewName is a temp view. 'ALTER TABLE ... DROP PARTITION ...' expects a table") // For the following v2 ALERT TABLE statements, unsupported operations are checked first // before resolving the relations. @@ -175,7 +175,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { val e2 = intercept[AnalysisException] { sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""") }.getMessage - assert(e2.contains(s"$viewName is a temp view not table")) + assert(e2.contains(s"$viewName is a temp view. 'LOAD DATA' expects a table")) assertNoSuchTable(s"TRUNCATE TABLE $viewName") val e3 = intercept[AnalysisException] { sql(s"SHOW CREATE TABLE $viewName") @@ -214,7 +214,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { e = intercept[AnalysisException] { sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""") }.getMessage - assert(e.contains("default.testView is a view not table")) + assert(e.contains("default.testView is a view. 'LOAD DATA' expects a table")) e = intercept[AnalysisException] { sql(s"TRUNCATE TABLE $viewName") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 9d0147048dbb8..43a33860d262e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -3104,84 +3104,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(spark.sessionState.catalog.isRegisteredFunction(rand)) } } - - test("SPARK-32481 Move data to trash on truncate table if enabled") { - val trashIntervalKey = "fs.trash.interval" - withTable("tab1") { - withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "true") { - sql("CREATE TABLE tab1 (col INT) USING parquet") - sql("INSERT INTO tab1 SELECT 1") - // scalastyle:off hadoopconfiguration - val hadoopConf = spark.sparkContext.hadoopConfiguration - // scalastyle:on hadoopconfiguration - val originalValue = hadoopConf.get(trashIntervalKey, "0") - val tablePath = new Path(spark.sessionState.catalog - .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get) - - val fs = tablePath.getFileSystem(hadoopConf) - val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current") - val trashPath = Path.mergePaths(trashCurrent, tablePath) - assume( - fs.mkdirs(trashPath) && fs.delete(trashPath, false), - "Trash directory could not be created, skipping.") - assert(!fs.exists(trashPath)) - try { - hadoopConf.set(trashIntervalKey, "5") - sql("TRUNCATE TABLE tab1") - } finally { - hadoopConf.set(trashIntervalKey, originalValue) - } - assert(fs.exists(trashPath)) - fs.delete(trashPath, true) - } - } - } - - test("SPARK-32481 delete data permanently on truncate table if trash interval is non-positive") { - val trashIntervalKey = "fs.trash.interval" - withTable("tab1") { - withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "true") { - sql("CREATE TABLE tab1 (col INT) USING parquet") - sql("INSERT INTO tab1 SELECT 1") - // scalastyle:off hadoopconfiguration - val hadoopConf = spark.sparkContext.hadoopConfiguration - // scalastyle:on hadoopconfiguration - val originalValue = hadoopConf.get(trashIntervalKey, "0") - val tablePath = new Path(spark.sessionState.catalog - .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get) - - val fs = tablePath.getFileSystem(hadoopConf) - val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current") - val trashPath = Path.mergePaths(trashCurrent, tablePath) - assert(!fs.exists(trashPath)) - try { - hadoopConf.set(trashIntervalKey, "0") - sql("TRUNCATE TABLE tab1") - } finally { - hadoopConf.set(trashIntervalKey, originalValue) - } - assert(!fs.exists(trashPath)) - } - } - } - - test("SPARK-32481 Do not move data to trash on truncate table if disabled") { - withTable("tab1") { - withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "false") { - sql("CREATE TABLE tab1 (col INT) USING parquet") - sql("INSERT INTO tab1 SELECT 1") - val hadoopConf = spark.sessionState.newHadoopConf() - val tablePath = new Path(spark.sessionState.catalog - .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get) - - val fs = tablePath.getFileSystem(hadoopConf) - val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current") - val trashPath = Path.mergePaths(trashCurrent, tablePath) - sql("TRUNCATE TABLE tab1") - assert(!fs.exists(trashPath)) - } - } - } } object FakeLocalFsFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala new file mode 100644 index 0000000000000..b965a78c9eec0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.test.SharedSparkSession + +class PathFilterStrategySuite extends QueryTest with SharedSparkSession { + + test("SPARK-31962: PathFilterStrategies - modifiedAfter option") { + val options = + CaseInsensitiveMap[String](Map("modifiedAfter" -> "2010-10-01T01:01:00")) + val strategy = PathFilterFactory.create(options) + assert(strategy.head.isInstanceOf[ModifiedAfterFilter]) + assert(strategy.size == 1) + } + + test("SPARK-31962: PathFilterStrategies - modifiedBefore option") { + val options = + CaseInsensitiveMap[String](Map("modifiedBefore" -> "2020-10-01T01:01:00")) + val strategy = PathFilterFactory.create(options) + assert(strategy.head.isInstanceOf[ModifiedBeforeFilter]) + assert(strategy.size == 1) + } + + test("SPARK-31962: PathFilterStrategies - pathGlobFilter option") { + val options = CaseInsensitiveMap[String](Map("pathGlobFilter" -> "*.txt")) + val strategy = PathFilterFactory.create(options) + assert(strategy.head.isInstanceOf[PathGlobFilter]) + assert(strategy.size == 1) + } + + test("SPARK-31962: PathFilterStrategies - no options") { + val options = CaseInsensitiveMap[String](Map.empty) + val strategy = PathFilterFactory.create(options) + assert(strategy.isEmpty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala new file mode 100644 index 0000000000000..1af2adfd8640c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.time.{LocalDateTime, ZoneId, ZoneOffset} +import java.time.format.DateTimeFormatter + +import scala.util.Random + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.util.{stringToFile, DateTimeUtils} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class PathFilterSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("SPARK-31962: modifiedBefore specified" + + " and sharing same timestamp with file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formatTime(curTime))) + } + } + + test("SPARK-31962: modifiedAfter specified" + + " and sharing same timestamp with file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + executeTest(dir, Seq(curTime), 0, modifiedAfter = Some(formatTime(curTime))) + } + } + + test("SPARK-31962: modifiedBefore and modifiedAfter option" + + " share same timestamp with file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime), + modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore and modifiedAfter option" + + " share same timestamp with earlier file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val fileTime = curTime.minusDays(3) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(fileTime), 0, modifiedBefore = Some(formattedTime), + modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore and modifiedAfter option" + + " share same timestamp with later file last modified time.") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime), + modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: when modifiedAfter specified with a past date") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val pastTime = curTime.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(curTime), 1, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: when modifiedBefore specified with a future date") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val futureTime = curTime.plusYears(1) + val formattedTime = formatTime(futureTime) + executeTest(dir, Seq(curTime), 1, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: with modifiedBefore option provided using a past date") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val pastTime = curTime.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedAfter specified with a past date, multiple files, one valid") { + withTempDir { dir => + val fileTime1 = LocalDateTime.now(ZoneOffset.UTC) + val fileTime2 = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val pastTime = fileTime1.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(fileTime1, fileTime2), 1, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedAfter specified with a past date, multiple files, both valid") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val pastTime = curTime.minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(curTime, curTime), 2, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedAfter specified with a past date, multiple files, none valid") { + withTempDir { dir => + val fileTime = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val pastTime = LocalDateTime.now(ZoneOffset.UTC).minusYears(1) + val formattedTime = formatTime(pastTime) + executeTest(dir, Seq(fileTime, fileTime), 0, modifiedAfter = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore specified with a future date, multiple files, both valid") { + withTempDir { dir => + val fileTime = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val futureTime = LocalDateTime.now(ZoneOffset.UTC).plusYears(1) + val formattedTime = formatTime(futureTime) + executeTest(dir, Seq(fileTime, fileTime), 2, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore specified with a future date, multiple files, one valid") { + withTempDir { dir => + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val fileTime1 = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC) + val fileTime2 = curTime.plusDays(3) + val formattedTime = formatTime(curTime) + executeTest(dir, Seq(fileTime1, fileTime2), 1, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore specified with a future date, multiple files, none valid") { + withTempDir { dir => + val fileTime = LocalDateTime.now(ZoneOffset.UTC).minusDays(1) + val formattedTime = formatTime(fileTime) + executeTest(dir, Seq(fileTime, fileTime), 0, modifiedBefore = Some(formattedTime)) + } + } + + test("SPARK-31962: modifiedBefore/modifiedAfter is specified with an invalid date") { + executeTestWithBadOption( + Map("modifiedBefore" -> "2024-05+1 01:00:00"), + Seq("The timestamp provided", "modifiedbefore", "2024-05+1 01:00:00")) + + executeTestWithBadOption( + Map("modifiedAfter" -> "2024-05+1 01:00:00"), + Seq("The timestamp provided", "modifiedafter", "2024-05+1 01:00:00")) + } + + test("SPARK-31962: modifiedBefore/modifiedAfter - empty option") { + executeTestWithBadOption( + Map("modifiedBefore" -> ""), + Seq("The timestamp provided", "modifiedbefore")) + + executeTestWithBadOption( + Map("modifiedAfter" -> ""), + Seq("The timestamp provided", "modifiedafter")) + } + + test("SPARK-31962: modifiedBefore/modifiedAfter filter takes into account local timezone " + + "when specified as an option.") { + Seq("modifiedbefore", "modifiedafter").foreach { filterName => + // CET = UTC + 1 hour, HST = UTC - 10 hours + Seq("CET", "HST").foreach { tzId => + testModifiedDateFilterWithTimezone(tzId, filterName) + } + } + } + + test("Option pathGlobFilter: filter files correctly") { + withTempPath { path => + val dataDir = path.getCanonicalPath + Seq("foo").toDS().write.text(dataDir) + Seq("bar").toDS().write.mode("append").orc(dataDir) + val df = spark.read.option("pathGlobFilter", "*.txt").text(dataDir) + checkAnswer(df, Row("foo")) + + // Both glob pattern in option and path should be effective to filter files. + val df2 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*.orc") + checkAnswer(df2, Seq.empty) + + val df3 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*xt") + checkAnswer(df3, Row("foo")) + } + } + + test("Option pathGlobFilter: simple extension filtering should contains partition info") { + withTempPath { path => + val input = Seq(("foo", 1), ("oof", 2)).toDF("a", "b") + input.write.partitionBy("b").text(path.getCanonicalPath) + Seq("bar").toDS().write.mode("append").orc(path.getCanonicalPath + "/b=1") + + // If we use glob pattern in the path, the partition column won't be shown in the result. + val df = spark.read.text(path.getCanonicalPath + "/*/*.txt") + checkAnswer(df, input.select("a")) + + val df2 = spark.read.option("pathGlobFilter", "*.txt").text(path.getCanonicalPath) + checkAnswer(df2, input) + } + } + + private def executeTest( + dir: File, + fileDates: Seq[LocalDateTime], + expectedCount: Long, + modifiedBefore: Option[String] = None, + modifiedAfter: Option[String] = None): Unit = { + fileDates.foreach { fileDate => + val file = createSingleFile(dir) + setFileTime(fileDate, file) + } + + val schema = StructType(Seq(StructField("a", StringType))) + + var dfReader = spark.read.format("csv").option("timeZone", "UTC").schema(schema) + modifiedBefore.foreach { opt => dfReader = dfReader.option("modifiedBefore", opt) } + modifiedAfter.foreach { opt => dfReader = dfReader.option("modifiedAfter", opt) } + + if (expectedCount > 0) { + // without pathGlobFilter + val df1 = dfReader.load(dir.getCanonicalPath) + assert(df1.count() === expectedCount) + + // pathGlobFilter matched + val df2 = dfReader.option("pathGlobFilter", "*.csv").load(dir.getCanonicalPath) + assert(df2.count() === expectedCount) + + // pathGlobFilter mismatched + val df3 = dfReader.option("pathGlobFilter", "*.txt").load(dir.getCanonicalPath) + assert(df3.count() === 0) + } else { + val df = dfReader.load(dir.getCanonicalPath) + assert(df.count() === 0) + } + } + + private def executeTestWithBadOption( + options: Map[String, String], + expectedMsgParts: Seq[String]): Unit = { + withTempDir { dir => + createSingleFile(dir) + val exc = intercept[AnalysisException] { + var dfReader = spark.read.format("csv") + options.foreach { case (key, value) => + dfReader = dfReader.option(key, value) + } + dfReader.load(dir.getCanonicalPath) + } + expectedMsgParts.foreach { msg => assert(exc.getMessage.contains(msg)) } + } + } + + private def testModifiedDateFilterWithTimezone( + timezoneId: String, + filterParamName: String): Unit = { + val curTime = LocalDateTime.now(ZoneOffset.UTC) + val zoneId: ZoneId = DateTimeUtils.getTimeZone(timezoneId).toZoneId + val strategyTimeInMicros = + ModifiedDateFilter.toThreshold( + curTime.toString, + timezoneId, + filterParamName) + val strategyTimeInSeconds = strategyTimeInMicros / 1000 / 1000 + + val curTimeAsSeconds = curTime.atZone(zoneId).toEpochSecond + withClue(s"timezone: $timezoneId / param: $filterParamName,") { + assert(strategyTimeInSeconds === curTimeAsSeconds) + } + } + + private def createSingleFile(dir: File): File = { + val file = new File(dir, "temp" + Random.nextInt(1000000) + ".csv") + stringToFile(file, "text") + } + + private def setFileTime(time: LocalDateTime, file: File): Boolean = { + val sameTime = time.toEpochSecond(ZoneOffset.UTC) + file.setLastModified(sameTime * 1000) + } + + private def formatTime(time: LocalDateTime): String = { + time.format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 9f62ff8301ebc..6085c1f2cccb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -149,6 +149,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { "org.apache.spark.sql.catalyst.expressions.UnixTimestamp", "org.apache.spark.sql.catalyst.expressions.CurrentDate", "org.apache.spark.sql.catalyst.expressions.CurrentTimestamp", + "org.apache.spark.sql.catalyst.expressions.CurrentTimeZone", "org.apache.spark.sql.catalyst.expressions.Now", // Random output without a seed "org.apache.spark.sql.catalyst.expressions.Rand", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index cf9664a9764be..718095003b096 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.streaming import java.io.File import java.net.URI +import java.time.{LocalDateTime, ZoneOffset} +import java.time.format.DateTimeFormatter import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable @@ -40,7 +42,6 @@ import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ import org.apache.spark.sql.types.{StructType, _} import org.apache.spark.util.Utils @@ -2054,6 +2055,47 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-31962: file stream source shouldn't allow modifiedBefore/modifiedAfter") { + def formatTime(time: LocalDateTime): String = { + time.format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss")) + } + + def assertOptionIsNotSupported(options: Map[String, String], path: String): Unit = { + val schema = StructType(Seq(StructField("a", StringType))) + var dsReader = spark.readStream + .format("csv") + .option("timeZone", "UTC") + .schema(schema) + + options.foreach { case (k, v) => dsReader = dsReader.option(k, v) } + + val df = dsReader.load(path) + + testStream(df)( + ExpectFailure[IllegalArgumentException]( + t => assert(t.getMessage.contains("is not allowed in file stream source")), + isFatalError = false) + ) + } + + withTempDir { dir => + // "modifiedBefore" + val futureTime = LocalDateTime.now(ZoneOffset.UTC).plusYears(1) + val formattedFutureTime = formatTime(futureTime) + assertOptionIsNotSupported(Map("modifiedBefore" -> formattedFutureTime), dir.getCanonicalPath) + + // "modifiedAfter" + val prevTime = LocalDateTime.now(ZoneOffset.UTC).minusYears(1) + val formattedPrevTime = formatTime(prevTime) + assertOptionIsNotSupported(Map("modifiedAfter" -> formattedPrevTime), dir.getCanonicalPath) + + // both + assertOptionIsNotSupported( + Map("modifiedBefore" -> formattedFutureTime, "modifiedAfter" -> formattedPrevTime), + dir.getCanonicalPath) + } + } + private def createFile(content: String, src: File, tmp: File): File = { val tempFile = Utils.tempFileWith(new File(tmp, "text")) val finalFile = new File(src, tempFile.getName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala index 1a8b28001b8d1..307479db33949 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala @@ -139,7 +139,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B summaryText should contain ("Aggregated Number Of Total State Rows (?)") summaryText should contain ("Aggregated Number Of Updated State Rows (?)") summaryText should contain ("Aggregated State Memory Used In Bytes (?)") - summaryText should contain ("Aggregated Number Of State Rows Dropped By Watermark (?)") + summaryText should contain ("Aggregated Number Of Rows Dropped By Watermark (?)") } } finally { spark.streams.active.foreach(_.stop()) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 2e9975bcabc3f..f7a4be9591818 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -63,6 +63,10 @@ private[hive] class SparkExecuteStatementOperation( } } + private val substitutorStatement = SQLConf.withExistingConf(sqlContext.conf) { + new VariableSubstitution().substitute(statement) + } + private var result: DataFrame = _ // We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST. @@ -126,6 +130,17 @@ private[hive] class SparkExecuteStatementOperation( } def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = withLocalProperties { + try { + sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement) + getNextRowSetInternal(order, maxRowsL) + } finally { + sqlContext.sparkContext.clearJobGroup() + } + } + + private def getNextRowSetInternal( + order: FetchOrientation, + maxRowsL: Long): RowSet = withLocalProperties { log.info(s"Received getNextRowSet request order=${order} and maxRowsL=${maxRowsL} " + s"with ${statementId}") validateDefaultFetchOrientation(order) @@ -306,9 +321,6 @@ private[hive] class SparkExecuteStatementOperation( parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader) } - val substitutorStatement = SQLConf.withExistingConf(sqlContext.conf) { - new VariableSubstitution().substitute(statement) - } sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement) result = sqlContext.sql(statement) logDebug(result.queryExecution.toString()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 38a8c492d77a7..cf070f4611f3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -52,7 +52,6 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { import HiveExternalCatalogVersionsSuite._ - private val isTestAtLeastJava9 = SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9) private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `spark.test.cache-dir` to a static value like `/tmp/test-spark`, to @@ -60,6 +59,11 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val sparkTestingDir = Option(System.getProperty(SPARK_TEST_CACHE_DIR_SYSTEM_PROPERTY)) .map(new File(_)).getOrElse(Utils.createTempDir(namePrefix = "test-spark")) private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val hiveVersion = if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) { + "2.3.7" + } else { + "1.2.1" + } override def afterAll(): Unit = { try { @@ -149,7 +153,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) } - private def prepare(): Unit = { + override def beforeAll(): Unit = { + super.beforeAll() + val tempPyFile = File.createTempFile("test", ".py") // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, @@ -199,7 +205,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=$hiveVersion", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--conf", s"spark.sql.test.version.index=$index", @@ -211,23 +217,14 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { tempPyFile.delete() } - override def beforeAll(): Unit = { - super.beforeAll() - if (!isTestAtLeastJava9) { - prepare() - } - } - test("backward compatibility") { - // TODO SPARK-28704 Test backward compatibility on JDK9+ once we have a version supports JDK9+ - assume(!isTestAtLeastJava9) val args = Seq( "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"), "--name", "HiveExternalCatalog backward compatibility test", "--master", "local[2]", "--conf", s"${UI_ENABLED.key}=false", "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=$hiveVersion", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", @@ -252,7 +249,9 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // do not throw exception during object initialization. case NonFatal(_) => Seq("3.0.1", "2.4.7") // A temporary fallback to use a specific version } - versions.filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) + versions + .filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38()) + .filter(v => v.startsWith("3") || !SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) } protected var spark: SparkSession = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 1f15bd685b239..56b871644453b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -904,10 +904,10 @@ class HiveDDLSuite assertAnalysisError( s"ALTER TABLE $oldViewName ADD IF NOT EXISTS PARTITION (a='4', b='8')", - s"$oldViewName is a view not table") + s"$oldViewName is a view. 'ALTER TABLE ... ADD PARTITION ...' expects a table.") assertAnalysisError( s"ALTER TABLE $oldViewName DROP IF EXISTS PARTITION (a='2')", - s"$oldViewName is a view not table") + s"$oldViewName is a view. 'ALTER TABLE ... DROP PARTITION ...' expects a table.") assert(catalog.tableExists(TableIdentifier(tabName))) assert(catalog.tableExists(TableIdentifier(oldViewName)))