diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index ba8e4d69ba755..d21b9d9833e9e 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
+import org.apache.spark.storage.BlockManagerId
/**
* :: DeveloperApi ::
@@ -95,6 +96,20 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, this)
+ /**
+ * Stores the location of the list of chosen external shuffle services for handling the
+ * shuffle merge requests from mappers in this shuffle map stage.
+ */
+ private[spark] var mergerLocs: Seq[BlockManagerId] = Nil
+
+ def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
+ if (mergerLocs != null) {
+ this.mergerLocs = mergerLocs
+ }
+ }
+
+ def getMergerLocs: Seq[BlockManagerId] = mergerLocs
+
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 136da80d48dee..f49cb3c2b8836 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -80,6 +80,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
private val conf = SparkEnv.get.conf
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+ protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
protected val simplifiedTraceback: Boolean = false
@@ -139,6 +140,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (workerMemoryMb.isDefined) {
envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString)
}
+ envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index 527d0d6d3a48d..33849f6fcb65f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -85,4 +85,8 @@ private[spark] object PythonUtils {
def getBroadcastThreshold(sc: JavaSparkContext): Long = {
sc.conf.get(org.apache.spark.internal.config.BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD)
}
+
+ def getPythonAuthSocketTimeout(sc: JavaSparkContext): Long = {
+ sc.conf.get(org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
index 188d884319644..348a33e129d65 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
@@ -50,4 +50,10 @@ private[spark] object Python {
.version("2.4.0")
.bytesConf(ByteUnit.MiB)
.createOptional
+
+ val PYTHON_AUTH_SOCKET_TIMEOUT = ConfigBuilder("spark.python.authenticate.socketTimeout")
+ .internal()
+ .version("3.1.0")
+ .timeConf(TimeUnit.SECONDS)
+ .createWithDefaultString("15s")
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 4bc49514fc5ad..b38d0e5c617b9 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1945,4 +1945,51 @@ package object config {
.version("3.0.1")
.booleanConf
.createWithDefault(false)
+
+ private[spark] val PUSH_BASED_SHUFFLE_ENABLED =
+ ConfigBuilder("spark.shuffle.push.enabled")
+ .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " +
+ "conjunction with the server side flag spark.shuffle.server.mergedShuffleFileManagerImpl " +
+ "which needs to be set with the appropriate " +
+ "org.apache.spark.network.shuffle.MergedShuffleFileManager implementation for push-based " +
+ "shuffle to be enabled")
+ .version("3.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS =
+ ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations")
+ .doc("Maximum number of shuffle push merger locations cached for push based shuffle. " +
+ "Currently, shuffle push merger locations are nothing but external shuffle services " +
+ "which are responsible for handling pushed blocks and merging them and serving " +
+ "merged blocks for later shuffle fetch.")
+ .version("3.1.0")
+ .intConf
+ .createWithDefault(500)
+
+ private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO =
+ ConfigBuilder("spark.shuffle.push.mergersMinThresholdRatio")
+ .doc("The minimum number of shuffle merger locations required to enable push based " +
+ "shuffle for a stage. This is specified as a ratio of the number of partitions in " +
+ "the child stage. For example, a reduce stage which has 100 partitions and uses the " +
+ "default value 0.05 requires at least 5 unique merger locations to enable push based " +
+ "shuffle. Merger locations are currently defined as external shuffle services.")
+ .version("3.1.0")
+ .doubleConf
+ .createWithDefault(0.05)
+
+ private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD =
+ ConfigBuilder("spark.shuffle.push.mergersMinStaticThreshold")
+ .doc(s"The static threshold for number of shuffle push merger locations should be " +
+ "available in order to enable push based shuffle for a stage. Note this config " +
+ s"works in conjunction with ${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key}. " +
+ "Maximum of spark.shuffle.push.mergersMinStaticThreshold and " +
+ s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} ratio number of mergers needed to " +
+ "enable push based shuffle for a stage. For eg: with 1000 partitions for the child " +
+ "stage with spark.shuffle.push.mergersMinStaticThreshold as 5 and " +
+ s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " +
+ "at least 50 mergers to enable push based shuffle for that stage.")
+ .version("3.1.0")
+ .doubleConf
+ .createWithDefault(5)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 13b766e654832..6fb0fb93f253b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -249,6 +249,8 @@ private[spark] class DAGScheduler(
private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
+ private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf)
+
/**
* Called by the TaskSetManager to report task's starting.
*/
@@ -1252,6 +1254,33 @@ private[spark] class DAGScheduler(
execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores))
}
+ /**
+ * If push based shuffle is enabled, set the shuffle services to be used for the given
+ * shuffle map stage for block push/merge.
+ *
+ * Even with dynamic resource allocation kicking in and significantly reducing the number
+ * of available active executors, we would still be able to get sufficient shuffle service
+ * locations for block push/merge by getting the historical locations of past executors.
+ */
+ private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = {
+ // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize
+ // TODO changes we cannot disable shuffle merge for the retry/reuse cases
+ val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
+ stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+
+ if (mergerLocs.nonEmpty) {
+ stage.shuffleDep.setMergerLocs(mergerLocs)
+ logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
+ s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
+
+ logDebug("List of shuffle push merger locations " +
+ s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+ } else {
+ logInfo("No available merger locations." +
+ s" Push-based shuffle disabled for $stage (${stage.name})")
+ }
+ }
+
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
logDebug("submitMissingTasks(" + stage + ")")
@@ -1281,6 +1310,12 @@ private[spark] class DAGScheduler(
stage match {
case s: ShuffleMapStage =>
outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1)
+ // Only generate merger location for a given shuffle dependency once. This way, even if
+ // this stage gets retried, it would still be merging blocks using the same set of
+ // shuffle services.
+ if (pushBasedShuffleEnabled) {
+ prepareShuffleServicesForShuffleMapStage(s)
+ }
case s: ResultStage =>
outputCommitCoordinator.stageStart(
stage = s.id, maxPartitionId = s.rdd.partitions.length - 1)
@@ -2027,6 +2062,11 @@ private[spark] class DAGScheduler(
if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) {
executorFailureEpoch(execId) = currentEpoch
logInfo(s"Executor lost: $execId (epoch $currentEpoch)")
+ if (pushBasedShuffleEnabled) {
+ // Remove fetchFailed host in the shuffle push merger list for push based shuffle
+ hostToUnregisterOutputs.foreach(
+ host => blockManagerMaster.removeShufflePushMergerLocation(host))
+ }
blockManagerMaster.removeExecutor(execId)
clearCacheLocs()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index a566d0a04387c..b2acdb3e12a6d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerId
/**
* A backend interface for scheduling systems that allows plugging in different ones under
@@ -92,4 +93,16 @@ private[spark] trait SchedulerBackend {
*/
def maxNumConcurrentTasks(rp: ResourceProfile): Int
+ /**
+ * Get the list of host locations for push based shuffle
+ *
+ * Currently push based shuffle is disabled for both stage retry and stage reuse cases
+ * (for eg: in the case where few partitions are lost due to failure). Hence this method
+ * should be invoked only once for a ShuffleDependency.
+ * @return List of external shuffle services locations
+ */
+ def getShufflePushMergerLocations(
+ numPartitions: Int,
+ resourceProfileId: Int): Seq[BlockManagerId] = Nil
+
}
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index dbcb376905338..f800553c5388b 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -34,7 +34,7 @@ import org.apache.spark.util.Utils
*
* There's no secrecy, so this relies on the sockets being either local or somehow encrypted.
*/
-private[spark] class SocketAuthHelper(conf: SparkConf) {
+private[spark] class SocketAuthHelper(val conf: SparkConf) {
val secret = Utils.createSecret(conf)
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
index 548fd1b07ddc5..35990b5a59281 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -25,6 +25,8 @@ import scala.concurrent.duration.Duration
import scala.util.Try
import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -34,11 +36,11 @@ import org.apache.spark.util.{ThreadUtils, Utils}
* handling one batch of data, with authentication and error handling.
*
* The socket server can only accept one connection, or close if no connection
- * in 15 seconds.
+ * in configurable amount of seconds (default 15).
*/
private[spark] abstract class SocketAuthServer[T](
authHelper: SocketAuthHelper,
- threadName: String) {
+ threadName: String) extends Logging {
def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
def this(threadName: String) = this(SparkEnv.get, threadName)
@@ -46,19 +48,26 @@ private[spark] abstract class SocketAuthServer[T](
private val promise = Promise[T]()
private def startServer(): (Int, String) = {
+ logTrace("Creating listening socket")
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
- // Close the socket if no connection in 15 seconds
- serverSocket.setSoTimeout(15000)
+ // Close the socket if no connection in the configured seconds
+ val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt
+ logTrace(s"Setting timeout to $timeout sec")
+ serverSocket.setSoTimeout(timeout * 1000)
new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
+ logTrace(s"Waiting for connection on port ${serverSocket.getLocalPort}")
sock = serverSocket.accept()
+ logTrace(s"Connection accepted from address ${sock.getRemoteSocketAddress}")
authHelper.authClient(sock)
+ logTrace("Client authenticated")
promise.complete(Try(handleConnection(sock)))
} finally {
+ logTrace("Closing server")
JavaUtils.closeQuietly(serverSocket)
JavaUtils.closeQuietly(sock)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 49e32d04d450a..c6a4457d8f910 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -145,4 +145,6 @@ private[spark] object BlockManagerId {
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
blockManagerIdCache.get(id)
}
+
+ private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger"
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index f544d47b8e13c..fe1a5aef9499c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -125,6 +125,26 @@ class BlockManagerMaster(
driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ /**
+ * Get a list of unique shuffle service locations where an executor is successfully
+ * registered in the past for block push/merge with push based shuffle.
+ */
+ def getShufflePushMergerLocations(
+ numMergersNeeded: Int,
+ hostsToFilter: Set[String]): Seq[BlockManagerId] = {
+ driverEndpoint.askSync[Seq[BlockManagerId]](
+ GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter))
+ }
+
+ /**
+ * Remove the host from the candidate list of shuffle push mergers. This can be
+ * triggered if there is a FetchFailedException on the host
+ * @param host
+ */
+ def removeShufflePushMergerLocation(host: String): Unit = {
+ driverEndpoint.askSync[Seq[BlockManagerId]](RemoveShufflePushMergerLocation(host))
+ }
+
def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = {
driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index a7532a9870fae..4d565511704d4 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -74,6 +74,14 @@ class BlockManagerMasterEndpoint(
// Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
+ // Mapping from host name to shuffle (mergers) services where the current app
+ // registered an executor in the past. Older hosts are removed when the
+ // maxRetainedMergerLocations size is reached in favor of newer locations.
+ private val shuffleMergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]()
+
+ // Maximum number of merger locations to cache
+ private val maxRetainedMergerLocations = conf.get(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS)
+
private val askThreadPool =
ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100)
private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
@@ -92,6 +100,8 @@ class BlockManagerMasterEndpoint(
val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf)
+ private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf)
+
logInfo("BlockManagerMasterEndpoint up")
// same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)
// && conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)`
@@ -139,6 +149,12 @@ class BlockManagerMasterEndpoint(
case GetBlockStatus(blockId, askStorageEndpoints) =>
context.reply(blockStatus(blockId, askStorageEndpoints))
+ case GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter) =>
+ context.reply(getShufflePushMergerLocations(numMergersNeeded, hostsToFilter))
+
+ case RemoveShufflePushMergerLocation(host) =>
+ context.reply(removeShufflePushMergerLocation(host))
+
case IsExecutorAlive(executorId) =>
context.reply(blockManagerIdByExecutor.contains(executorId))
@@ -360,6 +376,17 @@ class BlockManagerMasterEndpoint(
}
+ private def addMergerLocation(blockManagerId: BlockManagerId): Unit = {
+ if (!blockManagerId.isDriver && !shuffleMergerLocations.contains(blockManagerId.host)) {
+ val shuffleServerId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER,
+ blockManagerId.host, externalShuffleServicePort)
+ if (shuffleMergerLocations.size >= maxRetainedMergerLocations) {
+ shuffleMergerLocations -= shuffleMergerLocations.head._1
+ }
+ shuffleMergerLocations(shuffleServerId.host) = shuffleServerId
+ }
+ }
+
private def removeExecutor(execId: String): Unit = {
logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
@@ -526,6 +553,10 @@ class BlockManagerMasterEndpoint(
blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(),
maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint, externalShuffleServiceBlockStatus)
+
+ if (pushBasedShuffleEnabled) {
+ addMergerLocation(id)
+ }
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize,
Some(maxOnHeapMemSize), Some(maxOffHeapMemSize)))
@@ -657,6 +688,40 @@ class BlockManagerMasterEndpoint(
}
}
+ private def getShufflePushMergerLocations(
+ numMergersNeeded: Int,
+ hostsToFilter: Set[String]): Seq[BlockManagerId] = {
+ val blockManagerHosts = blockManagerIdByExecutor.values.map(_.host).toSet
+ val filteredBlockManagerHosts = blockManagerHosts.filterNot(hostsToFilter.contains(_))
+ val filteredMergersWithExecutors = filteredBlockManagerHosts.map(
+ BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, _, externalShuffleServicePort))
+ // Enough mergers are available as part of active executors list
+ if (filteredMergersWithExecutors.size >= numMergersNeeded) {
+ filteredMergersWithExecutors.toSeq
+ } else {
+ // Delta mergers added from inactive mergers list to the active mergers list
+ val filteredMergersWithExecutorsHosts = filteredMergersWithExecutors.map(_.host)
+ val filteredMergersWithoutExecutors = shuffleMergerLocations.values
+ .filterNot(x => hostsToFilter.contains(x.host))
+ .filterNot(x => filteredMergersWithExecutorsHosts.contains(x.host))
+ val randomFilteredMergersLocations =
+ if (filteredMergersWithoutExecutors.size >
+ numMergersNeeded - filteredMergersWithExecutors.size) {
+ Utils.randomize(filteredMergersWithoutExecutors)
+ .take(numMergersNeeded - filteredMergersWithExecutors.size)
+ } else {
+ filteredMergersWithoutExecutors
+ }
+ filteredMergersWithExecutors.toSeq ++ randomFilteredMergersLocations
+ }
+ }
+
+ private def removeShufflePushMergerLocation(host: String): Unit = {
+ if (shuffleMergerLocations.contains(host)) {
+ shuffleMergerLocations.remove(host)
+ }
+ }
+
/**
* Returns an [[RpcEndpointRef]] of the [[BlockManagerReplicaEndpoint]] for sending RPC messages.
*/
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index bbc076cea9ba8..afe416a55ed0d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -141,4 +141,10 @@ private[spark] object BlockManagerMessages {
case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster
+
+ case class GetShufflePushMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String])
+ extends ToBlockManagerMaster
+
+ case class RemoveShufflePushMergerLocation(host: String) extends ToBlockManagerMaster
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala
index a3a528cddee37..4af48d5b9125c 100644
--- a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala
@@ -136,12 +136,53 @@ private[spark] object HadoopFSUtils extends Logging {
parallelismMax = 0)
(path, leafFiles)
}.iterator
+ }.map { case (path, statuses) =>
+ val serializableStatuses = statuses.map { status =>
+ // Turn FileStatus into SerializableFileStatus so we can send it back to the driver
+ val blockLocations = status match {
+ case f: LocatedFileStatus =>
+ f.getBlockLocations.map { loc =>
+ SerializableBlockLocation(
+ loc.getNames,
+ loc.getHosts,
+ loc.getOffset,
+ loc.getLength)
+ }
+
+ case _ =>
+ Array.empty[SerializableBlockLocation]
+ }
+
+ SerializableFileStatus(
+ status.getPath.toString,
+ status.getLen,
+ status.isDirectory,
+ status.getReplication,
+ status.getBlockSize,
+ status.getModificationTime,
+ status.getAccessTime,
+ blockLocations)
+ }
+ (path.toString, serializableStatuses)
}.collect()
} finally {
sc.setJobDescription(previousJobDescription)
}
- statusMap.toSeq
+ // turn SerializableFileStatus back to Status
+ statusMap.map { case (path, serializableStatuses) =>
+ val statuses = serializableStatuses.map { f =>
+ val blockLocations = f.blockLocations.map { loc =>
+ new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length)
+ }
+ new LocatedFileStatus(
+ new FileStatus(
+ f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime,
+ new Path(f.path)),
+ blockLocations)
+ }
+ (new Path(path), statuses)
+ }
}
// scalastyle:off argcount
@@ -291,4 +332,22 @@ private[spark] object HadoopFSUtils extends Logging {
resolvedLeafStatuses
}
// scalastyle:on argcount
+
+ /** A serializable variant of HDFS's BlockLocation. This is required by Hadoop 2.7. */
+ private case class SerializableBlockLocation(
+ names: Array[String],
+ hosts: Array[String],
+ offset: Long,
+ length: Long)
+
+ /** A serializable variant of HDFS's FileStatus. This is required by Hadoop 2.7. */
+ private case class SerializableFileStatus(
+ path: String,
+ length: Long,
+ isDir: Boolean,
+ blockReplication: Short,
+ blockSize: Long,
+ modificationTime: Long,
+ accessTime: Long,
+ blockLocations: Array[SerializableBlockLocation])
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index b743ab6507117..71a310a4279ad 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -50,7 +50,7 @@ import com.google.common.net.InetAddresses
import org.apache.commons.codec.binary.Hex
import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, FileUtil, Path, Trash}
+import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
@@ -269,29 +269,6 @@ private[spark] object Utils extends Logging {
file.setExecutable(true, true)
}
- /**
- * Move data to trash if 'spark.sql.truncate.trash.enabled' is true, else
- * delete the data permanently. If move data to trash failed fallback to hard deletion.
- */
- def moveToTrashOrDelete(
- fs: FileSystem,
- partitionPath: Path,
- isTrashEnabled: Boolean,
- hadoopConf: Configuration): Boolean = {
- if (isTrashEnabled) {
- logDebug(s"Try to move data ${partitionPath.toString} to trash")
- val isSuccess = Trash.moveToAppropriateTrash(fs, partitionPath, hadoopConf)
- if (!isSuccess) {
- logWarning(s"Failed to move data ${partitionPath.toString} to trash. " +
- "Fallback to hard deletion")
- return fs.delete(partitionPath, true)
- }
- isSuccess
- } else {
- fs.delete(partitionPath, true)
- }
- }
-
/**
* Create a directory given the abstract pathname
* @return true, if the directory is successfully created; otherwise, return false.
@@ -2541,6 +2518,14 @@ private[spark] object Utils extends Logging {
master == "local" || master.startsWith("local[")
}
+ /**
+ * Push based shuffle can only be enabled when external shuffle service is enabled.
+ */
+ def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = {
+ conf.get(PUSH_BASED_SHUFFLE_ENABLED) &&
+ (conf.get(IS_TESTING).getOrElse(false) || conf.get(SHUFFLE_SERVICE_ENABLED))
+ }
+
/**
* Return whether dynamic allocation is enabled in the given conf.
*/
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 55280fc578310..144489c5f7922 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -100,6 +100,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
.set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m")
.set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L)
.set(Network.RPC_ASK_TIMEOUT, "5s")
+ .set(PUSH_BASED_SHUFFLE_ENABLED, true)
}
private def makeSortShuffleManager(): SortShuffleManager = {
@@ -1974,6 +1975,48 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
}
}
+ test("SPARK-32919: Shuffle push merger locations should be bounded with in" +
+ " spark.shuffle.push.retainedMergerLocations") {
+ assert(master.getShufflePushMergerLocations(10, Set.empty).isEmpty)
+ makeBlockManager(100, "execA",
+ transferService = Some(new MockBlockTransferService(10, "hostA")))
+ makeBlockManager(100, "execB",
+ transferService = Some(new MockBlockTransferService(10, "hostB")))
+ makeBlockManager(100, "execC",
+ transferService = Some(new MockBlockTransferService(10, "hostC")))
+ makeBlockManager(100, "execD",
+ transferService = Some(new MockBlockTransferService(10, "hostD")))
+ makeBlockManager(100, "execE",
+ transferService = Some(new MockBlockTransferService(10, "hostA")))
+ assert(master.getShufflePushMergerLocations(10, Set.empty).size == 4)
+ assert(master.getShufflePushMergerLocations(10, Set.empty).map(_.host).sorted ===
+ Seq("hostC", "hostD", "hostA", "hostB").sorted)
+ assert(master.getShufflePushMergerLocations(10, Set("hostB")).size == 3)
+ }
+
+ test("SPARK-32919: Prefer active executor locations for shuffle push mergers") {
+ makeBlockManager(100, "execA",
+ transferService = Some(new MockBlockTransferService(10, "hostA")))
+ makeBlockManager(100, "execB",
+ transferService = Some(new MockBlockTransferService(10, "hostB")))
+ makeBlockManager(100, "execC",
+ transferService = Some(new MockBlockTransferService(10, "hostC")))
+ makeBlockManager(100, "execD",
+ transferService = Some(new MockBlockTransferService(10, "hostD")))
+ makeBlockManager(100, "execE",
+ transferService = Some(new MockBlockTransferService(10, "hostA")))
+ assert(master.getShufflePushMergerLocations(5, Set.empty).size == 4)
+
+ master.removeExecutor("execA")
+ master.removeExecutor("execE")
+
+ assert(master.getShufflePushMergerLocations(3, Set.empty).size == 3)
+ assert(master.getShufflePushMergerLocations(3, Set.empty).map(_.host).sorted ===
+ Seq("hostC", "hostB", "hostD").sorted)
+ assert(master.getShufflePushMergerLocations(4, Set.empty).map(_.host).sorted ===
+ Seq("hostB", "hostA", "hostC", "hostD").sorted)
+ }
+
test("SPARK-33387 Support ordered shuffle block migration") {
val blocks: Seq[ShuffleBlockInfo] = Seq(
ShuffleBlockInfo(1, 0L),
@@ -1995,7 +2038,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(sortedBlocks.sameElements(decomManager.shufflesToMigrate.asScala.map(_._1)))
}
- class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
+ class MockBlockTransferService(
+ val maxFailures: Int,
+ override val hostName: String = "MockBlockTransferServiceHost") extends BlockTransferService {
var numCalls = 0
var tempFileManager: DownloadFileManager = null
@@ -2013,8 +2058,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
override def close(): Unit = {}
- override def hostName: String = { "MockBlockTransferServiceHost" }
-
override def port: Int = { 63332 }
override def uploadBlock(
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 20624c743bc22..8fb408041ca9d 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.SparkListener
import org.apache.spark.util.io.ChunkedByteBufferInputStream
@@ -1432,6 +1433,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
}.getMessage
assert(message.contains(expected))
}
+
+ test("isPushBasedShuffleEnabled when both PUSH_BASED_SHUFFLE_ENABLED" +
+ " and SHUFFLE_SERVICE_ENABLED are true") {
+ val conf = new SparkConf()
+ assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+ conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+ conf.set(IS_TESTING, false)
+ assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+ conf.set(SHUFFLE_SERVICE_ENABLED, true)
+ assert(Utils.isPushBasedShuffleEnabled(conf) === true)
+ }
}
private class SimpleExtension
diff --git a/dev/check-license b/dev/check-license
index 0cc17ffe55c67..bd255954d6db4 100755
--- a/dev/check-license
+++ b/dev/check-license
@@ -67,7 +67,7 @@ mkdir -p "$FWDIR"/lib
exit 1
}
-mkdir target
+mkdir -p target
$java_cmd -jar "$rat_jar" -E "$FWDIR"/dev/.rat-excludes -d "$FWDIR" > target/rat-results.txt
if [ $? -ne 0 ]; then
diff --git a/dev/mima b/dev/mima
index f324c5c00a45c..d214bb96e09a3 100755
--- a/dev/mima
+++ b/dev/mima
@@ -25,8 +25,8 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
SPARK_PROFILES=${1:-"-Pmesos -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"}
-TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)"
-OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)"
+TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | grep jar | tail -n1)"
+OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | grep jar | tail -n1)"
rm -f .generated-mima*
diff --git a/docs/css/main.css b/docs/css/main.css
index 8168a46f9a437..8b279a157c2b6 100755
--- a/docs/css/main.css
+++ b/docs/css/main.css
@@ -162,6 +162,7 @@ body .container-wrapper {
margin-right: auto;
border-radius: 15px;
position: relative;
+ min-height: 100vh;
}
.title {
@@ -264,6 +265,7 @@ a:hover code {
max-width: 914px;
line-height: 1.6; /* Inspired by Github's wiki style */
padding-left: 30px;
+ min-height: 100vh;
}
.dropdown-menu {
@@ -325,6 +327,7 @@ a.anchorjs-link:hover { text-decoration: none; }
border-bottom-width: 0px;
margin-top: 0px;
width: 210px;
+ height: 80%;
float: left;
position: fixed;
overflow-y: scroll;
diff --git a/docs/sql-data-sources-generic-options.md b/docs/sql-data-sources-generic-options.md
index 6bcf48235bced..2e4fc879a435f 100644
--- a/docs/sql-data-sources-generic-options.md
+++ b/docs/sql-data-sources-generic-options.md
@@ -119,3 +119,40 @@ To load all files recursively, you can use:
{% include_example recursive_file_lookup r/RSparkSQLExample.R %}
+
+### Modification Time Path Filters
+
+`modifiedBefore` and `modifiedAfter` are options that can be
+applied together or separately in order to achieve greater
+granularity over which files may load during a Spark batch query.
+(Note that Structured Streaming file sources don't support these options.)
+
+* `modifiedBefore`: an optional timestamp to only include files with
+modification times occurring before the specified time. The provided timestamp
+must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+* `modifiedAfter`: an optional timestamp to only include files with
+modification times occurring after the specified time. The provided timestamp
+must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+
+When a timezone option is not provided, the timestamps will be interpreted according
+to the Spark session timezone (`spark.sql.session.timeZone`).
+
+To load files with paths matching a given modified time range, you can use:
+
+
+
+{% include_example load_with_modified_time_filter scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
+
+
+
+{% include_example load_with_modified_time_filter java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
+
+
+
+{% include_example load_with_modified_time_filter python/sql/datasource.py %}
+
+
+
+{% include_example load_with_modified_time_filter r/RSparkSQLExample.R %}
+
+
\ No newline at end of file
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index fd7208615a09f..870ed0aa0daaa 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -135,6 +135,7 @@ The behavior of some SQL functions can be different under ANSI mode (`spark.sql.
- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
- `element_at`: This function throws `NoSuchElementException` if key does not exist in map.
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
+ - `parse_url`: This function throws `IllegalArgumentException` if an input string is not a valid url.
### SQL Operators
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java
index 2295225387a33..46e740d78bffb 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java
@@ -147,6 +147,22 @@ private static void runGenericFileSourceOptionsExample(SparkSession spark) {
// |file1.parquet|
// +-------------+
// $example off:load_with_path_glob_filter$
+ // $example on:load_with_modified_time_filter$
+ Dataset beforeFilterDF = spark.read().format("parquet")
+ // Only load files modified before 7/1/2020 at 05:30
+ .option("modifiedBefore", "2020-07-01T05:30:00")
+ // Only load files modified after 6/1/2020 at 05:30
+ .option("modifiedAfter", "2020-06-01T05:30:00")
+ // Interpret both times above relative to CST timezone
+ .option("timeZone", "CST")
+ .load("examples/src/main/resources/dir1");
+ beforeFilterDF.show();
+ // +-------------+
+ // | file|
+ // +-------------+
+ // |file1.parquet|
+ // +-------------+
+ // $example off:load_with_modified_time_filter$
}
private static void runBasicDataSourceExample(SparkSession spark) {
diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py
index eecd8c2d84788..8c146ba0c9455 100644
--- a/examples/src/main/python/sql/datasource.py
+++ b/examples/src/main/python/sql/datasource.py
@@ -67,6 +67,26 @@ def generic_file_source_options_example(spark):
# +-------------+
# $example off:load_with_path_glob_filter$
+ # $example on:load_with_modified_time_filter$
+ # Only load files modified before 07/1/2050 @ 08:30:00
+ df = spark.read.load("examples/src/main/resources/dir1",
+ format="parquet", modifiedBefore="2050-07-01T08:30:00")
+ df.show()
+ # +-------------+
+ # | file|
+ # +-------------+
+ # |file1.parquet|
+ # +-------------+
+ # Only load files modified after 06/01/2050 @ 08:30:00
+ df = spark.read.load("examples/src/main/resources/dir1",
+ format="parquet", modifiedAfter="2050-06-01T08:30:00")
+ df.show()
+ # +-------------+
+ # | file|
+ # +-------------+
+ # +-------------+
+ # $example off:load_with_modified_time_filter$
+
def basic_datasource_example(spark):
# $example on:generic_load_save_functions$
diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R
index 8685cfb5c05f2..86ad5334248bc 100644
--- a/examples/src/main/r/RSparkSQLExample.R
+++ b/examples/src/main/r/RSparkSQLExample.R
@@ -144,6 +144,14 @@ df <- read.df("examples/src/main/resources/dir1", "parquet", pathGlobFilter = "*
# 1 file1.parquet
# $example off:load_with_path_glob_filter$
+# $example on:load_with_modified_time_filter$
+beforeDF <- read.df("examples/src/main/resources/dir1", "parquet", modifiedBefore= "2020-07-01T05:30:00")
+# file
+# 1 file1.parquet
+afterDF <- read.df("examples/src/main/resources/dir1", "parquet", modifiedAfter = "2020-06-01T05:30:00")
+# file
+# $example off:load_with_modified_time_filter$
+
# $example on:manual_save_options_orc$
df <- read.df("examples/src/main/resources/users.orc", "orc")
write.orc(df, "users_with_options.orc", orc.bloom.filter.columns = "favorite_color", orc.dictionary.key.threshold = 1.0, orc.column.encoding.direct = "name")
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
index 2c7abfcd335d1..90c0eeb5ba888 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
@@ -81,6 +81,27 @@ object SQLDataSourceExample {
// |file1.parquet|
// +-------------+
// $example off:load_with_path_glob_filter$
+ // $example on:load_with_modified_time_filter$
+ val beforeFilterDF = spark.read.format("parquet")
+ // Files modified before 07/01/2020 at 05:30 are allowed
+ .option("modifiedBefore", "2020-07-01T05:30:00")
+ .load("examples/src/main/resources/dir1");
+ beforeFilterDF.show();
+ // +-------------+
+ // | file|
+ // +-------------+
+ // |file1.parquet|
+ // +-------------+
+ val afterFilterDF = spark.read.format("parquet")
+ // Files modified after 06/01/2020 at 05:30 are allowed
+ .option("modifiedAfter", "2020-06-01T05:30:00")
+ .load("examples/src/main/resources/dir1");
+ afterFilterDF.show();
+ // +-------------+
+ // | file|
+ // +-------------+
+ // +-------------+
+ // $example off:load_with_modified_time_filter$
}
private def runBasicDataSourceExample(spark: SparkSession): Unit = {
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
index 4b9acd0d39f3f..d086c8cdcc589 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
@@ -29,7 +29,8 @@ import org.apache.spark.tags.DockerTest
* To run this test suite for a specific version (e.g., ibmcom/db2:11.5.4.0):
* {{{
* DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.4.0
- * ./build/sbt -Pdocker-integration-tests "testOnly *DB2IntegrationSuite"
+ * ./build/sbt -Pdocker-integration-tests
+ * "testOnly org.apache.spark.sql.jdbc.DB2IntegrationSuite"
* }}}
*/
@DockerTest
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index f1ffc8f0f3dc7..939a07238934b 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest
* To run this test suite for a specific version (e.g., 2019-GA-ubuntu-16.04):
* {{{
* MSSQLSERVER_DOCKER_IMAGE_NAME=2019-GA-ubuntu-16.04
- * ./build/sbt -Pdocker-integration-tests "testOnly *MsSqlServerIntegrationSuite"
+ * ./build/sbt -Pdocker-integration-tests
+ * "testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite"
* }}}
*/
@DockerTest
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
index 6f96ab33d0fee..68f0dbc057c1f 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -28,7 +28,8 @@ import org.apache.spark.tags.DockerTest
* To run this test suite for a specific version (e.g., mysql:5.7.31):
* {{{
* MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.31
- * ./build/sbt -Pdocker-integration-tests "testOnly *MySQLIntegrationSuite"
+ * ./build/sbt -Pdocker-integration-tests
+ * "testOnly org.apache.spark.sql.jdbc.MySQLIntegrationSuite"
* }}}
*/
@DockerTest
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index fa13100b5fdc8..0347c98bba2c4 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -30,7 +30,8 @@ import org.apache.spark.tags.DockerTest
* To run this test suite for a specific version (e.g., postgres:13.0):
* {{{
* POSTGRES_DOCKER_IMAGE_NAME=postgres:13.0
- * ./build/sbt -Pdocker-integration-tests "testOnly *PostgresIntegrationSuite"
+ * ./build/sbt -Pdocker-integration-tests
+ * "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite"
* }}}
*/
@DockerTest
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index ad1010da5c104..03ebe0299f63f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
* The imputation strategy. Currently only "mean" and "median" are supported.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
+ * If "mode", then replace missing using the most frequent value of the feature.
* Default: mean
*
* @group param
*/
final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
- s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
- ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))
+ s"If ${Imputer.median}, then replace missing values using the median value of the feature. " +
+ s"If ${Imputer.mode}, then replace missing values using the most frequent value of " +
+ s"the feature.", ParamValidators.inArray[String](Imputer.supportedStrategies))
/** @group getParam */
def getStrategy: String = $(strategy)
@@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
* For example, if the input column is IntegerType (1, 2, 4, null),
* the output will be IntegerType (1, 2, 4, 2) after mean imputation.
*
- * Note that the mean/median value is computed after filtering out missing values.
+ * Note that the mean/median/mode value is computed after filtering out missing values.
* All Null values in the input columns are treated as missing, and so are also imputed. For
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
*/
@@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
/**
- * Imputation strategy. Available options are ["mean", "median"].
+ * Imputation strategy. Available options are ["mean", "median", "mode"].
* @group setParam
*/
@Since("2.2.0")
@@ -151,39 +153,42 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
val spark = dataset.sparkSession
val (inputColumns, _) = getInOutCols()
-
val cols = inputColumns.map { inputCol =>
when(col(inputCol).equalTo($(missingValue)), null)
.when(col(inputCol).isNaN, null)
.otherwise(col(inputCol))
- .cast("double")
+ .cast(DoubleType)
.as(inputCol)
}
+ val numCols = cols.length
val results = $(strategy) match {
case Imputer.mean =>
// Function avg will ignore null automatically.
// For a column only containing null, avg will return null.
val row = dataset.select(cols.map(avg): _*).head()
- Array.range(0, inputColumns.length).map { i =>
- if (row.isNullAt(i)) {
- Double.NaN
- } else {
- row.getDouble(i)
- }
- }
+ Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i))
case Imputer.median =>
// Function approxQuantile will ignore null automatically.
// For a column only containing null, approxQuantile will return an empty array.
dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError))
- .map { array =>
- if (array.isEmpty) {
- Double.NaN
- } else {
- array.head
- }
- }
+ .map(_.headOption.getOrElse(Double.NaN))
+
+ case Imputer.mode =>
+ import spark.implicits._
+ // If there is more than one mode, choose the smallest one to keep in line
+ // with sklearn.impute.SimpleImputer (using scipy.stats.mode).
+ val modes = dataset.select(cols: _*).flatMap { row =>
+ // Ignore null.
+ Iterator.range(0, numCols)
+ .flatMap(i => if (row.isNullAt(i)) None else Some((i, row.getDouble(i))))
+ }.toDF("index", "value")
+ .groupBy("index", "value").agg(negate(count(lit(0))).as("negative_count"))
+ .groupBy("index").agg(min(struct("negative_count", "value")).as("mode"))
+ .select("index", "mode.value")
+ .as[(Int, Double)].collect().toMap
+ Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN))
}
val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1)
@@ -212,6 +217,10 @@ object Imputer extends DefaultParamsReadable[Imputer] {
/** strategy names that Imputer currently supports. */
private[feature] val mean = "mean"
private[feature] val median = "median"
+ private[feature] val mode = "mode"
+
+ /* Set of strategies that Imputer supports */
+ private[feature] val supportedStrategies = Array(mean, median, mode)
@Since("2.2.0")
override def load(path: String): Imputer = super.load(path)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
index dfee2b4029c8b..30887f55638f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
@@ -28,13 +28,14 @@ import org.apache.spark.sql.types._
class ImputerSuite extends MLTest with DefaultReadWriteTest {
test("Imputer for Double with default missing Value NaN") {
- val df = spark.createDataFrame( Seq(
- (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0),
- (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0),
- (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0),
- (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0)
- )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1",
- "expected_mean_value2", "expected_median_value2")
+ val df = spark.createDataFrame(Seq(
+ (0, 1.0, 4.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0),
+ (1, 11.0, 12.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0),
+ (2, 3.0, Double.NaN, 3.0, 3.0, 3.0, 10.0, 12.0, 4.0),
+ (3, Double.NaN, 14.0, 5.0, 3.0, 1.0, 14.0, 14.0, 14.0)
+ )).toDF("id", "value1", "value2",
+ "expected_mean_value1", "expected_median_value1", "expected_mode_value1",
+ "expected_mean_value2", "expected_median_value2", "expected_mode_value2")
val imputer = new Imputer()
.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("out1", "out2"))
@@ -42,23 +43,25 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Single Column: Imputer for Double with default missing Value NaN") {
- val df1 = spark.createDataFrame( Seq(
- (0, 1.0, 1.0, 1.0),
- (1, 11.0, 11.0, 11.0),
- (2, 3.0, 3.0, 3.0),
- (3, Double.NaN, 5.0, 3.0)
- )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+ val df1 = spark.createDataFrame(Seq(
+ (0, 1.0, 1.0, 1.0, 1.0),
+ (1, 11.0, 11.0, 11.0, 11.0),
+ (2, 3.0, 3.0, 3.0, 3.0),
+ (3, Double.NaN, 5.0, 3.0, 1.0)
+ )).toDF("id", "value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer1 = new Imputer()
.setInputCol("value")
.setOutputCol("out")
ImputerSuite.iterateStrategyTest(false, imputer1, df1)
- val df2 = spark.createDataFrame( Seq(
- (0, 4.0, 4.0, 4.0),
- (1, 12.0, 12.0, 12.0),
- (2, Double.NaN, 10.0, 12.0),
- (3, 14.0, 14.0, 14.0)
- )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+ val df2 = spark.createDataFrame(Seq(
+ (0, 4.0, 4.0, 4.0, 4.0),
+ (1, 12.0, 12.0, 12.0, 12.0),
+ (2, Double.NaN, 10.0, 12.0, 4.0),
+ (3, 14.0, 14.0, 14.0, 14.0)
+ )).toDF("id", "value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer2 = new Imputer()
.setInputCol("value")
.setOutputCol("out")
@@ -66,12 +69,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") {
- val df = spark.createDataFrame( Seq(
- (0, 1.0, 1.0, 1.0),
- (1, 3.0, 3.0, 3.0),
- (2, Double.NaN, Double.NaN, Double.NaN),
- (3, -1.0, 2.0, 1.0)
- )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+ val df = spark.createDataFrame(Seq(
+ (0, 1.0, 1.0, 1.0, 1.0),
+ (1, 3.0, 3.0, 3.0, 3.0),
+ (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN),
+ (3, -1.0, 2.0, 1.0, 1.0)
+ )).toDF("id", "value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setMissingValue(-1.0)
ImputerSuite.iterateStrategyTest(true, imputer, df)
@@ -79,64 +83,69 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
test("Single Column: Imputer should handle NaNs when computing surrogate value," +
" if missingValue is not NaN") {
- val df = spark.createDataFrame( Seq(
- (0, 1.0, 1.0, 1.0),
- (1, 3.0, 3.0, 3.0),
- (2, Double.NaN, Double.NaN, Double.NaN),
- (3, -1.0, 2.0, 1.0)
- )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+ val df = spark.createDataFrame(Seq(
+ (0, 1.0, 1.0, 1.0, 1.0),
+ (1, 3.0, 3.0, 3.0, 3.0),
+ (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN),
+ (3, -1.0, 2.0, 1.0, 1.0)
+ )).toDF("id", "value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer = new Imputer().setInputCol("value").setOutputCol("out")
.setMissingValue(-1.0)
ImputerSuite.iterateStrategyTest(false, imputer, df)
}
test("Imputer for Float with missing Value -1.0") {
- val df = spark.createDataFrame( Seq(
- (0, 1.0F, 1.0F, 1.0F),
- (1, 3.0F, 3.0F, 3.0F),
- (2, 10.0F, 10.0F, 10.0F),
- (3, 10.0F, 10.0F, 10.0F),
- (4, -1.0F, 6.0F, 3.0F)
- )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+ val df = spark.createDataFrame(Seq(
+ (0, 1.0F, 1.0F, 1.0F, 1.0F),
+ (1, 3.0F, 3.0F, 3.0F, 3.0F),
+ (2, 10.0F, 10.0F, 10.0F, 10.0F),
+ (3, 10.0F, 10.0F, 10.0F, 10.0F),
+ (4, -1.0F, 6.0F, 3.0F, 10.0F)
+ )).toDF("id", "value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setMissingValue(-1)
ImputerSuite.iterateStrategyTest(true, imputer, df)
}
test("Single Column: Imputer for Float with missing Value -1.0") {
- val df = spark.createDataFrame( Seq(
- (0, 1.0F, 1.0F, 1.0F),
- (1, 3.0F, 3.0F, 3.0F),
- (2, 10.0F, 10.0F, 10.0F),
- (3, 10.0F, 10.0F, 10.0F),
- (4, -1.0F, 6.0F, 3.0F)
- )).toDF("id", "value", "expected_mean_value", "expected_median_value")
+ val df = spark.createDataFrame(Seq(
+ (0, 1.0F, 1.0F, 1.0F, 1.0F),
+ (1, 3.0F, 3.0F, 3.0F, 3.0F),
+ (2, 10.0F, 10.0F, 10.0F, 10.0F),
+ (3, 10.0F, 10.0F, 10.0F, 10.0F),
+ (4, -1.0F, 6.0F, 3.0F, 10.0F)
+ )).toDF("id", "value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer = new Imputer().setInputCol("value").setOutputCol("out")
.setMissingValue(-1)
ImputerSuite.iterateStrategyTest(false, imputer, df)
}
test("Imputer should impute null as well as 'missingValue'") {
- val rawDf = spark.createDataFrame( Seq(
- (0, 4.0, 4.0, 4.0),
- (1, 10.0, 10.0, 10.0),
- (2, 10.0, 10.0, 10.0),
- (3, Double.NaN, 8.0, 10.0),
- (4, -1.0, 8.0, 10.0)
- )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value")
+ val rawDf = spark.createDataFrame(Seq(
+ (0, 4.0, 4.0, 4.0, 4.0),
+ (1, 10.0, 10.0, 10.0, 10.0),
+ (2, 10.0, 10.0, 10.0, 10.0),
+ (3, Double.NaN, 8.0, 10.0, 10.0),
+ (4, -1.0, 8.0, 10.0, 10.0)
+ )).toDF("id", "rawValue",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
ImputerSuite.iterateStrategyTest(true, imputer, df)
}
test("Single Column: Imputer should impute null as well as 'missingValue'") {
- val rawDf = spark.createDataFrame( Seq(
- (0, 4.0, 4.0, 4.0),
- (1, 10.0, 10.0, 10.0),
- (2, 10.0, 10.0, 10.0),
- (3, Double.NaN, 8.0, 10.0),
- (4, -1.0, 8.0, 10.0)
- )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value")
+ val rawDf = spark.createDataFrame(Seq(
+ (0, 4.0, 4.0, 4.0, 4.0),
+ (1, 10.0, 10.0, 10.0, 10.0),
+ (2, 10.0, 10.0, 10.0, 10.0),
+ (3, Double.NaN, 8.0, 10.0, 10.0),
+ (4, -1.0, 8.0, 10.0, 10.0)
+ )).toDF("id", "rawValue",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value")
val imputer = new Imputer().setInputCol("value").setOutputCol("out")
ImputerSuite.iterateStrategyTest(false, imputer, df)
@@ -187,7 +196,7 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Imputer throws exception when surrogate cannot be computed") {
- val df = spark.createDataFrame( Seq(
+ val df = spark.createDataFrame(Seq(
(0, Double.NaN, 1.0, 1.0),
(1, Double.NaN, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN)
@@ -205,12 +214,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Single Column: Imputer throws exception when surrogate cannot be computed") {
- val df = spark.createDataFrame( Seq(
- (0, Double.NaN, 1.0, 1.0),
- (1, Double.NaN, 3.0, 3.0),
- (2, Double.NaN, Double.NaN, Double.NaN)
- )).toDF("id", "value", "expected_mean_value", "expected_median_value")
- Seq("mean", "median").foreach { strategy =>
+ val df = spark.createDataFrame(Seq(
+ (0, Double.NaN, 1.0, 1.0, 1.0),
+ (1, Double.NaN, 3.0, 3.0, 3.0),
+ (2, Double.NaN, Double.NaN, Double.NaN, Double.NaN)
+ )).toDF("id", "value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
+ Seq("mean", "median", "mode").foreach { strategy =>
val imputer = new Imputer().setInputCol("value").setOutputCol("out")
.setStrategy(strategy)
withClue("Imputer should fail all the values are invalid") {
@@ -223,12 +233,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Imputer input & output column validation") {
- val df = spark.createDataFrame( Seq(
+ val df = spark.createDataFrame(Seq(
(0, 1.0, 1.0, 1.0),
(1, Double.NaN, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN)
)).toDF("id", "value1", "value2", "value3")
- Seq("mean", "median").foreach { strategy =>
+ Seq("mean", "median", "mode").foreach { strategy =>
withClue("Imputer should fail if inputCols and outputCols are different length") {
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
val imputer = new Imputer().setStrategy(strategy)
@@ -306,13 +316,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Imputer for IntegerType with default missing value null") {
-
- val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
- (1, 1, 1),
- (11, 11, 11),
- (3, 3, 3),
- (null, 5, 3)
- )).toDF("value1", "expected_mean_value1", "expected_median_value1")
+ val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)](
+ (1, 1, 1, 1),
+ (11, 11, 11, 11),
+ (3, 3, 3, 3),
+ (null, 5, 3, 1)
+ )).toDF("value1",
+ "expected_mean_value1", "expected_median_value1", "expected_mode_value1")
val imputer = new Imputer()
.setInputCols(Array("value1"))
@@ -327,12 +337,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Single Column Imputer for IntegerType with default missing value null") {
- val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
- (1, 1, 1),
- (11, 11, 11),
- (3, 3, 3),
- (null, 5, 3)
- )).toDF("value", "expected_mean_value", "expected_median_value")
+ val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)](
+ (1, 1, 1, 1),
+ (11, 11, 11, 11),
+ (3, 3, 3, 3),
+ (null, 5, 3, 1)
+ )).toDF("value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer = new Imputer()
.setInputCol("value")
@@ -347,13 +358,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Imputer for IntegerType with missing value -1") {
-
- val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
- (1, 1, 1),
- (11, 11, 11),
- (3, 3, 3),
- (-1, 5, 3)
- )).toDF("value1", "expected_mean_value1", "expected_median_value1")
+ val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)](
+ (1, 1, 1, 1),
+ (11, 11, 11, 11),
+ (3, 3, 3, 3),
+ (-1, 5, 3, 1)
+ )).toDF("value1",
+ "expected_mean_value1", "expected_median_value1", "expected_mode_value1")
val imputer = new Imputer()
.setInputCols(Array("value1"))
@@ -369,12 +380,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Single Column: Imputer for IntegerType with missing value -1") {
- val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
- (1, 1, 1),
- (11, 11, 11),
- (3, 3, 3),
- (-1, 5, 3)
- )).toDF("value", "expected_mean_value", "expected_median_value")
+ val df = spark.createDataFrame(Seq[(Integer, Integer, Integer, Integer)](
+ (1, 1, 1, 1),
+ (11, 11, 11, 11),
+ (3, 3, 3, 3),
+ (-1, 5, 3, 1)
+ )).toDF("value",
+ "expected_mean_value", "expected_median_value", "expected_mode_value")
val imputer = new Imputer()
.setInputCol("value")
@@ -402,13 +414,13 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
}
test("Compare single/multiple column(s) Imputer in pipeline") {
- val df = spark.createDataFrame( Seq(
+ val df = spark.createDataFrame(Seq(
(0, 1.0, 4.0),
(1, 11.0, 12.0),
(2, 3.0, Double.NaN),
(3, Double.NaN, 14.0)
)).toDF("id", "value1", "value2")
- Seq("mean", "median").foreach { strategy =>
+ Seq("mean", "median", "mode").foreach { strategy =>
val multiColsImputer = new Imputer()
.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("result1", "result2"))
@@ -450,11 +462,12 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
object ImputerSuite {
/**
- * Imputation strategy. Available options are ["mean", "median"].
- * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median"
+ * Imputation strategy. Available options are ["mean", "median", "mode"].
+ * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median",
+ * "expected_mode".
*/
def iterateStrategyTest(isMultiCol: Boolean, imputer: Imputer, df: DataFrame): Unit = {
- Seq("mean", "median").foreach { strategy =>
+ Seq("mean", "median", "mode").foreach { strategy =>
imputer.setStrategy(strategy)
val model = imputer.fit(df)
val resultDF = model.transform(df)
diff --git a/pom.xml b/pom.xml
index 3ae2e7420e154..0ab5a8c5b3efa 100644
--- a/pom.xml
+++ b/pom.xml
@@ -164,7 +164,6 @@
3.2.2
2.12.10
2.12
- -Ywarn-unused-import
2.0.0
--test
@@ -932,7 +931,7 @@
org.scalatest
scalatest_${scala.binary.version}
- 3.2.0
+ 3.2.3
test
@@ -956,14 +955,14 @@
org.mockito
mockito-core
- 3.1.0
+ 3.4.6
test
org.jmock
jmock-junit4
test
- 2.8.4
+ 2.12.0
org.scalacheck
@@ -974,7 +973,7 @@
junit
junit
- 4.12
+ 4.13.1
test
@@ -2499,7 +2498,7 @@
net.alchim31.maven
scala-maven-plugin
- 4.3.0
+ 4.4.0
eclipse-add-source
@@ -2538,7 +2537,6 @@
-deprecation
-feature
-explaintypes
- ${scalac.arg.unused-imports}
-target:jvm-1.8
@@ -2575,7 +2573,7 @@
org.apache.maven.plugins
maven-surefire-plugin
- 3.0.0-M3
+ 3.0.0-M5
@@ -3262,13 +3260,12 @@
-
+
scala-2.13
2.13.3
2.13
- -Wconf:cat=unused-imports:e
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 55c87fcb3aaa2..05413b7091ad9 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -221,6 +221,7 @@ object SparkBuild extends PomBuild {
Seq(
"-Xfatal-warnings",
"-deprecation",
+ "-Ywarn-unused-import",
"-P:silencer:globalFilters=.*deprecated.*" //regex to catch deprecation warnings and supress them
)
} else {
@@ -230,6 +231,8 @@ object SparkBuild extends PomBuild {
// see `scalac -Wconf:help` for details
"-Wconf:cat=deprecation:wv,any:e",
// 2.13-specific warning hits to be muted (as narrowly as possible) and addressed separately
+ // TODO(SPARK-33499): Enable this option when Scala 2.12 is no longer supported.
+ // "-Wunused:imports",
"-Wconf:cat=lint-multiarg-infix:wv",
"-Wconf:cat=other-nullary-override:wv",
"-Wconf:cat=other-match-analysis&site=org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupFunction.catalogFunction:wv",
diff --git a/project/build.properties b/project/build.properties
index 5ec1d700fd2a8..c92de941c10be 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-sbt.version=1.4.2
+sbt.version=1.4.4
diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst
index 4039698d39958..9c9ff7fa7844b 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -48,7 +48,7 @@ If you want to install extra dependencies for a specific componenet, you can ins
pip install pyspark[sql]
-For PySpark with a different Hadoop version, you can install it by using ``HADOOP_VERSION`` environment variables as below:
+For PySpark with/without a specific Hadoop version, you can install it by using ``HADOOP_VERSION`` environment variables as below:
.. code-block:: bash
@@ -68,8 +68,13 @@ It is recommended to use ``-v`` option in ``pip`` to track the installation and
HADOOP_VERSION=2.7 pip install pyspark -v
-Supported versions of Hadoop are ``HADOOP_VERSION=2.7`` and ``HADOOP_VERSION=3.2`` (default).
-Note that this installation of PySpark with a different version of Hadoop is experimental. It can change or be removed between minor releases.
+Supported values in ``HADOOP_VERSION`` are:
+
+- ``without``: Spark pre-built with user-provided Apache Hadoop
+- ``2.7``: Spark pre-built for Apache Hadoop 2.7
+- ``3.2``: Spark pre-built for Apache Hadoop 3.2 and later (default)
+
+Note that this installation way of PySpark with/without a specific Hadoop version is experimental. It can change or be removed between minor releases.
Using Conda
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 9c9e3f4b3c881..1bd5961e0525a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -222,6 +222,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# data via a socket.
# scala's mangled names w/ $ in them require special treatment.
self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc)
+ os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = \
+ str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
self.pythonVer = "%d.%d" % sys.version_info[:2]
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index eafa5d90f9ff8..fe2e326dff8be 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret):
af, socktype, proto, _, sa = res
try:
sock = socket.socket(af, socktype, proto)
- sock.settimeout(15)
+ sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15)))
sock.connect(sa)
sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
_do_server_auth(sockfile, auth_secret)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 4d898bd5fffa8..82b9a6db1eb92 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1507,7 +1507,8 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has
strategy = Param(Params._dummy(), "strategy",
"strategy for imputation. If mean, then replace missing values using the mean "
"value of the feature. If median, then replace missing values using the "
- "median value of the feature.",
+ "median value of the feature. If mode, then replace missing using the most "
+ "frequent value of the feature.",
typeConverter=TypeConverters.toString)
missingValue = Param(Params._dummy(), "missingValue",
@@ -1541,7 +1542,7 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
numeric type. Currently Imputer does not support categorical features and
possibly creates incorrect values for a categorical feature.
- Note that the mean/median value is computed after filtering out missing values.
+ Note that the mean/median/mode value is computed after filtering out missing values.
All Null values in the input columns are treated as missing, and so are also imputed. For
computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a
relative error of `0.001`.
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 2ed991c87f506..bb31e6a3e09f8 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -125,6 +125,12 @@ def option(self, key, value):
* ``pathGlobFilter``: an optional glob pattern to only include files with paths matching
the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter.
It does not change the behavior of partition discovery.
+ * ``modifiedBefore``: an optional timestamp to only include files with
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ * ``modifiedAfter``: an optional timestamp to only include files with
+ modification times occurring after the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
"""
self._jreader = self._jreader.option(key, to_str(value))
return self
@@ -149,6 +155,12 @@ def options(self, **options):
* ``pathGlobFilter``: an optional glob pattern to only include files with paths matching
the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter.
It does not change the behavior of partition discovery.
+ * ``modifiedBefore``: an optional timestamp to only include files with
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ * ``modifiedAfter``: an optional timestamp to only include files with
+ modification times occurring after the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
"""
for k in options:
self._jreader = self._jreader.option(k, to_str(options[k]))
@@ -203,7 +215,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
dropFieldIfAllNull=None, encoding=None, locale=None, pathGlobFilter=None,
- recursiveFileLookup=None, allowNonNumericNumbers=None):
+ recursiveFileLookup=None, allowNonNumericNumbers=None,
+ modifiedBefore=None, modifiedAfter=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -322,6 +335,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
``+Infinity`` and ``Infinity``.
* ``-INF``: for negative infinity, alias ``-Infinity``.
* ``NaN``: for other not-a-numbers, like result of division by zero.
+ modifiedBefore : an optional timestamp to only include files with
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedAfter : an optional timestamp to only include files with
+ modification times occurring after the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+
Examples
--------
@@ -344,6 +364,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding,
locale=locale, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup,
+ modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter,
allowNonNumericNumbers=allowNonNumericNumbers)
if isinstance(path, str):
path = [path]
@@ -410,6 +431,15 @@ def parquet(self, *paths, **options):
disables
`partition discovery `_. # noqa
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedBefore (batch only) : an optional timestamp to only include files with
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedAfter (batch only) : an optional timestamp to only include files with
+ modification times occurring after the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+
Examples
--------
>>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')
@@ -418,13 +448,18 @@ def parquet(self, *paths, **options):
"""
mergeSchema = options.get('mergeSchema', None)
pathGlobFilter = options.get('pathGlobFilter', None)
+ modifiedBefore = options.get('modifiedBefore', None)
+ modifiedAfter = options.get('modifiedAfter', None)
recursiveFileLookup = options.get('recursiveFileLookup', None)
self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter,
- recursiveFileLookup=recursiveFileLookup)
+ recursiveFileLookup=recursiveFileLookup, modifiedBefore=modifiedBefore,
+ modifiedAfter=modifiedAfter)
+
return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths)))
def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None,
- recursiveFileLookup=None):
+ recursiveFileLookup=None, modifiedBefore=None,
+ modifiedAfter=None):
"""
Loads text files and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
@@ -453,6 +488,15 @@ def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None,
recursively scan a directory for files. Using this option disables
`partition discovery `_. # noqa
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedBefore (batch only) : an optional timestamp to only include files with
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedAfter (batch only) : an optional timestamp to only include files with
+ modification times occurring after the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+
Examples
--------
>>> df = spark.read.text('python/test_support/sql/text-test.txt')
@@ -464,7 +508,9 @@ def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None,
"""
self._set_opts(
wholetext=wholetext, lineSep=lineSep, pathGlobFilter=pathGlobFilter,
- recursiveFileLookup=recursiveFileLookup)
+ recursiveFileLookup=recursiveFileLookup, modifiedBefore=modifiedBefore,
+ modifiedAfter=modifiedAfter)
+
if isinstance(paths, str):
paths = [paths]
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths)))
@@ -476,7 +522,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None,
- pathGlobFilter=None, recursiveFileLookup=None):
+ pathGlobFilter=None, recursiveFileLookup=None, modifiedBefore=None, modifiedAfter=None):
r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -631,6 +677,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
recursively scan a directory for files. Using this option disables
`partition discovery `_. # noqa
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedBefore (batch only) : an optional timestamp to only include files with
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedAfter (batch only) : an optional timestamp to only include files with
+ modification times occurring after the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+
Examples
--------
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
@@ -652,7 +707,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep,
- pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup)
+ pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup,
+ modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter)
if isinstance(path, str):
path = [path]
if type(path) == list:
@@ -679,7 +735,8 @@ def func(iterator):
else:
raise TypeError("path can be only string, list or RDD")
- def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None):
+ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None,
+ modifiedBefore=None, modifiedAfter=None):
"""Loads ORC files, returning the result as a :class:`DataFrame`.
.. versionadded:: 1.5.0
@@ -701,6 +758,15 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N
disables
`partition discovery `_. # noqa
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedBefore : an optional timestamp to only include files with
+ modification times occurring before the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ modifiedAfter : an optional timestamp to only include files with
+ modification times occurring after the specified time. The provided timestamp
+ must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+
Examples
--------
>>> df = spark.read.orc('python/test_support/sql/orc_partitioned')
@@ -708,6 +774,7 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N
[('a', 'bigint'), ('b', 'int'), ('c', 'int')]
"""
self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter,
+ modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter,
recursiveFileLookup=recursiveFileLookup)
if isinstance(path, str):
path = [path]
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index b42bdb9816600..22002bb32004d 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -18,7 +18,7 @@
package org.apache.spark.scheduler.cluster
import java.util.EnumSet
-import java.util.concurrent.atomic.{AtomicBoolean}
+import java.util.concurrent.atomic.AtomicBoolean
import javax.servlet.DispatcherType
import scala.concurrent.{ExecutionContext, Future}
@@ -29,14 +29,14 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
import org.apache.spark.SparkContext
import org.apache.spark.deploy.security.HadoopDelegationTokenManager
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config
+import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.config.UI._
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.rpc._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{RpcUtils, ThreadUtils}
+import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
+import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
/**
* Abstract Yarn scheduler backend that contains common logic
@@ -80,6 +80,18 @@ private[spark] abstract class YarnSchedulerBackend(
/** Attempt ID. This is unset for client-mode schedulers */
private var attemptId: Option[ApplicationAttemptId] = None
+ private val blockManagerMaster: BlockManagerMaster = sc.env.blockManager.master
+
+ private val minMergersThresholdRatio =
+ conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO)
+
+ private val minMergersStaticThreshold =
+ conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD)
+
+ private val maxNumExecutors = conf.get(config.DYN_ALLOCATION_MAX_EXECUTORS)
+
+ private val numExecutors = conf.get(config.EXECUTOR_INSTANCES).getOrElse(0)
+
/**
* Bind to YARN. This *must* be done before calling [[start()]].
*
@@ -161,6 +173,36 @@ private[spark] abstract class YarnSchedulerBackend(
totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
}
+ override def getShufflePushMergerLocations(
+ numPartitions: Int,
+ resourceProfileId: Int): Seq[BlockManagerId] = {
+ // TODO (SPARK-33481) This is a naive way of calculating numMergersDesired for a stage,
+ // TODO we can use better heuristics to calculate numMergersDesired for a stage.
+ val maxExecutors = if (Utils.isDynamicAllocationEnabled(sc.getConf)) {
+ maxNumExecutors
+ } else {
+ numExecutors
+ }
+ val tasksPerExecutor = sc.resourceProfileManager
+ .resourceProfileFromId(resourceProfileId).maxTasksPerExecutor(sc.conf)
+ val numMergersDesired = math.min(
+ math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxExecutors)
+ val minMergersNeeded = math.max(minMergersStaticThreshold,
+ math.floor(numMergersDesired * minMergersThresholdRatio).toInt)
+
+ // Request for numMergersDesired shuffle mergers to BlockManagerMasterEndpoint
+ // and if it's less than minMergersNeeded, we disable push based shuffle.
+ val mergerLocations = blockManagerMaster
+ .getShufflePushMergerLocations(numMergersDesired, scheduler.excludedNodes())
+ if (mergerLocations.size < numMergersDesired && mergerLocations.size < minMergersNeeded) {
+ Seq.empty[BlockManagerId]
+ } else {
+ logDebug(s"The number of shuffle mergers desired ${numMergersDesired}" +
+ s" and available locations are ${mergerLocations.length}")
+ mergerLocations
+ }
+ }
+
/**
* Add filters to the SparkUI.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala
new file mode 100644
index 0000000000000..c680502cb328f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/QueryCompilationErrors.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.errors
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{Expression, GroupingID}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.toPrettySQL
+import org.apache.spark.sql.connector.catalog.TableChange
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{AbstractDataType, DataType, StructType}
+
+/**
+ * Object for grouping all error messages of the query compilation.
+ * Currently it includes all AnalysisExcpetions created and thrown directly in
+ * org.apache.spark.sql.catalyst.analysis.Analyzer.
+ */
+object QueryCompilationErrors {
+ def groupingIDMismatchError(groupingID: GroupingID, groupByExprs: Seq[Expression]): Throwable = {
+ new AnalysisException(
+ s"Columns of grouping_id (${groupingID.groupByExprs.mkString(",")}) " +
+ s"does not match grouping columns (${groupByExprs.mkString(",")})")
+ }
+
+ def groupingColInvalidError(groupingCol: Expression, groupByExprs: Seq[Expression]): Throwable = {
+ new AnalysisException(
+ s"Column of grouping ($groupingCol) can't be found " +
+ s"in grouping columns ${groupByExprs.mkString(",")}")
+ }
+
+ def groupingSizeTooLargeError(sizeLimit: Int): Throwable = {
+ new AnalysisException(
+ s"Grouping sets size cannot be greater than $sizeLimit")
+ }
+
+ def unorderablePivotColError(pivotCol: Expression): Throwable = {
+ new AnalysisException(
+ s"Invalid pivot column '$pivotCol'. Pivot columns must be comparable."
+ )
+ }
+
+ def nonLiteralPivotValError(pivotVal: Expression): Throwable = {
+ new AnalysisException(
+ s"Literal expressions required for pivot values, found '$pivotVal'")
+ }
+
+ def pivotValDataTypeMismatchError(pivotVal: Expression, pivotCol: Expression): Throwable = {
+ new AnalysisException(
+ s"Invalid pivot value '$pivotVal': " +
+ s"value data type ${pivotVal.dataType.simpleString} does not match " +
+ s"pivot column data type ${pivotCol.dataType.catalogString}")
+ }
+
+ def unsupportedIfNotExistsError(tableName: String): Throwable = {
+ new AnalysisException(
+ s"Cannot write, IF NOT EXISTS is not supported for table: $tableName")
+ }
+
+ def nonPartitionColError(partitionName: String): Throwable = {
+ new AnalysisException(
+ s"PARTITION clause cannot contain a non-partition column name: $partitionName")
+ }
+
+ def addStaticValToUnknownColError(staticName: String): Throwable = {
+ new AnalysisException(
+ s"Cannot add static value for unknown column: $staticName")
+ }
+
+ def unknownStaticPartitionColError(name: String): Throwable = {
+ new AnalysisException(s"Unknown static partition column: $name")
+ }
+
+ def nestedGeneratorError(trimmedNestedGenerator: Expression): Throwable = {
+ new AnalysisException(
+ "Generators are not supported when it's nested in " +
+ "expressions, but got: " + toPrettySQL(trimmedNestedGenerator))
+ }
+
+ def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = {
+ new AnalysisException(
+ s"Only one generator allowed per $clause clause but found " +
+ generators.size + ": " + generators.map(toPrettySQL).mkString(", "))
+ }
+
+ def generatorOutsideSelectError(plan: LogicalPlan): Throwable = {
+ new AnalysisException(
+ "Generators are not supported outside the SELECT clause, but " +
+ "got: " + plan.simpleString(SQLConf.get.maxToStringFields))
+ }
+
+ def legacyStoreAssignmentPolicyError(): Throwable = {
+ val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key
+ new AnalysisException(
+ "LEGACY store assignment policy is disallowed in Spark data source V2. " +
+ s"Please set the configuration $configKey to other values.")
+ }
+
+ def unresolvedUsingColForJoinError(
+ colName: String, plan: LogicalPlan, side: String): Throwable = {
+ new AnalysisException(
+ s"USING column `$colName` cannot be resolved on the $side " +
+ s"side of the join. The $side-side columns: [${plan.output.map(_.name).mkString(", ")}]")
+ }
+
+ def dataTypeMismatchForDeserializerError(
+ dataType: DataType, desiredType: String): Throwable = {
+ val quantifier = if (desiredType.equals("array")) "an" else "a"
+ new AnalysisException(
+ s"need $quantifier $desiredType field but got " + dataType.catalogString)
+ }
+
+ def fieldNumberMismatchForDeserializerError(
+ schema: StructType, maxOrdinal: Int): Throwable = {
+ new AnalysisException(
+ s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}, " +
+ "but failed as the number of fields does not line up.")
+ }
+
+ def upCastFailureError(
+ fromStr: String, from: Expression, to: DataType, walkedTypePath: Seq[String]): Throwable = {
+ new AnalysisException(
+ s"Cannot up cast $fromStr from " +
+ s"${from.dataType.catalogString} to ${to.catalogString}.\n" +
+ s"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
+ "You can either add an explicit cast to the input data or choose a higher precision " +
+ "type of the field in the target object")
+ }
+
+ def unsupportedAbstractDataTypeForUpCastError(gotType: AbstractDataType): Throwable = {
+ new AnalysisException(
+ s"UpCast only support DecimalType as AbstractDataType yet, but got: $gotType")
+ }
+
+ def outerScopeFailureForNewInstanceError(className: String): Throwable = {
+ new AnalysisException(
+ s"Unable to generate an encoder for inner class `$className` without " +
+ "access to the scope that this class was defined in.\n" +
+ "Try moving this class out of its parent class.")
+ }
+
+ def referenceColNotFoundForAlterTableChangesError(
+ after: TableChange.After, parentName: String): Throwable = {
+ new AnalysisException(
+ s"Couldn't find the reference column for $after at $parentName")
+ }
+
+}
+
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8d95d8cf49d45..837686420375a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -44,6 +44,7 @@ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
@@ -448,9 +449,7 @@ class Analyzer(override val catalogManager: CatalogManager)
e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) {
Alias(gid, toPrettySQL(e))()
} else {
- throw new AnalysisException(
- s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
- s"grouping columns (${groupByExprs.mkString(",")})")
+ throw QueryCompilationErrors.groupingIDMismatchError(e, groupByExprs)
}
case e @ Grouping(col: Expression) =>
val idx = groupByExprs.indexWhere(_.semanticEquals(col))
@@ -458,8 +457,7 @@ class Analyzer(override val catalogManager: CatalogManager)
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
Literal(1L)), ByteType), toPrettySQL(e))()
} else {
- throw new AnalysisException(s"Column of grouping ($col) can't be found " +
- s"in grouping columns ${groupByExprs.mkString(",")}")
+ throw QueryCompilationErrors.groupingColInvalidError(col, groupByExprs)
}
}
}
@@ -575,8 +573,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs)
if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) {
- throw new AnalysisException(
- s"Grouping sets size cannot be greater than ${GroupingID.dataType.defaultSize * 8}")
+ throw QueryCompilationErrors.groupingSizeTooLargeError(GroupingID.dataType.defaultSize * 8)
}
// Expand works by setting grouping expressions to null as determined by the
@@ -712,8 +709,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|| !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
if (!RowOrdering.isOrderable(pivotColumn.dataType)) {
- throw new AnalysisException(
- s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.")
+ throw QueryCompilationErrors.unorderablePivotColError(pivotColumn)
}
// Check all aggregate expressions.
aggregates.foreach(checkValidAggregateExpression)
@@ -724,13 +720,10 @@ class Analyzer(override val catalogManager: CatalogManager)
case _ => value.foldable
}
if (!foldable) {
- throw new AnalysisException(
- s"Literal expressions required for pivot values, found '$value'")
+ throw QueryCompilationErrors.nonLiteralPivotValError(value)
}
if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
- throw new AnalysisException(s"Invalid pivot value '$value': " +
- s"value data type ${value.dataType.simpleString} does not match " +
- s"pivot column data type ${pivotColumn.dataType.catalogString}")
+ throw QueryCompilationErrors.pivotValDataTypeMismatchError(value, pivotColumn)
}
Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
}
@@ -868,9 +861,9 @@ class Analyzer(override val catalogManager: CatalogManager)
}.getOrElse(write)
case _ => write
}
- case u @ UnresolvedTable(ident) =>
+ case u @ UnresolvedTable(ident, cmd) =>
lookupTempView(ident).foreach { _ =>
- u.failAnalysis(s"${ident.quoted} is a temp view not table.")
+ u.failAnalysis(s"${ident.quoted} is a temp view. '$cmd' expects a table")
}
u
case u @ UnresolvedTableOrView(ident, allowTempView) =>
@@ -957,7 +950,7 @@ class Analyzer(override val catalogManager: CatalogManager)
SubqueryAlias(catalog.get.name +: ident.namespace :+ ident.name, relation)
}.getOrElse(u)
- case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) =>
+ case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident), _) =>
CatalogV2Util.loadTable(catalog, ident)
.map(ResolvedTable(catalog.asTableCatalog, ident, _))
.getOrElse(u)
@@ -1084,11 +1077,11 @@ class Analyzer(override val catalogManager: CatalogManager)
lookupRelation(u.multipartIdentifier, u.options, u.isStreaming)
.map(resolveViews).getOrElse(u)
- case u @ UnresolvedTable(identifier) =>
+ case u @ UnresolvedTable(identifier, cmd) =>
lookupTableOrView(identifier).map {
case v: ResolvedView =>
val viewStr = if (v.isTemp) "temp view" else "view"
- u.failAnalysis(s"${v.identifier.quoted} is a $viewStr not table.")
+ u.failAnalysis(s"${v.identifier.quoted} is a $viewStr. '$cmd' expects a table.'")
case table => table
}.getOrElse(u)
@@ -1167,8 +1160,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
- throw new AnalysisException(
- s"Cannot write, IF NOT EXISTS is not supported for table: ${r.table.name}")
+ throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
}
val partCols = partitionColumnNames(r.table)
@@ -1205,8 +1197,7 @@ class Analyzer(override val catalogManager: CatalogManager)
partitionColumnNames.find(name => conf.resolver(name, partitionName)) match {
case Some(_) =>
case None =>
- throw new AnalysisException(
- s"PARTITION clause cannot contain a non-partition column name: $partitionName")
+ throw QueryCompilationErrors.nonPartitionColError(partitionName)
}
}
}
@@ -1228,8 +1219,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case Some(attr) =>
attr.name -> staticName
case _ =>
- throw new AnalysisException(
- s"Cannot add static value for unknown column: $staticName")
+ throw QueryCompilationErrors.addStaticValToUnknownColError(staticName)
}).toMap
val queryColumns = query.output.iterator
@@ -1271,7 +1261,7 @@ class Analyzer(override val catalogManager: CatalogManager)
// an UnresolvedAttribute.
EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType))
case None =>
- throw new AnalysisException(s"Unknown static partition column: $name")
+ throw QueryCompilationErrors.unknownStaticPartitionColError(name)
}
}.reduce(And)
}
@@ -2483,23 +2473,19 @@ class Analyzer(override val catalogManager: CatalogManager)
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
- throw new AnalysisException("Generators are not supported when it's nested in " +
- "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator)))
+ throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
case Project(projectList, _) if projectList.count(hasGenerator) > 1 =>
val generators = projectList.filter(hasGenerator).map(trimAlias)
- throw new AnalysisException("Only one generator allowed per select clause but found " +
- generators.size + ": " + generators.map(toPrettySQL).mkString(", "))
+ throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "select")
case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
val nestedGenerator = aggList.find(hasNestedGenerator).get
- throw new AnalysisException("Generators are not supported when it's nested in " +
- "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator)))
+ throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
val generators = aggList.filter(hasGenerator).map(trimAlias)
- throw new AnalysisException("Only one generator allowed per aggregate clause but found " +
- generators.size + ": " + generators.map(toPrettySQL).mkString(", "))
+ throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate")
case agg @ Aggregate(groupList, aggList, child) if aggList.forall {
case AliasedGenerator(_, _, _) => true
@@ -2582,8 +2568,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case g: Generate => g
case p if p.expressions.exists(hasGenerator) =>
- throw new AnalysisException("Generators are not supported outside the SELECT clause, but " +
- "got: " + p.simpleString(SQLConf.get.maxToStringFields))
+ throw QueryCompilationErrors.generatorOutsideSelectError(p)
}
}
@@ -3122,10 +3107,7 @@ class Analyzer(override val catalogManager: CatalogManager)
private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
- val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key
- throw new AnalysisException(s"""
- |"LEGACY" store assignment policy is disallowed in Spark data source V2.
- |Please set the configuration $configKey to other values.""".stripMargin)
+ throw QueryCompilationErrors.legacyStoreAssignmentPolicyError()
}
}
@@ -3138,14 +3120,12 @@ class Analyzer(override val catalogManager: CatalogManager)
hint: JoinHint) = {
val leftKeys = joinNames.map { keyName =>
left.output.find(attr => resolver(attr.name, keyName)).getOrElse {
- throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " +
- s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]")
+ throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, left, "left")
}
}
val rightKeys = joinNames.map { keyName =>
right.output.find(attr => resolver(attr.name, keyName)).getOrElse {
- throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " +
- s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]")
+ throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, right, "right")
}
}
val joinPairs = leftKeys.zip(rightKeys)
@@ -3208,7 +3188,8 @@ class Analyzer(override val catalogManager: CatalogManager)
ExtractValue(child, fieldName, resolver)
}
case other =>
- throw new AnalysisException("need an array field but got " + other.catalogString)
+ throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other,
+ "array")
}
case u: UnresolvedCatalystToExternalMap if u.child.resolved =>
u.child.dataType match {
@@ -3218,7 +3199,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ExtractValue(child, fieldName, resolver)
}
case other =>
- throw new AnalysisException("need a map field but got " + other.catalogString)
+ throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, "map")
}
}
validateNestedTupleFields(result)
@@ -3227,8 +3208,7 @@ class Analyzer(override val catalogManager: CatalogManager)
}
private def fail(schema: StructType, maxOrdinal: Int): Unit = {
- throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" +
- ", but failed as the number of fields does not line up.")
+ throw QueryCompilationErrors.fieldNumberMismatchForDeserializerError(schema, maxOrdinal)
}
/**
@@ -3287,10 +3267,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case n: NewInstance if n.childrenResolved && !n.resolved =>
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
- throw new AnalysisException(
- s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
- "access to the scope that this class was defined in.\n" +
- "Try moving this class out of its parent class.")
+ throw QueryCompilationErrors.outerScopeFailureForNewInstanceError(n.cls.getName)
}
n.copy(outerPointer = Some(outer))
}
@@ -3306,11 +3283,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case l: LambdaVariable => "array element"
case e => e.sql
}
- throw new AnalysisException(s"Cannot up cast $fromStr from " +
- s"${from.dataType.catalogString} to ${to.catalogString}.\n" +
- "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
- "You can either add an explicit cast to the input data or choose a higher precision " +
- "type of the field in the target object")
+ throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath)
}
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
@@ -3321,8 +3294,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case u @ UpCast(child, _, _) if !child.resolved => u
case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] =>
- throw new AnalysisException(
- s"UpCast only support DecimalType as AbstractDataType yet, but got: $target")
+ throw QueryCompilationErrors.unsupportedAbstractDataTypeForUpCastError(target)
case UpCast(child, target, walkedTypePath) if target == DecimalType
&& child.dataType.isInstanceOf[DecimalType] =>
@@ -3501,8 +3473,8 @@ class Analyzer(override val catalogManager: CatalogManager)
case Some(colName) =>
ColumnPosition.after(colName)
case None =>
- throw new AnalysisException("Couldn't find the reference column for " +
- s"$after at $parentName")
+ throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError(after,
+ parentName)
}
case other => other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 452ba80b23441..9998035d65c3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -98,7 +98,7 @@ trait CheckAnalysis extends PredicateHelper {
u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}")
case u: UnresolvedTable =>
- u.failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}")
+ u.failAnalysis(s"Table not found for '${u.commandName}': ${u.multipartIdentifier.quoted}")
case u: UnresolvedTableOrView =>
u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 508239077a70e..6fb9bed9625d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -391,6 +391,7 @@ object FunctionRegistry {
expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
+ expression[CurrentTimeZone]("current_timezone"),
expression[DateDiff]("datediff"),
expression[DateAdd]("date_add"),
expression[DateFormatClass]("date_format"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala
index 5e19a32968992..531d40f431dee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, Alte
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec
/**
* Resolve [[UnresolvedPartitionSpec]] to [[ResolvedPartitionSpec]] in partition related commands.
@@ -33,32 +34,38 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case r @ AlterTableAddPartition(
ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _) =>
- r.copy(parts = resolvePartitionSpecs(partSpecs, table.partitionSchema()))
+ r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema()))
case r @ AlterTableDropPartition(
ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _, _, _) =>
- r.copy(parts = resolvePartitionSpecs(partSpecs, table.partitionSchema()))
+ r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema()))
}
private def resolvePartitionSpecs(
- partSpecs: Seq[PartitionSpec], partSchema: StructType): Seq[ResolvedPartitionSpec] =
+ tableName: String,
+ partSpecs: Seq[PartitionSpec],
+ partSchema: StructType): Seq[ResolvedPartitionSpec] =
partSpecs.map {
case unresolvedPartSpec: UnresolvedPartitionSpec =>
ResolvedPartitionSpec(
- convertToPartIdent(unresolvedPartSpec.spec, partSchema), unresolvedPartSpec.location)
+ convertToPartIdent(tableName, unresolvedPartSpec.spec, partSchema),
+ unresolvedPartSpec.location)
case resolvedPartitionSpec: ResolvedPartitionSpec =>
resolvedPartitionSpec
}
private def convertToPartIdent(
- partSpec: TablePartitionSpec, partSchema: StructType): InternalRow = {
- val conflictKeys = partSpec.keys.toSeq.diff(partSchema.map(_.name))
- if (conflictKeys.nonEmpty) {
- throw new AnalysisException(s"Partition key ${conflictKeys.mkString(",")} not exists")
- }
+ tableName: String,
+ partitionSpec: TablePartitionSpec,
+ partSchema: StructType): InternalRow = {
+ val normalizedSpec = normalizePartitionSpec(
+ partitionSpec,
+ partSchema.map(_.name),
+ tableName,
+ conf.resolver)
val partValues = partSchema.map { part =>
- val partValue = partSpec.get(part.name).orNull
+ val partValue = normalizedSpec.get(part.name).orNull
if (partValue == null) {
null
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
index 98bd84fb94bd6..0e883a88f2691 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
@@ -37,7 +37,9 @@ case class UnresolvedNamespace(multipartIdentifier: Seq[String]) extends LeafNod
* Holds the name of a table that has yet to be looked up in a catalog. It will be resolved to
* [[ResolvedTable]] during analysis.
*/
-case class UnresolvedTable(multipartIdentifier: Seq[String]) extends LeafNode {
+case class UnresolvedTable(
+ multipartIdentifier: Seq[String],
+ commandName: String) extends LeafNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
index 3189d81289903..ff9c4cf3147d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
@@ -18,8 +18,6 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.IdentityHashMap
-import scala.collection.JavaConverters._
-
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
@@ -98,7 +96,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
proxyExpressionCurrentId += 1
- proxyMap.putAll(e.map(_ -> proxy).toMap.asJava)
+ // We leverage `IdentityHashMap` so we compare expression keys by reference here.
+ // So for example if there are one group of common exprs like Seq(common expr 1,
+ // common expr2, ..., common expr n), we will insert into `proxyMap` some key/value
+ // pairs like Map(common expr 1 -> proxy(common expr 1), ...,
+ // common expr n -> proxy(common expr 1)).
+ e.map(proxyMap.put(_, proxy))
}
// Only adding proxy if we find subexpressions.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 97aacb3f7530c..9953b780ceace 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -73,6 +73,21 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression {
}
}
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the current session local timezone.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_();
+ Asia/Shanghai
+ """,
+ group = "datetime_funcs",
+ since = "3.1.0")
+case class CurrentTimeZone() extends LeafExpression with Unevaluable {
+ override def nullable: Boolean = false
+ override def dataType: DataType = StringType
+ override def prettyName: String = "current_timezone"
+}
+
/**
* Returns the current date at the start of query evaluation.
* There is no code generation since this expression should get constant folded by the optimizer.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 16e22940495f1..9f92181b34df1 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1357,8 +1357,9 @@ object ParseUrl {
1
""",
since = "2.0.0")
-case class ParseUrl(children: Seq[Expression])
+case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled)
extends Expression with ExpectsInputTypes with CodegenFallback {
+ def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)
override def nullable: Boolean = true
override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
@@ -1404,7 +1405,9 @@ case class ParseUrl(children: Seq[Expression])
try {
new URI(url.toString)
} catch {
- case e: URISyntaxException => null
+ case e: URISyntaxException if failOnError =>
+ throw new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e)
+ case _: URISyntaxException => null
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c4b9936fa4c4f..9eee7c2b914a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -811,9 +811,12 @@ object CollapseRepartition extends Rule[LogicalPlan] {
*/
object OptimizeWindowFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
- case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), spec)
- if spec.orderSpec.nonEmpty &&
- spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame].frameType == RowFrame =>
+ case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _),
+ WindowSpecDefinition(_, orderSpec, frameSpecification: SpecifiedWindowFrame))
+ if orderSpec.nonEmpty && frameSpecification.frameType == RowFrame &&
+ frameSpecification.lower == UnboundedPreceding &&
+ (frameSpecification.upper == UnboundedFollowing ||
+ frameSpecification.upper == CurrentRow) =>
we.copy(windowFunction = NthValue(first.child, Literal(1), first.ignoreNulls))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index 9aa7e3201ab1b..1f2389176d1e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.CatalogManager
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -75,6 +76,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
val timeExpr = CurrentTimestamp()
val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long]
val currentTime = Literal.create(timestamp, timeExpr.dataType)
+ val timezone = Literal.create(SQLConf.get.sessionLocalTimeZone, StringType)
plan transformAllExpressions {
case currentDate @ CurrentDate(Some(timeZoneId)) =>
@@ -84,6 +86,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
DateType)
})
case CurrentTimestamp() | Now() => currentTime
+ case CurrentTimeZone() => timezone
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 79857a63a69b5..ea4baafbacede 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1414,8 +1414,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
// So we use LikeAll or NotLikeAll instead.
val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
ctx.NOT match {
- case null => LikeAll(e, patterns)
- case _ => NotLikeAll(e, patterns)
+ case null => LikeAll(e, patterns.toSeq)
+ case _ => NotLikeAll(e, patterns.toSeq)
}
} else {
getLikeQuantifierExprs(ctx.expression).reduceLeft(And)
@@ -3303,7 +3303,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
*/
override def visitLoadData(ctx: LoadDataContext): LogicalPlan = withOrigin(ctx) {
LoadData(
- child = UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)),
+ child = UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier), "LOAD DATA"),
path = string(ctx.path),
isLocal = ctx.LOCAL != null,
isOverwrite = ctx.OVERWRITE != null,
@@ -3449,7 +3449,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
UnresolvedPartitionSpec(spec, location)
}
AlterTableAddPartition(
- UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)),
+ UnresolvedTable(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ "ALTER TABLE ... ADD PARTITION ..."),
specsAndLocs.toSeq,
ctx.EXISTS != null)
}
@@ -3491,7 +3493,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
val partSpecs = ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec)
.map(spec => UnresolvedPartitionSpec(spec))
AlterTableDropPartition(
- UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier)),
+ UnresolvedTable(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ "ALTER TABLE ... DROP PARTITION ..."),
partSpecs.toSeq,
ifExists = ctx.EXISTS != null,
purge = ctx.PURGE != null,
@@ -3720,6 +3724,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
case _ => string(ctx.STRING)
}
val nameParts = visitMultipartIdentifier(ctx.multipartIdentifier)
- CommentOnTable(UnresolvedTable(nameParts), comment)
+ CommentOnTable(UnresolvedTable(nameParts, "COMMENT ON TABLE"), comment)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index fcf222c8fdab0..ef974dc176e51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2913,18 +2913,6 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
- val TRUNCATE_TRASH_ENABLED =
- buildConf("spark.sql.truncate.trash.enabled")
- .doc("This configuration decides when truncating table, whether data files will be moved " +
- "to trash directory or deleted permanently. The trash retention time is controlled by " +
- "'fs.trash.interval', and in default, the server side configuration value takes " +
- "precedence over the client-side one. Note that if 'fs.trash.interval' is non-positive, " +
- "this will be a no-op and log a warning message. If the data fails to be moved to " +
- "trash, Spark will turn to delete it permanently.")
- .version("3.1.0")
- .booleanConf
- .createWithDefault(false)
-
val DISABLED_JDBC_CONN_PROVIDER_LIST =
buildConf("spark.sql.sources.disabledJdbcConnProviderList")
.internal()
@@ -3577,8 +3565,6 @@ class SQLConf extends Serializable with Logging {
def legacyPathOptionBehavior: Boolean = getConf(SQLConf.LEGACY_PATH_OPTION_BEHAVIOR)
- def truncateTrashEnabled: Boolean = getConf(SQLConf.TRUNCATE_TRASH_ENABLED)
-
def disabledJdbcConnectionProviders: String = getConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST)
/** ********************** SQLConf functionality methods ************ */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala
new file mode 100644
index 0000000000000..586aa6c59164f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.Resolver
+
+object PartitioningUtils {
+ /**
+ * Normalize the column names in partition specification, w.r.t. the real partition column names
+ * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
+ * partition column named `month`, and it's case insensitive, we will normalize `monTh` to
+ * `month`.
+ */
+ def normalizePartitionSpec[T](
+ partitionSpec: Map[String, T],
+ partColNames: Seq[String],
+ tblName: String,
+ resolver: Resolver): Map[String, T] = {
+ val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) =>
+ val normalizedKey = partColNames.find(resolver(_, key)).getOrElse {
+ throw new AnalysisException(s"$key is not a valid partition column in table $tblName.")
+ }
+ normalizedKey -> value
+ }
+
+ SchemaUtils.checkColumnNameDuplication(
+ normalizedPartSpec.map(_._1), "in the partition schema", resolver)
+
+ normalizedPartSpec.toMap
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index a1b6cec24f23f..730574a4b9846 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -943,6 +943,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil)
}
+ test("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ val msg = intercept[IllegalArgumentException] {
+ evaluateWithoutCodegen(
+ ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST")))
+ }.getMessage
+ assert(msg.contains("Find an invaild url string"))
+ }
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ checkEvaluation(
+ ParseUrl(Seq("https://a.b.c/index.php?params1=a|b¶ms2=x", "HOST")), null)
+ }
+ }
+
test("Sentences") {
val nullString = Literal.create(null, StringType)
checkEvaluation(Sentences(nullString, nullString, nullString), null)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
index 64b619ca7766b..f8dca266a62d4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
@@ -95,4 +95,26 @@ class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
})
assert(proxys.isEmpty)
}
+
+ test("SubExprEvaluationRuntime should wrap semantically equal exprs") {
+ val runtime = new SubExprEvaluationRuntime(1)
+
+ val one = Literal(1)
+ val two = Literal(2)
+ def mul: (Literal, Literal) => Expression =
+ (left: Literal, right: Literal) => Multiply(left, right)
+
+ val mul2_1 = Multiply(mul(one, two), mul(one, two))
+ val mul2_2 = Multiply(mul(one, two), mul(one, two))
+
+ val sqrt = Sqrt(mul2_1)
+ val sum = Add(mul2_2, sqrt)
+ val proxyExpressions = runtime.proxyExpressions(Seq(sum))
+ val proxys = proxyExpressions.flatMap(_.collect {
+ case p: ExpressionProxy => p
+ })
+ // ( (one * two) * (one * two) )
+ assert(proxys.size == 2)
+ assert(proxys.forall(_.child.semanticEquals(mul2_1)))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
index db0399d2a73ee..82d6757407b51 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
@@ -20,11 +20,13 @@ package org.apache.spark.sql.catalyst.optimizer
import java.time.ZoneId
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.unsafe.types.UTF8String
class ComputeCurrentTimeSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -67,4 +69,16 @@ class ComputeCurrentTimeSuite extends PlanTest {
assert(lits(1) >= min && lits(1) <= max)
assert(lits(0) == lits(1))
}
+
+ test("SPARK-33469: Add current_timezone function") {
+ val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation())
+ val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
+ val lits = new scala.collection.mutable.ArrayBuffer[String]
+ plan.transformAllExpressions { case e: Literal =>
+ lits += e.value.asInstanceOf[UTF8String].toString
+ e
+ }
+ assert(lits.size == 1)
+ assert(lits.head == SQLConf.get.sessionLocalTimeZone)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala
index 389aaeafe655f..cf850bbe21ce6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala
@@ -36,7 +36,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest {
val b = testRelation.output(1)
val c = testRelation.output(2)
- test("replace first(col) by nth_value(col, 1)") {
+ test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND CURRENT ROW") {
val inputPlan = testRelation.select(
WindowExpression(
First(a, false).toAggregateExpression(),
@@ -52,7 +52,34 @@ class OptimizeWindowFunctionsSuite extends PlanTest {
assert(optimized == correctAnswer)
}
- test("can't replace first(col) by nth_value(col, 1) if the window frame type is range") {
+ test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") {
+ val inputPlan = testRelation.select(
+ WindowExpression(
+ First(a, false).toAggregateExpression(),
+ WindowSpecDefinition(b :: Nil, c.asc :: Nil,
+ SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))))
+ val correctAnswer = testRelation.select(
+ WindowExpression(
+ NthValue(a, Literal(1), false),
+ WindowSpecDefinition(b :: Nil, c.asc :: Nil,
+ SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))))
+
+ val optimized = Optimize.execute(inputPlan)
+ assert(optimized == correctAnswer)
+ }
+
+ test("can't replace first by nth_value if frame is not suitable") {
+ val inputPlan = testRelation.select(
+ WindowExpression(
+ First(a, false).toAggregateExpression(),
+ WindowSpecDefinition(b :: Nil, c.asc :: Nil,
+ SpecifiedWindowFrame(RowFrame, Literal(1), CurrentRow))))
+
+ val optimized = Optimize.execute(inputPlan)
+ assert(optimized == inputPlan)
+ }
+
+ test("can't replace first by nth_value if the window frame type is range") {
val inputPlan = testRelation.select(
WindowExpression(
First(a, false).toAggregateExpression(),
@@ -63,7 +90,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest {
assert(optimized == inputPlan)
}
- test("can't replace first(col) by nth_value(col, 1) if the window frame isn't ordered") {
+ test("can't replace first by nth_value if the window frame isn't ordered") {
val inputPlan = testRelation.select(
WindowExpression(
First(a, false).toAggregateExpression(),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index f93c0dcf59f4c..bd28484b23f46 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -1555,15 +1555,15 @@ class DDLParserSuite extends AnalysisTest {
test("LOAD DATA INTO table") {
comparePlans(
parsePlan("LOAD DATA INPATH 'filepath' INTO TABLE a.b.c"),
- LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", false, false, None))
+ LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", false, false, None))
comparePlans(
parsePlan("LOAD DATA LOCAL INPATH 'filepath' INTO TABLE a.b.c"),
- LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", true, false, None))
+ LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, false, None))
comparePlans(
parsePlan("LOAD DATA LOCAL INPATH 'filepath' OVERWRITE INTO TABLE a.b.c"),
- LoadData(UnresolvedTable(Seq("a", "b", "c")), "filepath", true, true, None))
+ LoadData(UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"), "filepath", true, true, None))
comparePlans(
parsePlan(
@@ -1572,7 +1572,7 @@ class DDLParserSuite extends AnalysisTest {
|PARTITION(ds='2017-06-10')
""".stripMargin),
LoadData(
- UnresolvedTable(Seq("a", "b", "c")),
+ UnresolvedTable(Seq("a", "b", "c"), "LOAD DATA"),
"filepath",
true,
true,
@@ -1674,13 +1674,13 @@ class DDLParserSuite extends AnalysisTest {
val parsed2 = parsePlan(sql2)
val expected1 = AlterTableAddPartition(
- UnresolvedTable(Seq("a", "b", "c")),
+ UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ..."),
Seq(
UnresolvedPartitionSpec(Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")),
UnresolvedPartitionSpec(Map("dt" -> "2009-09-09", "country" -> "uk"), None)),
ifNotExists = true)
val expected2 = AlterTableAddPartition(
- UnresolvedTable(Seq("a", "b", "c")),
+ UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... ADD PARTITION ..."),
Seq(UnresolvedPartitionSpec(Map("dt" -> "2008-08-08"), Some("loc"))),
ifNotExists = false)
@@ -1747,7 +1747,7 @@ class DDLParserSuite extends AnalysisTest {
assertUnsupported(sql2_view)
val expected1_table = AlterTableDropPartition(
- UnresolvedTable(Seq("table_name")),
+ UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP PARTITION ..."),
Seq(
UnresolvedPartitionSpec(Map("dt" -> "2008-08-08", "country" -> "us")),
UnresolvedPartitionSpec(Map("dt" -> "2009-09-09", "country" -> "uk"))),
@@ -1763,7 +1763,7 @@ class DDLParserSuite extends AnalysisTest {
val sql3_table = "ALTER TABLE a.b.c DROP IF EXISTS PARTITION (ds='2017-06-10')"
val expected3_table = AlterTableDropPartition(
- UnresolvedTable(Seq("a", "b", "c")),
+ UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... DROP PARTITION ..."),
Seq(UnresolvedPartitionSpec(Map("ds" -> "2017-06-10"))),
ifExists = true,
purge = false,
@@ -2174,7 +2174,7 @@ class DDLParserSuite extends AnalysisTest {
comparePlans(
parsePlan("COMMENT ON TABLE a.b.c IS 'xYz'"),
- CommentOnTable(UnresolvedTable(Seq("a", "b", "c")), "xYz"))
+ CommentOnTable(UnresolvedTable(Seq("a", "b", "c"), "COMMENT ON TABLE"), "xYz"))
}
// TODO: ignored by SPARK-31707, restore the test after create table syntax unification
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala
index 1c96bdf3afa20..23987e909aa70 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala
@@ -92,4 +92,8 @@ class InMemoryPartitionTable(
override def partitionExists(ident: InternalRow): Boolean =
memoryTablePartitions.containsKey(ident)
+
+ override protected def addPartitionKey(key: Seq[Any]): Unit = {
+ memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
index 3b47271a114e2..c93053abc550a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
@@ -160,12 +160,15 @@ class InMemoryTable(
}
}
+ protected def addPartitionKey(key: Seq[Any]): Unit = {}
+
def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {
data.foreach(_.rows.foreach { row =>
val key = getKey(row)
dataMap += dataMap.get(key)
.map(key -> _.withRow(row))
.getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row))
+ addPartitionKey(key)
})
this
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 276d5d29bfa2c..b26bc6441b6cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -493,6 +493,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `pathGlobFilter`: an optional glob pattern to only include files with paths matching
* the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter
.
* It does not change the behavior of partition discovery.
+ * `modifiedBefore` (batch only): an optional timestamp to only include files with
+ * modification times occurring before the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ * `modifiedAfter` (batch only): an optional timestamp to only include files with
+ * modification times occurring after the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
* `recursiveFileLookup`: recursively scan a directory for files. Using this option
* disables partition discovery
* `allowNonNumericNumbers` (default `true`): allows JSON parser to recognize set of
@@ -750,6 +756,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `pathGlobFilter`: an optional glob pattern to only include files with paths matching
* the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter
.
* It does not change the behavior of partition discovery.
+ * `modifiedBefore` (batch only): an optional timestamp to only include files with
+ * modification times occurring before the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ * `modifiedAfter` (batch only): an optional timestamp to only include files with
+ * modification times occurring after the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
* `recursiveFileLookup`: recursively scan a directory for files. Using this option
* disables partition discovery
*
@@ -781,6 +793,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `pathGlobFilter`: an optional glob pattern to only include files with paths matching
* the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter
.
* It does not change the behavior of partition discovery.
+ * `modifiedBefore` (batch only): an optional timestamp to only include files with
+ * modification times occurring before the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ * `modifiedAfter` (batch only): an optional timestamp to only include files with
+ * modification times occurring after the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
* `recursiveFileLookup`: recursively scan a directory for files. Using this option
* disables partition discovery
*
@@ -814,6 +832,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `pathGlobFilter`: an optional glob pattern to only include files with paths matching
* the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter
.
* It does not change the behavior of partition discovery.
+ * `modifiedBefore` (batch only): an optional timestamp to only include files with
+ * modification times occurring before the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ * `modifiedAfter` (batch only): an optional timestamp to only include files with
+ * modification times occurring after the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
* `recursiveFileLookup`: recursively scan a directory for files. Using this option
* disables partition discovery
*
@@ -880,6 +904,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `pathGlobFilter`: an optional glob pattern to only include files with paths matching
* the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter
.
* It does not change the behavior of partition discovery.
+ * `modifiedBefore` (batch only): an optional timestamp to only include files with
+ * modification times occurring before the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
+ * `modifiedAfter` (batch only): an optional timestamp to only include files with
+ * modification times occurring after the specified Time. The provided timestamp
+ * must be in the following form: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
* `recursiveFileLookup`: recursively scan a directory for files. Using this option
* disables partition discovery
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
index fc62dce5002b1..0b265bfb63e3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, Unresol
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
-import org.apache.spark.sql.execution.datasources.PartitioningUtils
+import org.apache.spark.sql.util.PartitioningUtils
/**
* Analyzes a given set of partitions to generate per-partition statistics, which will be used in
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index d550fe270c753..27ad62026c9b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -39,11 +39,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog}
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitioningUtils}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.PartitioningUtils
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
// Note: The definition of these commands are based on the ones described in
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 206f952fed0ca..bd238948aab02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier, CaseInsensitiveMap}
-import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils}
+import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -47,8 +47,8 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.PartitioningUtils
import org.apache.spark.sql.util.SchemaUtils
-import org.apache.spark.util.Utils
/**
* A command to create a table with the same definition of the given existing table.
@@ -490,7 +490,6 @@ case class TruncateTableCommand(
}
val hadoopConf = spark.sessionState.newHadoopConf()
val ignorePermissionAcl = SQLConf.get.truncateTableIgnorePermissionAcl
- val isTrashEnabled = SQLConf.get.truncateTrashEnabled
locations.foreach { location =>
if (location.isDefined) {
val path = new Path(location.get)
@@ -515,7 +514,7 @@ case class TruncateTableCommand(
}
}
- Utils.moveToTrashOrDelete(fs, path, isTrashEnabled, hadoopConf)
+ fs.delete(path, true)
// We should keep original permission/acl of the path.
// For owner/group, only super-user can set it, for example on HDFS. Because
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
index fed9614347f6a..5b0d0606da093 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
@@ -57,13 +57,10 @@ abstract class PartitioningAwareFileIndex(
protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]]
private val caseInsensitiveMap = CaseInsensitiveMap(parameters)
+ private val pathFilters = PathFilterFactory.create(caseInsensitiveMap)
- protected lazy val pathGlobFilter: Option[GlobFilter] =
- caseInsensitiveMap.get("pathGlobFilter").map(new GlobFilter(_))
-
- protected def matchGlobPattern(file: FileStatus): Boolean = {
- pathGlobFilter.forall(_.accept(file.getPath))
- }
+ protected def matchPathPattern(file: FileStatus): Boolean =
+ pathFilters.forall(_.accept(file))
protected lazy val recursiveFileLookup: Boolean = {
caseInsensitiveMap.getOrElse("recursiveFileLookup", "false").toBoolean
@@ -86,7 +83,7 @@ abstract class PartitioningAwareFileIndex(
val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match {
case Some(existingDir) =>
// Directory has children files in it, return them
- existingDir.filter(f => matchGlobPattern(f) && isNonEmptyFile(f))
+ existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f))
case None =>
// Directory does not exist, or has no children files
@@ -135,7 +132,7 @@ abstract class PartitioningAwareFileIndex(
} else {
leafFiles.values.toSeq
}
- files.filter(matchGlobPattern)
+ files.filter(matchPathPattern)
}
protected def inferPartitioning(): PartitionSpec = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 796c23c7337d8..ea437d200eaab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -30,7 +30,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter}
@@ -357,30 +357,6 @@ object PartitioningUtils {
getPathFragment(spec, StructType.fromAttributes(partitionColumns))
}
- /**
- * Normalize the column names in partition specification, w.r.t. the real partition column names
- * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
- * partition column named `month`, and it's case insensitive, we will normalize `monTh` to
- * `month`.
- */
- def normalizePartitionSpec[T](
- partitionSpec: Map[String, T],
- partColNames: Seq[String],
- tblName: String,
- resolver: Resolver): Map[String, T] = {
- val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) =>
- val normalizedKey = partColNames.find(resolver(_, key)).getOrElse {
- throw new AnalysisException(s"$key is not a valid partition column in table $tblName.")
- }
- normalizedKey -> value
- }
-
- SchemaUtils.checkColumnNameDuplication(
- normalizedPartSpec.map(_._1), "in the partition schema", resolver)
-
- normalizedPartSpec.toMap
- }
-
/**
* Resolves possible type conflicts between partitions by up-casting "lower" types using
* [[findWiderTypeForPartitionColumn]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala
new file mode 100644
index 0000000000000..c8f23988f93c6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.util.{Locale, TimeZone}
+
+import org.apache.hadoop.fs.{FileStatus, GlobFilter}
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.unsafe.types.UTF8String
+
+trait PathFilterStrategy extends Serializable {
+ def accept(fileStatus: FileStatus): Boolean
+}
+
+trait StrategyBuilder {
+ def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy]
+}
+
+class PathGlobFilter(filePatten: String) extends PathFilterStrategy {
+
+ private val globFilter = new GlobFilter(filePatten)
+
+ override def accept(fileStatus: FileStatus): Boolean =
+ globFilter.accept(fileStatus.getPath)
+}
+
+object PathGlobFilter extends StrategyBuilder {
+ val PARAM_NAME = "pathglobfilter"
+
+ override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = {
+ parameters.get(PARAM_NAME).map(new PathGlobFilter(_))
+ }
+}
+
+/**
+ * Provide modifiedAfter and modifiedBefore options when
+ * filtering from a batch-based file data source.
+ *
+ * Example Usages
+ * Load all CSV files modified after date:
+ * {{{
+ * spark.read.format("csv").option("modifiedAfter","2020-06-15T05:00:00").load()
+ * }}}
+ *
+ * Load all CSV files modified before date:
+ * {{{
+ * spark.read.format("csv").option("modifiedBefore","2020-06-15T05:00:00").load()
+ * }}}
+ *
+ * Load all CSV files modified between two dates:
+ * {{{
+ * spark.read.format("csv").option("modifiedAfter","2019-01-15T05:00:00")
+ * .option("modifiedBefore","2020-06-15T05:00:00").load()
+ * }}}
+ */
+abstract class ModifiedDateFilter extends PathFilterStrategy {
+
+ def timeZoneId: String
+
+ protected def localTime(micros: Long): Long =
+ DateTimeUtils.fromUTCTime(micros, timeZoneId)
+}
+
+object ModifiedDateFilter {
+
+ def getTimeZoneId(options: CaseInsensitiveMap[String]): String = {
+ options.getOrElse(
+ DateTimeUtils.TIMEZONE_OPTION.toLowerCase(Locale.ROOT),
+ SQLConf.get.sessionLocalTimeZone)
+ }
+
+ def toThreshold(timeString: String, timeZoneId: String, strategy: String): Long = {
+ val timeZone: TimeZone = DateTimeUtils.getTimeZone(timeZoneId)
+ val ts = UTF8String.fromString(timeString)
+ DateTimeUtils.stringToTimestamp(ts, timeZone.toZoneId).getOrElse {
+ throw new AnalysisException(
+ s"The timestamp provided for the '$strategy' option is invalid. The expected format " +
+ s"is 'YYYY-MM-DDTHH:mm:ss', but the provided timestamp: $timeString")
+ }
+ }
+}
+
+/**
+ * Filter used to determine whether file was modified before the provided timestamp.
+ */
+class ModifiedBeforeFilter(thresholdTime: Long, val timeZoneId: String)
+ extends ModifiedDateFilter {
+
+ override def accept(fileStatus: FileStatus): Boolean =
+ // We standardize on microseconds wherever possible
+ // getModificationTime returns in milliseconds
+ thresholdTime - localTime(DateTimeUtils.millisToMicros(fileStatus.getModificationTime)) > 0
+}
+
+object ModifiedBeforeFilter extends StrategyBuilder {
+ import ModifiedDateFilter._
+
+ val PARAM_NAME = "modifiedbefore"
+
+ override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = {
+ parameters.get(PARAM_NAME).map { value =>
+ val timeZoneId = getTimeZoneId(parameters)
+ val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME)
+ new ModifiedBeforeFilter(thresholdTime, timeZoneId)
+ }
+ }
+}
+
+/**
+ * Filter used to determine whether file was modified after the provided timestamp.
+ */
+class ModifiedAfterFilter(thresholdTime: Long, val timeZoneId: String)
+ extends ModifiedDateFilter {
+
+ override def accept(fileStatus: FileStatus): Boolean =
+ // getModificationTime returns in milliseconds
+ // We standardize on microseconds wherever possible
+ localTime(DateTimeUtils.millisToMicros(fileStatus.getModificationTime)) - thresholdTime > 0
+}
+
+object ModifiedAfterFilter extends StrategyBuilder {
+ import ModifiedDateFilter._
+
+ val PARAM_NAME = "modifiedafter"
+
+ override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = {
+ parameters.get(PARAM_NAME).map { value =>
+ val timeZoneId = getTimeZoneId(parameters)
+ val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME)
+ new ModifiedAfterFilter(thresholdTime, timeZoneId)
+ }
+ }
+}
+
+object PathFilterFactory {
+
+ private val strategies =
+ Seq(PathGlobFilter, ModifiedBeforeFilter, ModifiedAfterFilter)
+
+ def create(parameters: CaseInsensitiveMap[String]): Seq[PathFilterStrategy] = {
+ strategies.flatMap { _.create(parameters) }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 3a2a642b870f8..9e65b0ce13693 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.{AtomicType, StructType}
+import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec
import org.apache.spark.sql.util.SchemaUtils
/**
@@ -386,7 +387,7 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] {
partColNames: Seq[String],
catalogTable: Option[CatalogTable]): InsertIntoStatement = {
- val normalizedPartSpec = PartitioningUtils.normalizePartitionSpec(
+ val normalizedPartSpec = normalizePartitionSpec(
insert.partitionSpec, partColNames, tblName, conf.resolver)
val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 21abfc2816ee4..e5c29312b80e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -147,6 +147,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
catalog match {
case staging: StagingTableCatalog =>
AtomicReplaceTableAsSelectExec(
+ session,
staging,
ident,
parts,
@@ -157,6 +158,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
orCreate = orCreate) :: Nil
case _ =>
ReplaceTableAsSelectExec(
+ session,
catalog,
ident,
parts,
@@ -170,9 +172,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case AppendData(r: DataSourceV2Relation, query, writeOptions, _) =>
r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
- AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil
+ AppendDataExecV1(v1, writeOptions.asOptions, query, r) :: Nil
case v2 =>
- AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil
+ AppendDataExec(session, v2, r, writeOptions.asOptions, planLater(query)) :: Nil
}
case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) =>
@@ -184,14 +186,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}.toArray
r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
- OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil
+ OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query, r) :: Nil
case v2 =>
- OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil
+ OverwriteByExpressionExec(session, v2, r, filters,
+ writeOptions.asOptions, planLater(query)) :: Nil
}
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) =>
OverwritePartitionsDynamicExec(
- r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil
+ session, r.table.asWritable, r, writeOptions.asOptions, planLater(query)) :: Nil
case DeleteFromTable(relation, condition) =>
relation match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
index 560da39314b36..af7721588edeb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
@@ -37,10 +37,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class AppendDataExecV1(
table: SupportsWrite,
writeOptions: CaseInsensitiveStringMap,
- plan: LogicalPlan) extends V1FallbackWriters {
+ plan: LogicalPlan,
+ v2Relation: DataSourceV2Relation) extends V1FallbackWriters {
override protected def run(): Seq[InternalRow] = {
- writeWithV1(newWriteBuilder().buildForV1Write())
+ writeWithV1(newWriteBuilder().buildForV1Write(), Some(v2Relation))
}
}
@@ -59,7 +60,8 @@ case class OverwriteByExpressionExecV1(
table: SupportsWrite,
deleteWhere: Array[Filter],
writeOptions: CaseInsensitiveStringMap,
- plan: LogicalPlan) extends V1FallbackWriters {
+ plan: LogicalPlan,
+ v2Relation: DataSourceV2Relation) extends V1FallbackWriters {
private def isTruncate(filters: Array[Filter]): Boolean = {
filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
@@ -68,10 +70,10 @@ case class OverwriteByExpressionExecV1(
override protected def run(): Seq[InternalRow] = {
newWriteBuilder() match {
case builder: SupportsTruncate if isTruncate(deleteWhere) =>
- writeWithV1(builder.truncate().asV1Builder.buildForV1Write())
+ writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), Some(v2Relation))
case builder: SupportsOverwrite =>
- writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write())
+ writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(), Some(v2Relation))
case _ =>
throw new SparkException(s"Table does not support overwrite by expression: $table")
@@ -112,9 +114,14 @@ sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write {
trait SupportsV1Write extends SparkPlan {
def plan: LogicalPlan
- protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = {
+ protected def writeWithV1(
+ relation: InsertableRelation,
+ v2Relation: Option[DataSourceV2Relation] = None): Seq[InternalRow] = {
+ val session = sqlContext.sparkSession
// The `plan` is already optimized, we should not analyze and optimize it again.
- relation.insert(AlreadyOptimized.dataFrame(sqlContext.sparkSession, plan), overwrite = false)
+ relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false)
+ v2Relation.foreach(r => session.sharedState.cacheManager.recacheByPlan(session, r))
+
Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 1421a9315c3a8..1648134d0a1b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -127,6 +128,7 @@ case class AtomicCreateTableAsSelectExec(
* ReplaceTableAsSelectStagingExec.
*/
case class ReplaceTableAsSelectExec(
+ session: SparkSession,
catalog: TableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
@@ -146,6 +148,8 @@ case class ReplaceTableAsSelectExec(
// 2. Writing to the new table fails,
// 3. The table returned by catalog.createTable doesn't support writing.
if (catalog.tableExists(ident)) {
+ val table = catalog.loadTable(ident)
+ uncacheTable(session, catalog, table, ident)
catalog.dropTable(ident)
} else if (!orCreate) {
throw new CannotReplaceMissingTableException(ident)
@@ -169,6 +173,7 @@ case class ReplaceTableAsSelectExec(
* is left untouched.
*/
case class AtomicReplaceTableAsSelectExec(
+ session: SparkSession,
catalog: StagingTableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
@@ -180,6 +185,10 @@ case class AtomicReplaceTableAsSelectExec(
override protected def run(): Seq[InternalRow] = {
val schema = query.schema.asNullable
+ if (catalog.tableExists(ident)) {
+ val table = catalog.loadTable(ident)
+ uncacheTable(session, catalog, table, ident)
+ }
val staged = if (orCreate) {
catalog.stageCreateOrReplace(
ident, schema, partitioning.toArray, properties.asJava)
@@ -204,12 +213,16 @@ case class AtomicReplaceTableAsSelectExec(
* Rows in the output data set are appended.
*/
case class AppendDataExec(
+ session: SparkSession,
table: SupportsWrite,
+ relation: DataSourceV2Relation,
writeOptions: CaseInsensitiveStringMap,
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
override protected def run(): Seq[InternalRow] = {
- writeWithV2(newWriteBuilder().buildForBatch())
+ val writtenRows = writeWithV2(newWriteBuilder().buildForBatch())
+ session.sharedState.cacheManager.recacheByPlan(session, relation)
+ writtenRows
}
}
@@ -224,7 +237,9 @@ case class AppendDataExec(
* AlwaysTrue to delete all rows.
*/
case class OverwriteByExpressionExec(
+ session: SparkSession,
table: SupportsWrite,
+ relation: DataSourceV2Relation,
deleteWhere: Array[Filter],
writeOptions: CaseInsensitiveStringMap,
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
@@ -234,7 +249,7 @@ case class OverwriteByExpressionExec(
}
override protected def run(): Seq[InternalRow] = {
- newWriteBuilder() match {
+ val writtenRows = newWriteBuilder() match {
case builder: SupportsTruncate if isTruncate(deleteWhere) =>
writeWithV2(builder.truncate().buildForBatch())
@@ -244,9 +259,12 @@ case class OverwriteByExpressionExec(
case _ =>
throw new SparkException(s"Table does not support overwrite by expression: $table")
}
+ session.sharedState.cacheManager.recacheByPlan(session, relation)
+ writtenRows
}
}
+
/**
* Physical plan node for dynamic partition overwrite into a v2 table.
*
@@ -257,18 +275,22 @@ case class OverwriteByExpressionExec(
* are not modified.
*/
case class OverwritePartitionsDynamicExec(
+ session: SparkSession,
table: SupportsWrite,
+ relation: DataSourceV2Relation,
writeOptions: CaseInsensitiveStringMap,
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
override protected def run(): Seq[InternalRow] = {
- newWriteBuilder() match {
+ val writtenRows = newWriteBuilder() match {
case builder: SupportsDynamicOverwrite =>
writeWithV2(builder.overwriteDynamicPartitions().buildForBatch())
case _ =>
throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
}
+ session.sharedState.cacheManager.recacheByPlan(session, relation)
+ writtenRows
}
}
@@ -370,6 +392,15 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
Nil
}
+
+ protected def uncacheTable(
+ session: SparkSession,
+ catalog: TableCatalog,
+ table: Table,
+ ident: Identifier): Unit = {
+ val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
+ session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true)
+ }
}
object DataWritingSparkTask extends Logging {
@@ -484,3 +515,4 @@ private[v2] case class DataWritingSparkTaskResult(
* Sink progress information collected after commit.
*/
private[sql] case class StreamWriterCommitProgress(numOutputRows: Long)
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala
index 712ed1585bc8a..6f43542fd6595 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala
@@ -23,6 +23,7 @@ import scala.util.Try
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.execution.datasources.{ModifiedAfterFilter, ModifiedBeforeFilter}
import org.apache.spark.util.Utils
/**
@@ -32,6 +33,16 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
+ checkDisallowedOptions(parameters)
+
+ private def checkDisallowedOptions(options: Map[String, String]): Unit = {
+ Seq(ModifiedBeforeFilter.PARAM_NAME, ModifiedAfterFilter.PARAM_NAME).foreach { param =>
+ if (parameters.contains(param)) {
+ throw new IllegalArgumentException(s"option '$param' is not allowed in file stream sources")
+ }
+ }
+ }
+
val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str =>
Try(str.toInt).toOption.filter(_ > 0).getOrElse {
throw new IllegalArgumentException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala
index 77078046dda7c..f48672afb41f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala
@@ -189,8 +189,8 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab)
val graphUIDataForNumRowsDroppedByWatermark =
new GraphUIData(
- "aggregated-num-state-rows-dropped-by-watermark-timeline",
- "aggregated-num-state-rows-dropped-by-watermark-histogram",
+ "aggregated-num-rows-dropped-by-watermark-timeline",
+ "aggregated-num-rows-dropped-by-watermark-histogram",
numRowsDroppedByWatermarkData,
minBatchTime,
maxBatchTime,
@@ -209,33 +209,33 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab)
{graphUIDataForNumberTotalRows.generateTimelineHtml(jsCollector)} |
{graphUIDataForNumberTotalRows.generateHistogramHtml(jsCollector)} |
-
-
-
- Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
-
- |
- {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} |
- {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} |
-
-
-
-
- Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
-
- |
- {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} |
- {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} |
-
-
-
-
- Aggregated Number Of State Rows Dropped By Watermark {SparkUIUtils.tooltip("Aggregated number of state rows dropped by watermark.", "right")}
-
- |
- {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} |
- {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} |
-
+
+
+
+ Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
+
+ |
+ {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} |
+ {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} |
+
+
+
+
+ Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
+
+ |
+ {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} |
+ {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} |
+
+
+
+
+ Aggregated Number Of Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all input rows being dropped in stateful operators by watermark. 'Inputs' are relative to operators.", "right")}
+
+ |
+ {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} |
+ {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} |
+
// scalastyle:on
} else {
new NodeBuffer()
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index da83df4994d8d..0a54dff3a1cea 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
## Summary
- - Number of queries: 341
+ - Number of queries: 342
- Number of expressions that missing example: 13
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window
## Schema of Built-in Functions
@@ -86,6 +86,7 @@
| org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct |
| org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct |
| org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct |
+| org.apache.spark.sql.catalyst.expressions.CurrentTimeZone | current_timezone | SELECT current_timezone() | struct |
| org.apache.spark.sql.catalyst.expressions.CurrentTimestamp | current_timestamp | SELECT current_timestamp() | struct |
| org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct |
| org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6a1378837ea9b..953a58760cd5c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1947,6 +1947,14 @@ class DatasetSuite extends QueryTest
df.where($"zoo".contains(Array('a', 'b'))),
Seq(Row("abc")))
}
+
+ test("SPARK-33469: Add current_timezone function") {
+ val df = Seq(1).toDF("c")
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Shanghai") {
+ val timezone = df.selectExpr("current_timezone()").collect().head.getString(0)
+ assert(timezone == "Asia/Shanghai")
+ }
+ }
}
object AssertExecutionId {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index b27c1145181bd..876f62803dc7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -577,38 +577,6 @@ class FileBasedDataSourceSuite extends QueryTest
}
}
- test("Option pathGlobFilter: filter files correctly") {
- withTempPath { path =>
- val dataDir = path.getCanonicalPath
- Seq("foo").toDS().write.text(dataDir)
- Seq("bar").toDS().write.mode("append").orc(dataDir)
- val df = spark.read.option("pathGlobFilter", "*.txt").text(dataDir)
- checkAnswer(df, Row("foo"))
-
- // Both glob pattern in option and path should be effective to filter files.
- val df2 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*.orc")
- checkAnswer(df2, Seq.empty)
-
- val df3 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*xt")
- checkAnswer(df3, Row("foo"))
- }
- }
-
- test("Option pathGlobFilter: simple extension filtering should contains partition info") {
- withTempPath { path =>
- val input = Seq(("foo", 1), ("oof", 2)).toDF("a", "b")
- input.write.partitionBy("b").text(path.getCanonicalPath)
- Seq("bar").toDS().write.mode("append").orc(path.getCanonicalPath + "/b=1")
-
- // If we use glob pattern in the path, the partition column won't be shown in the result.
- val df = spark.read.text(path.getCanonicalPath + "/*/*.txt")
- checkAnswer(df, input.select("a"))
-
- val df2 = spark.read.option("pathGlobFilter", "*.txt").text(path.getCanonicalPath)
- checkAnswer(df2, input)
- }
- }
-
test("Option recursiveFileLookup: recursive loading correctly") {
val expectedFileList = mutable.ListBuffer[String]()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala
index 107d0ea47249d..e05c2c09ace2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionsException, PartitionsAlreadyExistException}
import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits
+import org.apache.spark.sql.internal.SQLConf
class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase {
@@ -159,4 +160,29 @@ class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase {
assert(partTable.asPartitionable.listPartitionIdentifiers(InternalRow.empty).isEmpty)
}
}
+
+ test("case sensitivity in resolving partition specs") {
+ val t = "testpart.ns1.ns2.tbl"
+ withTable(t) {
+ spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)")
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val errMsg = intercept[AnalysisException] {
+ spark.sql(s"ALTER TABLE $t ADD PARTITION (ID=1) LOCATION 'loc1'")
+ }.getMessage
+ assert(errMsg.contains(s"ID is not a valid partition column in table $t"))
+ }
+
+ val partTable = catalog("testpart").asTableCatalog
+ .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl"))
+ .asPartitionable
+ assert(!partTable.partitionExists(InternalRow.fromSeq(Seq(1))))
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ spark.sql(s"ALTER TABLE $t ADD PARTITION (ID=1) LOCATION 'loc1'")
+ assert(partTable.partitionExists(InternalRow.fromSeq(Seq(1))))
+ spark.sql(s"ALTER TABLE $t DROP PARTITION (Id=1)")
+ assert(!partTable.partitionExists(InternalRow.fromSeq(Seq(1))))
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index ddafa1bb5070a..da53936239de8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.SimpleScanSource
import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
class DataSourceV2SQLSuite
@@ -43,7 +44,6 @@ class DataSourceV2SQLSuite
with AlterTableTests with DatasourceV2SQLBase {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
- import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
private val v2Source = classOf[FakeV2Provider].getName
override protected val v2Format = v2Source
@@ -782,6 +782,84 @@ class DataSourceV2SQLSuite
}
}
+ test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") {
+ Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t =>
+ val view = "view"
+ withTable(t) {
+ withTempView(view) {
+ sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source")
+ sql(s"CACHE TABLE $view AS SELECT id FROM $t")
+ checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source"))
+ checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id"))
+
+ sql(s"REPLACE TABLE $t USING foo AS SELECT id FROM source")
+ assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty)
+ }
+ }
+ }
+ }
+
+ test("SPARK-33492: AppendData should refresh cache") {
+ import testImplicits._
+
+ val t = "testcat.ns.t"
+ val view = "view"
+ withTable(t) {
+ withTempView(view) {
+ Seq((1, "a")).toDF("i", "j").write.saveAsTable(t)
+ sql(s"CACHE TABLE $view AS SELECT i FROM $t")
+ checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil)
+ checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil)
+
+ Seq((2, "b")).toDF("i", "j").write.mode(SaveMode.Append).saveAsTable(t)
+
+ assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined)
+ checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Row(2, "b") :: Nil)
+ checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Row(2) :: Nil)
+ }
+ }
+ }
+
+ test("SPARK-33492: OverwriteByExpression should refresh cache") {
+ val t = "testcat.ns.t"
+ val view = "view"
+ withTable(t) {
+ withTempView(view) {
+ sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source")
+ sql(s"CACHE TABLE $view AS SELECT id FROM $t")
+ checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source"))
+ checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id"))
+
+ sql(s"INSERT OVERWRITE TABLE $t VALUES (1, 'a')")
+
+ assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined)
+ checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil)
+ checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil)
+ }
+ }
+ }
+
+ test("SPARK-33492: OverwritePartitionsDynamic should refresh cache") {
+ import testImplicits._
+
+ val t = "testcat.ns.t"
+ val view = "view"
+ withTable(t) {
+ withTempView(view) {
+ Seq((1, "a", 1)).toDF("i", "j", "k").write.partitionBy("k") saveAsTable(t)
+ sql(s"CACHE TABLE $view AS SELECT i FROM $t")
+ checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a", 1) :: Nil)
+ checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil)
+
+ Seq((2, "b", 1)).toDF("i", "j", "k").writeTo(t).overwritePartitions()
+
+ assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined)
+ checkAnswer(sql(s"SELECT * FROM $t"), Row(2, "b", 1) :: Nil)
+ checkAnswer(sql(s"SELECT * FROM $view"), Row(2) :: Nil)
+ }
+ }
+ }
+
test("Relation: basic") {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
@@ -1980,57 +2058,6 @@ class DataSourceV2SQLSuite
}
}
- test("ALTER TABLE RECOVER PARTITIONS") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo")
- val e = intercept[AnalysisException] {
- sql(s"ALTER TABLE $t RECOVER PARTITIONS")
- }
- assert(e.message.contains("ALTER TABLE RECOVER PARTITIONS is only supported with v1 tables"))
- }
- }
-
- test("ALTER TABLE ADD PARTITION") {
- val t = "testpart.ns1.ns2.tbl"
- withTable(t) {
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)")
- spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'")
-
- val partTable = catalog("testpart").asTableCatalog
- .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable]
- assert(partTable.partitionExists(InternalRow.fromSeq(Seq(1))))
-
- val partMetadata = partTable.loadPartitionMetadata(InternalRow.fromSeq(Seq(1)))
- assert(partMetadata.containsKey("location"))
- assert(partMetadata.get("location") == "loc")
- }
- }
-
- test("ALTER TABLE RENAME PARTITION") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)")
- val e = intercept[AnalysisException] {
- sql(s"ALTER TABLE $t PARTITION (id=1) RENAME TO PARTITION (id=2)")
- }
- assert(e.message.contains("ALTER TABLE RENAME PARTITION is only supported with v1 tables"))
- }
- }
-
- test("ALTER TABLE DROP PARTITION") {
- val t = "testpart.ns1.ns2.tbl"
- withTable(t) {
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY (id)")
- spark.sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'")
- spark.sql(s"ALTER TABLE $t DROP PARTITION (id=1)")
-
- val partTable =
- catalog("testpart").asTableCatalog.loadTable(Identifier.of(Array("ns1", "ns2"), "tbl"))
- assert(!partTable.asPartitionable.partitionExists(InternalRow.fromSeq(Seq(1))))
- }
- }
-
test("ALTER TABLE SerDe properties") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
@@ -2387,7 +2414,8 @@ class DataSourceV2SQLSuite
withTempView("v") {
sql("create global temp view v as select 1")
val e = intercept[AnalysisException](sql("COMMENT ON TABLE global_temp.v IS NULL"))
- assert(e.getMessage.contains("global_temp.v is a temp view not table."))
+ assert(e.getMessage.contains(
+ "global_temp.v is a temp view. 'COMMENT ON TABLE' expects a table"))
}
}
@@ -2513,6 +2541,25 @@ class DataSourceV2SQLSuite
}
}
+ test("SPARK-33505: insert into partitioned table") {
+ val t = "testpart.ns1.ns2.tbl"
+ withTable(t) {
+ sql(s"""
+ |CREATE TABLE $t (id bigint, city string, data string)
+ |USING foo
+ |PARTITIONED BY (id, city)""".stripMargin)
+ val partTable = catalog("testpart").asTableCatalog
+ .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable]
+ val expectedPartitionIdent = InternalRow.fromSeq(Seq(1, UTF8String.fromString("NY")))
+ assert(!partTable.partitionExists(expectedPartitionIdent))
+ sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'abc'")
+ assert(partTable.partitionExists(expectedPartitionIdent))
+ // Insert into the existing partition must not fail
+ sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'def'")
+ assert(partTable.partitionExists(expectedPartitionIdent))
+ }
+ }
+
private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = {
val e = intercept[AnalysisException] {
sql(s"$sqlCommand $sqlParams")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 4b52a4cbf4116..cba7dd35fb3bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -24,14 +24,17 @@ import scala.collection.mutable
import org.scalatest.BeforeAndAfter
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources._
@@ -145,6 +148,52 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
SparkSession.setDefaultSession(spark)
}
}
+
+ test("SPARK-33492: append fallback should refresh cache") {
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ try {
+ val session = SparkSession.builder()
+ .master("local[1]")
+ .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName)
+ .getOrCreate()
+ val df = session.createDataFrame(Seq((1, "x")))
+ df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test")
+ session.catalog.cacheTable("test")
+ checkAnswer(session.read.table("test"), Row(1, "x") :: Nil)
+
+ val df2 = session.createDataFrame(Seq((2, "y")))
+ df2.writeTo("test").append()
+ checkAnswer(session.read.table("test"), Row(1, "x") :: Row(2, "y") :: Nil)
+
+ } finally {
+ SparkSession.setActiveSession(spark)
+ SparkSession.setDefaultSession(spark)
+ }
+ }
+
+ test("SPARK-33492: overwrite fallback should refresh cache") {
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ try {
+ val session = SparkSession.builder()
+ .master("local[1]")
+ .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName)
+ .getOrCreate()
+ val df = session.createDataFrame(Seq((1, "x")))
+ df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test")
+ session.catalog.cacheTable("test")
+ checkAnswer(session.read.table("test"), Row(1, "x") :: Nil)
+
+ val df2 = session.createDataFrame(Seq((2, "y")))
+ df2.writeTo("test").overwrite(lit(true))
+ checkAnswer(session.read.table("test"), Row(2, "y") :: Nil)
+
+ } finally {
+ SparkSession.setActiveSession(spark)
+ SparkSession.setDefaultSession(spark)
+ }
+ }
}
class V1WriteFallbackSessionCatalogSuite
@@ -177,6 +226,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV
properties: util.Map[String, String]): InMemoryTableWithV1Fallback = {
val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties)
InMemoryV1Provider.tables.put(name, t)
+ tables.put(Identifier.of(Array("default"), name), t)
t
}
}
@@ -272,7 +322,7 @@ class InMemoryTableWithV1Fallback(
override val partitioning: Array[Transform],
override val properties: util.Map[String, String])
extends Table
- with SupportsWrite {
+ with SupportsWrite with SupportsRead {
partitioning.foreach { t =>
if (!t.isInstanceOf[IdentityTransform]) {
@@ -281,6 +331,7 @@ class InMemoryTableWithV1Fallback(
}
override def capabilities: util.Set[TableCapability] = Set(
+ TableCapability.BATCH_READ,
TableCapability.V1_BATCH_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
TableCapability.TRUNCATE).asJava
@@ -338,6 +389,30 @@ class InMemoryTableWithV1Fallback(
}
}
}
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
+ new V1ReadFallbackScanBuilder(schema)
+
+ private class V1ReadFallbackScanBuilder(schema: StructType) extends ScanBuilder {
+ override def build(): Scan = new V1ReadFallbackScan(schema)
+ }
+
+ private class V1ReadFallbackScan(schema: StructType) extends V1Scan {
+ override def readSchema(): StructType = schema
+ override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T =
+ new V1TableScan(context, schema).asInstanceOf[T]
+ }
+
+ private class V1TableScan(
+ context: SQLContext,
+ requiredSchema: StructType) extends BaseRelation with TableScan {
+ override def sqlContext: SQLContext = context
+ override def schema: StructType = requiredSchema
+ override def buildScan(): RDD[Row] = {
+ val data = InMemoryV1Provider.getTableData(context.sparkSession, name).collect()
+ context.sparkContext.makeRDD(data)
+ }
+ }
}
/** A rule that fails if a query plan is analyzed twice. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
index 792f920ee0217..504cc57dc12d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
@@ -147,10 +147,10 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
s"'$viewName' is a view not a table")
assertAnalysisError(
s"ALTER TABLE $viewName ADD IF NOT EXISTS PARTITION (a='4', b='8')",
- s"$viewName is a temp view not table")
+ s"$viewName is a temp view. 'ALTER TABLE ... ADD PARTITION ...' expects a table")
assertAnalysisError(
s"ALTER TABLE $viewName DROP PARTITION (a='4', b='8')",
- s"$viewName is a temp view not table")
+ s"$viewName is a temp view. 'ALTER TABLE ... DROP PARTITION ...' expects a table")
// For the following v2 ALERT TABLE statements, unsupported operations are checked first
// before resolving the relations.
@@ -175,7 +175,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
val e2 = intercept[AnalysisException] {
sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""")
}.getMessage
- assert(e2.contains(s"$viewName is a temp view not table"))
+ assert(e2.contains(s"$viewName is a temp view. 'LOAD DATA' expects a table"))
assertNoSuchTable(s"TRUNCATE TABLE $viewName")
val e3 = intercept[AnalysisException] {
sql(s"SHOW CREATE TABLE $viewName")
@@ -214,7 +214,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
e = intercept[AnalysisException] {
sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""")
}.getMessage
- assert(e.contains("default.testView is a view not table"))
+ assert(e.contains("default.testView is a view. 'LOAD DATA' expects a table"))
e = intercept[AnalysisException] {
sql(s"TRUNCATE TABLE $viewName")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index 9d0147048dbb8..43a33860d262e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -3104,84 +3104,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assert(spark.sessionState.catalog.isRegisteredFunction(rand))
}
}
-
- test("SPARK-32481 Move data to trash on truncate table if enabled") {
- val trashIntervalKey = "fs.trash.interval"
- withTable("tab1") {
- withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "true") {
- sql("CREATE TABLE tab1 (col INT) USING parquet")
- sql("INSERT INTO tab1 SELECT 1")
- // scalastyle:off hadoopconfiguration
- val hadoopConf = spark.sparkContext.hadoopConfiguration
- // scalastyle:on hadoopconfiguration
- val originalValue = hadoopConf.get(trashIntervalKey, "0")
- val tablePath = new Path(spark.sessionState.catalog
- .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get)
-
- val fs = tablePath.getFileSystem(hadoopConf)
- val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current")
- val trashPath = Path.mergePaths(trashCurrent, tablePath)
- assume(
- fs.mkdirs(trashPath) && fs.delete(trashPath, false),
- "Trash directory could not be created, skipping.")
- assert(!fs.exists(trashPath))
- try {
- hadoopConf.set(trashIntervalKey, "5")
- sql("TRUNCATE TABLE tab1")
- } finally {
- hadoopConf.set(trashIntervalKey, originalValue)
- }
- assert(fs.exists(trashPath))
- fs.delete(trashPath, true)
- }
- }
- }
-
- test("SPARK-32481 delete data permanently on truncate table if trash interval is non-positive") {
- val trashIntervalKey = "fs.trash.interval"
- withTable("tab1") {
- withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "true") {
- sql("CREATE TABLE tab1 (col INT) USING parquet")
- sql("INSERT INTO tab1 SELECT 1")
- // scalastyle:off hadoopconfiguration
- val hadoopConf = spark.sparkContext.hadoopConfiguration
- // scalastyle:on hadoopconfiguration
- val originalValue = hadoopConf.get(trashIntervalKey, "0")
- val tablePath = new Path(spark.sessionState.catalog
- .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get)
-
- val fs = tablePath.getFileSystem(hadoopConf)
- val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current")
- val trashPath = Path.mergePaths(trashCurrent, tablePath)
- assert(!fs.exists(trashPath))
- try {
- hadoopConf.set(trashIntervalKey, "0")
- sql("TRUNCATE TABLE tab1")
- } finally {
- hadoopConf.set(trashIntervalKey, originalValue)
- }
- assert(!fs.exists(trashPath))
- }
- }
- }
-
- test("SPARK-32481 Do not move data to trash on truncate table if disabled") {
- withTable("tab1") {
- withSQLConf(SQLConf.TRUNCATE_TRASH_ENABLED.key -> "false") {
- sql("CREATE TABLE tab1 (col INT) USING parquet")
- sql("INSERT INTO tab1 SELECT 1")
- val hadoopConf = spark.sessionState.newHadoopConf()
- val tablePath = new Path(spark.sessionState.catalog
- .getTableMetadata(TableIdentifier("tab1")).storage.locationUri.get)
-
- val fs = tablePath.getFileSystem(hadoopConf)
- val trashCurrent = new Path(fs.getHomeDirectory, ".Trash/Current")
- val trashPath = Path.mergePaths(trashCurrent, tablePath)
- sql("TRUNCATE TABLE tab1")
- assert(!fs.exists(trashPath))
- }
- }
- }
}
object FakeLocalFsFileSystem {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala
new file mode 100644
index 0000000000000..b965a78c9eec0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterStrategySuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.test.SharedSparkSession
+
+class PathFilterStrategySuite extends QueryTest with SharedSparkSession {
+
+ test("SPARK-31962: PathFilterStrategies - modifiedAfter option") {
+ val options =
+ CaseInsensitiveMap[String](Map("modifiedAfter" -> "2010-10-01T01:01:00"))
+ val strategy = PathFilterFactory.create(options)
+ assert(strategy.head.isInstanceOf[ModifiedAfterFilter])
+ assert(strategy.size == 1)
+ }
+
+ test("SPARK-31962: PathFilterStrategies - modifiedBefore option") {
+ val options =
+ CaseInsensitiveMap[String](Map("modifiedBefore" -> "2020-10-01T01:01:00"))
+ val strategy = PathFilterFactory.create(options)
+ assert(strategy.head.isInstanceOf[ModifiedBeforeFilter])
+ assert(strategy.size == 1)
+ }
+
+ test("SPARK-31962: PathFilterStrategies - pathGlobFilter option") {
+ val options = CaseInsensitiveMap[String](Map("pathGlobFilter" -> "*.txt"))
+ val strategy = PathFilterFactory.create(options)
+ assert(strategy.head.isInstanceOf[PathGlobFilter])
+ assert(strategy.size == 1)
+ }
+
+ test("SPARK-31962: PathFilterStrategies - no options") {
+ val options = CaseInsensitiveMap[String](Map.empty)
+ val strategy = PathFilterFactory.create(options)
+ assert(strategy.isEmpty)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala
new file mode 100644
index 0000000000000..1af2adfd8640c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PathFilterSuite.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.io.File
+import java.time.{LocalDateTime, ZoneId, ZoneOffset}
+import java.time.format.DateTimeFormatter
+
+import scala.util.Random
+
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.util.{stringToFile, DateTimeUtils}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+class PathFilterSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ test("SPARK-31962: modifiedBefore specified" +
+ " and sharing same timestamp with file last modified time.") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formatTime(curTime)))
+ }
+ }
+
+ test("SPARK-31962: modifiedAfter specified" +
+ " and sharing same timestamp with file last modified time.") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ executeTest(dir, Seq(curTime), 0, modifiedAfter = Some(formatTime(curTime)))
+ }
+ }
+
+ test("SPARK-31962: modifiedBefore and modifiedAfter option" +
+ " share same timestamp with file last modified time.") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val formattedTime = formatTime(curTime)
+ executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime),
+ modifiedAfter = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedBefore and modifiedAfter option" +
+ " share same timestamp with earlier file last modified time.") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val fileTime = curTime.minusDays(3)
+ val formattedTime = formatTime(curTime)
+ executeTest(dir, Seq(fileTime), 0, modifiedBefore = Some(formattedTime),
+ modifiedAfter = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedBefore and modifiedAfter option" +
+ " share same timestamp with later file last modified time.") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val formattedTime = formatTime(curTime)
+ executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime),
+ modifiedAfter = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: when modifiedAfter specified with a past date") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val pastTime = curTime.minusYears(1)
+ val formattedTime = formatTime(pastTime)
+ executeTest(dir, Seq(curTime), 1, modifiedAfter = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: when modifiedBefore specified with a future date") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val futureTime = curTime.plusYears(1)
+ val formattedTime = formatTime(futureTime)
+ executeTest(dir, Seq(curTime), 1, modifiedBefore = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: with modifiedBefore option provided using a past date") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val pastTime = curTime.minusYears(1)
+ val formattedTime = formatTime(pastTime)
+ executeTest(dir, Seq(curTime), 0, modifiedBefore = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedAfter specified with a past date, multiple files, one valid") {
+ withTempDir { dir =>
+ val fileTime1 = LocalDateTime.now(ZoneOffset.UTC)
+ val fileTime2 = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC)
+ val pastTime = fileTime1.minusYears(1)
+ val formattedTime = formatTime(pastTime)
+ executeTest(dir, Seq(fileTime1, fileTime2), 1, modifiedAfter = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedAfter specified with a past date, multiple files, both valid") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val pastTime = curTime.minusYears(1)
+ val formattedTime = formatTime(pastTime)
+ executeTest(dir, Seq(curTime, curTime), 2, modifiedAfter = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedAfter specified with a past date, multiple files, none valid") {
+ withTempDir { dir =>
+ val fileTime = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC)
+ val pastTime = LocalDateTime.now(ZoneOffset.UTC).minusYears(1)
+ val formattedTime = formatTime(pastTime)
+ executeTest(dir, Seq(fileTime, fileTime), 0, modifiedAfter = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedBefore specified with a future date, multiple files, both valid") {
+ withTempDir { dir =>
+ val fileTime = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC)
+ val futureTime = LocalDateTime.now(ZoneOffset.UTC).plusYears(1)
+ val formattedTime = formatTime(futureTime)
+ executeTest(dir, Seq(fileTime, fileTime), 2, modifiedBefore = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedBefore specified with a future date, multiple files, one valid") {
+ withTempDir { dir =>
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val fileTime1 = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC)
+ val fileTime2 = curTime.plusDays(3)
+ val formattedTime = formatTime(curTime)
+ executeTest(dir, Seq(fileTime1, fileTime2), 1, modifiedBefore = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedBefore specified with a future date, multiple files, none valid") {
+ withTempDir { dir =>
+ val fileTime = LocalDateTime.now(ZoneOffset.UTC).minusDays(1)
+ val formattedTime = formatTime(fileTime)
+ executeTest(dir, Seq(fileTime, fileTime), 0, modifiedBefore = Some(formattedTime))
+ }
+ }
+
+ test("SPARK-31962: modifiedBefore/modifiedAfter is specified with an invalid date") {
+ executeTestWithBadOption(
+ Map("modifiedBefore" -> "2024-05+1 01:00:00"),
+ Seq("The timestamp provided", "modifiedbefore", "2024-05+1 01:00:00"))
+
+ executeTestWithBadOption(
+ Map("modifiedAfter" -> "2024-05+1 01:00:00"),
+ Seq("The timestamp provided", "modifiedafter", "2024-05+1 01:00:00"))
+ }
+
+ test("SPARK-31962: modifiedBefore/modifiedAfter - empty option") {
+ executeTestWithBadOption(
+ Map("modifiedBefore" -> ""),
+ Seq("The timestamp provided", "modifiedbefore"))
+
+ executeTestWithBadOption(
+ Map("modifiedAfter" -> ""),
+ Seq("The timestamp provided", "modifiedafter"))
+ }
+
+ test("SPARK-31962: modifiedBefore/modifiedAfter filter takes into account local timezone " +
+ "when specified as an option.") {
+ Seq("modifiedbefore", "modifiedafter").foreach { filterName =>
+ // CET = UTC + 1 hour, HST = UTC - 10 hours
+ Seq("CET", "HST").foreach { tzId =>
+ testModifiedDateFilterWithTimezone(tzId, filterName)
+ }
+ }
+ }
+
+ test("Option pathGlobFilter: filter files correctly") {
+ withTempPath { path =>
+ val dataDir = path.getCanonicalPath
+ Seq("foo").toDS().write.text(dataDir)
+ Seq("bar").toDS().write.mode("append").orc(dataDir)
+ val df = spark.read.option("pathGlobFilter", "*.txt").text(dataDir)
+ checkAnswer(df, Row("foo"))
+
+ // Both glob pattern in option and path should be effective to filter files.
+ val df2 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*.orc")
+ checkAnswer(df2, Seq.empty)
+
+ val df3 = spark.read.option("pathGlobFilter", "*.txt").text(dataDir + "/*xt")
+ checkAnswer(df3, Row("foo"))
+ }
+ }
+
+ test("Option pathGlobFilter: simple extension filtering should contains partition info") {
+ withTempPath { path =>
+ val input = Seq(("foo", 1), ("oof", 2)).toDF("a", "b")
+ input.write.partitionBy("b").text(path.getCanonicalPath)
+ Seq("bar").toDS().write.mode("append").orc(path.getCanonicalPath + "/b=1")
+
+ // If we use glob pattern in the path, the partition column won't be shown in the result.
+ val df = spark.read.text(path.getCanonicalPath + "/*/*.txt")
+ checkAnswer(df, input.select("a"))
+
+ val df2 = spark.read.option("pathGlobFilter", "*.txt").text(path.getCanonicalPath)
+ checkAnswer(df2, input)
+ }
+ }
+
+ private def executeTest(
+ dir: File,
+ fileDates: Seq[LocalDateTime],
+ expectedCount: Long,
+ modifiedBefore: Option[String] = None,
+ modifiedAfter: Option[String] = None): Unit = {
+ fileDates.foreach { fileDate =>
+ val file = createSingleFile(dir)
+ setFileTime(fileDate, file)
+ }
+
+ val schema = StructType(Seq(StructField("a", StringType)))
+
+ var dfReader = spark.read.format("csv").option("timeZone", "UTC").schema(schema)
+ modifiedBefore.foreach { opt => dfReader = dfReader.option("modifiedBefore", opt) }
+ modifiedAfter.foreach { opt => dfReader = dfReader.option("modifiedAfter", opt) }
+
+ if (expectedCount > 0) {
+ // without pathGlobFilter
+ val df1 = dfReader.load(dir.getCanonicalPath)
+ assert(df1.count() === expectedCount)
+
+ // pathGlobFilter matched
+ val df2 = dfReader.option("pathGlobFilter", "*.csv").load(dir.getCanonicalPath)
+ assert(df2.count() === expectedCount)
+
+ // pathGlobFilter mismatched
+ val df3 = dfReader.option("pathGlobFilter", "*.txt").load(dir.getCanonicalPath)
+ assert(df3.count() === 0)
+ } else {
+ val df = dfReader.load(dir.getCanonicalPath)
+ assert(df.count() === 0)
+ }
+ }
+
+ private def executeTestWithBadOption(
+ options: Map[String, String],
+ expectedMsgParts: Seq[String]): Unit = {
+ withTempDir { dir =>
+ createSingleFile(dir)
+ val exc = intercept[AnalysisException] {
+ var dfReader = spark.read.format("csv")
+ options.foreach { case (key, value) =>
+ dfReader = dfReader.option(key, value)
+ }
+ dfReader.load(dir.getCanonicalPath)
+ }
+ expectedMsgParts.foreach { msg => assert(exc.getMessage.contains(msg)) }
+ }
+ }
+
+ private def testModifiedDateFilterWithTimezone(
+ timezoneId: String,
+ filterParamName: String): Unit = {
+ val curTime = LocalDateTime.now(ZoneOffset.UTC)
+ val zoneId: ZoneId = DateTimeUtils.getTimeZone(timezoneId).toZoneId
+ val strategyTimeInMicros =
+ ModifiedDateFilter.toThreshold(
+ curTime.toString,
+ timezoneId,
+ filterParamName)
+ val strategyTimeInSeconds = strategyTimeInMicros / 1000 / 1000
+
+ val curTimeAsSeconds = curTime.atZone(zoneId).toEpochSecond
+ withClue(s"timezone: $timezoneId / param: $filterParamName,") {
+ assert(strategyTimeInSeconds === curTimeAsSeconds)
+ }
+ }
+
+ private def createSingleFile(dir: File): File = {
+ val file = new File(dir, "temp" + Random.nextInt(1000000) + ".csv")
+ stringToFile(file, "text")
+ }
+
+ private def setFileTime(time: LocalDateTime, file: File): Boolean = {
+ val sameTime = time.toEpochSecond(ZoneOffset.UTC)
+ file.setLastModified(sameTime * 1000)
+ }
+
+ private def formatTime(time: LocalDateTime): String = {
+ time.format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 9f62ff8301ebc..6085c1f2cccb0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -149,6 +149,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
"org.apache.spark.sql.catalyst.expressions.UnixTimestamp",
"org.apache.spark.sql.catalyst.expressions.CurrentDate",
"org.apache.spark.sql.catalyst.expressions.CurrentTimestamp",
+ "org.apache.spark.sql.catalyst.expressions.CurrentTimeZone",
"org.apache.spark.sql.catalyst.expressions.Now",
// Random output without a seed
"org.apache.spark.sql.catalyst.expressions.Rand",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index cf9664a9764be..718095003b096 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.streaming
import java.io.File
import java.net.URI
+import java.time.{LocalDateTime, ZoneOffset}
+import java.time.format.DateTimeFormatter
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
@@ -40,7 +42,6 @@ import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{StructType, _}
import org.apache.spark.util.Utils
@@ -2054,6 +2055,47 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
}
}
+ test("SPARK-31962: file stream source shouldn't allow modifiedBefore/modifiedAfter") {
+ def formatTime(time: LocalDateTime): String = {
+ time.format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss"))
+ }
+
+ def assertOptionIsNotSupported(options: Map[String, String], path: String): Unit = {
+ val schema = StructType(Seq(StructField("a", StringType)))
+ var dsReader = spark.readStream
+ .format("csv")
+ .option("timeZone", "UTC")
+ .schema(schema)
+
+ options.foreach { case (k, v) => dsReader = dsReader.option(k, v) }
+
+ val df = dsReader.load(path)
+
+ testStream(df)(
+ ExpectFailure[IllegalArgumentException](
+ t => assert(t.getMessage.contains("is not allowed in file stream source")),
+ isFatalError = false)
+ )
+ }
+
+ withTempDir { dir =>
+ // "modifiedBefore"
+ val futureTime = LocalDateTime.now(ZoneOffset.UTC).plusYears(1)
+ val formattedFutureTime = formatTime(futureTime)
+ assertOptionIsNotSupported(Map("modifiedBefore" -> formattedFutureTime), dir.getCanonicalPath)
+
+ // "modifiedAfter"
+ val prevTime = LocalDateTime.now(ZoneOffset.UTC).minusYears(1)
+ val formattedPrevTime = formatTime(prevTime)
+ assertOptionIsNotSupported(Map("modifiedAfter" -> formattedPrevTime), dir.getCanonicalPath)
+
+ // both
+ assertOptionIsNotSupported(
+ Map("modifiedBefore" -> formattedFutureTime, "modifiedAfter" -> formattedPrevTime),
+ dir.getCanonicalPath)
+ }
+ }
+
private def createFile(content: String, src: File, tmp: File): File = {
val tempFile = Utils.tempFileWith(new File(tmp, "text"))
val finalFile = new File(src, tempFile.getName)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala
index 1a8b28001b8d1..307479db33949 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala
@@ -139,7 +139,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
summaryText should contain ("Aggregated Number Of Total State Rows (?)")
summaryText should contain ("Aggregated Number Of Updated State Rows (?)")
summaryText should contain ("Aggregated State Memory Used In Bytes (?)")
- summaryText should contain ("Aggregated Number Of State Rows Dropped By Watermark (?)")
+ summaryText should contain ("Aggregated Number Of Rows Dropped By Watermark (?)")
}
} finally {
spark.streams.active.foreach(_.stop())
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index 2e9975bcabc3f..f7a4be9591818 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -63,6 +63,10 @@ private[hive] class SparkExecuteStatementOperation(
}
}
+ private val substitutorStatement = SQLConf.withExistingConf(sqlContext.conf) {
+ new VariableSubstitution().substitute(statement)
+ }
+
private var result: DataFrame = _
// We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST.
@@ -126,6 +130,17 @@ private[hive] class SparkExecuteStatementOperation(
}
def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = withLocalProperties {
+ try {
+ sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement)
+ getNextRowSetInternal(order, maxRowsL)
+ } finally {
+ sqlContext.sparkContext.clearJobGroup()
+ }
+ }
+
+ private def getNextRowSetInternal(
+ order: FetchOrientation,
+ maxRowsL: Long): RowSet = withLocalProperties {
log.info(s"Received getNextRowSet request order=${order} and maxRowsL=${maxRowsL} " +
s"with ${statementId}")
validateDefaultFetchOrientation(order)
@@ -306,9 +321,6 @@ private[hive] class SparkExecuteStatementOperation(
parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader)
}
- val substitutorStatement = SQLConf.withExistingConf(sqlContext.conf) {
- new VariableSubstitution().substitute(statement)
- }
sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement)
result = sqlContext.sql(statement)
logDebug(result.queryExecution.toString())
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
index 38a8c492d77a7..cf070f4611f3b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
@@ -52,7 +52,6 @@ import org.apache.spark.util.Utils
@ExtendedHiveTest
class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
import HiveExternalCatalogVersionsSuite._
- private val isTestAtLeastJava9 = SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)
private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse")
private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data")
// For local test, you can set `spark.test.cache-dir` to a static value like `/tmp/test-spark`, to
@@ -60,6 +59,11 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
private val sparkTestingDir = Option(System.getProperty(SPARK_TEST_CACHE_DIR_SYSTEM_PROPERTY))
.map(new File(_)).getOrElse(Utils.createTempDir(namePrefix = "test-spark"))
private val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val hiveVersion = if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) {
+ "2.3.7"
+ } else {
+ "1.2.1"
+ }
override def afterAll(): Unit = {
try {
@@ -149,7 +153,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8)
}
- private def prepare(): Unit = {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
val tempPyFile = File.createTempFile("test", ".py")
// scalastyle:off line.size.limit
Files.write(tempPyFile.toPath,
@@ -199,7 +205,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
"--master", "local[2]",
"--conf", s"${UI_ENABLED.key}=false",
"--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false",
- "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1",
+ "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=$hiveVersion",
"--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven",
"--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}",
"--conf", s"spark.sql.test.version.index=$index",
@@ -211,23 +217,14 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
tempPyFile.delete()
}
- override def beforeAll(): Unit = {
- super.beforeAll()
- if (!isTestAtLeastJava9) {
- prepare()
- }
- }
-
test("backward compatibility") {
- // TODO SPARK-28704 Test backward compatibility on JDK9+ once we have a version supports JDK9+
- assume(!isTestAtLeastJava9)
val args = Seq(
"--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"),
"--name", "HiveExternalCatalog backward compatibility test",
"--master", "local[2]",
"--conf", s"${UI_ENABLED.key}=false",
"--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false",
- "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1",
+ "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=$hiveVersion",
"--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven",
"--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}",
"--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}",
@@ -252,7 +249,9 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils {
// do not throw exception during object initialization.
case NonFatal(_) => Seq("3.0.1", "2.4.7") // A temporary fallback to use a specific version
}
- versions.filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38())
+ versions
+ .filter(v => v.startsWith("3") || !TestUtils.isPythonVersionAtLeast38())
+ .filter(v => v.startsWith("3") || !SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9))
}
protected var spark: SparkSession = _
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 1f15bd685b239..56b871644453b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -904,10 +904,10 @@ class HiveDDLSuite
assertAnalysisError(
s"ALTER TABLE $oldViewName ADD IF NOT EXISTS PARTITION (a='4', b='8')",
- s"$oldViewName is a view not table")
+ s"$oldViewName is a view. 'ALTER TABLE ... ADD PARTITION ...' expects a table.")
assertAnalysisError(
s"ALTER TABLE $oldViewName DROP IF EXISTS PARTITION (a='2')",
- s"$oldViewName is a view not table")
+ s"$oldViewName is a view. 'ALTER TABLE ... DROP PARTITION ...' expects a table.")
assert(catalog.tableExists(TableIdentifier(tabName)))
assert(catalog.tableExists(TableIdentifier(oldViewName)))