From c632bdc01f51bb253fa3dc258ffa7fdecf814d35 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 22 Mar 2016 10:17:08 -0700 Subject: [PATCH 01/26] [SPARK-14029][SQL] Improve BooleanSimplification optimization by implementing `Not` canonicalization. ## What changes were proposed in this pull request? Currently, **BooleanSimplification** optimization can handle the following cases. * a && (!a || b ) ==> a && b * a && (b || !a ) ==> a && b However, it can not handle the followings cases since those equations fail at the comparisons between their canonicalized forms. * a < 1 && (!(a < 1) || b) ==> (a < 1) && b * a <= 1 && (!(a <= 1) || b) ==> (a <= 1) && b * a > 1 && (!(a > 1) || b) ==> (a > 1) && b * a >= 1 && (!(a >= 1) || b) ==> (a >= 1) && b This PR implements the above cases and also the followings, too. * a < 1 && ((a >= 1) || b ) ==> (a < 1) && b * a <= 1 && ((a > 1) || b ) ==> (a <= 1) && b * a > 1 && ((a <= 1) || b) ==> (a > 1) && b * a >= 1 && ((a < 1) || b) ==> (a >= 1) && b ## How was this patch tested? Pass the Jenkins tests including new test cases in BooleanSimplicationSuite. Author: Dongjoon Hyun Closes #11851 from dongjoon-hyun/SPARK-14029. --- .../catalyst/expressions/Canonicalize.scala | 9 ++++++ .../expressions/ExpressionSetSuite.scala | 6 ++++ .../BooleanSimplificationSuite.scala | 28 +++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index ae1f6006135bb..07ba7d5e4a849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -71,6 +71,15 @@ object Canonicalize extends { case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + case Not(GreaterThan(l, r)) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + case Not(GreaterThan(l, r)) => LessThanOrEqual(l, r) + case Not(LessThan(l, r)) if l.hashCode() > r.hashCode() => LessThan(r, l) + case Not(LessThan(l, r)) => GreaterThanOrEqual(l, r) + case Not(GreaterThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) + case Not(LessThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + case _ => e } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index 0b350c6a98255..60939ee0eda5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -74,6 +74,12 @@ class ExpressionSetSuite extends SparkFunSuite { setTest(1, aUpper > bUpper, bUpper < aUpper) setTest(1, aUpper >= bUpper, bUpper <= aUpper) + // `Not` canonicalization + setTest(1, Not(aUpper > 1), aUpper <= 1, Not(Literal(1) < aUpper), Literal(1) >= aUpper) + setTest(1, Not(aUpper < 1), aUpper >= 1, Not(Literal(1) > aUpper), Literal(1) <= aUpper) + setTest(1, Not(aUpper >= 1), aUpper < 1, Not(Literal(1) <= aUpper), Literal(1) > aUpper) + setTest(1, Not(aUpper <= 1), aUpper > 1, Not(Literal(1) >= aUpper), Literal(1) < aUpper) + test("add to / remove from set") { val initialSet = ExpressionSet(aUpper + 1 :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 47b79fe462457..2ab31eea8ab38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -99,6 +99,34 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(('b || !'a ) && 'a, 'b && 'a) } + test("a < 1 && (!(a < 1) || b)") { + checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b) + checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b) + + checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b) + + checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b) + checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b) + + checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b) + checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b) + } + + test("a < 1 && ((a >= 1) || b)") { + checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b) + checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b) + + checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b) + + checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b) + checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b) + + checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b) + checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b) + } + test("DeMorgan's law") { checkCondition(!('a && 'b), !'a || !'b) From caea15214571d9b12dcf1553e5c1cc8b83a8ba5b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 22 Mar 2016 10:18:42 -0700 Subject: [PATCH 02/26] [SPARK-13985][SQL] Deterministic batches with ids This PR relaxes the requirements of a `Sink` for structured streaming to only require idempotent appending of data. Previously the `Sink` needed to be able to transactionally append data while recording an opaque offset indicated how far in a stream we have processed. In order to do this, a new write-ahead-log has been added to stream execution, which records the offsets that will are present in each batch. The log is created in the newly added `checkpointLocation`, which defaults to `${spark.sql.streaming.checkpointLocation}/${queryName}` but can be overriden by setting `checkpointLocation` in `DataFrameWriter`. In addition to making sinks easier to write the addition of batchIds and a checkpoint location is done in anticipation of integration with the the `StateStore` (#11645). Author: Michael Armbrust Closes #11804 from marmbrus/batchIds. --- .../spark/sql/ContinuousQueryManager.scala | 8 +- .../apache/spark/sql/DataFrameWriter.scala | 11 +- .../org/apache/spark/sql/SinkStatus.scala | 2 +- .../execution/datasources/DataSource.scala | 3 +- .../execution/streaming/CompositeOffset.scala | 12 ++ .../streaming/FileStreamSource.scala | 24 +-- .../execution/streaming/HDFSMetadataLog.scala | 7 +- .../spark/sql/execution/streaming/Sink.scala | 30 +-- .../sql/execution/streaming/Source.scala | 10 +- .../execution/streaming/StreamExecution.scala | 193 ++++++++++++------ .../execution/streaming/StreamProgress.scala | 52 ++--- .../sql/execution/streaming/memory.scala | 85 ++++---- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../org/apache/spark/sql/StreamTest.scala | 18 +- .../ContinuousQueryManagerSuite.scala | 8 +- .../sql/streaming/ContinuousQuerySuite.scala | 11 +- .../DataFrameReaderWriterSuite.scala | 55 +++-- .../sql/streaming/FileStreamSourceSuite.scala | 18 -- .../util/ContinuousQueryListenerSuite.scala | 6 +- 19 files changed, 319 insertions(+), 241 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 0a156ea99a297..fa8219bbed0d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -164,13 +164,17 @@ class ContinuousQueryManager(sqlContext: SQLContext) { } /** Start a query */ - private[sql] def startQuery(name: String, df: DataFrame, sink: Sink): ContinuousQuery = { + private[sql] def startQuery( + name: String, + checkpointLocation: String, + df: DataFrame, + sink: Sink): ContinuousQuery = { activeQueriesLock.synchronized { if (activeQueries.contains(name)) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } - val query = new StreamExecution(sqlContext, name, df.logicalPlan, sink) + val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink) query.start() activeQueries.put(name, query) query diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7ed1c51360f0c..c07bd0e7b7175 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,6 +21,8 @@ import java.util.Properties import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -251,8 +253,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { options = extraOptions.toMap, partitionColumns = normalizedParCols.getOrElse(Nil)) + val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) + val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { + new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString + }) df.sqlContext.sessionState.continuousQueryManager.startQuery( - extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink()) + queryName, + checkpointLocation, + df, + dataSource.createSink()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala index ce21451b2c9c7..5a9852809c0eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala @@ -31,4 +31,4 @@ import org.apache.spark.sql.execution.streaming.{Offset, Sink} @Experimental class SinkStatus private[sql]( val description: String, - val offset: Option[Offset]) + val offset: Offset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index e2a14edc54a10..fac2a64726618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -162,7 +162,8 @@ case class DataSource( paths = files, userSpecifiedSchema = Some(dataSchema), className = className, - options = options.filterKeys(_ != "path")).resolveRelation())) + options = + new CaseInsensitiveMap(options.filterKeys(_ != "path"))).resolveRelation())) } new FileStreamSource( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala index 59a52a3d59d91..e48ac598929ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -52,6 +52,18 @@ case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { case i if i == 0 => 0 case i if i > 0 => 1 } + + /** + * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of + * sources. + * + * This method is typically used to associate a serialized offset with actual sources (which + * cannot be serialized). + */ + def toStreamProgress(sources: Seq[Source]): StreamProgress = { + assert(sources.size == offsets.size) + new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } + } } object CompositeOffset { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 787e93f543963..d13b1a6166798 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -109,20 +109,16 @@ class FileStreamSource( /** * Returns the next batch of data that is available after `start`, if any is available. */ - override def getNextBatch(start: Option[Offset]): Option[Batch] = { + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { val startId = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) - val end = fetchMaxOffset() - val endId = end.offset - - if (startId + 1 <= endId) { - val files = metadataLog.get(Some(startId + 1), endId).map(_._2).flatten - logDebug(s"Return files from batches ${startId + 1}:$endId") - logDebug(s"Streaming ${files.mkString(", ")}") - Some(new Batch(end, dataFrameBuilder(files))) - } - else { - None - } + val endId = end.asInstanceOf[LongOffset].offset + + assert(startId <= endId) + val files = metadataLog.get(Some(startId + 1), endId).map(_._2).flatten + logDebug(s"Return files from batches ${startId + 1}:$endId") + logDebug(s"Streaming ${files.mkString(", ")}") + dataFrameBuilder(files) + } private def fetchAllFiles(): Seq[String] = { @@ -130,4 +126,6 @@ class FileStreamSource( .filterNot(_.getPath.getName.startsWith("_")) .map(_.getPath.toUri.toString) } + + override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index ac2842b6d5df9..298b5d292e8e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -27,6 +27,7 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission +import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.SQLContext @@ -42,7 +43,9 @@ import org.apache.spark.sql.SQLContext * Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing * files in a directory always shows the latest files. */ -class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) extends MetadataLog[T] { +class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) + extends MetadataLog[T] + with Logging { private val metadataPath = new Path(path) @@ -113,6 +116,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) extends try { // Try to commit the batch // It will fail if there is an existing file (someone has committed the batch) + logDebug(s"Attempting to write log #${batchFile(batchId)}") fc.rename(tempPath, batchFile(batchId), Options.Rename.NONE) return } catch { @@ -161,6 +165,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) extends val bytes = IOUtils.toByteArray(input) Some(serializer.deserialize[T](ByteBuffer.wrap(bytes))) } else { + logDebug(s"Unable to find batch $batchMetadataFile") None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index e3b2d2f67ee0c..25015d58f75ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -17,31 +17,19 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.DataFrame + /** - * An interface for systems that can collect the results of a streaming query. - * - * When new data is produced by a query, a [[Sink]] must be able to transactionally collect the - * data and update the [[Offset]]. In the case of a failure, the sink will be recreated - * and must be able to return the [[Offset]] for all of the data that is made durable. - * This contract allows Spark to process data with exactly-once semantics, even in the case - * of failures that require the computation to be restarted. + * An interface for systems that can collect the results of a streaming query. In order to preserve + * exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same + * batch. */ trait Sink { - /** - * Returns the [[Offset]] for all data that is currently present in the sink, if any. This - * function will be called by Spark when restarting execution in order to determine at which point - * in the input stream computation should be resumed from. - */ - def currentOffset: Option[Offset] /** - * Accepts a new batch of data as well as a [[Offset]] that denotes how far in the input - * data computation has progressed to. When computation restarts after a failure, it is important - * that a [[Sink]] returns the same [[Offset]] as the most recent batch of data that - * has been persisted durably. Note that this does not necessarily have to be the - * [[Offset]] for the most recent batch of data that was given to the sink. For example, - * it is valid to buffer data before persisting, as long as the [[Offset]] is stored - * transactionally as data is eventually persisted. + * Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if + * this method is called more than once with the same batchId (which will happen in the case of + * failures), then `data` should only be added once. */ - def addBatch(batch: Batch): Unit + def addBatch(batchId: Long, data: DataFrame): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 25922979ac83e..6457f928ed887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -29,8 +30,13 @@ trait Source { /** Returns the schema of the data from this source */ def schema: StructType + /** Returns the maximum available offset for this source. */ + def getOffset: Option[Offset] + /** - * Returns the next batch of data that is available after `start`, if any is available. + * Returns the data that is is between the offsets (`start`, `end`]. When `start` is `None` then + * the batch should begin with the first available record. This method must always return the + * same data for a particular `start` and `end` pair. */ - def getNextBatch(start: Option[Offset]): Option[Batch] + def getBatch(start: Option[Offset], end: Offset): DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 0062b7fc75c4a..c5fefb5346bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -23,6 +23,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} @@ -41,6 +43,7 @@ import org.apache.spark.sql.util.ContinuousQueryListener._ class StreamExecution( val sqlContext: SQLContext, override val name: String, + val checkpointRoot: String, private[sql] val logicalPlan: LogicalPlan, val sink: Sink) extends ContinuousQuery with Logging { @@ -52,13 +55,28 @@ class StreamExecution( /** Minimum amount of time in between the start of each batch. */ private val minBatchTime = 10 - /** Tracks how much data we have processed from each input source. */ - private[sql] val streamProgress = new StreamProgress + /** + * Tracks how much data we have processed and committed to the sink or state store from each + * input source. + */ + private[sql] var committedOffsets = new StreamProgress + + /** + * Tracks the offsets that are available to be processed, but have not yet be committed to the + * sink. + */ + private var availableOffsets = new StreamProgress + + /** The current batchId or -1 if execution has not yet been initialized. */ + private var currentBatchId: Long = -1 /** All stream sources present the query plan. */ private val sources = logicalPlan.collect { case s: StreamingRelation => s.source } + /** A list of unique sources in the query plan. */ + private val uniqueSources = sources.distinct + /** Defines the internal state of execution */ @volatile private var state: State = INITIALIZED @@ -74,20 +92,34 @@ class StreamExecution( override def run(): Unit = { runBatches() } } + /** + * A write-ahead-log that records the offsets that are present in each batch. In order to ensure + * that a given batch will always consist of the same data, we write to this log *before* any + * processing is done. Thus, the Nth record in this log indicated data that is currently being + * processed and the N-1th entry indicates which offsets have been durably committed to the sink. + */ + private val offsetLog = + new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets")) + /** Whether the query is currently active or not */ override def isActive: Boolean = state == ACTIVE /** Returns current status of all the sources. */ override def sourceStatuses: Array[SourceStatus] = { - sources.map(s => new SourceStatus(s.toString, streamProgress.get(s))).toArray + sources.map(s => new SourceStatus(s.toString, availableOffsets.get(s))).toArray } /** Returns current status of the sink. */ - override def sinkStatus: SinkStatus = new SinkStatus(sink.toString, sink.currentOffset) + override def sinkStatus: SinkStatus = + new SinkStatus(sink.toString, committedOffsets.toCompositeOffset(sources)) /** Returns the [[ContinuousQueryException]] if the query was terminated by an exception. */ override def exception: Option[ContinuousQueryException] = Option(streamDeathCause) + /** Returns the path of a file with `name` in the checkpoint directory. */ + private def checkpointFile(name: String): String = + new Path(new Path(checkpointRoot), name).toUri.toString + /** * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event * has been posted to all the listeners. @@ -102,7 +134,7 @@ class StreamExecution( * Repeatedly attempts to run batches as data arrives. * * Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted - * so that listeners are guaranteed to get former event before the latter. Furthermore, this + * such that listeners are guaranteed to get a start event before a termination. Furthermore, this * method also ensures that [[QueryStarted]] event is posted before the `start()` method returns. */ private def runBatches(): Unit = { @@ -118,9 +150,10 @@ class StreamExecution( // While active, repeatedly attempt to run batches. SQLContext.setActive(sqlContext) populateStartOffsets() - logInfo(s"Stream running at $streamProgress") + logDebug(s"Stream running from $committedOffsets to $availableOffsets") while (isActive) { - attemptBatch() + if (dataAvailable) runBatch() + commitAndConstructNextBatch() Thread.sleep(minBatchTime) // TODO: Could be tighter } } catch { @@ -130,7 +163,7 @@ class StreamExecution( this, s"Query $name terminated with exception: ${e.getMessage}", e, - Some(streamProgress.toCompositeOffset(sources))) + Some(committedOffsets.toCompositeOffset(sources))) logError(s"Query $name terminated with error", e) } finally { state = TERMINATED @@ -142,48 +175,99 @@ class StreamExecution( /** * Populate the start offsets to start the execution at the current offsets stored in the sink - * (i.e. avoid reprocessing data that we have already processed). + * (i.e. avoid reprocessing data that we have already processed). This function must be called + * before any processing occurs and will populate the following fields: + * - currentBatchId + * - committedOffsets + * - availableOffsets */ private def populateStartOffsets(): Unit = { - sink.currentOffset match { - case Some(c: CompositeOffset) => - val storedProgress = c.offsets - val sources = logicalPlan collect { - case StreamingRelation(source, _) => source + offsetLog.getLatest() match { + case Some((batchId, nextOffsets)) => + logInfo(s"Resuming continuous query, starting with batch $batchId") + currentBatchId = batchId + 1 + availableOffsets = nextOffsets.toStreamProgress(sources) + logDebug(s"Found possibly uncommitted offsets $availableOffsets") + + offsetLog.get(batchId - 1).foreach { + case lastOffsets => + committedOffsets = lastOffsets.toStreamProgress(sources) + logDebug(s"Resuming with committed offsets: $committedOffsets") } - assert(sources.size == storedProgress.size) - sources.zip(storedProgress).foreach { case (source, offset) => - offset.foreach(streamProgress.update(source, _)) - } case None => // We are starting this stream for the first time. - case _ => throw new IllegalArgumentException("Expected composite offset from sink") + logInfo(s"Starting new continuous query.") + currentBatchId = 0 + commitAndConstructNextBatch() } } /** - * Checks to see if any new data is present in any of the sources. When new data is available, - * a batch is executed and passed to the sink, updating the currentOffsets. + * Returns true if there is any new data available to be processed. */ - private def attemptBatch(): Unit = { + private def dataAvailable: Boolean = { + availableOffsets.exists { + case (source, available) => + committedOffsets + .get(source) + .map(committed => committed < available) + .getOrElse(true) + } + } + + /** + * Queries all of the sources to see if any new data is available. When there is new data the + * batchId counter is incremented and a new log entry is written with the newest offsets. + * + * Note that committing the offsets for a new batch implicitly marks the previous batch as + * finished and thus this method should only be called when all currently available data + * has been written to the sink. + */ + private def commitAndConstructNextBatch(): Boolean = { + // Update committed offsets. + committedOffsets ++= availableOffsets + + // Check to see what new data is available. + val newData = uniqueSources.flatMap(s => s.getOffset.map(o => s -> o)) + availableOffsets ++= newData + + if (dataAvailable) { + assert( + offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), + s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") + currentBatchId += 1 + logInfo(s"Committed offsets for batch $currentBatchId.") + true + } else { + false + } + } + + /** + * Processes any data available between `availableOffsets` and `committedOffsets`. + */ + private def runBatch(): Unit = { val startTime = System.nanoTime() - // A list of offsets that need to be updated if this batch is successful. - // Populated while walking the tree. - val newOffsets = new ArrayBuffer[(Source, Offset)] + // Request unprocessed data from all sources. + val newData = availableOffsets.flatMap { + case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => + val current = committedOffsets.get(source) + val batch = source.getBatch(current, available) + logDebug(s"Retrieving data from $source: $current -> $available") + Some(source -> batch) + case _ => None + }.toMap + // A list of attributes that will need to be updated. var replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. val withNewSources = logicalPlan transform { case StreamingRelation(source, output) => - val prevOffset = streamProgress.get(source) - val newBatch = source.getNextBatch(prevOffset) - - newBatch.map { batch => - newOffsets += ((source, batch.end)) - val newPlan = batch.data.logicalPlan - - assert(output.size == newPlan.output.size) + newData.get(source).map { data => + val newPlan = data.logicalPlan + assert(output.size == newPlan.output.size, + s"Invalid batch: ${output.mkString(",")} != ${newPlan.output.mkString(",")}") replacements ++= output.zip(newPlan.output) newPlan }.getOrElse { @@ -197,35 +281,24 @@ class StreamExecution( case a: Attribute if replacementMap.contains(a) => replacementMap(a) } - if (newOffsets.nonEmpty) { - val optimizerStart = System.nanoTime() - - lastExecution = new QueryExecution(sqlContext, newPlan) - val executedPlan = lastExecution.executedPlan - val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 - logDebug(s"Optimized batch in ${optimizerTime}ms") + val optimizerStart = System.nanoTime() - streamProgress.synchronized { - // Update the offsets and calculate a new composite offset - newOffsets.foreach(streamProgress.update) + lastExecution = new QueryExecution(sqlContext, newPlan) + val executedPlan = lastExecution.executedPlan + val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 + logDebug(s"Optimized batch in ${optimizerTime}ms") - // Construct the batch and send it to the sink. - val batchOffset = streamProgress.toCompositeOffset(sources) - val nextBatch = new Batch(batchOffset, Dataset.newDataFrame(sqlContext, newPlan)) - sink.addBatch(nextBatch) - } - - awaitBatchLock.synchronized { - // Wake up any threads that are waiting for the stream to progress. - awaitBatchLock.notifyAll() - } + val nextBatch = Dataset.newDataFrame(sqlContext, newPlan) + sink.addBatch(currentBatchId - 1, nextBatch) - val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 - logInfo(s"Completed up to $newOffsets in ${batchTime}ms") - postEvent(new QueryProgress(this)) + awaitBatchLock.synchronized { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLock.notifyAll() } - logDebug(s"Waiting for data, current: $streamProgress") + val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 + logInfo(s"Completed up to $availableOffsets in ${batchTime}ms") + postEvent(new QueryProgress(this)) } private def postEvent(event: ContinuousQueryListener.Event) { @@ -252,9 +325,7 @@ class StreamExecution( * least the given `Offset`. This method is indented for use primarily when writing tests. */ def awaitOffset(source: Source, newOffset: Offset): Unit = { - def notDone = streamProgress.synchronized { - !streamProgress.contains(source) || streamProgress(source) < newOffset - } + def notDone = !committedOffsets.contains(source) || committedOffsets(source) < newOffset while (notDone) { logInfo(s"Waiting until $newOffset at $source") @@ -297,7 +368,7 @@ class StreamExecution( s""" |=== Continuous Query === |Name: $name - |Current Offsets: $streamProgress + |Current Offsets: $committedOffsets | |Current State: $state |Thread State: ${microBatchThread.getState} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index d45b9bd9838c1..405a5f0387a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -17,55 +17,31 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.mutable +import scala.collection.{immutable, GenTraversableOnce} /** * A helper class that looks like a Map[Source, Offset]. */ -class StreamProgress { - private val currentOffsets = new mutable.HashMap[Source, Offset] +class StreamProgress( + val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset]) + extends scala.collection.immutable.Map[Source, Offset] { - private[streaming] def update(source: Source, newOffset: Offset): Unit = { - currentOffsets.get(source).foreach(old => - assert(newOffset > old, s"Stream going backwards $newOffset -> $old")) - currentOffsets.put(source, newOffset) + private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = { + CompositeOffset(source.map(get)) } - private[streaming] def update(newOffset: (Source, Offset)): Unit = - update(newOffset._1, newOffset._2) + override def toString: String = + baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") - private[streaming] def apply(source: Source): Offset = currentOffsets(source) - private[streaming] def get(source: Source): Option[Offset] = currentOffsets.get(source) - private[streaming] def contains(source: Source): Boolean = currentOffsets.contains(source) + override def +[B1 >: Offset](kv: (Source, B1)): Map[Source, B1] = baseMap + kv - private[streaming] def ++(updates: Map[Source, Offset]): StreamProgress = { - val updated = new StreamProgress - currentOffsets.foreach(updated.update) - updates.foreach(updated.update) - updated - } + override def get(key: Source): Option[Offset] = baseMap.get(key) - /** - * Used to create a new copy of this [[StreamProgress]]. While this class is currently mutable, - * it should be copied before being passed to user code. - */ - private[streaming] def copy(): StreamProgress = { - val copied = new StreamProgress - currentOffsets.foreach(copied.update) - copied - } + override def iterator: Iterator[(Source, Offset)] = baseMap.iterator - private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = { - CompositeOffset(source.map(get)) - } + override def -(key: Source): Map[Source, Offset] = baseMap - key - override def toString: String = - currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") - - override def equals(other: Any): Boolean = other match { - case s: StreamProgress => currentOffsets == s.currentOffsets - case _ => false + def ++(updates: GenTraversableOnce[(Source, Offset)]): StreamProgress = { + new StreamProgress(baseMap ++ updates) } - - override def hashCode: Int = currentOffsets.hashCode() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index a6504cd088b7f..8c2bb4abd5f6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -51,8 +51,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected var currentOffset: LongOffset = new LongOffset(-1) - protected def blockManager = SparkEnv.get.blockManager - def schema: StructType = encoder.schema def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { @@ -78,25 +76,32 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def getNextBatch(start: Option[Offset]): Option[Batch] = synchronized { - val newBlocks = - batches.drop( - start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1) - - if (newBlocks.nonEmpty) { - logDebug(s"Running [$start, $currentOffset] on blocks ${newBlocks.mkString(", ")}") - val df = newBlocks - .map(_.toDF()) - .reduceOption(_ unionAll _) - .getOrElse(sqlContext.emptyDataFrame) + override def toString: String = s"MemoryStream[${output.mkString(",")}]" - Some(new Batch(currentOffset, df)) - } else { - None - } + override def getOffset: Option[Offset] = if (batches.isEmpty) { + None + } else { + Some(currentOffset) } - override def toString: String = s"MemoryStream[${output.mkString(",")}]" + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startOrdinal = + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 + val newBlocks = batches.slice(startOrdinal, endOrdinal) + + logDebug( + s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") + newBlocks + .map(_.toDF()) + .reduceOption(_ unionAll _) + .getOrElse { + sys.error("No data selected!") + } + } } /** @@ -105,45 +110,29 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) */ class MemorySink(schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ - private var batches = new ArrayBuffer[Batch]() - - /** Used to convert an [[InternalRow]] to an external [[Row]] for comparison in testing. */ - private val externalRowConverter = RowEncoder(schema) - - override def currentOffset: Option[Offset] = synchronized { - batches.lastOption.map(_.end) - } - - override def addBatch(nextBatch: Batch): Unit = synchronized { - nextBatch.data.collect() // 'compute' the batch's data and record the batch - batches.append(nextBatch) - } + private val batches = new ArrayBuffer[Array[Row]]() /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches - .map(_.data) - .reduceOption(_ unionAll _) - .map(_.collect().toSeq) - .getOrElse(Seq.empty) - } - - /** - * Atomically drops the most recent `num` batches and resets the [[StreamProgress]] to the - * corresponding point in the input. This function can be used when testing to simulate data - * that has been lost due to buffering. - */ - def dropBatches(num: Int): Unit = synchronized { - batches.dropRight(num) + batches.flatten } def toDebugString: String = synchronized { - batches.map { b => - val dataStr = try b.data.collect().mkString(" ") catch { + batches.zipWithIndex.map { case (b, i) => + val dataStr = try b.mkString(" ") catch { case NonFatal(e) => "[Error converting to string]" } - s"${b.end}: $dataStr" + s"$i: $dataStr" }.mkString("\n") } + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (batchId == batches.size) { + logDebug(s"Committing batch $batchId") + batches.append(data.collect()) + } else { + logDebug(s"Skipping already committed batch: $batchId") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 70d1a8b071dfb..fd1d77f514a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -524,6 +524,11 @@ object SQLConf { doc = "When true, the planner will try to find out duplicated exchanges and re-use them.", isPublic = false) + val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation", + defaultValue = None, + doc = "The default location for storing checkpoint data for continuously executing queries.", + isPublic = true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -554,6 +559,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin /** ************************ Spark SQL Params/Hints ******************* */ + def checkpointLocation: String = getConf(CHECKPOINT_LOCATION) + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) def useCompression: Boolean = getConf(COMPRESS_CACHED) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 81078dc6a0450..f356cde9cf87a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.util.Utils /** * A framework for implementing tests for streaming queries and sources. @@ -64,6 +65,12 @@ import org.apache.spark.sql.execution.streaming._ */ trait StreamTest extends QueryTest with Timeouts { + implicit class RichContinuousQuery(cq: ContinuousQuery) { + def stopQuietly(): Unit = quietly { + cq.stop() + } + } + implicit class RichSource(s: Source) { def toDF(): DataFrame = Dataset.newDataFrame(sqlContext, StreamingRelation(s)) @@ -126,8 +133,6 @@ trait StreamTest extends QueryTest with Timeouts { override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" } - case class DropBatches(num: Int) extends StreamAction - /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -202,7 +207,7 @@ trait StreamTest extends QueryTest with Timeouts { }.mkString("\n") def currentOffsets = - if (currentStream != null) currentStream.streamProgress.toString else "not started" + if (currentStream != null) currentStream.committedOffsets.toString else "not started" def threadState = if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" @@ -266,6 +271,7 @@ trait StreamTest extends QueryTest with Timeouts { } val testThread = Thread.currentThread() + val metadataRoot = Utils.createTempDir("streaming.metadata").getCanonicalPath try { startedTest.foreach { action => @@ -276,7 +282,7 @@ trait StreamTest extends QueryTest with Timeouts { currentStream = sqlContext .streams - .startQuery(StreamExecution.nextName, stream, sink) + .startQuery(StreamExecution.nextName, metadataRoot, stream, sink) .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { @@ -308,10 +314,6 @@ trait StreamTest extends QueryTest with Timeouts { currentStream = null } - case DropBatches(num) => - verify(currentStream == null, "dropping batches while running leads to corruption") - sink.dropBatches(num) - case ef: ExpectFailure[_] => verify(currentStream != null, "can not expect failure when stream is not running") try failAfter(streamingTimeout) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 45e824ad6353e..54ce98d195e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { @@ -235,9 +236,14 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with @volatile var query: StreamExecution = null try { val df = ds.toDF + val metadataRoot = Utils.createTempDir("streaming.metadata").getCanonicalPath query = sqlContext .streams - .startQuery(StreamExecution.nextName, df, new MemorySink(df.schema)) + .startQuery( + StreamExecution.nextName, + metadataRoot, + df, + new MemorySink(df.schema)) .asInstanceOf[StreamExecution] } catch { case NonFatal(e) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala index 84ed017a9d0d4..3be0ea481dc53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala @@ -54,7 +54,8 @@ class ContinuousQuerySuite extends StreamTest with SharedSQLContext { TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000), TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10), AssertOnQuery( - q => q.exception.get.startOffset.get === q.streamProgress.toCompositeOffset(Seq(inputData)), + q => + q.exception.get.startOffset.get === q.committedOffsets.toCompositeOffset(Seq(inputData)), "incorrect start offset on exception") ) } @@ -68,19 +69,19 @@ class ContinuousQuerySuite extends StreamTest with SharedSQLContext { AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), AssertOnQuery(_.sourceStatuses(0).offset === None), AssertOnQuery(_.sinkStatus.description.contains("Memory")), - AssertOnQuery(_.sinkStatus.offset === None), + AssertOnQuery(_.sinkStatus.offset === new CompositeOffset(None :: Nil)), AddData(inputData, 1, 2), CheckAnswer(6, 3), AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(0))), - AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))), + AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))), AddData(inputData, 1, 2), CheckAnswer(6, 3, 6, 3), AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(1))), - AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))), + AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(1))), AddData(inputData, 0), ExpectFailure[SparkException], AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(2))), - AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))) + AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(1))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 0878277811e12..e485aa837b7ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.streaming.test import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, ContinuousQuery, SQLContext, StreamTest} -import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source} +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.Utils object LastOptions { var parameters: Map[String, String] = null @@ -41,8 +42,15 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { LastOptions.parameters = parameters LastOptions.schema = schema new Source { - override def getNextBatch(start: Option[Offset]): Option[Batch] = None override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + + override def getOffset: Option[Offset] = Some(new LongOffset(0)) + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + import sqlContext.implicits._ + + Seq[Int]().toDS().toDF() + } } } @@ -53,8 +61,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns new Sink { - override def addBatch(batch: Batch): Unit = {} - override def currentOffset: Option[Offset] = None + override def addBatch(batchId: Long, data: DataFrame): Unit = {} } } } @@ -62,8 +69,10 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ + private def newMetadataDir = Utils.createTempDir("streaming.metadata").getCanonicalPath + after { - sqlContext.streams.active.foreach(_.stop()) + sqlContext.streams.active.foreach(_.stopQuietly()) } test("resolve default source") { @@ -72,8 +81,9 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .stream() .write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .startStream() - .stop() + .stopQuietly() } test("resolve full class") { @@ -82,8 +92,9 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .stream() .write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .startStream() - .stop() + .stopQuietly() } test("options") { @@ -108,8 +119,9 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) + .option("checkpointLocation", newMetadataDir) .startStream() - .stop() + .stopQuietly() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") @@ -123,38 +135,43 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B df.write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .startStream() - .stop() + .stopQuietly() assert(LastOptions.partitionColumns == Nil) df.write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .partitionBy("a") .startStream() - .stop() + .stopQuietly() assert(LastOptions.partitionColumns == Seq("a")) withSQLConf("spark.sql.caseSensitive" -> "false") { df.write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .partitionBy("A") .startStream() - .stop() + .stopQuietly() assert(LastOptions.partitionColumns == Seq("a")) } intercept[AnalysisException] { df.write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .partitionBy("b") .startStream() - .stop() + .stopQuietly() } } test("stream paths") { val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .stream("/test") assert(LastOptions.parameters("path") == "/test") @@ -163,8 +180,9 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B df.write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .startStream("/test") - .stop() + .stopQuietly() assert(LastOptions.parameters("path") == "/test") } @@ -187,8 +205,9 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .option("intOpt", 56) .option("boolOpt", false) .option("doubleOpt", 6.7) + .option("checkpointLocation", newMetadataDir) .startStream("/test") - .stop() + .stopQuietly() assert(LastOptions.parameters("intOpt") == "56") assert(LastOptions.parameters("boolOpt") == "false") @@ -204,6 +223,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .stream("/test") .write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .queryName(name) .startStream() } @@ -215,6 +235,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .stream("/test") .write .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) .startStream() } @@ -248,9 +269,9 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } // Should be able to start query with that name after stopping the previous query - q1.stop() + q1.stopQuietly() val q5 = startQueryWithName("name") assert(activeStreamNames.contains("name")) - sqlContext.streams.active.foreach(_.stop()) + sqlContext.streams.active.foreach(_.stopQuietly()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 4c18e38db8280..89de15acf506d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -318,16 +318,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } test("fault tolerance") { - def assertBatch(batch1: Option[Batch], batch2: Option[Batch]): Unit = { - (batch1, batch2) match { - case (Some(b1), Some(b2)) => - assert(b1.end === b2.end) - assert(b1.data.as[String].collect() === b2.data.as[String].collect()) - case (None, None) => - case _ => fail(s"batch ($batch1) is not equal to batch ($batch2)") - } - } - val src = Utils.createTempDir("streaming.src") val tmp = Utils.createTempDir("streaming.tmp") @@ -345,14 +335,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") ) - val textSource2 = createFileStreamSource("text", src.getCanonicalPath) - assert(textSource2.currentOffset === textSource.currentOffset) - assertBatch(textSource2.getNextBatch(None), textSource.getNextBatch(None)) - for (f <- 0L to textSource.currentOffset.offset) { - val offset = LongOffset(f) - assertBatch(textSource2.getNextBatch(Some(offset)), textSource.getNextBatch(Some(offset))) - } - Utils.deleteRecursively(src) Utils.deleteRecursively(tmp) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala index 52783281abb00..d04783ecacbb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -61,7 +61,7 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with // The source and sink offsets must be None as this must be called before the // batches have started assert(status.sourceStatuses(0).offset === None) - assert(status.sinkStatus.offset === None) + assert(status.sinkStatus.offset === CompositeOffset(None :: Nil)) // No progress events or termination events assert(listener.progressStatuses.isEmpty) @@ -78,7 +78,7 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with assert(status != null) assert(status.active == true) assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) - assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))) + assert(status.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))) // No termination events assert(listener.terminationStatus === null) @@ -92,7 +92,7 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with assert(status.active === false) // must be inactive by the time onQueryTerm is called assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) - assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))) + assert(status.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))) } listener.checkAsyncErrors() } From 297c20226d3330309c9165d789749458f8f4ab8e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Mar 2016 11:37:37 -0700 Subject: [PATCH 03/26] [SPARK-14063][SQL] SQLContext.range should return Dataset[java.lang.Long] ## What changes were proposed in this pull request? This patch changed the return type for SQLContext.range from `Dataset[Long]` (Scala primitive) to `Dataset[java.lang.Long]` (Java boxed long). Previously, SPARK-13894 changed the return type of range from `Dataset[Row]` to `Dataset[Long]`. The problem is that due to https://issues.scala-lang.org/browse/SI-4388, Scala compiles primitive types in generics into just Object, i.e. range at bytecode level now just returns `Dataset[Object]`. This is really bad for Java users because they are losing type safety and also need to add a type cast every time they use range. Talked to Jason Zaugg from Lightbend (Typesafe) who suggested the best approach is to return `Dataset[java.lang.Long]`. The downside is that when Scala users want to explicitly type a closure used on the dataset returned by range, they would need to use `java.lang.Long` instead of the Scala `Long`. ## How was this patch tested? The signature change should be covered by existing unit tests and API tests. I also added a new test case in DatasetSuite for range. Author: Reynold Xin Closes #11880 from rxin/SPARK-14063. --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 10 +++++----- .../test/org/apache/spark/sql/JavaDataFrameSuite.java | 4 ++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d562f55e9f26a..efaccec262e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -725,7 +725,7 @@ class SQLContext private[sql]( * @group dataset */ @Experimental - def range(end: Long): Dataset[Long] = range(0, end) + def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** * :: Experimental :: @@ -736,7 +736,7 @@ class SQLContext private[sql]( * @group dataset */ @Experimental - def range(start: Long, end: Long): Dataset[Long] = { + def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) } @@ -749,7 +749,7 @@ class SQLContext private[sql]( * @group dataset */ @Experimental - def range(start: Long, end: Long, step: Long): Dataset[Long] = { + def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) } @@ -763,8 +763,8 @@ class SQLContext private[sql]( * @group dataset */ @Experimental - def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[Long] = { - new Dataset(this, Range(start, end, step, numPartitions), implicits.newLongEncoder) + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { + new Dataset(this, Range(start, end, step, numPartitions), Encoders.LONG) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index cf764c645f9ee..10ee7d57c7390 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -329,7 +329,7 @@ public void testTextLoad() { @Test public void testCountMinSketch() { - Dataset df = context.range(1000); + Dataset df = context.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -354,7 +354,7 @@ public void testCountMinSketch() { @Test public void testBloomFilter() { - Dataset df = context.range(1000); + Dataset df = context.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 04d3a25fcb4f0..677f84eb60cc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -44,6 +44,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + test("range") { + assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55) + assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55) + assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) + assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + } + test("SPARK-12404: Datatype Helper Serializability") { val ds = sparkContext.parallelize(( new Timestamp(0), From 7e3423b9c03c9812d404134c3d204c4cfea87721 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 22 Mar 2016 12:11:23 -0700 Subject: [PATCH 04/26] [SPARK-13951][ML][PYTHON] Nested Pipeline persistence Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835 --- python/pyspark/ml/classification.py | 8 +- python/pyspark/ml/clustering.py | 4 +- python/pyspark/ml/feature.py | 89 +++++++++-------- python/pyspark/ml/pipeline.py | 150 ++++++++++++++-------------- python/pyspark/ml/recommendation.py | 4 +- python/pyspark/ml/regression.py | 12 +-- python/pyspark/ml/tests.py | 82 +++++++++++---- python/pyspark/ml/util.py | 89 ++++++++++++++--- python/pyspark/ml/wrapper.py | 37 +++++-- 9 files changed, 300 insertions(+), 175 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 16ad76483de63..8075108114c18 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -38,7 +38,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, - HasWeightCol, MLWritable, MLReadable): + HasWeightCol, JavaMLWritable, JavaMLReadable): """ Logistic regression. Currently, this class only supports binary classification. @@ -198,7 +198,7 @@ def _checkThresholdConsistency(self): " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) -class LogisticRegressionModel(JavaModel, MLWritable, MLReadable): +class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LogisticRegression. @@ -601,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels): @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, MLWritable, MLReadable): + HasRawPredictionCol, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. Multinomial NB @@ -720,7 +720,7 @@ def getModelType(self): return self.getOrDefault(self.modelType) -class NaiveBayesModel(JavaModel, MLWritable, MLReadable): +class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by NaiveBayes. diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 1cea477acb47d..2db5b82c44543 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -25,7 +25,7 @@ 'KMeans', 'KMeansModel'] -class KMeansModel(JavaModel, MLWritable, MLReadable): +class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -48,7 +48,7 @@ def computeCost(self, dataset): @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - MLWritable, MLReadable): + JavaMLWritable, JavaMLReadable): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3182faac0de0f..16cb9d1db3ea7 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,7 +22,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * -from pyspark.ml.util import keyword_only, MLReadable, MLWritable +from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector @@ -58,7 +58,7 @@ @inherit_doc -class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -123,7 +123,7 @@ def getThreshold(self): @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -198,7 +198,7 @@ def getSplits(self): @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -324,7 +324,7 @@ def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, MLReadable, MLWritable): +class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -343,7 +343,7 @@ def vocabulary(self): @inherit_doc -class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -415,7 +415,8 @@ def getInverse(self): @inherit_doc -class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -481,7 +482,8 @@ def getScalingVec(self): @inherit_doc -class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLReadable, MLWritable): +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -529,7 +531,7 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): @inherit_doc -class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -604,7 +606,7 @@ def _create_model(self, java_model): return IDFModel(java_model) -class IDFModel(JavaModel, MLReadable, MLWritable): +class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -615,7 +617,7 @@ class IDFModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -676,7 +678,7 @@ def _create_model(self, java_model): return MaxAbsScalerModel(java_model) -class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable): +class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -695,7 +697,7 @@ def maxAbs(self): @inherit_doc -class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -802,7 +804,7 @@ def _create_model(self, java_model): return MinMaxScalerModel(java_model) -class MinMaxScalerModel(JavaModel, MLReadable, MLWritable): +class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -830,7 +832,7 @@ def originalMax(self): @inherit_doc @ignore_unicode_prefix -class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -908,7 +910,7 @@ def getN(self): @inherit_doc -class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -974,7 +976,7 @@ def getP(self): @inherit_doc -class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1056,7 +1058,8 @@ def getDropLast(self): @inherit_doc -class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1123,8 +1126,8 @@ def getDegree(self): @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLReadable, - MLWritable): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1213,7 +1216,7 @@ def _create_model(self, java_model): @inherit_doc @ignore_unicode_prefix -class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1345,7 +1348,7 @@ def getToLowercase(self): @inherit_doc -class SQLTransformer(JavaTransformer, MLReadable, MLWritable): +class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1406,7 +1409,7 @@ def getStatement(self): @inherit_doc -class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1499,7 +1502,7 @@ def _create_model(self, java_model): return StandardScalerModel(java_model) -class StandardScalerModel(JavaModel, MLReadable, MLWritable): +class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1526,8 +1529,8 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, MLReadable, - MLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1591,7 +1594,7 @@ def _create_model(self, java_model): return StringIndexerModel(java_model) -class StringIndexerModel(JavaModel, MLReadable, MLWritable): +class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1610,7 +1613,7 @@ def labels(self): @inherit_doc -class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1664,7 +1667,7 @@ def getLabels(self): return self.getOrDefault(self.labels) -class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1751,7 +1754,7 @@ def getCaseSensitive(self): @inherit_doc @ignore_unicode_prefix -class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1806,7 +1809,7 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, MLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1852,7 +1855,7 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1969,7 +1972,7 @@ def _create_model(self, java_model): return VectorIndexerModel(java_model) -class VectorIndexerModel(JavaModel, MLReadable, MLWritable): +class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1998,7 +2001,7 @@ def categoryMaps(self): @inherit_doc -class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2093,7 +2096,7 @@ def getNames(self): @inherit_doc @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol, - MLReadable, MLWritable): + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2226,7 +2229,7 @@ def _create_model(self, java_model): return Word2VecModel(java_model) -class Word2VecModel(JavaModel, MLReadable, MLWritable): +class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2257,7 +2260,7 @@ def findSynonyms(self, word, num): @inherit_doc -class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2331,7 +2334,7 @@ def _create_model(self, java_model): return PCAModel(java_model) -class PCAModel(JavaModel, MLReadable, MLWritable): +class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2360,7 +2363,7 @@ def explainedVariance(self): @inherit_doc -class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable): +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2463,7 +2466,7 @@ def _create_model(self, java_model): return RFormulaModel(java_model) -class RFormulaModel(JavaModel, MLReadable, MLWritable): +class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2474,8 +2477,8 @@ class RFormulaModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, MLReadable, - MLWritable): +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -2561,7 +2564,7 @@ def _create_model(self, java_model): return ChiSqSelectorModel(java_model) -class ChiSqSelectorModel(JavaModel, MLReadable, MLWritable): +class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a1658b0a0254b..2b5504bc2966a 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -24,72 +24,31 @@ from pyspark import since from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc -def _stages_java2py(java_stages): - """ - Transforms the parameter Python stages from a list of Java stages. - :param java_stages: An array of Java stages. - :return: An array of Python stages. - """ - - return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages] - - -def _stages_py2java(py_stages, cls): - """ - Transforms the parameter of Python stages to a Java array of Java stages. - :param py_stages: An array of Python stages. - :return: A Java array of Java Stages. - """ - - for stage in py_stages: - assert(isinstance(stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - gateway = SparkContext._gateway - java_stages = gateway.new_array(cls, len(py_stages)) - for idx, stage in enumerate(py_stages): - java_stages[idx] = stage._transfer_stage_to_java() - return java_stages - - @inherit_doc -class PipelineMLWriter(JavaMLWriter, JavaWrapper): +class PipelineMLWriter(JavaMLWriter): """ Private Pipeline utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage - self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid) - self._java_obj.setStages(_stages_py2java(instance.getStages(), cls)) - self._jwrite = self._java_obj.write() + We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java. + """ @inherit_doc class PipelineMLReader(JavaMLReader): """ Private utility class that can load Pipeline instances through their Scala implementation. - """ - def load(self, path): - """Load the Pipeline instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - - java_obj = self._jread.load(path) - instance = self._clazz() - instance._resetUid(java_obj.uid()) - instance.setStages(_stages_java2py(java_obj.getStages())) - - return instance + We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java. + """ @inherit_doc -class Pipeline(Estimator): +class Pipeline(Estimator, MLReadable, MLWritable): """ A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an @@ -206,49 +165,65 @@ def save(self, path): @classmethod @since("2.0.0") def read(cls): - """Returns an JavaMLReader instance for this class.""" + """Returns an MLReader instance for this class.""" return PipelineMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java Pipeline, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Create a new instance of this stage. + py_stage = cls() + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stage.setStages(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java Pipeline. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage + java_stages = gateway.new_array(cls, len(self.getStages())) + for idx, stage in enumerate(self.getStages()): + java_stages[idx] = stage._to_java() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj.setStages(java_stages) + + return _java_obj @inherit_doc -class PipelineModelMLWriter(JavaMLWriter, JavaWrapper): +class PipelineModelMLWriter(JavaMLWriter): """ Private PipelineModel utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.Transformer - self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel", - instance.uid, - _stages_py2java(instance.stages, cls)) - self._jwrite = self._java_obj.write() + We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements + _to_java. + """ @inherit_doc class PipelineModelMLReader(JavaMLReader): """ Private utility class that can load PipelineModel instances through their Scala implementation. - """ - def load(self, path): - """Load the PipelineModel instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - java_obj = self._jread.load(path) - instance = self._clazz(_stages_java2py(java_obj.stages())) - instance._resetUid(java_obj.uid()) - return instance + We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements + _from_java. + """ @inherit_doc -class PipelineModel(Model): +class PipelineModel(Model, MLReadable, MLWritable): """ Represents a compiled pipeline with transformers and fitted models. @@ -294,7 +269,32 @@ def read(cls): return PipelineModelMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java PipelineModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + # Create a new instance of this stage. + py_stage = cls(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java PipelineModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.Transformer + java_stages = gateway.new_array(cls, len(self.stages)) + for idx, stage in enumerate(self.stages): + java_stages[idx] = stage._to_java() + + _java_obj =\ + JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + + return _java_obj diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 2b605e5c5078b..de4c2675ed793 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -27,7 +27,7 @@ @inherit_doc class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed, - MLWritable, MLReadable): + JavaMLWritable, JavaMLReadable): """ Alternating Least Squares (ALS) matrix factorization. @@ -289,7 +289,7 @@ def getNonnegative(self): return self.getOrDefault(self.nonnegative) -class ALSModel(JavaModel, MLWritable, MLReadable): +class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by ALS. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 6e23393f9102f..664a44bc473ac 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): + HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable): """ Linear regression. @@ -118,7 +118,7 @@ def _create_model(self, java_model): return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel, MLWritable, MLReadable): +class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LinearRegression. @@ -154,7 +154,7 @@ def intercept(self): @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasWeightCol, MLWritable, MLReadable): + HasWeightCol, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -249,7 +249,7 @@ def getFeatureIndex(self): return self.getOrDefault(self.featureIndex) -class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable): +class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -719,7 +719,7 @@ class GBTRegressionModel(TreeEnsembleModels): @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable): + HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -857,7 +857,7 @@ def getQuantilesCol(self): return self.getOrDefault(self.quantilesCol) -class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable): +class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by AFTSurvivalRegression. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9783ce7e77bd4..211248e8b2a23 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -47,6 +47,7 @@ from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.linalg import DenseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -517,7 +518,39 @@ def test_logistic_regression(self): except OSError: pass + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaWrapper): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() @@ -527,33 +560,46 @@ def test_pipeline_persistence(self): pca = PCA(k=2, inputCol="features", outputCol="pca_features") pl = Pipeline(stages=[tf, pca]) model = pl.fit(df) + pipeline_path = temp_path + "/pipeline" pl.save(pipeline_path) loaded_pipeline = Pipeline.load(pipeline_path) - self.assertEqual(loaded_pipeline.uid, pl.uid) - self.assertEqual(len(loaded_pipeline.getStages()), 2) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass - [loaded_tf, loaded_pca] = loaded_pipeline.getStages() - self.assertIsInstance(loaded_tf, HashingTF) - self.assertEqual(loaded_tf.uid, tf.uid) - param = loaded_tf.getParam("numFeatures") - self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param)) + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) - self.assertIsInstance(loaded_pca, PCA) - self.assertEqual(loaded_pca.uid, pca.uid) - self.assertEqual(loaded_pca.getK(), pca.getK()) + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) model_path = temp_path + "/pipeline-model" model.save(model_path) loaded_model = PipelineModel.load(model_path) - [model_tf, model_pca] = model.stages - [loaded_model_tf, loaded_model_pca] = loaded_model.stages - self.assertEqual(model_tf.uid, loaded_model_tf.uid) - self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param)) - - self.assertEqual(model_pca.uid, loaded_model_pca.uid) - self.assertEqual(model_pca.pc, loaded_model_pca.pc) - self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance) + self._compare_pipelines(model, loaded_model) finally: try: rmtree(temp_path) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 42801c91bbbd3..670385126294d 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -74,18 +74,38 @@ def _randomUID(cls): @inherit_doc -class JavaMLWriter(object): +class MLWriter(object): """ .. note:: Experimental - Utility class that can save ML instances through their Scala implementation. + Utility class that can save ML instances. .. versionadded:: 2.0.0 """ + def save(self, path): + """Save the ML instance to the input path.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def overwrite(self): + """Overwrites if the output path already exists.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + """ + def __init__(self, instance): - instance._transfer_params_to_java() - self._jwrite = instance._java_obj.write() + super(JavaMLWriter, self).__init__() + _java_obj = instance._to_java() + self._jwrite = _java_obj.write() def save(self, path): """Save the ML instance to the input path.""" @@ -109,14 +129,14 @@ class MLWritable(object): """ .. note:: Experimental - Mixin for ML instances that provide JavaMLWriter. + Mixin for ML instances that provide :py:class:`MLWriter`. .. versionadded:: 2.0.0 """ def write(self): """Returns an JavaMLWriter instance for this ML instance.""" - return JavaMLWriter(self) + raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" @@ -124,15 +144,41 @@ def save(self, path): @inherit_doc -class JavaMLReader(object): +class JavaMLWritable(MLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + +@inherit_doc +class MLReader(object): """ .. note:: Experimental - Utility class that can load ML instances through their Scala implementation. + Utility class that can load ML instances. .. versionadded:: 2.0.0 """ + def load(self, path): + """Load the ML instance from the input path.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + """ + def __init__(self, clazz): self._clazz = clazz self._jread = self._load_java_obj(clazz).read() @@ -142,11 +188,10 @@ def load(self, path): if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) java_obj = self._jread.load(path) - instance = self._clazz() - instance._java_obj = java_obj - instance._resetUid(java_obj.uid()) - instance._transfer_params_from_java() - return instance + if not hasattr(self._clazz, "_from_java"): + raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" + % self._clazz) + return self._clazz._from_java(java_obj) def context(self, sqlContext): """Sets the SQL context to use for loading.""" @@ -164,7 +209,7 @@ def _java_loader_class(cls, clazz): if clazz.__name__ in ("Pipeline", "PipelineModel"): # Remove the last package name "pipeline" for Pipeline and PipelineModel. java_package = ".".join(java_package.split(".")[0:-1]) - return ".".join([java_package, clazz.__name__]) + return java_package + "." + clazz.__name__ @classmethod def _load_java_obj(cls, clazz): @@ -181,7 +226,7 @@ class MLReadable(object): """ .. note:: Experimental - Mixin for instances that provide JavaMLReader. + Mixin for instances that provide :py:class:`MLReader`. .. versionadded:: 2.0.0 """ @@ -189,9 +234,21 @@ class MLReadable(object): @classmethod def read(cls): """Returns an JavaMLReader instance for this class.""" - return JavaMLReader(cls) + raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls) @classmethod def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" return cls.read().load(path) + + +@inherit_doc +class JavaMLReadable(MLReadable): + """ + (Private) Mixin for instances that provide JavaMLReader. + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 37dcb23b6776b..35b0eba9267b6 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -95,12 +95,26 @@ def _empty_java_param_map(): """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _transfer_stage_to_java(self): + def _to_java(self): + """ + Transfer this instance's Params to the wrapped Java object, and return the Java object. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method. + + :return: Java object equivalent to this instance. + """ self._transfer_params_to_java() return self._java_obj @staticmethod - def _transfer_stage_from_java(java_stage): + def _from_java(java_stage): + """ + Given a Java object, create and return a Python wrapper of it. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method as a classmethod. + """ def __get_class(clazz): """ Loads Python class from its name. @@ -113,13 +127,18 @@ def __get_class(clazz): return m stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. - py_stage = __get_class(stage_name)() - assert(isinstance(py_stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - # Load information from java_stage to the instance. - py_stage._java_obj = java_stage - py_stage._resetUid(java_stage.uid()) - py_stage._transfer_params_from_java() + py_type = __get_class(stage_name) + if issubclass(py_type, JavaWrapper): + # Load information from java_stage to the instance. + py_stage = py_type() + py_stage._java_obj = java_stage + py_stage._resetUid(java_stage.uid()) + py_stage._transfer_params_from_java() + elif hasattr(py_type, "_from_java"): + py_stage = py_type._from_java(java_stage) + else: + raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" + % stage_name) return py_stage From b2b1ad7d4cc3b3469c3d2c841b40b58ed0e34447 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Mar 2016 13:48:03 -0700 Subject: [PATCH 05/26] [SPARK-14060][SQL] Move StringToColumn implicit class into SQLImplicits ## What changes were proposed in this pull request? This patch moves StringToColumn implicit class into SQLImplicits. This was kept in SQLContext.implicits object for binary backward compatibility, in the Spark 1.x series. It makes more sense for this API to be in SQLImplicits since that's the single class that defines all the SQL implicits. ## How was this patch tested? Should be covered by existing unit tests. Author: Reynold Xin Author: Wenchen Fan Closes #11878 from rxin/SPARK-14060. --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 12 ------------ .../scala/org/apache/spark/sql/SQLImplicits.scala | 11 +++++++++++ .../org/apache/spark/sql/test/SQLTestUtils.scala | 7 ------- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index efaccec262e3a..c070e867c953d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -339,18 +339,6 @@ class SQLContext private[sql]( @Experimental object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = self - - /** - * Converts $"col name" into an [[Column]]. - * - * @since 1.3.0 - */ - // This must live here to preserve binary compatibility with Spark < 1.5. - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index fd814e0f28e97..4aab16b866cf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -36,6 +36,17 @@ abstract class SQLImplicits { protected def _sqlContext: SQLContext + /** + * Converts $"col name" into an [[Column]]. + * + * @since 2.0.0 + */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 926fabe611c59..ab3876728bea8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -66,13 +66,6 @@ private[sql] trait SQLTestUtils */ protected object testImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.sqlContext - - // This must live here to preserve binary compatibility with Spark < 1.5. - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } } /** From d6dc12ef0146ae409834c78737c116050961f350 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 22 Mar 2016 14:16:51 -0700 Subject: [PATCH 06/26] [SPARK-13449] Naive Bayes wrapper in SparkR ## What changes were proposed in this pull request? This PR continues the work in #11486 from yinxusen with some code refactoring. In R package e1071, `naiveBayes` supports both categorical (Bernoulli) and continuous features (Gaussian), while in MLlib we support Bernoulli and multinomial. This PR implements the common subset: Bernoulli. I moved the implementation out from SparkRWrappers to NaiveBayesWrapper to make it easier to read. Argument names, default values, and summary now match e1071's naiveBayes. I removed the preprocess part that omit NA values because we don't know which columns to process. ## How was this patch tested? Test against output from R package e1071's naiveBayes. cc: yanboliang yinxusen Closes #11486 Author: Xusen Yin Author: Xiangrui Meng Closes #11890 from mengxr/SPARK-13449. --- R/pkg/DESCRIPTION | 3 +- R/pkg/NAMESPACE | 3 +- R/pkg/R/generics.R | 4 + R/pkg/R/mllib.R | 91 ++++++++++++++++++- R/pkg/inst/tests/testthat/test_mllib.R | 59 ++++++++++++ .../apache/spark/ml/r/NaiveBayesWrapper.scala | 75 +++++++++++++++ 6 files changed, 228 insertions(+), 7 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 0cd0d75df0f70..e26f9a7a2ab6c 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -11,7 +11,8 @@ Depends: R (>= 3.0), methods, Suggests: - testthat + testthat, + e1071 Description: R frontend for Spark License: Apache License (== 2.0) Collate: diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 636d39e1e9cae..5d8a4b1d6ed82 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -15,7 +15,8 @@ exportMethods("glm", "predict", "summary", "kmeans", - "fitted") + "fitted", + "naiveBayes") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6ad71fcb46712..46b115f45e53c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1175,3 +1175,7 @@ setGeneric("kmeans") #' @rdname fitted #' @export setGeneric("fitted") + +#' @rdname naiveBayes +#' @export +setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 5c0d3dcf3af90..25550193690bb 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -22,6 +22,11 @@ #' @export setClass("PipelineModel", representation(model = "jobj")) +#' @title S4 class that represents a NaiveBayesModel +#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper +#' @export +setClass("NaiveBayesModel", representation(jobj = "jobj")) + #' Fits a generalized linear model #' #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. @@ -42,7 +47,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' @rdname glm #' @export #' @examples -#'\dontrun{ +#' \dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' data(iris) @@ -71,7 +76,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' @rdname predict #' @export #' @examples -#'\dontrun{ +#' \dontrun{ #' model <- glm(y ~ x, trainingData) #' predicted <- predict(model, testData) #' showDF(predicted) @@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"), return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) +#' Make predictions from a naive Bayes model +#' +#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict. +#' +#' @param object A fitted naive Bayes model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted labels in a column named "prediction" +#' @rdname predict +#' @export +#' @examples +#' \dontrun{ +#' model <- naiveBayes(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "NaiveBayesModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) + }) + #' Get the summary of a model #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). @@ -97,7 +122,7 @@ setMethod("predict", signature(object = "PipelineModel"), #' @rdname summary #' @export #' @examples -#'\dontrun{ +#' \dontrun{ #' model <- glm(y ~ x, trainingData) #' summary(model) #'} @@ -140,6 +165,35 @@ setMethod("summary", signature(object = "PipelineModel"), } }) +#' Get the summary of a naive Bayes model +#' +#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary(). +#' +#' @param object A fitted MLlib model +#' @return a list containing 'apriori', the label distribution, and 'tables', conditional +# probabilities given the target label +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- naiveBayes(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "NaiveBayesModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + labels <- callJMethod(jobj, "labels") + apriori <- callJMethod(jobj, "apriori") + apriori <- t(as.matrix(unlist(apriori))) + colnames(apriori) <- unlist(labels) + tables <- callJMethod(jobj, "tables") + tables <- matrix(tables, nrow = length(labels)) + rownames(tables) <- unlist(labels) + colnames(tables) <- unlist(features) + return(list(apriori = apriori, tables = tables)) + }) + #' Fit a k-means model #' #' Fit a k-means model, similarly to R's kmeans(). @@ -152,7 +206,7 @@ setMethod("summary", signature(object = "PipelineModel"), #' @rdname kmeans #' @export #' @examples -#'\dontrun{ +#' \dontrun{ #' model <- kmeans(x, centers = 2, algorithm="random") #'} setMethod("kmeans", signature(x = "DataFrame"), @@ -173,7 +227,7 @@ setMethod("kmeans", signature(x = "DataFrame"), #' @rdname fitted #' @export #' @examples -#'\dontrun{ +#' \dontrun{ #' model <- kmeans(trainingData, 2) #' fitted.model <- fitted(model) #' showDF(fitted.model) @@ -192,3 +246,30 @@ setMethod("fitted", signature(object = "PipelineModel"), stop(paste("Unsupported model", modelName, sep = " ")) } }) + +#' Fit a Bernoulli naive Bayes model +#' +#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only +#' categorical features are supported. The input should be a DataFrame of observations instead of a +#' contingency table. +#' +#' @param object A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param data DataFrame for training +#' @param laplace Smoothing parameter +#' @return a fitted naive Bayes model +#' @rdname naiveBayes +#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/} +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(sqlContext, infert) +#' model <- naiveBayes(education ~ ., df, laplace = 0) +#'} +setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"), + function(formula, data, laplace = 0, ...) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", + formula, data@sdf, laplace) + return(new("NaiveBayesModel", jobj = jobj)) + }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index e120462964d1e..44b48369ef2b5 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -141,3 +141,62 @@ test_that("kmeans", { cluster <- summary.model$cluster expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) }) + +test_that("naiveBayes", { + # R code to reproduce the result. + # We do not support instance weights yet. So we ignore the frequencies. + # + #' library(e1071) + #' t <- as.data.frame(Titanic) + #' t1 <- t[t$Freq > 0, -5] + #' m <- naiveBayes(Survived ~ ., data = t1) + #' m + #' predict(m, t1) + # + # -- output of 'm' + # + # A-priori probabilities: + # Y + # No Yes + # 0.4166667 0.5833333 + # + # Conditional probabilities: + # Class + # Y 1st 2nd 3rd Crew + # No 0.2000000 0.2000000 0.4000000 0.2000000 + # Yes 0.2857143 0.2857143 0.2857143 0.1428571 + # + # Sex + # Y Male Female + # No 0.5 0.5 + # Yes 0.5 0.5 + # + # Age + # Y Child Adult + # No 0.2000000 0.8000000 + # Yes 0.4285714 0.5714286 + # + # -- output of 'predict(m, t1)' + # + # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No + # + + t <- as.data.frame(Titanic) + t1 <- t[t$Freq > 0, -5] + df <- suppressWarnings(createDataFrame(sqlContext, t1)) + m <- naiveBayes(Survived ~ ., data = df) + s <- summary(m) + expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) + p <- collect(select(predict(m, df), "prediction")) + expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", + "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", + "Yes", "Yes", "No", "No")) + + # Test e1071::naiveBayes + if (requireNamespace("e1071", quietly = TRUE)) { + expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error())) + expect_equal(as.character(predict(m, t1[1, ])), "Yes") + } +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala new file mode 100644 index 0000000000000..07383d393d637 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} +import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.sql.DataFrame + +private[r] class NaiveBayesWrapper private ( + pipeline: PipelineModel, + val labels: Array[String], + val features: Array[String]) { + + import NaiveBayesWrapper._ + + private val naiveBayesModel: NaiveBayesModel = pipeline.stages(1).asInstanceOf[NaiveBayesModel] + + lazy val apriori: Array[Double] = naiveBayesModel.pi.toArray.map(math.exp) + + lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) + } +} + +private[r] object NaiveBayesWrapper { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + + def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = { + val rFormula = new RFormula() + .setFormula(formula) + .fit(data) + // get labels and feature names from output schema + val schema = rFormula.transform(data).schema + val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + // assemble and fit the pipeline + val naiveBayes = new NaiveBayes() + .setSmoothing(laplace) + .setModelType("bernoulli") + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() + .setStages(Array(rFormula, naiveBayes, idxToStr)) + .fit(data) + new NaiveBayesWrapper(pipeline, labels, features) + } +} From d16710b4c986f0eaf28552ce0e2db33d8c9343b8 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 22 Mar 2016 16:41:55 -0700 Subject: [PATCH 07/26] [HOTFIX][SQL] Add a timeout for 'cq.stop' ## What changes were proposed in this pull request? Fix an issue that DataFrameReaderWriterSuite may hang forever. ## How was this patch tested? Existing tests. Author: Shixiong Zhu Closes #11902 from zsxwing/hotfix. --- .../test/scala/org/apache/spark/sql/StreamTest.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index f356cde9cf87a..26c597bf349b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import org.scalatest.Assertions import org.scalatest.concurrent.{Eventually, Timeouts} +import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -67,7 +68,14 @@ trait StreamTest extends QueryTest with Timeouts { implicit class RichContinuousQuery(cq: ContinuousQuery) { def stopQuietly(): Unit = quietly { - cq.stop() + try { + failAfter(10.seconds) { + cq.stop() + } + } catch { + case e: TestFailedDueToTimeoutException => + logError(e.getMessage(), e) + } } } From 4700adb98e4a37c2b0ef7123eca8a9a03bbdbe78 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 22 Mar 2016 16:45:20 -0700 Subject: [PATCH 08/26] [SPARK-13806] [SQL] fix rounding mode of negative float/double ## What changes were proposed in this pull request? Round() in database usually round the number up (away from zero), it's different than Math.round() in Java. For example: ``` scala> java.lang.Math.round(-3.5) res3: Long = -3 ``` In Database, we should return -4.0 in this cases. This PR remove the buggy special case for scale=0. ## How was this patch tested? Add tests for negative values with tie. Author: Davies Liu Closes #11894 from davies/fix_round. --- .../expressions/mathExpressions.scala | 48 ++++++------------- .../expressions/MathFunctionsSuite.scala | 4 ++ 2 files changed, 19 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 12fcc40376e10..e3d1bc127d2e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -748,7 +748,7 @@ case class Round(child: Expression, scale: Expression) if (f.isNaN || f.isInfinite) { f } else { - BigDecimal(f).setScale(_scale, HALF_UP).toFloat + BigDecimal(f.toDouble).setScale(_scale, HALF_UP).toFloat } case DoubleType => val d = input1.asInstanceOf[Double] @@ -804,39 +804,21 @@ case class Round(child: Expression, scale: Expression) s"${ev.value} = ${ce.value};" } case FloatType => // if child eval to NaN or Infinity, just return it. - if (_scale == 0) { - s""" - if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = Math.round(${ce.value}); - }""" - } else { - s""" - if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); - }""" - } + s""" + if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) { + ${ev.value} = ${ce.value}; + } else { + ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" case DoubleType => // if child eval to NaN or Infinity, just return it. - if (_scale == 0) { - s""" - if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = Math.round(${ce.value}); - }""" - } else { - s""" - if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); - }""" - } + s""" + if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) { + ${ev.value} = ${ce.value}; + } else { + ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" } if (scaleV == null) { // if scale is null, no need to eval its child at all diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index bd674dadd0fcc..27195d3458b8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -553,5 +553,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(Literal.create(null, dataType), Literal.create(null, IntegerType)), null) } + + checkEvaluation(Round(-3.5, 0), -4.0) + checkEvaluation(Round(-0.35, 1), -0.4) + checkEvaluation(Round(-35, -1), -40) } } From 0d51b60443ae78fa46988a6aed2397db9c35f96d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 22 Mar 2016 21:01:52 -0700 Subject: [PATCH 09/26] [SPARK-14072][CORE] Show JVM/OS version information when we run a benchmark program ## What changes were proposed in this pull request? This PR allows us to identify what JVM is used when someone ran a benchmark program. In some cases, a JVM version may affect performance result. Thus, it would be good to show processor information and JVM version information. ``` model name : Intel(R) Xeon(R) CPU E5-2697 v2 2.70GHz JVM information : OpenJDK 64-Bit Server VM, 1.7.0_65-mockbuild_2014_07_14_06_19-b00 Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 981 / 994 10.7 93.5 1.0X SQL Parquet MR 2518 / 2542 4.2 240.1 0.4X ``` ``` model name : Intel(R) Xeon(R) CPU E5-2697 v2 2.70GHz JVM information : IBM J9 VM, pxa6480sr2-20151023_01 (SR2) String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 693 / 740 15.1 66.1 1.0X SQL Parquet MR 2501 / 2562 4.2 238.5 0.3X ``` ## How was this patch tested? Tested by using existing benchmark programs (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Kazuaki Ishizaki Closes #11893 from kiszk/SPARK-14072. --- .../org/apache/spark/util/Benchmark.scala | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index b562b58f1b6bf..9e40bafd521d7 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -64,6 +64,7 @@ private[spark] class Benchmark( val firstBest = results.head.bestMs // The results are going to be processor specific so it is useful to include that. + println(Benchmark.getJVMOSInfo()) println(Benchmark.getProcessorName()) printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", "Per Row(ns)", "Relative") @@ -91,16 +92,31 @@ private[spark] object Benchmark { * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz" */ def getProcessorName(): String = { - if (SystemUtils.IS_OS_MAC_OSX) { + val cpu = if (SystemUtils.IS_OS_MAC_OSX) { Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string")) } else if (SystemUtils.IS_OS_LINUX) { Try { val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")) Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo")) + .replaceFirst("model name[\\s*]:[\\s*]", "") }.getOrElse("Unknown processor") } else { System.getenv("PROCESSOR_IDENTIFIER") } + cpu + } + + /** + * This should return a user helpful JVM & OS information. + * This should return something like + * "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64" + */ + def getJVMOSInfo(): String = { + val vmName = System.getProperty("java.vm.name") + val runtimeVersion = System.getProperty("java.runtime.version") + val osName = System.getProperty("os.name") + val osVersion = System.getProperty("os.version") + s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}" } /** From 75dc29620e8bf22aa56a55c0f2bc1b85800e84b1 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 22 Mar 2016 21:08:11 -0700 Subject: [PATCH 10/26] [SPARK-13401][SQL][TESTS] Fix SQL test warnings. ## What changes were proposed in this pull request? This fix tries to fix several SQL test warnings under the sql/core/src/test directory. The fixed warnings includes "[unchecked]", "[rawtypes]", and "[varargs]". ## How was this patch tested? All existing tests passed. Author: Yong Tang Closes #11857 from yongtang/SPARK-13401. --- .../datasources/parquet/test/avro/AvroArrayOfArray.java | 1 + .../execution/datasources/parquet/test/avro/AvroMapOfArray.java | 1 + .../datasources/parquet/test/avro/AvroNonNullableArrays.java | 1 + .../sql/execution/datasources/parquet/test/avro/Nested.java | 1 + .../datasources/parquet/test/avro/ParquetAvroCompat.java | 1 + .../test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 1 + 6 files changed, 6 insertions(+) diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java index ee327827903e5..8de0b06b162c6 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java @@ -129,6 +129,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfA } @Override + @SuppressWarnings(value="unchecked") public AvroArrayOfArray build() { try { AvroArrayOfArray record = new AvroArrayOfArray(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java index 727f6a7bf733e..29f3109f83a15 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java @@ -129,6 +129,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArr } @Override + @SuppressWarnings(value="unchecked") public AvroMapOfArray build() { try { AvroMapOfArray record = new AvroMapOfArray(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java index 934793f42f9c9..c5522ed1e53e5 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java @@ -182,6 +182,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNulla } @Override + @SuppressWarnings(value="unchecked") public AvroNonNullableArrays build() { try { AvroNonNullableArrays record = new AvroNonNullableArrays(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index a7bf4841919c5..f84e3f2d61efb 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -182,6 +182,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Build } @Override + @SuppressWarnings(value="unchecked") public Nested build() { try { Nested record = new Nested(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index ef12d193f916c..46fc608398ccf 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -235,6 +235,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroC } @Override + @SuppressWarnings(value="unchecked") public ParquetAvroCompat build() { try { ParquetAvroCompat record = new ParquetAvroCompat(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 4b8b0d9d4f8aa..3bff129ae2294 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -309,6 +309,7 @@ private static Set toSet(List records) { } @SafeVarargs + @SuppressWarnings("varargs") private static Set asSet(T... records) { return toSet(Arrays.asList(records)); } From 1a22cf1e9b6447005c9a329856d734d80a496a06 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 22 Mar 2016 23:07:49 -0700 Subject: [PATCH 11/26] [MINOR][SQL][DOCS] Update `sql/README.md` and remove some unused imports in `sql` module. ## What changes were proposed in this pull request? This PR updates `sql/README.md` according to the latest console output and removes some unused imports in `sql` module. This is done by manually, so there is no guarantee to remove all unused imports. ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #11907 from dongjoon-hyun/update_sql_module. --- sql/README.md | 9 +++++---- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 4 ---- .../main/scala/org/apache/spark/sql/SQLImplicits.scala | 4 ---- .../scala/org/apache/spark/sql/execution/SparkQl.scala | 1 - .../apache/spark/sql/execution/WholeStageCodegen.scala | 1 - .../src/test/scala/org/apache/spark/sql/QueryTest.scala | 1 - .../src/test/scala/org/apache/spark/sql/StreamTest.scala | 1 - .../org/apache/spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/execution/WholeStageCodegenSuite.scala | 1 - .../execution/datasources/FileSourceStrategySuite.scala | 4 ++-- .../datasources/parquet/ParquetReadBenchmark.scala | 1 - .../spark/sql/execution/joins/HashedRelationSuite.scala | 1 - 13 files changed, 9 insertions(+), 23 deletions(-) diff --git a/sql/README.md b/sql/README.md index 9ea271d33d856..b0903980a59f3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -47,7 +47,7 @@ An interactive scala console can be invoked by running `build/sbt hive/console`. From here you can execute queries with HiveQl and manipulate DataFrame by using DSL. ```scala -catalyst$ build/sbt hive/console +$ build/sbt hive/console [info] Starting scala interpreter... import org.apache.spark.sql.catalyst.analysis._ @@ -61,22 +61,23 @@ import org.apache.spark.sql.execution import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.types._ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.DataFrame = org.apache.spark.sql.DataFrame@74448eed +query: org.apache.spark.sql.DataFrame = [key: int, value: string] ``` Query results are `DataFrames` and can be operated as such. ``` scala> query.collect() -res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... +res0: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` You can also build further queries on top of these `DataFrames` using the query DSL. ``` scala> query.where(query("key") > 30).select(avg(query("key"))).collect() -res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) +res1: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) ``` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 01c1fa40dcfbd..ecf4285c46a51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c070e867c953d..542f2f4debbf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,13 +32,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} -import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.execution.datasources._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 4aab16b866cf2..c35a969bf031a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -21,11 +21,7 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index 11391bd12acae..ef30ba0cdbf55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.command._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index e3c7d7209af18..5634e5fc5861b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,7 +22,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.toCommentSafeString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 855295d5f2db2..a1b45ca7ebd19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{LogicalRDD, Queryable} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.SQLConf abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 26c597bf349b3..62dc492d60456 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -28,7 +28,6 @@ import scala.util.control.NonFatal import org.scalatest.Assertions import org.scalatest.concurrent.{Eventually, Timeouts} -import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index e9b65539b0d62..bdbcf842ca47d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 716c367eae551..6d5be0b5dda12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 2f8129c5da40d..4abc6d6a55ecd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import java.io.{File, FilenameFilter} +import java.io.File import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.Job @@ -28,7 +28,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.execution.{DataSourceScan, PhysicalRDD} +import org.apache.spark.sql.execution.DataSourceScan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index cc0cc65d3eb59..cef541f0444be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.SQLContext import org.apache.spark.util.{Benchmark, Utils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index dd20855a81d9a..e19b4ff1e2ff8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,7 +22,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer From 926a93e54b83f1ee596096f3301fef015705b627 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Mar 2016 23:43:09 -0700 Subject: [PATCH 12/26] [SPARK-14088][SQL] Some Dataset API touch-up ## What changes were proposed in this pull request? 1. Deprecated unionAll. It is pretty confusing to have both "union" and "unionAll" when the two do the same thing in Spark but are different in SQL. 2. Rename reduce in KeyValueGroupedDataset to reduceGroups so it is more consistent with rest of the functions in KeyValueGroupedDataset. Also makes it more obvious what "reduce" and "reduceGroups" mean. Previously it was confusing because it could be reducing a Dataset, or just reducing groups. 3. Added a "name" function, which is more natural to name columns than "as" for non-SQL users. 4. Remove "subtract" function since it is just an alias for "except". ## How was this patch tested? All changes should be covered by existing tests. Also added couple test cases to cover "name". Author: Reynold Xin Closes #11908 from rxin/SPARK-14088. --- project/MimaExcludes.scala | 1 + python/pyspark/sql/column.py | 2 ++ python/pyspark/sql/dataframe.py | 14 +++++++-- .../scala/org/apache/spark/sql/Column.scala | 29 +++++++++++++----- .../scala/org/apache/spark/sql/Dataset.scala | 30 +++++++------------ .../spark/sql/KeyValueGroupedDataset.scala | 11 ++----- .../apache/spark/sql/JavaDatasetSuite.java | 4 +-- .../spark/sql/ColumnExpressionSuite.scala | 3 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- 9 files changed, 56 insertions(+), 40 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 68e9c50d60f6a..42eafcb0f52d0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -317,6 +317,7 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 19ec6fcc5d6dc..43e9baece2de9 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -315,6 +315,8 @@ def alias(self, *alias): sc = SparkContext._active_spark_context return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.") + @ignore_unicode_prefix @since(1.3) def cast(self, dataType): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7e1854c43be3b..5cfc348a69caf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -911,14 +911,24 @@ def agg(self, *exprs): """ return self.groupBy().agg(*exprs) + @since(2.0) + def union(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union + (that does deduplication of elements), use this function followed by a distinct. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + @since(1.3) def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. - This is equivalent to `UNION ALL` in SQL. + .. note:: Deprecated in 2.0, use union instead. """ - return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + return self.union(other) @since(1.3) def intersect(self, other): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 622a62abad896..d64736e11110b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -856,7 +856,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def alias(alias: String): Column = as(alias) + def alias(alias: String): Column = name(alias) /** * Gives the column an alias. @@ -871,12 +871,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = withExpr { - expr match { - case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias)() - } - } + def as(alias: String): Column = name(alias) /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -936,6 +931,26 @@ class Column(protected[sql] val expr: Expression) extends Logging { Alias(expr, alias)(explicitMetadata = Some(metadata)) } + /** + * Gives the column a name (alias). + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".name("colB")) + * }}} + * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * + * @group expr_ops + * @since 2.0.0 + */ + def name(alias: String): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } + } + /** * Casts the column to a different data type. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index be0dfe7c3344a..31864d63ab595 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1350,20 +1350,24 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) - } + @deprecated("use union()", "2.0.0") + def unionAll(other: Dataset[T]): Dataset[T] = union(other) /** * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. * This is equivalent to `UNION ALL` in SQL. * + * To do a SQL-style set union (that does deduplication of elements), use this function followed + * by a [[distinct]]. + * * @group typedrel * @since 2.0.0 */ - def union(other: Dataset[T]): Dataset[T] = unionAll(other) + def union(other: Dataset[T]): Dataset[T] = withTypedPlan { + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, other.logicalPlan)) + } /** * Returns a new [[Dataset]] containing rows only in both this Dataset and another Dataset. @@ -1393,18 +1397,6 @@ class Dataset[T] private[sql]( Except(logicalPlan, other.logicalPlan) } - /** - * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * - * @group typedrel - * @since 2.0.0 - */ - def subtract(other: Dataset[T]): Dataset[T] = except(other) - /** * Returns a new [[Dataset]] by sampling a fraction of rows. * @@ -1756,7 +1748,7 @@ class Dataset[T] private[sql]( outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } - val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq // Pivot the data so each summary is one row row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index f0f96825e2683..8bb75bf2bf0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -190,7 +190,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def reduce(f: (V, V) => V): Dataset[(K, V)] = { + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) @@ -203,15 +203,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { - reduce(f.call _) + def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { + reduceGroups(f.call _) } - // This is here to prevent us from adding overloads that would be ambiguous. - @scala.annotation.varargs - private def agg(exprs: Column*): DataFrame = - groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) - private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3bff129ae2294..18f17a85a9dbd 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -204,7 +204,7 @@ public Iterator call(Integer key, Iterator values) { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); - Dataset> reduced = grouped.reduce(new ReduceFunction() { + Dataset> reduced = grouped.reduceGroups(new ReduceFunction() { @Override public String call(String v1, String v2) throws Exception { return v1 + v2; @@ -300,7 +300,7 @@ public void testSetOperation() { Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"), unioned.collectAsList()); - Dataset subtracted = ds.subtract(ds2); + Dataset subtracted = ds.except(ds2); Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index c2434e46f7ecd..351b03b38bad1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -105,10 +105,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row("a") :: Nil) } - test("alias") { + test("alias and name") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") assert(df.select(df("a").alias("b")).columns.head === "b") + assert(df.select(df("a").name("b")).columns.head === "b") } test("as propagates metadata") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 677f84eb60cc3..0bcc512d7137d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -305,7 +305,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, reduce") { val ds = Seq("abc", "xyz", "hello").toDS() - val agged = ds.groupByKey(_.length).reduce(_ + _) + val agged = ds.groupByKey(_.length).reduceGroups(_ + _) checkDataset( agged, From abacf5f258e9bc5c9218ddbee3909dfe5c08d0ea Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 23 Mar 2016 00:00:35 -0700 Subject: [PATCH 13/26] [HOTFIX][SQL] Don't stop ContinuousQuery in quietly ## What changes were proposed in this pull request? Try to fix a flaky hang ## How was this patch tested? Existing Jenkins test Author: Shixiong Zhu Closes #11909 from zsxwing/hotfix2. --- .../org/apache/spark/sql/StreamTest.scala | 13 ---------- .../DataFrameReaderWriterSuite.scala | 24 +++++++++---------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 62dc492d60456..2dd6416853a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -65,19 +65,6 @@ import org.apache.spark.util.Utils */ trait StreamTest extends QueryTest with Timeouts { - implicit class RichContinuousQuery(cq: ContinuousQuery) { - def stopQuietly(): Unit = quietly { - try { - failAfter(10.seconds) { - cq.stop() - } - } catch { - case e: TestFailedDueToTimeoutException => - logError(e.getMessage(), e) - } - } - } - implicit class RichSource(s: Source) { def toDF(): DataFrame = Dataset.newDataFrame(sqlContext, StreamingRelation(s)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index e485aa837b7ee..c1bab9b577bbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -72,7 +72,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B private def newMetadataDir = Utils.createTempDir("streaming.metadata").getCanonicalPath after { - sqlContext.streams.active.foreach(_.stopQuietly()) + sqlContext.streams.active.foreach(_.stop()) } test("resolve default source") { @@ -83,7 +83,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .startStream() - .stopQuietly() + .stop() } test("resolve full class") { @@ -94,7 +94,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .startStream() - .stopQuietly() + .stop() } test("options") { @@ -121,7 +121,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .options(map) .option("checkpointLocation", newMetadataDir) .startStream() - .stopQuietly() + .stop() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") @@ -137,7 +137,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .startStream() - .stopQuietly() + .stop() assert(LastOptions.partitionColumns == Nil) df.write @@ -145,7 +145,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .option("checkpointLocation", newMetadataDir) .partitionBy("a") .startStream() - .stopQuietly() + .stop() assert(LastOptions.partitionColumns == Seq("a")) withSQLConf("spark.sql.caseSensitive" -> "false") { @@ -154,7 +154,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .option("checkpointLocation", newMetadataDir) .partitionBy("A") .startStream() - .stopQuietly() + .stop() assert(LastOptions.partitionColumns == Seq("a")) } @@ -164,7 +164,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .option("checkpointLocation", newMetadataDir) .partitionBy("b") .startStream() - .stopQuietly() + .stop() } } @@ -182,7 +182,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .startStream("/test") - .stopQuietly() + .stop() assert(LastOptions.parameters("path") == "/test") } @@ -207,7 +207,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .option("doubleOpt", 6.7) .option("checkpointLocation", newMetadataDir) .startStream("/test") - .stopQuietly() + .stop() assert(LastOptions.parameters("intOpt") == "56") assert(LastOptions.parameters("boolOpt") == "false") @@ -269,9 +269,9 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } // Should be able to start query with that name after stopping the previous query - q1.stopQuietly() + q1.stop() val q5 = startQueryWithName("name") assert(activeStreamNames.contains("name")) - sqlContext.streams.active.foreach(_.stopQuietly()) + sqlContext.streams.active.foreach(_.stop()) } } From 4d955cd69452e34f74369e62bc741a5c749905a8 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 23 Mar 2016 10:51:58 +0000 Subject: [PATCH 14/26] [SPARK-14035][MLLIB] Make error message more verbose for mllib NaiveBayesSuite ## What changes were proposed in this pull request? Print more info about failed NaiveBayesSuite tests which have exhibited flakiness. ## How was this patch tested? Ran locally with incorrect check to cause failure. Author: Joseph K. Bradley Closes #11858 from jkbradley/naive-bayes-bug-log. --- .../classification/NaiveBayesSuite.scala | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index cffa1ab700f80..ab54cb06d5aab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -103,17 +104,24 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { piData: Array[Double], thetaData: Array[Array[Double]], model: NaiveBayesModel): Unit = { - def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { - (d1 - d2).abs <= precision - } - val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) - for (i <- modelIndex) { - assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) - } - for (i <- modelIndex) { - for (j <- 0 until thetaData(i._2).length) { - assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + val modelIndex = piData.indices.zip(model.labels.map(_.toInt)) + try { + for (i <- modelIndex) { + assert(math.exp(piData(i._2)) ~== math.exp(model.pi(i._1)) absTol 0.05) + for (j <- thetaData(i._2).indices) { + assert(math.exp(thetaData(i._2)(j)) ~== math.exp(model.theta(i._1)(j)) absTol 0.05) + } } + } catch { + case e: TestFailedException => + def arr2str(a: Array[Double]): String = a.mkString("[", ", ", "]") + def msg(orig: String): String = orig + "\nvalidateModelFit:\n" + + " piData: " + arr2str(piData) + "\n" + + " thetaData: " + thetaData.map(arr2str).mkString("\n") + "\n" + + " model.labels: " + arr2str(model.labels) + "\n" + + " model.pi: " + arr2str(model.pi) + "\n" + + " model.theta: " + model.theta.map(arr2str).mkString("\n") + throw e.modifyMessage(_.map(msg)) } } From 7d1175011c976756efcd4e4e4f70a8fd6f287026 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 23 Mar 2016 07:57:03 -0700 Subject: [PATCH 15/26] [SPARK-14074][SPARKR] Specify commit sha1 ID when using install_github to install intr package. ## What changes were proposed in this pull request? In dev/lint-r.R, `install_github` makes our builds depend on a unstable source. This may cause un-expected test failures and then build break. This PR adds a specified commit sha1 ID to `install_github` to get a stable source. ## How was this patch tested? dev/lint-r Author: Sun Rui Closes #11913 from sun-rui/SPARK-14074. --- dev/lint-r.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/lint-r.R b/dev/lint-r.R index 999eef571b824..87ee36d5c9b68 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -27,7 +27,7 @@ if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { # Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { - devtools::install_github("jimhester/lintr") + devtools::install_github("jimhester/lintr@a769c0b") } library(lintr) From cde086cb2a9a85406fc18d8e63e46425f614c15f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 24 Mar 2016 00:42:13 +0800 Subject: [PATCH 16/26] [SPARK-13817][SQL][MINOR] Renames Dataset.newDataFrame to Dataset.ofRows ## What changes were proposed in this pull request? This PR does the renaming as suggested by marmbrus in [this comment][1]. ## How was this patch tested? Existing tests. [1]: https://github.com/apache/spark/commit/6d37e1eb90054cdb6323b75fb202f78ece604b15#commitcomment-16654694 Author: Cheng Lian Closes #11889 from liancheng/spark-13817-follow-up. --- .../apache/spark/sql/DataFrameReader.scala | 8 +++--- .../scala/org/apache/spark/sql/Dataset.scala | 4 +-- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 8 +++--- .../org/apache/spark/sql/SQLContext.scala | 26 +++++++++---------- .../sql/execution/command/commands.scala | 2 +- .../execution/datasources/DataSource.scala | 2 +- .../datasources/InsertIntoDataSource.scala | 2 +- .../InsertIntoHadoopFsRelation.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 8 +++--- .../sql/execution/stat/FrequentItems.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 2 +- .../execution/streaming/StreamExecution.scala | 2 +- .../sql/execution/streaming/memory.scala | 2 +- .../org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../org/apache/spark/sql/StreamTest.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../spark/sql/hive/execution/commands.scala | 2 +- .../spark/sql/hive/SQLBuilderTest.scala | 2 +- .../execution/AggregationQuerySuite.scala | 2 +- 22 files changed, 44 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1d4693f54ff93..704535adaa60d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -129,7 +129,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) - Dataset.newDataFrame(sqlContext, LogicalRelation(dataSource.resolveRelation())) + Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())) } /** @@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) - Dataset.newDataFrame(sqlContext, StreamingRelation(dataSource.createSource())) + Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource())) } /** @@ -376,7 +376,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { parsedOptions) } - Dataset.newDataFrame( + Dataset.ofRows( sqlContext, LogicalRDD( schema.toAttributes, @@ -424,7 +424,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - Dataset.newDataFrame(sqlContext, + Dataset.ofRows(sqlContext, sqlContext.sessionState.catalog.lookupRelation( sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 31864d63ab595..ec0b3c78ed72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -53,7 +53,7 @@ private[sql] object Dataset { new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]]) } - def newDataFrame(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { + def ofRows(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { val qe = sqlContext.executePlan(logicalPlan) qe.assertAnalyzed() new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema)) @@ -2322,7 +2322,7 @@ class Dataset[T] private[sql]( /** A convenient function to wrap a logical plan and produce a DataFrame. */ @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { - Dataset.newDataFrame(sqlContext, logicalPlan) + Dataset.ofRows(sqlContext, logicalPlan) } /** A convenient function to wrap a logical plan and produce a Dataset. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 8bb75bf2bf0c1..07aa1515f3841 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -59,7 +59,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private def groupedData = { new RelationalGroupedDataset( - Dataset.newDataFrame(sqlContext, logicalPlan), + Dataset.ofRows(sqlContext, logicalPlan), groupingAttributes, RelationalGroupedDataset.GroupByType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 521032a8b3a83..91c02053ae1a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -52,17 +52,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => - Dataset.newDataFrame( + Dataset.ofRows( df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.RollupType => - Dataset.newDataFrame( + Dataset.ofRows( df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.CubeType => - Dataset.newDataFrame( + Dataset.ofRows( df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) - Dataset.newDataFrame( + Dataset.ofRows( df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 542f2f4debbf9..853a74c827d47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -351,7 +351,7 @@ class SQLContext private[sql]( val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) - Dataset.newDataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) + Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRDD)(self)) } /** @@ -366,7 +366,7 @@ class SQLContext private[sql]( SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - Dataset.newDataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) + Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) } /** @@ -376,7 +376,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { - Dataset.newDataFrame(this, LogicalRelation(baseRelation)) + Dataset.ofRows(this, LogicalRelation(baseRelation)) } /** @@ -431,7 +431,7 @@ class SQLContext private[sql]( rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) - Dataset.newDataFrame(this, logicalPlan) + Dataset.ofRows(this, logicalPlan) } @@ -466,7 +466,7 @@ class SQLContext private[sql]( // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) - Dataset.newDataFrame(this, logicalPlan) + Dataset.ofRows(this, logicalPlan) } /** @@ -494,7 +494,7 @@ class SQLContext private[sql]( */ @DeveloperApi def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { - Dataset.newDataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } /** @@ -513,7 +513,7 @@ class SQLContext private[sql]( val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) } - Dataset.newDataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) + Dataset.ofRows(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -541,7 +541,7 @@ class SQLContext private[sql]( val className = beanClass.getName val beanInfo = Introspector.getBeanInfo(beanClass) val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) - Dataset.newDataFrame(self, LocalRelation(attrSeq, rows.toSeq)) + Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) } /** @@ -759,7 +759,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def sql(sqlText: String): DataFrame = { - Dataset.newDataFrame(this, parseSql(sqlText)) + Dataset.ofRows(this, parseSql(sqlText)) } /** @@ -782,7 +782,7 @@ class SQLContext private[sql]( } private def table(tableIdent: TableIdentifier): DataFrame = { - Dataset.newDataFrame(this, sessionState.catalog.lookupRelation(tableIdent)) + Dataset.ofRows(this, sessionState.catalog.lookupRelation(tableIdent)) } /** @@ -794,7 +794,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(): DataFrame = { - Dataset.newDataFrame(this, ShowTablesCommand(None)) + Dataset.ofRows(this, ShowTablesCommand(None)) } /** @@ -806,7 +806,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(databaseName: String): DataFrame = { - Dataset.newDataFrame(this, ShowTablesCommand(Some(databaseName))) + Dataset.ofRows(this, ShowTablesCommand(Some(databaseName))) } /** @@ -871,7 +871,7 @@ class SQLContext private[sql]( schema: StructType): DataFrame = { val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) - Dataset.newDataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + Dataset.ofRows(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index cd769d013786a..59c3ffcf488c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -261,7 +261,7 @@ case class CacheTableCommand( override def run(sqlContext: SQLContext): Seq[Row] = { plan.foreach { logicalPlan => - sqlContext.registerDataFrameAsTable(Dataset.newDataFrame(sqlContext, logicalPlan), tableName) + sqlContext.registerDataFrameAsTable(Dataset.ofRows(sqlContext, logicalPlan), tableName) } sqlContext.cacheTable(tableName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index fac2a64726618..548da86359c26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -154,7 +154,7 @@ case class DataSource( } def dataFrameBuilder(files: Array[String]): DataFrame = { - Dataset.newDataFrame( + Dataset.ofRows( sqlContext, LogicalRelation( DataSource( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala index 9cf794804d043..37c2c4517ccf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -34,7 +34,7 @@ private[sql] case class InsertIntoDataSource( override def run(sqlContext: SQLContext): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] - val data = Dataset.newDataFrame(sqlContext, query) + val data = Dataset.ofRows(sqlContext, query) // Apply the schema of the existing table to the new data. val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 51ec969daf68f..a30b52080fe19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -114,7 +114,7 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionSet = AttributeSet(partitionColumns) val dataColumns = query.output.filterNot(partitionSet.contains) - val queryExecution = Dataset.newDataFrame(sqlContext, query).queryExecution + val queryExecution = Dataset.ofRows(sqlContext, query).queryExecution SQLExecution.withNewExecutionId(sqlContext, queryExecution) { val relation = WriteRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 7ca0e8859a03e..9e8e0352db644 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -101,7 +101,7 @@ case class CreateTempTableUsing( options = options) sqlContext.sessionState.catalog.registerTable( tableIdent, - Dataset.newDataFrame(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan) + Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan) Seq.empty[Row] } @@ -116,7 +116,7 @@ case class CreateTempTableUsingAsSelect( query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val df = Dataset.newDataFrame(sqlContext, query) + val df = Dataset.ofRows(sqlContext, query) val dataSource = DataSource( sqlContext, className = provider, @@ -126,7 +126,7 @@ case class CreateTempTableUsingAsSelect( val result = dataSource.write(mode, df) sqlContext.sessionState.catalog.registerTable( tableIdent, - Dataset.newDataFrame(sqlContext, LogicalRelation(result)).logicalPlan) + Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan) Seq.empty[Row] } @@ -147,7 +147,7 @@ case class RefreshTable(tableIdent: TableIdentifier) if (isCached) { // Create a data frame to represent the table. // TODO: Use uncacheTable once it supports database name. - val df = Dataset.newDataFrame(sqlContext, logicalPlan) + val df = Dataset.ofRows(sqlContext, logicalPlan) // Uncache the logicalPlan. sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) // Cache it again. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index bccd2a44d9fe9..8c2231335c789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging { StructField(v._1 + "_freqItems", ArrayType(v._2, false)) } val schema = StructType(outputCols).toAttributes - Dataset.newDataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) + Dataset.ofRows(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 0a0dccbad1cb1..e0b6709c51d17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -454,6 +454,6 @@ private[sql] object StatFunctions extends Logging { } val schema = StructType(StructField(tableName, StringType) +: headerNames) - Dataset.newDataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + Dataset.ofRows(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index c5fefb5346bc7..29b058f2e4062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -288,7 +288,7 @@ class StreamExecution( val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 logDebug(s"Optimized batch in ${optimizerTime}ms") - val nextBatch = Dataset.newDataFrame(sqlContext, newPlan) + val nextBatch = Dataset.ofRows(sqlContext, newPlan) sink.addBatch(currentBatchId - 1, nextBatch) awaitBatchLock.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 8c2bb4abd5f6d..8bc8bcaa966b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -58,7 +58,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } def toDF()(implicit sqlContext: SQLContext): DataFrame = { - Dataset.newDataFrame(sqlContext, logicalPlan) + Dataset.ofRows(sqlContext, logicalPlan) } def addData(data: A*): Offset = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dd4aa9e93ae4a..304d747d4fffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -917,7 +917,7 @@ object functions { * @since 1.5.0 */ def broadcast(df: DataFrame): DataFrame = { - Dataset.newDataFrame(df.sqlContext, BroadcastHint(df.logicalPlan)) + Dataset.ofRows(df.sqlContext, BroadcastHint(df.logicalPlan)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f60c5ea759342..e6b7dc9199984 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -956,7 +956,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - Dataset.newDataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") + Dataset.ofRows(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 2dd6416853a2e..4ca739450c607 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -66,7 +66,7 @@ import org.apache.spark.util.Utils trait StreamTest extends QueryTest with Timeouts { implicit class RichSource(s: Source) { - def toDF(): DataFrame = Dataset.newDataFrame(sqlContext, StreamingRelation(s)) + def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s)) def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 4abc6d6a55ecd..1fa15730bc2e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -268,7 +268,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi l.copy(relation = r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) } - Dataset.newDataFrame(sqlContext, bucketed) + Dataset.ofRows(sqlContext, bucketed) } else { df } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ab3876728bea8..d48358566e38e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -214,7 +214,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.newDataFrame(sqlContext, plan) + Dataset.ofRows(sqlContext, plan) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index ff6657362013d..226b8e179604d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -247,7 +247,7 @@ case class CreateMetastoreDataSourceAsSelect( createMetastoreTable = true } - val data = Dataset.newDataFrame(hiveContext, query) + val data = Dataset.ofRows(hiveContext, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. case Some(s) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, s) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala index 047e82e411bda..9a63ecb4ca8d0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -63,7 +63,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(generatedSQL), Dataset.newDataFrame(sqlContext, plan)) + checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 5c26aa1a79cf7..81fd71201d338 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -968,7 +968,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. // todo: remove it? - val newActual = Dataset.newDataFrame(sqlContext, actual.logicalPlan) + val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => From 6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 24 Mar 2016 00:51:31 +0800 Subject: [PATCH 17/26] [SPARK-13549][SQL] Refactor the Optimizer Rule CollapseProject #### What changes were proposed in this pull request? The PR https://github.com/apache/spark/pull/10541 changed the rule `CollapseProject` by enabling collapsing `Project` into `Aggregate`. It leaves a to-do item to remove the duplicate code. This PR is to finish this to-do item. Also added a test case for covering this change. #### How was this patch tested? Added a new test case. liancheng Could you check if the code refactoring is fine? Thanks! Author: gatorsmile Closes #11427 from gatorsmile/collapseProjectRefactor. --- .../sql/catalyst/optimizer/Optimizer.scala | 101 ++++++++---------- .../optimizer/CollapseProjectSuite.scala | 26 ++++- 2 files changed, 70 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0840d46e4e5ec..4cfdcf95cb925 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -417,68 +417,57 @@ object ColumnPruning extends Rule[LogicalPlan] { object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p @ Project(projectList1, Project(projectList2, child)) => - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliasMap = AttributeMap(projectList2.collect { - case a: Alias => (a.toAttribute, a) - }) - - // We only collapse these two Projects if their overlapped expressions are all - // deterministic. - val hasNondeterministic = projectList1.exists(_.collect { - case a: Attribute if aliasMap.contains(a) => aliasMap(a).child - }.exists(!_.deterministic)) - - if (hasNondeterministic) { + case p1 @ Project(_, p2: Project) => + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + p1 + } else { + p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) + } + case p @ Project(_, agg: Aggregate) => + if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p } else { - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute => aliasMap.getOrElse(a, a) - }).asInstanceOf[Seq[NamedExpression]] - // collapse 2 projects may introduce unnecessary Aliases, trim them here. - val cleanedProjection = substitutedProjection.map(p => - CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] - ) - Project(cleanedProjection, child) + agg.copy(aggregateExpressions = buildCleanedProjectList( + p.projectList, agg.aggregateExpressions)) } + } - // TODO Eliminate duplicate code - // This clause is identical to the one above except that the inner operator is an `Aggregate` - // rather than a `Project`. - case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) => - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliasMap = AttributeMap(projectList2.collect { - case a: Alias => (a.toAttribute, a) - }) + private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { + AttributeMap(projectList.collect { + case a: Alias => a.toAttribute -> a + }) + } - // We only collapse these two Projects if their overlapped expressions are all - // deterministic. - val hasNondeterministic = projectList1.exists(_.collect { - case a: Attribute if aliasMap.contains(a) => aliasMap(a).child - }.exists(!_.deterministic)) + private def haveCommonNonDeterministicOutput( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { + // Create a map of Aliases to their values from the lower projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliases = collectAliases(lower) + + // Collapse upper and lower Projects if and only if their overlapped expressions are all + // deterministic. + upper.exists(_.collect { + case a: Attribute if aliases.contains(a) => aliases(a).child + }.exists(!_.deterministic)) + } - if (hasNondeterministic) { - p - } else { - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute => aliasMap.getOrElse(a, a) - }).asInstanceOf[Seq[NamedExpression]] - // collapse 2 projects may introduce unnecessary Aliases, trim them here. - val cleanedProjection = substitutedProjection.map(p => - CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] - ) - agg.copy(aggregateExpressions = cleanedProjection) - } + private def buildCleanedProjectList( + upper: Seq[NamedExpression], + lower: Seq[NamedExpression]): Seq[NamedExpression] = { + // Create a map of Aliases to their values from the lower projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliases = collectAliases(lower) + + // Substitute any attributes that are produced by the lower projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + val rewrittenUpper = upper.map(_.transform { + case a: Attribute => aliases.getOrElse(a, a) + }) + // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. + rewrittenUpper.map { p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 833f054659bee..587437e9aa81d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -29,7 +29,7 @@ class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) :: - Batch("CollapseProject", Once, CollapseProject) :: Nil + Batch("CollapseProject", Once, CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int) @@ -95,4 +95,28 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("collapse project into aggregate") { + val query = testRelation + .groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse common nondeterministic project and aggregate") { + val query = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } } From 3de24ae2ed6c58fc96a7e50832afe42fe7af34fb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 23 Mar 2016 10:15:23 -0700 Subject: [PATCH 18/26] [SPARK-14075] Refactor MemoryStore to be testable independent of BlockManager This patch refactors the `MemoryStore` so that it can be tested without needing to construct / mock an entire `BlockManager`. - The block manager's serialization- and compression-related methods have been moved from `BlockManager` to `SerializerManager`. - `BlockInfoManager `is now passed directly to classes that need it, rather than being passed via the `BlockManager`. - The `MemoryStore` now calls `dropFromMemory` via a new `BlockEvictionHandler` interface rather than directly calling the `BlockManager`. This change helps to enforce a narrow interface between the `MemoryStore` and `BlockManager` functionality and makes this interface easier to mock in tests. - Several of the block unrolling tests have been moved from `BlockManagerSuite` into a new `MemoryStoreSuite`. Author: Josh Rosen Closes #11899 from JoshRosen/reduce-memorystore-blockmanager-coupling. --- .../spark/unsafe/map/BytesToBytesMap.java | 7 +- .../unsafe/sort/UnsafeExternalSorter.java | 17 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 6 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 5 +- .../spark/serializer/SerializerManager.scala | 90 +++++- .../shuffle/BlockStoreShuffleReader.scala | 5 +- .../apache/spark/storage/BlockManager.scala | 118 ++----- .../storage/BlockManagerManagedBuffer.scala | 6 +- .../spark/storage/memory/MemoryStore.scala | 55 +++- .../collection/ExternalAppendOnlyMap.scala | 7 +- .../util/collection/ExternalSorter.scala | 3 +- .../sort/UnsafeShuffleWriterSuite.java | 32 +- .../map/AbstractBytesToBytesMapSuite.java | 17 +- .../sort/UnsafeExternalSorterSuite.java | 12 +- .../org/apache/spark/DistributedSuite.scala | 3 +- .../BlockStoreShuffleReaderSuite.scala | 22 +- .../spark/storage/BlockManagerSuite.scala | 197 +----------- .../spark/storage/MemoryStoreSuite.scala | 302 ++++++++++++++++++ .../execution/UnsafeExternalRowSorter.java | 1 + .../UnsafeFixedWidthAggregationMap.java | 8 +- .../sql/execution/UnsafeKVExternalSorter.java | 7 +- .../apache/spark/sql/execution/Window.scala | 1 + .../datasources/WriterContainer.scala | 1 + .../execution/joins/CartesianProduct.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 2 +- .../spark/sql/hive/hiveWriterContainers.scala | 1 + .../rdd/WriteAheadLogBackedBlockRDD.scala | 3 +- .../receiver/ReceivedBlockHandler.scala | 6 +- .../receiver/ReceiverSupervisorImpl.scala | 2 +- .../streaming/ReceivedBlockHandlerSuite.scala | 11 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 9 +- 31 files changed, 555 insertions(+), 402 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index de36814ecca15..9aacb084f648c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -32,6 +32,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -163,12 +164,14 @@ public final class BytesToBytesMap extends MemoryConsumer { private long peakMemoryUsedBytes = 0L; private final BlockManager blockManager; + private final SerializerManager serializerManager; private volatile MapIterator destructiveIterator = null; private LinkedList spillWriters = new LinkedList<>(); public BytesToBytesMap( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, int initialCapacity, double loadFactor, long pageSizeBytes, @@ -176,6 +179,7 @@ public BytesToBytesMap( super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; + this.serializerManager = serializerManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -209,6 +213,7 @@ public BytesToBytesMap( this( taskMemoryManager, SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, + SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null, initialCapacity, 0.70, pageSizeBytes, @@ -271,7 +276,7 @@ private void advanceToNextPage() { } try { Closeables.close(reader, /* swallowIOException = */ false); - reader = spillWriters.getFirst().getReader(blockManager); + reader = spillWriters.getFirst().getReader(serializerManager); recordsInPage = -1; } catch (IOException e) { // Scala iterator does not handle exception diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 927b19c4e8038..ded8f0472b275 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -31,6 +31,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -51,6 +52,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final RecordComparator recordComparator; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; + private final SerializerManager serializerManager; private final TaskContext taskContext; private ShuffleWriteMetrics writeMetrics; @@ -78,6 +80,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, @@ -85,7 +88,8 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( long pageSizeBytes, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + serializerManager, taskContext, recordComparator, prefixComparator, initialSize, + pageSizeBytes, inMemorySorter); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -95,18 +99,20 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( public static UnsafeExternalSorter create( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes) { - return new UnsafeExternalSorter(taskMemoryManager, blockManager, + return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); } private UnsafeExternalSorter( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, @@ -116,6 +122,7 @@ private UnsafeExternalSorter( super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; + this.serializerManager = serializerManager; this.taskContext = taskContext; this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; @@ -412,7 +419,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); } if (inMemSorter != null) { readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); @@ -463,7 +470,7 @@ public long spill() throws IOException { } spillWriter.close(); spillWriters.add(spillWriter); - nextUpstream = spillWriter.getReader(blockManager); + nextUpstream = spillWriter.getReader(serializerManager); long released = 0L; synchronized (UnsafeExternalSorter.this) { @@ -549,7 +556,7 @@ public UnsafeSorterIterator getIterator() throws IOException { } else { LinkedList queue = new LinkedList<>(); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - queue.add(spillWriter.getReader(blockManager)); + queue.add(spillWriter.getReader(serializerManager)); } if (inMemSorter != null) { queue.add(inMemSorter.getSortedIterator()); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 20ee1c8eb0c77..1d588c37c5db0 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -22,8 +22,8 @@ import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; /** @@ -46,13 +46,13 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( - BlockManager blockManager, + SerializerManager serializerManager, File file, BlockId blockId) throws IOException { assert (file.length() > 0); final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); try { - this.in = blockManager.wrapForCompression(blockId, bs); + this.in = serializerManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 234e21140a1dd..9ba760e8422f4 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; +import org.apache.spark.serializer.SerializerManager; import scala.Tuple2; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -144,7 +145,7 @@ public File getFile() { return file; } - public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { - return new UnsafeSorterSpillReader(blockManager, file, blockId); + public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException { + return new UnsafeSorterSpillReader(serializerManager, file, blockId); } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index b9f115463a6eb..27e5fa4c2b464 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -17,17 +17,25 @@ package org.apache.spark.serializer +import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream} +import java.nio.ByteBuffer + import scala.reflect.ClassTag import org.apache.spark.SparkConf +import org.apache.spark.io.CompressionCodec +import org.apache.spark.storage._ +import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer} /** - * Component that selects which [[Serializer]] to use for shuffles. + * Component which configures serialization and compression for various Spark components, including + * automatic selection of which [[Serializer]] to use for shuffles. */ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { private[this] val kryoSerializer = new KryoSerializer(conf) + private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]] private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = { val primitiveClassTags = Set[ClassTag[_]]( ClassTag.Boolean, @@ -44,7 +52,21 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar primitiveClassTags ++ arrayClassTags } - private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]] + // Whether to compress broadcast variables that are stored + private[this] val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) + // Whether to compress shuffle output that are stored + private[this] val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) + // Whether to compress RDD partitions that are stored serialized + private[this] val compressRdds = conf.getBoolean("spark.rdd.compress", false) + // Whether to compress shuffle output temporarily spilled to disk + private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + + /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay + * the initialization of the compression codec until it is first used. The reason is that a Spark + * program could be using a user-defined codec in a third party jar, which is loaded in + * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been + * loaded yet. */ + private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) private def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag @@ -68,4 +90,68 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar defaultSerializer } } + + private def shouldCompress(blockId: BlockId): Boolean = { + blockId match { + case _: ShuffleBlockId => compressShuffle + case _: BroadcastBlockId => compressBroadcast + case _: RDDBlockId => compressRdds + case _: TempLocalBlockId => compressShuffleSpill + case _: TempShuffleBlockId => compressShuffle + case _ => false + } + } + + /** + * Wrap an output stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s + } + + /** + * Wrap an input stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s + } + + /** Serializes into a stream. */ + def dataSerializeStream[T: ClassTag]( + blockId: BlockId, + outputStream: OutputStream, + values: Iterator[T]): Unit = { + val byteStream = new BufferedOutputStream(outputStream) + val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + } + + /** Serializes into a chunked byte buffer. */ + def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { + val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4) + dataSerializeStream(blockId, byteArrayChunkOutputStream, values) + new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)) + } + + /** + * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = { + dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true)) + } + + /** + * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserializeStream[T: ClassTag]( + blockId: BlockId, + inputStream: InputStream): Iterator[T] = { + val stream = new BufferedInputStream(inputStream) + getSerializer(implicitly[ClassTag[T]]) + .newInstance() + .deserializeStream(wrapForCompression(blockId, stream)) + .asIterator.asInstanceOf[Iterator[T]] + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4054465c0f7fc..637b2dfc193b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { @@ -52,7 +53,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Wrap the streams for compression based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - blockManager.wrapForCompression(blockId, inputStream) + serializerManager.wrapForCompression(blockId, inputStream) } val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 83f8c5c37d136..eebb43e245df7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,7 +18,6 @@ package org.apache.spark.storage import java.io._ -import java.nio.ByteBuffer import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, ExecutionContext, Future} @@ -30,7 +29,6 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging -import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryManager import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} @@ -38,11 +36,11 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.{Serializer, SerializerInstance, SerializerManager} +import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ import org.apache.spark.util._ -import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer} +import org.apache.spark.util.io.ChunkedByteBuffer /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( @@ -68,7 +66,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) - extends BlockDataManager with Logging { + extends BlockDataManager with BlockEvictionHandler with Logging { private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -80,13 +78,15 @@ private[spark] class BlockManager( new DiskBlockManager(conf, deleteFilesOnStop) } + // Visible for testing private[storage] val blockInfoManager = new BlockInfoManager private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) // Actual storage of where blocks are kept - private[spark] val memoryStore = new MemoryStore(conf, this, memoryManager) + private[spark] val memoryStore = + new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) private[spark] val diskStore = new DiskStore(conf, diskBlockManager) memoryManager.setMemoryStore(memoryStore) @@ -126,14 +126,6 @@ private[spark] class BlockManager( blockTransferService } - // Whether to compress broadcast variables that are stored - private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) - // Whether to compress shuffle output that are stored - private val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) - // Whether to compress RDD partitions that are stored serialized - private val compressRdds = conf.getBoolean("spark.rdd.compress", false) - // Whether to compress shuffle output temporarily spilled to disk - private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) // Max number of failures before this block manager refreshes the block locations from the driver private val maxFailuresBeforeLocationRefresh = conf.getInt("spark.block.failures.beforeLocationRefresh", 5) @@ -152,13 +144,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay - * the initialization of the compression codec until it is first used. The reason is that a Spark - * program could be using a user-defined codec in a third party jar, which is loaded in - * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been - * loaded yet. */ - private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -286,7 +271,7 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(this, blockId, buffer) + case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) case None => throw new BlockNotFoundException(blockId.toString) } } @@ -422,7 +407,8 @@ private[spark] class BlockManager( val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get } else { - dataDeserialize(blockId, memoryStore.getBytes(blockId).get)(info.classTag) + serializerManager.dataDeserialize( + blockId, memoryStore.getBytes(blockId).get)(info.classTag) } val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) @@ -430,11 +416,11 @@ private[spark] class BlockManager( val iterToReturn: Iterator[Any] = { val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { - val diskValues = dataDeserialize(blockId, diskBytes)(info.classTag) + val diskValues = serializerManager.dataDeserialize(blockId, diskBytes)(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { val bytes = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - dataDeserialize(blockId, bytes)(info.classTag) + serializerManager.dataDeserialize(blockId, bytes)(info.classTag) } } val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) @@ -486,7 +472,7 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - dataSerialize(blockId, memoryStore.getValues(blockId).get) + serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get) } else { releaseLock(blockId) throw new SparkException(s"Block $blockId was not found even though it's read-locked") @@ -510,7 +496,8 @@ private[spark] class BlockManager( */ private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { getRemoteBytes(blockId).map { data => - new BlockResult(dataDeserialize(blockId, data), DataReadMethod.Network, data.size) + new BlockResult( + serializerManager.dataDeserialize(blockId, data), DataReadMethod.Network, data.size) } } @@ -699,7 +686,8 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) + val compressStream: OutputStream => OutputStream = + serializerManager.wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, syncWrites, writeMetrics, blockId) @@ -757,7 +745,7 @@ private[spark] class BlockManager( // Put it in memory first, even if it also has useDisk set to true; // We will drop it to disk later if the memory store can't hold it. val putSucceeded = if (level.deserialized) { - val values = dataDeserialize(blockId, bytes)(classTag) + val values = serializerManager.dataDeserialize(blockId, bytes)(classTag) memoryStore.putIterator(blockId, values, level, classTag) match { case Right(_) => true case Left(iter) => @@ -896,7 +884,7 @@ private[spark] class BlockManager( if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") diskStore.put(blockId) { fileOutputStream => - dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -905,7 +893,7 @@ private[spark] class BlockManager( } } else if (level.useDisk) { diskStore.put(blockId) { fileOutputStream => - dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1167,7 +1155,7 @@ private[spark] class BlockManager( * * @return the block's new effective StorageLevel. */ - private[storage] def dropFromMemory[T: ClassTag]( + private[storage] override def dropFromMemory[T: ClassTag]( blockId: BlockId, data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { logInfo(s"Dropping block $blockId from memory") @@ -1181,7 +1169,7 @@ private[spark] class BlockManager( data() match { case Left(elements) => diskStore.put(blockId) { fileOutputStream => - dataSerializeStream( + serializerManager.dataSerializeStream( blockId, fileOutputStream, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) @@ -1264,70 +1252,6 @@ private[spark] class BlockManager( } } - private def shouldCompress(blockId: BlockId): Boolean = { - blockId match { - case _: ShuffleBlockId => compressShuffle - case _: BroadcastBlockId => compressBroadcast - case _: RDDBlockId => compressRdds - case _: TempLocalBlockId => compressShuffleSpill - case _: TempShuffleBlockId => compressShuffle - case _ => false - } - } - - /** - * Wrap an output stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s - } - - /** - * Wrap an input stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s - } - - /** Serializes into a stream. */ - def dataSerializeStream[T: ClassTag]( - blockId: BlockId, - outputStream: OutputStream, - values: Iterator[T]): Unit = { - val byteStream = new BufferedOutputStream(outputStream) - val ser = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() - } - - /** Serializes into a chunked byte buffer. */ - def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { - val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4) - dataSerializeStream(blockId, byteArrayChunkOutputStream, values) - new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)) - } - - /** - * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = { - dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true)) - } - - /** - * Deserializes a InputStream into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserializeStream[T: ClassTag]( - blockId: BlockId, - inputStream: InputStream): Iterator[T] = { - val stream = new BufferedInputStream(inputStream) - serializerManager.getSerializer(implicitly[ClassTag[T]]) - .newInstance() - .deserializeStream(wrapForCompression(blockId, stream)) - .asIterator.asInstanceOf[Iterator[T]] - } - def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index 12594e6a2bc0c..f66f942798550 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -29,19 +29,19 @@ import org.apache.spark.util.io.ChunkedByteBuffer * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( - blockManager: BlockManager, + blockInfoManager: BlockInfoManager, blockId: BlockId, chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { override def retain(): ManagedBuffer = { super.retain() - val locked = blockManager.blockInfoManager.lockForReading(blockId, blocking = false) + val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this } override def release(): ManagedBuffer = { - blockManager.releaseLock(blockId) + blockInfoManager.unlock(blockId) super.release() } } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index d370ee912ab31..90016cbeb8467 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -26,7 +26,8 @@ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.MemoryManager -import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel} +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector import org.apache.spark.util.io.ChunkedByteBuffer @@ -44,14 +45,33 @@ private case class SerializedMemoryEntry[T]( size: Long, classTag: ClassTag[T]) extends MemoryEntry[T] +private[storage] trait BlockEvictionHandler { + /** + * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory + * store reaches its limit and needs to free up space. + * + * If `data` is not put on disk, it won't be created. + * + * The caller of this method must hold a write lock on the block before calling this method. + * This method does not release the write lock. + * + * @return the block's new effective StorageLevel. + */ + private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel +} + /** * Stores blocks in memory, either as Arrays of deserialized Java objects or as * serialized ByteBuffers. */ private[spark] class MemoryStore( conf: SparkConf, - blockManager: BlockManager, - memoryManager: MemoryManager) + blockInfoManager: BlockInfoManager, + serializerManager: SerializerManager, + memoryManager: MemoryManager, + blockEvictionHandler: BlockEvictionHandler) extends Logging { // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and @@ -117,7 +137,7 @@ private[spark] class MemoryStore( entries.put(blockId, entry) } logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) + blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) true } else { false @@ -201,7 +221,7 @@ private[spark] class MemoryStore( val entry = if (level.deserialized) { new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag) } else { - val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator)(classTag) + val bytes = serializerManager.dataSerialize(blockId, arrayValues.iterator)(classTag) new SerializedMemoryEntry[T](bytes, bytes.size, classTag) } val size = entry.size @@ -237,7 +257,10 @@ private[spark] class MemoryStore( } val bytesOrValues = if (level.deserialized) "values" else "bytes" logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( - blockId, bytesOrValues, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) + blockId, + bytesOrValues, + Utils.bytesToString(size), + Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(size) } else { assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, @@ -284,7 +307,7 @@ private[spark] class MemoryStore( } if (entry != null) { memoryManager.releaseStorageMemory(entry.size) - logDebug(s"Block $blockId of size ${entry.size} dropped " + + logInfo(s"Block $blockId of size ${entry.size} dropped " + s"from memory (free ${maxMemory - blocksMemoryUsed})") true } else { @@ -339,7 +362,7 @@ private[spark] class MemoryStore( // We don't want to evict blocks which are currently being read, so we need to obtain // an exclusive write lock on blocks which are candidates for eviction. We perform a // non-blocking "tryLock" here in order to ignore blocks which are locked for reading: - if (blockManager.blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { + if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { selectedBlocks += blockId freedMemory += pair.getValue.size } @@ -353,20 +376,21 @@ private[spark] class MemoryStore( case SerializedMemoryEntry(buffer, _, _) => Right(buffer) } val newEffectiveStorageLevel = - blockManager.dropFromMemory(blockId, () => data)(entry.classTag) + blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag) if (newEffectiveStorageLevel.isValid) { // The block is still present in at least one store, so release the lock // but don't delete the block info - blockManager.releaseLock(blockId) + blockInfoManager.unlock(blockId) } else { // The block isn't present in any store, so delete the block info so that the // block can be stored again - blockManager.blockInfoManager.removeBlock(blockId) + blockInfoManager.removeBlock(blockId) } } if (freedMemory >= space) { - logInfo(s"${selectedBlocks.size} blocks selected for dropping") + logInfo(s"${selectedBlocks.size} blocks selected for dropping " + + s"(${Utils.bytesToString(freedMemory)} bytes)") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } // This should never be null as only one task should be dropping @@ -376,14 +400,15 @@ private[spark] class MemoryStore( dropBlock(blockId, entry) } } + logInfo(s"After dropping ${selectedBlocks.size} blocks, " + + s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") freedMemory } else { blockId.foreach { id => - logInfo(s"Will not store $id as it would require dropping another block " + - "from the same RDD") + logInfo(s"Will not store $id") } selectedBlocks.foreach { id => - blockManager.releaseLock(id) + blockInfoManager.unlock(id) } 0L } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 531f1c4dd2760..95351e98261d7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator @@ -59,7 +59,8 @@ class ExternalAppendOnlyMap[K, V, C]( mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager, - context: TaskContext = TaskContext.get()) + context: TaskContext = TaskContext.get(), + serializerManager: SerializerManager = SparkEnv.get.serializerManager) extends Iterable[(K, C)] with Serializable with Logging @@ -458,7 +459,7 @@ class ExternalAppendOnlyMap[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream) ser.deserializeStream(compressedStream) } else { // No more batches left diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 8cdc4663e6d6f..561ba22df557f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,6 +108,7 @@ private[spark] class ExternalSorter[K, V, C]( private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager + private val serializerManager = SparkEnv.get.serializerManager private val serInstance = serializer.newInstance() // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided @@ -503,7 +504,7 @@ private[spark] class ExternalSorter[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream) + val compressedStream = serializerManager.wrapForCompression(spill.blockId, bufferedStream) serInstance.deserializeStream(compressedStream) } else { // No more batches left diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 47c695ad4e717..44733dcdafc41 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -70,6 +70,7 @@ public class UnsafeShuffleWriterSuite { final LinkedList spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); + final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -111,7 +112,7 @@ public void setUp() throws IOException { .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); - taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -135,35 +136,6 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( - new Answer() { - @Override - public InputStream answer(InvocationOnMock invocation) throws Throwable { - assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId); - InputStream is = (InputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); - } else { - return is; - } - } - } - ); - - when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( - new Answer() { - @Override - public OutputStream answer(InvocationOnMock invocation) throws Throwable { - assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId); - OutputStream os = (OutputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); - } else { - return os; - } - } - } - ); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(new Answer() { diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 6667179b9d30c..449fb45c301e2 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,7 +19,6 @@ import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; @@ -42,7 +41,9 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -51,7 +52,6 @@ import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -64,6 +64,9 @@ public abstract class AbstractBytesToBytesMapSuite { private TestMemoryManager memoryManager; private TaskMemoryManager taskMemoryManager; + private SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes final LinkedList spillFilesCreated = new LinkedList<>(); @@ -85,7 +88,9 @@ public void setup() { new TestMemoryManager( new SparkConf() .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) - .set("spark.memory.offHeap.size", "256mb")); + .set("spark.memory.offHeap.size", "256mb") + .set("spark.shuffle.spill.compress", "false") + .set("spark.shuffle.compress", "false")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); @@ -124,8 +129,6 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -546,8 +549,8 @@ public void failureToGrow() { @Test public void spillInIterator() throws IOException { - BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false); + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); try { int i; for (i = 0; i < 1024; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index db50e551f256e..a2253d8559640 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,7 +19,6 @@ import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; @@ -43,14 +42,15 @@ import org.apache.spark.executor.TaskMetrics; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -60,6 +60,9 @@ public class UnsafeExternalSorterSuite { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -135,8 +138,6 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -172,6 +173,7 @@ private UnsafeExternalSorter newSorter() throws IOException { return UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -374,6 +376,7 @@ public void forcedSpillingWithoutComparator() throws Exception { final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, null, null, @@ -408,6 +411,7 @@ public void testPeakMemoryUsed() throws Exception { final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 2732cd674992d..3dded4d486dcb 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -194,10 +194,11 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager val blockTransfer = SparkEnv.get.blockTransferService + val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = blockManager.dataDeserialize[Int](blockId, + val deserialized = serializerManager.dataDeserialize[Int](blockId, new ChunkedByteBuffer(bytes.nioByteBuffer())).toList assert(deserialized === (1 to 100).toList) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 08f52c92e1812..dba1172d5fdbd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,14 +20,11 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito.{mock, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -77,13 +74,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // can ensure retain() and release() are properly called. val blockManager = mock(classOf[BlockManager]) - // Create a return function to use for the mocked wrapForCompression method that just returns - // the original input stream. - val dummyCompressionFunction = new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = - invocation.getArguments()(1).asInstanceOf[InputStream] - } - // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() @@ -105,9 +95,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) - when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) - .thenAnswer(dummyCompressionFunction) - managedBuffer } @@ -133,11 +120,18 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext new BaseShuffleHandle(shuffleId, numMaps, dependency) } + val serializerManager = new SerializerManager( + serializer, + new SparkConf() + .set("spark.shuffle.compress", "false") + .set("spark.shuffle.spill.compress", "false")) + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, TaskContext.empty(), + serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9419dfaa00648..94f6f877408a5 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1033,138 +1033,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } - test("reserve/release unroll memory") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { - memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory) - } - - // Reserve - assert(reserveUnrollMemoryForThisTask(100)) - assert(memoryStore.currentUnrollMemoryForThisTask === 100) - assert(reserveUnrollMemoryForThisTask(200)) - assert(memoryStore.currentUnrollMemoryForThisTask === 300) - assert(reserveUnrollMemoryForThisTask(500)) - assert(memoryStore.currentUnrollMemoryForThisTask === 800) - assert(!reserveUnrollMemoryForThisTask(1000000)) - assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted - // Release - memoryStore.releaseUnrollMemoryForThisTask(100) - assert(memoryStore.currentUnrollMemoryForThisTask === 700) - memoryStore.releaseUnrollMemoryForThisTask(100) - assert(memoryStore.currentUnrollMemoryForThisTask === 600) - // Reserve again - assert(reserveUnrollMemoryForThisTask(4400)) - assert(memoryStore.currentUnrollMemoryForThisTask === 5000) - assert(!reserveUnrollMemoryForThisTask(20000)) - assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted - // Release again - memoryStore.releaseUnrollMemoryForThisTask(1000) - assert(memoryStore.currentUnrollMemoryForThisTask === 4000) - memoryStore.releaseUnrollMemoryForThisTask() // release all - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - } - - test("safely unroll blocks") { - store = makeBlockManager(12000) - val smallList = List.fill(40)(new Array[Byte](100)) - val bigList = List.fill(40)(new Array[Byte](1000)) - val memoryStore = store.memoryStore - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll with all the space in the world. This should succeed. - var putResult = - memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any) - assert(putResult.isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } - assert(memoryStore.remove("unroll")) - - // Unroll with not enough space. This should succeed after kicking out someBlock1. - assert(store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY)) - assert(store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY)) - putResult = - memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any) - assert(putResult.isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - assert(memoryStore.contains("someBlock2")) - assert(!memoryStore.contains("someBlock1")) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } - assert(memoryStore.remove("unroll")) - - // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = - // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. - // In the mean time, however, we kicked out someBlock2 before giving up. - assert(store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY)) - putResult = - memoryStore.putIterator("unroll", bigList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - assert(!memoryStore.contains("someBlock2")) - assert(putResult.isLeft) - bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => - assert(e === a, "putIterator() did not return original values!") - } - // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - } - - test("safely unroll blocks through putIterator") { - store = makeBlockManager(12000) - val memOnly = StorageLevel.MEMORY_ONLY - val memoryStore = store.memoryStore - val smallList = List.fill(40)(new Array[Byte](100)) - val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll with plenty of space. This should succeed and cache both blocks. - val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any) - val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any) - assert(memoryStore.contains("b1")) - assert(memoryStore.contains("b2")) - assert(result1.isRight) // unroll was successful - assert(result2.isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Re-put these two blocks so block manager knows about them too. Otherwise, block manager - // would not know how to drop them from memory later. - memoryStore.remove("b1") - memoryStore.remove("b2") - store.putIterator("b1", smallIterator, memOnly) - store.putIterator("b2", smallIterator, memOnly) - - // Unroll with not enough space. This should succeed but kick out b1 in the process. - val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any) - assert(result3.isRight) - assert(!memoryStore.contains("b1")) - assert(memoryStore.contains("b2")) - assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - memoryStore.remove("b3") - store.putIterator("b3", smallIterator, memOnly) - - // Unroll huge block with not enough space. This should fail and kick out b2 in the process. - val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, ClassTag.Any) - assert(result4.isLeft) // unroll was unsuccessful - assert(!memoryStore.contains("b1")) - assert(!memoryStore.contains("b2")) - assert(memoryStore.contains("b3")) - assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - } - - /** - * This test is essentially identical to the preceding one, except that it uses MEMORY_AND_DISK. - */ test("safely unroll blocks through putIterator (disk)") { store = makeBlockManager(12000) val memAndDisk = StorageLevel.MEMORY_AND_DISK @@ -1203,72 +1071,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - result4.left.get.close() - assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory - } - - test("multiple unrolls by the same thread") { - store = makeBlockManager(12000) - val memOnly = StorageLevel.MEMORY_ONLY - val memoryStore = store.memoryStore - val smallList = List.fill(40)(new Array[Byte](100)) - def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // All unroll memory used is released because putIterator did not return an iterator - assert(memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any).isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - assert(memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any).isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll memory is not released because putIterator returned an iterator - // that still depends on the underlying vector used in the process - assert(memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB3 > 0) - - // The unroll memory owned by this thread builds on top of its value after the previous unrolls - assert(memoryStore.putIterator("b4", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) - - // ... but only to a certain extent (until we run out of free space to grant new unroll memory) - assert(memoryStore.putIterator("b5", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask - assert(memoryStore.putIterator("b6", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask - assert(memoryStore.putIterator("b7", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) - assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) - assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) - } - - test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - val blockId = BlockId("rdd_3_10") - store.blockInfoManager.lockNewBlockForWriting( - blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false)) - memoryStore.putBytes(blockId, 13000, () => { - fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") - }) - } - - test("put a small ByteBuffer to MemoryStore") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - val blockId = BlockId("rdd_3_10") - var bytes: ChunkedByteBuffer = null - memoryStore.putBytes(blockId, 10000, () => { - bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000)) - bytes - }) - assert(memoryStore.getSize(blockId) === 10000) } - test("read-locked blocks cannot be evicted from the MemoryStore") { + test("read-locked blocks cannot be evicted from memory") { store = makeBlockManager(12000) val arr = new Array[Byte](4000) // First store a1 and a2, both in memory, and a3, on disk only diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala new file mode 100644 index 0000000000000..b4ab67ca15a88 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.language.implicitConversions +import scala.language.postfixOps +import scala.language.reflectiveCalls +import scala.reflect.ClassTag + +import org.scalatest._ + +import org.apache.spark._ +import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore, PartiallyUnrolledIterator} +import org.apache.spark.util._ +import org.apache.spark.util.io.ChunkedByteBuffer + +class MemoryStoreSuite + extends SparkFunSuite + with PrivateMethodTester + with BeforeAndAfterEach + with ResetSystemProperties { + + var conf: SparkConf = new SparkConf(false) + .set("spark.test.useCompressedOops", "true") + .set("spark.storage.unrollFraction", "0.4") + .set("spark.storage.unrollMemoryThreshold", "512") + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m")) + + // Implicitly convert strings to BlockIds for test clarity. + implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId) + + override def beforeEach(): Unit = { + super.beforeEach() + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + System.setProperty("os.arch", "amd64") + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + } + + def makeMemoryStore(maxMem: Long): (MemoryStore, BlockInfoManager) = { + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val serializerManager = new SerializerManager(serializer, conf) + val blockInfoManager = new BlockInfoManager + val blockEvictionHandler = new BlockEvictionHandler { + var memoryStore: MemoryStore = _ + override private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { + memoryStore.remove(blockId) + StorageLevel.NONE + } + } + val memoryStore = + new MemoryStore(conf, blockInfoManager, serializerManager, memManager, blockEvictionHandler) + memManager.setMemoryStore(memoryStore) + blockEvictionHandler.memoryStore = memoryStore + (memoryStore, blockInfoManager) + } + + test("reserve/release unroll memory") { + val (memoryStore, _) = makeMemoryStore(12000) + assert(memoryStore.currentUnrollMemory === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { + memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory) + } + + // Reserve + assert(reserveUnrollMemoryForThisTask(100)) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + assert(reserveUnrollMemoryForThisTask(200)) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + assert(reserveUnrollMemoryForThisTask(500)) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + assert(!reserveUnrollMemoryForThisTask(1000000)) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted + // Release + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) + // Reserve again + assert(reserveUnrollMemoryForThisTask(4400)) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + assert(!reserveUnrollMemoryForThisTask(20000)) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted + // Release again + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + } + + test("safely unroll blocks") { + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + val ct = implicitly[ClassTag[Array[Byte]]] + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIterator[T]( + blockId: BlockId, + iter: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + assert(blockInfoManager.lockNewBlockForWriting( + blockId, + new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false))) + val res = memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, classTag) + blockInfoManager.unlock(blockId) + res + } + + // Unroll with all the space in the world. This should succeed. + var putResult = putIterator("unroll", smallList.iterator, ClassTag.Any) + assert(putResult.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + blockInfoManager.lockForWriting("unroll") + assert(memoryStore.remove("unroll")) + blockInfoManager.removeBlock("unroll") + + // Unroll with not enough space. This should succeed after kicking out someBlock1. + assert(putIterator("someBlock1", smallList.iterator, ct).isRight) + assert(putIterator("someBlock2", smallList.iterator, ct).isRight) + putResult = putIterator("unroll", smallList.iterator, ClassTag.Any) + assert(putResult.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + assert(memoryStore.contains("someBlock2")) + assert(!memoryStore.contains("someBlock1")) + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + blockInfoManager.lockForWriting("unroll") + assert(memoryStore.remove("unroll")) + blockInfoManager.removeBlock("unroll") + + // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = + // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. + // In the meantime, however, we kicked out someBlock2 before giving up. + assert(putIterator("someBlock3", smallList.iterator, ct).isRight) + putResult = putIterator("unroll", bigList.iterator, ClassTag.Any) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + assert(!memoryStore.contains("someBlock2")) + assert(putResult.isLeft) + bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => + assert(e === a, "putIterator() did not return original values!") + } + // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + } + + test("safely unroll blocks through putIterator") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIterator[T]( + blockId: BlockId, + iter: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + assert(blockInfoManager.lockNewBlockForWriting( + blockId, + new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false))) + val res = memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, classTag) + blockInfoManager.unlock(blockId) + res + } + + // Unroll with plenty of space. This should succeed and cache both blocks. + val result1 = putIterator("b1", smallIterator, ClassTag.Any) + val result2 = putIterator("b2", smallIterator, ClassTag.Any) + assert(memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(result1.isRight) // unroll was successful + assert(result2.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + // Re-put these two blocks so block manager knows about them too. Otherwise, block manager + // would not know how to drop them from memory later. + blockInfoManager.lockForWriting("b1") + memoryStore.remove("b1") + blockInfoManager.removeBlock("b1") + blockInfoManager.lockForWriting("b2") + memoryStore.remove("b2") + blockInfoManager.removeBlock("b2") + putIterator("b1", smallIterator, ClassTag.Any) + putIterator("b2", smallIterator, ClassTag.Any) + + // Unroll with not enough space. This should succeed but kick out b1 in the process. + val result3 = putIterator("b3", smallIterator, ClassTag.Any) + assert(result3.isRight) + assert(!memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + blockInfoManager.lockForWriting("b3") + assert(memoryStore.remove("b3")) + blockInfoManager.removeBlock("b3") + putIterator("b3", smallIterator, ClassTag.Any) + + // Unroll huge block with not enough space. This should fail and kick out b2 in the process. + val result4 = putIterator("b4", bigIterator, ClassTag.Any) + assert(result4.isLeft) // unroll was unsuccessful + assert(!memoryStore.contains("b1")) + assert(!memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(!memoryStore.contains("b4")) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + result4.left.get.close() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory + } + + test("multiple unrolls by the same thread") { + val (memoryStore, _) = makeMemoryStore(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIterator( + blockId: BlockId, + iter: Iterator[Any]): Either[PartiallyUnrolledIterator[Any], Long] = { + memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, ClassTag.Any) + } + + // All unroll memory used is released because putIterator did not return an iterator + assert(putIterator("b1", smallIterator).isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + assert(putIterator("b2", smallIterator).isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + // Unroll memory is not released because putIterator returned an iterator + // that still depends on the underlying vector used in the process + assert(putIterator("b3", smallIterator).isLeft) + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB3 > 0) + + // The unroll memory owned by this thread builds on top of its value after the previous unrolls + assert(putIterator("b4", smallIterator).isLeft) + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) + + // ... but only to a certain extent (until we run out of free space to grant new unroll memory) + assert(putIterator("b5", smallIterator).isLeft) + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask + assert(putIterator("b6", smallIterator).isLeft) + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask + assert(putIterator("b7", smallIterator).isLeft) + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) + } + + test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val blockId = BlockId("rdd_3_10") + blockInfoManager.lockNewBlockForWriting( + blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false)) + memoryStore.putBytes(blockId, 13000, () => { + fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") + }) + } + + test("put a small ByteBuffer to MemoryStore") { + val (memoryStore, _) = makeMemoryStore(12000) + val blockId = BlockId("rdd_3_10") + var bytes: ChunkedByteBuffer = null + memoryStore.putBytes(blockId, 10000, () => { + bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000)) + bytes + }) + assert(memoryStore.getSize(blockId) === 10000) + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index d85147e961fa8..aa7fc2121e86c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -67,6 +67,7 @@ public UnsafeExternalRowSorter( sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), sparkEnv.blockManager(), + sparkEnv.serializerManager(), taskContext, new RowComparator(ordering, schema.length()), prefixComparator, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index acf6c583bbb58..8882903bbf8ad 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -241,7 +241,11 @@ public void printPerfMetrics() { */ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { return new UnsafeKVExternalSorter( - groupingKeySchema, aggregationBufferSchema, - SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); + groupingKeySchema, + aggregationBufferSchema, + SparkEnv.get().blockManager(), + SparkEnv.get().serializerManager(), + map.getPageSizeBytes(), + map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 9e08675c3e669..d3bfb00b3fa20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -24,6 +24,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; @@ -52,14 +53,16 @@ public UnsafeKVExternalSorter( StructType keySchema, StructType valueSchema, BlockManager blockManager, + SerializerManager serializerManager, long pageSizeBytes) throws IOException { - this(keySchema, valueSchema, blockManager, pageSizeBytes, null); + this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, null); } public UnsafeKVExternalSorter( StructType keySchema, StructType valueSchema, BlockManager blockManager, + SerializerManager serializerManager, long pageSizeBytes, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; @@ -77,6 +80,7 @@ public UnsafeKVExternalSorter( sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -116,6 +120,7 @@ public UnsafeKVExternalSorter( sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( taskMemoryManager, blockManager, + serializerManager, taskContext, new KVComparator(ordering, keySchema.length()), prefixComparator, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index a4c0e1c9fba41..270c09aff3f88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -339,6 +339,7 @@ case class Window( sorter = UnsafeExternalSorter.create( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, TaskContext.get(), null, null, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index c74ac8a282c2c..233ac263aaafc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -399,6 +399,7 @@ private[sql] class DynamicPartitionWriterContainer( sortingKeySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes) while (iterator.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index fabd2fbe1e0c6..fb65b50da800b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -41,6 +41,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField val sorter = UnsafeExternalSorter.create( context.taskMemoryManager(), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, context, null, null, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index e03bd6a3e7d20..476d93fc2a9ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -120,7 +120,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) + keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, pageSize) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index a29d55ee25b20..794fe264ead5d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -279,6 +279,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( StructType.fromAttributes(partitionOutput), StructType.fromAttributes(dataOutput), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes) while (iterator.hasNext) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index ace67a639c6b8..c56520b1e21e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -115,6 +115,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( assertValid() val hadoopConf = broadcastedHadoopConf.value val blockManager = SparkEnv.get.blockManager + val serializerManager = SparkEnv.get.serializerManager val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockId = partition.blockId @@ -161,7 +162,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } - blockManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead)) + serializerManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead)) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 6d4f4b99c175f..85350ff658d66 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} @@ -123,6 +124,7 @@ private[streaming] case class WriteAheadLogBasedStoreResult( */ private[streaming] class WriteAheadLogBasedBlockHandler( blockManager: BlockManager, + serializerManager: SerializerManager, streamId: Int, storageLevel: StorageLevel, conf: SparkConf, @@ -173,10 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - blockManager.dataSerialize(blockId, arrayBuffer.iterator) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index e41fd11963ba3..4fb0f8caacbb6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -60,7 +60,7 @@ private[streaming] class ReceiverSupervisorImpl( "Please use streamingContext.checkpoint() to set the checkpoint directory. " + "See documentation for more details.") } - new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId, + new WriteAheadLogBasedBlockHandler(env.blockManager, env.serializerManager, receiver.streamId, receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get) } else { new BlockManagerBasedBlockHandler(env.blockManager, receiver.storageLevel) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 122ca0627f720..4e77cd6347d1b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -60,6 +60,7 @@ class ReceivedBlockHandlerSuite val mapOutputTracker = new MapOutputTrackerMaster(conf) val shuffleManager = new HashShuffleManager(conf) val serializer = new KryoSerializer(conf) + var serializerManager = new SerializerManager(serializer, conf) val manualClock = new ManualClock val blockManagerSize = 10000000 val blockManagerBuffer = new ArrayBuffer[BlockManager]() @@ -156,7 +157,7 @@ class ReceivedBlockHandlerSuite val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) val bytes = reader.read(fileSegment) reader.close() - blockManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList + serializerManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList } loggedData shouldEqual data } @@ -265,7 +266,6 @@ class ReceivedBlockHandlerSuite name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) @@ -335,7 +335,8 @@ class ReceivedBlockHandlerSuite } } - def dataToByteBuffer(b: Seq[String]) = blockManager.dataSerialize(generateBlockId, b.iterator) + def dataToByteBuffer(b: Seq[String]) = + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq @@ -367,8 +368,8 @@ class ReceivedBlockHandlerSuite /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1) - val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, - storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) + val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, serializerManager, + 1, storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) try { body(receivedBlockHandler) } finally { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index c4bf42d0f272d..ce5a6e00fb2fe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils @@ -39,6 +40,7 @@ class WriteAheadLogBackedBlockRDDSuite var sparkContext: SparkContext = null var blockManager: BlockManager = null + var serializerManager: SerializerManager = null var dir: File = null override def beforeEach(): Unit = { @@ -58,6 +60,7 @@ class WriteAheadLogBackedBlockRDDSuite super.beforeAll() sparkContext = new SparkContext(conf) blockManager = sparkContext.env.blockManager + serializerManager = sparkContext.env.serializerManager } override def afterAll(): Unit = { @@ -65,6 +68,8 @@ class WriteAheadLogBackedBlockRDDSuite try { sparkContext.stop() System.clearProperty("spark.driver.port") + blockManager = null + serializerManager = null } finally { super.afterAll() } @@ -107,8 +112,6 @@ class WriteAheadLogBackedBlockRDDSuite * It can also test if the partitions that were read from the log were again stored in * block manager. * - * - * * @param numPartitions Number of partitions in RDD * @param numPartitionsInBM Number of partitions to write to the BlockManager. * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager @@ -223,7 +226,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(blockManager.dataSerialize(id, data.iterator).toByteBuffer) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments From 48ee16d8012602c75d50aa2a85e26b7de3c48944 Mon Sep 17 00:00:00 2001 From: Ernest Date: Wed, 23 Mar 2016 10:29:36 -0700 Subject: [PATCH 19/26] [SPARK-14055] writeLocksByTask need to be update when removeBlock ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14055 ## How was this patch tested? manual tests by running LiveJournalPageRank on a large dataset ( the dataset must larger enough to incure RDD partition eviction). Author: Ernest Closes #11875 from Earne/issue-14055. --- .../main/scala/org/apache/spark/storage/BlockInfoManager.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 94d11c5be5a49..ca53534b61c4a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -421,6 +421,7 @@ private[storage] class BlockInfoManager extends Logging { infos.remove(blockId) blockInfo.readerCount = 0 blockInfo.writerTask = BlockInfo.NO_WRITER + writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) } case None => throw new IllegalArgumentException( From 30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 23 Mar 2016 11:20:44 -0700 Subject: [PATCH 20/26] [SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params ## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah Closes #11663 from sethah/SPARK-13068-tc. --- python/pyspark/ml/classification.py | 20 +- python/pyspark/ml/clustering.py | 14 +- python/pyspark/ml/feature.py | 95 +++++---- python/pyspark/ml/param/__init__.py | 181 ++++++++++++++++-- .../ml/param/_shared_params_code_gen.py | 91 +++++---- python/pyspark/ml/param/shared.py | 58 +++--- python/pyspark/ml/recommendation.py | 25 ++- python/pyspark/ml/regression.py | 25 ++- python/pyspark/ml/tests.py | 83 ++++++-- python/pyspark/ml/tuning.py | 5 +- 10 files changed, 421 insertions(+), 176 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 8075108114c18..fdeccf822c0e6 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -20,6 +20,7 @@ from pyspark import since from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) @@ -87,7 +88,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + - " If threshold and thresholds are both set, they must match.") + " If threshold and thresholds are both set, they must match.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -243,7 +245,7 @@ class TreeClassifierParams(object): impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + - ", ".join(supportedImpurities)) + ", ".join(supportedImpurities), typeConverter=TypeConverters.toString) def __init__(self): super(TreeClassifierParams, self).__init__() @@ -534,7 +536,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + "Supported options: " + ", ".join(GBTParams.supportedLossTypes), + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -652,9 +655,10 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H """ smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + - "default is 1.0") + "default is 1.0", typeConverter=TypeConverters.toFloat) modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + - "(case-sensitive). Supported options: multinomial (default) and bernoulli.") + "(case-sensitive). Supported options: multinomial (default) and bernoulli.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -782,11 +786,13 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + - "neurons and output layer of 10 neurons, default is [1, 1].") + "neurons and output layer of 10 neurons, default is [1, 1].", + typeConverter=TypeConverters.toListInt) blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + "matrices. Data is stacked within partitions. If block size is more than " + "remaining data in a partition then it is adjusted to the size of this " + - "data. Recommended size is between 10 and 1000, default is 128.") + "data. Recommended size is between 10 and 1000, default is 128.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 2db5b82c44543..e22d5c8ea4afa 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -87,12 +87,14 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol .. versionadded:: 1.5.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create") + k = Param(Params._dummy(), "k", "number of clusters to create", + typeConverter=TypeConverters.toInt) initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + - "to use a parallel variant of k-means++") - initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") + "to use a parallel variant of k-means++", TypeConverters.toString) + initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -227,10 +229,12 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte .. versionadded:: 2.0.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create") + k = Param(Params._dummy(), "k", "number of clusters to create", + typeConverter=TypeConverters.toInt) minDivisibleClusterSize = Param(Params._dummy(), "minDivisibleClusterSize", "the minimum number of points (if >= 1.0) " + - "or the minimum proportion") + "or the minimum proportion", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 16cb9d1db3ea7..86b53285b5b00 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -83,7 +83,8 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java """ threshold = Param(Params._dummy(), "threshold", - "threshold in binary classification prediction, in range [0, 1]") + "threshold in binary classification prediction, in range [0, 1]", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, threshold=0.0, inputCol=None, outputCol=None): @@ -159,7 +160,8 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav "range [x,y) except the last bucket, which also includes y. The splits " + "should be strictly increasing. Values at -inf, inf must be explicitly " + "provided to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.") + "specified will be treated as errors.", + typeConverter=TypeConverters.toListFloat) @keyword_only def __init__(self, splits=None, inputCol=None, outputCol=None): @@ -243,15 +245,17 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + " times the term must appear in the document); if this is a double in [0,1), then this " + "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) minDF = Param( Params._dummy(), "minDF", "Specifies the minimum number of" + " different documents a term must appear in to be included in the vocabulary." + " If this is an integer >= 1, this specifies the number of documents the term must" + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0") + " Default 1.0", typeConverter=TypeConverters.toFloat) vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): @@ -375,7 +379,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit """ inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " + - "default False.") + "default False.", typeConverter=TypeConverters.toBoolean) @keyword_only def __init__(self, inverse=False, inputCol=None, outputCol=None): @@ -441,8 +445,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada .. versionadded:: 1.5.0 """ - scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " + - "it must be MLlib Vector type.") + scalingVec = Param(Params._dummy(), "scalingVec", "Vector for hadamard product.", + typeConverter=TypeConverters.toVector) @keyword_only def __init__(self, scalingVec=None, inputCol=None, outputCol=None): @@ -564,7 +568,8 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab """ minDocFreq = Param(Params._dummy(), "minDocFreq", - "minimum of documents in which a term should appear for filtering") + "minimum of documents in which a term should appear for filtering", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): @@ -746,8 +751,10 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav .. versionadded:: 1.6.0 """ - min = Param(Params._dummy(), "min", "Lower bound of the output feature range") - max = Param(Params._dummy(), "max", "Upper bound of the output feature range") + min = Param(Params._dummy(), "min", "Lower bound of the output feature range", + typeConverter=TypeConverters.toFloat) + max = Param(Params._dummy(), "max", "Upper bound of the output feature range", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): @@ -870,7 +877,8 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr .. versionadded:: 1.5.0 """ - n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, n=2, inputCol=None, outputCol=None): @@ -936,7 +944,8 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav .. versionadded:: 1.4.0 """ - p = Param(Params._dummy(), "p", "the p norm value.") + p = Param(Params._dummy(), "p", "the p norm value.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, p=2.0, inputCol=None, outputCol=None): @@ -1018,7 +1027,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, .. versionadded:: 1.4.0 """ - dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", + typeConverter=TypeConverters.toBoolean) @keyword_only def __init__(self, dropLast=True, inputCol=None, outputCol=None): @@ -1085,7 +1095,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead .. versionadded:: 1.4.0 """ - degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") + degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, degree=2, inputCol=None, outputCol=None): @@ -1163,7 +1174,8 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, Jav # a placeholder to make it appear in the generated doc numBuckets = Param(Params._dummy(), "numBuckets", "Maximum number of buckets (quantiles, or " + - "categories) into which data points are grouped. Must be >= 2. Default 2.") + "categories) into which data points are grouped. Must be >= 2. Default 2.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): @@ -1255,11 +1267,13 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, .. versionadded:: 1.4.0 """ - minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") + minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)", + typeConverter=TypeConverters.toInt) gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") - pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") + pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing", + TypeConverters.toString) toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " + - "lowercase before tokenizing") + "lowercase before tokenizing", TypeConverters.toBoolean) @keyword_only def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, @@ -1370,7 +1384,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): .. versionadded:: 1.6.0 """ - statement = Param(Params._dummy(), "statement", "SQL statement") + statement = Param(Params._dummy(), "statement", "SQL statement", TypeConverters.toString) @keyword_only def __init__(self, statement=None): @@ -1444,8 +1458,9 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J .. versionadded:: 1.4.0 """ - withMean = Param(Params._dummy(), "withMean", "Center data with mean") - withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") + withMean = Param(Params._dummy(), "withMean", "Center data with mean", TypeConverters.toBoolean) + withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation", + TypeConverters.toBoolean) @keyword_only def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): @@ -1628,7 +1643,8 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, labels = Param(Params._dummy(), "labels", "Optional array of labels specifying index-string mapping." + - " If not provided or if empty, then metadata from inputCol is used instead.") + " If not provided or if empty, then metadata from inputCol is used instead.", + typeConverter=TypeConverters.toListString) @keyword_only def __init__(self, inputCol=None, outputCol=None, labels=None): @@ -1689,9 +1705,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl .. versionadded:: 1.6.0 """ - stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out", + typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + - "comparison over the stop words") + "comparison over the stop words", TypeConverters.toBoolean) @keyword_only def __init__(self, inputCol=None, outputCol=None, stopWords=None, @@ -1930,7 +1947,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja maxCategories = Param(Params._dummy(), "maxCategories", "Threshold for the number of values a categorical feature can take " + "(>= 2). If a feature is found to have > maxCategories values, then " + - "it is declared continuous.") + "it is declared continuous.", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, maxCategories=20, inputCol=None, outputCol=None): @@ -2035,11 +2052,12 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J """ indices = Param(Params._dummy(), "indices", "An array of indices to select features from " + - "a vector column. There can be no overlap with names.") + "a vector column. There can be no overlap with names.", + typeConverter=TypeConverters.toListInt) names = Param(Params._dummy(), "names", "An array of feature names to select features from " + "a vector column. These names must be specified by ML " + "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " + - "indices.") + "indices.", typeConverter=TypeConverters.toListString) @keyword_only def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): @@ -2147,12 +2165,14 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ vectorSize = Param(Params._dummy(), "vectorSize", - "the dimension of codes after transforming from words") + "the dimension of codes after transforming from words", + typeConverter=TypeConverters.toInt) numPartitions = Param(Params._dummy(), "numPartitions", - "number of partitions for sentences of words") + "number of partitions for sentences of words", + typeConverter=TypeConverters.toInt) minCount = Param(Params._dummy(), "minCount", "the minimum number of times a token must appear to be included in the " + - "word2vec model's vocabulary") + "word2vec model's vocabulary", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, @@ -2293,7 +2313,8 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab .. versionadded:: 1.5.0 """ - k = Param(Params._dummy(), "k", "the number of principal components") + k = Param(Params._dummy(), "k", "the number of principal components", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, k=None, inputCol=None, outputCol=None): @@ -2425,7 +2446,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM .. versionadded:: 1.5.0 """ - formula = Param(Params._dummy(), "formula", "R model formula") + formula = Param(Params._dummy(), "formula", "R model formula", TypeConverters.toString) @keyword_only def __init__(self, formula=None, featuresCol="features", labelCol="label"): @@ -2511,12 +2532,11 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja .. versionadded:: 2.0.0 """ - # a placeholder to make it appear in the generated doc numTopFeatures = \ Param(Params._dummy(), "numTopFeatures", "Number of features that selector will select, ordered by statistics value " + "descending. If the number of features is < numTopFeatures, then this will select " + - "all features.") + "all features.", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): @@ -2525,11 +2545,6 @@ def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, la """ super(ChiSqSelector, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) - self.numTopFeatures = \ - Param(self, "numTopFeatures", - "Number of features that selector will select, ordered by statistics value " + - "descending. If the number of features is < numTopFeatures, then this will " + - "select all features.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index c0f0a71eb6f67..a1265294a1e9e 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -14,31 +14,47 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import array +import sys +if sys.version > '3': + basestring = str + xrange = range + unicode = str from abc import ABCMeta import copy +import numpy as np +import warnings from pyspark import since from pyspark.ml.util import Identifiable +from pyspark.mllib.linalg import DenseVector, Vector -__all__ = ['Param', 'Params'] +__all__ = ['Param', 'Params', 'TypeConverters'] class Param(object): """ A param with self-contained documentation. + Note: `expectedType` is deprecated and will be removed in 2.1. Use typeConverter instead, + as a keyword argument. + .. versionadded:: 1.3.0 """ - def __init__(self, parent, name, doc, expectedType=None): + def __init__(self, parent, name, doc, expectedType=None, typeConverter=None): if not isinstance(parent, Identifiable): raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) self.parent = parent.uid self.name = str(name) self.doc = str(doc) self.expectedType = expectedType + if expectedType is not None: + warnings.warn("expectedType is deprecated and will be removed in 2.1. " + + "Use typeConverter instead, as a keyword argument.") + self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter def _copy_new_parent(self, parent): """Copy the current param to a new parent, must be a dummy param.""" @@ -65,6 +81,146 @@ def __eq__(self, other): return False +class TypeConverters(object): + """ + .. note:: DeveloperApi + + Factory methods for common type conversion functions for `Param.typeConverter`. + + .. versionadded:: 2.0.0 + """ + + @staticmethod + def _is_numeric(value): + vtype = type(value) + return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long' + + @staticmethod + def _is_integer(value): + return TypeConverters._is_numeric(value) and float(value).is_integer() + + @staticmethod + def _can_convert_to_list(value): + vtype = type(value) + return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector) + + @staticmethod + def _can_convert_to_string(value): + vtype = type(value) + return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_] + + @staticmethod + def identity(value): + """ + Dummy converter that just returns value. + """ + return value + + @staticmethod + def toList(value): + """ + Convert a value to a list, if possible. + """ + if type(value) == list: + return value + elif type(value) in [np.ndarray, tuple, xrange, array.array]: + return list(value) + elif isinstance(value, Vector): + return list(value.toArray()) + else: + raise TypeError("Could not convert %s to list" % value) + + @staticmethod + def toListFloat(value): + """ + Convert a value to list of floats, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._is_numeric(v), value)): + return [float(v) for v in value] + raise TypeError("Could not convert %s to list of floats" % value) + + @staticmethod + def toListInt(value): + """ + Convert a value to list of ints, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._is_integer(v), value)): + return [int(v) for v in value] + raise TypeError("Could not convert %s to list of ints" % value) + + @staticmethod + def toListString(value): + """ + Convert a value to list of strings, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)): + return [TypeConverters.toString(v) for v in value] + raise TypeError("Could not convert %s to list of strings" % value) + + @staticmethod + def toVector(value): + """ + Convert a value to a MLlib Vector, if possible. + """ + if isinstance(value, Vector): + return value + elif TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._is_numeric(v), value)): + return DenseVector(value) + raise TypeError("Could not convert %s to vector" % value) + + @staticmethod + def toFloat(value): + """ + Convert a value to a float, if possible. + """ + if TypeConverters._is_numeric(value): + return float(value) + else: + raise TypeError("Could not convert %s to float" % value) + + @staticmethod + def toInt(value): + """ + Convert a value to an int, if possible. + """ + if TypeConverters._is_integer(value): + return int(value) + else: + raise TypeError("Could not convert %s to int" % value) + + @staticmethod + def toString(value): + """ + Convert a value to a string, if possible. + """ + if isinstance(value, basestring): + return value + elif type(value) in [np.string_, np.str_]: + return str(value) + elif type(value) == np.unicode_: + return unicode(value) + else: + raise TypeError("Could not convert %s to string type" % type(value)) + + @staticmethod + def toBoolean(value): + """ + Convert a value to a boolean, if possible. + """ + if type(value) == bool: + return value + else: + raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value)) + + class Params(Identifiable): """ Components that take parameters. This also provides an internal @@ -275,23 +431,12 @@ def _set(self, **kwargs): """ for param, value in kwargs.items(): p = getattr(self, param) - if p.expectedType is None or type(value) == p.expectedType or value is None: - self._paramMap[getattr(self, param)] = value - else: + if value is not None: try: - # Try and do "safe" conversions that don't lose information - if p.expectedType == float: - self._paramMap[getattr(self, param)] = float(value) - # Python 3 unified long & int - elif p.expectedType == int and type(value).__name__ == 'long': - self._paramMap[getattr(self, param)] = value - else: - raise Exception( - "Provided type {0} incompatible with type {1} for param {2}" - .format(type(value), p.expectedType, p)) - except ValueError: - raise Exception(("Failed to convert {0} to type {1} for param {2}" - .format(type(value), p.expectedType, p))) + value = p.typeConverter(value) + except TypeError as e: + raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e)) + self._paramMap[p] = value return self def _setDefault(self, **kwargs): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 5e297b8214823..7dd2937db7b82 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -38,7 +38,7 @@ # python _shared_params_code_gen.py > shared.py -def _gen_param_header(name, doc, defaultValueStr, expectedType): +def _gen_param_header(name, doc, defaultValueStr, typeConverter): """ Generates the header part for shared variables @@ -50,7 +50,7 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType): Mixin for param $name: $doc """ - $name = Param(Params._dummy(), "$name", "$doc", $expectedType) + $name = Param(Params._dummy(), "$name", "$doc", typeConverter=$typeConverter) def __init__(self): super(Has$Name, self).__init__()''' @@ -60,15 +60,14 @@ def __init__(self): self._setDefault($name=$defaultValueStr)''' Name = name[0].upper() + name[1:] - expectedTypeName = str(expectedType) - if expectedType is not None: - expectedTypeName = expectedType.__name__ + if typeConverter is None: + typeConverter = str(None) return template \ .replace("$name", name) \ .replace("$Name", Name) \ .replace("$doc", doc) \ .replace("$defaultValueStr", str(defaultValueStr)) \ - .replace("$expectedType", expectedTypeName) + .replace("$typeConverter", typeConverter) def _gen_param_code(name, doc, defaultValueStr): @@ -105,64 +104,73 @@ def get$Name(self): if __name__ == "__main__": print(header) print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") - print("from pyspark.ml.param import Param, Params\n\n") + print("from pyspark.ml.param import *\n\n") shared = [ - ("maxIter", "max number of iterations (>= 0).", None, int), - ("regParam", "regularization parameter (>= 0).", None, float), - ("featuresCol", "features column name.", "'features'", str), - ("labelCol", "label column name.", "'label'", str), - ("predictionCol", "prediction column name.", "'prediction'", str), + ("maxIter", "max number of iterations (>= 0).", None, "TypeConverters.toInt"), + ("regParam", "regularization parameter (>= 0).", None, "TypeConverters.toFloat"), + ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"), + ("labelCol", "label column name.", "'label'", "TypeConverters.toString"), + ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"), ("probabilityCol", "Column name for predicted class conditional probabilities. " + "Note: Not all models output well-calibrated probability estimates! These probabilities " + - "should be treated as confidences, not precise probabilities.", "'probability'", str), + "should be treated as confidences, not precise probabilities.", "'probability'", + "TypeConverters.toString"), ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'", - str), - ("inputCol", "input column name.", None, str), - ("inputCols", "input column names.", None, None), - ("outputCol", "output column name.", "self.uid + '__output'", str), - ("numFeatures", "number of features.", None, int), + "TypeConverters.toString"), + ("inputCol", "input column name.", None, "TypeConverters.toString"), + ("inputCols", "input column names.", None, "TypeConverters.toListString"), + ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, int), - ("seed", "random seed.", "hash(type(self).__name__)", int), - ("tol", "the convergence tolerance for iterative algorithms.", None, float), - ("stepSize", "Step size to be used for each iteration of optimization.", None, float), + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, + "TypeConverters.toInt"), + ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), + ("tol", "the convergence tolerance for iterative algorithms.", None, + "TypeConverters.toFloat"), + ("stepSize", "Step size to be used for each iteration of optimization.", None, + "TypeConverters.toFloat"), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None, str), + "added later.", None, "TypeConverters.toBoolean"), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", float), - ("fitIntercept", "whether to fit an intercept term.", "True", bool), + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", + "TypeConverters.toFloat"), + ("fitIntercept", "whether to fit an intercept term.", "True", "TypeConverters.toBoolean"), ("standardization", "whether to standardize the training features before fitting the " + - "model.", "True", bool), + "model.", "True", "TypeConverters.toBoolean"), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None, None), + "probability of that class and t is the class' threshold.", None, + "TypeConverters.toListFloat"), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None, str), + "all instance weights as 1.0.", None, "TypeConverters.toString"), ("solver", "the solver algorithm for optimization. If this is not set or empty, " + - "default value is 'auto'.", "'auto'", str)] + "default value is 'auto'.", "'auto'", "TypeConverters.toString")] code = [] - for name, doc, defaultValueStr, expectedType in shared: - param_code = _gen_param_header(name, doc, defaultValueStr, expectedType) + for name, doc, defaultValueStr, typeConverter in shared: + param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter) code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr)) decisionTreeParams = [ ("maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; " + - "depth 1 means 1 internal node + 2 leaf nodes."), + "depth 1 means 1 internal node + 2 leaf nodes.", "TypeConverters.toInt"), ("maxBins", "Max number of bins for" + " discretizing continuous features. Must be >=2 and >= number of categories for any" + - " categorical feature."), + " categorical feature.", "TypeConverters.toInt"), ("minInstancesPerNode", "Minimum number of instances each child must have after split. " + "If a split causes the left or right child to have fewer than minInstancesPerNode, the " + - "split will be discarded as invalid. Should be >= 1."), - ("minInfoGain", "Minimum information gain for a split to be considered at a tree node."), - ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), + "split will be discarded as invalid. Should be >= 1.", "TypeConverters.toInt"), + ("minInfoGain", "Minimum information gain for a split to be considered at a tree node.", + "TypeConverters.toFloat"), + ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.", + "TypeConverters.toInt"), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + "Caching can speed up training of deeper trees. Users can set how often should the " + - "cache be checkpointed or disable it by setting checkpointInterval.")] + "cache be checkpointed or disable it by setting checkpointInterval.", + "TypeConverters.toBoolean")] decisionTreeCode = '''class DecisionTreeParams(Params): """ @@ -175,9 +183,12 @@ def __init__(self): super(DecisionTreeParams, self).__init__()''' dtParamMethods = "" dummyPlaceholders = "" - paramTemplate = """$name = Param($owner, "$name", "$doc")""" - for name, doc in decisionTreeParams: - variable = paramTemplate.replace("$name", name).replace("$doc", doc) + paramTemplate = """$name = Param($owner, "$name", "$doc", typeConverter=$typeConverterStr)""" + for name, doc, typeConverterStr in decisionTreeParams: + if typeConverterStr is None: + typeConverterStr = str(None) + variable = paramTemplate.replace("$name", name).replace("$doc", doc) \ + .replace("$typeConverterStr", typeConverterStr) dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " dtParamMethods += _gen_param_code(name, doc, None) + "\n" code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" + diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index db4a8a54d4956..83fbd5903963c 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -17,7 +17,7 @@ # DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. -from pyspark.ml.param import Param, Params +from pyspark.ml.param import * class HasMaxIter(Params): @@ -25,7 +25,7 @@ class HasMaxIter(Params): Mixin for param maxIter: max number of iterations (>= 0). """ - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", int) + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", typeConverter=TypeConverters.toInt) def __init__(self): super(HasMaxIter, self).__init__() @@ -49,7 +49,7 @@ class HasRegParam(Params): Mixin for param regParam: regularization parameter (>= 0). """ - regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", float) + regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasRegParam, self).__init__() @@ -73,7 +73,7 @@ class HasFeaturesCol(Params): Mixin for param featuresCol: features column name. """ - featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", str) + featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasFeaturesCol, self).__init__() @@ -98,7 +98,7 @@ class HasLabelCol(Params): Mixin for param labelCol: label column name. """ - labelCol = Param(Params._dummy(), "labelCol", "label column name.", str) + labelCol = Param(Params._dummy(), "labelCol", "label column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasLabelCol, self).__init__() @@ -123,7 +123,7 @@ class HasPredictionCol(Params): Mixin for param predictionCol: prediction column name. """ - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", str) + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasPredictionCol, self).__init__() @@ -148,7 +148,7 @@ class HasProbabilityCol(Params): Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. """ - probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) + probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", typeConverter=TypeConverters.toString) def __init__(self): super(HasProbabilityCol, self).__init__() @@ -173,7 +173,7 @@ class HasRawPredictionCol(Params): Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name. """ - rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) + rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasRawPredictionCol, self).__init__() @@ -198,7 +198,7 @@ class HasInputCol(Params): Mixin for param inputCol: input column name. """ - inputCol = Param(Params._dummy(), "inputCol", "input column name.", str) + inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasInputCol, self).__init__() @@ -222,7 +222,7 @@ class HasInputCols(Params): Mixin for param inputCols: input column names. """ - inputCols = Param(Params._dummy(), "inputCols", "input column names.", None) + inputCols = Param(Params._dummy(), "inputCols", "input column names.", typeConverter=TypeConverters.toListString) def __init__(self): super(HasInputCols, self).__init__() @@ -246,7 +246,7 @@ class HasOutputCol(Params): Mixin for param outputCol: output column name. """ - outputCol = Param(Params._dummy(), "outputCol", "output column name.", str) + outputCol = Param(Params._dummy(), "outputCol", "output column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasOutputCol, self).__init__() @@ -271,7 +271,7 @@ class HasNumFeatures(Params): Mixin for param numFeatures: number of features. """ - numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", int) + numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasNumFeatures, self).__init__() @@ -295,7 +295,7 @@ class HasCheckpointInterval(Params): Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasCheckpointInterval, self).__init__() @@ -319,7 +319,7 @@ class HasSeed(Params): Mixin for param seed: random seed. """ - seed = Param(Params._dummy(), "seed", "random seed.", int) + seed = Param(Params._dummy(), "seed", "random seed.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasSeed, self).__init__() @@ -344,7 +344,7 @@ class HasTol(Params): Mixin for param tol: the convergence tolerance for iterative algorithms. """ - tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", float) + tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasTol, self).__init__() @@ -368,7 +368,7 @@ class HasStepSize(Params): Mixin for param stepSize: Step size to be used for each iteration of optimization. """ - stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", float) + stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasStepSize, self).__init__() @@ -392,7 +392,7 @@ class HasHandleInvalid(Params): Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. """ - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toBoolean) def __init__(self): super(HasHandleInvalid, self).__init__() @@ -416,7 +416,7 @@ class HasElasticNetParam(Params): Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. """ - elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasElasticNetParam, self).__init__() @@ -441,7 +441,7 @@ class HasFitIntercept(Params): Mixin for param fitIntercept: whether to fit an intercept term. """ - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", bool) + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", typeConverter=TypeConverters.toBoolean) def __init__(self): super(HasFitIntercept, self).__init__() @@ -466,7 +466,7 @@ class HasStandardization(Params): Mixin for param standardization: whether to standardize the training features before fitting the model. """ - standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", bool) + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", typeConverter=TypeConverters.toBoolean) def __init__(self): super(HasStandardization, self).__init__() @@ -491,7 +491,7 @@ class HasThresholds(Params): Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. """ - thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat) def __init__(self): super(HasThresholds, self).__init__() @@ -515,7 +515,7 @@ class HasWeightCol(Params): Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. """ - weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) + weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", typeConverter=TypeConverters.toString) def __init__(self): super(HasWeightCol, self).__init__() @@ -539,7 +539,7 @@ class HasSolver(Params): Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. """ - solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", typeConverter=TypeConverters.toString) def __init__(self): super(HasSolver, self).__init__() @@ -564,12 +564,12 @@ class DecisionTreeParams(Params): Mixin for Decision Tree parameters. """ - maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") - maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") - minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") - minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") - maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") + maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", typeConverter=TypeConverters.toInt) + maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.", typeConverter=TypeConverters.toInt) + minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.", typeConverter=TypeConverters.toInt) + minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.", typeConverter=TypeConverters.toFloat) + maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.", typeConverter=TypeConverters.toInt) + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.", typeConverter=TypeConverters.toBoolean) def __init__(self): diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index de4c2675ed793..7c7a1b67a100e 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -100,16 +100,23 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha .. versionadded:: 1.4.0 """ - rank = Param(Params._dummy(), "rank", "rank of the factorization") - numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") - numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks") - implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference") - alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference") - userCol = Param(Params._dummy(), "userCol", "column name for user ids") - itemCol = Param(Params._dummy(), "itemCol", "column name for item ids") - ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings") + rank = Param(Params._dummy(), "rank", "rank of the factorization", + typeConverter=TypeConverters.toInt) + numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks", + typeConverter=TypeConverters.toInt) + numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks", + typeConverter=TypeConverters.toInt) + implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference", + TypeConverters.toBoolean) + alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference", + typeConverter=TypeConverters.toFloat) + userCol = Param(Params._dummy(), "userCol", "column name for user ids", TypeConverters.toString) + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids", TypeConverters.toString) + ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings", + TypeConverters.toString) nonnegative = Param(Params._dummy(), "nonnegative", - "whether to use nonnegative constraint for least squares") + "whether to use nonnegative constraint for least squares", + TypeConverters.toBoolean) @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 664a44bc473ac..898260879d5b3 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -189,10 +189,11 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti isotonic = \ Param(Params._dummy(), "isotonic", "whether the output sequence should be isotonic/increasing (true) or" + - "antitonic/decreasing (false).") + "antitonic/decreasing (false).", typeConverter=TypeConverters.toBoolean) featureIndex = \ Param(Params._dummy(), "featureIndex", - "The index of the feature if featuresCol is a vector column, no effect otherwise.") + "The index of the feature if featuresCol is a vector column, no effect otherwise.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -278,7 +279,8 @@ class TreeEnsembleParams(DecisionTreeParams): """ subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " + - "used for learning each decision tree, in range (0, 1].") + "used for learning each decision tree, in range (0, 1].", + typeConverter=TypeConverters.toFloat) def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -335,11 +337,13 @@ class RandomForestParams(TreeEnsembleParams): """ supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] - numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).") + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", + typeConverter=TypeConverters.toInt) featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies)) + "options: " + ", ".join(supportedFeatureSubsetStrategies), + typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() @@ -653,7 +657,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + "Supported options: " + ", ".join(GBTParams.supportedLossTypes), + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -767,14 +772,16 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi censorCol = Param(Params._dummy(), "censorCol", "censor column name. The value of this column could be 0 or 1. " + "If the value is 1, it means the event has occurred i.e. " + - "uncensored; otherwise censored.") + "uncensored; otherwise censored.", typeConverter=TypeConverters.toString) quantileProbabilities = \ Param(Params._dummy(), "quantileProbabilities", "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range (0, 1) and the array should be non-empty.") + "should be in the range (0, 1) and the array should be non-empty.", + typeConverter=TypeConverters.toListFloat) quantilesCol = Param(Params._dummy(), "quantilesCol", "quantiles column name. This column will output quantiles of " + - "corresponding quantileProbabilities if it is set.") + "corresponding quantileProbabilities if it is set.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 211248e8b2a23..2fa5da7738c1b 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -18,8 +18,11 @@ """ Unit tests for Spark ML Python APIs. """ - +import array import sys +if sys.version > '3': + xrange = range + try: import xmlrunner except ImportError: @@ -36,19 +39,20 @@ from shutil import rmtree import tempfile +import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import LogisticRegression from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.feature import * -from pyspark.ml.param import Param, Params +from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaWrapper -from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.linalg import DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -104,20 +108,65 @@ class ParamTypeConversionTests(PySparkTestCase): Test that param type conversion happens. """ - def test_int_to_float(self): - from pyspark.mllib.linalg import Vectors - df = self.sc.parallelize([ - Row(label=1.0, weight=2.0, features=Vectors.dense(1.0))]).toDF() - lr = LogisticRegression(elasticNetParam=0) - lr.fit(df) - lr.setElasticNetParam(0) - lr.fit(df) - - def test_invalid_to_float(self): - from pyspark.mllib.linalg import Vectors - self.assertRaises(Exception, lambda: LogisticRegression(elasticNetParam="happy")) - lr = LogisticRegression(elasticNetParam=0) - self.assertRaises(Exception, lambda: lr.setElasticNetParam("panda")) + def test_int(self): + lr = LogisticRegression(maxIter=5.0) + self.assertEqual(lr.getMaxIter(), 5) + self.assertTrue(type(lr.getMaxIter()) == int) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) + + def test_float(self): + lr = LogisticRegression(tol=1) + self.assertEqual(lr.getTol(), 1.0) + self.assertTrue(type(lr.getTol()) == float) + self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) + + def test_vector(self): + ewp = ElementwiseProduct(scalingVec=[1, 3]) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) + ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) + self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) + + def test_list(self): + l = [0, 1] + for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l), + array.array('l', l), xrange(2), tuple(l)]: + converted = TypeConverters.toList(lst_like) + self.assertEqual(type(converted), list) + self.assertListEqual(converted, l) + + def test_list_int(self): + for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), + SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), + array.array('d', [1.0, 2.0])]: + vs = VectorSlicer(indices=indices) + self.assertListEqual(vs.getIndices(), [1, 2]) + self.assertTrue(all([type(v) == int for v in vs.getIndices()])) + self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) + + def test_list_float(self): + b = Bucketizer(splits=[1, 4]) + self.assertEqual(b.getSplits(), [1.0, 4.0]) + self.assertTrue(all([type(v) == float for v in b.getSplits()])) + self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) + + def test_list_string(self): + for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: + idx_to_string = IndexToString(labels=labels) + self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) + self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) + + def test_string(self): + lr = LogisticRegression() + for col in ['features', u'features', np.str_('features')]: + lr.setFeaturesCol(col) + self.assertEqual(lr.getFeaturesCol(), 'features') + self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) + + def test_bool(self): + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) class PipelineTests(PySparkTestCase): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 77af0094dfca4..a528d22e18ec2 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,7 +20,7 @@ from pyspark import since from pyspark.ml import Estimator, Model -from pyspark.ml.param import Params, Param +from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -121,7 +121,8 @@ class CrossValidator(Estimator, HasSeed): evaluator = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") - numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") + numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, From 02d9c352c72a16725322678ef174c5c6e9f2c617 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 23 Mar 2016 11:58:43 -0700 Subject: [PATCH 21/26] [SPARK-14092] [SQL] move shouldStop() to end of while loop ## What changes were proposed in this pull request? This PR rollback some changes in #11274 , which introduced some performance regression when do a simple aggregation on parquet scan with one integer column. Does not really understand how this change introduce this huge impact, maybe related show JIT compiler inline functions. (saw very different stats from profiling). ## How was this patch tested? Manually run the parquet reader benchmark, before this change: ``` Intel(R) Core(TM) i7-4558U CPU 2.80GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 2391 / 3107 43.9 22.8 1.0X ``` After this change ``` Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU 2.80GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 2032 / 2626 51.6 19.4 1.0X``` Author: Davies Liu Closes #11912 from davies/fix_regression. --- .../org/apache/spark/sql/execution/ExistingRDD.scala | 8 +++++--- .../apache/spark/sql/execution/WholeStageCodegen.scala | 8 +++++--- .../org/apache/spark/sql/execution/basicOperators.scala | 3 ++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index b4348d39c2b4b..3e2c7997626f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -255,11 +255,12 @@ private[sql] case class DataSourceScan( | $numOutputRows.add(numRows); | } | - | while (!shouldStop() && $idx < numRows) { + | // this loop is very perf sensitive and changes to it should be measured carefully + | while ($idx < numRows) { | int $rowidx = $idx++; | ${consume(ctx, columns1).trim} + | if (shouldStop()) return; | } - | if (shouldStop()) return; | | if (!$input.hasNext()) { | $batch = null; @@ -280,7 +281,7 @@ private[sql] case class DataSourceScan( s""" | private void $scanRows(InternalRow $row) throws java.io.IOException { | boolean firstRow = true; - | while (!shouldStop() && (firstRow || $input.hasNext())) { + | while (firstRow || $input.hasNext()) { | if (firstRow) { | firstRow = false; | } else { @@ -288,6 +289,7 @@ private[sql] case class DataSourceScan( | } | $numOutputRows.add(1); | ${consume(ctx, columns2, inputRow).trim} + | if (shouldStop()) return; | } | }""".stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 5634e5fc5861b..0be0b8032a855 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -103,11 +103,12 @@ trait CodegenSupport extends SparkPlan { * # call child.produce() * initialized = true; * } - * while (!shouldStop() && hashmap.hasNext()) { + * while (hashmap.hasNext()) { * row = hashmap.next(); * # build the aggregation results * # create variables for results * # call consume(), which will call parent.doConsume() + * if (shouldStop()) return; * } */ protected def doProduce(ctx: CodegenContext): String @@ -251,9 +252,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) s""" - | while (!shouldStop() && $input.hasNext()) { + | while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | ${consume(ctx, columns, row).trim} + | if (shouldStop()) return; | } """.stripMargin } @@ -320,7 +322,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup /** Codegened pipeline for: * ${toCommentSafeString(child.treeString.trim)} */ - class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { private Object[] references; ${ctx.declareMutableStates()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6e2a5aa4f97c7..ee3f1d70e1300 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -282,13 +282,14 @@ case class Range( | } | } | - | while (!$overflow && $checkEnd && !shouldStop()) { + | while (!$overflow && $checkEnd) { | long $value = $number; | $number += ${step}L; | if ($number < $value ^ ${step}L < 0) { | $overflow = true; | } | ${consume(ctx, Seq(ev))} + | if (shouldStop()) return; | } """.stripMargin } From 0a64294fcb4b64bfe095c63c3a494e0f40e22743 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 23 Mar 2016 12:13:32 -0700 Subject: [PATCH 22/26] [SPARK-14015][SQL] Support TimestampType in vectorized parquet reader ## What changes were proposed in this pull request? This PR adds support for TimestampType in the vectorized parquet reader ## How was this patch tested? 1. `VectorizedColumnReader` initially had a gating condition on `primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96)` that made us fall back on parquet-mr for handling timestamps. This condition is now removed. 2. The `ParquetHadoopFsRelationSuite` (that tests for all supported hive types -- including `TimestampType`) fails when the gating condition is removed (https://github.com/apache/spark/pull/11808) and should now pass with this change. Similarly, the `ParquetHiveCompatibilitySuite.SPARK-10177 timestamp` test that fails when the gating condition is removed, should now pass as well. 3. Added tests in `HadoopFsRelationTest` that test both the dictionary encoded and non-encoded versions across all supported datatypes. Author: Sameer Agarwal Closes #11882 from sameeragarwal/timestamp-parquet. --- .../parquet/VectorizedColumnReader.java | 29 ++++++- .../VectorizedParquetRecordReader.java | 13 --- .../vectorized/OffHeapColumnVector.java | 2 +- .../vectorized/OnHeapColumnVector.java | 3 +- .../parquet/CatalystRowConverter.scala | 10 +++ .../sql/sources/hadoopFsRelationSuites.scala | 82 +++++++++++-------- 6 files changed, 86 insertions(+), 53 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 2c23ccc357a0b..6cc2fda5871dc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -213,6 +213,9 @@ void readBatch(int total, ColumnVector column) throws IOException { case INT64: readLongBatch(rowId, num, column); break; + case INT96: + readBinaryBatch(rowId, num, column); + break; case FLOAT: readFloatBatch(rowId, num, column); break; @@ -249,7 +252,17 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, case BINARY: column.setDictionary(dictionary); break; - + case INT96: + if (column.dataType() == DataTypes.TimestampType) { + for (int i = rowId; i < rowId + num; ++i) { + // TODO: Convert dictionary of Binaries to dictionary of Longs + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, CatalystRowConverter.binaryToSQLTimestamp(v)); + } + } else { + throw new NotImplementedException(); + } + break; case FIXED_LEN_BYTE_ARRAY: // DecimalType written in the legacy mode if (DecimalType.is32BitDecimalType(column.dataType())) { @@ -342,9 +355,19 @@ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOE private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions + VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; if (column.isArray()) { - defColumn.readBinarys( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + defColumn.readBinarys(num, column, rowId, maxDefLevel, data); + } else if (column.dataType() == DataTypes.TimestampType) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putLong(rowId + i, + // Read 12 bytes for INT96 + CatalystRowConverter.binaryToSQLTimestamp(data.readBinary(12))); + } else { + column.putNull(rowId + i); + } + } } else { throw new NotImplementedException("Unimplemented type: " + column.dataType()); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 9ac251391b61a..ab09208d5a0b2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -252,26 +252,13 @@ private void initializeInternal() throws IOException { /** * Check that the requested schema is supported. */ - OriginalType[] originalTypes = new OriginalType[requestedSchema.getFieldCount()]; missingColumns = new boolean[requestedSchema.getFieldCount()]; for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { Type t = requestedSchema.getFields().get(i); if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { throw new IOException("Complex types not supported."); } - PrimitiveType primitiveType = t.asPrimitiveType(); - originalTypes[i] = t.getOriginalType(); - - // TODO: Be extremely cautious in what is supported. Expand this. - if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && - originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE && - originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) { - throw new IOException("Unsupported type: " + t); - } - if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { - throw new IOException("Int96 not supported."); - } String[] colPath = requestedSchema.getPaths().get(i); if (fileSchema.containsPath(colPath)) { ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 689e6a2a6d82f..b1901411351a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -407,7 +407,7 @@ private void reserveInternal(int newCapacity) { type instanceof DateType || DecimalType.is32BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || - DecimalType.is64BitDecimalType(type)) { + DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index f332e87016692..b1429fe7cb5ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -403,7 +403,8 @@ private void reserveInternal(int newCapacity) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; - } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { + } else if (type instanceof LongType || type instanceof TimestampType || + DecimalType.is64BitDecimalType(type)) { long[] newData = new long[newCapacity]; if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); longData = newData; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index de6dd0fe3e6b5..6bf82bee67881 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -659,4 +660,13 @@ private[parquet] object CatalystRowConverter { unscaled = (unscaled << (64 - bits)) >> (64 - bits) unscaled } + + def binaryToSQLTimestamp(binary: Binary): SQLTimestamp = { + assert(binary.length() == 12, s"Timestamps (with nanoseconds) are expected to be stored in" + + s" 12-byte long binaries. Found a ${binary.length()}-byte binary instead.") + val buffer = binary.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buffer.getLong + val julianDay = buffer.getInt + DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 7e5506ee4af65..e842caf5bec13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -116,44 +116,56 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes new MyDenseVectorUDT() ).filter(supportsDataType) - for (dataType <- supportedDataTypes) { - test(s"test all data types - $dataType") { - withTempPath { file => - val path = file.getCanonicalPath - - val dataGenerator = RandomDataGenerator.forType( - dataType = dataType, - nullable = true, - new Random(System.nanoTime()) - ).getOrElse { - fail(s"Failed to create data generator for schema $dataType") + try { + for (dataType <- supportedDataTypes) { + for (parquetDictionaryEncodingEnabled <- Seq(true, false)) { + test(s"test all data types - $dataType with parquet.enable.dictionary = " + + s"$parquetDictionaryEncodingEnabled") { + + hadoopConfiguration.setBoolean("parquet.enable.dictionary", + parquetDictionaryEncodingEnabled) + + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, + nullable = true, + new Random(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") + } + + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = + sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) + + val loadedDF = sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } } - - // Create a DF for the schema with random data. The index field is used to sort the - // DataFrame. This is a workaround for SPARK-10591. - val schema = new StructType() - .add("index", IntegerType, nullable = false) - .add("col", dataType, nullable = true) - val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) - val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - - df.write - .mode("overwrite") - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .save(path) - - val loadedDF = sqlContext - .read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .schema(df.schema) - .load(path) - .orderBy("index") - - checkAnswer(loadedDF, df) } } + } finally { + hadoopConfiguration.unset("parquet.enable.dictionary") } test("save()/load() - non-partitioned table - Overwrite") { From 8c826880f5eaa3221c4e9e7d3fece54e821a0b98 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Mar 2016 12:48:05 -0700 Subject: [PATCH 23/26] [SPARK-13809][SQL] State store for streaming aggregations ## What changes were proposed in this pull request? In this PR, I am implementing a new abstraction for management of streaming state data - State Store. It is a key-value store for persisting running aggregates for aggregate operations in streaming dataframes. The motivation and design is discussed here. https://docs.google.com/document/d/1-ncawFx8JS5Zyfq1HAEGBx56RDet9wfVp_hDM8ZL254/edit# ## How was this patch tested? - [x] Unit tests - [x] Cluster tests **Coverage from unit tests** screen shot 2016-03-21 at 3 09 40 pm ## TODO - [x] Fix updates() iterator to avoid duplicate updates for same key - [x] Use Coordinator in ContinuousQueryManager - [x] Plugging in hadoop conf and other confs - [x] Unit tests - [x] StateStore object lifecycle and methods - [x] StateStoreCoordinator communication and logic - [x] StateStoreRDD fault-tolerance - [x] StateStoreRDD preferred location using StateStoreCoordinator - [ ] Cluster tests - [ ] Whether preferred locations are set correctly - [ ] Whether recovery works correctly with distributed storage - [x] Basic performance tests - [x] Docs Author: Tathagata Das Closes #11645 from tdas/state-store. --- .../spark/sql/ContinuousQueryManager.scala | 3 + .../state/HDFSBackedStateStoreProvider.scala | 584 ++++++++++++++++++ .../streaming/state/StateStore.scala | 247 ++++++++ .../streaming/state/StateStoreConf.scala | 37 ++ .../state/StateStoreCoordinator.scala | 146 +++++ .../streaming/state/StateStoreRDD.scala | 70 +++ .../execution/streaming/state/package.scala | 75 +++ .../apache/spark/sql/internal/SQLConf.scala | 13 + .../state/StateStoreCoordinatorSuite.scala | 123 ++++ .../streaming/state/StateStoreRDDSuite.scala | 192 ++++++ .../streaming/state/StateStoreSuite.scala | 562 +++++++++++++++++ 11 files changed, 2052 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index fa8219bbed0d5..465feeb60412f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.util.ContinuousQueryListener /** @@ -33,6 +34,8 @@ import org.apache.spark.sql.util.ContinuousQueryListener @Experimental class ContinuousQueryManager(sqlContext: SQLContext) { + private[sql] val stateStoreCoordinator = + StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env) private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) private val activeQueries = new mutable.HashMap[String, ContinuousQuery] private val activeQueriesLock = new Object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala new file mode 100644 index 0000000000000..ee015baf3fae7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -0,0 +1,584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.{DataInputStream, DataOutputStream, IOException} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random +import scala.util.control.NonFatal + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.io.LZ4CompressionCodec +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + + +/** + * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed + * by files in a HDFS-compatible file system. All updates to the store has to be done in sets + * transactionally, and each set of updates increments the store's version. These versions can + * be used to re-execute the updates (by retries in RDD operations) on the correct version of + * the store, and regenerate the store version. + * + * Usage: + * To update the data in the state store, the following order of operations are needed. + * + * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store + * - store.update(...) + * - store.remove(...) + * - store.commit() // commits all the updates to made with version number + * - store.iterator() // key-value data after last commit as an iterator + * - store.updates() // updates made in the last as an iterator + * + * Fault-tolerance model: + * - Every set of updates is written to a delta file before committing. + * - The state store is responsible for managing, collapsing and cleaning up of delta files. + * - Multiple attempts to commit the same version of updates may overwrite each other. + * Consistency guarantees depend on whether multiple attempts have the same updates and + * the overwrite semantics of underlying file system. + * - Background maintenance of files ensures that last versions of the store is always recoverable + * to ensure re-executed RDD operations re-apply updates on the correct past version of the + * store. + */ +private[state] class HDFSBackedStateStoreProvider( + val id: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + hadoopConf: Configuration + ) extends StateStoreProvider with Logging { + + type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] + + /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ + class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType) + extends StateStore { + + /** Trait and classes representing the internal state of the store */ + trait STATE + case object UPDATING extends STATE + case object COMMITTED extends STATE + case object CANCELLED extends STATE + + private val newVersion = version + 1 + private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") + private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) + + private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() + + @volatile private var state: STATE = UPDATING + @volatile private var finalDeltaFile: Path = null + + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + + /** + * Update the value of a key using the value generated by the update function. + * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous + * versions of the store data. + */ + override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot update after already committed or cancelled") + val oldValueOption = Option(mapToUpdate.get(key)) + val value = updateFunc(oldValueOption) + mapToUpdate.put(key, value) + + Option(allUpdates.get(key)) match { + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added already, keep it marked as added + allUpdates.put(key, ValueAdded(key, value)) + case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + // Value existed in prev version and updated/removed, mark it as updated + allUpdates.put(key, ValueUpdated(key, value)) + case None => + // There was no prior update, so mark this as added or updated according to its presence + // in previous version. + val update = + if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + allUpdates.put(key, update) + } + writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) + } + + /** Remove keys that match the following condition */ + override def remove(condition: UnsafeRow => Boolean): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + val keyIter = mapToUpdate.keySet().iterator() + while (keyIter.hasNext) { + val key = keyIter.next + if (condition(key)) { + keyIter.remove() + + Option(allUpdates.get(key)) match { + case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed + allUpdates.put(key, KeyRemoved(key)) + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates + allUpdates.remove(key) + case Some(KeyRemoved(_)) => + // Remove already in update map, no need to change + } + writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) + } + } + } + + /** Commit all the updates that have been made to the store, and return the new version. */ + override def commit(): Long = { + verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + + try { + finalizeDeltaFile(tempDeltaFileStream) + finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + state = COMMITTED + logInfo(s"Committed version $newVersion for $this") + newVersion + } catch { + case NonFatal(e) => + throw new IllegalStateException( + s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + } + } + + /** Cancel all the updates made on this store. This store will not be usable any more. */ + override def cancel(): Unit = { + state = CANCELLED + if (tempDeltaFileStream != null) { + tempDeltaFileStream.close() + } + if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { + fs.delete(tempDeltaFile, true) + } + logInfo("Canceled ") + } + + /** + * Get an iterator of all the store data. This can be called only after committing the + * updates. + */ + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { + verify(state == COMMITTED, "Cannot get iterator of store data before comitting") + HDFSBackedStateStoreProvider.this.iterator(newVersion) + } + + /** + * Get an iterator of all the updates made to the store in the current version. + * This can be called only after committing the updates. + */ + override def updates(): Iterator[StoreUpdate] = { + verify(state == COMMITTED, "Cannot get iterator of updates before committing") + allUpdates.values().asScala.toIterator + } + + /** + * Whether all updates have been committed + */ + override def hasCommitted: Boolean = { + state == COMMITTED + } + } + + /** Get the state store for making updates to create a new `version` of the store. */ + override def getStore(version: Long): StateStore = synchronized { + require(version >= 0, "Version cannot be less than 0") + val newMap = new MapType() + if (version > 0) { + newMap.putAll(loadMap(version)) + } + val store = new HDFSBackedStateStore(version, newMap) + logInfo(s"Retrieved version $version of $this for update") + store + } + + /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ + override def doMaintenance(): Unit = { + try { + doSnapshot() + cleanup() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up $this") + } + } + + override def toString(): String = { + s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" + } + + /* Internal classes and methods */ + + private val loadedMaps = new mutable.HashMap[Long, MapType] + private val baseDir = + new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private val fs = baseDir.getFileSystem(hadoopConf) + private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + + initialize() + + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + + /** Commit a set of updates to the store with the given new version */ + private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + synchronized { + val finalDeltaFile = deltaFile(newVersion) + fs.rename(tempDeltaFile, finalDeltaFile) + loadedMaps.put(newVersion, map) + finalDeltaFile + } + } + + /** + * Get iterator of all the data of the latest version of the store. + * Note that this will look up the files to determined the latest known version. + */ + private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + val versionsInFiles = fetchFiles().map(_.version).toSet + val versionsLoaded = loadedMaps.keySet + val allKnownVersions = versionsInFiles ++ versionsLoaded + if (allKnownVersions.nonEmpty) { + loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } else Iterator.empty + } + + /** Get iterator of a specific version of the store */ + private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + loadMap(version).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } + + /** Initialize the store provider */ + private def initialize(): Unit = { + if (!fs.exists(baseDir)) { + fs.mkdirs(baseDir) + } else { + if (!fs.isDirectory(baseDir)) { + throw new IllegalStateException( + s"Cannot use ${id.checkpointLocation} for storing state data for $this as" + + s"$baseDir already exists and is not a directory") + } + } + } + + /** Load the required version of the map data from the backing files */ + private def loadMap(version: Long): MapType = { + if (version <= 0) return new MapType + synchronized { loadedMaps.get(version) }.getOrElse { + val mapFromFile = readSnapshotFile(version).getOrElse { + val prevMap = loadMap(version - 1) + val newMap = new MapType(prevMap) + newMap.putAll(prevMap) + updateFromDeltaFile(version, newMap) + newMap + } + loadedMaps.put(version, mapFromFile) + mapFromFile + } + } + + private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { + + def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + + def writeRemove(key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) + } + + update match { + case ValueAdded(key, value) => + writeUpdate(key, value) + case ValueUpdated(key, value) => + writeUpdate(key, value) + case KeyRemoved(key) => + writeRemove(key) + } + } + + private def finalizeDeltaFile(output: DataOutputStream): Unit = { + output.writeInt(-1) // Write this magic number to signify end of file + output.close() + } + + private def updateFromDeltaFile(version: Long, map: MapType): Unit = { + val fileToRead = deltaFile(version) + if (!fs.exists(fileToRead)) { + throw new IllegalStateException( + s"Error reading delta file $fileToRead of $this: $fileToRead does not exist") + } + var input: DataInputStream = null + try { + input = decompressStream(fs.open(fileToRead)) + var eof = false + + while(!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new IOException( + s"Error reading delta file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + map.remove(keyRow) + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + } finally { + if (input != null) input.close() + } + logInfo(s"Read delta file for version $version of $this from $fileToRead") + } + + private def writeSnapshotFile(version: Long, map: MapType): Unit = { + val fileToWrite = snapshotFile(version) + var output: DataOutputStream = null + Utils.tryWithSafeFinally { + output = compressStream(fs.create(fileToWrite, false)) + val iter = map.entrySet().iterator() + while(iter.hasNext) { + val entry = iter.next() + val keyBytes = entry.getKey.getBytes() + val valueBytes = entry.getValue.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + output.writeInt(-1) + } { + if (output != null) output.close() + } + logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") + } + + private def readSnapshotFile(version: Long): Option[MapType] = { + val fileToRead = snapshotFile(version) + if (!fs.exists(fileToRead)) return None + + val map = new MapType() + var input: DataInputStream = null + + try { + input = decompressStream(fs.open(fileToRead)) + var eof = false + + while (!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new IOException( + s"Error reading snapshot file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + throw new IOException( + s"Error reading snapshot file $fileToRead of $this: value size cannot be $valueSize") + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + logInfo(s"Read snapshot file for version $version of $this from $fileToRead") + Some(map) + } finally { + if (input != null) input.close() + } + } + + + /** Perform a snapshot of the store to allow delta files to be consolidated */ + private def doSnapshot(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val lastVersion = files.last.version + val deltaFilesForLastVersion = + filesForVersion(files, lastVersion).filter(_.isSnapshot == false) + synchronized { loadedMaps.get(lastVersion) } match { + case Some(map) => + if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) { + writeSnapshotFile(lastVersion, map) + } + case None => + // The last map is not loaded, probably some other instance is incharge + } + + } + } catch { + case NonFatal(e) => + logWarning(s"Error doing snapshots for $this", e) + } + } + + /** + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ + private[state] def cleanup(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain + if (earliestVersionToRetain > 0) { + val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head + synchronized { + val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq + mapsToRemove.foreach(loadedMaps.remove) + } + files.filter(_.version < earliestFileToRetain.version).foreach { f => + fs.delete(f.path, true) + } + logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this") + } + } + } catch { + case NonFatal(e) => + logWarning(s"Error cleaning up files for $this", e) + } + } + + /** Files needed to recover the given version of the store */ + private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { + require(version >= 0) + require(allFiles.exists(_.version == version)) + + val latestSnapshotFileBeforeVersion = allFiles + .filter(_.isSnapshot == true) + .takeWhile(_.version <= version) + .lastOption + val deltaBatchFiles = latestSnapshotFileBeforeVersion match { + case Some(snapshotFile) => + val deltaBatchIds = (snapshotFile.version + 1) to version + + val deltaFiles = allFiles.filter { file => + file.version > snapshotFile.version && file.version <= version + } + verify( + deltaFiles.size == version - snapshotFile.version, + s"Unexpected list of delta files for version $version for $this: $deltaFiles" + ) + deltaFiles + + case None => + allFiles.takeWhile(_.version <= version) + } + latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles + } + + /** Fetch all the files that back the store */ + private def fetchFiles(): Seq[StoreFile] = { + val files: Seq[FileStatus] = try { + fs.listStatus(baseDir) + } catch { + case _: java.io.FileNotFoundException => + Seq.empty + } + val versionToFiles = new mutable.HashMap[Long, StoreFile] + files.foreach { status => + val path = status.getPath + val nameParts = path.getName.split("\\.") + if (nameParts.size == 2) { + val version = nameParts(0).toLong + nameParts(1).toLowerCase match { + case "delta" => + // ignore the file otherwise, snapshot file already exists for that batch id + if (!versionToFiles.contains(version)) { + versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) + } + case "snapshot" => + versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) + case _ => + logWarning(s"Could not identify file $path for $this") + } + } + } + val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) + logDebug(s"Current set of files for $this: $storeFiles") + storeFiles + } + + private def compressStream(outputStream: DataOutputStream): DataOutputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream) + new DataOutputStream(compressed) + } + + private def decompressStream(inputStream: DataInputStream): DataInputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream) + new DataInputStream(compressed) + } + + private def deltaFile(version: Long): Path = { + new Path(baseDir, s"$version.delta") + } + + private def snapshotFile(version: Long): Path = { + new Path(baseDir, s"$version.snapshot") + } + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { + throw new IllegalStateException(msg) + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala new file mode 100644 index 0000000000000..ca5c864d9e993 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.Timer +import java.util.concurrent.{ScheduledFuture, TimeUnit} + +import scala.collection.mutable +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils + + +/** Unique identifier for a [[StateStore]] */ +case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) + + +/** + * Base trait for a versioned key-value store used for streaming aggregations + */ +trait StateStore { + + /** Unique identifier of the store */ + def id: StateStoreId + + /** Version of the data in this store before committing updates. */ + def version: Long + + /** + * Update the value of a key using the value generated by the update function. + * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous + * versions of the store data. + */ + def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit + + /** + * Remove keys that match the following condition. + */ + def remove(condition: UnsafeRow => Boolean): Unit + + /** + * Commit all the updates that have been made to the store, and return the new version. + */ + def commit(): Long + + /** Cancel all the updates that have been made to the store. */ + def cancel(): Unit + + /** + * Iterator of store data after a set of updates have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ + def iterator(): Iterator[(UnsafeRow, UnsafeRow)] + + /** + * Iterator of the updates that have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ + def updates(): Iterator[StoreUpdate] + + /** + * Whether all updates have been committed + */ + def hasCommitted: Boolean +} + + +/** Trait representing a provider of a specific version of a [[StateStore]]. */ +trait StateStoreProvider { + + /** Get the store with the existing version. */ + def getStore(version: Long): StateStore + + /** Optional method for providers to allow for background maintenance */ + def doMaintenance(): Unit = { } +} + + +/** Trait representing updates made to a [[StateStore]]. */ +sealed trait StoreUpdate + +case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + +case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + +case class KeyRemoved(key: UnsafeRow) extends StoreUpdate + + +/** + * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores + * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), + * it also runs a periodic background tasks to do maintenance on the loaded stores. For each + * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of + * the store is the active instance. Accordingly, it either keeps it loaded and performs + * maintenance, or unloads the store. + */ +private[state] object StateStore extends Logging { + + val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval" + val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 + + private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val maintenanceTaskExecutor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") + + @volatile private var maintenanceTask: ScheduledFuture[_] = null + @volatile private var _coordRef: StateStoreCoordinatorRef = null + + /** Get or create a store associated with the id. */ + def get( + storeId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + version: Long, + storeConf: StateStoreConf, + hadoopConf: Configuration): StateStore = { + require(version >= 0) + val storeProvider = loadedProviders.synchronized { + startMaintenanceIfNeeded() + val provider = loadedProviders.getOrElseUpdate( + storeId, + new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) + reportActiveStoreInstance(storeId) + provider + } + storeProvider.getStore(version) + } + + /** Unload a state store provider */ + def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeId) + } + + /** Whether a state store provider is loaded or not */ + def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeId) + } + + /** Unload and stop all state store providers */ + def stop(): Unit = loadedProviders.synchronized { + loadedProviders.clear() + _coordRef = null + if (maintenanceTask != null) { + maintenanceTask.cancel(false) + maintenanceTask = null + } + logInfo("StateStore stopped") + } + + /** Start the periodic maintenance task if not already started and if Spark active */ + private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { + val env = SparkEnv.get + if (maintenanceTask == null && env != null) { + val periodMs = env.conf.getTimeAsMs( + MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") + val runnable = new Runnable { + override def run(): Unit = { doMaintenance() } + } + maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate( + runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) + logInfo("State Store maintenance task started") + } + } + + /** + * Execute background maintenance task in all the loaded store providers if they are still + * the active instances according to the coordinator. + */ + private def doMaintenance(): Unit = { + logDebug("Doing maintenance") + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfStoreInstanceActive(id)) { + provider.doMaintenance() + } else { + unload(id) + logInfo(s"Unloaded $provider") + } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider") + } + } + } + + private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + try { + val host = SparkEnv.get.blockManager.blockManagerId.host + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) + logDebug(s"Reported that the loaded instance $storeId is active") + } catch { + case NonFatal(e) => + logWarning(s"Error reporting active instance of $storeId") + } + } + + private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + try { + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + val verified = + coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) + logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" ) + verified + } catch { + case NonFatal(e) => + logWarning(s"Error verifying active instance of $storeId") + false + } + } + + private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + val env = SparkEnv.get + if (env != null) { + if (_coordRef == null) { + _coordRef = StateStoreCoordinatorRef.forExecutor(env) + } + logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + Some(_coordRef) + } else { + _coordRef = null + None + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala new file mode 100644 index 0000000000000..cca22a0af823f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.internal.SQLConf + +/** A class that contains configuration parameters for [[StateStore]]s. */ +private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { + + def this() = this(new SQLConf) + + import SQLConf._ + + val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + + val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) +} + +private[state] object StateStoreConf { + val empty = new StateStoreConf() +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala new file mode 100644 index 0000000000000..5aa0636850255 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.mutable + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.util.RpcUtils + +/** Trait representing all messages to [[StateStoreCoordinator]] */ +private sealed trait StateStoreCoordinatorMessage extends Serializable + +/** Classes representing messages */ +private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) + extends StateStoreCoordinatorMessage + +private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) + extends StateStoreCoordinatorMessage + +private case class GetLocation(storeId: StateStoreId) + extends StateStoreCoordinatorMessage + +private case class DeactivateInstances(storeRootLocation: String) + extends StateStoreCoordinatorMessage + +private object StopCoordinator + extends StateStoreCoordinatorMessage + +/** Helper object used to create reference to [[StateStoreCoordinator]]. */ +private[sql] object StateStoreCoordinatorRef extends Logging { + + private val endpointName = "StateStoreCoordinator" + + /** + * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as + * executors. + */ + def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + try { + val coordinator = new StateStoreCoordinator(env.rpcEnv) + val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) + logInfo("Registered StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(coordinatorRef) + } catch { + case e: IllegalArgumentException => + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(rpcEndpointRef) + } + } + + def forExecutor(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(rpcEndpointRef) + } +} + +/** + * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of + * [[StateStore]]s across all the executors, and get their locations for job scheduling. + */ +private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { + + private[state] def reportActiveInstance( + storeId: StateStoreId, + host: String, + executorId: String): Unit = { + rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + } + + /** Verify whether the given executor has the active instance of a state store */ + private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { + rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId)) + } + + /** Get the location of the state store */ + private[state] def getLocation(storeId: StateStoreId): Option[String] = { + rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId)) + } + + /** Deactivate instances related to a set of operator */ + private[state] def deactivateInstances(storeRootLocation: String): Unit = { + rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation)) + } + + private[state] def stop(): Unit = { + rpcEndpointRef.askWithRetry[Boolean](StopCoordinator) + } +} + + +/** + * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, + * and get their locations for job scheduling. + */ +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint { + private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + + override def receive: PartialFunction[Any, Unit] = { + case ReportActiveInstance(id, host, executorId) => + instances.put(id, ExecutorCacheTaskLocation(host, executorId)) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case VerifyIfInstanceActive(id, execId) => + val response = instances.get(id) match { + case Some(location) => location.executorId == execId + case None => false + } + context.reply(response) + + case GetLocation(id) => + context.reply(instances.get(id).map(_.toString)) + + case DeactivateInstances(loc) => + val storeIdsToRemove = + instances.keys.filter(_.checkpointLocation == loc).toSeq + instances --= storeIdsToRemove + context.reply(true) + + case StopCoordinator => + stop() // Stop before replying to ensure that endpoint name has been deregistered + context.reply(true) + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala new file mode 100644 index 0000000000000..3318660895195 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * An RDD that allows computations to be executed against [[StateStore]]s. It + * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as + * preferred locations. + */ +class StateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + extends RDD[U](dataRDD) { + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val confBroadcast = dataRDD.context.broadcast( + new SerializableConfiguration(dataRDD.context.hadoopConfiguration)) + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { + var store: StateStore = null + + Utils.tryWithSafeFinally { + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + val outputIter = storeUpdateFunction(store, inputIter) + assert(store.hasCommitted) + outputIter + } { + if (store != null) store.cancel() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala new file mode 100644 index 0000000000000..b249e37921f09 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.StructType + +package object state { + + implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ + def mapPartitionWithStateStore[U: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType + )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + + mapPartitionWithStateStore( + storeUpdateFunction, + checkpointLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) + } + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ + private[state] def mapPartitionWithStateStore[U: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + storeCoordinator: Option[StateStoreCoordinatorRef] + ): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new StateStoreRDD( + dataRDD, + cleanedF, + checkpointLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + storeConf, + storeCoordinator) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fd1d77f514a95..863a876afe9c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -524,6 +524,19 @@ object SQLConf { doc = "When true, the planner will try to find out duplicated exchanges and re-use them.", isPublic = false) + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf( + "spark.sql.streaming.stateStore.minDeltasForSnapshot", + defaultValue = Some(10), + doc = "Minimum number of state store delta files that needs to be generated before they " + + "consolidated into snapshots.", + isPublic = false) + + val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf( + "spark.sql.streaming.stateStore.minBatchesToRetain", + defaultValue = Some(2), + doc = "Minimum number of versions of a state store's data to retain after cleaning.", + isPublic = false) + val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation", defaultValue = None, doc = "The default location for storing checkpoint data for continuously executing queries.", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala new file mode 100644 index 0000000000000..c99c2f505f3e4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation + +class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { + + import StateStoreCoordinatorSuite._ + + test("report, verify, getLocation") { + withCoordinatorRef(sc) { coordinatorRef => + val id = StateStoreId("x", 0, 0) + + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.getLocation(id) === None) + + coordinatorRef.reportActiveInstance(id, "hostX", "exec1") + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + + coordinatorRef.reportActiveInstance(id, "hostX", "exec2") + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) + + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + } + } + } + + test("make inactive") { + withCoordinatorRef(sc) { coordinatorRef => + val id1 = StateStoreId("x", 0, 0) + val id2 = StateStoreId("y", 1, 0) + val id3 = StateStoreId("x", 0, 1) + val host = "hostX" + val exec = "exec1" + + coordinatorRef.reportActiveInstance(id1, host, exec) + coordinatorRef.reportActiveInstance(id2, host, exec) + coordinatorRef.reportActiveInstance(id3, host, exec) + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) + + } + + coordinatorRef.deactivateInstances("x") + + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false) + + assert(coordinatorRef.getLocation(id1) === None) + assert( + coordinatorRef.getLocation(id2) === + Some(ExecutorCacheTaskLocation(host, exec).toString)) + assert(coordinatorRef.getLocation(id3) === None) + + coordinatorRef.deactivateInstances("y") + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) + assert(coordinatorRef.getLocation(id2) === None) + } + } + + test("multiple references have same underlying coordinator") { + withCoordinatorRef(sc) { coordRef1 => + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + + val id = StateStoreId("x", 0, 0) + + coordRef1.reportActiveInstance(id, "hostX", "exec1") + + eventually(timeout(5 seconds)) { + assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordRef2.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + } + } +} + +object StateStoreCoordinatorSuite { + def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { + var coordinatorRef: StateStoreCoordinatorRef = null + try { + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + body(coordinatorRef) + } finally { + if (coordinatorRef != null) coordinatorRef.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala new file mode 100644 index 0000000000000..24cec30fa335c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File +import java.nio.file.Files + +import scala.util.Random + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.Utils + +class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + import StateStoreSuite._ + + after { + StateStore.stop() + } + + override def afterAll(): Unit = { + super.afterAll() + Utils.deleteRecursively(new File(tempDir)) + } + + test("versioning and immutability") { + quietly { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContet = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + test("recovering from files") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion, keySchema, valueSchema) + } + + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + } + } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } + } + } + + test("preferred locations using StateStoreCoordinator") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) + val coordinatorRef = sqlContext.streams.stateStoreCoordinator + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + assert( + coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + require(rdd.partitions.length === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) === + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) === + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() + } + } + } + + test("distributed test") { + quietly { + withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => + implicit val sqlContet = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { + sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) + } + + private val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala new file mode 100644 index 0000000000000..22b2f4f75d39e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File + +import scala.collection.mutable +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + + import StateStoreCoordinatorSuite._ + import StateStoreSuite._ + + private val tempDir = Utils.createTempDir().toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + after { + StateStore.stop() + } + + test("update, remove, commit, and all data iterator") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator().isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + + // Verify state after updating + update(store, "a", 1) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + assert(provider.latestIterator().isEmpty) + + // Make updates, commit and then verify state + update(store, "b", 2) + update(store, "aa", 3) + remove(store, _.startsWith("a")) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) + assert(fileExists(provider, version = 1, isSnapshot = false)) + + assert(getDataFromFiles(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getDataFromFiles(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = new HDFSBackedStateStoreProvider( + store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedStore = reloadedProvider.getStore(1) + update(reloadedStore, "c", 4) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) + assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) + } + + test("updates iterator with all combos of updates and removes") { + val provider = newStoreProvider() + var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { + val store = provider.getStore(currentVersion) + body(store) + currentVersion += 1 + } + + // New data should be seen in updates as value added, even if they had multiple updates + withStore { store => + update(store, "a", 1) + update(store, "aa", 1) + update(store, "aa", 2) + store.commit() + assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) + assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) + } + + // Multiple updates to same key should be collapsed in the updates as a single value update + // Keys that have not been updated should not appear in the updates + withStore { store => + update(store, "a", 4) + update(store, "a", 6) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Keys added, updated and finally removed before commit should not appear in updates + withStore { store => + update(store, "b", 4) // Added, finally removed + update(store, "bb", 5) // Added, updated, finally removed + update(store, "bb", 6) + remove(store, _.startsWith("b")) + store.commit() + assert(updatesToSet(store.updates()) === Set.empty) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Removed data should be seen in updates as a key removed + // Removed, but re-added data should be seen in updates as a value update + withStore { store => + remove(store, _.startsWith("a")) + update(store, "a", 10) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) + assert(rowsToSet(store.iterator()) === Set("a" -> 10)) + } + } + + test("cancel") { + val provider = newStoreProvider() + val store = provider.getStore(0) + update(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + update(store1, "b", 1) + store1.cancel() + assert(getDataFromFiles(provider) === Set("a" -> 1)) + } + + test("getStore with unexpected versions") { + val provider = newStoreProvider() + + intercept[IllegalArgumentException] { + provider.getStore(-1) + } + + // Prepare some data in the stoer + val store = provider.getStore(0) + update(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + intercept[IllegalStateException] { + provider.getStore(2) + } + + // Update store version with some data + val store1 = provider.getStore(1) + update(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) + + // Overwrite the version with other data + val store2 = provider.getStore(1) + update(store2, "c", 1) + assert(store2.commit() === 2) + assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) + } + + test("snapshotting") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + var currentVersion = 0 + def updateVersionTo(targetVersion: Int): Unit = { + for (i <- currentVersion + 1 to targetVersion) { + val store = provider.getStore(currentVersion) + update(store, "a", i) + store.commit() + currentVersion += 1 + } + require(currentVersion === targetVersion) + } + + updateVersionTo(2) + require(getDataFromFiles(provider) === Set("a" -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 2)) + + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + updateVersionTo(6) + require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + + val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getDataFromFiles(provider) === Set("a" -> 6), + "snapshotting messed up the data of the final version") + + // After version 20, snapshotting should generate newer snapshot files + updateVersionTo(20) + require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) + assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + } + + test("cleaning") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + require( + rowsToSet(provider.latestIterator()) === Set("a" -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + + // last couple of versions should be retrievable + assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) + assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + } + + + test("corrupted file handling") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + for (i <- 1 to 6) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + val snapshotVersion = (0 to 10).find( version => + fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) + + // Corrupt snapshot file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + corruptFile(provider, snapshotVersion, isSnapshot = true) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion) + } + + // Corrupt delta file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + corruptFile(provider, snapshotVersion - 1, isSnapshot = false) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + + // Delete delta file and verify that it throws error + deleteFilesEarlierThanVersion(provider, snapshotVersion) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + } + + test("StateStore.get") { + quietly { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + } + assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store + + intercept[IllegalStateException] { + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + } + + // Increase version of the store + val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + assert(store0.version === 0) + update(store0, "a", 1) + store0.commit() + + assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) + assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + + // Verify that you can remove the store and still reload and use it + StateStore.unload(storeId) + assert(!StateStore.isLoaded(storeId)) + + val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + update(store1, "a", 2) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + } + } + + test("maintenance") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + .set("spark.rpc.numRetries", "1") + val opId = 0 + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, opId, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + val provider = new HDFSBackedStateStoreProvider( + storeId, keySchema, valueSchema, storeConf, hadoopConf) + + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + for (i <- 1 to 20) { + val store = StateStore.get( + storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) + update(store, "a", i) + store.commit() + } + eventually(timeout(10 seconds)) { + assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + } + + // Background maintenance should clean up and generate snapshots + eventually(timeout(10 seconds)) { + // Earliest delta file should get cleaned up + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + + // Some snapshots should have been generated + val snapshotVersions = (0 to 20).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + + // If driver decides to deactivate all instances of the store, then this instance + // should be unloaded + coordinatorRef.deactivateInstances(dir) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + + // If some other executor loads the store, then this instance should be unloaded + coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + } + } + + // Verify if instance is unloaded if SparkContext is stopped + require(SparkEnv.get === null) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + } + } + + def getDataFromFiles( + provider: HDFSBackedStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = new HDFSBackedStateStoreProvider( + provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + if (version < 0) { + reloadedProvider.latestIterator().map(rowsToStringInt).toSet + } else { + reloadedProvider.iterator(version).map(rowsToStringInt).toSet + } + } + + def assertMap( + testMapOption: Option[MapType], + expectedMap: Map[String, Int]): Unit = { + assert(testMapOption.nonEmpty, "no map present") + val convertedMap = testMapOption.get.map(rowsToStringInt) + assert(convertedMap === expectedMap) + } + + def fileExists( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Boolean = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.exists + } + + def deleteFilesEarlierThanVersion(provider: HDFSBackedStateStoreProvider, version: Long): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + for (version <- 0 until version.toInt) { + for (isSnapshot <- Seq(false, true)) { + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + if (filePath.exists) filePath.delete() + } + } + } + + def corruptFile( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.delete() + filePath.createNewFile() + } + + def storeLoaded(storeId: StateStoreId): Boolean = { + val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) + val loadedStores = StateStore invokePrivate method() + loadedStores.contains(storeId) + } + + def unloadStore(storeId: StateStoreId): Boolean = { + val method = PrivateMethod('remove) + StateStore invokePrivate method(storeId) + } + + def newStoreProvider( + opId: Long = Random.nextLong, + partition: Int = 0, + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get + ): HDFSBackedStateStoreProvider = { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + new HDFSBackedStateStoreProvider( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + new StateStoreConf(sqlConf), + new Configuration()) + } + + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.remove(row => condition(rowToString(row))) + } + + private def update(store: StateStore, key: String, value: Int): Unit = { + store.update(stringToRow(key), _ => intToRow(value)) + } +} + +private[state] object StateStoreSuite { + + /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ + trait TestUpdate + case class Added(key: String, value: Int) extends TestUpdate + case class Updated(key: String, value: Int) extends TestUpdate + case class Removed(key: String) extends TestUpdate + + val strProj = UnsafeProjection.create(Array[DataType](StringType)) + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + + def stringToRow(s: String): UnsafeRow = { + strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy() + } + + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToString(row: UnsafeRow): String = { + row.getUTF8String(0).toString + } + + def rowToInt(row: UnsafeRow): Int = { + row.getInt(0) + } + + def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { + (rowToInt(row._1), rowToInt(row._2)) + } + + + def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { + (rowToString(row._1), rowToInt(row._2)) + } + + def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } + + def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { + iterator.map { _ match { + case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) + case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) + case KeyRemoved(key) => Removed(rowToString(key)) + }}.toSet + } +} From 919bf321987712d9143cae3c4e064fcb077ded1f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 23 Mar 2016 20:51:01 +0100 Subject: [PATCH 24/26] [SPARK-13325][SQL] Create a 64-bit hashcode expression This PR introduces a 64-bit hashcode expression. Such an expression is especially usefull for HyperLogLog++ and other probabilistic datastructures. I have implemented xxHash64 which is a 64-bit hashing algorithm created by Yann Colet and Mathias Westerdahl. This is a high speed (C implementation runs at memory bandwidth) and high quality hashcode. It exploits both Instruction Level Parralellism (for speed) and the multiplication and rotation techniques (for quality) like MurMurHash does. The initial results are promising. I have added a CG'ed test to the `HashBenchmark`, and this results in the following results (running from SBT): Running benchmark: Hash For simple Running case: interpreted version Running case: codegen version Running case: codegen version 64-bit Intel(R) Core(TM) i7-4750HQ CPU 2.00GHz Hash For simple: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- interpreted version 1011 / 1016 132.8 7.5 1.0X codegen version 1864 / 1869 72.0 13.9 0.5X codegen version 64-bit 1614 / 1644 83.2 12.0 0.6X Running benchmark: Hash For normal Running case: interpreted version Running case: codegen version Running case: codegen version 64-bit Intel(R) Core(TM) i7-4750HQ CPU 2.00GHz Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- interpreted version 2467 / 2475 0.9 1176.1 1.0X codegen version 2008 / 2115 1.0 957.5 1.2X codegen version 64-bit 728 / 758 2.9 347.0 3.4X Running benchmark: Hash For array Running case: interpreted version Running case: codegen version Running case: codegen version 64-bit Intel(R) Core(TM) i7-4750HQ CPU 2.00GHz Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- interpreted version 1544 / 1707 0.1 11779.6 1.0X codegen version 2728 / 2745 0.0 20815.5 0.6X codegen version 64-bit 2508 / 2549 0.1 19132.8 0.6X Running benchmark: Hash For map Running case: interpreted version Running case: codegen version Running case: codegen version 64-bit Intel(R) Core(TM) i7-4750HQ CPU 2.00GHz Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- interpreted version 1819 / 1826 0.0 444014.3 1.0X codegen version 183 / 194 0.0 44642.9 9.9X codegen version 64-bit 173 / 174 0.0 42120.9 10.5X This shows that algorithm is consistently faster than MurMurHash32 in all cases and up to 3x (!) in the normal case. I have also added this to HyperLogLog++ and it cuts the processing time of the following code in half: val df = sqlContext.range(1<<25).agg(approxCountDistinct("id")) df.explain() val t = System.nanoTime() df.show() val ns = System.nanoTime() - t // Before ns: Long = 5821524302 // After ns: Long = 2836418963 cc cloud-fan (you have been working on hashcodes) / rxin Author: Herman van Hovell Closes #11209 from hvanhovell/xxHash. --- .../spark/sql/catalyst/expressions/XXH64.java | 192 ++++++++++++++ .../aggregate/HyperLogLogPlusPlus.scala | 2 +- .../spark/sql/catalyst/expressions/misc.scala | 238 +++++++++++------- .../sql/catalyst/expressions/XXH64Suite.java | 166 ++++++++++++ .../org/apache/spark/sql/HashBenchmark.scala | 64 +++-- .../spark/sql/HashByteArrayBenchmark.scala | 148 +++++++++++ .../expressions/MiscFunctionsSuite.scala | 13 +- 7 files changed, 713 insertions(+), 110 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java new file mode 100644 index 0000000000000..5f2de266b538f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.SystemClock; + +// scalastyle: off + +/** + * xxHash64. A high quality and fast 64 bit hash code by Yann Colet and Mathias Westerdahl. The + * class below is modelled like its Murmur3_x86_32 cousin. + *

+ * This was largely based on the following (original) C and Java implementations: + * https://github.com/Cyan4973/xxHash/blob/master/xxhash.c + * https://github.com/OpenHFT/Zero-Allocation-Hashing/blob/master/src/main/java/net/openhft/hashing/XxHash_r39.java + * https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/XxHash64.java + */ +// scalastyle: on +public final class XXH64 { + + private static final long PRIME64_1 = 0x9E3779B185EBCA87L; + private static final long PRIME64_2 = 0xC2B2AE3D27D4EB4FL; + private static final long PRIME64_3 = 0x165667B19E3779F9L; + private static final long PRIME64_4 = 0x85EBCA77C2B2AE63L; + private static final long PRIME64_5 = 0x27D4EB2F165667C5L; + + private final long seed; + + public XXH64(long seed) { + super(); + this.seed = seed; + } + + @Override + public String toString() { + return "xxHash64(seed=" + seed + ")"; + } + + public long hashInt(int input) { + return hashInt(input, seed); + } + + public static long hashInt(int input, long seed) { + long hash = seed + PRIME64_5 + 4L; + hash ^= (input & 0xFFFFFFFFL) * PRIME64_1; + hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; + return fmix(hash); + } + + public long hashLong(long input) { + return hashLong(input, seed); + } + + public static long hashLong(long input, long seed) { + long hash = seed + PRIME64_5 + 8L; + hash ^= Long.rotateLeft(input * PRIME64_2, 31) * PRIME64_1; + hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; + return fmix(hash); + } + + public long hashUnsafeWords(Object base, long offset, int length) { + return hashUnsafeWords(base, offset, length, seed); + } + + public static long hashUnsafeWords(Object base, long offset, int length, long seed) { + assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWords(base, offset, length, seed); + return fmix(hash); + } + + public long hashUnsafeBytes(Object base, long offset, int length) { + return hashUnsafeBytes(base, offset, length, seed); + } + + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + assert (length >= 0) : "lengthInBytes cannot be negative"; + long hash = hashBytesByWords(base, offset, length, seed); + long end = offset + length; + offset += length & -8; + + if (offset + 4L <= end) { + hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; + offset += 4L; + } + + while (offset < end) { + hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash = Long.rotateLeft(hash, 11) * PRIME64_1; + offset++; + } + return fmix(hash); + } + + private static long fmix(long hash) { + hash ^= hash >>> 33; + hash *= PRIME64_2; + hash ^= hash >>> 29; + hash *= PRIME64_3; + hash ^= hash >>> 32; + return hash; + } + + private static long hashBytesByWords(Object base, long offset, int length, long seed) { + long end = offset + length; + long hash; + if (length >= 32) { + long limit = end - 32; + long v1 = seed + PRIME64_1 + PRIME64_2; + long v2 = seed + PRIME64_2; + long v3 = seed; + long v4 = seed - PRIME64_1; + + do { + v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 = Long.rotateLeft(v1, 31); + v1 *= PRIME64_1; + + v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 = Long.rotateLeft(v2, 31); + v2 *= PRIME64_1; + + v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 = Long.rotateLeft(v3, 31); + v3 *= PRIME64_1; + + v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 = Long.rotateLeft(v4, 31); + v4 *= PRIME64_1; + + offset += 32L; + } while (offset <= limit); + + hash = Long.rotateLeft(v1, 1) + + Long.rotateLeft(v2, 7) + + Long.rotateLeft(v3, 12) + + Long.rotateLeft(v4, 18); + + v1 *= PRIME64_2; + v1 = Long.rotateLeft(v1, 31); + v1 *= PRIME64_1; + hash ^= v1; + hash = hash * PRIME64_1 + PRIME64_4; + + v2 *= PRIME64_2; + v2 = Long.rotateLeft(v2, 31); + v2 *= PRIME64_1; + hash ^= v2; + hash = hash * PRIME64_1 + PRIME64_4; + + v3 *= PRIME64_2; + v3 = Long.rotateLeft(v3, 31); + v3 *= PRIME64_1; + hash ^= v3; + hash = hash * PRIME64_1 + PRIME64_4; + + v4 *= PRIME64_2; + v4 = Long.rotateLeft(v4, 31); + v4 *= PRIME64_1; + hash ^= v4; + hash = hash * PRIME64_1 + PRIME64_4; + } else { + hash = seed + PRIME64_5; + } + + hash += length; + + long limit = end - 8; + while (offset <= limit) { + long k1 = Platform.getLong(base, offset); + hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; + hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; + offset += 8L; + } + return hash; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 32bae133608c9..b6bd56cff6b33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -169,7 +169,7 @@ case class HyperLogLogPlusPlus( val v = child.eval(input) if (v != null) { // Create the hashed value 'x'. - val x = MurmurHash.hash64(v) + val x = XxHash64Function.hash(v, child.dataType, 42L) // Determine the index of the register we are going to use. val idx = (x >>> idxShift).toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8f260ad151cd2..e8a3e129b49e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -185,6 +185,7 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp } } + /** * A function that calculates hash value for a group of expressions. Note that the `seed` argument * is not exposed to users and should only be set inside spark SQL. @@ -213,14 +214,10 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp * `result`. * * Finally we aggregate the hash values for each expression by the same way of struct. - * - * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle - * and bucketing have same data distribution. */ -case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression { - def this(arguments: Seq[Expression]) = this(arguments, 42) - - override def dataType: DataType = IntegerType +abstract class HashExpression[E] extends Expression { + /** Seed of the HashExpression. */ + val seed: E override def foldable: Boolean = children.forall(_.foldable) @@ -234,8 +231,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } - override def prettyName: String = "hash" - override def eval(input: InternalRow): Any = { var hash = seed var i = 0 @@ -247,80 +242,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression hash } - private def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - def hashInt(i: Int): Int = Murmur3_x86_32.hashInt(i, seed) - def hashLong(l: Long): Int = Murmur3_x86_32.hashLong(l, seed) - - value match { - case null => seed - case b: Boolean => hashInt(if (b) 1 else 0) - case b: Byte => hashInt(b) - case s: Short => hashInt(s) - case i: Int => hashInt(i) - case l: Long => hashLong(l) - case f: Float => hashInt(java.lang.Float.floatToIntBits(f)) - case d: Double => hashLong(java.lang.Double.doubleToLongBits(d)) - case d: Decimal => - val precision = dataType.asInstanceOf[DecimalType].precision - if (precision <= Decimal.MAX_LONG_DIGITS) { - hashLong(d.toUnscaledLong) - } else { - val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray - Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) - } - case c: CalendarInterval => Murmur3_x86_32.hashInt(c.months, hashLong(c.microseconds)) - case a: Array[Byte] => - Murmur3_x86_32.hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - Murmur3_x86_32.hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) - - case array: ArrayData => - val elementType = dataType match { - case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType - case ArrayType(et, _) => et - } - var result = seed - var i = 0 - while (i < array.numElements()) { - result = computeHash(array.get(i, elementType), elementType, result) - i += 1 - } - result - - case map: MapData => - val (kt, vt) = dataType match { - case udt: UserDefinedType[_] => - val mapType = udt.sqlType.asInstanceOf[MapType] - mapType.keyType -> mapType.valueType - case MapType(kt, vt, _) => kt -> vt - } - val keys = map.keyArray() - val values = map.valueArray() - var result = seed - var i = 0 - while (i < map.numElements()) { - result = computeHash(keys.get(i, kt), kt, result) - result = computeHash(values.get(i, vt), vt, result) - i += 1 - } - result - - case struct: InternalRow => - val types: Array[DataType] = dataType match { - case udt: UserDefinedType[_] => - udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray - case StructType(fields) => fields.map(_.dataType) - } - var result = seed - var i = 0 - val len = struct.numFields - while (i < len) { - result = computeHash(struct.get(i, types(i)), types(i), result) - i += 1 - } - result - } - } + protected def computeHash(value: Any, dataType: DataType, seed: E): E override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" @@ -332,7 +254,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression }.mkString("\n") s""" - int ${ev.value} = $seed; + ${ctx.javaType(dataType)} ${ev.value} = $seed; $childrenHash """ } @@ -360,7 +282,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression dataType: DataType, result: String, ctx: CodegenContext): String = { - val hasher = classOf[Murmur3_x86_32].getName + val hasher = hasherClassName def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" @@ -423,6 +345,125 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) } } + + protected def hasherClassName: String +} + +/** + * Base class for interpreted hash functions. + */ +abstract class InterpretedHashFunction { + protected def hashInt(i: Int, seed: Long): Long + + protected def hashLong(l: Long, seed: Long): Long + + protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + + def hash(value: Any, dataType: DataType, seed: Long): Long = { + value match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case d: Decimal => + val precision = dataType.asInstanceOf[DecimalType].precision + if (precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(d.toUnscaledLong, seed) + } else { + val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray + hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) + } + case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) + case a: Array[Byte] => + hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) + case s: UTF8String => + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + var result = seed + var i = 0 + while (i < array.numElements()) { + result = hash(array.get(i, elementType), elementType, result) + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(kt, vt, _) => kt -> vt + } + val keys = map.keyArray() + val values = map.valueArray() + var result = seed + var i = 0 + while (i < map.numElements()) { + result = hash(keys.get(i, kt), kt, result) + result = hash(values.get(i, vt), vt, result) + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + var result = seed + var i = 0 + val len = struct.numFields + while (i < len) { + result = hash(struct.get(i, types(i)), types(i), result) + i += 1 + } + result + } + } +} + +/** + * A MurMur3 Hash expression. + * + * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle + * and bucketing have same data distribution. + */ +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def prettyName: String = "hash" + + override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + Murmur3HashFunction.hash(value, dataType, seed).toInt + } +} + +object Murmur3HashFunction extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = { + Murmur3_x86_32.hashInt(i, seed.toInt) + } + + override protected def hashLong(l: Long, seed: Long): Long = { + Murmur3_x86_32.hashLong(l, seed.toInt) + } + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) + } } /** @@ -442,3 +483,30 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { """.stripMargin) } } + +/** + * A xxHash64 64-bit hash expression. + */ +case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] { + def this(arguments: Seq[Expression]) = this(arguments, 42L) + + override def dataType: DataType = LongType + + override def prettyName: String = "xxHash" + + override protected def hasherClassName: String = classOf[XXH64].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { + XxHash64Function.hash(value, dataType, seed) + } +} + +object XxHash64Function extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed) + + override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + XXH64.hashUnsafeBytes(base, offset, len, seed) + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java new file mode 100644 index 0000000000000..711887f02832a --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test the XXH64 function. + *

+ * Test constants were taken from the original implementation and the airlift/slice implementation. + */ +public class XXH64Suite { + + private static final XXH64 hasher = new XXH64(0); + + private static final int SIZE = 101; + private static final long PRIME = 2654435761L; + private static final byte[] BUFFER = new byte[SIZE]; + private static final int TEST_INT = 0x4B1FFF9E; // First 4 bytes in the buffer + private static final long TEST_LONG = 0xDD2F535E4B1FFF9EL; // First 8 bytes in the buffer + + /* Create the test data. */ + static { + long seed = PRIME; + for (int i = 0; i < SIZE; i++) { + BUFFER[i] = (byte) (seed >> 24); + seed *= seed; + } + } + + @Test + public void testKnownIntegerInputs() { + Assert.assertEquals(0x9256E58AA397AEF1L, hasher.hashInt(TEST_INT)); + Assert.assertEquals(0x9D5FFDFB928AB4BL, XXH64.hashInt(TEST_INT, PRIME)); + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(0xF74CB1451B32B8CFL, hasher.hashLong(TEST_LONG)); + Assert.assertEquals(0x9C44B77FBCC302C5L, XXH64.hashLong(TEST_LONG, PRIME)); + } + + @Test + public void testKnownByteArrayInputs() { + Assert.assertEquals(0xEF46DB3751D8E999L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 0)); + Assert.assertEquals(0xAC75FDA2929B17EFL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 0, PRIME)); + Assert.assertEquals(0x4FCE394CC88952D8L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1)); + Assert.assertEquals(0x739840CB819FA723L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME)); + + // These tests currently fail in a big endian environment because the test data and expected + // answers are generated with little endian the assumptions. We could revisit this when Platform + // becomes endian aware. + if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { + Assert.assertEquals(0x9256E58AA397AEF1L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); + Assert.assertEquals(0x9D5FFDFB928AB4BL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME)); + Assert.assertEquals(0xF74CB1451B32B8CFL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8)); + Assert.assertEquals(0x9C44B77FBCC302C5L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME)); + Assert.assertEquals(0xCFFA8DB881BC3A3DL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14)); + Assert.assertEquals(0x5B9611585EFCC9CBL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME)); + Assert.assertEquals(0x0EAB543384F878ADL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); + Assert.assertEquals(0xCAA65939306F1E21L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); + } + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint)); + Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint)); + + hashcodes.add(hasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95d); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95d); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95d); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 5a929f211aaa4..c6a1a2be0d071 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.types._ import org.apache.spark.util.Benchmark /** - * Benchmark for the previous interpreted hash function(InternalRow.hashCode) vs the new codegen - * hash expression(Murmur3Hash). + * Benchmark for the previous interpreted hash function(InternalRow.hashCode) vs codegened + * hash expressions (Murmur3Hash/xxHash64). */ object HashBenchmark { @@ -63,19 +63,44 @@ object HashBenchmark { } } } + + val getHashCode64b = UnsafeProjection.create(new XxHash64(attrs) :: Nil, attrs) + benchmark.addCase("codegen version 64-bit") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numRows) { + sum += getHashCode64b(rows(i)).getInt(0) + i += 1 + } + } + } + benchmark.run() } def main(args: Array[String]): Unit = { - val simple = new StructType().add("i", IntegerType) + val singleInt = new StructType().add("i", IntegerType) + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 1006 / 1011 133.4 7.5 1.0X + codegen version 1835 / 1839 73.1 13.7 0.5X + codegen version 64-bit 1627 / 1628 82.5 12.1 0.6X + */ + test("single ints", singleInt, 1 << 15, 1 << 14) + + val singleLong = new StructType().add("i", LongType) /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Hash For simple: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - interpreted version 941 / 955 142.6 7.0 1.0X - codegen version 1737 / 1775 77.3 12.9 0.5X + interpreted version 1196 / 1209 112.2 8.9 1.0X + codegen version 2178 / 2181 61.6 16.2 0.5X + codegen version 64-bit 1752 / 1753 76.6 13.1 0.7X */ - test("simple", simple, 1 << 13, 1 << 14) + test("single longs", singleLong, 1 << 15, 1 << 14) val normal = new StructType() .add("null", NullType) @@ -93,11 +118,12 @@ object HashBenchmark { .add("date", DateType) .add("timestamp", TimestampType) /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - interpreted version 2209 / 2271 0.9 1053.4 1.0X - codegen version 1887 / 2018 1.1 899.9 1.2X + interpreted version 2713 / 2715 0.8 1293.5 1.0X + codegen version 2015 / 2018 1.0 960.9 1.3X + codegen version 64-bit 735 / 738 2.9 350.7 3.7X */ test("normal", normal, 1 << 10, 1 << 11) @@ -106,11 +132,12 @@ object HashBenchmark { .add("array", arrayOfInt) .add("arrayOfArray", ArrayType(arrayOfInt)) /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - interpreted version 1481 / 1529 0.1 11301.7 1.0X - codegen version 2591 / 2636 0.1 19771.1 0.6X + interpreted version 1498 / 1499 0.1 11432.1 1.0X + codegen version 2642 / 2643 0.0 20158.4 0.6X + codegen version 64-bit 2421 / 2424 0.1 18472.5 0.6X */ test("array", array, 1 << 8, 1 << 9) @@ -119,11 +146,12 @@ object HashBenchmark { .add("map", mapOfInt) .add("mapOfMap", MapType(IntegerType, mapOfInt)) /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - interpreted version 1820 / 1861 0.0 444347.2 1.0X - codegen version 205 / 223 0.0 49936.5 8.9X + interpreted version 1612 / 1618 0.0 393553.4 1.0X + codegen version 149 / 150 0.0 36381.2 10.8X + codegen version 64-bit 144 / 145 0.0 35122.1 11.2X */ test("map", map, 1 << 6, 1 << 6) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala new file mode 100644 index 0000000000000..53f21a8442429 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Random + +import org.apache.spark.sql.catalyst.expressions.XXH64 +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.util.Benchmark + +/** + * Synthetic benchmark for MurMurHash 3 and xxHash64. + */ +object HashByteArrayBenchmark { + def test(length: Int, seed: Long, numArrays: Int, iters: Int): Unit = { + val random = new Random(seed) + val arrays = Array.fill[Array[Byte]](numArrays) { + val bytes = new Array[Byte](length) + random.nextBytes(bytes) + bytes + } + + val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) + benchmark.addCase("Murmur3_x86_32") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numArrays) { + sum += Murmur3_x86_32.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) + i += 1 + } + } + } + + benchmark.addCase("xxHash 64-bit") { _: Int => + for (_ <- 0L until iters) { + var sum = 0L + var i = 0 + while (i < numArrays) { + sum += XXH64.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) + i += 1 + } + } + } + + benchmark.run() + } + + def main(args: Array[String]): Unit = { + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 11 / 12 185.1 5.4 1.0X + xxHash 64-bit 17 / 18 120.0 8.3 0.6X + */ + test(8, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 18 / 18 118.6 8.4 1.0X + xxHash 64-bit 20 / 21 102.5 9.8 0.9X + */ + test(16, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 24 / 24 86.6 11.5 1.0X + xxHash 64-bit 23 / 23 93.2 10.7 1.1X + */ + test(24, 42L, 1 << 10, 1 << 11) + + // Add 31 to all arrays to create worse case alignment for xxHash. + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 38 / 39 54.7 18.3 1.0X + xxHash 64-bit 33 / 33 64.4 15.5 1.2X + */ + test(31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 91 / 94 22.9 43.6 1.0X + xxHash 64-bit 68 / 69 30.6 32.7 1.3X + */ + test(64 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 268 / 268 7.8 127.6 1.0X + xxHash 64-bit 108 / 109 19.4 51.6 2.5X + */ + test(256 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 942 / 945 2.2 449.4 1.0X + xxHash 64-bit 276 / 276 7.6 131.4 3.4X + */ + test(1024 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 1839 / 1843 1.1 876.8 1.0X + xxHash 64-bit 445 / 448 4.7 212.1 4.1X + */ + test(2048 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 7307 / 7310 0.3 3484.4 1.0X + xxHash 64-bit 1487 / 1488 1.4 709.1 4.9X + */ + test(8192 + 31, 42L, 1 << 10, 1 << 11) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 60d50baf511d9..f5bafcc6a783e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -76,7 +76,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { private val mapOfString = MapType(StringType, StringType) private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) - testMurmur3Hash( + testHash( new StructType() .add("null", NullType) .add("boolean", BooleanType) @@ -94,7 +94,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { .add("timestamp", TimestampType) .add("udt", new ExamplePointUDT)) - testMurmur3Hash( + testHash( new StructType() .add("arrayOfNull", arrayOfNull) .add("arrayOfString", arrayOfString) @@ -104,7 +104,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { .add("arrayOfStruct", ArrayType(structOfString)) .add("arrayOfUDT", arrayOfUDT)) - testMurmur3Hash( + testHash( new StructType() .add("mapOfIntAndString", MapType(IntegerType, StringType)) .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) @@ -114,7 +114,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { .add("mapOfStructAndString", MapType(structOfString, StringType)) .add("mapOfStruct", MapType(structOfString, structOfString))) - testMurmur3Hash( + testHash( new StructType() .add("structOfString", structOfString) .add("structOfStructOfString", new StructType().add("struct", structOfString)) @@ -124,11 +124,11 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) - private def testMurmur3Hash(inputSchema: StructType): Unit = { + private def testHash(inputSchema: StructType): Unit = { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get val encoder = RowEncoder(inputSchema) val seed = scala.util.Random.nextInt() - test(s"murmur3 hash: ${inputSchema.simpleString}") { + test(s"murmur3/xxHash64 hash: ${inputSchema.simpleString}") { for (_ <- 1 to 10) { val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { @@ -136,6 +136,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Only test the interpreted version has same result with codegen version. checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) + checkEvaluation(XxHash64(literals, seed), XxHash64(literals, seed).eval()) } } } From 6bc4be64f86afcb38e4444c80c9400b7b6b745de Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 23 Mar 2016 13:02:40 -0700 Subject: [PATCH 25/26] [SPARK-14078] Streaming Parquet Based FileSink This PR adds a new `Sink` implementation that writes out Parquet files. In order to correctly handle partial failures while maintaining exactly once semantics, the files for each batch are written out to a unique directory and then atomically appended to a metadata log. When a parquet based `DataSource` is initialized for reading, we first check for this log directory and use it instead of file listing when present. Unit tests are added, as well as a stress test that checks the answer after non-deterministic injected failures. Author: Michael Armbrust Closes #11897 from marmbrus/fileSink. --- .../apache/spark/sql/ContinuousQuery.scala | 9 ++ .../spark/sql/ContinuousQueryException.scala | 6 +- .../execution/datasources/DataSource.scala | 64 ++++++++- .../execution/streaming/CompositeOffset.scala | 3 + .../execution/streaming/FileStreamSink.scala | 81 +++++++++++ .../streaming/FileStreamSource.scala | 14 +- .../execution/streaming/HDFSMetadataLog.scala | 7 +- .../sql/execution/streaming/LongOffset.scala | 2 + .../sql/execution/streaming/MetadataLog.scala | 2 +- .../execution/streaming/StreamExecution.scala | 18 +++ .../streaming/StreamFileCatalog.scala | 59 ++++++++ .../streaming/HDFSMetadataLogSuite.scala | 2 + .../sql/streaming/FileStreamSinkSuite.scala | 49 +++++++ .../spark/sql/streaming/FileStressSuite.scala | 129 ++++++++++++++++++ 14 files changed, 430 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index eb69804c39b5d..1dc9a6893ebb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -91,6 +91,15 @@ trait ContinuousQuery { */ def awaitTermination(timeoutMs: Long): Boolean + /** + * Blocks until all available data in the source has been processed an committed to the sink. + * This method is intended for testing. Note that in the case of continually arriving data, this + * method may block forever. Additionally, this method is only guranteed to block until data that + * has been synchronously appended data to a [[org.apache.spark.sql.execution.streaming.Source]] + * prior to invocation. (i.e. `getOffset` must immediately reflect the addition). + */ + def processAllAvailable(): Unit + /** * Stops the execution of this query if it is running. This method blocks until the threads * performing execution has stopped. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala index 67dd9dbe23726..fec38629d914e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala @@ -32,12 +32,12 @@ import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} */ @Experimental class ContinuousQueryException private[sql]( - val query: ContinuousQuery, + @transient val query: ContinuousQuery, val message: String, val cause: Throwable, val startOffset: Option[Offset] = None, - val endOffset: Option[Offset] = None - ) extends Exception(message, cause) { + val endOffset: Option[Offset] = None) + extends Exception(message, cause) { /** Time when the exception occurred */ val time: Long = System.currentTimeMillis diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 548da86359c26..c66921f4852c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -22,6 +22,7 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -29,7 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils @@ -176,14 +177,41 @@ case class DataSource( /** Returns a sink that can be used to continually write data. */ def createSink(): Sink = { - val datasourceClass = providingClass.newInstance() match { - case s: StreamSinkProvider => s + providingClass.newInstance() match { + case s: StreamSinkProvider => s.createSink(sqlContext, options, partitionColumns) + case format: FileFormat => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val path = caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + + new FileStreamSink(sqlContext, path, format) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed writing") } + } - datasourceClass.createSink(sqlContext, options, partitionColumns) + /** + * Returns true if there is a single path that has a metadata log indicating which files should + * be read. + */ + def hasMetadata(path: Seq[String]): Boolean = { + path match { + case Seq(singlePath) => + try { + val hdfsPath = new Path(singlePath) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val metadataPath = new Path(hdfsPath, FileStreamSink.metadataDir) + val res = fs.exists(metadataPath) + res + } catch { + case NonFatal(e) => + logWarning(s"Error while looking for metadata directory.") + false + } + case _ => false + } } /** Create a resolved [[BaseRelation]] that can be used to read data from this [[DataSource]] */ @@ -200,6 +228,34 @@ case class DataSource( case (_: RelationProvider, Some(_)) => throw new AnalysisException(s"$className does not allow user-specified schemas.") + // We are reading from the results of a streaming query. Load files from the metadata log + // instead of listing them using HDFS APIs. + case (format: FileFormat, _) + if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => + val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) + val fileCatalog = + new StreamFileCatalog(sqlContext, basePath) + val dataSchema = userSpecifiedSchema.orElse { + format.inferSchema( + sqlContext, + caseInsensitiveOptions, + fileCatalog.allFiles()) + }.getOrElse { + throw new AnalysisException( + s"Unable to infer schema for $format at ${fileCatalog.allFiles().mkString(",")}. " + + "It must be specified manually") + } + + HadoopFsRelation( + sqlContext, + fileCatalog, + partitionSchema = fileCatalog.partitionSpec().partitionColumns, + dataSchema = dataSchema, + bucketSpec = None, + format, + options) + + // This is a non-streaming file based datasource. case (format: FileFormat, _) => val allPaths = caseInsensitiveOptions.get("path") ++ paths val globbedPaths = allPaths.flatMap { path => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala index e48ac598929ab..729c8462fed65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -64,6 +64,9 @@ case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { assert(sources.size == offsets.size) new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } + + override def toString: String = + offsets.map(_.map(_.toString).getOrElse("-")).mkString("[", ", ", "]") } object CompositeOffset { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala new file mode 100644 index 0000000000000..e819e95d61f9a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.sources.FileFormat + +object FileStreamSink { + // The name of the subdirectory that is used to store metadata about which files are valid. + val metadataDir = "_spark_metadata" +} + +/** + * A sink that writes out results to parquet files. Each batch is written out to a unique + * directory. After all of the files in a batch have been succesfully written, the list of + * file paths is appended to the log atomically. In the case of partial failures, some duplicate + * data may be present in the target directory, but only one copy of each file will be present + * in the log. + */ +class FileStreamSink( + sqlContext: SQLContext, + path: String, + fileFormat: FileFormat) extends Sink with Logging { + + private val basePath = new Path(path) + private val logPath = new Path(basePath, FileStreamSink.metadataDir) + private val fileLog = new HDFSMetadataLog[Seq[String]](sqlContext, logPath.toUri.toString) + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (fileLog.get(batchId).isDefined) { + logInfo(s"Skipping already committed batch $batchId") + } else { + val files = writeFiles(data) + if (fileLog.add(batchId, files)) { + logInfo(s"Committed batch $batchId") + } else { + logWarning(s"Race while writing batch $batchId") + } + } + } + + /** Writes the [[DataFrame]] to a UUID-named dir, returning the list of files paths. */ + private def writeFiles(data: DataFrame): Seq[String] = { + val ctx = sqlContext + val outputDir = path + val format = fileFormat + val schema = data.schema + + val file = new Path(basePath, UUID.randomUUID().toString).toUri.toString + data.write.parquet(file) + sqlContext.read + .schema(data.schema) + .parquet(file) + .inputFiles + .map(new Path(_)) + .filterNot(_.getName.startsWith("_")) + .map(_.toUri.toString) + } + + override def toString: String = s"FileSink[$path]" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index d13b1a6166798..1b70055f346b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -44,7 +44,7 @@ class FileStreamSource( private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) private val seenFiles = new OpenHashSet[String] - metadataLog.get(None, maxBatchId).foreach { case (batchId, files) => + metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) => files.foreach(seenFiles.add) } @@ -114,18 +114,24 @@ class FileStreamSource( val endId = end.asInstanceOf[LongOffset].offset assert(startId <= endId) - val files = metadataLog.get(Some(startId + 1), endId).map(_._2).flatten - logDebug(s"Return files from batches ${startId + 1}:$endId") + val files = metadataLog.get(Some(startId + 1), Some(endId)).map(_._2).flatten + logInfo(s"Processing ${files.length} files from ${startId + 1}:$endId") logDebug(s"Streaming ${files.mkString(", ")}") dataFrameBuilder(files) } private def fetchAllFiles(): Seq[String] = { - fs.listStatus(new Path(path)) + val startTime = System.nanoTime() + val files = fs.listStatus(new Path(path)) .filterNot(_.getPath.getName.startsWith("_")) .map(_.getPath.toUri.toString) + val endTime = System.nanoTime() + logDebug(s"Listed ${files.size} in ${(endTime.toDouble - startTime) / 1000000}ms") + files } override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) + + override def toString: String = s"FileSource[$path]" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 298b5d292e8e4..f27d23b1cdcdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -170,11 +170,12 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) } } - override def get(startId: Option[Long], endId: Long): Array[(Long, T)] = { - val batchIds = fc.util().listStatus(metadataPath, batchFilesFilter) + override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] = { + val files = fc.util().listStatus(metadataPath, batchFilesFilter) + val batchIds = files .map(_.getPath.getName.toLong) .filter { batchId => - batchId <= endId && (startId.isEmpty || batchId >= startId.get) + (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) } batchIds.sorted.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 008195af38b75..bb176408d8f59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -30,4 +30,6 @@ case class LongOffset(offset: Long) extends Offset { def +(increment: Long): LongOffset = new LongOffset(offset + increment) def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) + + override def toString: String = s"#$offset" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala index 3f9896d23ce36..cc70e1d314d1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala @@ -42,7 +42,7 @@ trait MetadataLog[T] { * Return metadata for batches between startId (inclusive) and endId (inclusive). If `startId` is * `None`, just return all batches before endId (inclusive). */ - def get(startId: Option[Long], endId: Long): Array[(Long, T)] + def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] /** * Return the latest batch Id and its metadata if exist. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 29b058f2e4062..5abd7eca2c2e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -239,6 +239,12 @@ class StreamExecution( logInfo(s"Committed offsets for batch $currentBatchId.") true } else { + noNewData = true + awaitBatchLock.synchronized { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLock.notifyAll() + } + false } } @@ -334,6 +340,18 @@ class StreamExecution( logDebug(s"Unblocked at $newOffset for $source") } + /** A flag to indicate that a batch has completed with no new data available. */ + @volatile private var noNewData = false + + override def processAllAvailable(): Unit = { + noNewData = false + while (!noNewData) { + awaitBatchLock.synchronized { awaitBatchLock.wait(10000) } + if (streamDeathCause != null) { throw streamDeathCause } + } + if (streamDeathCause != null) { throw streamDeathCause } + } + override def awaitTermination(): Unit = { if (state == INITIALIZED) { throw new IllegalStateException("Cannot wait for termination on a query that has not started") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala new file mode 100644 index 0000000000000..b8d69b18450cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources.{FileCatalog, Partition} +import org.apache.spark.sql.types.StructType + +class StreamFileCatalog(sqlContext: SQLContext, path: Path) extends FileCatalog with Logging { + val metadataDirectory = new Path(path, FileStreamSink.metadataDir) + logInfo(s"Reading streaming file log from $metadataDirectory") + val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataDirectory.toUri.toString) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + + override def paths: Seq[Path] = path :: Nil + + override def partitionSpec(): PartitionSpec = PartitionSpec(StructType(Nil), Nil) + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with not partition values. + * + * @param filters the filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + override def listFiles(filters: Seq[Expression]): Seq[Partition] = + Partition(InternalRow.empty, allFiles()) :: Nil + + override def getStatus(path: Path): Array[FileStatus] = fs.listStatus(path) + + override def refresh(): Unit = {} + + override def allFiles(): Seq[FileStatus] = { + fs.listStatus(metadataLog.get(None, None).flatMap(_._2).map(new Path(_))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4ddc218455eb2..9ed5686d977c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.test.SharedSQLContext class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { + private implicit def toOption[A](a: A): Option[A] = Option(a) + test("basic") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala new file mode 100644 index 0000000000000..7f316113835ff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class FileStreamSinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("unpartitioned writing") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + + val outputDir = Utils.createTempDir("stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir("stream.checkpoint").getCanonicalPath + + val query = + df.write + .format("parquet") + .option("checkpointLocation", checkpointDir) + .startStream(outputDir) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { query.processAllAvailable() } + + val outputDf = sqlContext.read.parquet(outputDir).as[Int] + checkDataset( + outputDf, + 1, 2, 3) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala new file mode 100644 index 0000000000000..5a1bfb3a005c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File +import java.util.UUID + +import scala.util.Random +import scala.util.control.NonFatal + +import org.apache.spark.sql.{ContinuousQuery, ContinuousQueryException, StreamTest} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +/** + * A stress test for streamign queries that read and write files. This test constists of + * two threads: + * - one that writes out `numRecords` distinct integers to files of random sizes (the total + * number of records is fixed but each files size / creation time is random). + * - another that continually restarts a buggy streaming query (i.e. fails with 5% probability on + * any partition). + * + * At the end, the resulting files are loaded and the answer is checked. + */ +class FileStressSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("fault tolerance stress test") { + val numRecords = 10000 + val inputDir = Utils.createTempDir("stream.input").getCanonicalPath + val stagingDir = Utils.createTempDir("stream.staging").getCanonicalPath + val outputDir = Utils.createTempDir("stream.output").getCanonicalPath + val checkpoint = Utils.createTempDir("stream.checkpoint").getCanonicalPath + + @volatile + var continue = true + @volatile + var stream: ContinuousQuery = null + + val writer = new Thread("stream writer") { + override def run(): Unit = { + var i = numRecords + while (i > 0) { + val count = Random.nextInt(100) + var j = 0 + var string = "" + while (j < count && i > 0) { + if (i % 10000 == 0) { logError(s"Wrote record $i") } + string = string + i + "\n" + j += 1 + i -= 1 + } + + val uuid = UUID.randomUUID().toString + val fileName = new File(stagingDir, uuid) + stringToFile(fileName, string) + fileName.renameTo(new File(inputDir, uuid)) + val sleep = Random.nextInt(100) + Thread.sleep(sleep) + } + + logError("== DONE WRITING ==") + var done = false + while (!done) { + try { + stream.processAllAvailable() + done = true + } catch { + case NonFatal(_) => + } + } + + continue = false + stream.stop() + } + } + writer.start() + + val input = sqlContext.read.format("text").stream(inputDir) + def startStream(): ContinuousQuery = input + .repartition(5) + .as[String] + .mapPartitions { iter => + val rand = Random.nextInt(100) + if (rand < 5) { sys.error("failure") } + iter.map(_.toLong) + } + .write + .format("parquet") + .option("checkpointLocation", checkpoint) + .startStream(outputDir) + + var failures = 0 + val streamThread = new Thread("stream runner") { + while (continue) { + if (failures % 10 == 0) { logError(s"Query restart #$failures") } + stream = startStream() + + try { + stream.awaitTermination() + } catch { + case ce: ContinuousQueryException => + failures += 1 + } + } + } + + streamThread.join() + + logError(s"Stream restarted $failures times.") + assert(sqlContext.read.parquet(outputDir).distinct().count() == numRecords) + } +} From 5dfc01976bb0d72489620b4f32cc12d620bb6260 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 23 Mar 2016 13:34:22 -0700 Subject: [PATCH 26/26] [SPARK-14014][SQL] Replace existing catalog with SessionCatalog ## What changes were proposed in this pull request? `SessionCatalog`, introduced in #11750, is a catalog that keeps track of temporary functions and tables, and delegates metastore operations to `ExternalCatalog`. This functionality overlaps a lot with the existing `analysis.Catalog`. As of this commit, `SessionCatalog` and `ExternalCatalog` will no longer be dead code. There are still things that need to be done after this patch, namely: - SPARK-14013: Properly implement temporary functions in `SessionCatalog` - SPARK-13879: Decide which DDL/DML commands to support natively in Spark - SPARK-?????: Implement the ones we do want to support through `SessionCatalog`. - SPARK-?????: Merge SQL/HiveContext ## How was this patch tested? This is largely a refactoring task so there are no new tests introduced. The particularly relevant tests are `SessionCatalogSuite` and `ExternalCatalogSuite`. Author: Andrew Or Author: Yin Huai Closes #11836 from andrewor14/use-session-catalog. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 3 +- project/MimaExcludes.scala | 3 + python/pyspark/sql/context.py | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 20 +- .../spark/sql/catalyst/analysis/Catalog.scala | 218 -------- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../catalyst/catalog/InMemoryCatalog.scala | 35 +- .../sql/catalyst/catalog/SessionCatalog.scala | 123 +++-- .../sql/catalyst/catalog/interface.scala | 2 + .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +- .../sql/catalyst/analysis/AnalysisTest.scala | 23 +- .../analysis/DecimalPrecisionSuite.scala | 25 +- .../catalyst/catalog/CatalogTestCases.scala | 3 +- .../catalog/SessionCatalogSuite.scala | 20 +- .../BooleanSimplificationSuite.scala | 11 +- .../optimizer/EliminateSortsSuite.scala | 5 +- .../org/apache/spark/sql/SQLContext.scala | 73 ++- .../sql/execution/command/commands.scala | 8 +- .../spark/sql/execution/datasources/ddl.scala | 24 +- .../sql/execution/datasources/rules.scala | 10 +- .../spark/sql/internal/SessionState.scala | 7 +- .../apache/spark/sql/ListTablesSuite.scala | 15 +- .../apache/spark/sql/SQLContextSuite.scala | 9 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 +- .../parquet/ParquetQuerySuite.scala | 6 +- .../apache/spark/sql/test/SQLTestUtils.scala | 4 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 3 +- .../sql/hive/thriftserver/CliSuite.scala | 5 +- .../apache/spark/sql/hive/HiveCatalog.scala | 5 +- .../apache/spark/sql/hive/HiveContext.scala | 498 ++++++++++-------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 60 +-- .../spark/sql/hive/HiveSessionCatalog.scala | 104 ++++ .../spark/sql/hive/HiveSessionState.scala | 10 +- .../spark/sql/hive/client/HiveClient.scala | 3 - .../sql/hive/client/HiveClientImpl.scala | 4 - .../hive/execution/CreateTableAsSelect.scala | 4 +- .../hive/execution/CreateViewAsSelect.scala | 4 +- .../hive/execution/InsertIntoHiveTable.scala | 14 +- .../spark/sql/hive/execution/commands.scala | 9 +- .../apache/spark/sql/hive/test/TestHive.scala | 151 ++++-- .../hive/JavaMetastoreDataSourcesSuite.java | 5 +- .../spark/sql/hive/HiveContextSuite.scala | 38 ++ .../sql/hive/HiveMetastoreCatalogSuite.scala | 9 +- .../spark/sql/hive/ListTablesSuite.scala | 6 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 31 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 5 +- .../spark/sql/hive/StatisticsSuite.scala | 3 +- .../spark/sql/hive/client/VersionsSuite.scala | 4 - .../sql/hive/execution/HiveQuerySuite.scala | 16 +- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 24 +- 52 files changed, 919 insertions(+), 783 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 63acbadfa6a16..eef365b42e56d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1817,7 +1817,8 @@ test_that("approxQuantile() on a DataFrame", { test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) - expect_equal(grepl("Table not found: blah", retError), TRUE) + expect_equal(grepl("Table not found", retError), TRUE) + expect_equal(grepl("blah", retError), TRUE) }) irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 42eafcb0f52d0..915898389ca00 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -562,6 +562,9 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Logging.initializeLogIfNecessary"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerEvent.logEvent"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance") + ) ++ Seq( + // [SPARK-14014] Replace existing analysis.Catalog with SessionCatalog + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.this") ) ++ Seq( // [SPARK-13928] Move org.apache.spark.Logging into org.apache.spark.internal.Logging ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Logging"), diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9c2f6a3c5660f..4008332c84d0a 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -554,7 +554,7 @@ def tableNames(self, dbName=None): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() True - >>> "table1" in sqlContext.tableNames("db") + >>> "table1" in sqlContext.tableNames("default") True """ if dbName is None: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5951a70c4809a..178e9402faa74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -36,23 +37,22 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.types._ /** - * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing - * when all relations are already filled in and the analyzer needs only to resolve attribute - * references. + * A trivial [[Analyzer]] with an dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. + * Used for testing when all relations are already filled in and the analyzer needs only + * to resolve attribute references. */ object SimpleAnalyzer - extends Analyzer( - EmptyCatalog, - EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + extends SimpleAnalyzer(new SimpleCatalystConf(caseSensitiveAnalysis = true)) +class SimpleAnalyzer(conf: CatalystConf) + extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf) /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and - * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and - * a [[FunctionRegistry]]. + * [[UnresolvedRelation]]s into fully typed objects using information in a + * [[SessionCatalog]] and a [[FunctionRegistry]]. */ class Analyzer( - catalog: Catalog, + catalog: SessionCatalog, registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala deleted file mode 100644 index 2f0a4dbc107aa..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, EmptyConf, TableIdentifier} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} - - -/** - * An interface for looking up relations by name. Used by an [[Analyzer]]. - */ -trait Catalog { - - val conf: CatalystConf - - def tableExists(tableIdent: TableIdentifier): Boolean - - def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan - - def setCurrentDatabase(databaseName: String): Unit = { - throw new UnsupportedOperationException - } - - /** - * Returns tuples of (tableName, isTemporary) for all tables in the given database. - * isTemporary is a Boolean value indicates if a table is a temporary or not. - */ - def getTables(databaseName: Option[String]): Seq[(String, Boolean)] - - def refreshTable(tableIdent: TableIdentifier): Unit - - def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit - - def unregisterTable(tableIdent: TableIdentifier): Unit - - def unregisterAllTables(): Unit - - /** - * Get the table name of TableIdentifier for temporary tables. - */ - protected def getTableName(tableIdent: TableIdentifier): String = { - // It is not allowed to specify database name for temporary tables. - // We check it here and throw exception if database is defined. - if (tableIdent.database.isDefined) { - throw new AnalysisException("Specifying database name or other qualifiers are not allowed " + - "for temporary tables. If the table name has dots (.) in it, please quote the " + - "table name with backticks (`).") - } - if (conf.caseSensitiveAnalysis) { - tableIdent.table - } else { - tableIdent.table.toLowerCase - } - } -} - -class SimpleCatalog(val conf: CatalystConf) extends Catalog { - private[this] val tables = new ConcurrentHashMap[String, LogicalPlan] - - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - tables.put(getTableName(tableIdent), plan) - } - - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - tables.remove(getTableName(tableIdent)) - } - - override def unregisterAllTables(): Unit = { - tables.clear() - } - - override def tableExists(tableIdent: TableIdentifier): Boolean = { - tables.containsKey(getTableName(tableIdent)) - } - - override def lookupRelation( - tableIdent: TableIdentifier, - alias: Option[String] = None): LogicalPlan = { - val tableName = getTableName(tableIdent) - val table = tables.get(tableName) - if (table == null) { - throw new AnalysisException("Table not found: " + tableName) - } - val qualifiedTable = SubqueryAlias(tableName, table) - - // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are - // properly qualified with this alias. - alias - .map(a => SubqueryAlias(a, qualifiedTable)) - .getOrElse(qualifiedTable) - } - - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - tables.keySet().asScala.map(_ -> true).toSeq - } - - override def refreshTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException - } -} - -/** - * A trait that can be mixed in with other Catalogs allowing specific tables to be overridden with - * new logical plans. This can be used to bind query result to virtual tables, or replace tables - * with in-memory cached versions. Note that the set of overrides is stored in memory and thus - * lost when the JVM exits. - */ -trait OverrideCatalog extends Catalog { - private[this] val overrides = new ConcurrentHashMap[String, LogicalPlan] - - private def getOverriddenTable(tableIdent: TableIdentifier): Option[LogicalPlan] = { - if (tableIdent.database.isDefined) { - None - } else { - Option(overrides.get(getTableName(tableIdent))) - } - } - - abstract override def tableExists(tableIdent: TableIdentifier): Boolean = { - getOverriddenTable(tableIdent) match { - case Some(_) => true - case None => super.tableExists(tableIdent) - } - } - - abstract override def lookupRelation( - tableIdent: TableIdentifier, - alias: Option[String] = None): LogicalPlan = { - getOverriddenTable(tableIdent) match { - case Some(table) => - val tableName = getTableName(tableIdent) - val qualifiedTable = SubqueryAlias(tableName, table) - - // If an alias was specified by the lookup, wrap the plan in a sub-query so that attributes - // are properly qualified with this alias. - alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) - - case None => super.lookupRelation(tableIdent, alias) - } - } - - abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - overrides.keySet().asScala.map(_ -> true).toSeq ++ super.getTables(databaseName) - } - - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - overrides.put(getTableName(tableIdent), plan) - } - - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - if (tableIdent.database.isEmpty) { - overrides.remove(getTableName(tableIdent)) - } - } - - override def unregisterAllTables(): Unit = { - overrides.clear() - } -} - -/** - * A trivial catalog that returns an error when a relation is requested. Used for testing when all - * relations are already filled in and the analyzer needs only to resolve attribute references. - */ -object EmptyCatalog extends Catalog { - - override val conf: CatalystConf = EmptyConf - - override def tableExists(tableIdent: TableIdentifier): Boolean = { - throw new UnsupportedOperationException - } - - override def lookupRelation( - tableIdent: TableIdentifier, - alias: Option[String] = None): LogicalPlan = { - throw new UnsupportedOperationException - } - - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - throw new UnsupportedOperationException - } - - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - throw new UnsupportedOperationException - } - - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException - } - - override def unregisterAllTables(): Unit = { - throw new UnsupportedOperationException - } - - override def refreshTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 9518309fbf8ea..e73d367a730e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -34,7 +34,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str errors.TreeNodeException(tree, s"Invalid call to $function on unresolved object", null) /** - * Holds the name of a relation that has yet to be looked up in a [[Catalog]]. + * Holds the name of a relation that has yet to be looked up in a catalog. */ case class UnresolvedRelation( tableIdentifier: TableIdentifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 7ead1ddebe852..e216fa552804b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -52,37 +52,34 @@ class InMemoryCatalog extends ExternalCatalog { names.filter { funcName => regex.pattern.matcher(funcName).matches() } } - private def existsFunction(db: String, funcName: String): Boolean = { + private def functionExists(db: String, funcName: String): Boolean = { requireDbExists(db) catalog(db).functions.contains(funcName) } - private def existsTable(db: String, table: String): Boolean = { - requireDbExists(db) - catalog(db).tables.contains(table) - } - - private def existsPartition(db: String, table: String, spec: TablePartitionSpec): Boolean = { + private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = { requireTableExists(db, table) catalog(db).tables(table).partitions.contains(spec) } private def requireFunctionExists(db: String, funcName: String): Unit = { - if (!existsFunction(db, funcName)) { - throw new AnalysisException(s"Function '$funcName' does not exist in database '$db'") + if (!functionExists(db, funcName)) { + throw new AnalysisException( + s"Function not found: '$funcName' does not exist in database '$db'") } } private def requireTableExists(db: String, table: String): Unit = { - if (!existsTable(db, table)) { - throw new AnalysisException(s"Table '$table' does not exist in database '$db'") + if (!tableExists(db, table)) { + throw new AnalysisException( + s"Table not found: '$table' does not exist in database '$db'") } } private def requirePartitionExists(db: String, table: String, spec: TablePartitionSpec): Unit = { - if (!existsPartition(db, table, spec)) { + if (!partitionExists(db, table, spec)) { throw new AnalysisException( - s"Partition does not exist in database '$db' table '$table': '$spec'") + s"Partition not found: database '$db' table '$table' does not contain: '$spec'") } } @@ -159,7 +156,7 @@ class InMemoryCatalog extends ExternalCatalog { ignoreIfExists: Boolean): Unit = synchronized { requireDbExists(db) val table = tableDefinition.name.table - if (existsTable(db, table)) { + if (tableExists(db, table)) { if (!ignoreIfExists) { throw new AnalysisException(s"Table '$table' already exists in database '$db'") } @@ -173,7 +170,7 @@ class InMemoryCatalog extends ExternalCatalog { table: String, ignoreIfNotExists: Boolean): Unit = synchronized { requireDbExists(db) - if (existsTable(db, table)) { + if (tableExists(db, table)) { catalog(db).tables.remove(table) } else { if (!ignoreIfNotExists) { @@ -200,13 +197,17 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables(table).table } + override def tableExists(db: String, table: String): Boolean = synchronized { + requireDbExists(db) + catalog(db).tables.contains(table) + } + override def listTables(db: String): Seq[String] = synchronized { requireDbExists(db) catalog(db).tables.keySet.toSeq } override def listTables(db: String, pattern: String): Seq[String] = synchronized { - requireDbExists(db) filterPattern(listTables(db), pattern) } @@ -295,7 +296,7 @@ class InMemoryCatalog extends ExternalCatalog { override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) - if (existsFunction(db, func.name.funcName)) { + if (functionExists(db, func.name.funcName)) { throw new AnalysisException(s"Function '$func' already exists in '$db' database") } else { catalog(db).functions.put(func.name.funcName, func) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3ac2bcf7e8d03..34265faa74399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} @@ -31,17 +32,34 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary * tables and functions of the Spark Session that it belongs to. */ -class SessionCatalog(externalCatalog: ExternalCatalog) { +class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { import ExternalCatalog._ - private[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan] - private[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction] + def this(externalCatalog: ExternalCatalog) { + this(externalCatalog, new SimpleCatalystConf(true)) + } + + protected[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan] + protected[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. - private[this] var currentDb = "default" + protected[this] var currentDb = { + val defaultName = "default" + val defaultDbDefinition = CatalogDatabase(defaultName, "default database", "", Map()) + // Initialize default database if it doesn't already exist + createDatabase(defaultDbDefinition, ignoreIfExists = true) + defaultName + } + + /** + * Format table name, taking into account case sensitivity. + */ + protected[this] def formatTableName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase + } // ---------------------------------------------------------------------------- // Databases @@ -105,8 +123,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { val db = tableDefinition.name.database.getOrElse(currentDb) - val newTableDefinition = tableDefinition.copy( - name = TableIdentifier(tableDefinition.name.table, Some(db))) + val table = formatTableName(tableDefinition.name.table) + val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db))) externalCatalog.createTable(db, newTableDefinition, ignoreIfExists) } @@ -121,8 +139,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def alterTable(tableDefinition: CatalogTable): Unit = { val db = tableDefinition.name.database.getOrElse(currentDb) - val newTableDefinition = tableDefinition.copy( - name = TableIdentifier(tableDefinition.name.table, Some(db))) + val table = formatTableName(tableDefinition.name.table) + val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db))) externalCatalog.alterTable(db, newTableDefinition) } @@ -132,7 +150,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def getTable(name: TableIdentifier): CatalogTable = { val db = name.database.getOrElse(currentDb) - externalCatalog.getTable(db, name.table) + val table = formatTableName(name.table) + externalCatalog.getTable(db, table) } // ------------------------------------------------------------- @@ -146,10 +165,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { name: String, tableDefinition: LogicalPlan, ignoreIfExists: Boolean): Unit = { - if (tempTables.containsKey(name) && !ignoreIfExists) { + val table = formatTableName(name) + if (tempTables.containsKey(table) && !ignoreIfExists) { throw new AnalysisException(s"Temporary table '$name' already exists.") } - tempTables.put(name, tableDefinition) + tempTables.put(table, tableDefinition) } /** @@ -166,11 +186,13 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { throw new AnalysisException("rename does not support moving tables across databases") } val db = oldName.database.getOrElse(currentDb) - if (oldName.database.isDefined || !tempTables.containsKey(oldName.table)) { - externalCatalog.renameTable(db, oldName.table, newName.table) + val oldTableName = formatTableName(oldName.table) + val newTableName = formatTableName(newName.table) + if (oldName.database.isDefined || !tempTables.containsKey(oldTableName)) { + externalCatalog.renameTable(db, oldTableName, newTableName) } else { - val table = tempTables.remove(oldName.table) - tempTables.put(newName.table, table) + val table = tempTables.remove(oldTableName) + tempTables.put(newTableName, table) } } @@ -183,10 +205,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = { val db = name.database.getOrElse(currentDb) - if (name.database.isDefined || !tempTables.containsKey(name.table)) { - externalCatalog.dropTable(db, name.table, ignoreIfNotExists) + val table = formatTableName(name.table) + if (name.database.isDefined || !tempTables.containsKey(table)) { + externalCatalog.dropTable(db, table, ignoreIfNotExists) } else { - tempTables.remove(name.table) + tempTables.remove(table) } } @@ -199,28 +222,43 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = { val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) val relation = - if (name.database.isDefined || !tempTables.containsKey(name.table)) { - val metadata = externalCatalog.getTable(db, name.table) + if (name.database.isDefined || !tempTables.containsKey(table)) { + val metadata = externalCatalog.getTable(db, table) CatalogRelation(db, metadata, alias) } else { - tempTables.get(name.table) + tempTables.get(table) } - val qualifiedTable = SubqueryAlias(name.table, relation) + val qualifiedTable = SubqueryAlias(table, relation) // If an alias was specified by the lookup, wrap the plan in a subquery so that // attributes are properly qualified with this alias. alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) } /** - * List all tables in the specified database, including temporary tables. + * Return whether a table with the specified name exists. + * + * Note: If a database is explicitly specified, then this will return whether the table + * exists in that particular database instead. In that case, even if there is a temporary + * table with the same name, we will return false if the specified database does not + * contain the table. */ - def listTables(db: String): Seq[TableIdentifier] = { - val dbTables = externalCatalog.listTables(db).map { t => TableIdentifier(t, Some(db)) } - val _tempTables = tempTables.keys().asScala.map { t => TableIdentifier(t) } - dbTables ++ _tempTables + def tableExists(name: TableIdentifier): Boolean = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + if (name.database.isDefined || !tempTables.containsKey(table)) { + externalCatalog.tableExists(db, table) + } else { + true // it's a temporary table + } } + /** + * List all tables in the specified database, including temporary tables. + */ + def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") + /** * List all matching tables in the specified database, including temporary tables. */ @@ -234,6 +272,19 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { dbTables ++ _tempTables } + /** + * Refresh the cache entry for a metastore table, if any. + */ + def refreshTable(name: TableIdentifier): Unit = { /* no-op */ } + + /** + * Drop all existing temporary tables. + * For testing only. + */ + def clearTempTables(): Unit = { + tempTables.clear() + } + /** * Return a temporary table exactly as it was stored. * For testing only. @@ -263,7 +314,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = { val db = tableName.database.getOrElse(currentDb) - externalCatalog.createPartitions(db, tableName.table, parts, ignoreIfExists) + val table = formatTableName(tableName.table) + externalCatalog.createPartitions(db, table, parts, ignoreIfExists) } /** @@ -275,7 +327,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = { val db = tableName.database.getOrElse(currentDb) - externalCatalog.dropPartitions(db, tableName.table, parts, ignoreIfNotExists) + val table = formatTableName(tableName.table) + externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists) } /** @@ -289,7 +342,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { specs: Seq[TablePartitionSpec], newSpecs: Seq[TablePartitionSpec]): Unit = { val db = tableName.database.getOrElse(currentDb) - externalCatalog.renamePartitions(db, tableName.table, specs, newSpecs) + val table = formatTableName(tableName.table) + externalCatalog.renamePartitions(db, table, specs, newSpecs) } /** @@ -303,7 +357,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = { val db = tableName.database.getOrElse(currentDb) - externalCatalog.alterPartitions(db, tableName.table, parts) + val table = formatTableName(tableName.table) + externalCatalog.alterPartitions(db, table, parts) } /** @@ -312,7 +367,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = { val db = tableName.database.getOrElse(currentDb) - externalCatalog.getPartition(db, tableName.table, spec) + val table = formatTableName(tableName.table) + externalCatalog.getPartition(db, table, spec) } /** @@ -321,7 +377,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog) { */ def listPartitions(tableName: TableIdentifier): Seq[CatalogTablePartition] = { val db = tableName.database.getOrElse(currentDb) - externalCatalog.listPartitions(db, tableName.table) + val table = formatTableName(tableName.table) + externalCatalog.listPartitions(db, table) } // ---------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c4e49614c5c35..34803133f6a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -91,6 +91,8 @@ abstract class ExternalCatalog { def getTable(db: String, table: String): CatalogTable + def tableExists(db: String, table: String): Boolean + def listTables(db: String): Seq[String] def listTables(db: String, pattern: String): Seq[String] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 8b568b6dd6acd..afc2f327df997 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -161,14 +161,10 @@ class AnalysisSuite extends AnalysisTest { } test("resolve relations") { - assertAnalysisError( - UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) - + assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq()) checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) - checkAnalysis( UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false) - checkAnalysis( UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 39166c4f8ef73..6fa4beed99267 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -18,26 +18,21 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ trait AnalysisTest extends PlanTest { - val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { - val caseSensitiveConf = new SimpleCatalystConf(caseSensitiveAnalysis = true) - val caseInsensitiveConf = new SimpleCatalystConf(caseSensitiveAnalysis = false) + protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true) + protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false) - val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) - val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - - caseSensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) - caseInsensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) - - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { - override val extendedResolutionRules = EliminateSubqueryAliases :: Nil - } -> - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { + private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { + val conf = new SimpleCatalystConf(caseSensitive) + val catalog = new SessionCatalog(new InMemoryCatalog, conf) + catalog.createTempTable("TaBlE", TestRelations.testRelation, ignoreIfExists = true) + new Analyzer(catalog, EmptyFunctionRegistry, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 9aa685e1e8f55..31501864a8e13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -30,11 +31,11 @@ import org.apache.spark.sql.types._ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { - val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) - val catalog = new SimpleCatalog(conf) - val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) + private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) + private val catalog = new SessionCatalog(new InMemoryCatalog, conf) + private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) - val relation = LocalRelation( + private val relation = LocalRelation( AttributeReference("i", IntegerType)(), AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), @@ -43,15 +44,15 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { AttributeReference("b", DoubleType)() ) - val i: Expression = UnresolvedAttribute("i") - val d1: Expression = UnresolvedAttribute("d1") - val d2: Expression = UnresolvedAttribute("d2") - val u: Expression = UnresolvedAttribute("u") - val f: Expression = UnresolvedAttribute("f") - val b: Expression = UnresolvedAttribute("b") + private val i: Expression = UnresolvedAttribute("i") + private val d1: Expression = UnresolvedAttribute("d1") + private val d2: Expression = UnresolvedAttribute("d2") + private val u: Expression = UnresolvedAttribute("u") + private val f: Expression = UnresolvedAttribute("f") + private val b: Expression = UnresolvedAttribute("b") before { - catalog.registerTable(TableIdentifier("table"), relation) + catalog.createTempTable("table", relation, ignoreIfExists = true) } private def checkType(expression: Expression, expectedType: DataType): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index a1ea61920dd68..277c2d717e3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -225,13 +225,14 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { test("list tables without pattern") { val catalog = newBasicCatalog() + intercept[AnalysisException] { catalog.listTables("unknown_db") } assert(catalog.listTables("db1").toSet == Set.empty) assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) } test("list tables with pattern") { val catalog = newBasicCatalog() - intercept[AnalysisException] { catalog.listTables("unknown_db") } + intercept[AnalysisException] { catalog.listTables("unknown_db", "*") } assert(catalog.listTables("db1", "*").toSet == Set.empty) assert(catalog.listTables("db2", "*").toSet == Set("tbl1", "tbl2")) assert(catalog.listTables("db2", "tbl*").toSet == Set("tbl1", "tbl2")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index e1973ee258235..74e995cc5b4b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -397,6 +397,24 @@ class SessionCatalogSuite extends SparkFunSuite { TableIdentifier("tbl1", Some("db2")), alias = Some(alias)) == relationWithAlias) } + test("table exists") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) + assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) + assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) + // If database is explicitly specified, do not check temporary tables + val tempTable = Range(1, 10, 1, 10, Seq()) + catalog.createTempTable("tbl3", tempTable, ignoreIfExists = false) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + // If database is not explicitly specified, check the current database + catalog.setCurrentDatabase("db2") + assert(catalog.tableExists(TableIdentifier("tbl1"))) + assert(catalog.tableExists(TableIdentifier("tbl2"))) + assert(catalog.tableExists(TableIdentifier("tbl3"))) + } + test("list tables without pattern") { val catalog = new SessionCatalog(newBasicCatalog()) val tempTable = Range(1, 10, 2, 10, Seq()) @@ -429,7 +447,7 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalog.listTables("db2", "*1").toSet == Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) intercept[AnalysisException] { - catalog.listTables("unknown_db") + catalog.listTables("unknown_db", "*") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 2ab31eea8ab38..e2c76b700f51c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -137,11 +138,11 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) } - private val caseInsensitiveAnalyzer = - new Analyzer( - EmptyCatalog, - EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = false)) + private val caseInsensitiveConf = new SimpleCatalystConf(false) + private val caseInsensitiveAnalyzer = new Analyzer( + new SessionCatalog(new InMemoryCatalog, caseInsensitiveConf), + EmptyFunctionRegistry, + caseInsensitiveConf) test("(a && b) || (a && c) => a && (b || c) when case insensitive") { val plan = caseInsensitiveAnalyzer.execute( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index a4c8d1c6d2aa8..3824c675630c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, SimpleCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._ class EliminateSortsSuite extends PlanTest { val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) - val catalog = new SimpleCatalog(conf) + val catalog = new SessionCatalog(new InMemoryCatalog, conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 853a74c827d47..e413e77bc1349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,13 +25,14 @@ import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} @@ -65,13 +66,14 @@ class SQLContext private[sql]( @transient val sparkContext: SparkContext, @transient protected[sql] val cacheManager: CacheManager, @transient private[sql] val listener: SQLListener, - val isRootContext: Boolean) + val isRootContext: Boolean, + @transient private[sql] val externalCatalog: ExternalCatalog) extends Logging with Serializable { self => - def this(sparkContext: SparkContext) = { - this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true) + def this(sc: SparkContext) = { + this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), true, new InMemoryCatalog) } def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) @@ -109,7 +111,8 @@ class SQLContext private[sql]( sparkContext = sparkContext, cacheManager = cacheManager, listener = listener, - isRootContext = false) + isRootContext = false, + externalCatalog = externalCatalog) } /** @@ -186,6 +189,12 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + // Extract `spark.sql.*` entries and put it in our SQLConf. + // Subclasses may additionally set these entries in other confs. + SQLContext.getSQLProperties(sparkContext.getConf).asScala.foreach { case (k, v) => + setConf(k, v) + } + protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) @@ -199,30 +208,6 @@ class SQLContext private[sql]( sparkContext.addJar(path) } - { - // We extract spark sql settings from SparkContext's conf and put them to - // Spark SQL's conf. - // First, we populate the SQLConf (conf). So, we can make sure that other values using - // those settings in their construction can get the correct settings. - // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version - // and spark.sql.hive.metastore.jars to get correctly constructed. - val properties = new Properties - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) - case _ => - } - // We directly put those settings to conf to avoid of calling setConf, which may have - // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive - // get constructed. If we call setConf directly, the constructed metadataHive may have - // wrong settings, or the construction may fail. - conf.setConf(properties) - // After we have populated SQLConf, we call setConf to populate other confs in the subclass - // (e.g. hiveconf in HiveContext). - properties.asScala.foreach { - case (key, value) => setConf(key, value) - } - } - /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into @@ -683,8 +668,10 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - sessionState.catalog.registerTable( - sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) + sessionState.catalog.createTempTable( + sessionState.sqlParser.parseTableIdentifier(tableName).table, + df.logicalPlan, + ignoreIfExists = true) } /** @@ -697,7 +684,7 @@ class SQLContext private[sql]( */ def dropTempTable(tableName: String): Unit = { cacheManager.tryUncacheQuery(table(tableName)) - sessionState.catalog.unregisterTable(TableIdentifier(tableName)) + sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true) } /** @@ -824,9 +811,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(): Array[String] = { - sessionState.catalog.getTables(None).map { - case (tableName, _) => tableName - }.toArray + tableNames(sessionState.catalog.getCurrentDatabase) } /** @@ -836,9 +821,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(databaseName: String): Array[String] = { - sessionState.catalog.getTables(Some(databaseName)).map { - case (tableName, _) => tableName - }.toArray + sessionState.catalog.listTables(databaseName).map(_.table).toArray } @transient @@ -1025,4 +1008,18 @@ object SQLContext { } sqlListener.get() } + + /** + * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]]. + */ + private[sql] def getSQLProperties(sparkConf: SparkConf): Properties = { + val properties = new Properties + sparkConf.getAll.foreach { case (key, value) => + if (key.startsWith("spark.sql")) { + properties.setProperty(key, value) + } + } + properties + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 59c3ffcf488c7..964f0a7a7b4e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -339,10 +339,12 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma override def run(sqlContext: SQLContext): Seq[Row] = { // Since we need to return a Seq of rows, we will call getTables directly // instead of calling tables in sqlContext. - val rows = sqlContext.sessionState.catalog.getTables(databaseName).map { - case (tableName, isTemporary) => Row(tableName, isTemporary) + val catalog = sqlContext.sessionState.catalog + val db = databaseName.getOrElse(catalog.getCurrentDatabase) + val rows = catalog.listTables(db).map { t => + val isTemp = t.database.isEmpty + Row(t.table, isTemp) } - rows } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 9e8e0352db644..24923bbb10c74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -93,15 +93,21 @@ case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { + if (tableIdent.database.isDefined) { + throw new AnalysisException( + s"Temporary table '$tableIdent' should not have specified a database") + } + def run(sqlContext: SQLContext): Seq[Row] = { val dataSource = DataSource( sqlContext, userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) - sqlContext.sessionState.catalog.registerTable( - tableIdent, - Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan) + sqlContext.sessionState.catalog.createTempTable( + tableIdent.table, + Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan, + ignoreIfExists = true) Seq.empty[Row] } @@ -115,6 +121,11 @@ case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { + if (tableIdent.database.isDefined) { + throw new AnalysisException( + s"Temporary table '$tableIdent' should not have specified a database") + } + override def run(sqlContext: SQLContext): Seq[Row] = { val df = Dataset.ofRows(sqlContext, query) val dataSource = DataSource( @@ -124,9 +135,10 @@ case class CreateTempTableUsingAsSelect( bucketSpec = None, options = options) val result = dataSource.write(mode, df) - sqlContext.sessionState.catalog.registerTable( - tableIdent, - Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan) + sqlContext.sessionState.catalog.createTempTable( + tableIdent.table, + Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan, + ignoreIfExists = true) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 63f0e4f8c96ac..28ac4583e9b25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} /** @@ -99,7 +101,9 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { /** * A rule to do various checks before inserting into or writing to a data source table. */ -private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { +private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) + extends (LogicalPlan => Unit) { + def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } def apply(plan: LogicalPlan): Unit = { @@ -139,7 +143,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - r.schema, part.keySet.toSeq, catalog.conf.caseSensitiveAnalysis) + r.schema, part.keySet.toSeq, conf.caseSensitiveAnalysis) // Get all input data source relations of the query. val srcRelations = query.collect { @@ -190,7 +194,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis) for { spec <- c.bucketSpec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index e6be0ab3bc420..e5f02caabcca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.internal import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -45,7 +46,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Internal catalog for managing table and database states. */ - lazy val catalog: Catalog = new SimpleCatalog(conf) + lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf) /** * Internal catalog for managing functions registered by the user. @@ -68,7 +69,7 @@ private[sql] class SessionState(ctx: SQLContext) { DataSourceAnalysis :: (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) - override val extendedCheckRules = Seq(datasources.PreWriteCheck(catalog)) + override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2820e4fa23e13..bb54c525cb76d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -33,7 +33,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) } test("get all tables") { @@ -45,20 +46,22 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } - test("getting all Tables with a database name has no impact on returned table names") { + test("getting all tables with a database name has no impact on returned table names") { checkAnswer( - sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("default").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in default").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 2ad92b52c4ff0..2f62ad4850dee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ +class SQLContextSuite extends SparkFunSuite with SharedSparkContext { object DummyRule extends Rule[LogicalPlan] { def apply(p: LogicalPlan): LogicalPlan = p @@ -78,4 +78,11 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ sqlContext.experimental.extraOptimizations = Seq(DummyRule) assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } + + test("SQLContext can access `spark.sql.*` configs") { + sc.conf.set("spark.sql.with.or.without.you", "my love") + val sqlContext = new SQLContext(sc) + assert(sqlContext.getConf("spark.sql.with.or.without.you") == "my love") + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2733ae7d98c1e..bd13474e738fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1397,12 +1397,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4699 case sensitivity SQL query") { - sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) - val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.toDF().registerTempTable("testTable1") - checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) + val orig = sqlContext.getConf(SQLConf.CASE_SENSITIVE) + try { + sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + rdd.toDF().registerTempTable("testTable1") + checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) + } finally { + sqlContext.setConf(SQLConf.CASE_SENSITIVE, orig) + } } test("SPARK-6145: ORDER BY test for nested fields") { @@ -1676,7 +1680,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .format("parquet") .save(path) - val message = intercept[AnalysisException] { + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { sqlContext.sql( s""" |CREATE TEMPORARY TABLE db.t @@ -1686,9 +1691,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |) """.stripMargin) }.getMessage - assert(message.contains("Specifying database name or other qualifiers are not allowed")) - // If you use backticks to quote the name of a temporary table having dot in it. + // If you use backticks to quote the name then it's OK. sqlContext.sql( s""" |CREATE TEMPORARY TABLE `db.t` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index f8166c7ddc4da..2f806ebba6f96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -51,7 +51,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { @@ -61,7 +62,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true) } test("self-join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48358566e38e..80a85a6615974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -189,8 +189,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sql(s"USE $db") - try f finally sqlContext.sql(s"USE default") + sqlContext.sessionState.catalog.setCurrentDatabase(db) + try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default") } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7fe31b0025272..57693284b01df 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -150,7 +150,8 @@ private[hive] object SparkSQLCLIDriver extends Logging { } if (sessionState.database != null) { - SparkSQLEnv.hiveContext.runSqlHive(s"USE ${sessionState.database}") + SparkSQLEnv.hiveContext.sessionState.catalog.setCurrentDatabase( + s"${sessionState.database}") } // Execute -i init files (always in silent mode) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 032965d0d9c28..8e1ebe2937d23 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -193,10 +193,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { ) runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( - "" - -> "OK", - "" - -> "hive_test" + "" -> "hive_test" ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala index 491f2aebb4f4a..0722fb02a8f9d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala @@ -85,7 +85,6 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit withClient { getTable(db, table) } } - // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- @@ -182,6 +181,10 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit client.getTable(db, table) } + override def tableExists(db: String, table: String): Boolean = withClient { + client.getTableOption(db, table).isDefined + } + override def listTables(db: String): Seq[String] = withClient { requireDbExists(db) client.listTables(db) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 914f8e9a9893a..ca3ce43591f5f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -28,6 +28,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal @@ -38,7 +39,7 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.hadoop.util.VersionInfo -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql._ @@ -52,6 +53,7 @@ import org.apache.spark.sql.execution.command.{ExecutedCommand, SetCommand} import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SQLConfEntry import org.apache.spark.sql.internal.SQLConf.SQLConfEntry._ import org.apache.spark.sql.types._ @@ -67,7 +69,7 @@ private[hive] case class CurrentDatabase(ctx: HiveContext) override def foldable: Boolean = true override def nullable: Boolean = false override def eval(input: InternalRow): Any = { - UTF8String.fromString(ctx.metadataHive.currentDatabase) + UTF8String.fromString(ctx.sessionState.catalog.getCurrentDatabase) } } @@ -81,15 +83,31 @@ class HiveContext private[hive]( sc: SparkContext, cacheManager: CacheManager, listener: SQLListener, - @transient private val execHive: HiveClientImpl, - @transient private val metaHive: HiveClient, - isRootContext: Boolean) - extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging { + @transient private[hive] val executionHive: HiveClientImpl, + @transient private[hive] val metadataHive: HiveClient, + isRootContext: Boolean, + @transient private[sql] val hiveCatalog: HiveCatalog) + extends SQLContext(sc, cacheManager, listener, isRootContext, hiveCatalog) with Logging { self => + private def this(sc: SparkContext, execHive: HiveClientImpl, metaHive: HiveClient) { + this( + sc, + new CacheManager, + SQLContext.createListenerAndUI(sc), + execHive, + metaHive, + true, + new HiveCatalog(metaHive)) + } + def this(sc: SparkContext) = { - this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), null, null, true) + this( + sc, + HiveContext.newClientForExecution(sc.conf, sc.hadoopConfiguration), + HiveContext.newClientForMetadata(sc.conf, sc.hadoopConfiguration)) } + def this(sc: JavaSparkContext) = this(sc.sc) import org.apache.spark.sql.hive.HiveContext._ @@ -106,9 +124,10 @@ class HiveContext private[hive]( sc = sc, cacheManager = cacheManager, listener = listener, - execHive = executionHive.newSession(), - metaHive = metadataHive.newSession(), - isRootContext = false) + executionHive = executionHive.newSession(), + metadataHive = metadataHive.newSession(), + isRootContext = false, + hiveCatalog = hiveCatalog) } @transient @@ -149,41 +168,6 @@ class HiveContext private[hive]( */ protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) - /** - * The version of the hive client that will be used to communicate with the metastore. Note that - * this does not necessarily need to be the same version of Hive that is used internally by - * Spark SQL for execution. - */ - protected[hive] def hiveMetastoreVersion: String = getConf(HIVE_METASTORE_VERSION) - - /** - * The location of the jars that should be used to instantiate the HiveMetastoreClient. This - * property can be one of three options: - * - a classpath in the standard format for both hive and hadoop. - * - builtin - attempt to discover the jars that were used to load Spark SQL and use those. This - * option is only valid when using the execution version of Hive. - * - maven - download the correct version of hive on demand from maven. - */ - protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS) - - /** - * A comma separated list of class prefixes that should be loaded using the classloader that - * is shared between Spark SQL and a specific version of Hive. An example of classes that should - * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need - * to be shared are those that interact with classes that are already shared. For example, - * custom appenders that are used by log4j. - */ - protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = - getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") - - /** - * A comma separated list of class prefixes that should explicitly be reloaded for each version - * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a - * prefix that typically would be shared (i.e. org.apache.spark.*) - */ - protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = - getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") - /* * hive thrift server use background spark sql thread pool to execute sql queries */ @@ -195,29 +179,6 @@ class HiveContext private[hive]( @transient protected[sql] lazy val substitutor = new VariableSubstitution() - /** - * The copy of the hive client that is used for execution. Currently this must always be - * Hive 13 as this is the version of Hive that is packaged with Spark SQL. This copy of the - * client is used for execution related tasks like registering temporary functions or ensuring - * that the ThreadLocal SessionState is correctly populated. This copy of Hive is *not* used - * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. - */ - @transient - protected[hive] lazy val executionHive: HiveClientImpl = if (execHive != null) { - execHive - } else { - logInfo(s"Initializing execution hive, version $hiveExecutionVersion") - val loader = new IsolatedClientLoader( - version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), - sparkConf = sc.conf, - execJars = Seq(), - hadoopConf = sc.hadoopConfiguration, - config = newTemporaryConfiguration(useInMemoryDerby = true), - isolationOn = false, - baseClassLoader = Utils.getContextOrSparkClassLoader) - loader.createClient().asInstanceOf[HiveClientImpl] - } - /** * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. * - allow SQL11 keywords to be used as identifiers @@ -228,111 +189,6 @@ class HiveContext private[hive]( defaultOverrides() - /** - * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. - * The version of the Hive client that is used here must match the metastore that is configured - * in the hive-site.xml file. - */ - @transient - protected[hive] lazy val metadataHive: HiveClient = if (metaHive != null) { - metaHive - } else { - val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) - - // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options - // into the isolated client loader - val metadataConf = new HiveConf(sc.hadoopConfiguration, classOf[HiveConf]) - - val defaultWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") - logInfo("default warehouse location is " + defaultWarehouseLocation) - - // `configure` goes second to override other settings. - val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure - - val isolatedLoader = if (hiveMetastoreJars == "builtin") { - if (hiveExecutionVersion != hiveMetastoreVersion) { - throw new IllegalArgumentException( - "Builtin jars can only be used when hive execution version == hive metastore version. " + - s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") - } - - // We recursively find all jars in the class loader chain, - // starting from the given classLoader. - def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { - case null => Array.empty[URL] - case urlClassLoader: URLClassLoader => - urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) - case other => allJars(other.getParent) - } - - val classLoader = Utils.getContextOrSparkClassLoader - val jars = allJars(classLoader) - if (jars.length == 0) { - throw new IllegalArgumentException( - "Unable to locate hive jars to connect to metastore. " + - "Please set spark.sql.hive.metastore.jars.") - } - - logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using Spark classes.") - new IsolatedClientLoader( - version = metaVersion, - sparkConf = sc.conf, - execJars = jars.toSeq, - hadoopConf = sc.hadoopConfiguration, - config = allConfig, - isolationOn = true, - barrierPrefixes = hiveMetastoreBarrierPrefixes, - sharedPrefixes = hiveMetastoreSharedPrefixes) - } else if (hiveMetastoreJars == "maven") { - // TODO: Support for loading the jars from an already downloaded location. - logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion( - hiveMetastoreVersion = hiveMetastoreVersion, - hadoopVersion = VersionInfo.getVersion, - sparkConf = sc.conf, - hadoopConf = sc.hadoopConfiguration, - config = allConfig, - barrierPrefixes = hiveMetastoreBarrierPrefixes, - sharedPrefixes = hiveMetastoreSharedPrefixes) - } else { - // Convert to files and expand any directories. - val jars = - hiveMetastoreJars - .split(File.pathSeparator) - .flatMap { - case path if new File(path).getName() == "*" => - val files = new File(path).getParentFile().listFiles() - if (files == null) { - logWarning(s"Hive jar path '$path' does not exist.") - Nil - } else { - files.filter(_.getName().toLowerCase().endsWith(".jar")) - } - case path => - new File(path) :: Nil - } - .map(_.toURI.toURL) - - logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion " + - s"using ${jars.mkString(":")}") - new IsolatedClientLoader( - version = metaVersion, - sparkConf = sc.conf, - execJars = jars.toSeq, - hadoopConf = sc.hadoopConfiguration, - config = allConfig, - isolationOn = true, - barrierPrefixes = hiveMetastoreBarrierPrefixes, - sharedPrefixes = hiveMetastoreSharedPrefixes) - } - isolatedLoader.createClient() - } - protected[sql] override def parseSql(sql: String): LogicalPlan = { executionHive.withHiveState { super.parseSql(substitutor.substitute(hiveconf, sql)) @@ -432,7 +288,7 @@ class HiveContext private[hive]( // recorded in the Hive metastore. // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). if (newTotalSize > 0 && newTotalSize != oldTotalSize) { - sessionState.catalog.client.alterTable( + sessionState.catalog.alterTable( relation.table.copy( properties = relation.table.properties + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) @@ -459,64 +315,10 @@ class HiveContext private[hive]( setConf(entry.key, entry.stringConverter(value)) } - /** Overridden by child classes that need to set configuration before the client init. */ - protected def configure(): Map[String, String] = { - // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch - // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- - // compatibility when users are trying to connecting to a Hive metastore of lower version, - // because these options are expected to be integral values in lower versions of Hive. - // - // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according - // to their output time units. - Seq( - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, - ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, - ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, - ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, - ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, - ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, - ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, - ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, - ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, - ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, - ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS - ).map { case (confVar, unit) => - confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString - }.toMap - } - /** * SQLConf and HiveConf contracts: * - * 1. create a new SessionState for each HiveContext + * 1. create a new o.a.h.hive.ql.session.SessionState for each HiveContext * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be * set in the SQLConf *as well as* in the HiveConf. @@ -600,7 +402,7 @@ class HiveContext private[hive]( } -private[hive] object HiveContext { +private[hive] object HiveContext extends Logging { /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" @@ -666,6 +468,242 @@ private[hive] object HiveContext { defaultValue = Some(true), doc = "When set to true, Hive Thrift server executes SQL queries in an asynchronous way.") + /** + * The version of the hive client that will be used to communicate with the metastore. Note that + * this does not necessarily need to be the same version of Hive that is used internally by + * Spark SQL for execution. + */ + private def hiveMetastoreVersion(conf: SQLConf): String = { + conf.getConf(HIVE_METASTORE_VERSION) + } + + /** + * The location of the jars that should be used to instantiate the HiveMetastoreClient. This + * property can be one of three options: + * - a classpath in the standard format for both hive and hadoop. + * - builtin - attempt to discover the jars that were used to load Spark SQL and use those. This + * option is only valid when using the execution version of Hive. + * - maven - download the correct version of hive on demand from maven. + */ + private def hiveMetastoreJars(conf: SQLConf): String = { + conf.getConf(HIVE_METASTORE_JARS) + } + + /** + * A comma separated list of class prefixes that should be loaded using the classloader that + * is shared between Spark SQL and a specific version of Hive. An example of classes that should + * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + * to be shared are those that interact with classes that are already shared. For example, + * custom appenders that are used by log4j. + */ + private def hiveMetastoreSharedPrefixes(conf: SQLConf): Seq[String] = { + conf.getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") + } + + /** + * A comma separated list of class prefixes that should explicitly be reloaded for each version + * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + * prefix that typically would be shared (i.e. org.apache.spark.*) + */ + private def hiveMetastoreBarrierPrefixes(conf: SQLConf): Seq[String] = { + conf.getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") + } + + /** + * Configurations needed to create a [[HiveClient]]. + */ + private[hive] def hiveClientConfigurations(hiveconf: HiveConf): Map[String, String] = { + // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch + // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- + // compatibility when users are trying to connecting to a Hive metastore of lower version, + // because these options are expected to be integral values in lower versions of Hive. + // + // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according + // to their output time units. + Seq( + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, + ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, + ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS + ).map { case (confVar, unit) => + confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString + }.toMap + } + + /** + * Create a [[HiveClient]] used for execution. + * + * Currently this must always be Hive 13 as this is the version of Hive that is packaged + * with Spark SQL. This copy of the client is used for execution related tasks like + * registering temporary functions or ensuring that the ThreadLocal SessionState is + * correctly populated. This copy of Hive is *not* used for storing persistent metadata, + * and only point to a dummy metastore in a temporary directory. + */ + protected[hive] def newClientForExecution( + conf: SparkConf, + hadoopConf: Configuration): HiveClientImpl = { + logInfo(s"Initializing execution hive, version $hiveExecutionVersion") + val loader = new IsolatedClientLoader( + version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), + sparkConf = conf, + execJars = Seq(), + hadoopConf = hadoopConf, + config = newTemporaryConfiguration(useInMemoryDerby = true), + isolationOn = false, + baseClassLoader = Utils.getContextOrSparkClassLoader) + loader.createClient().asInstanceOf[HiveClientImpl] + } + + /** + * Create a [[HiveClient]] used to retrieve metadata from the Hive MetaStore. + * + * The version of the Hive client that is used here must match the metastore that is configured + * in the hive-site.xml file. + */ + private def newClientForMetadata(conf: SparkConf, hadoopConf: Configuration): HiveClient = { + val hiveConf = new HiveConf(hadoopConf, classOf[HiveConf]) + val configurations = hiveClientConfigurations(hiveConf) + newClientForMetadata(conf, hiveConf, hadoopConf, configurations) + } + + protected[hive] def newClientForMetadata( + conf: SparkConf, + hiveConf: HiveConf, + hadoopConf: Configuration, + configurations: Map[String, String]): HiveClient = { + val sqlConf = new SQLConf + sqlConf.setConf(SQLContext.getSQLProperties(conf)) + val hiveMetastoreVersion = HiveContext.hiveMetastoreVersion(sqlConf) + val hiveMetastoreJars = HiveContext.hiveMetastoreJars(sqlConf) + val hiveMetastoreSharedPrefixes = HiveContext.hiveMetastoreSharedPrefixes(sqlConf) + val hiveMetastoreBarrierPrefixes = HiveContext.hiveMetastoreBarrierPrefixes(sqlConf) + val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) + + val defaultWarehouseLocation = hiveConf.get("hive.metastore.warehouse.dir") + logInfo("default warehouse location is " + defaultWarehouseLocation) + + // `configure` goes second to override other settings. + val allConfig = hiveConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configurations + + val isolatedLoader = if (hiveMetastoreJars == "builtin") { + if (hiveExecutionVersion != hiveMetastoreVersion) { + throw new IllegalArgumentException( + "Builtin jars can only be used when hive execution version == hive metastore version. " + + s"Execution: $hiveExecutionVersion != Metastore: $hiveMetastoreVersion. " + + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") + } + + // We recursively find all jars in the class loader chain, + // starting from the given classLoader. + def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { + case null => Array.empty[URL] + case urlClassLoader: URLClassLoader => + urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) + case other => allJars(other.getParent) + } + + val classLoader = Utils.getContextOrSparkClassLoader + val jars = allJars(classLoader) + if (jars.length == 0) { + throw new IllegalArgumentException( + "Unable to locate hive jars to connect to metastore. " + + "Please set spark.sql.hive.metastore.jars.") + } + + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using Spark classes.") + new IsolatedClientLoader( + version = metaVersion, + sparkConf = conf, + hadoopConf = hadoopConf, + execJars = jars.toSeq, + config = allConfig, + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } else if (hiveMetastoreJars == "maven") { + // TODO: Support for loading the jars from an already downloaded location. + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = hiveMetastoreVersion, + hadoopVersion = VersionInfo.getVersion, + sparkConf = conf, + hadoopConf = hadoopConf, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } else { + // Convert to files and expand any directories. + val jars = + hiveMetastoreJars + .split(File.pathSeparator) + .flatMap { + case path if new File(path).getName == "*" => + val files = new File(path).getParentFile.listFiles() + if (files == null) { + logWarning(s"Hive jar path '$path' does not exist.") + Nil + } else { + files.filter(_.getName.toLowerCase.endsWith(".jar")) + } + case path => + new File(path) :: Nil + } + .map(_.toURI.toURL) + + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion " + + s"using ${jars.mkString(":")}") + new IsolatedClientLoader( + version = metaVersion, + sparkConf = conf, + hadoopConf = hadoopConf, + execJars = jars.toSeq, + config = allConfig, + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } + isolatedLoader.createClient() + } + /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(useInMemoryDerby: Boolean): Map[String, String] = { val withInMemoryMode = if (useInMemoryDerby) "memory:" else "" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 27e4cfc103bee..c7066d73631af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.DataTypeParser @@ -98,27 +98,33 @@ private[hive] object HiveSerDe { } -// TODO: replace this with o.a.s.sql.hive.HiveCatalog once we merge SQLContext and HiveContext +/** + * Legacy catalog for interacting with the Hive metastore. + * + * This is still used for things like creating data source tables, but in the future will be + * cleaned up to integrate more nicely with [[HiveCatalog]]. + */ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) - extends Catalog with Logging { + extends Logging { val conf = hive.conf - /** Usages should lock on `this`. */ - protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) - /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) - private def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { + private def getCurrentDatabase: String = { + hive.sessionState.catalog.getCurrentDatabase + } + + def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { QualifiedTableName( - tableIdent.database.getOrElse(client.currentDatabase).toLowerCase, + tableIdent.database.getOrElse(getCurrentDatabase).toLowerCase, tableIdent.table.toLowerCase) } private def getQualifiedTableName(t: CatalogTable): QualifiedTableName = { QualifiedTableName( - t.name.database.getOrElse(client.currentDatabase).toLowerCase, + t.name.database.getOrElse(getCurrentDatabase).toLowerCase, t.name.table.toLowerCase) } @@ -194,7 +200,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - override def refreshTable(tableIdent: TableIdentifier): Unit = { + def refreshTable(tableIdent: TableIdentifier): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. // Since we also cache ParquetRelations converted from Hive Parquet tables and @@ -408,12 +414,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte new Path(new Path(client.getDatabase(dbName).locationUri), tblName).toString } - override def tableExists(tableIdent: TableIdentifier): Boolean = { - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - client.getTableOption(dbName, tblName).isDefined - } - - override def lookupRelation( + def lookupRelation( tableIdent: TableIdentifier, alias: Option[String]): LogicalPlan = { val qualifiedTableName = getQualifiedTableName(tableIdent) @@ -555,12 +556,6 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) } - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val db = databaseName.getOrElse(client.currentDatabase) - - client.listTables(db).map(tableName => (tableName, false)) - } - /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. @@ -716,27 +711,6 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte } } - /** - * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. - * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. - */ - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - throw new UnsupportedOperationException - } - - /** - * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. - * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. - */ - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException - } - - override def unregisterAllTables(): Unit = {} - - override def setCurrentDatabase(databaseName: String): Unit = { - client.setCurrentDatabase(databaseName) - } } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala new file mode 100644 index 0000000000000..aa44cba4b5641 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.hive.client.HiveClient +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + + +class HiveSessionCatalog( + externalCatalog: HiveCatalog, + client: HiveClient, + context: HiveContext, + conf: SQLConf) + extends SessionCatalog(externalCatalog, conf) { + + override def setCurrentDatabase(db: String): Unit = { + super.setCurrentDatabase(db) + client.setCurrentDatabase(db) + } + + override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = { + val table = formatTableName(name.table) + if (name.database.isDefined || !tempTables.containsKey(table)) { + val newName = name.copy(table = table) + metastoreCatalog.lookupRelation(newName, alias) + } else { + val relation = tempTables.get(table) + val tableWithQualifiers = SubqueryAlias(table, relation) + // If an alias was specified by the lookup, wrap the plan in a subquery so that + // attributes are properly qualified with this alias. + alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + } + } + + // ---------------------------------------------------------------- + // | Methods and fields for interacting with HiveMetastoreCatalog | + // ---------------------------------------------------------------- + + // Catalog for handling data source tables. TODO: This really doesn't belong here since it is + // essentially a cache for metastore tables. However, it relies on a lot of session-specific + // things so it would be a lot of work to split its functionality between HiveSessionCatalog + // and HiveCatalog. We should still do it at some point... + private val metastoreCatalog = new HiveMetastoreCatalog(client, context) + + val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions + val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables + val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts + + override def refreshTable(name: TableIdentifier): Unit = { + metastoreCatalog.refreshTable(name) + } + + def invalidateTable(name: TableIdentifier): Unit = { + metastoreCatalog.invalidateTable(name) + } + + def invalidateCache(): Unit = { + metastoreCatalog.cachedDataSourceTables.invalidateAll() + } + + def createDataSourceTable( + name: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { + metastoreCatalog.createDataSourceTable( + name, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options, isExternal) + } + + def hiveDefaultTableFilePath(name: TableIdentifier): String = { + metastoreCatalog.hiveDefaultTableFilePath(name) + } + + // For testing only + private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + val key = metastoreCatalog.getQualifiedTableName(table) + metastoreCatalog.cachedDataSourceTables.getIfPresent(key) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index d9cd96d66f493..caa7f296ed16a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.execution.{python, SparkPlanner} import org.apache.spark.sql.execution.datasources._ @@ -35,9 +35,11 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) } /** - * A metadata catalog that points to the Hive metastore. + * Internal catalog for managing table and database states. */ - override lazy val catalog = new HiveMetastoreCatalog(ctx.metadataHive, ctx) with OverrideCatalog + override lazy val catalog = { + new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, conf) + } /** * Internal catalog for managing functions registered by the user. @@ -61,7 +63,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) DataSourceAnalysis :: (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) - override val extendedCheckRules = Seq(PreWriteCheck(catalog)) + override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index d214e5288eff0..f4d30358cafa8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -53,9 +53,6 @@ private[hive] trait HiveClient { /** Returns the names of tables in the given database that matches the given pattern. */ def listTables(dbName: String, pattern: String): Seq[String] - /** Returns the name of the active database. */ - def currentDatabase: String - /** Sets the name of current database. */ def setCurrentDatabase(databaseName: String): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 928408c52bd23..e4e15d13df658 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -241,10 +241,6 @@ private[hive] class HiveClientImpl( state.err = stream } - override def currentDatabase: String = withHiveState { - state.getCurrentDatabase - } - override def setCurrentDatabase(databaseName: String): Unit = withHiveState { if (getDatabaseOption(databaseName).isDefined) { state.setCurrentDatabase(databaseName) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 391e2975d0086..5a61eef0f2439 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -69,10 +69,10 @@ case class CreateTableAsSelect( withFormat } - hiveContext.sessionState.catalog.client.createTable(withSchema, ignoreIfExists = false) + hiveContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) // Get the Metastore Relation - hiveContext.sessionState.catalog.lookupRelation(tableIdentifier, None) match { + hiveContext.sessionState.catalog.lookupRelation(tableIdentifier) match { case r: MetastoreRelation => r } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 8a1cf2caaaaaa..9ff520da1d41d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -56,7 +56,7 @@ private[hive] case class CreateViewAsSelect( case true if orReplace => // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - hiveContext.sessionState.catalog.client.alertView(prepareTable(sqlContext)) + hiveContext.metadataHive.alertView(prepareTable(sqlContext)) case true => // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already @@ -66,7 +66,7 @@ private[hive] case class CreateViewAsSelect( "CREATE OR REPLACE VIEW AS") case false => - hiveContext.sessionState.catalog.client.createView(prepareTable(sqlContext)) + hiveContext.metadataHive.createView(prepareTable(sqlContext)) } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 4ffd868242b86..430fa4616fc2b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -45,7 +45,7 @@ case class InsertIntoHiveTable( @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val catalog = sc.sessionState.catalog + @transient private lazy val client = sc.metadataHive def output: Seq[Attribute] = Seq.empty @@ -186,8 +186,8 @@ case class InsertIntoHiveTable( // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - catalog.synchronized { - catalog.client.loadDynamicPartitions( + client.synchronized { + client.loadDynamicPartitions( outputPath.toString, qualifiedTableName, orderedPartitionSpec, @@ -202,12 +202,12 @@ case class InsertIntoHiveTable( // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries // scalastyle:on val oldPart = - catalog.client.getPartitionOption( - catalog.client.getTable(table.databaseName, table.tableName), + client.getPartitionOption( + client.getTable(table.databaseName, table.tableName), partitionSpec) if (oldPart.isEmpty || !ifNotExists) { - catalog.client.loadPartition( + client.loadPartition( outputPath.toString, qualifiedTableName, orderedPartitionSpec, @@ -218,7 +218,7 @@ case class InsertIntoHiveTable( } } } else { - catalog.client.loadTable( + client.loadTable( outputPath.toString, // TODO: URI qualifiedTableName, overwrite, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 226b8e179604d..cd26a68f357c7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -71,7 +71,8 @@ case class DropTable( } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.sessionState.catalog.unregisterTable(TableIdentifier(tableName)) + hiveContext.sessionState.catalog.dropTable( + TableIdentifier(tableName), ignoreIfNotExists = true) Seq.empty[Row] } } @@ -142,7 +143,8 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } @@ -200,7 +202,8 @@ case class CreateMetastoreDataSourceAsSelect( val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 19c05f9cb0d9c..11559030374b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -24,6 +24,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.processors._ @@ -35,9 +37,11 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.command.CacheTableCommand +import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.hive.client.{HiveClient, HiveClientImpl} import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -71,10 +75,77 @@ trait TestHiveSingleton { * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { - self => +class TestHiveContext private[hive]( + sc: SparkContext, + cacheManager: CacheManager, + listener: SQLListener, + executionHive: HiveClientImpl, + metadataHive: HiveClient, + isRootContext: Boolean, + hiveCatalog: HiveCatalog, + val warehousePath: File, + val scratchDirPath: File) + extends HiveContext( + sc, + cacheManager, + listener, + executionHive, + metadataHive, + isRootContext, + hiveCatalog) { self => + + // Unfortunately, due to the complex interactions between the construction parameters + // and the limitations in scala constructors, we need many of these constructors to + // provide a shorthand to create a new TestHiveContext with only a SparkContext. + // This is not a great design pattern but it's necessary here. + + private def this( + sc: SparkContext, + executionHive: HiveClientImpl, + metadataHive: HiveClient, + warehousePath: File, + scratchDirPath: File) { + this( + sc, + new CacheManager, + SQLContext.createListenerAndUI(sc), + executionHive, + metadataHive, + true, + new HiveCatalog(metadataHive), + warehousePath, + scratchDirPath) + } + + private def this(sc: SparkContext, warehousePath: File, scratchDirPath: File) { + this( + sc, + HiveContext.newClientForExecution(sc.conf, sc.hadoopConfiguration), + TestHiveContext.newClientForMetadata( + sc.conf, sc.hadoopConfiguration, warehousePath, scratchDirPath), + warehousePath, + scratchDirPath) + } - import HiveContext._ + def this(sc: SparkContext) { + this( + sc, + Utils.createTempDir(namePrefix = "warehouse"), + TestHiveContext.makeScratchDir()) + } + + override def newSession(): HiveContext = { + new TestHiveContext( + sc = sc, + cacheManager = cacheManager, + listener = listener, + executionHive = executionHive.newSession(), + metadataHive = metadataHive.newSession(), + isRootContext = false, + hiveCatalog = hiveCatalog, + warehousePath = warehousePath, + scratchDirPath = scratchDirPath) + } // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. @@ -83,26 +154,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveconf.set("hive.plan.serialization.format", "javaXML") - lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") - - lazy val scratchDirPath = { - val dir = Utils.createTempDir(namePrefix = "scratch-") - dir.delete() - dir - } - - private lazy val temporaryConfig = newTemporaryConfiguration(useInMemoryDerby = false) - - /** Sets up the system initially or after a RESET command */ - protected override def configure(): Map[String, String] = { - super.configure() ++ temporaryConfig ++ Map( - ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, - ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", - ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" - ) - } - val testTempDir = Utils.createTempDir() // For some hive test case which contain ${system:test.tmp.dir} @@ -427,9 +478,9 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { cacheManager.clearCache() loadedTables.clear() - sessionState.catalog.cachedDataSourceTables.invalidateAll() - sessionState.catalog.client.reset() - sessionState.catalog.unregisterAllTables() + sessionState.catalog.clearTempTables() + sessionState.catalog.invalidateCache() + metadataHive.reset() FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } @@ -448,13 +499,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { // Lots of tests fail if we do not change the partition whitelist from the default. runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") - configure().foreach { - case (k, v) => - metadataHive.runSqlHive(s"SET $k=$v") - } defaultOverrides() - - runSqlHive("USE default") + sessionState.catalog.setCurrentDatabase("default") } catch { case e: Exception => logError("FATAL ERROR: Failed to reset TestDB state.", e) @@ -490,4 +536,43 @@ private[hive] object TestHiveContext { // Fewer shuffle partitions to speed up testing. SQLConf.SHUFFLE_PARTITIONS.key -> "5" ) + + /** + * Create a [[HiveClient]] used to retrieve metadata from the Hive MetaStore. + */ + private def newClientForMetadata( + conf: SparkConf, + hadoopConf: Configuration, + warehousePath: File, + scratchDirPath: File): HiveClient = { + val hiveConf = new HiveConf(hadoopConf, classOf[HiveConf]) + HiveContext.newClientForMetadata( + conf, + hiveConf, + hadoopConf, + hiveClientConfigurations(hiveConf, warehousePath, scratchDirPath)) + } + + /** + * Configurations needed to create a [[HiveClient]]. + */ + private def hiveClientConfigurations( + hiveconf: HiveConf, + warehousePath: File, + scratchDirPath: File): Map[String, String] = { + HiveContext.hiveClientConfigurations(hiveconf) ++ + HiveContext.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", + ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" + ) + } + + private def makeScratchDir(): File = { + val scratchDir = Utils.createTempDir(namePrefix = "scratch") + scratchDir.delete() + scratchDir + } + } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index bd14a243eaeb4..2fc38e2b2d2e7 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -70,8 +70,9 @@ public void setUp() throws IOException { if (path.exists()) { path.delete(); } - hiveManagedPath = new Path(sqlContext.sessionState().catalog().hiveDefaultTableFilePath( - new TableIdentifier("javaSavedTable"))); + hiveManagedPath = new Path( + sqlContext.sessionState().catalog().hiveDefaultTableFilePath( + new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); if (fs.exists(hiveManagedPath)){ fs.delete(hiveManagedPath, true); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala new file mode 100644 index 0000000000000..fa0c4d92cd527 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala @@ -0,0 +1,38 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.hive.test.TestHive + + +class HiveContextSuite extends SparkFunSuite { + + // TODO: investigate; this passes locally but fails on Jenkins for some reason. + ignore("HiveContext can access `spark.sql.*` configs") { + // Avoid creating another SparkContext in the same JVM + val sc = TestHive.sparkContext + require(sc.conf.get("spark.sql.hive.metastore.barrierPrefixes") == + "org.apache.spark.sql.hive.execution.PairSerDe") + assert(TestHive.getConf("spark.sql.hive.metastore.barrierPrefixes") == + "org.apache.spark.sql.hive.execution.PairSerDe") + assert(TestHive.metadataHive.getConf("spark.sql.hive.metastore.barrierPrefixes", "") == + "org.apache.spark.sql.hive.execution.PairSerDe") + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ce7b08ab72f79..42cbfee10ee1f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -21,6 +21,7 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -83,7 +84,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = sessionState.catalog.client.getTable("default", "t") + val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) @@ -114,7 +115,8 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = sessionState.catalog.client.getTable("default", "t") + val hiveTable = + sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) @@ -144,7 +146,8 @@ class DataSourceWithHiveMetastoreCatalogSuite |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) - val hiveTable = sessionState.catalog.client.getTable("default", "t") + val hiveTable = + sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 0a31ac64a20f5..c3b24623d1a79 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -32,14 +32,16 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. - sessionState.catalog.registerTable(TableIdentifier("ListTablesSuiteTable"), df.logicalPlan) + sessionState.catalog.createTempTable( + "ListTablesSuiteTable", df.logicalPlan, ignoreIfExists = true) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") } override def afterAll(): Unit = { - sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 3f3d0692b7b61..7d2a4eb1de7e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -693,13 +693,13 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("SPARK-6024 wide schema support") { withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { withTable("wide_schema") { - withTempDir( tempDir => { + withTempDir { tempDir => // We will need 80 splits for this schema if the threshold is 4000. val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) // Manually create a metastore data source table. sessionState.catalog.createDataSourceTable( - tableIdent = TableIdentifier("wide_schema"), + name = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], bucketSpec = None, @@ -711,7 +711,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val actualSchema = table("wide_schema").schema assert(schema === actualSchema) - }) + } } } } @@ -737,7 +737,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv "spark.sql.sources.schema" -> schema.json, "EXTERNAL" -> "FALSE")) - sessionState.catalog.client.createTable(hiveTable, ignoreIfExists = false) + hiveCatalog.createTable("default", hiveTable, ignoreIfExists = false) invalidateTable(tableName) val actualSchema = table(tableName).schema @@ -752,7 +752,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv withTable(tableName) { df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) invalidateTable(tableName) - val metastoreTable = sessionState.catalog.client.getTable("default", tableName) + val metastoreTable = hiveCatalog.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val numPartCols = metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt @@ -787,7 +787,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .sortBy("c") .saveAsTable(tableName) invalidateTable(tableName) - val metastoreTable = sessionState.catalog.client.getTable("default", tableName) + val metastoreTable = hiveCatalog.getTable("default", tableName) val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val expectedSortByColumns = StructType(df.schema("c") :: Nil) @@ -903,11 +903,11 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("skip hive metadata on table creation") { - withTempDir(tempPath => { + withTempDir { tempPath => val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) sessionState.catalog.createDataSourceTable( - tableIdent = TableIdentifier("not_skip_hive_metadata"), + name = TableIdentifier("not_skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], bucketSpec = None, @@ -917,11 +917,11 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // As a proxy for verifying that the table was stored in Hive compatible format, // we verify that each column of the table is of native type StringType. - assert(sessionState.catalog.client.getTable("default", "not_skip_hive_metadata").schema + assert(hiveCatalog.getTable("default", "not_skip_hive_metadata").schema .forall(column => HiveMetastoreTypes.toDataType(column.dataType) == StringType)) sessionState.catalog.createDataSourceTable( - tableIdent = TableIdentifier("skip_hive_metadata"), + name = TableIdentifier("skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], bucketSpec = None, @@ -929,10 +929,11 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv options = Map("path" -> tempPath.getCanonicalPath, "skipHiveMetadata" -> "true"), isExternal = false) - // As a proxy for verifying that the table was stored in SparkSQL format, we verify that - // the table has a column type as array of StringType. - assert(sessionState.catalog.client.getTable("default", "skip_hive_metadata").schema - .forall(column => HiveMetastoreTypes.toDataType(column.dataType) == ArrayType(StringType))) - }) + // As a proxy for verifying that the table was stored in SparkSQL format, + // we verify that the table has a column type as array of StringType. + assert(hiveCatalog.getTable("default", "skip_hive_metadata").schema.forall { c => + HiveMetastoreTypes.toDataType(c.dataType) == ArrayType(StringType) + }) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index d275190744002..3be2269d3f11f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -25,9 +25,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle private lazy val df = sqlContext.range(10).coalesce(1).toDF() private def checkTablePath(dbName: String, tableName: String): Unit = { - val metastoreTable = hiveContext.sessionState.catalog.client.getTable(dbName, tableName) - val expectedPath = - hiveContext.sessionState.catalog.client.getDatabase(dbName).locationUri + "/" + tableName + val metastoreTable = hiveContext.hiveCatalog.getTable(dbName, tableName) + val expectedPath = hiveContext.hiveCatalog.getDatabase(dbName).locationUri + "/" + tableName assert(metastoreTable.storage.serdeProperties("path") === expectedPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 151aacbdd1c44..ae026ed4964eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -121,7 +121,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { intercept[UnsupportedOperationException] { hiveContext.analyze("tempTable") } - hiveContext.sessionState.catalog.unregisterTable(TableIdentifier("tempTable")) + hiveContext.sessionState.catalog.dropTable( + TableIdentifier("tempTable"), ignoreIfNotExists = true) } test("estimates the size of a test MetastoreRelation") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 295069228fea1..d59bca4c7ee4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -171,10 +171,6 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(client.listTables("default") === Seq("src")) } - test(s"$version: currentDatabase") { - assert(client.currentDatabase === "default") - } - test(s"$version: getDatabase") { client.getDatabase("default") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5fe85eaef2b55..197a123905d2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -49,6 +49,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -57,11 +58,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } override def afterAll() { - TestHive.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - sql("DROP TEMPORARY FUNCTION udtf_count2") - super.afterAll() + try { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + sql("DROP TEMPORARY FUNCTION udtf_count2") + } finally { + super.afterAll() + } } test("SPARK-4908: concurrent hive native commands") { @@ -1209,7 +1213,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("USE hive_test_db") assert("hive_test_db" == sql("select current_database()").first().getString(0)) - intercept[NoSuchDatabaseException] { + intercept[AnalysisException] { sql("USE not_existing_db") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index bc8896d4bd0ba..6199253d34db0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1325,6 +1325,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .format("parquet") .save(path) + // We don't support creating a temporary table while specifying a database val message = intercept[AnalysisException] { sqlContext.sql( s""" @@ -1335,9 +1336,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |) """.stripMargin) }.getMessage - assert(message.contains("Specifying database name or other qualifiers are not allowed")) - // If you use backticks to quote the name of a temporary table having dot in it. + // If you use backticks to quote the name then it's OK. sqlContext.sql( s""" |CREATE TEMPORARY TABLE `db.t` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index cc412241fb4da..92f424bac7ff3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -222,7 +222,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - sessionState.catalog.unregisterTable(TableIdentifier("tmp")) + sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { @@ -232,7 +232,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(table("t"), data.map(Row.fromTuple)) } - sessionState.catalog.unregisterTable(TableIdentifier("tmp")) + sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true) } test("self-join") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index bb53179c3cce3..07fe0ccd877d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.DataSourceScan import org.apache.spark.sql.execution.command.ExecutedCommand import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} @@ -425,10 +426,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("Caching converted data source Parquet Relations") { - val _catalog = sessionState.catalog - def checkCached(tableIdentifier: _catalog.QualifiedTableName): Unit = { + def checkCached(tableIdentifier: TableIdentifier): Unit = { // Converted test_parquet should be cached. - sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { + sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK case other => @@ -453,17 +453,17 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - var tableIdentifier = _catalog.QualifiedTableName("default", "test_insert_parquet") + var tableIdentifier = TableIdentifier("test_insert_parquet", Some("default")) // First, make sure the converted test_parquet is not cached. - assert(sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Table lookup will make the table cached. table("test_insert_parquet") checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. invalidateTable("test_insert_parquet") - assert(sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_insert_parquet @@ -476,7 +476,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("select a, b from jt").collect()) // Invalidate the cache. invalidateTable("test_insert_parquet") - assert(sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Create a partitioned table. sql( @@ -493,8 +493,8 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - tableIdentifier = _catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") - assert(sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + tableIdentifier = TableIdentifier("test_parquet_partitioned_cache_test", Some("default")) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -503,14 +503,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) - assert(sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") @@ -526,7 +526,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin).collect()) invalidateTable("test_parquet_partitioned_cache_test") - assert(sessionState.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") }