diff --git a/build/mvn b/build/mvn index a78b93a685689..c3ab62da36868 100755 --- a/build/mvn +++ b/build/mvn @@ -141,9 +141,13 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then + ZINC_JAVA_HOME= + if [ -n "$JAVA_7_HOME" ]; then + ZINC_JAVA_HOME="env JAVA_HOME=$JAVA_7_HOME" + fi export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} - "${ZINC_BIN}" -start -port ${ZINC_PORT} \ + $ZINC_JAVA_HOME "${ZINC_BIN}" -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java index e8d999dd00135..462ca3f6f6d19 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -22,7 +22,7 @@ /** * Base interface for a function used in Dataset's filter function. * - * If the function returns true, the element is discarded in the returned Dataset. + * If the function returns true, the element is included in the returned Dataset. */ public interface FilterFunction extends Serializable { boolean call(T value) throws Exception; diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala index 0f9bac7164162..e19f12fdac090 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/package.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala @@ -22,4 +22,4 @@ package org.apache.spark.api.java * these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's * Java programming guide for more details. */ -package object function +package object function diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 38a21a896e1fe..fc1f3a80239ba 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -23,7 +23,7 @@ import org.apache.spark.unsafe.memory.MemoryBlock; /** - * An memory consumer of TaskMemoryManager, which support spilling. + * A memory consumer of {@link TaskMemoryManager} that supports spilling. * * Note: this only supports allocation / spilling of Tungsten memory. */ @@ -45,7 +45,7 @@ protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { } /** - * Returns the memory mode, ON_HEAP or OFF_HEAP. + * Returns the memory mode, {@link MemoryMode#ON_HEAP} or {@link MemoryMode#OFF_HEAP}. */ public MemoryMode getMode() { return mode; diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index fb6323413e3ea..bfb6a35f5bb93 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -596,6 +596,7 @@ class JavaSparkContext(val sc: SparkContext) * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `add` method. Only the master can access the accumulator's `value`. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) @@ -605,6 +606,7 @@ class JavaSparkContext(val sc: SparkContext) * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T]) : Accumulator[T] = sc.accumulator(initialValue, name)(accumulatorParam) @@ -613,6 +615,7 @@ class JavaSparkContext(val sc: SparkContext) * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks * can "add" values with `add`. Only the master can access the accumuable's `value`. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = sc.accumulable(initialValue)(param) @@ -622,6 +625,7 @@ class JavaSparkContext(val sc: SparkContext) * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R]) : Accumulable[T, R] = sc.accumulable(initialValue, name)(param) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 20ffe1342e509..dd8f5bacb9f6e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -211,9 +211,6 @@ private[storage] class BlockInfoManager extends Logging { * If another task has already locked this block for either reading or writing, then this call * will block until the other locks are released or will return immediately if `blocking = false`. * - * If this is called by a task which already holds the block's exclusive write lock, then this - * method will throw an exception. - * * @param blockId the block to lock. * @param blocking if true (default), this call will block until the lock is acquired. If false, * this call will return immediately if the lock acquisition fails. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 216ec0793492f..fad0404bebc36 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -120,8 +120,14 @@ class StorageLevel private( private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) override def toString: String = { - s"StorageLevel(disk=$useDisk, memory=$useMemory, offheap=$useOffHeap, " + - s"deserialized=$deserialized, replication=$replication)" + val disk = if (useDisk) "disk" else "" + val memory = if (useMemory) "memory" else "" + val heap = if (useOffHeap) "offheap" else "" + val deserialize = if (deserialized) "deserialized" else "" + + val output = + Seq(disk, memory, heap, deserialize, s"$replication replicas").filter(_.nonEmpty) + s"StorageLevel(${output.mkString(", ")})" } override def hashCode(): Int = toInt * 41 + replication diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7e204fa21852c..1a9dbcae8c083 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2344,29 +2344,24 @@ private[spark] class RedirectThread( * the toString method. */ private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](sizeInBytes) + private var pos: Int = 0 + private var isBufferFull = false + private val buffer = new Array[Byte](sizeInBytes) - def write(i: Int): Unit = { - buffer(pos) = i + def write(input: Int): Unit = { + buffer(pos) = input.toByte pos = (pos + 1) % buffer.length + isBufferFull = isBufferFull || (pos == 0) } override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while (line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() + if (!isBufferFull) { + return new String(buffer, 0, pos, StandardCharsets.UTF_8) } - stringBuilder.toString() + + val nonCircularBuffer = new Array[Byte](sizeInBytes) + System.arraycopy(buffer, pos, nonCircularBuffer, 0, buffer.length - pos) + System.arraycopy(buffer, 0, nonCircularBuffer, buffer.length - pos, pos) + new String(nonCircularBuffer, StandardCharsets.UTF_8) } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4aa4854c36f3a..66987498669d4 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, PrintStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} @@ -681,14 +681,37 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, childFile3)) } - test("circular buffer") { + test("circular buffer: if nothing was written to the buffer, display nothing") { + val buffer = new CircularBuffer(4) + assert(buffer.toString === "") + } + + test("circular buffer: if the buffer isn't full, print only the contents written") { + val buffer = new CircularBuffer(10) + val stream = new PrintStream(buffer, true, "UTF-8") + stream.print("test") + assert(buffer.toString === "test") + } + + test("circular buffer: data written == size of the buffer") { + val buffer = new CircularBuffer(4) + val stream = new PrintStream(buffer, true, "UTF-8") + + // fill the buffer to its exact size so that it just hits overflow + stream.print("test") + assert(buffer.toString === "test") + + // add more data to the buffer + stream.print("12") + assert(buffer.toString === "st12") + } + + test("circular buffer: multiple overflow") { val buffer = new CircularBuffer(25) - val stream = new java.io.PrintStream(buffer, true, "UTF-8") + val stream = new PrintStream(buffer, true, "UTF-8") - // scalastyle:off println - stream.println("test circular test circular test circular test circular test circular") - // scalastyle:on println - assert(buffer.toString === "t circular test circular\n") + stream.print("test circular test circular test circular test circular test circular") + assert(buffer.toString === "st circular test circular") } test("nanSafeCompareDoubles") { diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index c50f25d951947..a68fd0285f567 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -29,13 +29,10 @@ object BroadcastTest { val blockSize = if (args.length > 2) args(2) else "4096" - val sparkConf = new SparkConf() - .set("spark.broadcast.blockSize", blockSize) - val spark = SparkSession - .builder - .config(sparkConf) + .builder() .appName("Broadcast Test") + .config("spark.broadcast.blockSize", blockSize) .getOrCreate() val sc = spark.sparkContext diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 7651aade493a0..3fbf8e03339e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -191,6 +191,7 @@ object LDAExample { val spark = SparkSession .builder + .sparkContext(sc) .getOrCreate() import spark.implicits._ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index d3bb7e4398cd3..2d7a01a95d830 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -22,7 +22,6 @@ import java.io.File import com.google.common.io.{ByteStreams, Files} -import org.apache.spark.SparkConf import org.apache.spark.sql._ object HiveFromSpark { @@ -35,8 +34,6 @@ object HiveFromSpark { ByteStreams.copy(kv1Stream, Files.newOutputStreamSupplier(kv1File)) def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("HiveFromSpark") - // When working with Hive, one must instantiate `SparkSession` with Hive support, including // connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined // functions. Users who do not have an existing Hive deployment can still enable Hive support. @@ -45,7 +42,7 @@ object HiveFromSpark { // which defaults to the directory `spark-warehouse` in the current directory that the spark // application is started. val spark = SparkSession.builder - .config(sparkConf) + .appName("HiveFromSpark") .enableHiveSupport() .getOrCreate() val sc = spark.sparkContext diff --git a/external/java8-tests/pom.xml b/external/java8-tests/pom.xml index 60e3ff60df065..74a3ee1ce11e2 100644 --- a/external/java8-tests/pom.xml +++ b/external/java8-tests/pom.xml @@ -106,6 +106,7 @@ net.alchim31.maven scala-maven-plugin + ${useZincForJdk8} -source 1.8 diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index ec60991af64ff..5aec692c98e6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging @@ -696,8 +696,8 @@ class DistributedLDAModel private[ml] ( @DeveloperApi @Since("2.0.0") def deleteCheckpointFiles(): Unit = { - val fs = FileSystem.get(sparkSession.sparkContext.hadoopConfiguration) - _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) + val hadoopConf = sparkSession.sparkContext.hadoopConfiguration + _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, hadoopConf)) _checkpointFiles = Array.empty[String] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 800430f96c5b1..a7c5f489dea86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -21,7 +21,7 @@ import java.io.IOException import scala.collection.mutable -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{LearningNode, Split} @@ -77,8 +77,8 @@ private[spark] class NodeIdCache( // Indicates whether we can checkpoint private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty - // FileSystem instance for deleting checkpoints as needed - private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration) + // Hadoop Configuration for deleting checkpoints as needed + private val hadoopConf = nodeIdsForInstances.sparkContext.hadoopConfiguration /** * Update the node index values in the cache. @@ -130,7 +130,9 @@ private[spark] class NodeIdCache( val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we'll manually delete it here. try { - fs.delete(new Path(old.getCheckpointFile.get), true) + val path = new Path(old.getCheckpointFile.get) + val fs = path.getFileSystem(hadoopConf) + fs.delete(path, true) } catch { case e: IOException => logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + @@ -154,7 +156,9 @@ private[spark] class NodeIdCache( val old = checkpointQueue.dequeue() if (old.getCheckpointFile.isDefined) { try { - fs.delete(new Path(old.getCheckpointFile.get), true) + val path = new Path(old.getCheckpointFile.get) + val fs = path.getFileSystem(hadoopConf) + fs.delete(path, true) } catch { case e: IOException => logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6e0ed374c7ee1..e43469bf1cf86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1177,7 +1177,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. val sc = indexedRowMatrix.rows.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(indexedRowMatrix.rows) } @@ -1188,7 +1188,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. val sc = coordinateMatrix.entries.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(coordinateMatrix.entries) } @@ -1199,7 +1199,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. val sc = blockMatrix.blocks.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(blockMatrix.blocks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index b186ca37703d3..e4cc784cfe421 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -437,7 +437,7 @@ class LogisticRegressionWithLBFGS lr.setMaxIter(optimizer.getNumIterations()) lr.setTol(optimizer.getConvergenceTol()) // Convert our input into a DataFrame - val spark = SparkSession.builder().config(input.context.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(input.context).getOrCreate() val df = spark.createDataFrame(input.map(_.asML)) // Determine if we should cache the DF val handlePersistence = input.getStorageLevel == StorageLevel.NONE diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 452802f043abf..593a86f69ad51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -193,7 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -207,7 +207,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -238,7 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -251,7 +251,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 32e323d080af8..84491181d077f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -73,7 +73,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 510a91b5a77fd..b3546a1ee3677 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -144,7 +144,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rootId" -> model.root.index))) @@ -165,7 +165,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { } def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 31ad56dba6aef..c30cc3e2398e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -143,7 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { path: String, weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render @@ -159,7 +159,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 5f939c1a218fb..aa78149699a27 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -123,7 +123,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) @@ -135,7 +135,7 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 1b66013d543ad..4f07236225cd2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -446,7 +446,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { docConcentration: Vector, topicConcentration: Double, gammaShape: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -470,7 +470,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { topicConcentration: Double, gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) Loader.checkSchema[Data](dataFrame.schema) @@ -851,7 +851,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { topicConcentration: Double, iterationTimes: Array[Double], gammaShape: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -887,7 +887,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) val vertexDataFrame = spark.read.parquet(vertexDataPath) val edgeDataFrame = spark.read.parquet(edgeDataPath) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 51077bd630a15..c760ddd6ad40b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -82,7 +82,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 13decefcd6695..c8c2823bbaf04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -149,7 +149,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9bd79aa7c627e..2f52825c6cb01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -609,7 +609,7 @@ object Word2VecModel extends Loader[Word2VecModel] { case class Data(word: String, vector: Array[Float]) def load(sc: SparkContext, path: String): Word2VecModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(Loader.dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -620,7 +620,7 @@ object Word2VecModel extends Loader[Word2VecModel] { } def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val vectorSize = model.values.head.length val numWords = model.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 8c0639baeaca4..0f7fbe9556c5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -99,7 +99,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { def save(model: FPGrowthModel[_], path: String): Unit = { val sc = model.freqItemsets.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -123,7 +123,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { def load(sc: SparkContext, path: String): FPGrowthModel[_] = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 10bbcd2a3d924..c13c794775fec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -616,7 +616,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { def save(model: PrefixSpanModel[_], path: String): Unit = { val sc = model.freqSequences.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -640,7 +640,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 5c12c9305b99c..4dd498cd91b4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.impl import scala.collection.mutable -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext import org.apache.spark.internal.Logging @@ -160,21 +161,23 @@ private[mllib] abstract class PeriodicCheckpointer[T]( private def removeCheckpointFile(): Unit = { val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) + getCheckpointFiles(old).foreach( + PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration)) } } private[spark] object PeriodicCheckpointer extends Logging { /** Delete a checkpoint file, and log a warning if deletion fails. */ - def removeCheckpointFile(path: String, fs: FileSystem): Unit = { + def removeCheckpointFile(checkpointFile: String, conf: Configuration): Unit = { try { - fs.delete(new Path(path), true) + val path = new Path(checkpointFile) + val fs = path.getFileSystem(conf) + fs.delete(path, true) } catch { case e: Exception => logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + - path) + checkpointFile) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 450025f477f19..c642573ccba6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -354,7 +354,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() import spark.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -365,7 +365,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 215a799b9646a..1cd6f2a8969a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 3c7bbc52446d9..cd90e97cc5388 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -68,7 +68,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 72663188a98ae..a1562384b0a7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -233,13 +233,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { // Create Parquet data. val nodes = model.topNode.subtreeIterator.toSeq val dataRDD = sc.parallelize(nodes).map(NodeData.apply(0, _)) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { // Load Parquet data. - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataPath = Loader.dataPath(path) val dataRDD = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index c653b988e21e7..f7d9b22b6f424 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -413,7 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // SPARK-6120: We do a hacky check here so users understand why save() is failing // when they run the ML guide example. @@ -471,7 +471,7 @@ private[tree] object TreeEnsembleModel extends Logging { sc: SparkContext, path: String, treeAlgo: String): Array[DecisionTreeModel] = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() import spark.implicits._ val nodes = spark.read.parquet(Loader.dataPath(path)).map(NodeData.apply) val trees = constructTrees(nodes.rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 40d5b4881f839..3558290b23ae0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -23,18 +23,14 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test Chi-Square selector") { - val spark = SparkSession.builder - .master("local[2]") - .appName("ChiSqSelectorSuite") - .getOrCreate() + val spark = this.spark import spark.implicits._ - val data = Seq( LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 621c13a8e5ac6..b73dbd62328cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -27,7 +27,7 @@ class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test observed number of buckets and their sizes match expected values") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val datasetSize = 100000 @@ -53,7 +53,7 @@ class QuantileDiscretizerSuite } test("Test transform method on unseen data") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") @@ -82,7 +82,7 @@ class QuantileDiscretizerSuite } test("Verify resulting model has parent") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 59b5edc4013e8..e8ed50acf877c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -591,6 +591,7 @@ class ALSCleanerSuite extends SparkFunSuite { val spark = SparkSession.builder .master("local[2]") .appName("ALSCleanerSuite") + .sparkContext(sc) .getOrCreate() import spark.implicits._ val als = new ALS() @@ -606,7 +607,7 @@ class ALSCleanerSuite extends SparkFunSuite { val pattern = "shuffle_(\\d+)_.+\\.data".r val rddIds = resultingFiles.flatMap { f => pattern.findAllIn(f.getName()).matchData.map { _.group(1) } } - assert(rddIds.toSet.size === 4) + assert(rddIds.size === 4) } finally { sc.stop() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 8cbd652bacf31..d2fa8d0d6335d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -42,9 +42,10 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val spark = SparkSession.builder + val spark = SparkSession.builder() .master("local[2]") .appName("TreeTests") + .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index e331c75989187..a13e7f63a9296 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.impl -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} @@ -140,9 +140,11 @@ private object PeriodicGraphCheckpointerSuite { // Instead, we check for the presence of the checkpoint files. // This test should continue to work even after this graph.isCheckpointed issue // is fixed (though it can then be simplified and not look for the files). - val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration) + val hadoopConf = graph.vertices.sparkContext.hadoopConfiguration graph.getCheckpointFiles.foreach { checkpointFile => - assert(!fs.exists(new Path(checkpointFile)), + val path = new Path(checkpointFile) + val fs = path.getFileSystem(hadoopConf) + assert(!fs.exists(path), "Graph checkpoint file should have been removed") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala index b2a459a68b5fa..14adf8c29fc6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.impl -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -127,9 +127,11 @@ private object PeriodicRDDCheckpointerSuite { // Instead, we check for the presence of the checkpoint files. // This test should continue to work even after this rdd.isCheckpointed issue // is fixed (though it can then be simplified and not look for the files). - val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + val hadoopConf = rdd.sparkContext.hadoopConfiguration rdd.getCheckpointFile.foreach { checkpointFile => - assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + val path = new Path(checkpointFile) + val fs = path.getFileSystem(hadoopConf) + assert(!fs.exists(path), "RDD checkpoint file should have been removed") } } diff --git a/pom.xml b/pom.xml index fff5560afea2a..60c8c8dc7a727 100644 --- a/pom.xml +++ b/pom.xml @@ -184,6 +184,9 @@ ${java.home} + + true + org.spark_project @@ -2576,6 +2579,42 @@ + + java7 + + env.JAVA_7_HOME + + + false + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + -bootclasspath + ${env.JAVA_7_HOME}/jre/lib/rt.jar + + + + + net.alchim31.maven + scala-maven-plugin + + + -javabootclasspath + ${env.JAVA_7_HOME}/jre/lib/rt.jar + + + + + + + + scala-2.11 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f08ca7001f34d..744f57c5177a3 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -277,12 +277,24 @@ object SparkBuild extends PomBuild { // additional discussion and explanation. javacOptions in (Compile, compile) ++= Seq( "-target", javacJVMVersion.value - ), + ) ++ sys.env.get("JAVA_7_HOME").toSeq.flatMap { jdk7 => + if (javacJVMVersion.value == "1.7") { + Seq("-bootclasspath", s"$jdk7/jre/lib/rt.jar") + } else { + Nil + } + }, scalacOptions in Compile ++= Seq( s"-target:jvm-${scalacJVMVersion.value}", "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc - ), + ) ++ sys.env.get("JAVA_7_HOME").toSeq.flatMap { jdk7 => + if (javacJVMVersion.value == "1.7") { + Seq("-javabootclasspath", s"$jdk7/jre/lib/rt.jar") + } else { + Nil + } + }, // Implements -Xfatal-warnings, ignoring deprecation warnings. // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index a457904e7880a..92df19e804374 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -64,6 +64,21 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte .. note:: Experimental GaussianMixture clustering. + This class performs expectation maximization for multivariate Gaussian + Mixture Models (GMMs). A GMM represents a composite distribution of + independent Gaussian distributions with associated "mixing" weights + specifying each's contribution to the composite. + + Given a set of sample points, this class will maximize the log-likelihood + for a mixture of k Gaussians, iterating until the log-likelihood changes by + less than convergenceTol, or until it has reached the max number of iterations. + While this process is generally guaranteed to converge, it is not guaranteed + to find a global optimum. + + Note: For high-dimensional data (with many features), this algorithm may perform poorly. + This is due to high-dimensional data (a) making it difficult to cluster at all + (based on statistical/theoretical arguments) and (b) numerical issues with + Gaussian distributions. >>> from pyspark.ml.linalg import Vectors @@ -118,8 +133,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte .. versionadded:: 2.0.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create", - typeConverter=TypeConverters.toInt) + k = Param(Params._dummy(), "k", "Number of independent Gaussians in the mixture model. " + + "Must be > 1.", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -227,15 +242,15 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol .. versionadded:: 1.5.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create", + k = Param(Params._dummy(), "k", "The number of clusters to create. Must be > 1.", typeConverter=TypeConverters.toInt) initMode = Param(Params._dummy(), "initMode", - "the initialization algorithm. This can be either \"random\" to " + + "The initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + "to use a parallel variant of k-means++", typeConverter=TypeConverters.toString) - initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode", - typeConverter=TypeConverters.toInt) + initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -380,11 +395,11 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte .. versionadded:: 2.0.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create", + k = Param(Params._dummy(), "k", "The desired number of leaf clusters. Must be > 1.", typeConverter=TypeConverters.toInt) minDivisibleClusterSize = Param(Params._dummy(), "minDivisibleClusterSize", - "the minimum number of points (if >= 1.0) " + - "or the minimum proportion", + "The minimum number of points (if >= 1.0) or the minimum " + + "proportion of points (if < 1.0) of a divisible cluster.", typeConverter=TypeConverters.toFloat) @keyword_only @@ -661,7 +676,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter .. versionadded:: 2.0.0 """ - k = Param(Params._dummy(), "k", "number of topics (clusters) to infer", + k = Param(Params._dummy(), "k", "The number of topics (clusters) to infer. Must be > 1.", typeConverter=TypeConverters.toInt) optimizer = Param(Params._dummy(), "optimizer", "Optimizer or inference algorithm used to estimate the LDA model. " diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 73105f881b464..9208a527d29c3 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -497,6 +497,26 @@ def mode(self, saveMode): self._jwrite = self._jwrite.mode(saveMode) return self + @since(2.0) + def outputMode(self, outputMode): + """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + + Options include: + + * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to + the sink + * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink + every time these is some updates + + .. note:: Experimental. + + >>> writer = sdf.write.outputMode('append') + """ + if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: + raise ValueError('The output mode must be a non-empty string. Got: %s' % outputMode) + self._jwrite = self._jwrite.outputMode(outputMode) + return self + @since(1.4) def format(self, source): """Specifies the underlying output data source. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 8238b8e7cde6b..cd75622cedf5e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -201,7 +201,8 @@ def __init__(self, interval): self.interval = interval def _to_java_trigger(self, sqlContext): - return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) + return sqlContext._sc._jvm.org.apache.spark.sql.streaming.ProcessingTime.create( + self.interval) def _test(): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1790432edd5dc..0d9dd5ea2a364 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -926,7 +926,7 @@ def test_stream_save_options(self): out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') cq = df.write.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').option('path', out).startStream() + .format('parquet').outputMode('append').option('path', out).startStream() try: self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) @@ -952,8 +952,9 @@ def test_stream_save_options_overwrite(self): fake1 = os.path.join(tmpPath, 'fake1') fake2 = os.path.join(tmpPath, 'fake2') cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ - .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', - checkpointLocation=chk) + .queryName('fake_query').outputMode('append') \ + .startStream(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8c8768f50bfde..9ddaf78acf91d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -71,7 +71,7 @@ def deco(*a, **kw): raise AnalysisException(s.split(': ', 1)[1], stackTrace) if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): raise ParseException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.ContinuousQueryException: '): + if s.startswith('org.apache.spark.sql.streaming.ContinuousQueryException: '): raise ContinuousQueryException(s.split(': ', 1)[1], stackTrace) if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '): raise QueryExecutionException(s.split(': ', 1)[1], stackTrace) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 005edda2bee76..771670fa559a0 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -22,6 +22,7 @@ import java.io.File import scala.tools.nsc.GenericRunnerSettings import org.apache.spark._ +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils @@ -88,10 +89,23 @@ object Main extends Logging { } val builder = SparkSession.builder.config(conf) - if (SparkSession.hiveClassesArePresent) { - sparkSession = builder.enableHiveSupport().getOrCreate() - logInfo("Created Spark session with Hive support") + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } } else { + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. sparkSession = builder.getOrCreate() logInfo("Created Spark session") } diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index af82e7a111fa8..125686030c01f 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,9 +21,11 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer - import org.apache.commons.lang3.StringEscapeUtils +import org.apache.log4j.{Level, LogManager} import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils class ReplSuite extends SparkFunSuite { @@ -99,6 +101,52 @@ class ReplSuite extends SparkFunSuite { System.clearProperty("spark.driver.port") } + test("SPARK-15236: use Hive catalog") { + // turn on the INFO log so that it is possible the code will dump INFO + // entry for using "HiveMetastore" + val rootLogger = LogManager.getRootLogger() + val logLevel = rootLogger.getLevel + rootLogger.setLevel(Level.INFO) + try { + Main.conf.set(CATALOG_IMPLEMENTATION.key, "hive") + val output = runInterpreter("local", + """ + |spark.sql("drop table if exists t_15236") + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + // only when the config is set to hive and + // hive classes are built, we will use hive catalog. + // Then log INFO entry will show things using HiveMetastore + if (SparkSession.hiveClassesArePresent) { + assertContains("HiveMetaStore", output) + } else { + // If hive classes are not built, in-memory catalog will be used + assertDoesNotContain("HiveMetaStore", output) + } + } finally { + rootLogger.setLevel(logLevel) + } + } + + test("SPARK-15236: use in-memory catalog") { + val rootLogger = LogManager.getRootLogger() + val logLevel = rootLogger.getLevel + rootLogger.setLevel(Level.INFO) + try { + Main.conf.set(CATALOG_IMPLEMENTATION.key, "in-memory") + val output = runInterpreter("local", + """ + |spark.sql("drop table if exists t_16236") + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertDoesNotContain("HiveMetaStore", output) + } finally { + rootLogger.setLevel(logLevel) + } + } + test("simple foreach with accumulator") { val output = runInterpreter("local", """ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java new file mode 100644 index 0000000000000..41e2582921198 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.InternalOutputModes; + +/** + * :: Experimental :: + * + * OutputMode is used to what data will be written to a streaming sink when there is + * new data available in a streaming DataFrame/Dataset. + * + * @since 2.0.0 + */ +@Experimental +public class OutputMode { + + /** + * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be + * written to the sink. This output mode can be only be used in queries that do not + * contain any aggregation. + * + * @since 2.0.0 + */ + public static OutputMode Append() { + return InternalOutputModes.Append$.MODULE$; + } + + /** + * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates. This output mode can only be used in queries + * that contain aggregations. + * + * @since 2.0.0 + */ + public static OutputMode Complete() { + return InternalOutputModes.Complete$.MODULE$; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala new file mode 100644 index 0000000000000..153f9f57faf42 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.streaming.OutputMode + +/** + * Internal helper class to generate objects representing various [[OutputMode]]s, + */ +private[sql] object InternalOutputModes { + + /** + * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be + * written to the sink. This output mode can be only be used in queries that do not + * contain any aggregation. + */ + case object Append extends OutputMode + + /** + * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates. This output mode can only be used in queries + * that contain aggregations. + */ + case object Complete extends OutputMode + + /** + * OutputMode in which only the rows in the streaming DataFrame/Dataset that were updated will be + * written to the sink every time these is some updates. This output mode can only be used in + * queries that contain aggregations. + */ + case object Update extends OutputMode +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c5f221d7830f3..7b451baaa02b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -309,14 +309,14 @@ trait CheckAnalysis extends PredicateHelper { case s: SimpleCatalogRelation => failAnalysis( s""" - |Please enable Hive support when selecting the regular tables: + |Hive support is required to select over the following tables: |${s.catalogTable.identifier} """.stripMargin) case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) => failAnalysis( s""" - |Please enable Hive support when inserting the regular tables: + |Hive support is required to insert into the following tables: |${s.catalogTable.identifier} """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 387e5552549e6..a5b5b91e4ab3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -290,11 +290,6 @@ object TypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) => - a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right)) - case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) => - a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT))) - case a @ BinaryArithmetic(left @ StringType(), right) => a.makeCopy(Array(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringType()) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 0e08bf013c8d9..8373fa336dd4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, InternalOutputModes} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.OutputMode /** * Analyzes the presence of unsupported operations in a logical plan. @@ -29,8 +30,7 @@ object UnsupportedOperationChecker { def checkForBatch(plan: LogicalPlan): Unit = { plan.foreachUp { case p if p.isStreaming => - throwError( - "Queries with streaming sources must be executed with write.startStream()")(p) + throwError("Queries with streaming sources must be executed with write.startStream()")(p) case _ => } @@ -43,10 +43,10 @@ object UnsupportedOperationChecker { "Queries without streaming sources cannot be executed with write.startStream()")(plan) } - plan.foreachUp { implicit plan => + plan.foreachUp { implicit subPlan => // Operations that cannot exists anywhere in a streaming plan - plan match { + subPlan match { case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + @@ -55,21 +55,6 @@ object UnsupportedOperationChecker { case _: InsertIntoTable => throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") - case Aggregate(_, _, child) if child.isStreaming => - if (outputMode == Append) { - throwError( - "Aggregations are not supported on streaming DataFrames/Datasets in " + - "Append output mode. Consider changing output mode to Update.") - } - val moreStreamingAggregates = child.find { - case Aggregate(_, _, grandchild) if grandchild.isStreaming => true - case _ => false - } - if (moreStreamingAggregates.nonEmpty) { - throwError("Multiple streaming aggregations are not supported with " + - "streaming DataFrames/Datasets") - } - case Join(left, right, joinType, _) => joinType match { @@ -119,10 +104,10 @@ object UnsupportedOperationChecker { case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if plan.children.forall(_.isStreaming) => + case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => throwError("Limits are not supported on streaming DataFrames/Datasets") - case Sort(_, _, _) | SortPartitions(_, _) if plan.children.forall(_.isStreaming) => + case Sort(_, _, _) | SortPartitions(_, _) if subPlan.children.forall(_.isStreaming) => throwError("Sorting is not supported on streaming DataFrames/Datasets") case Sample(_, _, _, _, child) if child.isStreaming => @@ -138,6 +123,27 @@ object UnsupportedOperationChecker { case _ => } } + + // Checks related to aggregations + val aggregates = plan.collect { case a @ Aggregate(_, _, _) if a.isStreaming => a } + outputMode match { + case InternalOutputModes.Append if aggregates.nonEmpty => + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + + s"streaming DataFrames/DataSets")(plan) + + case InternalOutputModes.Complete | InternalOutputModes.Update if aggregates.isEmpty => + throwError( + s"$outputMode output mode not supported when there are no streaming aggregations on " + + s"streaming DataFrames/Datasets")(plan) + + case _ => + } + if (aggregates.size > 1) { + throwError( + "Multiple streaming aggregations are not supported with " + + "streaming DataFrames/Datasets")(plan) + } } private def throwErrorIf( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 489a1c8c3facd..60525794edc5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -22,7 +22,7 @@ import java.io.IOException import scala.collection.mutable import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException @@ -105,8 +105,6 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E } } - private val fs = FileSystem.get(hadoopConfig) - // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- @@ -120,7 +118,9 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E } } else { try { - fs.mkdirs(new Path(dbDefinition.locationUri)) + val location = new Path(dbDefinition.locationUri) + val fs = location.getFileSystem(hadoopConfig) + fs.mkdirs(location) } catch { case e: IOException => throw new SparkException(s"Unable to create database ${dbDefinition.name} as failed " + @@ -147,7 +147,9 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E // Remove the database. val dbDefinition = catalog(db).db try { - fs.delete(new Path(dbDefinition.locationUri), true) + val location = new Path(dbDefinition.locationUri) + val fs = location.getFileSystem(hadoopConfig) + fs.delete(location, true) } catch { case e: IOException => throw new SparkException(s"Unable to drop database ${dbDefinition.name} as failed " + @@ -203,6 +205,7 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E if (tableDefinition.tableType == CatalogTableType.MANAGED) { val dir = new Path(catalog(db).db.locationUri, table) try { + val fs = dir.getFileSystem(hadoopConfig) fs.mkdirs(dir) } catch { case e: IOException => @@ -223,6 +226,7 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E if (getTable(db, table).tableType == CatalogTableType.MANAGED) { val dir = new Path(catalog(db).db.locationUri, table) try { + val fs = dir.getFileSystem(hadoopConfig) fs.delete(dir, true) } catch { case e: IOException => @@ -248,6 +252,7 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E val oldDir = new Path(catalog(db).db.locationUri, oldName) val newDir = new Path(catalog(db).db.locationUri, newName) try { + val fs = oldDir.getFileSystem(hadoopConfig) fs.rename(oldDir, newDir) } catch { case e: IOException => @@ -338,6 +343,7 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E p.spec.get(col).map(col + "=" + _) }.mkString("/") try { + val fs = tableDir.getFileSystem(hadoopConfig) fs.mkdirs(new Path(tableDir, partitionPath)) } catch { case e: IOException => @@ -373,6 +379,7 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E p.get(col).map(col + "=" + _) }.mkString("/") try { + val fs = tableDir.getFileSystem(hadoopConfig) fs.delete(new Path(tableDir, partitionPath), true) } catch { case e: IOException => @@ -409,6 +416,7 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E newSpec.get(col).map(col + "=" + _) }.mkString("/") try { + val fs = tableDir.getFileSystem(hadoopConfig) fs.rename(new Path(tableDir, oldPath), new Path(tableDir, newPath)) } catch { case e: IOException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 86883d7593412..9657f26402c01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,6 +24,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import scala.language.existentials +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -560,6 +561,10 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(row: String, expressions: Seq[String]): String = { + if (row == null) { + // Cannot split these expressions because they are not created from a row object. + return expressions.mkString("\n") + } val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() for (code <- expressions) { @@ -720,15 +725,23 @@ class CodegenContext { /** * Register a comment and return the corresponding place holder */ - def registerComment(text: String): String = { - val name = freshName("c") - val comment = if (text.contains("\n") || text.contains("\r")) { - text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") + def registerComment(text: => String): String = { + // By default, disable comments in generated code because computing the comments themselves can + // be extremely expensive in certain cases, such as deeply-nested expressions which operate over + // inputs with wide schemas. For more details on the performance issues that motivated this + // flat, see SPARK-15680. + if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + val name = freshName("c") + val comment = if (text.contains("\n") || text.contains("\r")) { + text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") + } else { + s"// $text" + } + placeHolderToComments += (name -> comment) + s"/*$name*/" } else { - s"// $text" + "" } - placeHolderToComments += (name -> comment) - s"/*$name*/" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index e8e2a7bbabcd4..d87e6c76ed734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -434,7 +434,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other :: Nil }.mkString(", ") - /** String representation of this node without any children. */ + /** ONE line description of this node. */ def simpleString: String = s"$nodeName $argString".trim override def toString: String = treeString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java similarity index 69% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala rename to sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java index a4d387eae3c80..e0a54fe30ac7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java @@ -15,9 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.analysis +package org.apache.spark.sql.streaming; -sealed trait OutputMode +import org.junit.Test; -case object Append extends OutputMode -case object Update extends OutputMode +public class JavaOutputModeSuite { + + @Test + public void testOutputModes() { + OutputMode o1 = OutputMode.Append(); + assert(o1.toString().toLowerCase().contains("append")); + OutputMode o2 = OutputMode.Complete(); + assert (o2.toString().toLowerCase().contains("complete")); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a63d1770f3255..77ea29ead92cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -182,8 +182,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - // StringType will be promoted into Decimal(38, 18) - assert(pl(3).dataType == DecimalType(38, 22)) + assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index aaeee0f2a41c4..378cca3644eab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.InternalOutputModes._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -26,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.IntegerType /** A dummy command for testing unsupported operations. */ @@ -79,35 +81,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = "commands" :: Nil) - // Aggregates: Not supported on streams in Append mode - assertSupportedInStreamingPlan( - "aggregate - batch with update output mode", - batchRelation.groupBy("a")("count(*)"), - outputMode = Update) - - assertSupportedInStreamingPlan( - "aggregate - batch with append output mode", - batchRelation.groupBy("a")("count(*)"), - outputMode = Append) - - assertSupportedInStreamingPlan( - "aggregate - stream with update output mode", - streamRelation.groupBy("a")("count(*)"), - outputMode = Update) - - assertNotSupportedInStreamingPlan( - "aggregate - stream with append output mode", - streamRelation.groupBy("a")("count(*)"), - outputMode = Append, - Seq("aggregation", "append output mode")) - // Multiple streaming aggregations not supported def aggExprs(name: String): Seq[NamedExpression] = Seq(Count("*").as(name)) assertSupportedInStreamingPlan( "aggregate - multiple batch aggregations", Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), batchRelation)), - Update) + Append) assertSupportedInStreamingPlan( "aggregate - multiple aggregations but only one streaming aggregation", @@ -209,7 +189,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.intersect(_), streamStreamSupported = false) - // Unary operations testUnaryOperatorInStreamingPlan("sort", Sort(Nil, true, _)) testUnaryOperatorInStreamingPlan("sort partitions", SortPartitions(Nil, _), expectedMsg = "sort") @@ -218,6 +197,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testUnaryOperatorInStreamingPlan( "window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows") + // Output modes with aggregation and non-aggregation plans + testOutputMode(Append, shouldSupportAggregation = false) + testOutputMode(Update, shouldSupportAggregation = true) + testOutputMode(Complete, shouldSupportAggregation = true) /* ======================================================================================= @@ -316,6 +299,37 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode) } + def testOutputMode( + outputMode: OutputMode, + shouldSupportAggregation: Boolean): Unit = { + + // aggregation + if (shouldSupportAggregation) { + assertNotSupportedInStreamingPlan( + s"$outputMode output mode - no aggregation", + streamRelation.where($"a" > 1), + outputMode = outputMode, + Seq("aggregation", s"$outputMode output mode")) + + assertSupportedInStreamingPlan( + s"$outputMode output mode - aggregation", + streamRelation.groupBy("a")("count(*)"), + outputMode = outputMode) + + } else { + assertSupportedInStreamingPlan( + s"$outputMode output mode - no aggregation", + streamRelation.where($"a" > 1), + outputMode = outputMode) + + assertNotSupportedInStreamingPlan( + s"$outputMode output mode - aggregation", + streamRelation.groupBy("a")("count(*)"), + outputMode = outputMode, + Seq("aggregation", s"$outputMode output mode")) + } + } + /** * Assert that the logical plan is supported as subplan insider a streaming plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index f2ba2dfc086c2..25678e938d846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingA import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Utils /** @@ -77,7 +78,47 @@ final class DataFrameWriter private[sql](df: DataFrame) { case "ignore" => SaveMode.Ignore case "error" | "default" => SaveMode.ErrorIfExists case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + - "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") + "Accepted save modes are 'overwrite', 'append', 'ignore', 'error'.") + } + this + } + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be + * written to the sink + * - `OutputMode.Complete()`: all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates + * + * @since 2.0.0 + */ + @Experimental + def outputMode(outputMode: OutputMode): DataFrameWriter = { + assertStreaming("outputMode() can only be called on continuous queries") + this.outputMode = outputMode + this + } + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `append`: only the new rows in the streaming DataFrame/Dataset will be written to + * the sink + * - `complete`: all the rows in the streaming DataFrame/Dataset will be written to the sink + * every time these is some updates + * + * @since 2.0.0 + */ + @Experimental + def outputMode(outputMode: String): DataFrameWriter = { + assertStreaming("outputMode() can only be called on continuous queries") + this.outputMode = outputMode.toLowerCase match { + case "append" => + OutputMode.Append + case "complete" => + OutputMode.Complete + case _ => + throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + + "Accepted output modes are 'append' and 'complete'") } this } @@ -319,7 +360,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { checkpointPath.toUri.toString } - val sink = new MemorySink(df.schema) + val sink = new MemorySink(df.schema, outputMode) val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) resultDf.createOrReplaceTempView(queryName) val continuousQuery = df.sparkSession.sessionState.continuousQueryManager.startQuery( @@ -327,6 +368,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { checkpointLocation, df, sink, + outputMode, trigger) continuousQuery } else { @@ -352,7 +394,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { queryName, checkpointLocation, df, - dataSource.createSink(), + dataSource.createSink(outputMode), + outputMode, trigger) } } @@ -708,6 +751,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var mode: SaveMode = SaveMode.ErrorIfExists + private var outputMode: OutputMode = OutputMode.Append + private var trigger: Trigger = ProcessingTime(0L) private var extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8b6662ab1fae6..3a6ec4595e78e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql @@ -49,6 +49,7 @@ import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.streaming.ContinuousQuery import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -93,14 +94,14 @@ private[sql] object Dataset { * to some files on storage systems, using the `read` function available on a `SparkSession`. * {{{ * val people = spark.read.parquet("...").as[Person] // Scala - * Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class) // Java + * Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java * }}} * * Datasets can also be created through transformations available on existing Datasets. For example, * the following creates a new Dataset by applying a filter on the existing one: * {{{ * val names = people.map(_.name) // in Scala; names is a Dataset[String] - * Dataset names = people.map((Person p) -> p.name, Encoders.STRING)) // in Java 8 + * Dataset names = people.map((Person p) -> p.name, Encoders.STRING)); // in Java 8 * }}} * * Dataset operations can also be untyped, through various domain-specific-language (DSL) @@ -110,7 +111,7 @@ private[sql] object Dataset { * To select a column from the Dataset, use `apply` method in Scala and `col` in Java. * {{{ * val ageCol = people("age") // in Scala - * Column ageCol = people.col("age") // in Java + * Column ageCol = people.col("age"); // in Java * }}} * * Note that the [[Column]] type can also be manipulated through its various functions. @@ -1703,8 +1704,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with a column dropped. - * This is a no-op if schema doesn't contain column name. + * Returns a new [[Dataset]] with a column dropped. This is a no-op if schema doesn't contain + * column name. + * + * This method can only be used to drop top level columns. the colName string is treated + * literally without further interpretation. * * @group untypedrel * @since 2.0.0 @@ -1717,15 +1721,20 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] with columns dropped. * This is a no-op if schema doesn't contain column name(s). * + * This method can only be used to drop top level columns. the colName string is treated literally + * without further interpretation. + * * @group untypedrel * @since 2.0.0 */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = { val resolver = sparkSession.sessionState.analyzer.resolver - val remainingCols = - schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) - if (remainingCols.size == this.schema.size) { + val allColumns = queryExecution.analyzed.output + val remainingCols = allColumns.filter { attribute => + colNames.forall(n => !resolver(attribute.name, n)) + }.map(attribute => Column(attribute)) + if (remainingCols.size == allColumns.size) { toDF() } else { this.select(remainingCols: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0dc70c0b1c7fe..2e14c5d486d4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -30,13 +30,11 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.streaming.ContinuousQueryManager import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager @@ -645,7 +643,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) /** * Returns a [[ContinuousQueryManager]] that allows managing all the - * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this` context. + * [[org.apache.spark.sql.streaming.ContinuousQuery ContinuousQueries]] active on `this` context. * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index f423e7d6b5765..b7ea2a89175fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder /** - * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * A collection of implicit methods for converting common Scala objects into [[Dataset]]s. * * @since 1.6.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 20e22baa351a9..52bedf9dbddae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -40,8 +40,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState, SQLConf} +import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils @@ -182,7 +183,7 @@ class SparkSession private( /** * :: Experimental :: * Returns a [[ContinuousQueryManager]] that allows managing all the - * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this`. + * [[ContinuousQuery ContinuousQueries]] active on `this`. * * @group basic * @since 2.0.0 @@ -694,7 +695,7 @@ object SparkSession { private[this] var userSuppliedContext: Option[SparkContext] = None - private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized { + private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { userSuppliedContext = Option(sparkContext) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index a99bc3bff6eea..6ddb1a7a1f1a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.types.StructType /** * Catalog interface for Spark. To access this, use `SparkSession.catalog`. + * + * @since 2.0.0 */ abstract class Catalog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index 0f7feb8eee7be..33032f07f7bea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -25,6 +25,14 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams // Note: all classes here are expected to be wrapped in Datasets and so must extend // DefinedByConstructorParams for the catalog to be able to create encoders for them. +/** + * A database in Spark, as returned by the `listDatabases` method defined in [[Catalog]]. + * + * @param name name of the database. + * @param description description of the database. + * @param locationUri path (in the form of a uri) to data files. + * @since 2.0.0 + */ class Database( val name: String, @Nullable val description: String, @@ -41,6 +49,16 @@ class Database( } +/** + * A table in Spark, as returned by the `listTables` method in [[Catalog]]. + * + * @param name name of the table. + * @param database name of the database the table belongs to. + * @param description description of the table. + * @param tableType type of the table (e.g. view, table). + * @param isTemporary whether the table is a temporary table. + * @since 2.0.0 + */ class Table( val name: String, @Nullable val database: String, @@ -61,6 +79,17 @@ class Table( } +/** + * A column in Spark, as returned by `listColumns` method in [[Catalog]]. + * + * @param name name of the column. + * @param description description of the column. + * @param dataType data type of the column. + * @param nullable whether the column is nullable. + * @param isPartition whether the column is a partition column. + * @param isBucket whether the column is a bucket column. + * @since 2.0.0 + */ class Column( val name: String, @Nullable val description: String, @@ -83,9 +112,19 @@ class Column( } -// TODO(andrew): should we include the database here? +/** + * A user-defined function in Spark, as returned by `listFunctions` method in [[Catalog]]. + * + * @param name name of the function. + * @param database name of the database the function belongs to. + * @param description description of the function; description can be null. + * @param className the fully qualified class name of the function. + * @param isTemporary whether the function is a temporary function or not. + * @since 2.0.0 + */ class Function( val name: String, + @Nullable val database: String, @Nullable val description: String, val className: String, val isTemporary: Boolean) @@ -94,6 +133,7 @@ class Function( override def toString: String = { "Function[" + s"name='$name', " + + Option(database).map { d => s"database='$d', " }.getOrElse("") + Option(description).map { d => s"description='$d', " }.getOrElse("") + s"className='$className', " + s"isTemporary='$isTemporary']" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e40525287a0a1..7e3e45e56e90a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming.MemoryPlan import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.ContinuousQuery private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => @@ -201,7 +202,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** * Used to plan aggregation queries that are computed incrementally as part of a - * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner + * [[ContinuousQuery]]. Currently this rule is injected into the planner * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] */ object StatefulAggregationStrategy extends Strategy { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 2aec9318941a7..cd9ba7c75b91d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -130,6 +130,7 @@ trait CodegenSupport extends SparkPlan { } val evaluateInputs = evaluateVariables(outputVars) // generate the code to create a UnsafeRow + ctx.INPUT_ROW = row ctx.currentVars = outputVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala index 2e74d59c5f5b6..af1fb4c604c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala @@ -106,6 +106,6 @@ case class SortBasedAggregateExec( val keyString = groupingExpressions.mkString("[", ",", "]") val functionString = allAggregateExpressions.mkString("[", ",", "]") val outputString = output.mkString("[", ",", "]") - s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)" + s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d2dc80a7e42eb..091177959bedb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -599,6 +599,8 @@ case class TungstenAggregate( // create grouping key ctx.currentVars = input + // make sure that the generated code will not be splitted as multiple functions + ctx.INPUT_ROW = null val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) val vectorizedRowKeys = ctx.generateExpressions( @@ -767,9 +769,9 @@ case class TungstenAggregate( val keyString = groupingExpressions.mkString("[", ",", "]") val functionString = allAggregateExpressions.mkString("[", ",", "]") val outputString = output.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)" + s"Aggregate(key=$keyString, functions=$functionString, output=$outputString)" case Some(fallbackStartsAt) => - s"TungstenAggregateWithControlledFallback $groupingExpressions " + + s"AggregateWithControlledFallback $groupingExpressions " + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index f93c446007422..d617a048130e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -311,8 +311,10 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = restored) } - - val saved = StateStoreSaveExec(groupingAttributes, None, partialMerged2) + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = StateStoreSaveExec( + groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b3beb6c85f8ed..93f1ad01bf9aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils @@ -248,15 +249,20 @@ case class DataSource( } /** Returns a sink that can be used to continually write data. */ - def createSink(): Sink = { + def createSink(outputMode: OutputMode): Sink = { providingClass.newInstance() match { - case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns) + case s: StreamSinkProvider => + s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) case parquet: parquet.ParquetFileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) + if (outputMode != OutputMode.Append) { + throw new IllegalArgumentException( + s"Data source $className does not support $outputMode output mode") + } new FileStreamSink(sparkSession, path, parquet, partitionColumns, options) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 1e5bce4a75978..9c03ab28dd769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -92,20 +92,31 @@ class TextFileFormat extends FileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + assert( + requiredSchema.length <= 1, + "Text data source only produces a single data column named \"value\".") + val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) - - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + + if (requiredSchema.isEmpty) { + val emptyUnsafeRow = new UnsafeRow(0) + reader.map(_ => emptyUnsafeRow) + } else { + val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + + reader.map { line => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) + unsafeRow.setTotalSize(bufferHolder.totalSize()) + unsafeRow + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala index 8a5089c09ec7c..9bca9941ae07e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} -import org.apache.spark.sql.util.ContinuousQueryListener -import org.apache.spark.sql.util.ContinuousQueryListener._ +import org.apache.spark.sql.streaming.ContinuousQueryListener import org.apache.spark.util.ListenerBus /** @@ -32,6 +31,8 @@ import org.apache.spark.util.ListenerBus class ContinuousQueryListenerBus(sparkListenerBus: LiveListenerBus) extends SparkListener with ListenerBus[ContinuousQueryListener, ContinuousQueryListener.Event] { + import ContinuousQueryListener._ + sparkListenerBus.addListener(this) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index fe5f36e1cdeeb..bc0e443ca7a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.OutputMode +import org.apache.spark.sql.{InternalOutputModes, SparkSession} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} +import org.apache.spark.sql.streaming.OutputMode /** * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] @@ -54,16 +54,19 @@ class IncrementalExecution private[sql]( /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, + case StateStoreSaveExec(keys, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) + val returnAllStates = if (outputMode == InternalOutputModes.Complete) true else false operatorId += 1 StateStoreSaveExec( keys, Some(stateId), + Some(returnAllStates), agg.withNewChildren( StateStoreRestoreExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index d5e4dd8f78ac2..4d0283fbef1d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -82,10 +82,14 @@ case class StateStoreRestoreExec( case class StateStoreSaveExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], + returnAllStates: Option[Boolean], child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { override protected def doExecute(): RDD[InternalRow] = { + assert(returnAllStates.nonEmpty, + "Incorrect planning in IncrementalExecution, returnAllStates have not been set") + val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -93,29 +97,57 @@ case class StateStoreSaveExec( keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - new Iterator[InternalRow] { - private[this] val baseIterator = iter - private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + Some(sqlContext.streams.stateStoreCoordinator) + )(saveAndReturnFunc) + } + + override def output: Seq[Attribute] = child.output - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - store.commit() - false - } else { - true - } - } + /** + * Save all the rows to the state store, and return all the rows in the state store. + * Note that this returns an iterator that pipelines the saving to store with downstream + * processing. + */ + private def saveAndReturnUpdated( + store: StateStore, + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - row - } + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + false + } else { + true } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + row + } } } - override def output: Seq[Attribute] = child.output + /** + * Save all the rows to the state store, and return all the rows in the state store. + * Note that the saving to store is blocking; only after all the rows have been saved + * is the iterator on the update store data is generated. + */ + private def saveAndReturnAll( + store: StateStore, + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + } + store.commit() + store.iterator().map(_._2.asInstanceOf[InternalRow]) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 1c92117bc6254..79fa0bf47a8d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -28,14 +28,12 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.OutputMode import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.util.ContinuousQueryListener -import org.apache.spark.sql.util.ContinuousQueryListener._ +import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} /** @@ -55,6 +53,8 @@ class StreamExecution( val outputMode: OutputMode) extends ContinuousQuery with Logging { + import org.apache.spark.sql.streaming.ContinuousQueryListener._ + /** * A lock used to wait/notify when batches complete. Use a fair lock to avoid thread starvation. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index 569907b369a54..ac510df209f0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging -import org.apache.spark.sql.ProcessingTime +import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.util.{Clock, SystemClock} trait TriggerExecutor { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index f11a3fb969db6..2ec2a3c3c4a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.streaming.OutputMode class ConsoleSink(options: Map[String, String]) extends Sink with Logging { // Number of rows to display, by default 20 rows @@ -52,7 +53,8 @@ class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister { def createSink( sqlContext: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink = { + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { new ConsoleSink(parameters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index bcc33ae8c88bd..4496f41615a4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,10 +24,11 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType object MemoryStream { @@ -114,35 +115,49 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { + + private case class AddedData(batchId: Long, data: Array[Row]) + /** An order list of batches that have been written to this [[Sink]]. */ @GuardedBy("this") - private val batches = new ArrayBuffer[Array[Row]]() + private val batches = new ArrayBuffer[AddedData]() /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.flatten + batches.map(_.data).flatten } - def latestBatchId: Option[Int] = synchronized { - if (batches.size == 0) None else Some(batches.size - 1) + def latestBatchId: Option[Long] = synchronized { + batches.lastOption.map(_.batchId) } - def lastBatch: Seq[Row] = synchronized { batches.last } + def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } def toDebugString: String = synchronized { - batches.zipWithIndex.map { case (b, i) => - val dataStr = try b.mkString(" ") catch { + batches.map { case AddedData(batchId, data) => + val dataStr = try data.mkString(" ") catch { case NonFatal(e) => "[Error converting to string]" } - s"$i: $dataStr" + s"$batchId: $dataStr" }.mkString("\n") } override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - if (batchId == batches.size) { - logDebug(s"Committing batch $batchId") - batches.append(data.collect()) + if (latestBatchId.isEmpty || batchId > latestBatchId.get) { + logDebug(s"Committing batch $batchId to $this") + outputMode match { + case InternalOutputModes.Append | InternalOutputModes.Update => + batches.append(AddedData(batchId, data.collect())) + + case InternalOutputModes.Complete => + batches.clear() + batches.append(AddedData(batchId, data.collect())) + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } } else { logDebug(s"Skipping already committed batch: $batchId") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index ceb68622752ea..70e17b10ac3cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -125,6 +125,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) new Function( name = funcIdent.identifier, + database = funcIdent.database.orNull, description = null, // for now, this is always undefined className = metadata.getClassName, isTemporary = funcIdent.database.isEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 4c7bbf04bc72a..b2db377ec7f8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.AnalyzeTableCommand import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.streaming.{ContinuousQuery, ContinuousQueryManager} import org.apache.spark.sql.util.ExecutionListenerManager @@ -142,7 +143,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager /** - * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. + * Interface to start and stop [[ContinuousQuery]]s. */ lazy val continuousQueryManager: ContinuousQueryManager = { new ContinuousQueryManager(sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 26285bde31ad0..d2077a07f440a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** @@ -137,7 +138,8 @@ trait StreamSinkProvider { def createSink( sqlContext: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink + partitionColumns: Seq[String], + outputMode: OutputMode): Sink } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQuery.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQuery.scala index 4d5afe2eb5f4c..451cfd85e3bc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQuery.scala @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.SparkSession /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryException.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryException.scala index fec38629d914e..5196c5a537a71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryException.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryInfo.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryInfo.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryInfo.scala index b32cdc39a2259..0323cec7ef0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryInfo.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryListener.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryListener.scala index 86ff8849d8025..e6a725821c2a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryListener.scala @@ -15,22 +15,25 @@ * limitations under the License. */ -package org.apache.spark.sql.util +package org.apache.spark.sql.streaming import com.fasterxml.jackson.annotation.JsonTypeInfo import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.util.ContinuousQueryListener._ /** * :: Experimental :: * Interface for listening to events related to [[ContinuousQuery ContinuousQueries]]. * @note The methods are not thread-safe as they may be called from different threads. + * + * @since 2.0.0 */ @Experimental abstract class ContinuousQueryListener { + import ContinuousQueryListener._ + /** * Called when a query is started. * @note This is called synchronously with @@ -38,6 +41,7 @@ abstract class ContinuousQueryListener { * that is, `onQueryStart` will be called on all listeners before * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please * don't block this method as it will block your query. + * @since 2.0.0 */ def onQueryStarted(queryStarted: QueryStarted): Unit @@ -48,10 +52,14 @@ abstract class ContinuousQueryListener { * latest no matter when this method is called. Therefore, the status of [[ContinuousQuery]] * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]] * is terminated when you are processing [[QueryProgress]]. + * @since 2.0.0 */ def onQueryProgress(queryProgress: QueryProgress): Unit - /** Called when a query is stopped, with or without error */ + /** + * Called when a query is stopped, with or without error. + * @since 2.0.0 + */ def onQueryTerminated(queryTerminated: QueryTerminated): Unit } @@ -59,6 +67,7 @@ abstract class ContinuousQueryListener { /** * :: Experimental :: * Companion object of [[ContinuousQueryListener]] that defines the listener events. + * @since 2.0.0 */ @Experimental object ContinuousQueryListener { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryManager.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryManager.scala index eab557443d1d3..1bfdd2da4e69e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ContinuousQueryManager.scala @@ -15,27 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode, UnsupportedOperationChecker} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.util.ContinuousQueryListener import org.apache.spark.util.{Clock, SystemClock} /** * :: Experimental :: - * A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active - * on a [[SparkSession]]. + * A class to manage all the [[ContinuousQuery]] active on a [[SparkSession]]. * * @since 2.0.0 */ @Experimental -class ContinuousQueryManager(sparkSession: SparkSession) { +class ContinuousQueryManager private[sql] (sparkSession: SparkSession) { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) @@ -147,7 +146,7 @@ class ContinuousQueryManager(sparkSession: SparkSession) { /** * Register a [[ContinuousQueryListener]] to receive up-calls for life cycle events of - * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]]. + * [[ContinuousQuery]]. * * @since 2.0.0 */ @@ -175,9 +174,9 @@ class ContinuousQueryManager(sparkSession: SparkSession) { checkpointLocation: String, df: DataFrame, sink: Sink, + outputMode: OutputMode, trigger: Trigger = ProcessingTime(0), - triggerClock: Clock = new SystemClock(), - outputMode: OutputMode = Append): ContinuousQuery = { + triggerClock: Clock = new SystemClock()): ContinuousQuery = { activeQueriesLock.synchronized { if (activeQueries.contains(name)) { throw new IllegalArgumentException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala index 5a9852809c0eb..79ddf01042ef6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.streaming.{Offset, Sink} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala index 2479e67e369ec..8fccd5b7a3e4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.streaming.{Offset, Source} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala index 256e8a47a4665..d3fdbac576b60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import java.util.concurrent.TimeUnit @@ -29,9 +29,11 @@ import org.apache.spark.unsafe.types.CalendarInterval /** * :: Experimental :: * Used to indicate how often results should be produced by a [[ContinuousQuery]]. + * + * @since 2.0.0 */ @Experimental -sealed trait Trigger {} +sealed trait Trigger /** * :: Experimental :: @@ -53,6 +55,8 @@ sealed trait Trigger {} * import java.util.concurrent.TimeUnit * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} + * + * @since 2.0.0 */ @Experimental case class ProcessingTime(intervalMs: Long) extends Trigger { @@ -62,6 +66,8 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { /** * :: Experimental :: * Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s. + * + * @since 2.0.0 */ @Experimental object ProcessingTime { @@ -73,6 +79,8 @@ object ProcessingTime { * {{{ * df.write.trigger(ProcessingTime("10 seconds")) * }}} + * + * @since 2.0.0 */ def apply(interval: String): ProcessingTime = { if (StringUtils.isBlank(interval)) { @@ -101,6 +109,8 @@ object ProcessingTime { * import scala.concurrent.duration._ * df.write.trigger(ProcessingTime(10.seconds)) * }}} + * + * @since 2.0.0 */ def apply(interval: Duration): ProcessingTime = { new ProcessingTime(interval.toMillis) @@ -113,6 +123,8 @@ object ProcessingTime { * {{{ * df.write.trigger(ProcessingTime.create("10 seconds")) * }}} + * + * @since 2.0.0 */ def create(interval: String): ProcessingTime = { apply(interval) @@ -126,6 +138,8 @@ object ProcessingTime { * import java.util.concurrent.TimeUnit * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} + * + * @since 2.0.0 */ def create(interval: Long, unit: TimeUnit): ProcessingTime = { new ProcessingTime(unit.toMillis(interval)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e2dc4d86395ee..0e18ade09cbe5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -609,6 +609,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df("id") == person("id")) } + test("drop top level columns that contains dot") { + val df1 = Seq((1, 2)).toDF("a.b", "a.c") + checkAnswer(df1.drop("a.b"), Row(2)) + + // Creates data set: {"a.b": 1, "a": {"b": 3}} + val df2 = Seq((1)).toDF("a.b").withColumn("a", struct(lit(3) as "b")) + // Not like select(), drop() parses the column name "a.b" literally without interpreting "." + checkAnswer(df2.drop("a.b").select("a.b"), Row(3)) + + // "`" is treated as a normal char here with no interpreting, "`a`b" is a valid column name. + assert(df2.drop("`a.b`").columns.size == 2) + } + + test("drop(name: String) search and drop all top level columns that matchs the name") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((3, 4)).toDF("a", "b") + checkAnswer(df1.join(df2), Row(1, 2, 3, 4)) + // Finds and drops all columns that match the name (case insensitive). + checkAnswer(df1.join(df2).drop("A"), Row(2, 4)) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala index 0d18a645f6790..52c200796ce41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.streaming.ProcessingTime class ProcessingTimeSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1ddb586d60fd9..91d93022df377 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2483,6 +2483,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-15327: fail to compile generated code with complex data structure") { + withTempDir{ dir => + val json = + """ + |{"h": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "d": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"e": "testing", "count": 3}], "b": [{"e": "test", "count": 1}]}}, + |"c": {"b": {"c": [{"e": "adfgd"}], "a": [{"count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "a": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}}, + |"e": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "g": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"e": "testing", "count": 3}], "b": [{"e": "test", "count": 1}]}}, + |"f": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "b": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}}}' + | + """.stripMargin + val rdd = sparkContext.parallelize(Array(json)) + spark.read.json(rdd).write.mode("overwrite").parquet(dir.toString) + spark.read.parquet(dir.toString).collect() + } + } + test("SPARK-14986: Outer lateral view with empty generate expression") { checkAnswer( sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5d45cfb501c73..dd1f59880701d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1179,11 +1179,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { var message = intercept[AnalysisException] { sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1, 'a'") }.getMessage - assert(message.contains("Please enable Hive support when inserting the regular tables")) + assert(message.contains("Hive support is required to insert into the following tables")) message = intercept[AnalysisException] { sql(s"SELECT * FROM $tabName") }.getMessage - assert(message.contains("Please enable Hive support when selecting the regular tables")) + assert(message.contains("Hive support is required to select over the following tables")) } } @@ -1205,11 +1205,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { var message = intercept[AnalysisException] { sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1, 'a'") }.getMessage - assert(message.contains("Please enable Hive support when inserting the regular tables")) + assert(message.contains("Hive support is required to insert into the following tables")) message = intercept[AnalysisException] { sql(s"SELECT * FROM $tabName") }.getMessage - assert(message.contains("Please enable Hive support when selecting the regular tables")) + assert(message.contains("Hive support is required to select over the following tables")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 500d8ff55a9a8..9f35c02d48762 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -446,13 +446,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(BigDecimal("92233720368547758071.2")) + Row(92233720368547758071.2) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758071.2")) + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index b5e51e963f1b6..7b6981f95e9dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec @@ -31,6 +28,7 @@ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils class TextSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reading text file") { verifyFrame(spark.read.format("text").load(testFile)) @@ -126,6 +124,19 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-14343: select partitioning column") { + withTempPath { dir => + val path = dir.getCanonicalPath + val ds1 = spark.range(1).selectExpr("CONCAT('val_', id)") + ds1.write.text(s"$path/part=a") + ds1.write.text(s"$path/part=b") + + checkDataset( + spark.read.format("text").load(path).select($"part"), + Row("a"), Row("b")) + } + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index db32b6b6befb7..97adffa8ce101 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -1,25 +1,25 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.execution.joins import scala.reflect.ClassTag -import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan @@ -44,11 +44,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { */ override def beforeAll(): Unit = { super.beforeAll() - val conf = new SparkConf() - .setMaster("local-cluster[2,1,1024]") - .setAppName("testing") - val sc = new SparkContext(conf) - spark = SparkSession.builder.getOrCreate() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 7f99d303ba08a..00d5e051de357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.{CountDownLatch, TimeUnit} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.ProcessingTime +import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.util.{Clock, ManualClock, SystemClock} class ProcessingTimeExecutorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index cd434f7887db6..aec0312c4003e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.internal import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalog.{Column, Database, Function, Table} import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ @@ -207,6 +207,14 @@ class CatalogSuite assert(!funcNames2.contains("my_func1")) assert(funcNames2.contains("my_func2")) assert(funcNames2.contains("my_temp_func")) + + // Make sure database is set properly. + assert( + spark.catalog.listFunctions("my_db1").collect().map(_.database).toSet == Set("my_db1", null)) + assert( + spark.catalog.listFunctions("my_db2").collect().map(_.database).toSet == Set("my_db2", null)) + + // Remove the function and make sure they no longer appear. dropFunction("my_func1", Some("my_db1")) dropTempFunction("my_temp_func") val funcNames1b = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet @@ -248,9 +256,11 @@ class CatalogSuite } test("Function.toString") { - assert(new Function("nama", "commenta", "classNameAh", isTemporary = true).toString == - "Function[name='nama', description='commenta', className='classNameAh', isTemporary='true']") - assert(new Function("nama", null, "classNameAh", isTemporary = false).toString == + assert( + new Function("nama", "databasa", "commenta", "classNameAh", isTemporary = true).toString == + "Function[name='nama', database='databasa', description='commenta', " + + "className='classNameAh', isTemporary='true']") + assert(new Function("nama", null, null, "classNameAh", isTemporary = false).toString == "Function[name='nama', className='classNameAh', isTemporary='false']") } @@ -268,7 +278,7 @@ class CatalogSuite test("catalog classes format in Dataset.show") { val db = new Database("nama", "descripta", "locata") val table = new Table("nama", "databasa", "descripta", "typa", isTemporary = false) - val function = new Function("nama", "descripta", "classa", isTemporary = false) + val function = new Function("nama", "databasa", "descripta", "classa", isTemporary = false) val column = new Column( "nama", "descripta", "typa", nullable = false, isPartition = true, isBucket = true) val dbFields = ScalaReflection.getConstructorParameterValues(db) @@ -277,7 +287,7 @@ class CatalogSuite val columnFields = ScalaReflection.getConstructorParameterValues(column) assert(dbFields == Seq("nama", "descripta", "locata")) assert(tableFields == Seq("nama", "databasa", "descripta", "typa", false)) - assert(functionFields == Seq("nama", "descripta", "classa", false)) + assert(functionFields == Seq("nama", "databasa", "descripta", "classa", false)) assert(columnFields == Seq("nama", "descripta", "typa", false, true, true)) val dbString = CatalogImpl.makeDataset(Seq(db), spark).showString(10) val tableString = CatalogImpl.makeDataset(Seq(table), spark).showString(10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryListenerSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryListenerSuite.scala index a144303e4db20..9e58fe8c73f85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryListenerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.util +package org.apache.spark.sql.streaming import java.util.concurrent.ConcurrentLinkedQueue @@ -29,13 +29,14 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.util.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated} +import org.apache.spark.sql.streaming.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated} import org.apache.spark.util.JsonProtocol -class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + +class ContinuousQueryListenerSuite extends StreamTest with BeforeAndAfter { import testImplicits._ + import ContinuousQueryListener._ after { spark.streams.active.foreach(_.stop()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index a743cdde408fc..c1e4970b3a877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -28,12 +28,11 @@ import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class ContinuousQueryManagerSuite extends StreamTest with BeforeAndAfter { import AwaitTerminationTester._ import testImplicits._ @@ -232,20 +231,20 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = { failAfter(streamingTimeout) { val queries = withClue("Error starting queries") { - datasets.map { ds => + datasets.zipWithIndex.map { case (ds, i) => @volatile var query: StreamExecution = null try { val df = ds.toDF val metadataRoot = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - query = spark - .streams - .startQuery( - StreamExecution.nextName, - metadataRoot, - df, - new MemorySink(df.schema)) - .asInstanceOf[StreamExecution] + Utils.createTempDir(namePrefix = "streaming.checkpoint").getCanonicalPath + query = + df.write + .format("memory") + .queryName(s"query$i") + .option("checkpointLocation", metadataRoot) + .outputMode("append") + .startStream() + .asInstanceOf[StreamExecution] } catch { case NonFatal(e) => if (query != null) query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala index ea93ff6ce580e..38b595ccd0385 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkException -import org.apache.spark.sql.StreamTest import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution} -import org.apache.spark.sql.test.SharedSQLContext -class ContinuousQuerySuite extends StreamTest with SharedSQLContext { + +class ContinuousQuerySuite extends StreamTest { import AwaitTerminationTester._ import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 3d8dcaf5a5322..1c73208736f78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils -class FileStreamSinkSuite extends StreamTest with SharedSQLContext { +class FileStreamSinkSuite extends StreamTest { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 1d784f1f4ee85..f681b8878d9ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -137,7 +137,7 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { val valueSchema = new StructType().add("value", StringType) } -class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { +class FileStreamSourceSuite extends FileStreamSourceTest { import testImplicits._ @@ -594,7 +594,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } } -class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { +class FileStreamSourceStressTestSuite extends FileStreamSourceTest { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 4efb7cf52d4a4..1c0fb34dd0191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -23,9 +23,7 @@ import java.util.UUID import scala.util.Random import scala.util.control.NonFatal -import org.apache.spark.sql.{ContinuousQuery, ContinuousQueryException, StreamTest} import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils /** @@ -38,7 +36,7 @@ import org.apache.spark.util.Utils * * At the end, the resulting files are loaded and the answer is checked. */ -class FileStressSuite extends StreamTest with SharedSQLContext { +class FileStressSuite extends StreamTest { import testImplicits._ testQuietly("fault tolerance stress test - unpartitioned output") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 09c35bbf2c34b..df76499fa2801 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -17,27 +17,132 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.{AnalysisException, Row, StreamTest} +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils -class MemorySinkSuite extends StreamTest with SharedSQLContext { +class MemorySinkSuite extends StreamTest with BeforeAndAfter { + import testImplicits._ - test("registering as a table") { - testRegisterAsTable() + after { + sqlContext.streams.active.foreach(_.stop()) } - ignore("stress test") { - // Ignore the stress test as it takes several minutes to run - (0 until 1000).foreach(_ => testRegisterAsTable()) + test("directly add data in Append output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Append) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 1 to 9) } - private def testRegisterAsTable(): Unit = { + test("directly add data in Update output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Update) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 1 to 9) + } + + test("directly add data in Complete output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Complete) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 4 to 6) // new data should replace old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 4 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 7 to 9) + } + + + test("registering as a table in Append output mode") { val input = MemoryStream[Int] val query = input.toDF().write .format("memory") + .outputMode("append") .queryName("memStream") .startStream() input.addData(1, 2, 3) @@ -56,6 +161,57 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { query.stop() } + test("registering as a table in Complete output mode") { + val input = MemoryStream[Int] + val query = input.toDF() + .groupBy("value") + .count() + .write + .format("memory") + .outputMode("complete") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + spark.table("memStream").as[(Int, Long)], + (1, 1L), (2, 1L), (3, 1L)) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + spark.table("memStream").as[(Int, Long)], + (1, 1L), (2, 1L), (3, 1L), (4, 1L), (5, 1L), (6, 1L)) + + query.stop() + } + + ignore("stress test") { + // Ignore the stress test as it takes several minutes to run + (0 until 1000).foreach { _ => + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + spark.table("memStream").as[Int], + 1, 2, 3) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + spark.table("memStream").as[Int], + 1, 2, 3, 4, 5, 6) + + query.stop() + } + } + test("error when no name is specified") { val error = intercept[AnalysisException] { val input = MemoryStream[Int] @@ -88,4 +244,15 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { .startStream() } } + + private def checkAnswer(rows: Seq[Row], expected: Seq[Int])(implicit schema: StructType): Unit = { + checkAnswer( + sqlContext.createDataFrame(sparkContext.makeRDD(rows), schema), + intsToDF(expected)(schema)) + } + + private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): DataFrame = { + require(schema.fields.size === 1) + sqlContext.createDataset(seq).toDF(schema.fieldNames.head) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala index 81760d2aa8205..7f2972edea727 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.StreamTest import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext -class MemorySourceStressSuite extends StreamTest with SharedSQLContext { +class MemorySourceStressSuite extends StreamTest { import testImplicits._ test("memory stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index ae89a6887a6d1..9414b1ce4019b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.ManualClock -class StreamSuite extends StreamTest with SharedSQLContext { +class StreamSuite extends StreamTest { import testImplicits._ @@ -235,6 +235,13 @@ class StreamSuite extends StreamTest with SharedSQLContext { spark.experimental.extraStrategies = Nil } } + + test("output mode API in Scala") { + val o1 = OutputMode.Append + assert(o1 === InternalOutputModes.Append) + val o2 = OutputMode.Complete + assert(o2 === InternalOutputModes.Complete) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 1ab562f873341..dd8672aa641d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import java.lang.Thread.UncaughtExceptionHandler @@ -33,11 +33,12 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode} +import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} /** @@ -64,13 +65,11 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds - val outputMode: OutputMode = Append - /** A trait for actions that can be performed while testing a streaming DataFrame. */ trait StreamAction @@ -191,14 +190,17 @@ trait StreamTest extends QueryTest with Timeouts { * Note that if the stream is not explicitly started before an action that requires it to be * running then it will be automatically started before performing any other actions. */ - def testStream(_stream: Dataset[_])(actions: StreamAction*): Unit = { + def testStream( + _stream: Dataset[_], + outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = { + val stream = _stream.toDF() var pos = 0 var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = new MemorySink(stream.schema) + val sink = new MemorySink(stream.schema, outputMode) @volatile var streamDeathCause: Throwable = null @@ -297,9 +299,9 @@ trait StreamTest extends QueryTest with Timeouts { metadataRoot, stream, sink, + outputMode, trigger, - triggerClock, - outputMode = outputMode) + triggerClock) .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { @@ -429,7 +431,7 @@ trait StreamTest extends QueryTest with Timeouts { } } - val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } @@ -523,7 +525,7 @@ trait StreamTest extends QueryTest with Timeouts { case class ExpectException[E <: Exception]()(implicit val t: ClassTag[E]) extends ExpectedBehavior - private val DEFAULT_TEST_TIMEOUT = 1 second + private val DEFAULT_TEST_TIMEOUT = 1.second def test( expectedBehavior: ExpectedBehavior, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 7104d01c4a2a1..1f174aee8ce08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -20,19 +20,18 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.StreamTest -import org.apache.spark.sql.catalyst.analysis.Update +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.InternalOutputModes._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext object FailureSinglton { var firstTime = true } -class StreamingAggregationSuite extends StreamTest with SharedSQLContext with BeforeAndAfterAll { +class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { override def afterAll(): Unit = { super.afterAll() @@ -41,9 +40,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be import testImplicits._ - override val outputMode = Update - - test("simple count") { + test("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -52,7 +49,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be .agg(count("*")) .as[(Int, Long)] - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, 3), CheckLastBatch((3, 1)), AddData(inputData, 3, 2), @@ -67,6 +64,46 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be ) } + test("simple count, complete mode") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Complete)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 2), + CheckLastBatch((3, 1), (2, 1)), + StopStream, + StartStream(), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 2), (2, 2), (1, 1)), + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4), (3, 2), (2, 2), (1, 1)) + ) + } + + test("simple count, append mode") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + val e = intercept[AnalysisException] { + testStream(aggregated, Append)() + } + Seq("append", "not supported").foreach { m => + assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) + } + } + test("multiple keys") { val inputData = MemoryStream[Int] @@ -76,7 +113,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be .agg(count("*")) .as[(Int, Int, Long)] - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, 1, 2), CheckLastBatch((1, 2, 1), (2, 3, 1)), AddData(inputData, 1, 2), @@ -101,7 +138,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be .agg(count("*")) .as[(Int, Long)] - testStream(aggregated)( + testStream(aggregated, Update)( StartStream(), AddData(inputData, 1, 2, 3, 4), ExpectFailure[SparkException](), @@ -114,7 +151,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala index 288f6dc59741e..a2aac69064f9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -90,17 +90,18 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { override def createSink( spark: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink = { + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns - LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns) + LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns, outputMode) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } } } -class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -416,6 +417,39 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(e.getMessage == "mode() can only be called on non-continuous queries;") } + test("check outputMode(OutputMode) can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.outputMode(OutputMode.Append)) + Seq("outputmode", "continuous queries").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + test("check outputMode(string) can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.outputMode("append")) + Seq("outputmode", "continuous queries").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + test("check outputMode(string) throws exception on unsupported modes") { + def testError(outputMode: String): Unit = { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) + Seq("output mode", "unknown", outputMode).foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + testError("Update") + testError("Xyz") + } + test("check bucketBy() can only be called on non-continuous queries") { val df = spark.read .format("org.apache.spark.sql.streaming.test") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index a4bbe96cf8057..d56bede0cc2fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource} import org.apache.spark.sql.expressions.Window @@ -282,15 +282,12 @@ object SetWarehouseLocationTest extends Logging { val hiveWarehouseLocation = Utils.createTempDir() hiveWarehouseLocation.delete() - val conf = new SparkConf() - conf.set("spark.ui.enabled", "false") // We will use the value of spark.sql.warehouse.dir override the // value of hive.metastore.warehouse.dir. - conf.set("spark.sql.warehouse.dir", warehouseLocation.toString) - conf.set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) - - val sparkSession = SparkSession.builder - .config(conf) + val sparkSession = SparkSession.builder() + .config("spark.ui.enabled", "false") + .config("spark.sql.warehouse.dir", warehouseLocation.toString) + .config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) .enableHiveSupport() .getOrCreate() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4b51f021bfa0c..2a9b06b75efa1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1560,4 +1560,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT * FROM tbl"), Row(1, "a")) } } + + test("spark-15557 promote string test") { + withTable("tbl") { + sql("CREATE TABLE tbl(c1 string, c2 string)") + sql("insert into tbl values ('3', '2.3')") + checkAnswer( + sql("select (cast (99 as decimal(19,6)) + cast('3' as decimal)) * cast('2.3' as decimal)"), + Row(204.0) + ) + checkAnswer( + sql("select (cast(99 as decimal(19,6)) + '3') *'2.3' from tbl"), + Row(234.6) + ) + checkAnswer( + sql("select (cast(99 as decimal(19,6)) + c1) * c2 from tbl"), + Row(234.6) + ) + } + } }