diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1cb4781f4d0a..34691883bc5a9 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2954,6 +2954,9 @@ setMethod("exceptAll", #' @param source a name for external data source. #' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' #' save mode (it is 'error' by default) +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar +#' to Hive's partitioning scheme. #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2965,13 +2968,13 @@ setMethod("exceptAll", #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' write.df(df, "myfile", "parquet", "overwrite") +#' write.df(df, "myfile", "parquet", "overwrite", partitionBy = c("col1", "col2")) #' saveDF(df, parquetPath2, "parquet", mode = "append", mergeSchema = TRUE) #' } #' @note write.df since 1.4.0 setMethod("write.df", signature(df = "SparkDataFrame"), - function(df, path = NULL, source = NULL, mode = "error", ...) { + function(df, path = NULL, source = NULL, mode = "error", partitionBy = NULL, ...) { if (!is.null(path) && !is.character(path)) { stop("path should be character, NULL or omitted.") } @@ -2985,8 +2988,18 @@ setMethod("write.df", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) is.character(c)))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } write <- setWriteOptions(write, path = path, mode = mode, ...) write <- handledCallJMethod(write, "save") }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 572dee50127b8..6425c9d26bef3 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -198,8 +198,9 @@ NULL #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. In \code{arrays_zip}, this contains additional -#' Columns of arrays to be merged. +#' options as the JSON data source. Additionally \code{to_json} supports the "pretty" +#' option which enables pretty JSON generation. In \code{arrays_zip}, this contains +#' additional Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d3a9cbae7d808..038fefadaaeff 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -626,6 +626,8 @@ sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-pat sparkConfToSubmitOps[["spark.master"]] <- "--master" sparkConfToSubmitOps[["spark.yarn.keytab"]] <- "--keytab" sparkConfToSubmitOps[["spark.yarn.principal"]] <- "--principal" +sparkConfToSubmitOps[["spark.kerberos.keytab"]] <- "--keytab" +sparkConfToSubmitOps[["spark.kerberos.principal"]] <- "--principal" # Utility function that returns Spark Submit arguments as a string diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a874bfbb58dc7..50eff3755edf8 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2701,8 +2701,16 @@ test_that("read/write text files", { expect_equal(colnames(df2), c("value")) expect_equal(count(df2), count(df) * 2) + df3 <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), + schema = c("key", "value")) + textPath3 <- tempfile(pattern = "textPath3", fileext = ".txt") + write.df(df3, textPath3, "text", mode = "overwrite", partitionBy = "key") + df4 <- read.df(textPath3, "text") + expect_equal(count(df3), count(df4)) + unlink(textPath) unlink(textPath2) + unlink(textPath3) }) test_that("read/write text files - compression option", { diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 090363c5f8a3e..ad934947437bc 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -157,8 +157,8 @@ Property Name | Property group | spark-submit equivalent `spark.driver.extraClassPath` | Runtime Environment | `--driver-class-path` `spark.driver.extraJavaOptions` | Runtime Environment | `--driver-java-options` `spark.driver.extraLibraryPath` | Runtime Environment | `--driver-library-path` -`spark.yarn.keytab` | Application Properties | `--keytab` -`spark.yarn.principal` | Application Properties | `--principal` +`spark.kerberos.keytab` | Application Properties | `--keytab` +`spark.kerberos.principal` | Application Properties | `--principal` **For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`. diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 6c4c5c94cfa28..e0f98f1aca071 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -726,7 +726,11 @@ private[spark] object SparkConf extends Logging { DRIVER_MEMORY_OVERHEAD.key -> Seq( AlternateConfig("spark.yarn.driver.memoryOverhead", "2.3")), EXECUTOR_MEMORY_OVERHEAD.key -> Seq( - AlternateConfig("spark.yarn.executor.memoryOverhead", "2.3")) + AlternateConfig("spark.yarn.executor.memoryOverhead", "2.3")), + KEYTAB.key -> Seq( + AlternateConfig("spark.yarn.keytab", "2.5")), + PRINCIPAL.key -> Seq( + AlternateConfig("spark.yarn.principal", "2.5")) ) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cf902db8709e7..d5f2865f87281 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -520,6 +520,10 @@ private[spark] class SparkSubmit extends Logging { confKey = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.driver.extraLibraryPath"), + OptionAssigner(args.principal, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = PRINCIPAL.key), + OptionAssigner(args.keytab, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = KEYTAB.key), // Propagate attributes for dependency resolution at the driver side OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.packages"), @@ -537,8 +541,6 @@ private[spark] class SparkSubmit extends Logging { OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"), - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE | YARN | KUBERNETES, ALL_DEPLOY_MODES, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0998757715457..4cf08a7980f55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -199,8 +199,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull - keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull - principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + keytab = Option(keytab) + .orElse(sparkProperties.get("spark.kerberos.keytab")) + .orElse(sparkProperties.get("spark.yarn.keytab")) + .orNull + principal = Option(principal) + .orElse(sparkProperties.get("spark.kerberos.principal")) + .orElse(sparkProperties.get("spark.yarn.principal")) + .orNull dynamicAllocationEnabled = sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 44d23908146c7..c23a659e76df1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} import java.nio.file.Files -import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -133,9 +132,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - val perms = PosixFilePermissions.fromString("rwx------") - val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), - PosixFilePermissions.asFileAttribute(perms)).toFile() + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath()).toFile() + Utils.chmod700(dbPath) val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala index c03a360b91ef8..ad0dd23cb59c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala @@ -18,8 +18,6 @@ package org.apache.spark.deploy.history import java.io.File -import java.nio.file.Files -import java.nio.file.attribute.PosixFilePermissions import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ @@ -107,9 +105,8 @@ private class HistoryServerDiskManager( val needed = approximateSize(eventLogSize, isCompressed) makeRoom(needed) - val perms = PosixFilePermissions.fromString("rwx------") - val tmp = Files.createTempDirectory(tmpStoreDir.toPath(), "appstore", - PosixFilePermissions.asFileAttribute(perms)).toFile() + val tmp = Utils.createTempDir(tmpStoreDir.getPath(), "appstore") + Utils.chmod700(tmp) updateUsage(needed) val current = currentUsage.get() diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9891b6a2196de..7f6342208350a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -152,11 +152,11 @@ package object config { private[spark] val SHUFFLE_SERVICE_PORT = ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337) - private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab") + private[spark] val KEYTAB = ConfigBuilder("spark.kerberos.keytab") .doc("Location of user's keytab.") .stringConf.createOptional - private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal") + private[spark] val PRINCIPAL = ConfigBuilder("spark.kerberos.principal") .doc("Name of the Kerberos principal.") .stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 659694dd189ad..0e221edf3965a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -49,10 +49,16 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { + /** + * Min partition number to use [[HighlyCompressedMapStatus]]. A bit ugly here because in test + * code we can't assume SparkEnv.get exists. + */ + private lazy val minPartitionsToUseHighlyCompressMapStatus = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > Option(SparkEnv.get) - .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) - .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { + if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index f21eee1965761..36aaf67b57298 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -388,10 +388,11 @@ private[spark] class AppStatusListener( job.completionTime = if (event.time > 0) Some(new Date(event.time)) else None update(job, now, last = true) + if (job.status == JobExecutionStatus.SUCCEEDED) { + appSummary = new AppSummary(appSummary.numCompletedJobs + 1, appSummary.numCompletedStages) + kvstore.write(appSummary) + } } - - appSummary = new AppSummary(appSummary.numCompletedJobs + 1, appSummary.numCompletedStages) - kvstore.write(appSummary) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -653,13 +654,14 @@ private[spark] class AppStatusListener( if (removeStage) { liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber)) } + if (stage.status == v1.StageStatus.COMPLETE) { + appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) + kvstore.write(appSummary) + } } // remove any dead executors that were not running for any currently active stages deadExecutors.retain((execId, exec) => isExecutorActiveForLiveStages(exec)) - - appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) - kvstore.write(appSummary) } private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 22341467add5c..0fe82ac0cedc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -438,10 +438,8 @@ private[spark] class BlockManager( // stream. channel.close() // TODO SPARK-25035 Even if we're only going to write the data to disk after this, we end up - // using a lot of memory here. With encryption, we'll read the whole file into a regular - // byte buffer and OOM. Without encryption, we'll memory map the file and won't get a jvm - // OOM, but might get killed by the OS / cluster manager. We could at least read the tmp - // file as a stream in both cases. + // using a lot of memory here. We'll read the whole file into a regular + // byte buffer and OOM. We could at least read the tmp file as a stream. val buffer = securityManager.getIOEncryptionKey() match { case Some(key) => // we need to pass in the size of the unencrypted block @@ -453,7 +451,7 @@ private[spark] class BlockManager( new EncryptedBlockData(tmpFile, blockSize, conf, key).toChunkedByteBuffer(allocator) case None => - ChunkedByteBuffer.map(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt) + ChunkedByteBuffer.fromFile(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt) } putBytes(blockId, buffer, level)(classTag) tmpFile.delete() @@ -726,10 +724,9 @@ private[spark] class BlockManager( */ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { // TODO if we change this method to return the ManagedBuffer, then getRemoteValues - // could just use the inputStream on the temp file, rather than memory-mapping the file. + // could just use the inputStream on the temp file, rather than reading the file into memory. // Until then, replication can cause the process to use too much memory and get killed - // by the OS / cluster manager (not a java OOM, since it's a memory-mapped file) even though - // we've read the data to disk. + // even though we've read the data to disk. logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") var runningFailureCount = 0 diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e534c746433f2..aecc2284a9588 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -448,35 +448,35 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } - - input = streamWrapper(blockId, in) - // Only copy the stream if it's wrapped by compression or encryption, also the size of - // block is small (the decompressed block is smaller than maxBytesInFlight) - if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { - val originalInput = input - val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - try { + var isStreamCopied: Boolean = false + try { + input = streamWrapper(blockId, in) + // Only copy the stream if it's wrapped by compression or encryption, also the size of + // block is small (the decompressed block is smaller than maxBytesInFlight) + if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { + isStreamCopied = true + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) // Decompress the whole block at once to detect any corruption, which could increase // the memory usage tne potential increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. - Utils.copyStream(input, out) - out.close() + Utils.copyStream(input, out, closeStreams = true) input = out.toChunkedByteBuffer.toInputStream(dispose = true) - } catch { - case e: IOException => - buf.release() - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, address, e) - } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest(address, Array((blockId, size))) - result = null - } - } finally { - // TODO: release the buf here to free memory earlier - originalInput.close() + } + } catch { + case e: IOException => + buf.release() + if (buf.isInstanceOf[FileSegmentManagedBuffer] + || corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest(address, Array((blockId, size))) + result = null + } + } finally { + // TODO: release the buf here to free memory earlier + if (isStreamCopied) { in.close() } } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 39f050f6ca5ad..4aa8d45ec7404 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -19,17 +19,16 @@ package org.apache.spark.util.io import java.io.{File, FileInputStream, InputStream} import java.nio.ByteBuffer -import java.nio.channels.{FileChannel, WritableByteChannel} -import java.nio.file.StandardOpenOption - -import scala.collection.mutable.ListBuffer +import java.nio.channels.WritableByteChannel +import com.google.common.io.ByteStreams import com.google.common.primitives.UnsignedBytes +import org.apache.commons.io.IOUtils import org.apache.spark.SparkEnv import org.apache.spark.internal.config import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.util.ByteArrayWritableChannel +import org.apache.spark.network.util.{ByteArrayWritableChannel, LimitedInputStream} import org.apache.spark.storage.StorageUtils import org.apache.spark.util.Utils @@ -175,30 +174,36 @@ object ChunkedByteBuffer { def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = { data match { case f: FileSegmentManagedBuffer => - map(f.getFile, maxChunkSize, f.getOffset, f.getLength) + fromFile(f.getFile, maxChunkSize, f.getOffset, f.getLength) case other => new ChunkedByteBuffer(other.nioByteBuffer()) } } - def map(file: File, maxChunkSize: Int): ChunkedByteBuffer = { - map(file, maxChunkSize, 0, file.length()) + def fromFile(file: File, maxChunkSize: Int): ChunkedByteBuffer = { + fromFile(file, maxChunkSize, 0, file.length()) } - def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = { - Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel => - var remaining = length - var pos = offset - val chunks = new ListBuffer[ByteBuffer]() - while (remaining > 0) { - val chunkSize = math.min(remaining, maxChunkSize) - val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize) - pos += chunkSize - remaining -= chunkSize - chunks += chunk - } - new ChunkedByteBuffer(chunks.toArray) + private def fromFile( + file: File, + maxChunkSize: Int, + offset: Long, + length: Long): ChunkedByteBuffer = { + // We do *not* memory map the file, because we may end up putting this into the memory store, + // and spark currently is not expecting memory-mapped buffers in the memory store, it conflicts + // with other parts that manage the lifecyle of buffers and dispose them. See SPARK-25422. + val is = new FileInputStream(file) + ByteStreams.skipFully(is, offset) + val in = new LimitedInputStream(is, length) + val chunkSize = math.min(maxChunkSize, length).toInt + val out = new ChunkedByteBufferOutputStream(chunkSize, ByteBuffer.allocate _) + Utils.tryWithSafeFinally { + IOUtils.copy(in, out) + } { + in.close() + out.close() } + out.toChunkedByteBuffer } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 354e6386fa60e..2155a0f2b6c21 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -188,32 +188,4 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } - - test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { - val conf = new SparkConf() - val env = mock(classOf[SparkEnv]) - doReturn(conf).when(env).conf - SparkEnv.set(env) - val sizes = Array.fill[Long](500)(150L) - // Test default value - val status = MapStatus(null, sizes) - assert(status.isInstanceOf[CompressedMapStatus]) - // Test Non-positive values - for (s <- -1 to 0) { - assertThrows[IllegalArgumentException] { - conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) - } - } - // Test positive values - Seq(1, 100, 499, 500, 501).foreach { s => - conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) - if(sizes.length > s) { - assert(status.isInstanceOf[HighlyCompressedMapStatus]) - } else { - assert(status.isInstanceOf[CompressedMapStatus]) - } - } - } } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e3d67c34d53eb..687f9e46c3285 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -465,7 +465,7 @@ providers can be disabled individually by setting `spark.security.credentials.{s - + - + - + - + diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c72fa3d75d67f..6de9de90c62c3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1004,6 +1004,17 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession

+ + + + +
Property NameDefaultMeaning
spark.yarn.keytabspark.kerberos.keytab (none) The full path to the file that contains the keytab for the principal specified above. This keytab @@ -477,7 +477,7 @@ providers can be disabled individually by setting `spark.security.credentials.{s
spark.yarn.principalspark.kerberos.principal (none) Principal to be used to login to KDC, while running on secure clusters. Equivalent to the diff --git a/docs/sparkr.md b/docs/sparkr.md index b4248e8bb21de..55e8f15da17ca 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -70,12 +70,12 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s --master
spark.yarn.keytabspark.kerberos.keytab Application Properties --keytab
spark.yarn.principalspark.kerberos.principal Application Properties --principal
spark.sql.parquet.writeLegacyFormatfalse + If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal values + will be written in Apache Parquet's fixed-length byte array format, which other systems such as + Apache Hive and Apache Impala use. If false, the newer format in Parquet will be used. For + example, decimals will be written in int-based format. If Parquet output is intended for use + with systems that do not support this newer format, set to true. +
## ORC Files diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index ceb9e318b283b..7b1314bc8c3c0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -134,6 +134,8 @@ private[kafka010] case class InternalKafkaConsumer( /** Reset the internal pre-fetched data. */ def reset(): Unit = { _records = ju.Collections.emptyListIterator() + _nextOffsetInFetchedData = UNKNOWN_OFFSET + _offsetAfterPoll = UNKNOWN_OFFSET } /** @@ -361,8 +363,9 @@ private[kafka010] case class InternalKafkaConsumer( if (offset < fetchedData.offsetAfterPoll) { // Offsets in [offset, fetchedData.offsetAfterPoll) are invisible. Return a record to ask // the next call to start from `fetchedData.offsetAfterPoll`. + val nextOffsetToFetch = fetchedData.offsetAfterPoll fetchedData.reset() - return fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) + return fetchedRecord.withRecord(null, nextOffsetToFetch) } else { // Fetch records from Kafka and update `fetchedData`. fetchData(offset, pollTimeoutMs) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e5f008804ee5b..39c2cde7de40d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -874,6 +874,58 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } } + + test("SPARK-25495: FetchedData.reset should reset all fields") { + val topic = newTopic() + val topicPartition = new TopicPartition(topic, 0) + testUtils.createTopic(topic, partitions = 1) + + val ds = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_committed") + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + .select($"value".as[String]) + + testUtils.withTranscationalProducer { producer => + producer.beginTransaction() + (0 to 3).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + } + testUtils.waitUntilOffsetAppears(topicPartition, 5) + + val q = ds.writeStream.foreachBatch { (ds, epochId) => + if (epochId == 0) { + // Send more message before the tasks of the current batch start reading the current batch + // data, so that the executors will prefetch messages in the next batch and drop them. In + // this case, if we forget to reset `FetchedData._nextOffsetInFetchedData` or + // `FetchedData._offsetAfterPoll` (See SPARK-25495), the next batch will see incorrect + // values and return wrong results hence fail the test. + testUtils.withTranscationalProducer { producer => + producer.beginTransaction() + (4 to 7).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + } + testUtils.waitUntilOffsetAppears(topicPartition, 10) + checkDatasetUnorderly(ds, (0 to 3).map(_.toString): _*) + } else { + checkDatasetUnorderly(ds, (4 to 7).map(_.toString): _*) + } + }.start() + try { + q.processAllAvailable() + } finally { + q.stop() + } + } } diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6da5237d18de4..1c3d9725b285b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2295,7 +2295,9 @@ def to_json(col, options={}): into a JSON string. Throws an exception, in the case of an unsupported type. :param col: name of column containing a struct, an array or a map. - :param options: options to control converting. accepts the same options as the JSON datasource + :param options: options to control converting. accepts the same options as the JSON datasource. + Additionally the function supports the `pretty` option which enables + pretty JSON generation. >>> from pyspark.sql import Row >>> from pyspark.sql.types import * diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9877d6251388b..263cbc56f4799 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -141,8 +141,9 @@ def enableHiveSupport(self): return self.config("spark.sql.catalogImplementation", "hive") def _sparkContext(self, sc): - self._sc = sc - return self + with self._lock: + self._sc = sc + return self @since(2.0) def getOrCreate(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 95e88772ca5de..64a7ceb3fea96 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5846,7 +5846,8 @@ def test_positional_assignment_conf(self): import pandas as pd from pyspark.sql.functions import pandas_udf, PandasUDFType - with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + with self.sql_conf({ + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}): @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) def foo(_): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 974344f01d923..8c59f1f999f18 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -97,8 +97,9 @@ def verify_result_length(*a): def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): - assign_cols_by_pos = runner_conf.get( - "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + assign_cols_by_name = runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true") + assign_cols_by_name = assign_cols_by_name.lower() == "true" def wrapped(key_series, value_series): import pandas as pd @@ -119,7 +120,7 @@ def wrapped(key_series, value_series): "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) # Assign result columns by schema name if user labeled with strings, else use position - if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns): return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] else: return [(result[result.columns[i]], to_arrow_type(field.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index e511f8064e28a..82692334544e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -290,11 +290,13 @@ object DecimalPrecision extends TypeCoercionRule { // potentially loosing 11 digits of the fractional part. Using only the precision needed // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would // become DECIMAL(38, 16), safely having a much lower precision loss. - case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] - && l.dataType.isInstanceOf[IntegralType] => + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && + l.dataType.isInstanceOf[IntegralType] && + SQLConf.get.literalPickMinimumPrecision => b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) - case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] - && r.dataType.isInstanceOf[IntegralType] => + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && + r.dataType.isInstanceOf[IntegralType] && + SQLConf.get.literalPickMinimumPrecision => b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8b69a47036962..7dafebff79874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,15 +300,6 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), - expression[RegrCount]("regr_count"), - expression[RegrSXX]("regr_sxx"), - expression[RegrSYY]("regr_syy"), - expression[RegrAvgX]("regr_avgx"), - expression[RegrAvgY]("regr_avgy"), - expression[RegrSXY]("regr_sxy"), - expression[RegrSlope]("regr_slope"), - expression[RegrR2]("regr_r2"), - expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 7420b6b57d8e1..a7e09eee617e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + protected class AttributeEquals(val a: Attribute) { override def hashCode(): Int = a match { @@ -39,10 +41,13 @@ object AttributeSet { /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ def apply(baseSet: Iterable[Expression]): AttributeSet = { - new AttributeSet( - baseSet - .flatMap(_.references) - .map(new AttributeEquals(_)).toSet) + fromAttributeSets(baseSet.map(_.references)) + } + + /** Constructs a new [[AttributeSet]] given a sequence of [[AttributeSet]]s. */ + def fromAttributeSets(sets: Iterable[AttributeSet]): AttributeSet = { + val baseSet = sets.foldLeft(new mutable.LinkedHashSet[AttributeEquals]())( _ ++= _.baseSet) + new AttributeSet(baseSet.toSet) } } @@ -94,8 +99,14 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]): AttributeSet = - new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + def --(other: Traversable[NamedExpression]): AttributeSet = { + other match { + case otherSet: AttributeSet => + new AttributeSet(baseSet -- otherSet.baseSet) + case _ => + new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + } + } /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 773aefc0ac1f9..c215735ab1c98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -85,7 +85,7 @@ abstract class Expression extends TreeNode[Expression] { def nullable: Boolean - def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) + def references: AttributeSet = AttributeSet.fromAttributeSets(children.map(_.references)) /** Returns the result of evaluating this expression on a given input Row */ def eval(input: InternalRow = null): Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala deleted file mode 100644 index d8f4505588ff2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{AbstractDataType, DoubleType} - -/** - * Base trait for all regression functions. - */ -trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { - def y: Expression - def x: Expression - - override def children: Seq[Expression] = Seq(y, x) - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - - protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { - assert(aggBufferAttributes.length == exprs.length) - val nullableChildren = children.filter(_.nullable) - if (nullableChildren.isEmpty) { - exprs - } else { - exprs.zip(aggBufferAttributes).map { case (e, a) => - If(nullableChildren.map(IsNull).reduce(Or), a, e) - } - } - } -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", - since = "2.4.0") -case class RegrCount(y: Expression, x: Expression) - extends CountLike with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) - - override def prettyName: String = "regr_count" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrSXX(y: Expression, x: Expression) - extends CentralMomentAgg(x) with RegrLike { - - override protected def momentOrder = 2 - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), m2) - } - - override def prettyName: String = "regr_sxx" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrSYY(y: Expression, x: Expression) - extends CentralMomentAgg(y) with RegrLike { - - override protected def momentOrder = 2 - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), m2) - } - - override def prettyName: String = "regr_syy" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrAvgX(y: Expression, x: Expression) - extends AverageLike(x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override def prettyName: String = "regr_avgx" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrAvgY(y: Expression, x: Expression) - extends AverageLike(y) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override def prettyName: String = "regr_avgy" -} - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrSXY(y: Expression, x: Expression) - extends Covariance(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), ck) - } - - override def prettyName: String = "regr_sxy" -} - - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrSlope(y: Expression, x: Expression) - extends PearsonCorrelation(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) - } - - override def prettyName: String = "regr_slope" -} - - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrR2(y: Expression, x: Expression) - extends PearsonCorrelation(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), - If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) - } - - override def prettyName: String = "regr_r2" -} - - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrIntercept(y: Expression, x: Expression) - extends PearsonCorrelation(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), - xAvg - (ck / yMk) * yAvg) - } - - override def prettyName: String = "regr_intercept" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 85bc1cdb43051..9cc7dbadd923a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3088,11 +3088,24 @@ case class ArrayRemove(left: Expression, right: Expression) override def dataType: DataType = left.dataType override def inputTypes: Seq[AbstractDataType] = { - val elementType = left.dataType match { - case t: ArrayType => t.elementType - case _ => AnyDataType + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } - Seq(ArrayType, elementType) } private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @@ -3100,14 +3113,6 @@ case class ArrayRemove(left: Expression, right: Expression) @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") - } - } - override def nullSafeEval(arr: Any, value: Any): Any = { val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) var pos = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 47eeb70e00427..64152e04928d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -113,6 +113,11 @@ private[sql] class JSONOptions( } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + /** + * Generating JSON strings in pretty representation if the parameter is enabled. + */ + val pretty: Boolean = parameters.get("pretty").map(_.toBoolean).getOrElse(false) + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9b86d865622dc..d02a2be8ddad6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -70,7 +70,10 @@ private[sql] class JacksonGenerator( s"Initial type ${dataType.catalogString} must be a ${MapType.simpleString}") } - private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val gen = { + val generator = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + if (options.pretty) generator.useDefaultPrettyPrinter() else generator + } private val lineSeparator: String = options.lineSeparatorInWrite diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7c461895c5e52..07a653f3b5d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -532,12 +532,12 @@ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand - case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) - case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => proj.zip(e.output).filter { case (_, a) => @@ -547,18 +547,18 @@ object ColumnPruning extends Rule[LogicalPlan] { a.copy(child = Expand(newProjects, newOutput, grandChild)) // Prunes the unused columns from child of `DeserializeToObject` - case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => + case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) => d.copy(child = prunedChild(child, d.references)) // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation - case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) => a.copy(child = prunedChild(child, a.references)) - case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + case f @ FlatMapGroupsInPandas(_, _, _, child) if !child.outputSet.subsetOf(f.references) => f.copy(child = prunedChild(child, f.references)) - case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + case e @ Expand(_, _, child) if !child.outputSet.subsetOf(e.references) => e.copy(child = prunedChild(child, e.references)) case s @ ScriptTransformation(_, _, _, child, _) - if (child.outputSet -- s.references).nonEmpty => + if !child.outputSet.subsetOf(s.references) => s.copy(child = prunedChild(child, s.references)) // prune unrequired references @@ -579,7 +579,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, _: Distinct) => p // Eliminate unneeded attributes from children of Union. case p @ Project(_, u: Union) => - if ((u.outputSet -- p.references).nonEmpty) { + if (!u.outputSet.subsetOf(p.references)) { val firstChild = u.children.head val newOutput = prunedChild(firstChild, p.references).output // pruning the columns of all children based on the pruned first child. @@ -595,7 +595,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } // Prune unnecessary window expressions - case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) => p.copy(child = w.copy( windowExpressions = w.windowExpressions.filter(p.references.contains))) @@ -611,7 +611,7 @@ object ColumnPruning extends Rule[LogicalPlan] { // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references - if ((child.inputSet -- required).nonEmpty) { + if (!child.inputSet.subsetOf(required)) { val newChildren = child.children.map(c => prunedChild(c, required)) p.copy(child = child.withNewChildren(newChildren)) } else { @@ -621,7 +621,7 @@ object ColumnPruning extends Rule[LogicalPlan] { /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = - if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { + if (!c.outputSet.subsetOf(allReferences)) { Project(c.output.filter(allReferences.contains), c) } else { c diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b1ffdca091461..ca0cea6ba7de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -42,7 +42,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. */ - def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) + def references: AttributeSet = AttributeSet.fromAttributeSets(expressions.map(_.references)) /** * The set of all attributes that are input to this operator by its children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index bb2c5926ae9bb..288a4f34a447e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -42,7 +42,11 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma override def iterator: Iterator[(String, T)] = keyLowerCasedMap.iterator override def -(key: String): Map[String, T] = { - new CaseInsensitiveMap(originalMap.filterKeys(!_.equalsIgnoreCase(key))) + new CaseInsensitiveMap(originalMap.filter(!_._1.equalsIgnoreCase(key))) + } + + override def filterKeys(p: (String) => Boolean): Map[String, T] = { + new CaseInsensitiveMap(originalMap.filter(kv => p(kv._1))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0e0a01def357e..f6c98805bfb15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -451,8 +451,11 @@ object SQLConf { .createWithDefault(10) val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") - .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + - "versions, when converting Parquet schema to Spark SQL schema and vice versa.") + .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + + "values will be written in Apache Parquet's fixed-length byte array format, which other " + + "systems such as Apache Hive and Apache Impala use. If false, the newer format in Parquet " + + "will be used. For example, decimals will be written in int-based format. If Parquet " + + "output is intended for use with systems that do not support this newer format, set to true.") .booleanConf .createWithDefault(false) @@ -1295,15 +1298,15 @@ object SQLConf { .booleanConf .createWithDefault(true) - val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = - buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = + buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName") .internal() - .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + - "Pandas DataFrame based on position, regardless of column label type. When false, " + - "columns will be looked up by name if labeled with a string and fallback to use " + - "position if not. This configuration will be deprecated in future releases.") + .doc("When true, columns will be looked up by name if labeled with a string and fallback " + + "to use position if not. When false, a grouped map Pandas UDF will assign columns from " + + "the returned Pandas DataFrame based on position, regardless of column label type. " + + "This configuration will be deprecated in future releases.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() @@ -1328,6 +1331,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LITERAL_PICK_MINIMUM_PRECISION = + buildConf("spark.sql.legacy.literal.pickMinimumPrecision") + .internal() + .doc("When integral literal is used in decimal operations, pick a minimum precision " + + "required by the literal if this config is true, to make the resulting precision and/or " + + "scale smaller. This can reduce the possibility of precision lose and/or overflow.") + .booleanConf + .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = buildConf("spark.sql.redaction.options.regex") .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + @@ -1915,13 +1927,15 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) - def pandasGroupedMapAssignColumnssByPosition: Boolean = - getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def pandasGroupedMapAssignColumnsByName: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME) def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 67740c3166471..3081ff935f043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -22,7 +22,6 @@ import org.scalatest.Suite import org.scalatest.Tag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode @@ -57,7 +56,7 @@ trait CodegenInterpretedPlanTest extends PlanTest { * Provides helper methods for comparing plans, but without the overhead of * mandating a FunSuite. */ -trait PlanTestBase extends PredicateHelper { self: Suite => +trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get @@ -174,32 +173,4 @@ trait PlanTestBase extends PredicateHelper { self: Suite => plan1 == plan2 } } - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL - * configurations. - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val conf = SQLConf.get - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (conf.contains(key)) { - Some(conf.getConfString(key)) - } else { - None - } - } - (keys, values).zipped.foreach { (k, v) => - if (SQLConf.staticConfKeys.contains(k)) { - throw new AnalysisException(s"Cannot modify the value of a static config: $k") - } - conf.setConfString(k, v) - } - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConfString(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala new file mode 100644 index 0000000000000..4d869d79ad594 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import java.io.File + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +trait SQLHelper { + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala new file mode 100644 index 0000000000000..03eed4aaa750b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer + +class CaseInsensitiveMapSuite extends SparkFunSuite { + private def shouldBeSerializable(m: Map[String, String]): Unit = { + new JavaSerializer(new SparkConf()).newInstance().serialize(m) + } + + test("Keys are case insensitive") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foO" -> "bar")) + assert(m("FOO") == "bar") + assert(m("fOo") == "bar") + assert(m("A") == "b") + shouldBeSerializable(m) + } + + test("CaseInsensitiveMap should be serializable after '-' operator") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")) - "a" + assert(m == Map("foo" -> "bar")) + shouldBeSerializable(m) + } + + test("CaseInsensitiveMap should be serializable after '+' operator") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")) + ("x" -> "y") + assert(m == Map("a" -> "b", "foo" -> "bar", "x" -> "y")) + shouldBeSerializable(m) + } + + test("CaseInsensitiveMap should be serializable after 'filterKeys' method") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")).filterKeys(_ == "foo") + assert(m == Map("foo" -> "bar")) + shouldBeSerializable(m) + } +} diff --git a/sql/core/benchmarks/SortBenchmark-results.txt b/sql/core/benchmarks/SortBenchmark-results.txt new file mode 100644 index 0000000000000..0d00a0c89d02d --- /dev/null +++ b/sql/core/benchmarks/SortBenchmark-results.txt @@ -0,0 +1,17 @@ +================================================================================================ +radix sort +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_162-b12 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +reference TimSort key prefix array 11770 / 11960 2.1 470.8 1.0X +reference Arrays.sort 2106 / 2128 11.9 84.3 5.6X +radix sort one byte 93 / 100 269.7 3.7 126.9X +radix sort two bytes 171 / 179 146.0 6.9 68.7X +radix sort eight bytes 659 / 664 37.9 26.4 17.9X +radix sort key prefix array 1024 / 1053 24.4 41.0 11.5X + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 533097ac399e9..b1e8fb39ac9de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -131,11 +131,8 @@ object ArrowUtils { } else { Nil } - val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { - Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") - } else { - Nil - } - Map(timeZoneConf ++ pandasColsByPosition: _*) + val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> + conf.pandasGroupedMapAssignColumnsByName.toString) + Map(timeZoneConf ++ pandasColsByName: _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b67d7a1ca54..4c58e77df485e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3619,6 +3619,8 @@ object functions { * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. + * Additionally the function supports the `pretty` option which enables + * pretty JSON generation. * * @group collection_funcs * @since 2.1.0 @@ -3635,6 +3637,8 @@ object functions { * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. + * Additionally the function supports the `pretty` option which enables + * pretty JSON generation. * * @group collection_funcs * @since 2.1.0 diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql deleted file mode 100644 index 92c7e26e3add2..0000000000000 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql +++ /dev/null @@ -1,56 +0,0 @@ --- --- Licensed to the Apache Software Foundation (ASF) under one or more --- contributor license agreements. See the NOTICE file distributed with --- this work for additional information regarding copyright ownership. --- The ASF licenses this file to You under the Apache License, Version 2.0 --- (the "License"); you may not use this file except in compliance with --- the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. --- - -CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES - (101, 1, 1, 1), - (201, 2, 1, 1), - (301, 3, 1, 1), - (401, 4, 1, 11), - (501, 5, 1, null), - (601, 6, null, 1), - (701, 6, null, null), - (102, 1, 2, 2), - (202, 2, 1, 2), - (302, 3, 2, 1), - (402, 4, 2, 12), - (502, 5, 2, null), - (602, 6, null, 2), - (702, 6, null, null), - (103, 1, 3, 3), - (203, 2, 1, 3), - (303, 3, 3, 1), - (403, 4, 3, 13), - (503, 5, 3, null), - (603, 6, null, 3), - (703, 6, null, null), - (104, 1, 4, 4), - (204, 2, 1, 4), - (304, 3, 4, 1), - (404, 4, 4, 14), - (504, 5, 4, null), - (604, 6, null, 4), - (704, 6, null, null), - (800, 7, 1, 1) -as t1(id, px, y, x); - -select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), - regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), - regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) -from t1 group by px order by px; - - -select id, regr_count(y,x) over (partition by px) from t1 order by id; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out deleted file mode 100644 index d7d009a64bf84..0000000000000 --- a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out +++ /dev/null @@ -1,93 +0,0 @@ --- Automatically generated by SQLQueryTestSuite --- Number of queries: 3 - - --- !query 0 -CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES - (101, 1, 1, 1), - (201, 2, 1, 1), - (301, 3, 1, 1), - (401, 4, 1, 11), - (501, 5, 1, null), - (601, 6, null, 1), - (701, 6, null, null), - (102, 1, 2, 2), - (202, 2, 1, 2), - (302, 3, 2, 1), - (402, 4, 2, 12), - (502, 5, 2, null), - (602, 6, null, 2), - (702, 6, null, null), - (103, 1, 3, 3), - (203, 2, 1, 3), - (303, 3, 3, 1), - (403, 4, 3, 13), - (503, 5, 3, null), - (603, 6, null, 3), - (703, 6, null, null), - (104, 1, 4, 4), - (204, 2, 1, 4), - (304, 3, 4, 1), - (404, 4, 4, 14), - (504, 5, 4, null), - (604, 6, null, 4), - (704, 6, null, null), - (800, 7, 1, 1) -as t1(id, px, y, x) --- !query 0 schema -struct<> --- !query 0 output - - - --- !query 1 -select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), - regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), - regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) -from t1 group by px order by px --- !query 1 schema -struct --- !query 1 output -1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 -2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 -3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 -4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 -5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 -6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 -7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 - - --- !query 2 -select id, regr_count(y,x) over (partition by px) from t1 order by id --- !query 2 schema -struct --- !query 2 output -101 4 -102 4 -103 4 -104 4 -201 4 -202 4 -203 4 -204 4 -301 4 -302 4 -303 4 -304 4 -401 4 -402 4 -403 4 -404 4 -501 0 -502 0 -503 0 -504 0 -601 0 -602 0 -603 0 -604 0 -701 0 -702 0 -703 0 -704 0 -800 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fd71f24935611..88dbae8c21350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1574,6 +1574,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null)) ) + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1, 2), 1.23D)"), + Seq( + Row(Seq(1.0, 2.0)) + ) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1, 2), 1.0D)"), + Seq( + Row(Seq(2.0)) + ) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1.0D, 2.0D), 2)"), + Seq( + Row(Seq(1.0)) + ) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1.1D, 1.2D), 1)"), + Seq( + Row(Seq(1.1, 1.2)) + ) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), @@ -1583,10 +1611,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null, null)) ) - val e = intercept[AnalysisException] { + val e1 = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") } - assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |Input to function array_remove should have been array followed by a + |value with same element type, but it's [string, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_remove(array(1, 2), '1')") + } + + val errorMsg2 = + s""" + |Input to function array_remove should have been array followed by a + |value with same element type, but it's [array, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) } test("array_distinct functions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index fe4bf15fa3921..853bc182f2f4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -518,4 +518,25 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { jsonDF.select(to_json(from_json($"a", schema))), Seq(Row(json))) } + + test("pretty print - roundtrip from_json -> to_json") { + val json = """[{"book":{"publisher":[{"country":"NL","year":[1981,1986,1999]}]}}]""" + val jsonDF = Seq(json).toDF("root") + val expected = + """[ { + | "book" : { + | "publisher" : [ { + | "country" : "NL", + | "year" : [ 1981, 1986, 1999 ] + | } ] + | } + |} ]""".stripMargin + + checkAnswer( + jsonDF.select( + to_json( + from_json($"root", schema_of_json(lit(json))), + Map("pretty" -> "true"))), + Seq(Row(expected))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8fcebb35a0543..631ab1b7ece7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2849,6 +2849,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val result = ds.flatMap(_.bar).distinct result.rdd.isEmpty } + + test("SPARK-25454: decimal division with negative scale") { + // TODO: completely fix this issue even when LITERAL_PRECISE_PRECISION is true. + withSQLConf(SQLConf.LITERAL_PICK_MINIMUM_PRECISION.key -> "false") { + checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000"))) + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index cf9bda2fb1ff1..51a7f9f1ef096 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution.benchmark import java.io.File import scala.collection.JavaConverters._ -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnVector -import org.apache.spark.util.Utils /** @@ -37,7 +37,7 @@ import org.apache.spark.util.Utils * To run this: * spark-submit --class */ -object DataSourceReadBenchmark { +object DataSourceReadBenchmark extends SQLHelper { val conf = new SparkConf() .setAppName("DataSourceReadBenchmark") // Since `spark.master` always exists, overrides this value @@ -54,27 +54,10 @@ object DataSourceReadBenchmark { spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { val testDf = if (partition.isDefined) { df.write.partitionBy(partition.get) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 3b7f10783b64c..7cdf653e38697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.benchmark import java.io.File -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} -import org.apache.spark.util.Utils /** * Benchmark to measure read performance with Filter pushdown. @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". * }}} */ -object FilterPushdownBenchmark extends BenchmarkBase { +object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper { private val conf = new SparkConf() .setAppName(this.getClass.getSimpleName) @@ -60,28 +60,10 @@ object FilterPushdownBenchmark extends BenchmarkBase { private val spark = SparkSession.builder().config(conf).getOrCreate() - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - private def prepareTable( dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { import spark.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 17619ec5fadc1..958a064402149 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.unsafe.array.LongArray import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.collection.Sorter @@ -28,12 +28,15 @@ import org.apache.spark.util.random.XORShiftRandom /** * Benchmark to measure performance for aggregate primitives. - * To run this: - * build/sbt "sql/test-only *benchmark.SortBenchmark" - * - * Benchmarks in this file are skipped in normal builds. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} */ -class SortBenchmark extends BenchmarkWithCodegen { +object SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) @@ -54,10 +57,10 @@ class SortBenchmark extends BenchmarkWithCodegen { new LongArray(MemoryBlock.fromLongArray(extended))) } - ignore("sort") { + def sortBenchmark(): Unit = { val size = 25000000 val rand = new XORShiftRandom(123) - val benchmark = new Benchmark("radix sort " + size, size) + val benchmark = new Benchmark("radix sort " + size, size, output = output) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } val buf = new LongArray(MemoryBlock.fromLongArray(array)) @@ -114,20 +117,11 @@ class SortBenchmark extends BenchmarkWithCodegen { timer.stopTiming() } benchmark.run() + } - /* - Running benchmark: radix sort 25000000 - Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic - Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz - - radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X - reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X - radix sort one byte 133 / 137 188.4 5.3 117.2X - radix sort two bytes 255 / 258 98.2 10.2 61.1X - radix sort eight bytes 991 / 997 25.2 39.6 15.7X - radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X - */ + override def benchmark(): Unit = { + runBenchmark("radix sort") { + sortBenchmark() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 6d319eb723d93..5d1a874999c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -16,21 +16,19 @@ */ package org.apache.spark.sql.execution.datasources.csv -import java.io.File - import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * Benchmark to measure CSV read/write performance. * To run this: * spark-submit --class --jars */ -object CSVBenchmarks { +object CSVBenchmarks extends SQLHelper { val conf = new SparkConf() val spark = SparkSession.builder @@ -40,12 +38,6 @@ object CSVBenchmarks { .getOrCreate() import spark.implicits._ - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index e40cb9b50148b..368318ab38cb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -21,16 +21,16 @@ import java.io.File import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. * To run this: * spark-submit --class --jars */ -object JSONBenchmarks { +object JSONBenchmarks extends SQLHelper { val conf = new SparkConf() val spark = SparkSession.builder @@ -40,13 +40,6 @@ object JSONBenchmarks { .getOrCreate() import spark.implicits._ - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def schemaInferring(rowsNum: Int): Unit = { val benchmark = new Benchmark("JSON schema inferring", rowsNum) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala index fe59cb25d5005..cbac1c13cdd33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -25,12 +25,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.Utils -abstract class CheckpointFileManagerTests extends SparkFunSuite { +abstract class CheckpointFileManagerTests extends SparkFunSuite with SQLHelper { def createManager(path: Path): CheckpointFileManager @@ -88,12 +88,6 @@ abstract class CheckpointFileManagerTests extends SparkFunSuite { fm.delete(path) // should not throw exception } } - - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } } class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 2fb8f70a20791..6b03d1e5b7662 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.UninterruptibleThread import org.apache.spark.util.Utils @@ -167,18 +166,6 @@ private[sql] trait SQLTestUtilsBase super.withSQLConf(pairs: _*)(f) } - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - /** * Copy file in jar's resource to a temp file, then pass it to `f`. * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 0eab7d1ea8e80..49de007df3828 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive.orc import java.io.File -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils + /** * Benchmark to measure ORC read performance. @@ -34,7 +35,7 @@ import org.apache.spark.util.Utils * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. */ // scalastyle:off line.size.limit -object OrcReadBenchmark { +object OrcReadBenchmark extends SQLHelper { val conf = new SparkConf() conf.set("orc.compression", "snappy") @@ -47,28 +48,10 @@ object OrcReadBenchmark { // Set default configs. Individual cases will change them if necessary. spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index a882558551e37..135430f1ef621 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -59,6 +59,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", + "spark.kerberos.keytab", + "spark.kerberos.principal", "spark.ui.filters", "spark.mesos.driver.frameworkId")