diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a1cb4781f4d0a..34691883bc5a9 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2954,6 +2954,9 @@ setMethod("exceptAll",
#' @param source a name for external data source.
#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore'
#' save mode (it is 'error' by default)
+#' @param partitionBy a name or a list of names of columns to partition the output by on the file
+#' system. If specified, the output is laid out on the file system similar
+#' to Hive's partitioning scheme.
#' @param ... additional argument(s) passed to the method.
#'
#' @family SparkDataFrame functions
@@ -2965,13 +2968,13 @@ setMethod("exceptAll",
#' sparkR.session()
#' path <- "path/to/file.json"
#' df <- read.json(path)
-#' write.df(df, "myfile", "parquet", "overwrite")
+#' write.df(df, "myfile", "parquet", "overwrite", partitionBy = c("col1", "col2"))
#' saveDF(df, parquetPath2, "parquet", mode = "append", mergeSchema = TRUE)
#' }
#' @note write.df since 1.4.0
setMethod("write.df",
signature(df = "SparkDataFrame"),
- function(df, path = NULL, source = NULL, mode = "error", ...) {
+ function(df, path = NULL, source = NULL, mode = "error", partitionBy = NULL, ...) {
if (!is.null(path) && !is.character(path)) {
stop("path should be character, NULL or omitted.")
}
@@ -2985,8 +2988,18 @@ setMethod("write.df",
if (is.null(source)) {
source <- getDefaultSqlSource()
}
+ cols <- NULL
+ if (!is.null(partitionBy)) {
+ if (!all(sapply(partitionBy, function(c) is.character(c)))) {
+ stop("All partitionBy column names should be characters.")
+ }
+ cols <- as.list(partitionBy)
+ }
write <- callJMethod(df@sdf, "write")
write <- callJMethod(write, "format", source)
+ if (!is.null(cols)) {
+ write <- callJMethod(write, "partitionBy", cols)
+ }
write <- setWriteOptions(write, path = path, mode = mode, ...)
write <- handledCallJMethod(write, "save")
})
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 572dee50127b8..6425c9d26bef3 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -198,8 +198,9 @@ NULL
#' }
#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains
#' additional named properties to control how it is converted, accepts the same
-#' options as the JSON data source. In \code{arrays_zip}, this contains additional
-#' Columns of arrays to be merged.
+#' options as the JSON data source. Additionally \code{to_json} supports the "pretty"
+#' option which enables pretty JSON generation. In \code{arrays_zip}, this contains
+#' additional Columns of arrays to be merged.
#' @name column_collection_functions
#' @rdname column_collection_functions
#' @family collection functions
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index d3a9cbae7d808..038fefadaaeff 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -626,6 +626,8 @@ sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-pat
sparkConfToSubmitOps[["spark.master"]] <- "--master"
sparkConfToSubmitOps[["spark.yarn.keytab"]] <- "--keytab"
sparkConfToSubmitOps[["spark.yarn.principal"]] <- "--principal"
+sparkConfToSubmitOps[["spark.kerberos.keytab"]] <- "--keytab"
+sparkConfToSubmitOps[["spark.kerberos.principal"]] <- "--principal"
# Utility function that returns Spark Submit arguments as a string
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index a874bfbb58dc7..50eff3755edf8 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -2701,8 +2701,16 @@ test_that("read/write text files", {
expect_equal(colnames(df2), c("value"))
expect_equal(count(df2), count(df) * 2)
+ df3 <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")),
+ schema = c("key", "value"))
+ textPath3 <- tempfile(pattern = "textPath3", fileext = ".txt")
+ write.df(df3, textPath3, "text", mode = "overwrite", partitionBy = "key")
+ df4 <- read.df(textPath3, "text")
+ expect_equal(count(df3), count(df4))
+
unlink(textPath)
unlink(textPath2)
+ unlink(textPath3)
})
test_that("read/write text files - compression option", {
diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd
index 090363c5f8a3e..ad934947437bc 100644
--- a/R/pkg/vignettes/sparkr-vignettes.Rmd
+++ b/R/pkg/vignettes/sparkr-vignettes.Rmd
@@ -157,8 +157,8 @@ Property Name | Property group | spark-submit equivalent
`spark.driver.extraClassPath` | Runtime Environment | `--driver-class-path`
`spark.driver.extraJavaOptions` | Runtime Environment | `--driver-java-options`
`spark.driver.extraLibraryPath` | Runtime Environment | `--driver-library-path`
-`spark.yarn.keytab` | Application Properties | `--keytab`
-`spark.yarn.principal` | Application Properties | `--principal`
+`spark.kerberos.keytab` | Application Properties | `--keytab`
+`spark.kerberos.principal` | Application Properties | `--principal`
**For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`.
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 6c4c5c94cfa28..e0f98f1aca071 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -726,7 +726,11 @@ private[spark] object SparkConf extends Logging {
DRIVER_MEMORY_OVERHEAD.key -> Seq(
AlternateConfig("spark.yarn.driver.memoryOverhead", "2.3")),
EXECUTOR_MEMORY_OVERHEAD.key -> Seq(
- AlternateConfig("spark.yarn.executor.memoryOverhead", "2.3"))
+ AlternateConfig("spark.yarn.executor.memoryOverhead", "2.3")),
+ KEYTAB.key -> Seq(
+ AlternateConfig("spark.yarn.keytab", "2.5")),
+ PRINCIPAL.key -> Seq(
+ AlternateConfig("spark.yarn.principal", "2.5"))
)
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index cf902db8709e7..d5f2865f87281 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -520,6 +520,10 @@ private[spark] class SparkSubmit extends Logging {
confKey = "spark.driver.extraJavaOptions"),
OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
confKey = "spark.driver.extraLibraryPath"),
+ OptionAssigner(args.principal, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
+ confKey = PRINCIPAL.key),
+ OptionAssigner(args.keytab, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
+ confKey = KEYTAB.key),
// Propagate attributes for dependency resolution at the driver side
OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.packages"),
@@ -537,8 +541,6 @@ private[spark] class SparkSubmit extends Logging {
OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"),
OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"),
OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"),
- OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.principal"),
- OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"),
// Other options
OptionAssigner(args.executorCores, STANDALONE | YARN | KUBERNETES, ALL_DEPLOY_MODES,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 0998757715457..4cf08a7980f55 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -199,8 +199,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
numExecutors = Option(numExecutors)
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull
- keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull
- principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull
+ keytab = Option(keytab)
+ .orElse(sparkProperties.get("spark.kerberos.keytab"))
+ .orElse(sparkProperties.get("spark.yarn.keytab"))
+ .orNull
+ principal = Option(principal)
+ .orElse(sparkProperties.get("spark.kerberos.principal"))
+ .orElse(sparkProperties.get("spark.yarn.principal"))
+ .orNull
dynamicAllocationEnabled =
sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase)
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 44d23908146c7..c23a659e76df1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy.history
import java.io.{File, FileNotFoundException, IOException}
import java.nio.file.Files
-import java.nio.file.attribute.PosixFilePermissions
import java.util.{Date, ServiceLoader}
import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future, TimeUnit}
import java.util.zip.{ZipEntry, ZipOutputStream}
@@ -133,9 +132,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
// Visible for testing.
private[history] val listing: KVStore = storePath.map { path =>
- val perms = PosixFilePermissions.fromString("rwx------")
- val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(),
- PosixFilePermissions.asFileAttribute(perms)).toFile()
+ val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath()).toFile()
+ Utils.chmod700(dbPath)
val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION,
AppStatusStore.CURRENT_VERSION, logDir.toString())
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala
index c03a360b91ef8..ad0dd23cb59c8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala
@@ -18,8 +18,6 @@
package org.apache.spark.deploy.history
import java.io.File
-import java.nio.file.Files
-import java.nio.file.attribute.PosixFilePermissions
import java.util.concurrent.atomic.AtomicLong
import scala.collection.JavaConverters._
@@ -107,9 +105,8 @@ private class HistoryServerDiskManager(
val needed = approximateSize(eventLogSize, isCompressed)
makeRoom(needed)
- val perms = PosixFilePermissions.fromString("rwx------")
- val tmp = Files.createTempDirectory(tmpStoreDir.toPath(), "appstore",
- PosixFilePermissions.asFileAttribute(perms)).toFile()
+ val tmp = Utils.createTempDir(tmpStoreDir.getPath(), "appstore")
+ Utils.chmod700(tmp)
updateUsage(needed)
val current = currentUsage.get()
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9891b6a2196de..7f6342208350a 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -152,11 +152,11 @@ package object config {
private[spark] val SHUFFLE_SERVICE_PORT =
ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337)
- private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab")
+ private[spark] val KEYTAB = ConfigBuilder("spark.kerberos.keytab")
.doc("Location of user's keytab.")
.stringConf.createOptional
- private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal")
+ private[spark] val PRINCIPAL = ConfigBuilder("spark.kerberos.principal")
.doc("Name of the Kerberos principal.")
.stringConf.createOptional
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 659694dd189ad..0e221edf3965a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -49,10 +49,16 @@ private[spark] sealed trait MapStatus {
private[spark] object MapStatus {
+ /**
+ * Min partition number to use [[HighlyCompressedMapStatus]]. A bit ugly here because in test
+ * code we can't assume SparkEnv.get exists.
+ */
+ private lazy val minPartitionsToUseHighlyCompressMapStatus = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
+ .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
+
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
- if (uncompressedSizes.length > Option(SparkEnv.get)
- .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
- .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) {
+ if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
HighlyCompressedMapStatus(loc, uncompressedSizes)
} else {
new CompressedMapStatus(loc, uncompressedSizes)
diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
index f21eee1965761..36aaf67b57298 100644
--- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
@@ -388,10 +388,11 @@ private[spark] class AppStatusListener(
job.completionTime = if (event.time > 0) Some(new Date(event.time)) else None
update(job, now, last = true)
+ if (job.status == JobExecutionStatus.SUCCEEDED) {
+ appSummary = new AppSummary(appSummary.numCompletedJobs + 1, appSummary.numCompletedStages)
+ kvstore.write(appSummary)
+ }
}
-
- appSummary = new AppSummary(appSummary.numCompletedJobs + 1, appSummary.numCompletedStages)
- kvstore.write(appSummary)
}
override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = {
@@ -653,13 +654,14 @@ private[spark] class AppStatusListener(
if (removeStage) {
liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))
}
+ if (stage.status == v1.StageStatus.COMPLETE) {
+ appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1)
+ kvstore.write(appSummary)
+ }
}
// remove any dead executors that were not running for any currently active stages
deadExecutors.retain((execId, exec) => isExecutorActiveForLiveStages(exec))
-
- appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1)
- kvstore.write(appSummary)
}
private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 22341467add5c..0fe82ac0cedc5 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -438,10 +438,8 @@ private[spark] class BlockManager(
// stream.
channel.close()
// TODO SPARK-25035 Even if we're only going to write the data to disk after this, we end up
- // using a lot of memory here. With encryption, we'll read the whole file into a regular
- // byte buffer and OOM. Without encryption, we'll memory map the file and won't get a jvm
- // OOM, but might get killed by the OS / cluster manager. We could at least read the tmp
- // file as a stream in both cases.
+ // using a lot of memory here. We'll read the whole file into a regular
+ // byte buffer and OOM. We could at least read the tmp file as a stream.
val buffer = securityManager.getIOEncryptionKey() match {
case Some(key) =>
// we need to pass in the size of the unencrypted block
@@ -453,7 +451,7 @@ private[spark] class BlockManager(
new EncryptedBlockData(tmpFile, blockSize, conf, key).toChunkedByteBuffer(allocator)
case None =>
- ChunkedByteBuffer.map(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt)
+ ChunkedByteBuffer.fromFile(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt)
}
putBytes(blockId, buffer, level)(classTag)
tmpFile.delete()
@@ -726,10 +724,9 @@ private[spark] class BlockManager(
*/
def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
// TODO if we change this method to return the ManagedBuffer, then getRemoteValues
- // could just use the inputStream on the temp file, rather than memory-mapping the file.
+ // could just use the inputStream on the temp file, rather than reading the file into memory.
// Until then, replication can cause the process to use too much memory and get killed
- // by the OS / cluster manager (not a java OOM, since it's a memory-mapped file) even though
- // we've read the data to disk.
+ // even though we've read the data to disk.
logDebug(s"Getting remote block $blockId")
require(blockId != null, "BlockId is null")
var runningFailureCount = 0
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index e534c746433f2..aecc2284a9588 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -448,35 +448,35 @@ final class ShuffleBlockFetcherIterator(
buf.release()
throwFetchFailedException(blockId, address, e)
}
-
- input = streamWrapper(blockId, in)
- // Only copy the stream if it's wrapped by compression or encryption, also the size of
- // block is small (the decompressed block is smaller than maxBytesInFlight)
- if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
- val originalInput = input
- val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
- try {
+ var isStreamCopied: Boolean = false
+ try {
+ input = streamWrapper(blockId, in)
+ // Only copy the stream if it's wrapped by compression or encryption, also the size of
+ // block is small (the decompressed block is smaller than maxBytesInFlight)
+ if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
+ isStreamCopied = true
+ val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
// Decompress the whole block at once to detect any corruption, which could increase
// the memory usage tne potential increase the chance of OOM.
// TODO: manage the memory used here, and spill it into disk in case of OOM.
- Utils.copyStream(input, out)
- out.close()
+ Utils.copyStream(input, out, closeStreams = true)
input = out.toChunkedByteBuffer.toInputStream(dispose = true)
- } catch {
- case e: IOException =>
- buf.release()
- if (buf.isInstanceOf[FileSegmentManagedBuffer]
- || corruptedBlocks.contains(blockId)) {
- throwFetchFailedException(blockId, address, e)
- } else {
- logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
- corruptedBlocks += blockId
- fetchRequests += FetchRequest(address, Array((blockId, size)))
- result = null
- }
- } finally {
- // TODO: release the buf here to free memory earlier
- originalInput.close()
+ }
+ } catch {
+ case e: IOException =>
+ buf.release()
+ if (buf.isInstanceOf[FileSegmentManagedBuffer]
+ || corruptedBlocks.contains(blockId)) {
+ throwFetchFailedException(blockId, address, e)
+ } else {
+ logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
+ corruptedBlocks += blockId
+ fetchRequests += FetchRequest(address, Array((blockId, size)))
+ result = null
+ }
+ } finally {
+ // TODO: release the buf here to free memory earlier
+ if (isStreamCopied) {
in.close()
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index 39f050f6ca5ad..4aa8d45ec7404 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -19,17 +19,16 @@ package org.apache.spark.util.io
import java.io.{File, FileInputStream, InputStream}
import java.nio.ByteBuffer
-import java.nio.channels.{FileChannel, WritableByteChannel}
-import java.nio.file.StandardOpenOption
-
-import scala.collection.mutable.ListBuffer
+import java.nio.channels.WritableByteChannel
+import com.google.common.io.ByteStreams
import com.google.common.primitives.UnsignedBytes
+import org.apache.commons.io.IOUtils
import org.apache.spark.SparkEnv
import org.apache.spark.internal.config
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.util.ByteArrayWritableChannel
+import org.apache.spark.network.util.{ByteArrayWritableChannel, LimitedInputStream}
import org.apache.spark.storage.StorageUtils
import org.apache.spark.util.Utils
@@ -175,30 +174,36 @@ object ChunkedByteBuffer {
def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = {
data match {
case f: FileSegmentManagedBuffer =>
- map(f.getFile, maxChunkSize, f.getOffset, f.getLength)
+ fromFile(f.getFile, maxChunkSize, f.getOffset, f.getLength)
case other =>
new ChunkedByteBuffer(other.nioByteBuffer())
}
}
- def map(file: File, maxChunkSize: Int): ChunkedByteBuffer = {
- map(file, maxChunkSize, 0, file.length())
+ def fromFile(file: File, maxChunkSize: Int): ChunkedByteBuffer = {
+ fromFile(file, maxChunkSize, 0, file.length())
}
- def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = {
- Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel =>
- var remaining = length
- var pos = offset
- val chunks = new ListBuffer[ByteBuffer]()
- while (remaining > 0) {
- val chunkSize = math.min(remaining, maxChunkSize)
- val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize)
- pos += chunkSize
- remaining -= chunkSize
- chunks += chunk
- }
- new ChunkedByteBuffer(chunks.toArray)
+ private def fromFile(
+ file: File,
+ maxChunkSize: Int,
+ offset: Long,
+ length: Long): ChunkedByteBuffer = {
+ // We do *not* memory map the file, because we may end up putting this into the memory store,
+ // and spark currently is not expecting memory-mapped buffers in the memory store, it conflicts
+ // with other parts that manage the lifecyle of buffers and dispose them. See SPARK-25422.
+ val is = new FileInputStream(file)
+ ByteStreams.skipFully(is, offset)
+ val in = new LimitedInputStream(is, length)
+ val chunkSize = math.min(maxChunkSize, length).toInt
+ val out = new ChunkedByteBufferOutputStream(chunkSize, ByteBuffer.allocate _)
+ Utils.tryWithSafeFinally {
+ IOUtils.copy(in, out)
+ } {
+ in.close()
+ out.close()
}
+ out.toChunkedByteBuffer
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index 354e6386fa60e..2155a0f2b6c21 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -188,32 +188,4 @@ class MapStatusSuite extends SparkFunSuite {
assert(count === 3000)
}
}
-
- test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") {
- val conf = new SparkConf()
- val env = mock(classOf[SparkEnv])
- doReturn(conf).when(env).conf
- SparkEnv.set(env)
- val sizes = Array.fill[Long](500)(150L)
- // Test default value
- val status = MapStatus(null, sizes)
- assert(status.isInstanceOf[CompressedMapStatus])
- // Test Non-positive values
- for (s <- -1 to 0) {
- assertThrows[IllegalArgumentException] {
- conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s)
- val status = MapStatus(null, sizes)
- }
- }
- // Test positive values
- Seq(1, 100, 499, 500, 501).foreach { s =>
- conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s)
- val status = MapStatus(null, sizes)
- if(sizes.length > s) {
- assert(status.isInstanceOf[HighlyCompressedMapStatus])
- } else {
- assert(status.isInstanceOf[CompressedMapStatus])
- }
- }
- }
}
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index e3d67c34d53eb..687f9e46c3285 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -465,7 +465,7 @@ providers can be disabled individually by setting `spark.security.credentials.{s
Property Name | Default | Meaning |
- spark.yarn.keytab |
+ spark.kerberos.keytab |
(none) |
The full path to the file that contains the keytab for the principal specified above. This keytab
@@ -477,7 +477,7 @@ providers can be disabled individually by setting `spark.security.credentials.{s
|
- spark.yarn.principal |
+ spark.kerberos.principal |
(none) |
Principal to be used to login to KDC, while running on secure clusters. Equivalent to the
diff --git a/docs/sparkr.md b/docs/sparkr.md
index b4248e8bb21de..55e8f15da17ca 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -70,12 +70,12 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s
| --master |
- spark.yarn.keytab |
+ spark.kerberos.keytab |
Application Properties |
--keytab |
- spark.yarn.principal |
+ spark.kerberos.principal |
Application Properties |
--principal |
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index c72fa3d75d67f..6de9de90c62c3 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1004,6 +1004,17 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
+
+ spark.sql.parquet.writeLegacyFormat |
+ false |
+
+ If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal values
+ will be written in Apache Parquet's fixed-length byte array format, which other systems such as
+ Apache Hive and Apache Impala use. If false, the newer format in Parquet will be used. For
+ example, decimals will be written in int-based format. If Parquet output is intended for use
+ with systems that do not support this newer format, set to true.
+ |
+
## ORC Files
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
index ceb9e318b283b..7b1314bc8c3c0 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -134,6 +134,8 @@ private[kafka010] case class InternalKafkaConsumer(
/** Reset the internal pre-fetched data. */
def reset(): Unit = {
_records = ju.Collections.emptyListIterator()
+ _nextOffsetInFetchedData = UNKNOWN_OFFSET
+ _offsetAfterPoll = UNKNOWN_OFFSET
}
/**
@@ -361,8 +363,9 @@ private[kafka010] case class InternalKafkaConsumer(
if (offset < fetchedData.offsetAfterPoll) {
// Offsets in [offset, fetchedData.offsetAfterPoll) are invisible. Return a record to ask
// the next call to start from `fetchedData.offsetAfterPoll`.
+ val nextOffsetToFetch = fetchedData.offsetAfterPoll
fetchedData.reset()
- return fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll)
+ return fetchedRecord.withRecord(null, nextOffsetToFetch)
} else {
// Fetch records from Kafka and update `fetchedData`.
fetchData(offset, pollTimeoutMs)
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index e5f008804ee5b..39c2cde7de40d 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -874,6 +874,58 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
)
}
}
+
+ test("SPARK-25495: FetchedData.reset should reset all fields") {
+ val topic = newTopic()
+ val topicPartition = new TopicPartition(topic, 0)
+ testUtils.createTopic(topic, partitions = 1)
+
+ val ds = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("kafka.isolation.level", "read_committed")
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+ .select($"value".as[String])
+
+ testUtils.withTranscationalProducer { producer =>
+ producer.beginTransaction()
+ (0 to 3).foreach { i =>
+ producer.send(new ProducerRecord[String, String](topic, i.toString)).get()
+ }
+ producer.commitTransaction()
+ }
+ testUtils.waitUntilOffsetAppears(topicPartition, 5)
+
+ val q = ds.writeStream.foreachBatch { (ds, epochId) =>
+ if (epochId == 0) {
+ // Send more message before the tasks of the current batch start reading the current batch
+ // data, so that the executors will prefetch messages in the next batch and drop them. In
+ // this case, if we forget to reset `FetchedData._nextOffsetInFetchedData` or
+ // `FetchedData._offsetAfterPoll` (See SPARK-25495), the next batch will see incorrect
+ // values and return wrong results hence fail the test.
+ testUtils.withTranscationalProducer { producer =>
+ producer.beginTransaction()
+ (4 to 7).foreach { i =>
+ producer.send(new ProducerRecord[String, String](topic, i.toString)).get()
+ }
+ producer.commitTransaction()
+ }
+ testUtils.waitUntilOffsetAppears(topicPartition, 10)
+ checkDatasetUnorderly(ds, (0 to 3).map(_.toString): _*)
+ } else {
+ checkDatasetUnorderly(ds, (4 to 7).map(_.toString): _*)
+ }
+ }.start()
+ try {
+ q.processAllAvailable()
+ } finally {
+ q.stop()
+ }
+ }
}
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6da5237d18de4..1c3d9725b285b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2295,7 +2295,9 @@ def to_json(col, options={}):
into a JSON string. Throws an exception, in the case of an unsupported type.
:param col: name of column containing a struct, an array or a map.
- :param options: options to control converting. accepts the same options as the JSON datasource
+ :param options: options to control converting. accepts the same options as the JSON datasource.
+ Additionally the function supports the `pretty` option which enables
+ pretty JSON generation.
>>> from pyspark.sql import Row
>>> from pyspark.sql.types import *
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 9877d6251388b..263cbc56f4799 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -141,8 +141,9 @@ def enableHiveSupport(self):
return self.config("spark.sql.catalogImplementation", "hive")
def _sparkContext(self, sc):
- self._sc = sc
- return self
+ with self._lock:
+ self._sc = sc
+ return self
@since(2.0)
def getOrCreate(self):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 95e88772ca5de..64a7ceb3fea96 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -5846,7 +5846,8 @@ def test_positional_assignment_conf(self):
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
- with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}):
+ with self.sql_conf({
+ "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}):
@pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP)
def foo(_):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 974344f01d923..8c59f1f999f18 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -97,8 +97,9 @@ def verify_result_length(*a):
def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
- assign_cols_by_pos = runner_conf.get(
- "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False)
+ assign_cols_by_name = runner_conf.get(
+ "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")
+ assign_cols_by_name = assign_cols_by_name.lower() == "true"
def wrapped(key_series, value_series):
import pandas as pd
@@ -119,7 +120,7 @@ def wrapped(key_series, value_series):
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
# Assign result columns by schema name if user labeled with strings, else use position
- if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns):
+ if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns):
return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type]
else:
return [(result[result.columns[i]], to_arrow_type(field.dataType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index e511f8064e28a..82692334544e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -290,11 +290,13 @@ object DecimalPrecision extends TypeCoercionRule {
// potentially loosing 11 digits of the fractional part. Using only the precision needed
// by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would
// become DECIMAL(38, 16), safely having a much lower precision loss.
- case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType]
- && l.dataType.isInstanceOf[IntegralType] =>
+ case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
+ l.dataType.isInstanceOf[IntegralType] &&
+ SQLConf.get.literalPickMinimumPrecision =>
b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r))
- case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType]
- && r.dataType.isInstanceOf[IntegralType] =>
+ case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
+ r.dataType.isInstanceOf[IntegralType] &&
+ SQLConf.get.literalPickMinimumPrecision =>
b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 8b69a47036962..7dafebff79874 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -300,15 +300,6 @@ object FunctionRegistry {
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
- expression[RegrCount]("regr_count"),
- expression[RegrSXX]("regr_sxx"),
- expression[RegrSYY]("regr_syy"),
- expression[RegrAvgX]("regr_avgx"),
- expression[RegrAvgY]("regr_avgy"),
- expression[RegrSXY]("regr_sxy"),
- expression[RegrSlope]("regr_slope"),
- expression[RegrR2]("regr_r2"),
- expression[RegrIntercept]("regr_intercept"),
// string functions
expression[Ascii]("ascii"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 7420b6b57d8e1..a7e09eee617e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable
+
protected class AttributeEquals(val a: Attribute) {
override def hashCode(): Int = a match {
@@ -39,10 +41,13 @@ object AttributeSet {
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
def apply(baseSet: Iterable[Expression]): AttributeSet = {
- new AttributeSet(
- baseSet
- .flatMap(_.references)
- .map(new AttributeEquals(_)).toSet)
+ fromAttributeSets(baseSet.map(_.references))
+ }
+
+ /** Constructs a new [[AttributeSet]] given a sequence of [[AttributeSet]]s. */
+ def fromAttributeSets(sets: Iterable[AttributeSet]): AttributeSet = {
+ val baseSet = sets.foldLeft(new mutable.LinkedHashSet[AttributeEquals]())( _ ++= _.baseSet)
+ new AttributeSet(baseSet.toSet)
}
}
@@ -94,8 +99,14 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
* Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
* in `other`.
*/
- def --(other: Traversable[NamedExpression]): AttributeSet =
- new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
+ def --(other: Traversable[NamedExpression]): AttributeSet = {
+ other match {
+ case otherSet: AttributeSet =>
+ new AttributeSet(baseSet -- otherSet.baseSet)
+ case _ =>
+ new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
+ }
+ }
/**
* Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 773aefc0ac1f9..c215735ab1c98 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -85,7 +85,7 @@ abstract class Expression extends TreeNode[Expression] {
def nullable: Boolean
- def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
+ def references: AttributeSet = AttributeSet.fromAttributeSets(children.map(_.references))
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: InternalRow = null): Any
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
deleted file mode 100644
index d8f4505588ff2..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
+++ /dev/null
@@ -1,190 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{AbstractDataType, DoubleType}
-
-/**
- * Base trait for all regression functions.
- */
-trait RegrLike extends AggregateFunction with ImplicitCastInputTypes {
- def y: Expression
- def x: Expression
-
- override def children: Seq[Expression] = Seq(y, x)
- override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
-
- protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = {
- assert(aggBufferAttributes.length == exprs.length)
- val nullableChildren = children.filter(_.nullable)
- if (nullableChildren.isEmpty) {
- exprs
- } else {
- exprs.zip(aggBufferAttributes).map { case (e, a) =>
- If(nullableChildren.map(IsNull).reduce(Or), a, e)
- }
- }
- }
-}
-
-
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns the number of non-null pairs.",
- since = "2.4.0")
-case class RegrCount(y: Expression, x: Expression)
- extends CountLike with RegrLike {
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L))
-
- override def prettyName: String = "regr_count"
-}
-
-
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.",
- since = "2.4.0")
-case class RegrSXX(y: Expression, x: Expression)
- extends CentralMomentAgg(x) with RegrLike {
-
- override protected def momentOrder = 2
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override val evaluateExpression: Expression = {
- If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
- }
-
- override def prettyName: String = "regr_sxx"
-}
-
-
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.",
- since = "2.4.0")
-case class RegrSYY(y: Expression, x: Expression)
- extends CentralMomentAgg(y) with RegrLike {
-
- override protected def momentOrder = 2
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override val evaluateExpression: Expression = {
- If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
- }
-
- override def prettyName: String = "regr_syy"
-}
-
-
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.",
- since = "2.4.0")
-case class RegrAvgX(y: Expression, x: Expression)
- extends AverageLike(x) with RegrLike {
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override def prettyName: String = "regr_avgx"
-}
-
-
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.",
- since = "2.4.0")
-case class RegrAvgY(y: Expression, x: Expression)
- extends AverageLike(y) with RegrLike {
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override def prettyName: String = "regr_avgy"
-}
-
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.",
- since = "2.4.0")
-// scalastyle:on line.size.limit
-case class RegrSXY(y: Expression, x: Expression)
- extends Covariance(y, x) with RegrLike {
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override val evaluateExpression: Expression = {
- If(n === Literal(0.0), Literal.create(null, DoubleType), ck)
- }
-
- override def prettyName: String = "regr_sxy"
-}
-
-
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.",
- since = "2.4.0")
-// scalastyle:on line.size.limit
-case class RegrSlope(y: Expression, x: Expression)
- extends PearsonCorrelation(y, x) with RegrLike {
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override val evaluateExpression: Expression = {
- If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk)
- }
-
- override def prettyName: String = "regr_slope"
-}
-
-
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.",
- since = "2.4.0")
-// scalastyle:on line.size.limit
-case class RegrR2(y: Expression, x: Expression)
- extends PearsonCorrelation(y, x) with RegrLike {
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override val evaluateExpression: Expression = {
- If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
- If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk))
- }
-
- override def prettyName: String = "regr_r2"
-}
-
-
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.",
- since = "2.4.0")
-// scalastyle:on line.size.limit
-case class RegrIntercept(y: Expression, x: Expression)
- extends PearsonCorrelation(y, x) with RegrLike {
-
- override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
-
- override val evaluateExpression: Expression = {
- If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
- xAvg - (ck / yMk) * yAvg)
- }
-
- override def prettyName: String = "regr_intercept"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 85bc1cdb43051..9cc7dbadd923a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3088,11 +3088,24 @@ case class ArrayRemove(left: Expression, right: Expression)
override def dataType: DataType = left.dataType
override def inputTypes: Seq[AbstractDataType] = {
- val elementType = left.dataType match {
- case t: ArrayType => t.elementType
- case _ => AnyDataType
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, hasNull), e2) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
+ TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
+ case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+ s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " +
+ s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
}
- Seq(ArrayType, elementType)
}
private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
@@ -3100,14 +3113,6 @@ case class ArrayRemove(left: Expression, right: Expression)
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
- override def checkInputDataTypes(): TypeCheckResult = {
- super.checkInputDataTypes() match {
- case f: TypeCheckResult.TypeCheckFailure => f
- case TypeCheckResult.TypeCheckSuccess =>
- TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
- }
- }
-
override def nullSafeEval(arr: Any, value: Any): Any = {
val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements())
var pos = 0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 47eeb70e00427..64152e04928d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -113,6 +113,11 @@ private[sql] class JSONOptions(
}
val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n")
+ /**
+ * Generating JSON strings in pretty representation if the parameter is enabled.
+ */
+ val pretty: Boolean = parameters.get("pretty").map(_.toBoolean).getOrElse(false)
+
/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index 9b86d865622dc..d02a2be8ddad6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -70,7 +70,10 @@ private[sql] class JacksonGenerator(
s"Initial type ${dataType.catalogString} must be a ${MapType.simpleString}")
}
- private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+ private val gen = {
+ val generator = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+ if (options.pretty) generator.useDefaultPrettyPrinter() else generator
+ }
private val lineSeparator: String = options.lineSeparatorInWrite
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 7c461895c5e52..07a653f3b5d48 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -532,12 +532,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform {
// Prunes the unused columns from project list of Project/Aggregate/Expand
- case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
+ case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) =>
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
- case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
+ case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) =>
p.copy(
child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
- case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
+ case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) =>
val newOutput = e.output.filter(a.references.contains(_))
val newProjects = e.projections.map { proj =>
proj.zip(e.output).filter { case (_, a) =>
@@ -547,18 +547,18 @@ object ColumnPruning extends Rule[LogicalPlan] {
a.copy(child = Expand(newProjects, newOutput, grandChild))
// Prunes the unused columns from child of `DeserializeToObject`
- case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty =>
+ case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) =>
d.copy(child = prunedChild(child, d.references))
// Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation
- case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
+ case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) =>
a.copy(child = prunedChild(child, a.references))
- case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty =>
+ case f @ FlatMapGroupsInPandas(_, _, _, child) if !child.outputSet.subsetOf(f.references) =>
f.copy(child = prunedChild(child, f.references))
- case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
+ case e @ Expand(_, _, child) if !child.outputSet.subsetOf(e.references) =>
e.copy(child = prunedChild(child, e.references))
case s @ ScriptTransformation(_, _, _, child, _)
- if (child.outputSet -- s.references).nonEmpty =>
+ if !child.outputSet.subsetOf(s.references) =>
s.copy(child = prunedChild(child, s.references))
// prune unrequired references
@@ -579,7 +579,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
case p @ Project(_, _: Distinct) => p
// Eliminate unneeded attributes from children of Union.
case p @ Project(_, u: Union) =>
- if ((u.outputSet -- p.references).nonEmpty) {
+ if (!u.outputSet.subsetOf(p.references)) {
val firstChild = u.children.head
val newOutput = prunedChild(firstChild, p.references).output
// pruning the columns of all children based on the pruned first child.
@@ -595,7 +595,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
// Prune unnecessary window expressions
- case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty =>
+ case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) =>
p.copy(child = w.copy(
windowExpressions = w.windowExpressions.filter(p.references.contains)))
@@ -611,7 +611,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
// for all other logical plans that inherits the output from it's children
case p @ Project(_, child) =>
val required = child.references ++ p.references
- if ((child.inputSet -- required).nonEmpty) {
+ if (!child.inputSet.subsetOf(required)) {
val newChildren = child.children.map(c => prunedChild(c, required))
p.copy(child = child.withNewChildren(newChildren))
} else {
@@ -621,7 +621,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
/** Applies a projection only when the child is producing unnecessary attributes */
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
- if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
+ if (!c.outputSet.subsetOf(allReferences)) {
Project(c.output.filter(allReferences.contains), c)
} else {
c
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b1ffdca091461..ca0cea6ba7de3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -42,7 +42,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* All Attributes that appear in expressions from this operator. Note that this set does not
* include attributes that are implicitly referenced by being passed through to the output tuple.
*/
- def references: AttributeSet = AttributeSet(expressions.flatMap(_.references))
+ def references: AttributeSet = AttributeSet.fromAttributeSets(expressions.map(_.references))
/**
* The set of all attributes that are input to this operator by its children.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala
index bb2c5926ae9bb..288a4f34a447e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala
@@ -42,7 +42,11 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma
override def iterator: Iterator[(String, T)] = keyLowerCasedMap.iterator
override def -(key: String): Map[String, T] = {
- new CaseInsensitiveMap(originalMap.filterKeys(!_.equalsIgnoreCase(key)))
+ new CaseInsensitiveMap(originalMap.filter(!_._1.equalsIgnoreCase(key)))
+ }
+
+ override def filterKeys(p: (String) => Boolean): Map[String, T] = {
+ new CaseInsensitiveMap(originalMap.filter(kv => p(kv._1)))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 0e0a01def357e..f6c98805bfb15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -451,8 +451,11 @@ object SQLConf {
.createWithDefault(10)
val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
- .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " +
- "versions, when converting Parquet schema to Spark SQL schema and vice versa.")
+ .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " +
+ "values will be written in Apache Parquet's fixed-length byte array format, which other " +
+ "systems such as Apache Hive and Apache Impala use. If false, the newer format in Parquet " +
+ "will be used. For example, decimals will be written in int-based format. If Parquet " +
+ "output is intended for use with systems that do not support this newer format, set to true.")
.booleanConf
.createWithDefault(false)
@@ -1295,15 +1298,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
- val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION =
- buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition")
+ val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME =
+ buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName")
.internal()
- .doc("When true, a grouped map Pandas UDF will assign columns from the returned " +
- "Pandas DataFrame based on position, regardless of column label type. When false, " +
- "columns will be looked up by name if labeled with a string and fallback to use " +
- "position if not. This configuration will be deprecated in future releases.")
+ .doc("When true, columns will be looked up by name if labeled with a string and fallback " +
+ "to use position if not. When false, a grouped map Pandas UDF will assign columns from " +
+ "the returned Pandas DataFrame based on position, regardless of column label type. " +
+ "This configuration will be deprecated in future releases.")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter")
.internal()
@@ -1328,6 +1331,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val LITERAL_PICK_MINIMUM_PRECISION =
+ buildConf("spark.sql.legacy.literal.pickMinimumPrecision")
+ .internal()
+ .doc("When integral literal is used in decimal operations, pick a minimum precision " +
+ "required by the literal if this config is true, to make the resulting precision and/or " +
+ "scale smaller. This can reduce the possibility of precision lose and/or overflow.")
+ .booleanConf
+ .createWithDefault(true)
+
val SQL_OPTIONS_REDACTION_PATTERN =
buildConf("spark.sql.redaction.options.regex")
.doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " +
@@ -1915,13 +1927,15 @@ class SQLConf extends Serializable with Logging {
def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE)
- def pandasGroupedMapAssignColumnssByPosition: Boolean =
- getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION)
+ def pandasGroupedMapAssignColumnsByName: Boolean =
+ getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
+ def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
+
def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)
def continuousStreamingExecutorPollIntervalMs: Long =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 67740c3166471..3081ff935f043 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -22,7 +22,6 @@ import org.scalatest.Suite
import org.scalatest.Tag
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
@@ -57,7 +56,7 @@ trait CodegenInterpretedPlanTest extends PlanTest {
* Provides helper methods for comparing plans, but without the overhead of
* mandating a FunSuite.
*/
-trait PlanTestBase extends PredicateHelper { self: Suite =>
+trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite =>
// TODO(gatorsmile): remove this from PlanTest and all the analyzer rules
protected def conf = SQLConf.get
@@ -174,32 +173,4 @@ trait PlanTestBase extends PredicateHelper { self: Suite =>
plan1 == plan2
}
}
-
- /**
- * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
- * configurations.
- */
- protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
- val conf = SQLConf.get
- val (keys, values) = pairs.unzip
- val currentValues = keys.map { key =>
- if (conf.contains(key)) {
- Some(conf.getConfString(key))
- } else {
- None
- }
- }
- (keys, values).zipped.foreach { (k, v) =>
- if (SQLConf.staticConfKeys.contains(k)) {
- throw new AnalysisException(s"Cannot modify the value of a static config: $k")
- }
- conf.setConfString(k, v)
- }
- try f finally {
- keys.zip(currentValues).foreach {
- case (key, Some(value)) => conf.setConfString(key, value)
- case (key, None) => conf.unsetConf(key)
- }
- }
- }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala
new file mode 100644
index 0000000000000..4d869d79ad594
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.plans
+
+import java.io.File
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
+
+trait SQLHelper {
+
+ /**
+ * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
+ * configurations.
+ */
+ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
+ val conf = SQLConf.get
+ val (keys, values) = pairs.unzip
+ val currentValues = keys.map { key =>
+ if (conf.contains(key)) {
+ Some(conf.getConfString(key))
+ } else {
+ None
+ }
+ }
+ (keys, values).zipped.foreach { (k, v) =>
+ if (SQLConf.staticConfKeys.contains(k)) {
+ throw new AnalysisException(s"Cannot modify the value of a static config: $k")
+ }
+ conf.setConfString(k, v)
+ }
+ try f finally {
+ keys.zip(currentValues).foreach {
+ case (key, Some(value)) => conf.setConfString(key, value)
+ case (key, None) => conf.unsetConf(key)
+ }
+ }
+ }
+
+ /**
+ * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If
+ * a file/directory is created there by `f`, it will be delete after `f` returns.
+ */
+ protected def withTempPath(f: File => Unit): Unit = {
+ val path = Utils.createTempDir()
+ path.delete()
+ try f(path) finally Utils.deleteRecursively(path)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala
new file mode 100644
index 0000000000000..03eed4aaa750b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.JavaSerializer
+
+class CaseInsensitiveMapSuite extends SparkFunSuite {
+ private def shouldBeSerializable(m: Map[String, String]): Unit = {
+ new JavaSerializer(new SparkConf()).newInstance().serialize(m)
+ }
+
+ test("Keys are case insensitive") {
+ val m = CaseInsensitiveMap(Map("a" -> "b", "foO" -> "bar"))
+ assert(m("FOO") == "bar")
+ assert(m("fOo") == "bar")
+ assert(m("A") == "b")
+ shouldBeSerializable(m)
+ }
+
+ test("CaseInsensitiveMap should be serializable after '-' operator") {
+ val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")) - "a"
+ assert(m == Map("foo" -> "bar"))
+ shouldBeSerializable(m)
+ }
+
+ test("CaseInsensitiveMap should be serializable after '+' operator") {
+ val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")) + ("x" -> "y")
+ assert(m == Map("a" -> "b", "foo" -> "bar", "x" -> "y"))
+ shouldBeSerializable(m)
+ }
+
+ test("CaseInsensitiveMap should be serializable after 'filterKeys' method") {
+ val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")).filterKeys(_ == "foo")
+ assert(m == Map("foo" -> "bar"))
+ shouldBeSerializable(m)
+ }
+}
diff --git a/sql/core/benchmarks/SortBenchmark-results.txt b/sql/core/benchmarks/SortBenchmark-results.txt
new file mode 100644
index 0000000000000..0d00a0c89d02d
--- /dev/null
+++ b/sql/core/benchmarks/SortBenchmark-results.txt
@@ -0,0 +1,17 @@
+================================================================================================
+radix sort
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_162-b12 on Mac OS X 10.13.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+reference TimSort key prefix array 11770 / 11960 2.1 470.8 1.0X
+reference Arrays.sort 2106 / 2128 11.9 84.3 5.6X
+radix sort one byte 93 / 100 269.7 3.7 126.9X
+radix sort two bytes 171 / 179 146.0 6.9 68.7X
+radix sort eight bytes 659 / 664 37.9 26.4 17.9X
+radix sort key prefix array 1024 / 1053 24.4 41.0 11.5X
+
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
index 533097ac399e9..b1e8fb39ac9de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
@@ -131,11 +131,8 @@ object ArrowUtils {
} else {
Nil
}
- val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) {
- Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true")
- } else {
- Nil
- }
- Map(timeZoneConf ++ pandasColsByPosition: _*)
+ val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
+ conf.pandasGroupedMapAssignColumnsByName.toString)
+ Map(timeZoneConf ++ pandasColsByName: _*)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 10b67d7a1ca54..4c58e77df485e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3619,6 +3619,8 @@ object functions {
* @param e a column containing a struct, an array or a map.
* @param options options to control how the struct column is converted into a json string.
* accepts the same options and the json data source.
+ * Additionally the function supports the `pretty` option which enables
+ * pretty JSON generation.
*
* @group collection_funcs
* @since 2.1.0
@@ -3635,6 +3637,8 @@ object functions {
* @param e a column containing a struct, an array or a map.
* @param options options to control how the struct column is converted into a json string.
* accepts the same options and the json data source.
+ * Additionally the function supports the `pretty` option which enables
+ * pretty JSON generation.
*
* @group collection_funcs
* @since 2.1.0
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
deleted file mode 100644
index 92c7e26e3add2..0000000000000
--- a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
+++ /dev/null
@@ -1,56 +0,0 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-
-CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
- (101, 1, 1, 1),
- (201, 2, 1, 1),
- (301, 3, 1, 1),
- (401, 4, 1, 11),
- (501, 5, 1, null),
- (601, 6, null, 1),
- (701, 6, null, null),
- (102, 1, 2, 2),
- (202, 2, 1, 2),
- (302, 3, 2, 1),
- (402, 4, 2, 12),
- (502, 5, 2, null),
- (602, 6, null, 2),
- (702, 6, null, null),
- (103, 1, 3, 3),
- (203, 2, 1, 3),
- (303, 3, 3, 1),
- (403, 4, 3, 13),
- (503, 5, 3, null),
- (603, 6, null, 3),
- (703, 6, null, null),
- (104, 1, 4, 4),
- (204, 2, 1, 4),
- (304, 3, 4, 1),
- (404, 4, 4, 14),
- (504, 5, 4, null),
- (604, 6, null, 4),
- (704, 6, null, null),
- (800, 7, 1, 1)
-as t1(id, px, y, x);
-
-select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x),
- regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x),
- regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x)
-from t1 group by px order by px;
-
-
-select id, regr_count(y,x) over (partition by px) from t1 order by id;
diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out
deleted file mode 100644
index d7d009a64bf84..0000000000000
--- a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out
+++ /dev/null
@@ -1,93 +0,0 @@
--- Automatically generated by SQLQueryTestSuite
--- Number of queries: 3
-
-
--- !query 0
-CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
- (101, 1, 1, 1),
- (201, 2, 1, 1),
- (301, 3, 1, 1),
- (401, 4, 1, 11),
- (501, 5, 1, null),
- (601, 6, null, 1),
- (701, 6, null, null),
- (102, 1, 2, 2),
- (202, 2, 1, 2),
- (302, 3, 2, 1),
- (402, 4, 2, 12),
- (502, 5, 2, null),
- (602, 6, null, 2),
- (702, 6, null, null),
- (103, 1, 3, 3),
- (203, 2, 1, 3),
- (303, 3, 3, 1),
- (403, 4, 3, 13),
- (503, 5, 3, null),
- (603, 6, null, 3),
- (703, 6, null, null),
- (104, 1, 4, 4),
- (204, 2, 1, 4),
- (304, 3, 4, 1),
- (404, 4, 4, 14),
- (504, 5, 4, null),
- (604, 6, null, 4),
- (704, 6, null, null),
- (800, 7, 1, 1)
-as t1(id, px, y, x)
--- !query 0 schema
-struct<>
--- !query 0 output
-
-
-
--- !query 1
-select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x),
- regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x),
- regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x)
-from t1 group by px order by px
--- !query 1 schema
-struct
--- !query 1 output
-1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4
-2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4
-3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4
-4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4
-5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0
-6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0
-7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1
-
-
--- !query 2
-select id, regr_count(y,x) over (partition by px) from t1 order by id
--- !query 2 schema
-struct
--- !query 2 output
-101 4
-102 4
-103 4
-104 4
-201 4
-202 4
-203 4
-204 4
-301 4
-302 4
-303 4
-304 4
-401 4
-402 4
-403 4
-404 4
-501 0
-502 0
-503 0
-504 0
-601 0
-602 0
-603 0
-604 0
-701 0
-702 0
-703 0
-704 0
-800 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index fd71f24935611..88dbae8c21350 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1574,6 +1574,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null))
)
+ checkAnswer(
+ OneRowRelation().selectExpr("array_remove(array(1, 2), 1.23D)"),
+ Seq(
+ Row(Seq(1.0, 2.0))
+ )
+ )
+
+ checkAnswer(
+ OneRowRelation().selectExpr("array_remove(array(1, 2), 1.0D)"),
+ Seq(
+ Row(Seq(2.0))
+ )
+ )
+
+ checkAnswer(
+ OneRowRelation().selectExpr("array_remove(array(1.0D, 2.0D), 2)"),
+ Seq(
+ Row(Seq(1.0))
+ )
+ )
+
+ checkAnswer(
+ OneRowRelation().selectExpr("array_remove(array(1.1D, 1.2D), 1)"),
+ Seq(
+ Row(Seq(1.1, 1.2))
+ )
+ )
+
checkAnswer(
df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")",
"array_remove(c, \"\")"),
@@ -1583,10 +1611,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null, null, null))
)
- val e = intercept[AnalysisException] {
+ val e1 = intercept[AnalysisException] {
Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)")
}
- assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type"))
+ val errorMsg1 =
+ s"""
+ |Input to function array_remove should have been array followed by a
+ |value with same element type, but it's [string, string].
+ """.stripMargin.replace("\n", " ").trim()
+ assert(e1.message.contains(errorMsg1))
+
+ val e2 = intercept[AnalysisException] {
+ OneRowRelation().selectExpr("array_remove(array(1, 2), '1')")
+ }
+
+ val errorMsg2 =
+ s"""
+ |Input to function array_remove should have been array followed by a
+ |value with same element type, but it's [array, string].
+ """.stripMargin.replace("\n", " ").trim()
+ assert(e2.message.contains(errorMsg2))
}
test("array_distinct functions") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index fe4bf15fa3921..853bc182f2f4a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -518,4 +518,25 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
jsonDF.select(to_json(from_json($"a", schema))),
Seq(Row(json)))
}
+
+ test("pretty print - roundtrip from_json -> to_json") {
+ val json = """[{"book":{"publisher":[{"country":"NL","year":[1981,1986,1999]}]}}]"""
+ val jsonDF = Seq(json).toDF("root")
+ val expected =
+ """[ {
+ | "book" : {
+ | "publisher" : [ {
+ | "country" : "NL",
+ | "year" : [ 1981, 1986, 1999 ]
+ | } ]
+ | }
+ |} ]""".stripMargin
+
+ checkAnswer(
+ jsonDF.select(
+ to_json(
+ from_json($"root", schema_of_json(lit(json))),
+ Map("pretty" -> "true"))),
+ Seq(Row(expected)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 8fcebb35a0543..631ab1b7ece7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2849,6 +2849,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val result = ds.flatMap(_.bar).distinct
result.rdd.isEmpty
}
+
+ test("SPARK-25454: decimal division with negative scale") {
+ // TODO: completely fix this issue even when LITERAL_PRECISE_PRECISION is true.
+ withSQLConf(SQLConf.LITERAL_PICK_MINIMUM_PRECISION.key -> "false") {
+ checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000")))
+ }
+ }
}
case class Foo(bar: Option[String])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
index cf9bda2fb1ff1..51a7f9f1ef096 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
@@ -19,17 +19,17 @@ package org.apache.spark.sql.execution.benchmark
import java.io.File
import scala.collection.JavaConverters._
-import scala.util.{Random, Try}
+import scala.util.Random
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnVector
-import org.apache.spark.util.Utils
/**
@@ -37,7 +37,7 @@ import org.apache.spark.util.Utils
* To run this:
* spark-submit --class
*/
-object DataSourceReadBenchmark {
+object DataSourceReadBenchmark extends SQLHelper {
val conf = new SparkConf()
.setAppName("DataSourceReadBenchmark")
// Since `spark.master` always exists, overrides this value
@@ -54,27 +54,10 @@ object DataSourceReadBenchmark {
spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
- def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
-
def withTempTable(tableNames: String*)(f: => Unit): Unit = {
try f finally tableNames.foreach(spark.catalog.dropTempView)
}
- def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
- val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
- (keys, values).zipped.foreach(spark.conf.set)
- try f finally {
- keys.zip(currentValues).foreach {
- case (key, Some(value)) => spark.conf.set(key, value)
- case (key, None) => spark.conf.unset(key)
- }
- }
- }
private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
val testDf = if (partition.isDefined) {
df.write.partitionBy(partition.get)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala
index 3b7f10783b64c..7cdf653e38697 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala
@@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.benchmark
import java.io.File
-import scala.util.{Random, Try}
+import scala.util.Random
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.functions.monotonically_increasing_id
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType}
-import org.apache.spark.util.Utils
/**
* Benchmark to measure read performance with Filter pushdown.
@@ -40,7 +40,7 @@ import org.apache.spark.util.Utils
* Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt".
* }}}
*/
-object FilterPushdownBenchmark extends BenchmarkBase {
+object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper {
private val conf = new SparkConf()
.setAppName(this.getClass.getSimpleName)
@@ -60,28 +60,10 @@ object FilterPushdownBenchmark extends BenchmarkBase {
private val spark = SparkSession.builder().config(conf).getOrCreate()
- def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
-
def withTempTable(tableNames: String*)(f: => Unit): Unit = {
try f finally tableNames.foreach(spark.catalog.dropTempView)
}
- def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
- val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
- (keys, values).zipped.foreach(spark.conf.set)
- try f finally {
- keys.zip(currentValues).foreach {
- case (key, Some(value)) => spark.conf.set(key, value)
- case (key, None) => spark.conf.unset(key)
- }
- }
- }
-
private def prepareTable(
dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = {
import spark.implicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
index 17619ec5fadc1..958a064402149 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.benchmark
import java.util.{Arrays, Comparator}
-import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.unsafe.array.LongArray
import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.util.collection.Sorter
@@ -28,12 +28,15 @@ import org.apache.spark.util.random.XORShiftRandom
/**
* Benchmark to measure performance for aggregate primitives.
- * To run this:
- * build/sbt "sql/test-only *benchmark.SortBenchmark"
- *
- * Benchmarks in this file are skipped in normal builds.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt: bin/spark-submit --class
+ * 2. build/sbt "sql/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain "
+ * Results will be written to "benchmarks/-results.txt".
+ * }}}
*/
-class SortBenchmark extends BenchmarkWithCodegen {
+object SortBenchmark extends BenchmarkBase {
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
@@ -54,10 +57,10 @@ class SortBenchmark extends BenchmarkWithCodegen {
new LongArray(MemoryBlock.fromLongArray(extended)))
}
- ignore("sort") {
+ def sortBenchmark(): Unit = {
val size = 25000000
val rand = new XORShiftRandom(123)
- val benchmark = new Benchmark("radix sort " + size, size)
+ val benchmark = new Benchmark("radix sort " + size, size, output = output)
benchmark.addTimerCase("reference TimSort key prefix array") { timer =>
val array = Array.tabulate[Long](size * 2) { i => rand.nextLong }
val buf = new LongArray(MemoryBlock.fromLongArray(array))
@@ -114,20 +117,11 @@ class SortBenchmark extends BenchmarkWithCodegen {
timer.stopTiming()
}
benchmark.run()
+ }
- /*
- Running benchmark: radix sort 25000000
- Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic
- Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
-
- radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
- -------------------------------------------------------------------------------------------
- reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X
- reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X
- radix sort one byte 133 / 137 188.4 5.3 117.2X
- radix sort two bytes 255 / 258 98.2 10.2 61.1X
- radix sort eight bytes 991 / 997 25.2 39.6 15.7X
- radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X
- */
+ override def benchmark(): Unit = {
+ runBenchmark("radix sort") {
+ sortBenchmark()
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
index 6d319eb723d93..5d1a874999c09 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
@@ -16,21 +16,19 @@
*/
package org.apache.spark.sql.execution.datasources.csv
-import java.io.File
-
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{Column, Row, SparkSession}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
/**
* Benchmark to measure CSV read/write performance.
* To run this:
* spark-submit --class --jars
*/
-object CSVBenchmarks {
+object CSVBenchmarks extends SQLHelper {
val conf = new SparkConf()
val spark = SparkSession.builder
@@ -40,12 +38,6 @@ object CSVBenchmarks {
.getOrCreate()
import spark.implicits._
- def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
-
def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = {
val benchmark = new Benchmark(s"Parsing quoted values", rowsNum)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
index e40cb9b50148b..368318ab38cb9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
@@ -21,16 +21,16 @@ import java.io.File
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
/**
* The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't.
* To run this:
* spark-submit --class --jars
*/
-object JSONBenchmarks {
+object JSONBenchmarks extends SQLHelper {
val conf = new SparkConf()
val spark = SparkSession.builder
@@ -40,13 +40,6 @@ object JSONBenchmarks {
.getOrCreate()
import spark.implicits._
- def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
-
-
def schemaInferring(rowsNum: Int): Unit = {
val benchmark = new Benchmark("JSON schema inferring", rowsNum)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala
index fe59cb25d5005..cbac1c13cdd33 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala
@@ -25,12 +25,12 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.util.Utils
-abstract class CheckpointFileManagerTests extends SparkFunSuite {
+abstract class CheckpointFileManagerTests extends SparkFunSuite with SQLHelper {
def createManager(path: Path): CheckpointFileManager
@@ -88,12 +88,6 @@ abstract class CheckpointFileManagerTests extends SparkFunSuite {
fm.delete(path) // should not throw exception
}
}
-
- protected def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
}
class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 2fb8f70a20791..6b03d1e5b7662 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -40,7 +40,6 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.FilterExec
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.UninterruptibleThread
import org.apache.spark.util.Utils
@@ -167,18 +166,6 @@ private[sql] trait SQLTestUtilsBase
super.withSQLConf(pairs: _*)(f)
}
- /**
- * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If
- * a file/directory is created there by `f`, it will be delete after `f` returns.
- *
- * @todo Probably this method should be moved to a more general place
- */
- protected def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
-
/**
* Copy file in jar's resource to a temp file, then pass it to `f`.
* This function is used to make `f` can use the path of temp file(e.g. file:/), instead of
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala
index 0eab7d1ea8e80..49de007df3828 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala
@@ -19,14 +19,15 @@ package org.apache.spark.sql.hive.orc
import java.io.File
-import scala.util.{Random, Try}
+import scala.util.Random
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
+
/**
* Benchmark to measure ORC read performance.
@@ -34,7 +35,7 @@ import org.apache.spark.util.Utils
* This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources.
*/
// scalastyle:off line.size.limit
-object OrcReadBenchmark {
+object OrcReadBenchmark extends SQLHelper {
val conf = new SparkConf()
conf.set("orc.compression", "snappy")
@@ -47,28 +48,10 @@ object OrcReadBenchmark {
// Set default configs. Individual cases will change them if necessary.
spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true")
- def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
-
def withTempTable(tableNames: String*)(f: => Unit): Unit = {
try f finally tableNames.foreach(spark.catalog.dropTempView)
}
- def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
- val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
- (keys, values).zipped.foreach(spark.conf.set)
- try f finally {
- keys.zip(currentValues).foreach {
- case (key, Some(value)) => spark.conf.set(key, value)
- case (key, None) => spark.conf.unset(key)
- }
- }
- }
-
private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName
private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index a882558551e37..135430f1ef621 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -59,6 +59,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time)
"spark.yarn.jars",
"spark.yarn.keytab",
"spark.yarn.principal",
+ "spark.kerberos.keytab",
+ "spark.kerberos.principal",
"spark.ui.filters",
"spark.mesos.driver.frameworkId")